Stf CS149 flash attention
Stf CS149 flash attention lab assignment repo
Install library to compile code
It shows error about error loading shared object.
(cs149gpt) ➜ cs149gpt git:(main) ✗ python3 gpt149.py 4Daccess
/home/zt/.local/lib/python3.10/site-packages/torch/cuda/__init__.py:118: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.) return torch._C._cuda_getDeviceCount() > 0
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
Traceback (most recent call last):
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 14, in <module>
import module_ref as ms
ImportError: /home/zt/stf-cs149-pp/cs149gpt/module_ref.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE
I tried to use conda to create new env and install low version of pytorch but conda always installs 2.3.x version of torch for me.
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 cpuonly python=3.10 numpy=1.26 ninja tiktoken -c pytorch -c conda-forge
What it is required is 2.1.x version of torch.
Then I manually use pip to uninstall torch and reinstall 2.1.x version of torch
pip3 uninstall torch
pip3 install torch==2.1.2
Got another issue saying that I should use numpy version < 2.0.
Then I uninstall numpy and reinstall it with version 1.2.x
pip3 uninstall numpy
pip3 install numpy==1.26
So now I can run the code successfully.
Part 1: naive attention
My code produce the value that is 0.0003 less than what solution produces for each element.
I don’t know why.
Should I use double ?
First implementation:
torch::Tensor myNaiveAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector.
std::vector<float> QK_t = formatTensor(QK_tTensor);
/* Here is an example of how to read/write 0's to Q (B, H, N, d) using the 4D accessors
//loop over Batch Size
for (int b = 0; b < B; b++) {
//loop over Heads
for (int h = 0; h < H; h++) {
//loop over Sequence Length
for (int i = 0; i < N; i++) {
//loop over Embedding Dimensionality
for (int j = 0; j < d; j++) {
float val = fourDimRead(Q, b, h, i, j, H, N, d);
val = 0.0;
fourDimWrite(Q, b, h, i, j, H, N, d, val);
}
}
}
}
*/
/* Here is an example of how to read/write 0's to QK_t (N, N) using the 2D accessors
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float val = twoDimRead(QK_t, i, j, N);
val = 0.0;
twoDimWrite(QK_t, i, j, N, val);
}
}
*/
// -------- YOUR CODE HERE -------- //
for (int b = 0; b < B; b++) {
//loop over Heads
for (int h = 0; h < H; h++) {
//loop over Sequence Length
for (int i = 0; i < N; i++) {
for(int seq_i=0; seq_i < N; seq_i++) {
//loop over Embedding Dimensionality
float val = 0.0;
for (int j = 0; j < d; j++) {
int q_row = i;
int q_col = j;
int k_row = j;
int k_col = seq_i;
// float val = fourDimRead(Q, b, h, i, j, H, N, d);
float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d);
val += q_val * k_val;
// val = 0.0;
// fourDimWrite(Q, b, h, i, j, H, N, d, val);
}
fourDimWrite(QK_t, b, h, i, seq_i, H, N, d, val );
}
}
std::vector<float> tmp_row_res(N);
for(int row_idx=0; row_idx < N; row_idx++) {
float row_sum = 0.0;
for(int cold_idx=0; cold_idx < N ;cold_idx++) {
float val = twoDimRead(QK_t, row_idx, cold_idx, N);
float exp_val = std::exp(val);
row_sum += exp_val;
tmp_row_res[cold_idx] = exp_val;
}
for(int cold_idx=0; cold_idx < N ; cold_idx++) {
float prob = tmp_row_res[cold_idx] / row_sum;
twoDimWrite(QK_t, row_idx, cold_idx, N, prob);
}
}
for(int qkt_row_idx=0; qkt_row_idx < N; qkt_row_idx++) {
for(int output_d_idx=0; output_d_idx < d; output_d_idx++) {
float val =0.0;
for(int m_idx=0; m_idx < N ; m_idx++) {
float qkt_val = twoDimRead(QK_t, qkt_row_idx, m_idx, N);
int v_row = m_idx;
int v_col = output_d_idx;
float v_val = fourDimRead(V, b, h, v_row, v_col, H, N, d);
val += qkt_val * v_val;
}
fourDimWrite(O, b, h, qkt_row_idx, output_d_idx, H, N, d ,val);
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-17 21:15:18 207308:207308 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0768, 0.0776, 0.0784, 0.0792, 0.0800, 0.0808, 0.0816, 0.0824, 0.0832,
0.0840, 0.0848, 0.0856, 0.0864, 0.0872, 0.0880, 0.0888, 0.0896, 0.0904,
0.0912, 0.0920, 0.0928, 0.0936, 0.0944, 0.0952, 0.0960, 0.0968, 0.0976,
0.0984, 0.0992, 0.1000, 0.1008, 0.1016])
STAGE:2024-11-17 21:15:19 207308:207308 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-17 21:15:19 207308:207308 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
main()
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 311, in main
part1Test(N, d, B, H)
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 221, in part1Test
testTemplate(attentionModuleReference.myUnfusedAttention, params, "STUDENT - NAIVE ATTENTION")
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
YOUR ATTENTION PRODUCED INCORRECT RESULTS
Try to use double to store itermediate result to fix this problem.
I have two issues in previous version of code
for (int b = 0; b < B; b++) {
//loop over Heads
for (int h = 0; h < H; h++) {
//loop over Sequence Length
for (int i = 0; i < N; i++) {
for(int seq_i=0; seq_i < N; seq_i++) {
//loop over Embedding Dimensionality
float val = 0.0;
for (int j = 0; j < d; j++) {
int q_row = i;
int q_col = j;
// this is the correct indexing for the second matrix.
// Since K is not transposed.
// K should be indexed with (seq_i, j) instead of (j, seq_i) like normal matrix multiplciation
int k_row = seq_i;
int k_col = j;
float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d);
val += q_val * k_val;
}
// This is the second place that is fixed.
// QK_t is two dimenional.
// Should use twoDimWrite
twoDimWrite(QK_t, i, seq_i, N, val );
}
}
Output:
REFERENCE - NAIVE ATTENTION statistics
cpu time: 293.585ms
mem usage: 4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:10:17 1148267:1148267 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:10:17 1148267:1148267 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 13:10:17 1148267:1148267 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.2747969627380371
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 16.000us 0.01% 16.000us 2.000us 5.00 Mb 5.00 Mb 8
STUDENT - NAIVE ATTENTION 99.49% 273.438ms 99.68% 273.946ms 273.946ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.01% 14.000us 0.05% 133.000us 66.500us 4.50 Mb 0 b 2
aten::clone 0.01% 20.000us 0.13% 346.000us 173.000us 1.00 Mb 0 b 2
aten::empty_strided 0.00% 12.000us 0.00% 12.000us 2.400us 512.51 Kb 512.51 Kb 5
model_inference 0.20% 549.000us 100.00% 274.838ms 274.838ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.01% 19.000us 0.08% 209.000us 41.800us 512.00 Kb 0 b 5
aten::empty_like 0.00% 3.000us 0.00% 5.000us 5.000us 512.00 Kb 0 b 1
aten::to 0.00% 6.000us 0.01% 31.000us 5.167us 520 b 0 b 6
aten::_to_copy 0.01% 14.000us 0.01% 25.000us 6.250us 520 b 0 b 4
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 274.838ms
STUDENT - NAIVE ATTENTION statistics
cpu time: 273.946ms
mem usage: 4718592 bytes
Part2: blocked matrix multiplication
Initial version
Unlike that of code written in cuda, I have to manually iterate all rows and cols of input matrix.
Don’t know if this is correct way to do that.
#define TILE_SIZE 16
torch::Tensor myUnfusedAttentionBlocked(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector.
std::vector<float> QK_t = formatTensor(QK_tTensor);
// -------- YOUR CODE HERE -------- //
for(int b=0; b < B; b++) {
for(int h=0; h < H; h++) {
for(int q_row_tile_idx=0; q_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; q_row_tile_idx++) {
// K is not transposed so we traverse k by row.
for(int k_row_tile_idx=0; k_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; k_row_tile_idx++ ) {
for(int d_col_tile_idx=0; d_col_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; d_col_tile_idx++ ) {
for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
// int out_row_idx = q_row_tile_idx * TILE_SIZE + tile_row_idx;
for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
// int out_col_idx = k_row_tile_idx * TILE_SIZE + tile_col_idx;
int q_col_idx = d_col_tile_idx * TILE_SIZE + tile_col_idx;
int q_row_idx =q_row_tile_idx * TILE_SIZE + tile_row_idx;
int k_row_idx = k_row_tile_idx * TILE_SIZE + tile_row_idx;
int k_col_idx = d_col_tile_idx * TILE_SIZE + tile_col_idx;
if(q_row_idx < N && q_col_idx < d && k_row_idx < N && k_col_idx < d) {
float q_tile_val = fourDimRead(Q, b, h, q_row_idx, q_col_idx, H, N, d);
float k_tile_val = fourDimRead(K, b, h, k_row_idx, k_col_idx, H, N, d);
float orig_val = twoDimRead(QK_t, q_row_tile_idx, k_row_idx, N);
float val = q_tile_val * k_tile_val + orig_val;
twoDimWrite(QK_t, q_row_tile_idx, k_row_tile_idx, N, val );
}
}
}
}
}
}
for(int row_idx=0; row_idx < N; row_idx++) {
std::vector<double> tmp_row_res(N, 0.0);
double row_sum = 0.0;
for(int cold_idx=0; cold_idx < N ;cold_idx++) {
float val = twoDimRead(QK_t, row_idx, cold_idx, N);
double exp_val = std::exp(val);
row_sum += exp_val;
tmp_row_res[cold_idx] = exp_val;
}
for(int cold_idx=0; cold_idx < N ; cold_idx++) {
float prob = tmp_row_res[cold_idx] / row_sum;
twoDimWrite(QK_t, row_idx, cold_idx, N, prob);
}
}
for(int qkt_row_idx=0; qkt_row_idx < N; qkt_row_idx++) {
for(int output_d_idx=0; output_d_idx < d; output_d_idx++) {
float val =0.0;
for(int m_idx=0; m_idx < N ; m_idx++) {
float qkt_val = twoDimRead(QK_t, qkt_row_idx, m_idx, N);
int v_row = m_idx;
int v_col = output_d_idx;
float v_val = fourDimRead(V, b, h, v_row, v_col, H, N, d);
val += qkt_val * v_val;
}
fourDimWrite(O, b, h, qkt_row_idx, output_d_idx, H, N, d ,val);
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
It’s not correct
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 11:29:24 607233:607233 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0767, 0.0775, 0.0783, 0.0791, 0.0799, 0.0807, 0.0815, 0.0823, 0.0831,
0.0839, 0.0847, 0.0855, 0.0863, 0.0871, 0.0879, 0.0887, 0.0895, 0.0903,
0.0911, 0.0919, 0.0927, 0.0935, 0.0943, 0.0951, 0.0959, 0.0967, 0.0975,
0.0983, 0.0991, 0.0999, 0.1007, 0.1015])
STAGE:2024-11-18 11:29:24 607233:607233 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 11:29:24 607233:607233 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
main()
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 313, in main
part2Test(N, d, B, H)
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 233, in part2Test
testTemplate(attentionModuleReference.myUnfusedAttentionBlocked, params, "STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX")
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
YOUR ATTENTION PRODUCED INCORRECT RESULTS
Fixed the two indexing issue in original issue.
This is the correct implementation given by gpt. However, it’s too slow
for (int q_row_tile_idx = 0; q_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; q_row_tile_idx++) {
for (int k_row_tile_idx = 0; k_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; k_row_tile_idx++) {
for (int d_col_tile_idx = 0; d_col_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; d_col_tile_idx++) {
for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int q_row_idx = q_row_tile_idx * TILE_SIZE + tile_row_idx;
int k_row_idx = k_row_tile_idx * TILE_SIZE + tile_col_idx; // Fix indexing
for (int d_idx = d_col_tile_idx * TILE_SIZE; d_idx < (d_col_tile_idx + 1) * TILE_SIZE; d_idx++) {
if (q_row_idx < N && k_row_idx < N && d_idx < d) {
float q_tile_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
float k_tile_val = fourDimRead(K, b, h, k_row_idx, d_idx, H, N, d);
float orig_val = twoDimRead(QK_t, q_row_idx, k_row_idx, N);
float val = q_tile_val * k_tile_val + orig_val;
twoDimWrite(QK_t, q_row_idx, k_row_idx, N, val);
}
}
}
}
}
}
}
Explanation from gpt
You’re right to question the necessity of traversing the dimension ( d ) fully. Let’s clarify why it’s important in some cases and whether it applies here.
Understanding the Role of Dimension ( d )
In block matrix multiplication:
- Matrix Multiplication Basics:
- To compute an element ((i, j)) of the resulting matrix ( QK_t ), you need to calculate the dot product of the (i)-th row of (Q) and the (j)-th row of (K).
- This involves summing the product of corresponding elements from both rows over the shared dimension ( d ):
[ QK_t[i, j] = \sum_{k=0}^{d-1} Q[i, k] \times K[j, k] ]
- Tiling:
- In tiled matrix multiplication, the computation of a block of ( QK_t ) (e.g., a ( \text{TILE_SIZE} \times \text{TILE_SIZE} ) block) is split into smaller sub-blocks to improve memory locality.
- Each sub-block computation iterates over the shared dimension ( d ) in chunks.
Why Traverse ( d )?
If you don’t traverse ( d ), the computation will only consider partial contributions to the dot product. This would leave the values in ( QK_t ) incomplete.
For example, if ( d = 6 ), and your block size is ( \text{TILE_SIZE} = 3 ):
- You compute the dot product for ( d \in [0, 2] ), but ignore ( d \in [3, 5] ), leading to incorrect results.
Why the Code Needs It
Your original code suggests you’re updating ( QK_t ) incrementally:
float val = q_tile_val * k_tile_val + orig_val;
Here, orig_val
accumulates partial results. Without iterating fully over ( d ), the result will be incomplete because only a subset of the dot product contributions is added.
When ( d ) Traversal Isn’t Necessary
In rare cases, you might not need to traverse ( d ) if:
- Each block fully covers ( d ):
- This is possible only when ( \text{TILE_SIZE} ) equals ( d ).
- In this case, the tiling logic aligns perfectly with the shared dimension, and a single traversal suffices.
- Precomputed Values:
- If ( Q ) or ( K ) has been preprocessed (e.g., multiplied previously), the shared dimension ( d ) might not need iteration.
Conclusion
For the general case (and your code):
- Traversing ( d ) fully ensures the correctness of the dot product computation across the shared dimension.
- If you’re sure ( d ) traversal isn’t necessary, it usually means ( Q ) and ( K ) are already structured to encapsulate the full dot product for each tile.
If you’re still unsure about why it’s needed in your case, we can examine your problem more closely to confirm!
I can view the outside for loop from two kind of perspective.
One is from input matrix.
One is from output matrix. This one makes more sense after reading code generated by gpt
Two outer for loop traverse row and column tiles of output matrix. The third for loop traverse the row in input matrix and the column in second input matrix The fourth and fifth for loop traverse the row and column elements in current tile. The sixth for loop does dot product for a single row in input tile and single column in output tile and accumulate the sum. And then write back the sum to output matrix.
for (int row_tile_idx = 0; row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; row_tile_idx++) {
for (int col_tile_idx = 0; col_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; col_tile_idx++) {
// Initialize the current block of the output matrix
for (int i = 0; i < TILE_SIZE; i++) {
for (int j = 0; j < TILE_SIZE; j++) {
int row_idx = row_tile_idx * TILE_SIZE + i;
int col_idx = col_tile_idx * TILE_SIZE + j;
if (row_idx < N && col_idx < N) {
twoDimWrite(C, row_idx, col_idx, N, 0.0); // Initialize to zero
}
}
}
for (int k_tile_idx = 0; k_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; k_tile_idx++) {
for (int i = 0; i < TILE_SIZE; i++) {
for (int j = 0; j < TILE_SIZE; j++) {
int row_idx = row_tile_idx * TILE_SIZE + i;
int col_idx = col_tile_idx * TILE_SIZE + j;
if (row_idx >= N || col_idx >= N) continue;
float sum = twoDimRead(C, row_idx, col_idx, N);
for (int k = 0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
if (k_idx >= d) break;
float a_val = twoDimRead(A, row_idx, k_idx, d);
float b_val = twoDimRead(B, col_idx, k_idx, d); // Column index now indexes rows in B
sum += a_val * b_val;
}
twoDimWrite(C, row_idx, col_idx, N, sum);
}
}
}
}
}
The code above use global buffer.
for(int row_tile_idx=0; row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; row_tile_idx++) {
for(int col_tile_idx=0; col_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; col_tile_idx++) {
for(int k_tile_idx=0; k_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++ ) {
for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
if(row_idx >= N || col_idx >= N) {
continue;
}
float sum = twoDimRead(QK_t, row_idx, col_idx, N);
for(int k=0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
if(k_idx >= d) break;
float q_val = fourDimRead(Q,b, h, row_idx, k_idx, H, N, d);
float k_val = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
sum += q_val * k_val;
}
twoDimWrite(QK_t, row_idx, col_idx, N, sum);
}
}
}
}
}
Output:
REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 238.938ms
mem usage: 4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:39:56 1318038:1318038 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:39:56 1318038:1318038 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 13:39:56 1318038:1318038 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.29826903343200684
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 21.000us 0.01% 21.000us 2.625us 5.00 Mb 5.00 Mb 8
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX 99.35% 296.386ms 99.70% 297.443ms 297.443ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.01% 15.000us 0.24% 706.000us 353.000us 4.50 Mb 0 b 2
aten::clone 0.01% 19.000us 0.11% 328.000us 164.000us 1.00 Mb 0 b 2
aten::empty_strided 0.00% 11.000us 0.00% 11.000us 2.200us 512.51 Kb 512.51 Kb 5
model_inference 0.18% 550.000us 100.00% 298.329ms 298.329ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.01% 15.000us 0.07% 200.000us 40.000us 512.00 Kb 0 b 5
aten::empty_like 0.00% 3.000us 0.00% 5.000us 5.000us 512.00 Kb 0 b 1
aten::to 0.00% 6.000us 0.01% 31.000us 5.167us 520 b 0 b 6
aten::_to_copy 0.00% 14.000us 0.01% 25.000us 6.250us 520 b 0 b 4
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 298.329ms
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 297.443ms
mem usage: 4718592 bytes
This code use local buffer that is allocated. A little bit faster but not a lot.
#pragma omp parallel for collapse(2) // Parallelize the two outermost loops
for (int row_tile_idx = 0; row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; row_tile_idx++) {
for (int col_tile_idx = 0; col_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; col_tile_idx++) {
for (int k_tile_idx = 0; k_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; k_tile_idx++) {
// Buffers for tile data
float Q_tile[TILE_SIZE][TILE_SIZE];
float K_tile[TILE_SIZE][TILE_SIZE];
// Preload Q and K tiles into local buffers
for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
if (row_idx >= N) continue; // Skip out-of-bound rows
for (int k = 0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
if (k_idx < d) {
Q_tile[tile_row_idx][k] = fourDimRead(Q, b, h, row_idx, k_idx, H, N, d);
} else {
Q_tile[tile_row_idx][k] = 0.0f; // Fill with zero if out-of-bounds
}
}
}
for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
if (col_idx >= N) continue; // Skip out-of-bound columns
for (int k = 0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
if (k_idx < d) {
K_tile[tile_col_idx][k] = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
} else {
K_tile[tile_col_idx][k] = 0.0f; // Fill with zero if out-of-bounds
}
}
}
// Compute the dot product for the current tile
for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
if (row_idx >= N) continue; // Skip out-of-bound rows
for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
if (col_idx >= N) continue; // Skip out-of-bound columns
float sum = twoDimRead(QK_t, row_idx, col_idx, N);
// Unrolled loop for vectorized dot product
for (int k = 0; k < TILE_SIZE; k++) {
sum += Q_tile[tile_row_idx][k] * K_tile[tile_col_idx][k];
}
twoDimWrite(QK_t, row_idx, col_idx, N, sum);
}
}
}
}
}
OUtput:
REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 226.667ms
mem usage: 4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:44:02 1342423:1342423 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:44:02 1342423:1342423 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 13:44:02 1342423:1342423 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.2852001190185547
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.00% 13.000us 0.00% 13.000us 1.625us 5.00 Mb 5.00 Mb 8
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX 99.52% 283.895ms 99.68% 284.350ms 284.350ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.01% 17.000us 0.05% 131.000us 65.500us 4.50 Mb 0 b 2
aten::clone 0.01% 23.000us 0.11% 300.000us 150.000us 1.00 Mb 0 b 2
aten::empty_strided 0.00% 12.000us 0.00% 12.000us 2.400us 512.26 Kb 512.26 Kb 5
model_inference 0.20% 565.000us 100.00% 285.254ms 285.254ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.01% 16.000us 0.06% 159.000us 31.800us 512.00 Kb 0 b 5
aten::empty_like 0.00% 3.000us 0.00% 5.000us 5.000us 512.00 Kb 0 b 1
aten::to 0.00% 10.000us 0.01% 31.000us 5.167us 520 b 256 b 6
aten::_to_copy 0.01% 16.000us 0.01% 26.000us 6.500us 520 b 256 b 4
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 285.254ms
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 284.35ms
mem usage: 4718592 bytes
Code that do block matrix multiplication for both (Q * K_t) and (softmax(QK_t ) * V):
#define TILE_SIZE 16
torch::Tensor myUnfusedAttentionBlocked(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector.
std::vector<float> QK_t = formatTensor(QK_tTensor);
// -------- YOUR CODE HERE -------- //
for(int b=0; b < B; b++) {
for(int h=0; h < H; h++) {
//
// correct
// for(int row_tile_idx=0; row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; row_tile_idx++) {
// for(int col_tile_idx=0; col_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; col_tile_idx++) {
// for(int k_tile_idx=0; k_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++ ) {
// for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
// for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
// int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
// int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
// if(row_idx >= N || col_idx >= N) {
// continue;
// }
// float sum = twoDimRead(QK_t, row_idx, col_idx, N);
// for(int k=0; k < TILE_SIZE; k++) {
// int k_idx = k_tile_idx * TILE_SIZE + k;
// if(k_idx >= d) break;
// float q_val = fourDimRead(Q,b, h, row_idx, k_idx, H, N, d);
// float k_val = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
// sum += q_val * k_val;
// }
// twoDimWrite(QK_t, row_idx, col_idx, N, sum);
// }
// }
// }
// }
// }
// correct with local buffer
for (int row_tile_idx = 0; row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; row_tile_idx++) {
for (int col_tile_idx = 0; col_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; col_tile_idx++) {
for (int k_tile_idx = 0; k_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; k_tile_idx++) {
// Buffers for tile data
float Q_tile[TILE_SIZE][TILE_SIZE];
float K_tile[TILE_SIZE][TILE_SIZE];
// Preload Q and K tiles into local buffers
for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
if (row_idx >= N) continue; // Skip out-of-bound rows
for (int k = 0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
if (k_idx < d) {
Q_tile[tile_row_idx][k] = fourDimRead(Q, b, h, row_idx, k_idx, H, N, d);
} else {
Q_tile[tile_row_idx][k] = 0.0f; // Fill with zero if out-of-bounds
}
}
}
for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
if (col_idx >= N) continue; // Skip out-of-bound columns
for (int k = 0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
if (k_idx < d) {
K_tile[tile_col_idx][k] = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
} else {
K_tile[tile_col_idx][k] = 0.0f; // Fill with zero if out-of-bounds
}
}
}
// Compute the dot product for the current tile
for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
if (row_idx >= N) continue; // Skip out-of-bound rows
for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
if (col_idx >= N) continue; // Skip out-of-bound columns
float sum = twoDimRead(QK_t, row_idx, col_idx, N);
// Unrolled loop for vectorized dot product
for (int k = 0; k < TILE_SIZE; k++) {
sum += Q_tile[tile_row_idx][k] * K_tile[tile_col_idx][k];
}
twoDimWrite(QK_t, row_idx, col_idx, N, sum);
}
}
}
}
}
// also correct
// for (int q_row_tile_idx = 0; q_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; q_row_tile_idx++) {
// for (int k_row_tile_idx = 0; k_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; k_row_tile_idx++) {
// for (int d_col_tile_idx = 0; d_col_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; d_col_tile_idx++) {
// for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
// for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
// int q_row_idx = q_row_tile_idx * TILE_SIZE + tile_row_idx;
// int k_row_idx = k_row_tile_idx * TILE_SIZE + tile_col_idx; // Fix indexing
// for (int d_idx = d_col_tile_idx * TILE_SIZE; d_idx < (d_col_tile_idx + 1) * TILE_SIZE; d_idx++) {
// if (q_row_idx < N && k_row_idx < N && d_idx < d) {
// float q_tile_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
// float k_tile_val = fourDimRead(K, b, h, k_row_idx, d_idx, H, N, d);
// float orig_val = twoDimRead(QK_t, q_row_idx, k_row_idx, N);
// float val = q_tile_val * k_tile_val + orig_val;
// twoDimWrite(QK_t, q_row_idx, k_row_idx, N, val);
// }
// }
// }
// }
// }
// }
// }
for(int row_idx=0; row_idx < N; row_idx++) {
std::vector<double> tmp_row_res(N, 0.0);
double row_sum = 0.0;
for(int cold_idx=0; cold_idx < N ;cold_idx++) {
float val = twoDimRead(QK_t, row_idx, cold_idx, N);
double exp_val = std::exp(val);
row_sum += exp_val;
tmp_row_res[cold_idx] = exp_val;
}
for(int cold_idx=0; cold_idx < N ; cold_idx++) {
float prob = tmp_row_res[cold_idx] / row_sum;
twoDimWrite(QK_t, row_idx, cold_idx, N, prob);
}
}
for(int qkt_row_tile_idx=0; qkt_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; qkt_row_tile_idx++) {
for(int output_d_tile_idx=0; output_d_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; output_d_tile_idx++) {
for(int k_tile_idx=0; k_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++) {
for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
int out_row_idx = qkt_row_tile_idx * TILE_SIZE + tile_row_idx;
if(out_row_idx >= N) continue;
for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int out_col_idx = output_d_tile_idx * TILE_SIZE + tile_col_idx;
if( out_col_idx >= d) continue;
float sum = fourDimRead(O, b, h, out_row_idx, out_col_idx, H, N, d );
for(int k=0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
float qkt_val = twoDimRead(QK_t, out_row_idx, k_idx, N);
float v_val = fourDimRead(V, b, h, k_idx, out_col_idx, H, N, d);
sum += qkt_val * v_val;
}
fourDimWrite(O, b, h, out_row_idx, out_col_idx, H, N, d, sum);
}
}
}
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
Output:
Achieve the same cpu time as ref solution.
Self CPU time total: 215.420ms
REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 214.494ms
mem usage: 4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 20:06:22 3461555:3461555 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 20:06:23 3461555:3461555 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 20:06:23 3461555:3461555 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.2183218002319336
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 12.000us 0.01% 12.000us 1.500us 5.00 Mb 5.00 Mb 8
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX 99.38% 217.015ms 99.61% 217.504ms 217.504ms 4.50 Mb -1.00 Mb 1
aten::zeros 0.01% 17.000us 0.07% 142.000us 71.000us 4.50 Mb 0 b 2
aten::clone 0.01% 19.000us 0.14% 313.000us 156.500us 1.00 Mb 0 b 2
aten::empty_strided 0.00% 9.000us 0.00% 9.000us 1.800us 512.25 Kb 512.25 Kb 5
model_inference 0.24% 532.000us 100.00% 218.359ms 218.359ms 512.00 Kb -4.00 Mb 1
aten::flatten 0.01% 25.000us 0.09% 189.000us 37.800us 512.00 Kb 0 b 5
aten::empty_like 0.00% 4.000us 0.00% 5.000us 5.000us 512.00 Kb 0 b 1
aten::to 0.00% 9.000us 0.01% 30.000us 5.000us 520 b 4 b 6
aten::_to_copy 0.01% 14.000us 0.01% 25.000us 6.250us 520 b 260 b 4
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 218.359ms
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 217.504ms
mem usage: 4718592 bytes
Another issue of code above.
It’s not correct when output matrix size is not power of 2
(myenv) ➜ cs149gpt git:(main) python3 gpt149.py part2 -N 103
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0077, 0.0085, 0.0093, 0.0101, 0.0109, 0.0117, 0.0125, 0.0133, 0.0141,
0.0149, 0.0157, 0.0165, 0.0173, 0.0181, 0.0189, 0.0197, 0.0205, 0.0213,
0.0221, 0.0229, 0.0237, 0.0245, 0.0253, 0.0261, 0.0269, 0.0277, 0.0285,
0.0293, 0.0301, 0.0309, 0.0317, 0.0325])
STAGE:2024-11-18 23:12:28 271968:271968 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0077, 0.0086, 0.0094, 0.0103, 0.0112, 0.0121, 0.0129, 0.0138, 0.0147,
0.0155, 0.0164, 0.0173, 0.0181, 0.0190, 0.0199, 0.0208, 0.0216, 0.0225,
0.0234, 0.0242, 0.0251, 0.0260, 0.0268, 0.0277, 0.0286, 0.0295, 0.0303,
0.0312, 0.0321, 0.0329, 0.0338, 0.0347])
STAGE:2024-11-18 23:12:28 271968:271968 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 23:12:28 271968:271968 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
main()
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 313, in main
part2Test(N, d, B, H)
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 233, in part2Test
testTemplate(attentionModuleReference.myUnfusedAttentionBlocked, params, "STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX")
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
YOUR ATTENTION PRODUCED INCORRECT RESULTS
Found the root cause. It’s because there is out of bound check for softmax(QK_t) * V
for(int qkt_row_tile_idx=0; qkt_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; qkt_row_tile_idx++) {
for(int output_d_tile_idx=0; output_d_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; output_d_tile_idx++) {
for(int k_tile_idx=0; k_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++) {
for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
int out_row_idx = qkt_row_tile_idx * TILE_SIZE + tile_row_idx;
if(out_row_idx >= N) continue;
for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
int out_col_idx = output_d_tile_idx * TILE_SIZE + tile_col_idx;
if( out_col_idx >= d) continue;
float sum = fourDimRead(O, b, h, out_row_idx, out_col_idx, H, N, d );
for(int k=0; k < TILE_SIZE; k++) {
int k_idx = k_tile_idx * TILE_SIZE + k;
// Needto add this out of bound check
if(k_idx >= N) break;
float qkt_val = twoDimRead(QK_t, out_row_idx, k_idx, N);
float v_val = fourDimRead(V, b, h, k_idx, out_col_idx, H, N, d);
sum += qkt_val * v_val;
}
fourDimWrite(O, b, h, out_row_idx, out_col_idx, H, N, d, sum);
}
}
}
}
}
Now the result is correct when N is not power of 2. The cpu time is higher than aht of ref solution. I think this is because I add this out of bound check ?
Output:
(myenv) ➜ cs149gpt git:(main) ✗ python3 gpt149.py part2 -N 1023
Self CPU time total: 223.154ms
REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 222.105ms
mem usage: 4709892 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0770, 0.0778, 0.0786, 0.0794, 0.0802, 0.0810, 0.0818, 0.0826, 0.0834,
0.0842, 0.0850, 0.0858, 0.0866, 0.0874, 0.0882, 0.0890, 0.0898, 0.0906,
0.0914, 0.0922, 0.0930, 0.0938, 0.0946, 0.0954, 0.0962, 0.0970, 0.0978,
0.0986, 0.0994, 0.1002, 0.1010, 0.1018])
STAGE:2024-11-19 10:48:32 4016683:4016683 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0770, 0.0778, 0.0786, 0.0794, 0.0802, 0.0810, 0.0818, 0.0826, 0.0834,
0.0842, 0.0850, 0.0858, 0.0866, 0.0874, 0.0882, 0.0890, 0.0898, 0.0906,
0.0914, 0.0922, 0.0930, 0.0938, 0.0946, 0.0954, 0.0962, 0.0970, 0.0978,
0.0986, 0.0994, 0.1002, 0.1010, 0.1018])
STAGE:2024-11-19 10:48:32 4016683:4016683 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 10:48:32 4016683:4016683 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.23152732849121094
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 13.000us 0.01% 13.000us 1.625us 4.99 Mb 4.99 Mb 8
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX 99.41% 230.208ms 99.62% 230.695ms 230.695ms 4.49 Mb -1023.00 Kb 1
aten::zeros 0.01% 32.000us 0.06% 145.000us 72.500us 4.49 Mb 0 b 2
aten::clone 0.01% 18.000us 0.14% 318.000us 159.000us 1023.00 Kb 0 b 2
aten::empty_strided 0.00% 9.000us 0.00% 9.000us 1.800us 511.75 Kb 511.75 Kb 5
model_inference 0.24% 546.000us 100.00% 231.569ms 231.569ms 511.50 Kb -3.99 Mb 1
aten::flatten 0.01% 16.000us 0.07% 170.000us 34.000us 511.50 Kb 0 b 5
aten::empty_like 0.00% 4.000us 0.00% 6.000us 6.000us 511.50 Kb 0 b 1
aten::to 0.00% 6.000us 0.01% 30.000us 5.000us 552 b 32 b 6
aten::_to_copy 0.01% 15.000us 0.01% 24.000us 6.000us 520 b 260 b 4
---------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 231.569ms
STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time: 230.695ms
mem usage: 4709892 bytes
Part3: Fused attention
Problematic implementation: Output is all zero. Code:
torch::Tensor myFusedAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor temp,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//Make O Tensor with Shape (B, H, N, d)
//and O Row Tensor with Shape (N)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
at::Tensor ORowTensor = at::zeros({N}, at::kFloat);
//Format Y, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format ORow Tensor into a 1D vector
// You can simply access this as ORow[i]
std::vector<float> ORow = formatTensor(ORowTensor);
// -------- YOUR CODE HERE -------- //
// We give you a template of the first three loops for your convenience
//loop over batch
#pragma omp parallel for collapse(3)
for (int b = 0; b < B; b++){
//loop over heads
for (int h = 0; h < H; h++){
for (int q_row_idx = 0; q_row_idx < N ; q_row_idx++){
// YRow is moved inside so each OpenMP thread gets a local copy.
at::Tensor ORowTensor = temp.index({torch::indexing::Slice(omp_get_thread_num(), torch::indexing::None)});
std::vector<float> ORow = formatTensor(ORowTensor);
//YOUR CODE HERE
for(int k_row_idx=0; k_row_idx < N; k_row_idx++) {
float val = 0.0;
for(int d_idx=0; d_idx < d ;d_idx++ ) {
int q_row = q_row_idx;
int q_col = d_idx;
int k_row = k_row_idx;
int k_col = d_idx;
float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d) ;
val += q_val * k_val;
}
ORow[k_row_idx] = val;
}
// softmax
std::vector<float> tmp_row_res(N, 0.0);
float sum = 0.0;
for(int i=0; i < N; i++) {
ORow[i] = std::exp(ORow[i]) ;
sum += ORow[i];
// tmp_row_res[i] = exp_val;
}
for(int i=0; i < N; i++) {
float prob = ORow[i] / sum;
ORow[i] = prob;
}
for(int v_col_idx=0; v_col_idx < d; v_col_idx++) {
float sum =0.0;
for(int v_row_idx=0; v_row_idx < N; v_row_idx++) {
int v_val = fourDimRead(V, b, h, v_row_idx, v_col_idx, H, N ,d);
sum += v_val * ORow[v_row_idx];
}
std::cout << "vcold_idx" << v_col_idx << "val: " << sum << std::endl;
// tmp_row_res[v_col_idx] = sum;
fourDimWrite(O, b, h, q_row_idx, v_col_idx, H, N, d, sum);
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
REFERENCE - FUSED ATTENTION statistics
cpu time: 56.506ms
mem usage: 557056 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 11:53:42 171611:171611 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.])
STAGE:2024-11-19 11:53:42 171611:171611 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 11:53:42 171611:171611 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
main()
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 315, in main
part3Test(N, d, B, H)
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 245, in part3Test
testTemplate(attentionModuleReference.myFusedAttention, params, "STUDENT - FUSED ATTENTION")
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
YOUR ATTENTION PRODUCED INCORRECT RESULTS
Found the root cause . It’s becaue of wrong type definition.
for(int v_col_idx=0; v_col_idx < d; v_col_idx++) {
float sum =0.0;
for(int v_row_idx=0; v_row_idx < N; v_row_idx++) {
// This is the correct definition
float v_val = fourDimRead(V, b, h, v_row_idx, v_col_idx, H, N ,d);
sum += v_val * ORow[v_row_idx];
}
std::cout << "vcold_idx" << v_col_idx << "val: " << sum << std::endl;
fourDimWrite(O, b, h, q_row_idx, v_col_idx, H, N, d, sum);
}
Output:
It’s too slow. I don’t know why
REFERENCE - FUSED ATTENTION statistics
cpu time: 59.932ms
mem usage: 557056 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 13:07:43 578930:578930 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 13:07:43 578930:578930 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 13:07:43 578930:578930 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.3012988567352295
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 17.000us 0.01% 17.000us 1.889us 1.04 Mb 1.04 Mb 9
aten::clone 0.00% 15.000us 0.08% 247.000us 123.500us 1.00 Mb 0 b 2
aten::zeros 0.01% 19.000us 0.02% 65.000us 21.667us 548.00 Kb 0 b 3
STUDENT - FUSED ATTENTION 95.37% 287.422ms 99.68% 300.382ms 300.382ms 544.00 Kb -1.00 Mb 1
aten::empty_strided 0.00% 9.000us 0.00% 9.000us 1.800us 512.00 Kb 512.00 Kb 5
model_inference 0.21% 636.000us 100.00% 301.361ms 301.361ms 512.00 Kb -32.63 Kb 1
aten::flatten 1.42% 4.278ms 1.48% 4.467ms 1.089us 512.00 Kb 0 b 4101
aten::empty_like 0.00% 3.000us 0.00% 5.000us 5.000us 512.00 Kb 0 b 1
aten::to 0.00% 6.000us 0.01% 32.000us 5.333us 520 b 0 b 6
aten::_to_copy 0.01% 18.000us 0.01% 26.000us 6.500us 520 b 516 b 4
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 301.361ms
STUDENT - FUSED ATTENTION statistics
cpu time: 300.382ms
mem usage: 557056 bytes
Found the root cause. I don’t enable openmp in original code
The speed is comparable after enabling openmp Code:
torch::Tensor myFusedAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor temp,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//Make O Tensor with Shape (B, H, N, d)
//and O Row Tensor with Shape (N)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
at::Tensor ORowTensor = at::zeros({N}, at::kFloat);
//Format Y, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format ORow Tensor into a 1D vector
// You can simply access this as ORow[i]
std::vector<float> ORow = formatTensor(ORowTensor);
// -------- YOUR CODE HERE -------- //
// We give you a template of the first three loops for your convenience
//loop over batch
// Need to enable this openmp for parallel execution for all rows in batches and heads.
#pragma omp parallel for collapse(3)
for (int b = 0; b < B; b++){
//loop over heads
for (int h = 0; h < H; h++){
for (int q_row_idx = 0; q_row_idx < N ; q_row_idx++){
// YRow is moved inside so each OpenMP thread gets a local copy.
at::Tensor ORowTensor = temp.index({torch::indexing::Slice(omp_get_thread_num(), torch::indexing::None)});
std::vector<float> ORow = formatTensor(ORowTensor);
//YOUR CODE HERE
for(int k_row_idx=0; k_row_idx < N; k_row_idx++) {
float val = 0.0;
for(int d_idx=0; d_idx < d ;d_idx++ ) {
int q_row = q_row_idx;
int q_col = d_idx;
int k_row = k_row_idx;
int k_col = d_idx;
float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d) ;
val += q_val * k_val;
}
ORow[k_row_idx] = val;
// std::cout << "krowidx: " << k_row_idx << " val: " << ORow[k_row_idx] << std::endl;
}
// softmax
std::vector<float> tmp_row_res(N, 0.0);
float sum = 0.0;
for(int i=0; i < N; i++) {
ORow[i] = std::exp(ORow[i]) ;
sum += ORow[i];
// tmp_row_res[i] = exp_val;
}
for(int i=0; i < N; i++) {
float prob = ORow[i] / sum;
ORow[i] = prob;
// std::cout << "softmax col: " << i << " val: " << ORow[i] << std::endl;
}
for(int v_col_idx=0; v_col_idx < d; v_col_idx++) {
float sum =0.0;
for(int v_row_idx=0; v_row_idx < N; v_row_idx++) {
float v_val = fourDimRead(V, b, h, v_row_idx, v_col_idx, H, N ,d);
sum += v_val * ORow[v_row_idx];
}
// std::cout << "vcold_idx" << v_col_idx << "val: " << sum << std::endl;
fourDimWrite(O, b, h, q_row_idx, v_col_idx, H, N, d, sum);
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
REFERENCE - FUSED ATTENTION statistics
cpu time: 56.526ms
mem usage: 557056 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 17:17:41 1945272:1945272 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 17:17:41 1945272:1945272 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 17:17:41 1945272:1945272 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time: 0.04986262321472168
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.02% 12.000us 0.02% 12.000us 1.333us 1.04 Mb 1.04 Mb 9
aten::clone 0.04% 19.000us 0.42% 208.000us 104.000us 1.00 Mb 0 b 2
aten::zeros 0.04% 19.000us 0.14% 68.000us 22.667us 548.00 Kb 0 b 3
STUDENT - FUSED ATTENTION 94.35% 47.083ms 98.27% 49.043ms 49.043ms 544.00 Kb -1.00 Mb 1
aten::empty_strided 0.02% 8.000us 0.02% 8.000us 1.600us 512.00 Kb 512.00 Kb 5
model_inference 1.11% 555.000us 100.00% 49.904ms 49.904ms 512.00 Kb -32.63 Kb 1
aten::flatten 1.10% 549.000us 1.50% 748.000us 1.447us 512.00 Kb 0 b 517
aten::empty_like 0.01% 3.000us 0.01% 4.000us 4.000us 512.00 Kb 0 b 1
aten::to 0.01% 6.000us 0.06% 29.000us 4.833us 520 b 0 b 6
aten::_to_copy 0.03% 14.000us 0.05% 23.000us 5.750us 520 b 260 b 4
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 49.904ms
STUDENT - FUSED ATTENTION statistics
cpu time: 49.043ms
mem usage: 557056 bytes
Part4: Flash attention
gpt gives good explanation about how Oi forumala is obtained.
Basically previous O_i is divided by l_i . So we need to multiply O_i by l_i when calculating new O_i with new l_i according to this expression.
[ O_i = \frac{\sum_{\text{old } j} \alpha_{ij} V_j + \sum_{\text{new } j} \alpha_{ij} V_j}{\sum_{\text{old } j} \alpha_{ij} + \sum_{\text{new } j} \alpha_{ij}}. ]
The formula
[ O_i = \frac{l_i O_i + P_{ij} V_j}{l_{\text{new}}} ]
and its update mechanism for the cumulative sum ( l_i O_i ) and normalization factor ( l_i ) works because of the linearity of the softmax-weighted sum and the incremental construction of softmax normalization. Let’s derive and explain why this works step-by-step.
1. Standard Attention Mechanism
In the standard attention mechanism, the output ( O_i ) for query ( Q_i ) is defined as:
[ O_i = \frac{\sum_{j=1}^N \alpha_{ij} V_j}{\sum_{j=1}^N \alpha_{ij}}, ]
where:
- ( \alpha_{ij} = \exp(Q_i K_j^T) ): the unnormalized attention score between query ( Q_i ) and key ( K_j ).
- ( P_{ij} = \frac{\alpha_{ij}}{\sum_{k=1}^N \alpha_{ik}} ): the normalized attention score (softmax).
- ( V_j ): the value vector associated with key ( K_j ).
Rewriting ( O_i ) explicitly:
[ O_i = \frac{\sum_{j=1}^N \alpha_{ij} V_j}{Z_i}, ]
where ( Z_i = \sum_{j=1}^N \alpha_{ij} ) is the normalization constant for row ( i ).
2. Chunked Computation in Flash Attention
To handle memory constraints, Flash Attention processes the attention computation in chunks of ( j ). Suppose we process a chunk of keys/values indexed by ( j ) in each step. Let:
- ( l_i^{\text{old}} ) be the normalization constant ( Z_i ) for the previously processed chunks.
- ( O_i^{\text{old}} ) be the partial output for the previously processed chunks.
The goal is to update ( O_i ) and ( l_i ) incrementally when a new chunk is processed.
For the new chunk, compute:
- ( \alpha_{ij} ) for the current chunk.
- The sum of unnormalized scores for the chunk: ( l_{\text{chunk}} = \sum_{j \in \text{chunk}} \alpha_{ij} ).
- The weighted value contribution: ( O_{\text{chunk}} = \sum_{j \in \text{chunk}} \alpha_{ij} V_j ).
The updated values are:
[ l_i^{\text{new}} = l_i^{\text{old}} + l_{\text{chunk}} ] [ O_i^{\text{new}} = \frac{l_i^{\text{old}} O_i^{\text{old}} + O_{\text{chunk}}}{l_i^{\text{new}}}. ]
3. Why This Works: Derivation
a. Weighted Combination of Outputs
The full output ( O_i ) over all chunks can be expressed as a weighted combination of the contributions from the old and new chunks:
[ O_i = \frac{\sum_{\text{all } j} \alpha_{ij} V_j}{\sum_{\text{all } j} \alpha_{ij}}. ]
Breaking this into previously processed chunks (old) and the current chunk:
[ O_i = \frac{\sum_{\text{old } j} \alpha_{ij} V_j + \sum_{\text{new } j} \alpha_{ij} V_j}{\sum_{\text{old } j} \alpha_{ij} + \sum_{\text{new } j} \alpha_{ij}}. ]
Let:
- ( l_i^{\text{old}} = \sum_{\text{old } j} \alpha_{ij} ),
- ( l_{\text{chunk}} = \sum_{\text{new } j} \alpha_{ij} ),
- ( O_i^{\text{old}} = \frac{\sum_{\text{old } j} \alpha_{ij} V_j}{l_i^{\text{old}}} ),
- ( O_{\text{chunk}} = \sum_{\text{new } j} \alpha_{ij} V_j ).
Substituting:
[ O_i = \frac{l_i^{\text{old}} O_i^{\text{old}} + O_{\text{chunk}}}{l_i^{\text{old}} + l_{\text{chunk}}}. ]
This matches the Flash Attention formula:
[ O_i = \frac{l_i O_i + P_{ij} V_j}{l_{\text{new}}}. ]
b. Incremental Normalization
The normalization factor ( l_i ) (equivalent to ( Z_i ) in standard attention) is computed incrementally:
[ l_i^{\text{new}} = l_i^{\text{old}} + \sum_{j \in \text{chunk}} \alpha_{ij}. ]
This ensures that the normalization is consistent across all processed chunks, allowing the final ( O_i ) to represent the output for the full sequence.
4. Benefits of This Approach
- Memory Efficiency:
- By processing chunks incrementally, there is no need to store the full ( N \times N ) attention matrix or intermediate results.
- Numerical Stability:
- The incremental updates prevent large sums from growing uncontrollably, which could cause overflow/underflow in standard softmax computations.
- Linearity of Weighted Sum:
- The linearity of summation ensures that partial results (( l_i^{\text{old}} O_i^{\text{old}} )) can be combined directly with new contributions without recomputation.
5. Conclusion
The Flash Attention update formula works because:
- It incrementally combines contributions from each chunk using the linearity of the softmax-weighted sum.
- The normalization factor ( l_i ) ensures consistency across chunks, so the final result is equivalent to the full softmax computation.
- This approach trades memory for compute, making it feasible to handle large inputs while maintaining mathematical equivalence to standard attention.
Incorrect implementation:
torch::Tensor myFlashAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor,
torch::Tensor QiTensor, torch::Tensor KjTensor, torch::Tensor VjTensor,
torch::Tensor SijTensor, torch::Tensor PijTensor, torch::Tensor PVTensor,
torch::Tensor OiTensor, torch::Tensor LTensor, torch::Tensor LiTensor,
torch::Tensor LijTensor, torch::Tensor LnewTensor, int Bc, int Br,
int B, int H, int N, int d) {
// Q, K, V are passed in with Shape: (B, H, N, d)
// Sij, Pij are passed in with Shape: (Br, Bc)
// Kj, Vj are passed in with Shape: (Bc, d)
// Qi, Oi, and PV are passed in with Shape: (Br, d)
// L in passed in with Shape: (N)
// Li, Lij, and Lnew are passed in with shape (Br)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format All Tensors into Vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
std::vector<float> Sij = formatTensor(SijTensor);
std::vector<float> Pij = formatTensor(PijTensor);
std::vector<float> Kj = formatTensor(KjTensor);
std::vector<float> Vj = formatTensor(VjTensor);
std::vector<float> Qi = formatTensor(QiTensor);
std::vector<float> Oi = formatTensor(OiTensor);
std::vector<float> l = formatTensor(LTensor);
std::vector<float> PV = formatTensor(PVTensor);
std::vector<float> li = formatTensor(LiTensor);
std::vector<float> lij = formatTensor(LijTensor);
std::vector<float> lnew = formatTensor(LnewTensor);
// -------- YOUR CODE HERE -------- //
for(int b=0; b < B; b++ ) {
for(int h=0; h < H; h++) {
for(int k_block_idx=0; k_block_idx < (N+Bc-1)/Bc; k_block_idx++) {
// load Kj, Vj into local memory blocks.
for(int j=0; j < Bc; j++) {
int j_row = k_block_idx * Bc + j;
if(j_row >= N) continue;
for(int d_idx =0; d_idx < d; d_idx++) {
float k_val = fourDimRead(K, b, h, j_row, d_idx, H, N, d);
float v_val = fourDimRead(V, b, h, j_row, d_idx, H, N, d);
twoDimWrite(Kj, j, d_idx, d, k_val);
twoDimWrite(Vj, j, d_idx, d, v_val);
}
}
for(int q_block_idx=0; q_block_idx < (N+Br-1)/Br; q_block_idx++) {
// load Qi, Oi, li into local memory blocks
for(int br_idx=0; br_idx < Br; br_idx++ ) {
int q_row_idx = q_block_idx * Br + br_idx;
if(q_row_idx >= N ) continue;
for(int d_idx=0; d_idx < d; d_idx++) {
float q_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
float o_val = fourDimRead(O, b, h, q_row_idx , d_idx, H, N, d);
twoDimWrite(Qi, br_idx, d_idx, d, q_val);
twoDimWrite(Oi, br_idx, d_idx, d, o_val);
}
float l_val = l[q_row_idx];
li[br_idx] = l_val;
}
// compute Sij = Qi * Kj_T (Br x Bc)
for(int br_idx=0; br_idx < Br; br_idx++) {
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
float sum = 0.0;
for(int d_idx=0; d_idx < d; d_idx++) {
float q_val = twoDimRead(Qi, br_idx, d_idx, d);
float k_val = twoDimRead(Kj, bc_idx, d_idx, d);
sum += q_val * k_val;
}
twoDimWrite(Sij, br_idx, bc_idx, d, sum);
}
}
// Compute Pij = exp(Sij) of size (Br x Bc)
for(int br_idx=0; br_idx < Br; br_idx++) {
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
float exp_val = std::exp(twoDimRead(Sij, br_idx, bc_idx, Bc));
twoDimWrite(Pij, br_idx, bc_idx, Bc, exp_val);
}
}
// Compute lij = rowsum(Pij) of size (Br)
for(int br_idx=0; br_idx < Br; br_idx++) {
float sum = 0.0;
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
sum += twoDimRead(Pij, br_idx, bc_idx, Bc);
}
lij[br_idx] = sum;
}
// compute lnew = li + lij
for(int br_idx=0; br_idx < Br; br_idx++) {
lnew[br_idx] = li[br_idx] + lij[br_idx];
}
// Compute Oi = (liOi + PijVj)/ lnew
for(int br_idx=0; br_idx < Br; br_idx++) {
for(int d_idx=0; d_idx < d; d_idx++) {
float pv_sum =0.0;
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
int p_row = br_idx;
int p_col = bc_idx;
int v_row = bc_idx;
int v_col = d_idx;
pv_sum += twoDimRead(Pij, p_row, p_col, Bc) * twoDimRead(Vj, v_row, v_col, d);
}
float li_Oi_val = li[br_idx] * twoDimRead(Oi, br_idx, d_idx, d);
float new_sum = pv_sum + li_Oi_val;
float new_Oi_val = new_sum / lnew[br_idx];
twoDimWrite(Oi, br_idx, d_idx, d, new_Oi_val);
}
}
// Write Oi and lnew back to O and l in main memory
for(int br_idx=0; br_idx < Br; br_idx++) {
int O_row = q_block_idx * Br + br_idx;
if(O_row >= N) break;
for(int d_idx=0; d_idx < d; d_idx++) {
float Oi_val = twoDimRead(Oi, br_idx, d_idx, d);
int O_col = d_idx;
fourDimWrite(O, b, h, O_row, O_col, H, N, d, Oi_val);
}
l[O_row] = lnew[O_row];
}
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
Output:
Self CPU time total: 589.890ms
REFERENCE - FLASH ATTENTION statistics
cpu time: 588.591ms
mem usage: 524288 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-20 14:25:23 442505:442505 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
tensor([0.0773, 0.0781, 0.0789, 0.0797, 0.0805, 0.0813, 0.0821, 0.0829, 0.0837,
0.0845, 0.0853, 0.0861, 0.0869, 0.0877, 0.0885, 0.0893, 0.0901, 0.0909,
0.0917, 0.0925, 0.0933, 0.0941, 0.0949, 0.0957, 0.0965, 0.0973, 0.0981,
0.0989, 0.0997, 0.1005, 0.1013, 0.1021])
STAGE:2024-11-20 14:25:23 442505:442505 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-20 14:25:23 442505:442505 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
main()
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 317, in main
part4Test(N, d, B, H, int(args.bc), int(args.br))
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 257, in part4Test
testTemplate(attentionModuleReference.myFlashAttention, params, "STUDENT - FLASH ATTENTION")
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
YOUR ATTENTION PRODUCED INCORRECT RESULTS
Asked gpt to find bug in my previous code
Some are corret , some are not.
There are several potential bugs and inefficiencies in the provided implementation of Flash Attention. Below is an analysis of common issues based on the code structure.
1. Bug in Writing lnew
to Global Memory
At the end of the block:
l[O_row] = lnew[O_row];
This line is problematic because lnew[O_row]
is incorrect. You are trying to write from lnew
(a local array for the current query block) to the global l
. The index O_row
does not map directly to lnew
, as lnew
uses local indices within a chunk.
Fix:
l[O_row] = lnew[br_idx];
Here, br_idx
is the correct index for lnew
in the current query block.
2. Mismatch in twoDimRead
and twoDimWrite
for Sij
In the computation of ( S_{ij} ):
twoDimWrite(Sij, br_idx, bc_idx, d, sum);
You are using d
as the stride instead of Bc
. Sij
is declared as having a shape (Br, Bc)
, so the stride in the twoDimWrite
call should be Bc
.
Fix:
twoDimWrite(Sij, br_idx, bc_idx, Bc, sum);
Similarly, update other uses of Sij
to use the correct stride.
3. Incorrect Update of liOi
In the line:
float li_Oi_val = li[br_idx] * twoDimRead(Oi, br_idx, d_idx, d);
This reads the old Oi
value from Oi
, but Oi
is already being updated in the loop. This introduces unintended feedback into the computation.
Fix: Make a copy of Oi
at the start of each block before updates, and use that copy for the computation:
std::vector<float> Oi_copy = Oi; // Create a copy of Oi
float li_Oi_val = li[br_idx] * twoDimRead(Oi_copy, br_idx, d_idx, d);
4. Improper Memory Initialization
Variables like Kj
, Vj
, Qi
, Oi
, Sij
, etc., are written to without ensuring their contents are initialized correctly. This might cause unexpected behavior if these tensors are not zeroed out.
Fix: Explicitly zero out these tensors before use:
std::fill(Kj.begin(), Kj.end(), 0.0f);
std::fill(Vj.begin(), Vj.end(), 0.0f);
std::fill(Qi.begin(), Qi.end(), 0.0f);
std::fill(Oi.begin(), Oi.end(), 0.0f);
std::fill(Sij.begin(), Sij.end(), 0.0f);
5. Incorrect Stride in twoDimRead
and twoDimWrite
for Block Matrices
Some block matrices (e.g., Pij
, Kj
, Vj
, Qi
) use strides inconsistently. Double-check the stride parameter in each twoDimRead
and twoDimWrite
call to match the matrix dimensions.
6. Inefficient Computation of lij
and lnew
The computation of lij
(row sum of Pij
) and lnew
involves multiple sequential loops that can be combined for better performance. For example:
for(int br_idx=0; br_idx < Br; br_idx++) {
float sum = 0.0;
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
sum += twoDimRead(Pij, br_idx, bc_idx, Bc);
}
lij[br_idx] = sum;
lnew[br_idx] = li[br_idx] + sum;
}
This eliminates one loop.
7. Numerical Stability in Softmax
The computation of ( P_{ij} = \exp(S_{ij}) ) can suffer from numerical instability if ( S_{ij} ) contains large values. This can lead to overflow or inaccurate results.
Fix: Subtract the maximum value in each row of ( S_{ij} ) before applying the exponential function:
for(int br_idx=0; br_idx < Br; br_idx++) {
float max_val = -std::numeric_limits<float>::infinity();
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
float sij_val = twoDimRead(Sij, br_idx, bc_idx, Bc);
max_val = std::max(max_val, sij_val);
}
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
float sij_val = twoDimRead(Sij, br_idx, bc_idx, Bc) - max_val;
twoDimWrite(Sij, br_idx, bc_idx, Bc, sij_val);
}
}
This ensures the softmax computation is numerically stable.
8. Inefficient Loop Nesting
The loop order might lead to poor memory access patterns, especially for Kj
, Vj
, and Qi
. Consider transposing these matrices in local memory to improve cache efficiency.
9. Potential Misalignment in Block Indexing
When processing blocks of size Br
and Bc
, ensure that the indexing properly handles edge cases where N
is not evenly divisible by Br
or Bc
. Currently, edge handling appears inconsistent.
Fix: Use bounds checking consistently:
if (q_row_idx >= N || j_row >= N) continue;
10. Output Tensor Not Updated Properly
The output tensor OTensor
may not be updated properly due to indexing or memory mismanagement.
Fix: Ensure that updates to O
directly correspond to OTensor
:
for(int br_idx=0; br_idx < Br; br_idx++) {
int O_row = q_block_idx * Br + br_idx;
if(O_row >= N) continue;
for(int d_idx=0; d_idx < d; d_idx++) {
float Oi_val = twoDimRead(Oi, br_idx, d_idx, d);
fourDimWrite(O, b, h, O_row, d_idx, H, N, d, Oi_val);
}
}
Summary of Key Fixes
- Correct indexing when writing
lnew
tol
. - Fix strides in
twoDimRead
andtwoDimWrite
for block matrices. - Zero-initialize temporary tensors.
- Apply numerical stabilization for the softmax computation.
- Combine loops for
lij
andlnew
. - Address potential misalignment when
N
is not a multiple of block sizes.
These fixes ensure correctness, improve numerical stability, and enhance efficiency.
Add initialization of 0.0 to Kj and Vj, to avoid impact of leftover numnber in last iteration
I realize this problem from python diff check and found that value is not equal at the last row of matrix.
Mismatch at index (0, 3, 1023, 22): QKV=0.13594280183315277, QKS1=0.03398570418357849
Mismatch at index (0, 3, 1023, 23): QKV=0.13674280047416687, QKS1=0.03418571129441261
Mismatch at index (0, 3, 1023, 24): QKV=0.13754287362098694, QKS1=0.03438570350408554
Mismatch at index (0, 3, 1023, 25): QKV=0.13834281265735626, QKS1=0.03458569571375847
Mismatch at index (0, 3, 1023, 26): QKV=0.13914281129837036, QKS1=0.03478570282459259
Mismatch at index (0, 3, 1023, 27): QKV=0.13994282484054565, QKS1=0.034985702484846115
Mismatch at index (0, 3, 1023, 28): QKV=0.14074280858039856, QKS1=0.03518570587038994
Mismatch at index (0, 3, 1023, 29): QKV=0.14154279232025146, QKS1=0.035385698080062866
Mismatch at index (0, 3, 1023, 30): QKV=0.14234280586242676, QKS1=0.03558569401502609
Mismatch at index (0, 3, 1023, 31): QKV=0.14314278960227966, QKS1=0.03578570485115051
Traceback (most recent call last):
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 339, in <module>
main()
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 327, in main
part4Test(N, d, B, H, int(args.bc), int(args.br))
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 267, in part4Test
testTemplate(attentionModuleReference.myFlashAttention, params, "STUDENT - FLASH ATTENTION")
File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 190, in testTemplate
raise AssertionError(correctness_error_message)
AssertionError:
-------------------------------------------
YOUR ATTENTION PRODUCED INCORRECT RESULTS
Found that result is incorrect starting from second head.
I think it’s becase that some itermediate variable is not cleared before entering next head which leads to incorrect answer.
Mismatch at index (0, 1, 0, 0): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 1): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 2): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 3): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 4): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 5): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 6): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 7): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 8): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 9): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 10): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 11): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 12): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 13): QKV=1.0, QKS1=0.5
M
Found the root cause and fix it.
I need to clear l
for each head. This is because l
is read to li
for each Q-
block
Code:
torch::Tensor myFlashAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor,
torch::Tensor QiTensor, torch::Tensor KjTensor, torch::Tensor VjTensor,
torch::Tensor SijTensor, torch::Tensor PijTensor, torch::Tensor PVTensor,
torch::Tensor OiTensor, torch::Tensor LTensor, torch::Tensor LiTensor,
torch::Tensor LijTensor, torch::Tensor LnewTensor, int Bc, int Br,
int B, int H, int N, int d) {
// Q, K, V are passed in with Shape: (B, H, N, d)
// Sij, Pij are passed in with Shape: (Br, Bc)
// Kj, Vj are passed in with Shape: (Bc, d)
// Qi, Oi, and PV are passed in with Shape: (Br, d)
// L in passed in with Shape: (N)
// Li, Lij, and Lnew are passed in with shape (Br)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format All Tensors into Vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
std::vector<float> Sij = formatTensor(SijTensor); //clear
std::vector<float> Pij = formatTensor(PijTensor); //clear
std::vector<float> Kj = formatTensor(KjTensor); // clear
std::vector<float> Vj = formatTensor(VjTensor); // clear
std::vector<float> Qi = formatTensor(QiTensor); // clear
std::vector<float> Oi = formatTensor(OiTensor); //clear
std::vector<float> l = formatTensor(LTensor); // This should be cleared
std::vector<float> PV = formatTensor(PVTensor);
std::vector<float> li = formatTensor(LiTensor);
std::vector<float> lij = formatTensor(LijTensor);
std::vector<float> lnew = formatTensor(LnewTensor);
// std::cout << "br:" << Br << " bc:" << Bc <<std::endl;
// -------- YOUR CODE HERE -------- //
for(int b=0; b < B; b++ ) {
for(int h=0; h < H; h++) {
// This line is essential to correctness.
std::fill(l.begin(), l.end(), 0.0f);
std::fill(lnew.begin(), lnew.end(), 0.0f);
std::fill(lij.begin(), lij.end(), 0.0f);
for(int k_block_idx=0; k_block_idx < (N+Bc-1)/Bc; k_block_idx++) {
std::fill(Kj.begin(), Kj.end(), 0.0f);
std::fill(Vj.begin(), Vj.end(), 0.0f);
// load Kj, Vj into local memory blocks.
for(int j=0; j < Bc; j++) {
int j_row = k_block_idx * Bc + j;
if(j_row >= N) continue;
for(int d_idx =0; d_idx < d; d_idx++) {
float k_val = fourDimRead(K, b, h, j_row, d_idx, H, N, d);
float v_val = fourDimRead(V, b, h, j_row, d_idx, H, N, d);
twoDimWrite(Kj, j, d_idx, d, k_val);
twoDimWrite(Vj, j, d_idx, d, v_val);
// std::cout<< "j:" << j_row << " col:" << d_idx << "kj:" << k_val << " vj:" << v_val << std::endl;
}
}
for(int q_block_idx=0; q_block_idx < (N+Br-1)/Br; q_block_idx++) {
std::fill(Qi.begin(), Qi.end(), 0.0f);
std::fill(Oi.begin(), Oi.end(), 0.0f);
std::fill(Sij.begin(), Sij.end(), 0.0f);
std::fill(Pij.begin(), Pij.end(), 0.0f);
// load Qi, Oi, li into local memory blocks
for(int br_idx=0; br_idx < Br; br_idx++ ) {
int q_row_idx = q_block_idx * Br + br_idx;
if(q_row_idx >= N ) continue;
for(int d_idx=0; d_idx < d; d_idx++) {
float q_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
float o_val = fourDimRead(O, b, h, q_row_idx , d_idx, H, N, d);
twoDimWrite(Qi, br_idx, d_idx, d, q_val);
twoDimWrite(Oi, br_idx, d_idx, d, o_val);
// std::cout << "q_row_idx:" << q_row_idx << " d_idx:" << d_idx << " Qi:" << q_val << " Oi:" << o_val <<std::endl;
}
float l_val = l[q_row_idx];
li[br_idx] = l_val;
// std::cout << "li:" << l_val << std::endl;
}
// compute Sij = Qi * Kj_T (Br x Bc)
for(int br_idx=0; br_idx < Br; br_idx++) {
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
float sum = 0.0;
for(int d_idx=0; d_idx < d; d_idx++) {
float q_val = twoDimRead(Qi, br_idx, d_idx, d);
float k_val = twoDimRead(Kj, bc_idx, d_idx, d);
sum += q_val * k_val;
}
twoDimWrite(Sij, br_idx, bc_idx, Bc, sum);
// std::cout << "sij, br:" << br_idx << " bc:" << bc_idx << " val:" << sum << std::endl;
}
}
// Compute Pij = exp(Sij) of size (Br x Bc)
for(int br_idx=0; br_idx < Br; br_idx++) {
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
float exp_val = std::exp(twoDimRead(Sij, br_idx, bc_idx, Bc));
twoDimWrite(Pij, br_idx, bc_idx, Bc, exp_val);
}
}
// Compute lij = rowsum(Pij) of size (Br)
for(int br_idx=0; br_idx < Br; br_idx++) {
float sum = 0.0;
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
sum += twoDimRead(Pij, br_idx, bc_idx, Bc);
}
lij[br_idx] = sum;
// compute lnew = li + lij
lnew[br_idx] = li[br_idx] + lij[br_idx];
}
// Compute Oi = (liOi + PijVj)/ lnew
for(int br_idx=0; br_idx < Br; br_idx++) {
for(int d_idx=0; d_idx < d; d_idx++) {
float pv_sum =0.0;
for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
int p_row = br_idx;
int p_col = bc_idx;
int v_row = bc_idx;
int v_col = d_idx;
pv_sum += twoDimRead(Pij, p_row, p_col, Bc) * twoDimRead(Vj, v_row, v_col, d);
}
// twoDimWrite(PV, br_idx, d_idx, d, pv_sum);
float li_Oi_val = li[br_idx] * twoDimRead(Oi, br_idx, d_idx, d);
float new_sum = pv_sum + li_Oi_val;
float new_Oi_val = new_sum / lnew[br_idx];
twoDimWrite(Oi, br_idx, d_idx, d, new_Oi_val);
}
}
// Write Oi and lnew back to O and l in main memory
for(int br_idx=0; br_idx < Br; br_idx++) {
int O_row = q_block_idx * Br + br_idx;
if(O_row >= N) continue;
for(int d_idx=0; d_idx < d; d_idx++) {
float Oi_val = twoDimRead(Oi, br_idx, d_idx, d);
int O_col = d_idx;
fourDimWrite(O, b, h, O_row, O_col, H, N, d, Oi_val);
}
l[O_row] = lnew[br_idx];
}
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
OUtput:
My implementation is 3x faster than ref implementation. I don’t know why. Is it because that I don’t materialize PV
?
model_inference 0.11% 771.000us 100.00% 721.392ms 721.392ms 512.00 Kb -10.53 Kb 1
REFERENCE - FLASH ATTENTION 99.56% 718.217ms 99.83% 720.142ms 720.142ms 512.00 Kb -8.00 Mb 1
aten::to 0.00% 5.000us 0.00% 33.000us 5.500us 520 b 0 b 6
aten::_to_copy 0.00% 16.000us 0.00% 28.000us 7.000us 520 b 8 b 4
aten::abs 0.00% 30.000us 0.01% 54.000us 13.500us 512 b 256 b 4
aten::empty_strided 0.00% 3.000us 0.00% 3.000us 0.750us 512 b 512 b 4
aten::isfinite 0.00% 12.000us 0.01% 89.000us 89.000us 224 b 0 b 1
aten::masked_select 0.00% 14.000us 0.01% 43.000us 43.000us 128 b 120 b 1
------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 721.392ms
REFERENCE - FLASH ATTENTION statistics
cpu time: 720.142ms
mem usage: 524288 bytes
-----RUNNING STUDENT IMPLEMENTATION-----
first row value:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
student first row:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
manual attention == pytorch attention True
Manual Execution Time: 0.21289777755737305
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::empty 0.01% 16.000us 0.01% 16.000us 0.889us 1.00 Mb 1.00 Mb 18
aten::clone 0.01% 25.000us 0.16% 332.000us 166.000us 1.00 Mb 0 b 2
aten::zeros 0.02% 36.000us 0.04% 83.000us 6.917us 521.59 Kb 5.00 Kb 12
aten::empty_strided 0.00% 10.000us 0.00% 10.000us 2.000us 512.51 Kb 512.51 Kb 5
model_inference 0.34% 722.000us 100.00% 212.937ms 212.937ms 512.00 Kb -15.47 Kb 1
STUDENT - FLASH ATTENTION 99.25% 211.341ms 99.44% 211.746ms 211.746ms 512.00 Kb -1.00 Mb 1
aten::flatten 0.01% 26.000us 0.10% 216.000us 14.400us 512.00 Kb 0 b 15
aten::empty_like 0.00% 4.000us 0.00% 6.000us 6.000us 512.00 Kb 0 b 1
aten::zero_ 0.00% 5.000us 0.02% 35.000us 2.917us 5.00 Kb 5.00 Kb 12
aten::to 0.01% 16.000us 0.01% 29.000us 4.833us 520 b 4 b 6
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 212.937ms
STUDENT - FLASH ATTENTION statistics
cpu time: 211.746ms
mem usage: 524288 bytes
Enjoy Reading This Article?
Here are some more articles you might like to read next: