Различия в производительности вывода между Flash Attention v1 и v2.

1. История

С момента разработки преобразователя механизм внимания также проявил себя в LLM (большая языковая модель). Однако из-за вычислительных ограничений softmax процесс расчета MHA (Multi Head Attention) долгое время находился в состоянии серьезной привязки к памяти. Основываясь на математических характеристиках softmax, Flash Attention объединяет вычисление MHA в одном операторе и применяет стратегию обмена вычислениями и высокоскоростным доступом к памяти SRAM на низкоскоростной доступ к памяти HBM, что снижает нагрузку на память и значительно снижает нагрузку на память. улучшает скорость вычислений MHA.

Эта статья основана на интерфейсе C++ Flash Attention и Flash Attention v2 и посвящена изучению влияния различий в процессах вычислений между ними на производительность вывода MHA.

2 млн га

2.1 Процесс расчета

Процесс расчета самовнимания в MHA показан на рисунке выше и может быть разделен на следующие три этапа.

O = Softmax(Q * K^T) * V

Step1: S = Q * K^T
Step2: P = Softmax(S)
Step3: O = P * V

Аналогично, размеры Q, K, V и O в MHA следующие. Шаг 1 вычисляет умножение матрицы total_q. Размерность каждого умножения матрицы равна (sq * d) * (d * sk), и получается S. Шаг 2 передает softmax и вычисляет P. Шаг 3 вычисляет умножение матрицы total_q. Размерность каждого умножения матрицы равна (sq * sk) * (sk * d), и получается O.

  • Q: total_q * hq * dim
  • K: total_k * hk * дим
  • V: total_k * hk * дим
  • O: total_q * hq * dim

Код реализации MHA ЦП выглядит следующим образом, а исходный код находится в flash_attention_inference.

void mha_cpu(Tensor<cutlass::half_t> *Q, Tensor<cutlass::half_t> *K, Tensor<cutlass::half_t> *V,
             Tensor<cutlass::half_t> *O, int *cu_seq_q, int *cu_seq_k, size_t batch, size_t max_seq_q, size_t max_seq_k,
             bool is_causal, bool is_alibi) {
    size_t total_q = Q->getShape()[0];
    size_t head_q = Q->getShape()[1];
    size_t dim = Q->getShape()[2];
    size_t head_k = K->getShape()[1];

    FAI_CHECK_EQ(head_q % head_k, 0);
    const size_t head_ratio = head_q / head_k;

    cutlass::half_t *q_ptr = Q->getHostPtr();
    cutlass::half_t *k_ptr = K->getHostPtr();
    cutlass::half_t *v_ptr = V->getHostPtr();
    cutlass::half_t *o_ptr = O->getHostPtr();

    // S = Q * K^T
    Tensor<float> *S = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor S");
    FAI_CHECK(S);
    float *s_ptr = S->getHostPtr();
    for (size_t b = 0; b < batch; ++b) {
        size_t seq_q = cu_seq_q[b + 1] - cu_seq_q[b];
        size_t seq_k = cu_seq_k[b + 1] - cu_seq_k[b];
        for (size_t h = 0; h < head_q; ++h) {
            for (size_t sq = 0; sq < seq_q; ++sq) {
                for (size_t sk = 0; sk < seq_k; ++sk) {
                    float tmp = 0.0;
                    for (size_t d = 0; d < dim; ++d) {
                        tmp +=
                            static_cast<cutlass::half_t>(
                                q_ptr[cu_seq_q[b] * (head_q * dim) + sq * (head_q * dim) + h * dim + d]) *
                            static_cast<cutlass::half_t>(
                                k_ptr[cu_seq_k[b] * (head_k * dim) + sk * (head_k * dim) + (h / head_ratio) * dim + d]);
                    }
                    s_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = tmp;
                }
            }
        }
    }

    // P = Softmax(S)
    Tensor<cutlass::half_t> *P = new Tensor<cutlass::half_t>({total_q, head_q, max_seq_k}, "Tensor P");
    FAI_CHECK(P);
    cutlass::half_t *p_ptr = P->getHostPtr();
    float scale = 1.0 / std::sqrt(dim);
    for (size_t b = 0; b < batch; ++b) {
        size_t seq_q = cu_seq_q[b + 1] - cu_seq_q[b];
        size_t seq_k = cu_seq_k[b + 1] - cu_seq_k[b];
        size_t row_shift = seq_k - seq_q;
        for (size_t h = 0; h < head_q; ++h) {
            float h_slope = is_alibi ? (1.0 / exp2(8.0 * (h + 1) / head_q)) : 0.0;
            for (size_t sq = 0; sq < seq_q; ++sq) {
                size_t col_limit = is_causal ? std::min(seq_k, sq + row_shift + 1) : seq_k;

                // Max(S)
                std::vector<float> tmp_s(seq_k, 0.0);
                float max_s = -std::numeric_limits<float>::max();
                for (size_t sk = 0; sk < col_limit; ++sk) {
                    tmp_s[sk] = s_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] * scale;
                    if (is_alibi && sk < sq + row_shift) {
                        tmp_s[sk] +=
                            (h_slope * (static_cast<int>(sk) - static_cast<int>(sq) - static_cast<int>(row_shift)));
                    }
                    max_s = std::max(max_s, tmp_s[sk]);
                }

                // Sum(S)
                float sum_s = 0.0;
                for (size_t sk = 0; sk < col_limit; ++sk) {
                    tmp_s[sk] = std::exp(tmp_s[sk] - max_s);
                    sum_s += tmp_s[sk];
                }

                // Softmax(S)
                for (size_t sk = 0; sk < col_limit; ++sk) {
                    p_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] =
                        static_cast<cutlass::half_t>(tmp_s[sk] / sum_s);
                }

                // Causal(S)
                if (is_causal) {
                    for (size_t sk = col_limit; sk < seq_k; ++sk) {
                        p_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = 0_hf;
                    }
                }
            }
        }
    }

    // O = P * V
    for (size_t b = 0; b < batch; ++b) {
        size_t seq_q = cu_seq_q[b + 1] - cu_seq_q[b];
        size_t seq_k = cu_seq_k[b + 1] - cu_seq_k[b];
        for (size_t h = 0; h < head_q; ++h) {
            for (size_t sq = 0; sq < seq_q; ++sq) {
                for (size_t d = 0; d < dim; ++d) {
                    float tmp = 0.0;
                    for (size_t sk = 0; sk < seq_k; ++sk) {
                        tmp +=
                            static_cast<cutlass::half_t>(
                                p_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk]) *
                            static_cast<cutlass::half_t>(
                                v_ptr[cu_seq_k[b] * (head_k * dim) + sk * (head_k * dim) + (h / head_ratio) * dim + d]);
                    }
                    o_ptr[cu_seq_q[b] * (head_q * dim) + sq * (head_q * dim) + h * dim + d] =
                        static_cast<cutlass::half_t>(tmp);
                }
            }
        }
    }

    if (S) {
        delete S;
        S = nullptr;
    }

    if (P) {
        delete P;
        P = nullptr;
    }
}

2.2 Вспышка внимания

В этой статье основное внимание уделяется только процессу расчета MHA Flash Attention. Для получения дополнительной информации, пожалуйста, просмотрите документ и исходный код.

Расчет Flash Attention для MHA заключается в разделении блоков по пакету, заголовку и разделению-seq_q. При вычислении Q * K^T вычисление внутренней деформации делится в соответствии с размерностью seq_k матрицы K^T, то есть каждая деформация может получить только матрицу S. Результат частичной блокировки определенной строки. Поэтому при вычислении softmax блока сначала необходимо синхронизировать варп. С другой стороны, при окончательном вычислении P*V используется метод разделения-K. Промежуточные результаты каждого расчета деформации должны быть уменьшены и суммированы, прежде чем можно будет получить результат блока O. Перед уменьшением деформации все равно необходимо синхронизировать.

