Различия в производительности вывода между Flash Attention v1 и v2.
1. История
С момента разработки преобразователя механизм внимания также проявил себя в LLM (большая языковая модель). Однако из-за вычислительных ограничений softmax процесс расчета MHA (Multi Head Attention) долгое время находился в состоянии серьезной привязки к памяти. Основываясь на математических характеристиках softmax, Flash Attention объединяет вычисление MHA в одном операторе и применяет стратегию обмена вычислениями и высокоскоростным доступом к памяти SRAM на низкоскоростной доступ к памяти HBM, что снижает нагрузку на память и значительно снижает нагрузку на память. улучшает скорость вычислений MHA.
Эта статья основана на интерфейсе C++ Flash Attention и Flash Attention v2 и посвящена изучению влияния различий в процессах вычислений между ними на производительность вывода MHA.
2 млн га
2.1 Процесс расчета
Процесс расчета самовнимания в MHA показан на рисунке выше и может быть разделен на следующие три этапа.
O = Softmax(Q * K^T) * V Step1: S = Q * K^T Step2: P = Softmax(S) Step3: O = P * V
Аналогично, размеры Q, K, V и O в MHA следующие. Шаг 1 вычисляет умножение матрицы total_q. Размерность каждого умножения матрицы равна (sq * d) * (d * sk), и получается S. Шаг 2 передает softmax и вычисляет P. Шаг 3 вычисляет умножение матрицы total_q. Размерность каждого умножения матрицы равна (sq * sk) * (sk * d), и получается O.
- Q: total_q * hq * dim
- K: total_k * hk * дим
- V: total_k * hk * дим
- O: total_q * hq * dim
Код реализации MHA ЦП выглядит следующим образом, а исходный код находится в flash_attention_inference.
void mha_cpu(Tensor<cutlass::half_t> *Q, Tensor<cutlass::half_t> *K, Tensor<cutlass::half_t> *V, Tensor<cutlass::half_t> *O, int *cu_seq_q, int *cu_seq_k, size_t batch, size_t max_seq_q, size_t max_seq_k, bool is_causal, bool is_alibi) { size_t total_q = Q->getShape()[0]; size_t head_q = Q->getShape()[1]; size_t dim = Q->getShape()[2]; size_t head_k = K->getShape()[1]; FAI_CHECK_EQ(head_q % head_k, 0); const size_t head_ratio = head_q / head_k; cutlass::half_t *q_ptr = Q->getHostPtr(); cutlass::half_t *k_ptr = K->getHostPtr(); cutlass::half_t *v_ptr = V->getHostPtr(); cutlass::half_t *o_ptr = O->getHostPtr(); // S = Q * K^T Tensor<float> *S = new Tensor<float>({total_q, head_q, max_seq_k}, "Tensor S"); FAI_CHECK(S); float *s_ptr = S->getHostPtr(); for (size_t b = 0; b < batch; ++b) { size_t seq_q = cu_seq_q[b + 1] - cu_seq_q[b]; size_t seq_k = cu_seq_k[b + 1] - cu_seq_k[b]; for (size_t h = 0; h < head_q; ++h) { for (size_t sq = 0; sq < seq_q; ++sq) { for (size_t sk = 0; sk < seq_k; ++sk) { float tmp = 0.0; for (size_t d = 0; d < dim; ++d) { tmp += static_cast<cutlass::half_t>( q_ptr[cu_seq_q[b] * (head_q * dim) + sq * (head_q * dim) + h * dim + d]) * static_cast<cutlass::half_t>( k_ptr[cu_seq_k[b] * (head_k * dim) + sk * (head_k * dim) + (h / head_ratio) * dim + d]); } s_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = tmp; } } } } // P = Softmax(S) Tensor<cutlass::half_t> *P = new Tensor<cutlass::half_t>({total_q, head_q, max_seq_k}, "Tensor P"); FAI_CHECK(P); cutlass::half_t *p_ptr = P->getHostPtr(); float scale = 1.0 / std::sqrt(dim); for (size_t b = 0; b < batch; ++b) { size_t seq_q = cu_seq_q[b + 1] - cu_seq_q[b]; size_t seq_k = cu_seq_k[b + 1] - cu_seq_k[b]; size_t row_shift = seq_k - seq_q; for (size_t h = 0; h < head_q; ++h) { float h_slope = is_alibi ? (1.0 / exp2(8.0 * (h + 1) / head_q)) : 0.0; for (size_t sq = 0; sq < seq_q; ++sq) { size_t col_limit = is_causal ? std::min(seq_k, sq + row_shift + 1) : seq_k; // Max(S) std::vector<float> tmp_s(seq_k, 0.0); float max_s = -std::numeric_limits<float>::max(); for (size_t sk = 0; sk < col_limit; ++sk) { tmp_s[sk] = s_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] * scale; if (is_alibi && sk < sq + row_shift) { tmp_s[sk] += (h_slope * (static_cast<int>(sk) - static_cast<int>(sq) - static_cast<int>(row_shift))); } max_s = std::max(max_s, tmp_s[sk]); } // Sum(S) float sum_s = 0.0; for (size_t sk = 0; sk < col_limit; ++sk) { tmp_s[sk] = std::exp(tmp_s[sk] - max_s); sum_s += tmp_s[sk]; } // Softmax(S) for (size_t sk = 0; sk < col_limit; ++sk) { p_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = static_cast<cutlass::half_t>(tmp_s[sk] / sum_s); } // Causal(S) if (is_causal) { for (size_t sk = col_limit; sk < seq_k; ++sk) { p_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk] = 0_hf; } } } } } // O = P * V for (size_t b = 0; b < batch; ++b) { size_t seq_q = cu_seq_q[b + 1] - cu_seq_q[b]; size_t seq_k = cu_seq_k[b + 1] - cu_seq_k[b]; for (size_t h = 0; h < head_q; ++h) { for (size_t sq = 0; sq < seq_q; ++sq) { for (size_t d = 0; d < dim; ++d) { float tmp = 0.0; for (size_t sk = 0; sk < seq_k; ++sk) { tmp += static_cast<cutlass::half_t>( p_ptr[cu_seq_q[b] * (head_q * seq_k) + sq * (head_q * seq_k) + h * seq_k + sk]) * static_cast<cutlass::half_t>( v_ptr[cu_seq_k[b] * (head_k * dim) + sk * (head_k * dim) + (h / head_ratio) * dim + d]); } o_ptr[cu_seq_q[b] * (head_q * dim) + sq * (head_q * dim) + h * dim + d] = static_cast<cutlass::half_t>(tmp); } } } } if (S) { delete S; S = nullptr; } if (P) { delete P; P = nullptr; } }
2.2 Вспышка внимания
В этой статье основное внимание уделяется только процессу расчета MHA Flash Attention. Для получения дополнительной информации, пожалуйста, просмотрите документ и исходный код.
Расчет Flash Attention для MHA заключается в разделении блоков по пакету, заголовку и разделению-seq_q. При вычислении Q * K^T вычисление внутренней деформации делится в соответствии с размерностью seq_k матрицы K^T, то есть каждая деформация может получить только матрицу S. Результат частичной блокировки определенной строки. Поэтому при вычислении softmax блока сначала необходимо синхронизировать варп. С другой стороны, при окончательном вычислении P*V используется метод разделения-K. Промежуточные результаты каждого расчета деформации должны быть уменьшены и суммированы, прежде чем можно будет получить результат блока O. Перед уменьшением деформации все равно необходимо синхронизировать.
2.3 Вспышка внимания v2
Flash Attention v2 также разделяет блоки по пакетам, заголовку и Split-seq_q для расчета MHA. Однако при вычислении Q * K^T вычисление внутренней деформации делится в соответствии с размерностью seq_q матрицы Q, то есть каждая деформация может получить определенное значение матрицы S. Все результаты блока в одной строке. Поэтому при расчете softmax блока нет необходимости синхронизировать варп. С другой стороны, при окончательном вычислении P*V каждый варп также может напрямую вычислить результат варпа O, без необходимости сокращения или дополнительной синхронизации варпов.
3. Производительность вывода
3.1 Условия испытаний
Код с открытым исходным кодом находится в flash_attention_inference, а ядро взято из flash-attention. Коды обратной зависимости, исключения, bf16 и факела, которые не имеют отношения к выводу, удалены и могут быть легко интегрированы в сценарии вывода LLM. Основываясь на мгновенном внимании, этот код также полностью поддерживает сценарии вывода GQA (Group Query Attention)/MQA (Multi Query Attention), сценарии гибридного вывода предварительного заполнения/декодирования и сценарии вывода ALiBi (Внимание с линейными смещениями).
- MHA: O = Softmax(Q * K^T) * V
- КУДА: 11,8
- Графический процессор: RTX3090
- Флэш-внимание: v1.0.9
- Вспышка внимания v2: v2.1.0
- Абордаж: v3.1.0
- Номер головы: 32
- Размер головы: 128
3.2 Производительность вывода предварительного заполнения
(1) Сек Лен
Когда дело доходит до коротких сцен, их производительность эквивалентна; когда дело доходит до длинных последовательностей, Flash Attention v2 работает лучше и может быть улучшен примерно на 50%. Причина, по которой Flash Attention v2 хорошо работает в длинных последовательностях, в основном связана с уменьшением многократной синхронизации деформаций между данными блоков.
- Размер партии: 128
- Seq Q: Seq Len
- Seq K: Seq Len
(2) Размер партии
Когда размер пакета меньше, Flash Attention v2 работает лучше; когда размер партии больше, производительность обоих эквивалентна.
- Размер партии: Размер партии
- Последующий вопрос: 128
- Последовательность К: 128
3.3. Производительность декодирования вывода
(1) Сек Лен
Когда последовательность короткая, производительность обоих эквивалентна; когда последовательность длинная, эффективность Flash Attention выше. Причина, по которой Flash Attention хорошо работает в длинных последовательностях, в основном связана с деформационным разделением труда в измерении seq_k, которое улучшает параллелизм вычислений.
- Размер партии: 128
- Последовательность вопросов: 1
- Seq K: Seq Len
(2) Размер партии
Независимо от размера пакета, Flash Attention работает лучше.
- Размер партии: Размер партии
- Последовательность вопросов: 1
- Последовательность К: 128
3.4. Производительность гибридного вывода
Независимо от того, как меняется соотношение предварительного заполнения и декодирования, производительность Flash Attention и Flash Attention v2 относительно близка.
- Размер партии: 100
- Seq Q: 128 (предварительное заполнение) + 1 (декодирование)
- Последовательность К: 128
4 Другое
4.1 Сценарии вывода GQA/MQA
Поддерживаются все сценарии вывода GQA/MQA, а код обновляется в flash_attention_inference.
4.2 Сценарии гибридного вывода
Поддерживаются все сценарии гибридного вывода предварительного заполнения и декодирования. Код обновлен в flash_attention_inference, а производительность показана в версии 3.4.
4.3 Сценарии вывода ALiBi
Поддерживаются все сценарии вывода ALiBi, а код обновлен в flash_attention_inference.