2.3 Вспышка внимания v2

Flash Attention v2 также разделяет блоки по пакетам, заголовку и Split-seq_q для расчета MHA. Однако при вычислении Q * K^T вычисление внутренней деформации делится в соответствии с размерностью seq_q матрицы Q, то есть каждая деформация может получить определенное значение матрицы S. Все результаты блока в одной строке. Поэтому при расчете softmax блока нет необходимости синхронизировать варп. С другой стороны, при окончательном вычислении P*V каждый варп также может напрямую вычислить результат варпа O, без необходимости сокращения или дополнительной синхронизации варпов.

3. Производительность вывода

3.1 Условия испытаний

Код с открытым исходным кодом находится в flash_attention_inference, а ядро ​​взято из flash-attention. Коды обратной зависимости, исключения, bf16 и факела, которые не имеют отношения к выводу, удалены и могут быть легко интегрированы в сценарии вывода LLM. Основываясь на мгновенном внимании, этот код также полностью поддерживает сценарии вывода GQA (Group Query Attention)/MQA (Multi Query Attention), сценарии гибридного вывода предварительного заполнения/декодирования и сценарии вывода ALiBi (Внимание с линейными смещениями).

  • MHA: O = Softmax(Q * K^T) * V
  • КУДА: 11,8
  • Графический процессор: RTX3090
  • Флэш-внимание: v1.0.9
  • Вспышка внимания v2: v2.1.0
  • Абордаж: v3.1.0
  • Номер головы: 32
  • Размер головы: 128

3.2 Производительность вывода предварительного заполнения

(1) Сек Лен

Когда дело доходит до коротких сцен, их производительность эквивалентна; когда дело доходит до длинных последовательностей, Flash Attention v2 работает лучше и может быть улучшен примерно на 50%. Причина, по которой Flash Attention v2 хорошо работает в длинных последовательностях, в основном связана с уменьшением многократной синхронизации деформаций между данными блоков.

  • Размер партии: 128
  • Seq Q: Seq Len
  • Seq K: Seq Len

(2) Размер партии

Когда размер пакета меньше, Flash Attention v2 работает лучше; когда размер партии больше, производительность обоих эквивалентна.

  • Размер партии: Размер партии
  • Последующий вопрос: 128
  • Последовательность К: 128

3.3. Производительность декодирования вывода

(1) Сек Лен

Когда последовательность короткая, производительность обоих эквивалентна; когда последовательность длинная, эффективность Flash Attention выше. Причина, по которой Flash Attention хорошо работает в длинных последовательностях, в основном связана с деформационным разделением труда в измерении seq_k, которое улучшает параллелизм вычислений.

  • Размер партии: 128
  • Последовательность вопросов: 1
  • Seq K: Seq Len

(2) Размер партии

Независимо от размера пакета, Flash Attention работает лучше.

  • Размер партии: Размер партии
  • Последовательность вопросов: 1
  • Последовательность К: 128

3.4. Производительность гибридного вывода

Независимо от того, как меняется соотношение предварительного заполнения и декодирования, производительность Flash Attention и Flash Attention v2 относительно близка.

  • Размер партии: 100
  • Seq Q: 128 (предварительное заполнение) + 1 (декодирование)
  • Последовательность К: 128

4 Другое

4.1 Сценарии вывода GQA/MQA

Поддерживаются все сценарии вывода GQA/MQA, а код обновляется в flash_attention_inference.

4.2 Сценарии гибридного вывода

Поддерживаются все сценарии гибридного вывода предварительного заполнения и декодирования. Код обновлен в flash_attention_inference, а производительность показана в версии 3.4.

4.3 Сценарии вывода ALiBi

Поддерживаются все сценарии вывода ALiBi, а код обновлен в flash_attention_inference.