Stf CS149 flash attention

Stf CS149 flash attention lab assignment repo

Install library to compile code

It shows error about error loading shared object.

(cs149gpt) ➜  cs149gpt git:(main) ✗ python3 gpt149.py 4Daccess
/home/zt/.local/lib/python3.10/site-packages/torch/cuda/__init__.py:118: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)                                                                                        return torch._C._cuda_getDeviceCount() > 0
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
Traceback (most recent call last):
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 14, in <module>
    import module_ref as ms
ImportError: /home/zt/stf-cs149-pp/cs149gpt/module_ref.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE

I tried to use conda to create new env and install low version of pytorch but conda always installs 2.3.x version of torch for me.

 conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 cpuonly python=3.10 numpy=1.26 ninja tiktoken -c pytorch -c conda-forge

What it is required is 2.1.x version of torch.

Then I manually use pip to uninstall torch and reinstall 2.1.x version of torch

 pip3 uninstall torch
pip3 install torch==2.1.2

Got another issue saying that I should use numpy version < 2.0.

Then I uninstall numpy and reinstall it with version 1.2.x

pip3 uninstall numpy
 pip3 install numpy==1.26

So now I can run the code successfully.

Part 1: naive attention

My code produce the value that is 0.0003 less than what solution produces for each element.

I don’t know why.

Should I use double ?

First implementation:

torch::Tensor myNaiveAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
                int B, int H, int N, int d){

    // Q, K, V are passed in with Shape: (B, H, N, d)
    //QK^t Intermediate Tensor has Shape (N, N)
    
    //Make O Tensor with Shape (B, H, N, d) 
    at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);

    //Format O, Q, K, and V tensors into 4D vectors
    std::vector<float> O = formatTensor(OTensor);
    std::vector<float> Q = formatTensor(QTensor);
    std::vector<float> K = formatTensor(KTensor);
    std::vector<float> V = formatTensor(VTensor);

    //Format QK_t Tensor into a 2D vector.
    std::vector<float> QK_t = formatTensor(QK_tTensor);
    
    /* Here is an example of how to read/write 0's to  Q (B, H, N, d) using the 4D accessors

        //loop over Batch Size
         for (int b = 0; b < B; b++) {

             //loop over Heads
             for (int h = 0; h < H; h++) {

                 //loop over Sequence Length
                 for (int i = 0; i < N; i++) {

                     //loop over Embedding Dimensionality
                     for (int j = 0; j < d; j++) {
                        float val = fourDimRead(Q, b, h, i, j, H, N, d);
                        val = 0.0;
                        fourDimWrite(Q, b, h, i, j, H, N, d, val);
                     }
                 }
             }
         }
    */

    /* Here is an example of how to read/write 0's to  QK_t (N, N) using the 2D accessors

           for (int i = 0; i < N; i++) {
	       for (int j = 0; j < N; j++) {
	           float val = twoDimRead(QK_t, i, j, N);
               val = 0.0;
	           twoDimWrite(QK_t, i, j, N, val);
             }
         }
    */
    
    // -------- YOUR CODE HERE  -------- //
    for (int b = 0; b < B; b++) {
       //loop over Heads
       for (int h = 0; h < H; h++) {
           //loop over Sequence Length
           for (int i = 0; i < N; i++) {
            for(int seq_i=0; seq_i < N; seq_i++) {
             //loop over Embedding Dimensionality
              float val = 0.0;
               for (int j = 0; j < d; j++) {
                  int q_row  = i; 
                  int q_col = j;
                  int k_row = j;
                  int k_col = seq_i;
                  // float val = fourDimRead(Q, b, h, i, j, H, N, d);
          float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
          float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d);
          val += q_val * k_val;


                  // val = 0.0;
                  // fourDimWrite(Q, b, h, i, j, H, N, d, val);
               }
          fourDimWrite(QK_t, b, h, i, seq_i, H, N, d, val );

            }

           }
          std::vector<float> tmp_row_res(N);
          for(int row_idx=0; row_idx < N; row_idx++) {

            float row_sum = 0.0;
            for(int cold_idx=0; cold_idx < N ;cold_idx++) {
               float val = twoDimRead(QK_t, row_idx, cold_idx, N);
              float exp_val = std::exp(val);
              row_sum += exp_val;
              tmp_row_res[cold_idx] = exp_val;

            }

            for(int cold_idx=0; cold_idx < N ; cold_idx++) {
              float prob = tmp_row_res[cold_idx] / row_sum;
              twoDimWrite(QK_t, row_idx, cold_idx, N, prob);
            }
          }


        for(int qkt_row_idx=0; qkt_row_idx < N; qkt_row_idx++) {
        for(int output_d_idx=0; output_d_idx < d; output_d_idx++) {
          float val =0.0;
          for(int m_idx=0; m_idx < N ; m_idx++) {
            float qkt_val =  twoDimRead(QK_t, qkt_row_idx, m_idx, N);
            int v_row = m_idx;
            int v_col = output_d_idx;
            float v_val = fourDimRead(V, b, h, v_row, v_col, H, N, d);
            val += qkt_val * v_val;
          }
          fourDimWrite(O, b, h, qkt_row_idx, output_d_idx, H, N, d ,val);
        }
        }
       }
   }




    
    // DO NOT EDIT THIS RETURN STATEMENT //
    // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
    return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-17 21:15:18 207308:207308 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0768, 0.0776, 0.0784, 0.0792, 0.0800, 0.0808, 0.0816, 0.0824, 0.0832,
        0.0840, 0.0848, 0.0856, 0.0864, 0.0872, 0.0880, 0.0888, 0.0896, 0.0904,
        0.0912, 0.0920, 0.0928, 0.0936, 0.0944, 0.0952, 0.0960, 0.0968, 0.0976,
        0.0984, 0.0992, 0.1000, 0.1008, 0.1016])
STAGE:2024-11-17 21:15:19 207308:207308 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-17 21:15:19 207308:207308 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
    main()
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 311, in main
    part1Test(N, d, B, H)
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 221, in part1Test
    testTemplate(attentionModuleReference.myUnfusedAttention, params, "STUDENT - NAIVE ATTENTION")
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
    assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
 YOUR ATTENTION PRODUCED INCORRECT RESULTS

Try to use double to store itermediate result to fix this problem.

I have two issues in previous version of code

    for (int b = 0; b < B; b++) {
       //loop over Heads
       for (int h = 0; h < H; h++) {
          //loop over Sequence Length
          for (int i = 0; i < N; i++) {
            for(int seq_i=0; seq_i < N; seq_i++) {
             //loop over Embedding Dimensionality
              float val = 0.0;
               for (int j = 0; j < d; j++) {
                  int q_row  = i; 
                  int q_col = j;

                    // this is the correct indexing for the second matrix.
                    // Since K is not transposed.
                     // K should be indexed with (seq_i, j) instead of (j, seq_i) like normal matrix multiplciation
                  int k_row = seq_i;
                  int k_col = j;
                  float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
                  float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d);
                  val += q_val * k_val;
               }
                // This is the second place that is fixed. 
                 // QK_t is two dimenional. 
                 // Should use twoDimWrite
              twoDimWrite(QK_t, i, seq_i, N, val );
            }

           }
 

Output:


REFERENCE - NAIVE ATTENTION statistics
cpu time:  293.585ms
mem usage:  4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:10:17 1148267:1148267 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:10:17 1148267:1148267 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 13:10:17 1148267:1148267 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time:  0.2747969627380371

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::empty         0.01%      16.000us         0.01%      16.000us       2.000us       5.00 Mb       5.00 Mb             8
    STUDENT - NAIVE ATTENTION        99.49%     273.438ms        99.68%     273.946ms     273.946ms       4.50 Mb      -1.00 Mb             1
                  aten::zeros         0.01%      14.000us         0.05%     133.000us      66.500us       4.50 Mb           0 b             2
                  aten::clone         0.01%      20.000us         0.13%     346.000us     173.000us       1.00 Mb           0 b             2
          aten::empty_strided         0.00%      12.000us         0.00%      12.000us       2.400us     512.51 Kb     512.51 Kb             5
              model_inference         0.20%     549.000us       100.00%     274.838ms     274.838ms     512.00 Kb      -4.00 Mb             1
                aten::flatten         0.01%      19.000us         0.08%     209.000us      41.800us     512.00 Kb           0 b             5
             aten::empty_like         0.00%       3.000us         0.00%       5.000us       5.000us     512.00 Kb           0 b             1
                     aten::to         0.00%       6.000us         0.01%      31.000us       5.167us         520 b           0 b             6
               aten::_to_copy         0.01%      14.000us         0.01%      25.000us       6.250us         520 b           0 b             4
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 274.838ms

STUDENT - NAIVE ATTENTION statistics
cpu time:  273.946ms
mem usage:  4718592 bytes

Part2: blocked matrix multiplication

Initial version

Unlike that of code written in cuda, I have to manually iterate all rows and cols of input matrix.

Don’t know if this is correct way to do that.

#define TILE_SIZE 16
torch::Tensor myUnfusedAttentionBlocked(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
                int B, int H, int N, int d){
  
  // Q, K, V are passed in with Shape: (B, H, N, d)
  //QK^t Intermediate Tensor has Shape (N, N)

  //Make O Tensor with Shape (B, H, N, d) 
  at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);

  //Format O, Q, K, and V tensors into 4D vectors
  std::vector<float> O = formatTensor(OTensor);
  std::vector<float> Q = formatTensor(QTensor);
  std::vector<float> K = formatTensor(KTensor);
  std::vector<float> V = formatTensor(VTensor);

  //Format QK_t Tensor into a 2D vector.
  std::vector<float> QK_t = formatTensor(QK_tTensor);

  // -------- YOUR CODE HERE  -------- //
  for(int b=0; b < B; b++) {
    for(int h=0; h < H; h++) {
      for(int q_row_tile_idx=0; q_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; q_row_tile_idx++) {
        // K is not transposed so we traverse k by row.
        for(int k_row_tile_idx=0; k_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; k_row_tile_idx++ ) {
          for(int d_col_tile_idx=0; d_col_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; d_col_tile_idx++ ) {
            for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
              // int out_row_idx = q_row_tile_idx * TILE_SIZE + tile_row_idx;
              for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                // int out_col_idx = k_row_tile_idx * TILE_SIZE + tile_col_idx;
                int q_col_idx = d_col_tile_idx * TILE_SIZE + tile_col_idx;
                int q_row_idx =q_row_tile_idx * TILE_SIZE + tile_row_idx; 
                int k_row_idx = k_row_tile_idx * TILE_SIZE + tile_row_idx;
                int k_col_idx = d_col_tile_idx * TILE_SIZE + tile_col_idx;
                if(q_row_idx < N && q_col_idx < d && k_row_idx < N && k_col_idx < d) {
                  float q_tile_val = fourDimRead(Q, b, h, q_row_idx, q_col_idx, H, N, d);
                  float k_tile_val = fourDimRead(K, b, h, k_row_idx, k_col_idx, H, N, d);
                  float orig_val = twoDimRead(QK_t, q_row_tile_idx, k_row_idx, N);
                  float val = q_tile_val * k_tile_val + orig_val;
                  twoDimWrite(QK_t, q_row_tile_idx, k_row_tile_idx, N, val );
                }
              }
            }
          }

        }
      }

      for(int row_idx=0; row_idx < N; row_idx++) {
        std::vector<double> tmp_row_res(N, 0.0);
        double row_sum = 0.0;
        for(int cold_idx=0; cold_idx < N ;cold_idx++) {
           float val = twoDimRead(QK_t, row_idx, cold_idx, N);
          double exp_val = std::exp(val);
          row_sum += exp_val;
          tmp_row_res[cold_idx] = exp_val;
        }
        for(int cold_idx=0; cold_idx < N ; cold_idx++) {
          float prob = tmp_row_res[cold_idx] / row_sum;
          twoDimWrite(QK_t, row_idx, cold_idx, N, prob);
        }
      }

      for(int qkt_row_idx=0; qkt_row_idx < N; qkt_row_idx++) {
        for(int output_d_idx=0; output_d_idx < d; output_d_idx++) {
          float val =0.0;
          for(int m_idx=0; m_idx < N ; m_idx++) {
            float qkt_val =  twoDimRead(QK_t, qkt_row_idx, m_idx, N);
            int v_row = m_idx;
            int v_col = output_d_idx;
            float v_val = fourDimRead(V, b, h, v_row, v_col, H, N, d);
            val += qkt_val * v_val;
          }
          fourDimWrite(O, b, h, qkt_row_idx, output_d_idx, H, N, d ,val);
        }
      }

    }

  }


  // DO NOT EDIT THIS RETURN STATEMENT //
  // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
  return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}

It’s not correct

-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 11:29:24 607233:607233 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0767, 0.0775, 0.0783, 0.0791, 0.0799, 0.0807, 0.0815, 0.0823, 0.0831,
        0.0839, 0.0847, 0.0855, 0.0863, 0.0871, 0.0879, 0.0887, 0.0895, 0.0903,
        0.0911, 0.0919, 0.0927, 0.0935, 0.0943, 0.0951, 0.0959, 0.0967, 0.0975,
        0.0983, 0.0991, 0.0999, 0.1007, 0.1015])
STAGE:2024-11-18 11:29:24 607233:607233 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 11:29:24 607233:607233 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
    main()
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 313, in main
    part2Test(N, d, B, H)
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 233, in part2Test
    testTemplate(attentionModuleReference.myUnfusedAttentionBlocked, params, "STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX")
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
    assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
 YOUR ATTENTION PRODUCED INCORRECT RESULTS

Fixed the two indexing issue in original issue.

This is the correct implementation given by gpt. However, it’s too slow

      for (int q_row_tile_idx = 0; q_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; q_row_tile_idx++) {
    for (int k_row_tile_idx = 0; k_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; k_row_tile_idx++) {
        for (int d_col_tile_idx = 0; d_col_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; d_col_tile_idx++) {
            for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
                for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                    int q_row_idx = q_row_tile_idx * TILE_SIZE + tile_row_idx;
                    int k_row_idx = k_row_tile_idx * TILE_SIZE + tile_col_idx; // Fix indexing
                    for (int d_idx = d_col_tile_idx * TILE_SIZE; d_idx < (d_col_tile_idx + 1) * TILE_SIZE; d_idx++) {
                        if (q_row_idx < N && k_row_idx < N && d_idx < d) {
                            float q_tile_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
                            float k_tile_val = fourDimRead(K, b, h, k_row_idx, d_idx, H, N, d);
                            float orig_val = twoDimRead(QK_t, q_row_idx, k_row_idx, N);
                            float val = q_tile_val * k_tile_val + orig_val;
                            twoDimWrite(QK_t, q_row_idx, k_row_idx, N, val);
                        }
                    }
                }
            }
        }
    }
}

Explanation from gpt

You’re right to question the necessity of traversing the dimension ( d ) fully. Let’s clarify why it’s important in some cases and whether it applies here.

Understanding the Role of Dimension ( d )

In block matrix multiplication:

  1. Matrix Multiplication Basics:
    • To compute an element ((i, j)) of the resulting matrix ( QK_t ), you need to calculate the dot product of the (i)-th row of (Q) and the (j)-th row of (K).
    • This involves summing the product of corresponding elements from both rows over the shared dimension ( d ):
      [ QK_t[i, j] = \sum_{k=0}^{d-1} Q[i, k] \times K[j, k] ]
  2. Tiling:
    • In tiled matrix multiplication, the computation of a block of ( QK_t ) (e.g., a ( \text{TILE_SIZE} \times \text{TILE_SIZE} ) block) is split into smaller sub-blocks to improve memory locality.
    • Each sub-block computation iterates over the shared dimension ( d ) in chunks.

Why Traverse ( d )?

If you don’t traverse ( d ), the computation will only consider partial contributions to the dot product. This would leave the values in ( QK_t ) incomplete.

For example, if ( d = 6 ), and your block size is ( \text{TILE_SIZE} = 3 ):

  • You compute the dot product for ( d \in [0, 2] ), but ignore ( d \in [3, 5] ), leading to incorrect results.

Why the Code Needs It

Your original code suggests you’re updating ( QK_t ) incrementally:

float val = q_tile_val * k_tile_val + orig_val;

Here, orig_val accumulates partial results. Without iterating fully over ( d ), the result will be incomplete because only a subset of the dot product contributions is added.

When ( d ) Traversal Isn’t Necessary

In rare cases, you might not need to traverse ( d ) if:

  1. Each block fully covers ( d ):
    • This is possible only when ( \text{TILE_SIZE} ) equals ( d ).
    • In this case, the tiling logic aligns perfectly with the shared dimension, and a single traversal suffices.
  2. Precomputed Values:
    • If ( Q ) or ( K ) has been preprocessed (e.g., multiplied previously), the shared dimension ( d ) might not need iteration.

Conclusion

For the general case (and your code):

  • Traversing ( d ) fully ensures the correctness of the dot product computation across the shared dimension.
  • If you’re sure ( d ) traversal isn’t necessary, it usually means ( Q ) and ( K ) are already structured to encapsulate the full dot product for each tile.

If you’re still unsure about why it’s needed in your case, we can examine your problem more closely to confirm!

I can view the outside for loop from two kind of perspective.

One is from input matrix.

One is from output matrix. This one makes more sense after reading code generated by gpt

Two outer for loop traverse row and column tiles of output matrix. The third for loop traverse the row in input matrix and the column in second input matrix The fourth and fifth for loop traverse the row and column elements in current tile. The sixth for loop does dot product for a single row in input tile and single column in output tile and accumulate the sum. And then write back the sum to output matrix.

for (int row_tile_idx = 0; row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; row_tile_idx++) {
    for (int col_tile_idx = 0; col_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; col_tile_idx++) {
        // Initialize the current block of the output matrix
        for (int i = 0; i < TILE_SIZE; i++) {
            for (int j = 0; j < TILE_SIZE; j++) {
                int row_idx = row_tile_idx * TILE_SIZE + i;
                int col_idx = col_tile_idx * TILE_SIZE + j;
                if (row_idx < N && col_idx < N) {
                    twoDimWrite(C, row_idx, col_idx, N, 0.0); // Initialize to zero
                }
            }
        }

        for (int k_tile_idx = 0; k_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; k_tile_idx++) {
            for (int i = 0; i < TILE_SIZE; i++) {
                for (int j = 0; j < TILE_SIZE; j++) {
                    int row_idx = row_tile_idx * TILE_SIZE + i;
                    int col_idx = col_tile_idx * TILE_SIZE + j;
                    if (row_idx >= N || col_idx >= N) continue;

                    float sum = twoDimRead(C, row_idx, col_idx, N);

                    for (int k = 0; k < TILE_SIZE; k++) {
                        int k_idx = k_tile_idx * TILE_SIZE + k;
                        if (k_idx >= d) break;

                        float a_val = twoDimRead(A, row_idx, k_idx, d);
                        float b_val = twoDimRead(B, col_idx, k_idx, d); // Column index now indexes rows in B
                        sum += a_val * b_val;
                    }

                    twoDimWrite(C, row_idx, col_idx, N, sum);
                }
            }
        }
    }
}

The code above use global buffer.

      for(int row_tile_idx=0; row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; row_tile_idx++) {
        for(int col_tile_idx=0; col_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; col_tile_idx++) {
          for(int k_tile_idx=0; k_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++  ) {
            for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
              for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
                int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
                if(row_idx >= N || col_idx >= N) {
                  continue;
                }
                float sum = twoDimRead(QK_t, row_idx, col_idx, N);

                for(int k=0; k < TILE_SIZE; k++) {
                  int k_idx = k_tile_idx * TILE_SIZE + k;
                  if(k_idx >= d) break;
                  float q_val =  fourDimRead(Q,b, h, row_idx, k_idx, H, N, d);
                  float k_val = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
                  sum += q_val * k_val;
                }
                twoDimWrite(QK_t, row_idx, col_idx, N, sum);
              }

            }
          }
        }
      }

Output:


REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  238.938ms
mem usage:  4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:39:56 1318038:1318038 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:39:56 1318038:1318038 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 13:39:56 1318038:1318038 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time:  0.29826903343200684

----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                   aten::empty         0.01%      21.000us         0.01%      21.000us       2.625us       5.00 Mb       5.00 Mb             8
    STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX        99.35%     296.386ms        99.70%     297.443ms     297.443ms       4.50 Mb      -1.00 Mb             1
                                   aten::zeros         0.01%      15.000us         0.24%     706.000us     353.000us       4.50 Mb           0 b             2
                                   aten::clone         0.01%      19.000us         0.11%     328.000us     164.000us       1.00 Mb           0 b             2
                           aten::empty_strided         0.00%      11.000us         0.00%      11.000us       2.200us     512.51 Kb     512.51 Kb             5
                               model_inference         0.18%     550.000us       100.00%     298.329ms     298.329ms     512.00 Kb      -4.00 Mb             1
                                 aten::flatten         0.01%      15.000us         0.07%     200.000us      40.000us     512.00 Kb           0 b             5
                              aten::empty_like         0.00%       3.000us         0.00%       5.000us       5.000us     512.00 Kb           0 b             1
                                      aten::to         0.00%       6.000us         0.01%      31.000us       5.167us         520 b           0 b             6
                                aten::_to_copy         0.00%      14.000us         0.01%      25.000us       6.250us         520 b           0 b             4
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 298.329ms

STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  297.443ms
mem usage:  4718592 bytes

This code use local buffer that is allocated. A little bit faster but not a lot.

#pragma omp parallel for collapse(2) // Parallelize the two outermost loops
for (int row_tile_idx = 0; row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; row_tile_idx++) {
    for (int col_tile_idx = 0; col_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; col_tile_idx++) {
        for (int k_tile_idx = 0; k_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; k_tile_idx++) {

            // Buffers for tile data
            float Q_tile[TILE_SIZE][TILE_SIZE];
            float K_tile[TILE_SIZE][TILE_SIZE];

            // Preload Q and K tiles into local buffers
            for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
                int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
                if (row_idx >= N) continue; // Skip out-of-bound rows

                for (int k = 0; k < TILE_SIZE; k++) {
                    int k_idx = k_tile_idx * TILE_SIZE + k;
                    if (k_idx < d) {
                        Q_tile[tile_row_idx][k] = fourDimRead(Q, b, h, row_idx, k_idx, H, N, d);
                    } else {
                        Q_tile[tile_row_idx][k] = 0.0f; // Fill with zero if out-of-bounds
                    }
                }
            }

            for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
                if (col_idx >= N) continue; // Skip out-of-bound columns

                for (int k = 0; k < TILE_SIZE; k++) {
                    int k_idx = k_tile_idx * TILE_SIZE + k;
                    if (k_idx < d) {
                        K_tile[tile_col_idx][k] = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
                    } else {
                        K_tile[tile_col_idx][k] = 0.0f; // Fill with zero if out-of-bounds
                    }
                }
            }

            // Compute the dot product for the current tile
            for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
                int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
                if (row_idx >= N) continue; // Skip out-of-bound rows

                for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                    int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
                    if (col_idx >= N) continue; // Skip out-of-bound columns

                    float sum = twoDimRead(QK_t, row_idx, col_idx, N);

                    // Unrolled loop for vectorized dot product
                    for (int k = 0; k < TILE_SIZE; k++) {
                        sum += Q_tile[tile_row_idx][k] * K_tile[tile_col_idx][k];
                    }

                    twoDimWrite(QK_t, row_idx, col_idx, N, sum);
                }
            }
        }
    }
}

OUtput:

REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  226.667ms
mem usage:  4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:44:02 1342423:1342423 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 13:44:02 1342423:1342423 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 13:44:02 1342423:1342423 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time:  0.2852001190185547

----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                   aten::empty         0.00%      13.000us         0.00%      13.000us       1.625us       5.00 Mb       5.00 Mb             8
    STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX        99.52%     283.895ms        99.68%     284.350ms     284.350ms       4.50 Mb      -1.00 Mb             1
                                   aten::zeros         0.01%      17.000us         0.05%     131.000us      65.500us       4.50 Mb           0 b             2
                                   aten::clone         0.01%      23.000us         0.11%     300.000us     150.000us       1.00 Mb           0 b             2
                           aten::empty_strided         0.00%      12.000us         0.00%      12.000us       2.400us     512.26 Kb     512.26 Kb             5
                               model_inference         0.20%     565.000us       100.00%     285.254ms     285.254ms     512.00 Kb      -4.00 Mb             1
                                 aten::flatten         0.01%      16.000us         0.06%     159.000us      31.800us     512.00 Kb           0 b             5
                              aten::empty_like         0.00%       3.000us         0.00%       5.000us       5.000us     512.00 Kb           0 b             1
                                      aten::to         0.00%      10.000us         0.01%      31.000us       5.167us         520 b         256 b             6
                                aten::_to_copy         0.01%      16.000us         0.01%      26.000us       6.500us         520 b         256 b             4
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 285.254ms

STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  284.35ms
mem usage:  4718592 bytes

Code that do block matrix multiplication for both (Q * K_t) and (softmax(QK_t ) * V):

#define TILE_SIZE 16
torch::Tensor myUnfusedAttentionBlocked(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
                int B, int H, int N, int d){
  
  // Q, K, V are passed in with Shape: (B, H, N, d)
  //QK^t Intermediate Tensor has Shape (N, N)

  //Make O Tensor with Shape (B, H, N, d) 
  at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);

  //Format O, Q, K, and V tensors into 4D vectors
  std::vector<float> O = formatTensor(OTensor);
  std::vector<float> Q = formatTensor(QTensor);
  std::vector<float> K = formatTensor(KTensor);
  std::vector<float> V = formatTensor(VTensor);

  //Format QK_t Tensor into a 2D vector.
  std::vector<float> QK_t = formatTensor(QK_tTensor);

  // -------- YOUR CODE HERE  -------- //
  for(int b=0; b < B; b++) {
    for(int h=0; h < H; h++) {

      //
      // correct
      // for(int row_tile_idx=0; row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; row_tile_idx++) {
      //   for(int col_tile_idx=0; col_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; col_tile_idx++) {
      //     for(int k_tile_idx=0; k_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++  ) {
      //       for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
      //         for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
      //           int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
      //           int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
      //           if(row_idx >= N || col_idx >= N) {
      //             continue;
      //           }
      //           float sum = twoDimRead(QK_t, row_idx, col_idx, N);

      //           for(int k=0; k < TILE_SIZE; k++) {
      //             int k_idx = k_tile_idx * TILE_SIZE + k;
      //             if(k_idx >= d) break;
      //             float q_val =  fourDimRead(Q,b, h, row_idx, k_idx, H, N, d);
      //             float k_val = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
      //             sum += q_val * k_val;
      //           }
      //           twoDimWrite(QK_t, row_idx, col_idx, N, sum);
      //         }

      //       }
      //     }
      //   }
      // }

      // correct with local buffer
for (int row_tile_idx = 0; row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; row_tile_idx++) {
    for (int col_tile_idx = 0; col_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; col_tile_idx++) {
        for (int k_tile_idx = 0; k_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; k_tile_idx++) {

            // Buffers for tile data
            float Q_tile[TILE_SIZE][TILE_SIZE];
            float K_tile[TILE_SIZE][TILE_SIZE];

            // Preload Q and K tiles into local buffers
            for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
                int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
                if (row_idx >= N) continue; // Skip out-of-bound rows

                for (int k = 0; k < TILE_SIZE; k++) {
                    int k_idx = k_tile_idx * TILE_SIZE + k;
                    if (k_idx < d) {
                        Q_tile[tile_row_idx][k] = fourDimRead(Q, b, h, row_idx, k_idx, H, N, d);
                    } else {
                        Q_tile[tile_row_idx][k] = 0.0f; // Fill with zero if out-of-bounds
                    }
                }
            }

            for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
                if (col_idx >= N) continue; // Skip out-of-bound columns

                for (int k = 0; k < TILE_SIZE; k++) {
                    int k_idx = k_tile_idx * TILE_SIZE + k;
                    if (k_idx < d) {
                        K_tile[tile_col_idx][k] = fourDimRead(K, b, h, col_idx, k_idx, H, N, d);
                    } else {
                        K_tile[tile_col_idx][k] = 0.0f; // Fill with zero if out-of-bounds
                    }
                }
            }

            // Compute the dot product for the current tile
            for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
                int row_idx = row_tile_idx * TILE_SIZE + tile_row_idx;
                if (row_idx >= N) continue; // Skip out-of-bound rows

                for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                    int col_idx = col_tile_idx * TILE_SIZE + tile_col_idx;
                    if (col_idx >= N) continue; // Skip out-of-bound columns

                    float sum = twoDimRead(QK_t, row_idx, col_idx, N);

                    // Unrolled loop for vectorized dot product
                    for (int k = 0; k < TILE_SIZE; k++) {
                        sum += Q_tile[tile_row_idx][k] * K_tile[tile_col_idx][k];
                    }

                    twoDimWrite(QK_t, row_idx, col_idx, N, sum);
                }
            }
        }
    }
}



      // also correct
//       for (int q_row_tile_idx = 0; q_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; q_row_tile_idx++) {
//     for (int k_row_tile_idx = 0; k_row_tile_idx < (N + TILE_SIZE - 1) / TILE_SIZE; k_row_tile_idx++) {
//         for (int d_col_tile_idx = 0; d_col_tile_idx < (d + TILE_SIZE - 1) / TILE_SIZE; d_col_tile_idx++) {
//             for (int tile_row_idx = 0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
//                 for (int tile_col_idx = 0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
//                     int q_row_idx = q_row_tile_idx * TILE_SIZE + tile_row_idx;
//                     int k_row_idx = k_row_tile_idx * TILE_SIZE + tile_col_idx; // Fix indexing
//                     for (int d_idx = d_col_tile_idx * TILE_SIZE; d_idx < (d_col_tile_idx + 1) * TILE_SIZE; d_idx++) {
//                         if (q_row_idx < N && k_row_idx < N && d_idx < d) {
//                             float q_tile_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
//                             float k_tile_val = fourDimRead(K, b, h, k_row_idx, d_idx, H, N, d);
//                             float orig_val = twoDimRead(QK_t, q_row_idx, k_row_idx, N);
//                             float val = q_tile_val * k_tile_val + orig_val;
//                             twoDimWrite(QK_t, q_row_idx, k_row_idx, N, val);
//                         }
//                     }
//                 }
//             }
//         }
//     }
// }




      for(int row_idx=0; row_idx < N; row_idx++) {
        std::vector<double> tmp_row_res(N, 0.0);
        double row_sum = 0.0;
        for(int cold_idx=0; cold_idx < N ;cold_idx++) {
           float val = twoDimRead(QK_t, row_idx, cold_idx, N);
          double exp_val = std::exp(val);
          row_sum += exp_val;
          tmp_row_res[cold_idx] = exp_val;
        }
        for(int cold_idx=0; cold_idx < N ; cold_idx++) {
          float prob = tmp_row_res[cold_idx] / row_sum;
          twoDimWrite(QK_t, row_idx, cold_idx, N, prob);
        }
      }


      
      for(int qkt_row_tile_idx=0; qkt_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; qkt_row_tile_idx++) {
        for(int output_d_tile_idx=0; output_d_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; output_d_tile_idx++) {

          for(int k_tile_idx=0; k_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++) {
            for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
                int out_row_idx = qkt_row_tile_idx * TILE_SIZE + tile_row_idx;
              if(out_row_idx >= N) continue;
              for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                int out_col_idx = output_d_tile_idx * TILE_SIZE + tile_col_idx;
                if( out_col_idx >= d) continue;

                float sum = fourDimRead(O, b, h, out_row_idx, out_col_idx, H, N, d );
                for(int k=0; k < TILE_SIZE; k++) {
                  int k_idx = k_tile_idx * TILE_SIZE + k;
                  float qkt_val = twoDimRead(QK_t, out_row_idx, k_idx, N);
                  float v_val = fourDimRead(V, b, h, k_idx, out_col_idx, H, N, d);
                  sum += qkt_val * v_val; 
                }
                fourDimWrite(O, b, h, out_row_idx, out_col_idx, H, N, d, sum);
              }
            }
          }
        }
      }


    }

  }


  // DO NOT EDIT THIS RETURN STATEMENT //
  // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
  return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}

Output:

Achieve the same cpu time as ref solution.

Self CPU time total: 215.420ms

REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  214.494ms
mem usage:  4718592 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 20:06:22 3461555:3461555 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-18 20:06:23 3461555:3461555 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 20:06:23 3461555:3461555 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time:  0.2183218002319336

----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                   aten::empty         0.01%      12.000us         0.01%      12.000us       1.500us       5.00 Mb       5.00 Mb             8
    STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX        99.38%     217.015ms        99.61%     217.504ms     217.504ms       4.50 Mb      -1.00 Mb             1
                                   aten::zeros         0.01%      17.000us         0.07%     142.000us      71.000us       4.50 Mb           0 b             2
                                   aten::clone         0.01%      19.000us         0.14%     313.000us     156.500us       1.00 Mb           0 b             2
                           aten::empty_strided         0.00%       9.000us         0.00%       9.000us       1.800us     512.25 Kb     512.25 Kb             5
                               model_inference         0.24%     532.000us       100.00%     218.359ms     218.359ms     512.00 Kb      -4.00 Mb             1
                                 aten::flatten         0.01%      25.000us         0.09%     189.000us      37.800us     512.00 Kb           0 b             5
                              aten::empty_like         0.00%       4.000us         0.00%       5.000us       5.000us     512.00 Kb           0 b             1
                                      aten::to         0.00%       9.000us         0.01%      30.000us       5.000us         520 b           4 b             6
                                aten::_to_copy         0.01%      14.000us         0.01%      25.000us       6.250us         520 b         260 b             4
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 218.359ms

STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  217.504ms
mem usage:  4718592 bytes

Another issue of code above.

It’s not correct when output matrix size is not power of 2

(myenv) ➜  cs149gpt git:(main) python3 gpt149.py part2 -N 103
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0077, 0.0085, 0.0093, 0.0101, 0.0109, 0.0117, 0.0125, 0.0133, 0.0141,
        0.0149, 0.0157, 0.0165, 0.0173, 0.0181, 0.0189, 0.0197, 0.0205, 0.0213,
        0.0221, 0.0229, 0.0237, 0.0245, 0.0253, 0.0261, 0.0269, 0.0277, 0.0285,
        0.0293, 0.0301, 0.0309, 0.0317, 0.0325])
STAGE:2024-11-18 23:12:28 271968:271968 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0077, 0.0086, 0.0094, 0.0103, 0.0112, 0.0121, 0.0129, 0.0138, 0.0147,
        0.0155, 0.0164, 0.0173, 0.0181, 0.0190, 0.0199, 0.0208, 0.0216, 0.0225,
        0.0234, 0.0242, 0.0251, 0.0260, 0.0268, 0.0277, 0.0286, 0.0295, 0.0303,
        0.0312, 0.0321, 0.0329, 0.0338, 0.0347])
STAGE:2024-11-18 23:12:28 271968:271968 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-18 23:12:28 271968:271968 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
    main()
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 313, in main
    part2Test(N, d, B, H)
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 233, in part2Test
    testTemplate(attentionModuleReference.myUnfusedAttentionBlocked, params, "STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX")
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
    assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
 YOUR ATTENTION PRODUCED INCORRECT RESULTS

Found the root cause. It’s because there is out of bound check for softmax(QK_t) * V

      for(int qkt_row_tile_idx=0; qkt_row_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; qkt_row_tile_idx++) {
        for(int output_d_tile_idx=0; output_d_tile_idx < (d+TILE_SIZE-1)/TILE_SIZE; output_d_tile_idx++) {

          for(int k_tile_idx=0; k_tile_idx < (N+TILE_SIZE-1)/TILE_SIZE; k_tile_idx++) {
            for(int tile_row_idx=0; tile_row_idx < TILE_SIZE; tile_row_idx++) {
                int out_row_idx = qkt_row_tile_idx * TILE_SIZE + tile_row_idx;
              if(out_row_idx >= N) continue;
              for(int tile_col_idx=0; tile_col_idx < TILE_SIZE; tile_col_idx++) {
                int out_col_idx = output_d_tile_idx * TILE_SIZE + tile_col_idx;
                if( out_col_idx >= d) continue;

                float sum = fourDimRead(O, b, h, out_row_idx, out_col_idx, H, N, d );
                for(int k=0; k < TILE_SIZE; k++) {
                  int k_idx = k_tile_idx * TILE_SIZE + k;
                // Needto add this out of bound check
                  if(k_idx >= N) break;
                  float qkt_val = twoDimRead(QK_t, out_row_idx, k_idx, N);
                  float v_val = fourDimRead(V, b, h, k_idx, out_col_idx, H, N, d);
                  sum += qkt_val * v_val; 
                }
                fourDimWrite(O, b, h, out_row_idx, out_col_idx, H, N, d, sum);
              }
            }
          }
        }
      }

Now the result is correct when N is not power of 2. The cpu time is higher than aht of ref solution. I think this is because I add this out of bound check ?

Output:

(myenv) ➜  cs149gpt git:(main) ✗ python3 gpt149.py part2 -N 1023
Self CPU time total: 223.154ms

REFERENCE - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  222.105ms
mem usage:  4709892 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0770, 0.0778, 0.0786, 0.0794, 0.0802, 0.0810, 0.0818, 0.0826, 0.0834,
        0.0842, 0.0850, 0.0858, 0.0866, 0.0874, 0.0882, 0.0890, 0.0898, 0.0906,
        0.0914, 0.0922, 0.0930, 0.0938, 0.0946, 0.0954, 0.0962, 0.0970, 0.0978,
        0.0986, 0.0994, 0.1002, 0.1010, 0.1018])
STAGE:2024-11-19 10:48:32 4016683:4016683 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0770, 0.0778, 0.0786, 0.0794, 0.0802, 0.0810, 0.0818, 0.0826, 0.0834,
        0.0842, 0.0850, 0.0858, 0.0866, 0.0874, 0.0882, 0.0890, 0.0898, 0.0906,
        0.0914, 0.0922, 0.0930, 0.0938, 0.0946, 0.0954, 0.0962, 0.0970, 0.0978,
        0.0986, 0.0994, 0.1002, 0.1010, 0.1018])
STAGE:2024-11-19 10:48:32 4016683:4016683 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 10:48:32 4016683:4016683 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time:  0.23152732849121094

----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                   aten::empty         0.01%      13.000us         0.01%      13.000us       1.625us       4.99 Mb       4.99 Mb             8
    STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX        99.41%     230.208ms        99.62%     230.695ms     230.695ms       4.49 Mb   -1023.00 Kb             1
                                   aten::zeros         0.01%      32.000us         0.06%     145.000us      72.500us       4.49 Mb           0 b             2
                                   aten::clone         0.01%      18.000us         0.14%     318.000us     159.000us    1023.00 Kb           0 b             2
                           aten::empty_strided         0.00%       9.000us         0.00%       9.000us       1.800us     511.75 Kb     511.75 Kb             5
                               model_inference         0.24%     546.000us       100.00%     231.569ms     231.569ms     511.50 Kb      -3.99 Mb             1
                                 aten::flatten         0.01%      16.000us         0.07%     170.000us      34.000us     511.50 Kb           0 b             5
                              aten::empty_like         0.00%       4.000us         0.00%       6.000us       6.000us     511.50 Kb           0 b             1
                                      aten::to         0.00%       6.000us         0.01%      30.000us       5.000us         552 b          32 b             6
                                aten::_to_copy         0.01%      15.000us         0.01%      24.000us       6.000us         520 b         260 b             4
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 231.569ms

STUDENT - BLOCKED MATMUL + UNFUSED SOFTMAX statistics
cpu time:  230.695ms
mem usage:  4709892 bytes

Part3: Fused attention

Problematic implementation: Output is all zero. Code:

torch::Tensor myFusedAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor temp,
              int B, int H, int N, int d){

  // Q, K, V are passed in with Shape: (B, H, N, d)

  //Make O Tensor with Shape (B, H, N, d)
  //and O Row Tensor with Shape (N)
  at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
  at::Tensor ORowTensor = at::zeros({N}, at::kFloat);

  //Format Y, Q, K, and V tensors into 4D vectors
  std::vector<float> O = formatTensor(OTensor);
  std::vector<float> Q = formatTensor(QTensor);
  std::vector<float> K = formatTensor(KTensor);
  std::vector<float> V = formatTensor(VTensor);
  
  //Format ORow Tensor into a 1D vector
  // You can simply access this as ORow[i]
  std::vector<float> ORow = formatTensor(ORowTensor);


  // -------- YOUR CODE HERE  -------- //
  // We give you a template of the first three loops for your convenience
  //loop over batch

  #pragma omp parallel for collapse(3)
  for (int b = 0; b < B; b++){
    //loop over heads
    for (int h = 0; h < H; h++){
        for (int q_row_idx = 0; q_row_idx < N ; q_row_idx++){

  // YRow is moved inside so each OpenMP thread gets a local copy.
            at::Tensor ORowTensor = temp.index({torch::indexing::Slice(omp_get_thread_num(), torch::indexing::None)});      
            std::vector<float> ORow = formatTensor(ORowTensor);
  //YOUR CODE HERE
        for(int k_row_idx=0; k_row_idx < N; k_row_idx++) {
          float val = 0.0;
          for(int d_idx=0; d_idx < d ;d_idx++ ) {
            int q_row = q_row_idx;
            int q_col = d_idx;
            int k_row = k_row_idx;
            int k_col = d_idx;
            float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
            float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d) ;
            val += q_val * k_val;
          }
          ORow[k_row_idx] = val;

        }
        // softmax
        std::vector<float> tmp_row_res(N, 0.0);
        float sum = 0.0;
        for(int i=0; i < N; i++) {
          ORow[i]  = std::exp(ORow[i]) ;
          sum += ORow[i];
          // tmp_row_res[i] = exp_val;
        }
        for(int i=0; i < N; i++) {
          float prob = ORow[i]  /  sum;
          ORow[i] = prob;
        }

        for(int v_col_idx=0; v_col_idx < d; v_col_idx++) {
          float sum =0.0;
          for(int v_row_idx=0; v_row_idx < N; v_row_idx++) {
            int v_val = fourDimRead(V, b, h, v_row_idx, v_col_idx, H, N ,d);
            sum += v_val * ORow[v_row_idx];
          }
          std::cout << "vcold_idx" << v_col_idx << "val: " << sum << std::endl;
          // tmp_row_res[v_col_idx] = sum;
          fourDimWrite(O, b, h, q_row_idx, v_col_idx, H, N, d, sum);
        }
          
        
      }
    }
  }
    

  // DO NOT EDIT THIS RETURN STATEMENT //
  // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
  return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
REFERENCE - FUSED ATTENTION statistics
cpu time:  56.506ms
mem usage:  557056 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 11:53:42 171611:171611 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
STAGE:2024-11-19 11:53:42 171611:171611 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 11:53:42 171611:171611 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
    main()
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 315, in main
    part3Test(N, d, B, H)
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 245, in part3Test
    testTemplate(attentionModuleReference.myFusedAttention, params, "STUDENT - FUSED ATTENTION")
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
    assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
 YOUR ATTENTION PRODUCED INCORRECT RESULTS

Found the root cause . It’s becaue of wrong type definition.

        for(int v_col_idx=0; v_col_idx < d; v_col_idx++) {
          float sum =0.0;
          for(int v_row_idx=0; v_row_idx < N; v_row_idx++) {
            // This is the correct definition
            float v_val = fourDimRead(V, b, h, v_row_idx, v_col_idx, H, N ,d);
            sum += v_val * ORow[v_row_idx];
          }
          std::cout << "vcold_idx" << v_col_idx << "val: " << sum << std::endl;
          fourDimWrite(O, b, h, q_row_idx, v_col_idx, H, N, d, sum);
        }

Output:

It’s too slow. I don’t know why

REFERENCE - FUSED ATTENTION statistics
cpu time:  59.932ms
mem usage:  557056 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 13:07:43 578930:578930 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 13:07:43 578930:578930 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 13:07:43 578930:578930 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time:  0.3012988567352295

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::empty         0.01%      17.000us         0.01%      17.000us       1.889us       1.04 Mb       1.04 Mb             9
                  aten::clone         0.00%      15.000us         0.08%     247.000us     123.500us       1.00 Mb           0 b             2
                  aten::zeros         0.01%      19.000us         0.02%      65.000us      21.667us     548.00 Kb           0 b             3
    STUDENT - FUSED ATTENTION        95.37%     287.422ms        99.68%     300.382ms     300.382ms     544.00 Kb      -1.00 Mb             1
          aten::empty_strided         0.00%       9.000us         0.00%       9.000us       1.800us     512.00 Kb     512.00 Kb             5
              model_inference         0.21%     636.000us       100.00%     301.361ms     301.361ms     512.00 Kb     -32.63 Kb             1
                aten::flatten         1.42%       4.278ms         1.48%       4.467ms       1.089us     512.00 Kb           0 b          4101
             aten::empty_like         0.00%       3.000us         0.00%       5.000us       5.000us     512.00 Kb           0 b             1
                     aten::to         0.00%       6.000us         0.01%      32.000us       5.333us         520 b           0 b             6
               aten::_to_copy         0.01%      18.000us         0.01%      26.000us       6.500us         520 b         516 b             4
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 301.361ms

STUDENT - FUSED ATTENTION statistics
cpu time:  300.382ms
mem usage:  557056 bytes

Found the root cause. I don’t enable openmp in original code

The speed is comparable after enabling openmp Code:

torch::Tensor myFusedAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor temp,
              int B, int H, int N, int d){

  // Q, K, V are passed in with Shape: (B, H, N, d)

  //Make O Tensor with Shape (B, H, N, d)
  //and O Row Tensor with Shape (N)
  at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
  at::Tensor ORowTensor = at::zeros({N}, at::kFloat);

  //Format Y, Q, K, and V tensors into 4D vectors
  std::vector<float> O = formatTensor(OTensor);
  std::vector<float> Q = formatTensor(QTensor);
  std::vector<float> K = formatTensor(KTensor);
  std::vector<float> V = formatTensor(VTensor);
  
  //Format ORow Tensor into a 1D vector
  // You can simply access this as ORow[i]
  std::vector<float> ORow = formatTensor(ORowTensor);


  // -------- YOUR CODE HERE  -------- //
  // We give you a template of the first three loops for your convenience
  //loop over batch

// Need to enable this openmp for parallel execution for all rows in batches and heads.
  #pragma omp parallel for collapse(3)
  for (int b = 0; b < B; b++){
    //loop over heads
    for (int h = 0; h < H; h++){
        for (int q_row_idx = 0; q_row_idx < N ; q_row_idx++){

  // YRow is moved inside so each OpenMP thread gets a local copy.
            at::Tensor ORowTensor = temp.index({torch::indexing::Slice(omp_get_thread_num(), torch::indexing::None)});      
            std::vector<float> ORow = formatTensor(ORowTensor);
  //YOUR CODE HERE
        for(int k_row_idx=0; k_row_idx < N; k_row_idx++) {
          float val = 0.0;
          for(int d_idx=0; d_idx < d ;d_idx++ ) {
            int q_row = q_row_idx;
            int q_col = d_idx;
            int k_row = k_row_idx;
            int k_col = d_idx;
            float q_val = fourDimRead(Q, b, h, q_row, q_col, H, N, d);
            float k_val = fourDimRead(K, b, h, k_row, k_col, H, N, d) ;
            val += q_val * k_val;
          }
          ORow[k_row_idx] = val;
          // std::cout << "krowidx: " << k_row_idx << " val: " << ORow[k_row_idx] << std::endl;

        }
        // softmax
        std::vector<float> tmp_row_res(N, 0.0);
        float sum = 0.0;
        for(int i=0; i < N; i++) {
          ORow[i]  = std::exp(ORow[i]) ;
          sum += ORow[i];
          // tmp_row_res[i] = exp_val;
        }
        for(int i=0; i < N; i++) {
          float prob = ORow[i]  /  sum;
          ORow[i] = prob;
          // std::cout << "softmax col: "  << i << " val: " << ORow[i] << std::endl;
        }

        for(int v_col_idx=0; v_col_idx < d; v_col_idx++) {
          float sum =0.0;
          for(int v_row_idx=0; v_row_idx < N; v_row_idx++) {
            float v_val = fourDimRead(V, b, h, v_row_idx, v_col_idx, H, N ,d);
            sum += v_val * ORow[v_row_idx];
          }
          // std::cout << "vcold_idx" << v_col_idx << "val: " << sum << std::endl;
          fourDimWrite(O, b, h, q_row_idx, v_col_idx, H, N, d, sum);
        }
          
        
      }
    }
  }
    

  // DO NOT EDIT THIS RETURN STATEMENT //
  // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
  return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
REFERENCE - FUSED ATTENTION statistics
cpu time:  56.526ms
mem usage:  557056 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 17:17:41 1945272:1945272 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-19 17:17:41 1945272:1945272 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-19 17:17:41 1945272:1945272 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
manual attention == pytorch attention True
Manual Execution Time:  0.04986262321472168

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::empty         0.02%      12.000us         0.02%      12.000us       1.333us       1.04 Mb       1.04 Mb             9
                  aten::clone         0.04%      19.000us         0.42%     208.000us     104.000us       1.00 Mb           0 b             2
                  aten::zeros         0.04%      19.000us         0.14%      68.000us      22.667us     548.00 Kb           0 b             3
    STUDENT - FUSED ATTENTION        94.35%      47.083ms        98.27%      49.043ms      49.043ms     544.00 Kb      -1.00 Mb             1
          aten::empty_strided         0.02%       8.000us         0.02%       8.000us       1.600us     512.00 Kb     512.00 Kb             5
              model_inference         1.11%     555.000us       100.00%      49.904ms      49.904ms     512.00 Kb     -32.63 Kb             1
                aten::flatten         1.10%     549.000us         1.50%     748.000us       1.447us     512.00 Kb           0 b           517
             aten::empty_like         0.01%       3.000us         0.01%       4.000us       4.000us     512.00 Kb           0 b             1
                     aten::to         0.01%       6.000us         0.06%      29.000us       4.833us         520 b           0 b             6
               aten::_to_copy         0.03%      14.000us         0.05%      23.000us       5.750us         520 b         260 b             4
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 49.904ms

STUDENT - FUSED ATTENTION statistics
cpu time:  49.043ms
mem usage:  557056 bytes

Part4: Flash attention

gpt gives good explanation about how Oi forumala is obtained.

Basically previous O_i is divided by l_i . So we need to multiply O_i by l_i when calculating new O_i with new l_i according to this expression.

[ O_i = \frac{\sum_{\text{old } j} \alpha_{ij} V_j + \sum_{\text{new } j} \alpha_{ij} V_j}{\sum_{\text{old } j} \alpha_{ij} + \sum_{\text{new } j} \alpha_{ij}}. ]

The formula

[ O_i = \frac{l_i O_i + P_{ij} V_j}{l_{\text{new}}} ]

and its update mechanism for the cumulative sum ( l_i O_i ) and normalization factor ( l_i ) works because of the linearity of the softmax-weighted sum and the incremental construction of softmax normalization. Let’s derive and explain why this works step-by-step.


1. Standard Attention Mechanism

In the standard attention mechanism, the output ( O_i ) for query ( Q_i ) is defined as:

[ O_i = \frac{\sum_{j=1}^N \alpha_{ij} V_j}{\sum_{j=1}^N \alpha_{ij}}, ]

where:

  • ( \alpha_{ij} = \exp(Q_i K_j^T) ): the unnormalized attention score between query ( Q_i ) and key ( K_j ).
  • ( P_{ij} = \frac{\alpha_{ij}}{\sum_{k=1}^N \alpha_{ik}} ): the normalized attention score (softmax).
  • ( V_j ): the value vector associated with key ( K_j ).

Rewriting ( O_i ) explicitly:

[ O_i = \frac{\sum_{j=1}^N \alpha_{ij} V_j}{Z_i}, ]

where ( Z_i = \sum_{j=1}^N \alpha_{ij} ) is the normalization constant for row ( i ).


2. Chunked Computation in Flash Attention

To handle memory constraints, Flash Attention processes the attention computation in chunks of ( j ). Suppose we process a chunk of keys/values indexed by ( j ) in each step. Let:

  • ( l_i^{\text{old}} ) be the normalization constant ( Z_i ) for the previously processed chunks.
  • ( O_i^{\text{old}} ) be the partial output for the previously processed chunks.

The goal is to update ( O_i ) and ( l_i ) incrementally when a new chunk is processed.

For the new chunk, compute:

  • ( \alpha_{ij} ) for the current chunk.
  • The sum of unnormalized scores for the chunk: ( l_{\text{chunk}} = \sum_{j \in \text{chunk}} \alpha_{ij} ).
  • The weighted value contribution: ( O_{\text{chunk}} = \sum_{j \in \text{chunk}} \alpha_{ij} V_j ).

The updated values are:

[ l_i^{\text{new}} = l_i^{\text{old}} + l_{\text{chunk}} ] [ O_i^{\text{new}} = \frac{l_i^{\text{old}} O_i^{\text{old}} + O_{\text{chunk}}}{l_i^{\text{new}}}. ]


3. Why This Works: Derivation

a. Weighted Combination of Outputs

The full output ( O_i ) over all chunks can be expressed as a weighted combination of the contributions from the old and new chunks:

[ O_i = \frac{\sum_{\text{all } j} \alpha_{ij} V_j}{\sum_{\text{all } j} \alpha_{ij}}. ]

Breaking this into previously processed chunks (old) and the current chunk:

[ O_i = \frac{\sum_{\text{old } j} \alpha_{ij} V_j + \sum_{\text{new } j} \alpha_{ij} V_j}{\sum_{\text{old } j} \alpha_{ij} + \sum_{\text{new } j} \alpha_{ij}}. ]

Let:

  • ( l_i^{\text{old}} = \sum_{\text{old } j} \alpha_{ij} ),
  • ( l_{\text{chunk}} = \sum_{\text{new } j} \alpha_{ij} ),
  • ( O_i^{\text{old}} = \frac{\sum_{\text{old } j} \alpha_{ij} V_j}{l_i^{\text{old}}} ),
  • ( O_{\text{chunk}} = \sum_{\text{new } j} \alpha_{ij} V_j ).

Substituting:

[ O_i = \frac{l_i^{\text{old}} O_i^{\text{old}} + O_{\text{chunk}}}{l_i^{\text{old}} + l_{\text{chunk}}}. ]

This matches the Flash Attention formula:

[ O_i = \frac{l_i O_i + P_{ij} V_j}{l_{\text{new}}}. ]


b. Incremental Normalization

The normalization factor ( l_i ) (equivalent to ( Z_i ) in standard attention) is computed incrementally:

[ l_i^{\text{new}} = l_i^{\text{old}} + \sum_{j \in \text{chunk}} \alpha_{ij}. ]

This ensures that the normalization is consistent across all processed chunks, allowing the final ( O_i ) to represent the output for the full sequence.


4. Benefits of This Approach

  1. Memory Efficiency:
    • By processing chunks incrementally, there is no need to store the full ( N \times N ) attention matrix or intermediate results.
  2. Numerical Stability:
    • The incremental updates prevent large sums from growing uncontrollably, which could cause overflow/underflow in standard softmax computations.
  3. Linearity of Weighted Sum:
    • The linearity of summation ensures that partial results (( l_i^{\text{old}} O_i^{\text{old}} )) can be combined directly with new contributions without recomputation.

5. Conclusion

The Flash Attention update formula works because:

  • It incrementally combines contributions from each chunk using the linearity of the softmax-weighted sum.
  • The normalization factor ( l_i ) ensures consistency across chunks, so the final result is equivalent to the full softmax computation.
  • This approach trades memory for compute, making it feasible to handle large inputs while maintaining mathematical equivalence to standard attention.

Incorrect implementation:

torch::Tensor myFlashAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor,
               torch::Tensor QiTensor, torch::Tensor KjTensor, torch::Tensor VjTensor,
               torch::Tensor SijTensor, torch::Tensor PijTensor, torch::Tensor PVTensor,
               torch::Tensor OiTensor, torch::Tensor LTensor,  torch::Tensor LiTensor, 
	       torch::Tensor LijTensor, torch::Tensor LnewTensor, int Bc, int Br,
                int B, int H, int N, int d) {
      
  // Q, K, V are passed in with Shape: (B, H, N, d)
  // Sij, Pij are passed in with Shape: (Br, Bc)
  // Kj, Vj are passed in with Shape: (Bc, d)
  // Qi, Oi, and PV  are passed in with Shape: (Br, d)
  // L in passed in with Shape: (N)
  // Li, Lij, and Lnew are passed in with shape (Br)

  //Make O Tensor with Shape (B, H, N, d)
  at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
 
  //Format All Tensors into Vectors
  std::vector<float> O = formatTensor(OTensor);
  std::vector<float> Q = formatTensor(QTensor);
  std::vector<float> K = formatTensor(KTensor);
  std::vector<float> V = formatTensor(VTensor);
  std::vector<float> Sij = formatTensor(SijTensor);
  std::vector<float> Pij = formatTensor(PijTensor);
  std::vector<float> Kj = formatTensor(KjTensor);
  std::vector<float> Vj = formatTensor(VjTensor);
  std::vector<float> Qi = formatTensor(QiTensor);
  std::vector<float> Oi = formatTensor(OiTensor);
  std::vector<float> l = formatTensor(LTensor);
  std::vector<float> PV = formatTensor(PVTensor);
  std::vector<float> li = formatTensor(LiTensor);
  std::vector<float> lij = formatTensor(LijTensor);
  std::vector<float> lnew = formatTensor(LnewTensor);

  // -------- YOUR CODE HERE  -------- //
  for(int b=0; b < B; b++ ) {
    for(int h=0; h < H; h++) {
  for(int k_block_idx=0; k_block_idx < (N+Bc-1)/Bc; k_block_idx++) {
    // load Kj, Vj into local memory blocks.
    for(int j=0; j < Bc; j++) {
      int j_row = k_block_idx * Bc + j;
      if(j_row >= N) continue;
      for(int d_idx =0; d_idx < d; d_idx++) {
        float k_val = fourDimRead(K, b, h, j_row, d_idx, H, N, d);
        float v_val = fourDimRead(V, b, h, j_row, d_idx, H, N, d);
        twoDimWrite(Kj, j, d_idx, d, k_val);
        twoDimWrite(Vj, j, d_idx, d, v_val);
      }
    }

    for(int q_block_idx=0; q_block_idx < (N+Br-1)/Br; q_block_idx++) {
      // load Qi, Oi, li into local memory blocks
      for(int br_idx=0; br_idx < Br; br_idx++ ) {
        int q_row_idx = q_block_idx * Br + br_idx; 
        if(q_row_idx >= N ) continue;
        for(int d_idx=0; d_idx < d; d_idx++) {
          float q_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
          float o_val = fourDimRead(O, b, h, q_row_idx , d_idx, H, N, d);
          twoDimWrite(Qi, br_idx, d_idx, d, q_val);
          twoDimWrite(Oi, br_idx, d_idx, d, o_val);

        }
        float l_val = l[q_row_idx];
        li[br_idx] = l_val;

      }

      // compute Sij  = Qi * Kj_T (Br x Bc) 
      for(int br_idx=0; br_idx < Br; br_idx++) {
        for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
          float sum = 0.0;
          for(int d_idx=0; d_idx < d; d_idx++) {
            float q_val = twoDimRead(Qi, br_idx, d_idx, d);
            float k_val = twoDimRead(Kj, bc_idx, d_idx, d);
            sum += q_val * k_val;

          }
          twoDimWrite(Sij, br_idx, bc_idx, d, sum);
        }
      }

      // Compute Pij = exp(Sij) of size (Br x Bc)
      for(int br_idx=0; br_idx < Br; br_idx++) {
        for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
          float exp_val = std::exp(twoDimRead(Sij, br_idx, bc_idx, Bc));
          twoDimWrite(Pij, br_idx, bc_idx, Bc, exp_val);
        }
      }

      // Compute lij = rowsum(Pij) of size (Br)
      for(int br_idx=0; br_idx < Br; br_idx++) {
        float sum = 0.0;
        for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
          sum += twoDimRead(Pij, br_idx, bc_idx, Bc);
        }
        lij[br_idx] = sum;

      }

      // compute lnew = li + lij
      for(int br_idx=0; br_idx < Br; br_idx++) {
        lnew[br_idx] = li[br_idx] + lij[br_idx];
      }

      // Compute Oi = (liOi + PijVj)/ lnew
      for(int br_idx=0; br_idx < Br; br_idx++) {
        for(int d_idx=0; d_idx < d; d_idx++) {
          float pv_sum =0.0;
          for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
            int p_row = br_idx;
            int p_col = bc_idx;
            int v_row = bc_idx;
            int v_col = d_idx;
            pv_sum += twoDimRead(Pij, p_row, p_col, Bc) * twoDimRead(Vj, v_row, v_col, d);

          }

          float li_Oi_val = li[br_idx] * twoDimRead(Oi, br_idx, d_idx, d);
          float new_sum = pv_sum + li_Oi_val;
          float new_Oi_val = new_sum / lnew[br_idx];
          twoDimWrite(Oi, br_idx, d_idx, d, new_Oi_val);
        }
      }

      // Write Oi and lnew back to O and l in main memory
      for(int br_idx=0; br_idx < Br; br_idx++) {
        int O_row = q_block_idx * Br + br_idx;
        if(O_row >= N) break;
        for(int d_idx=0; d_idx < d; d_idx++) {
          float Oi_val = twoDimRead(Oi, br_idx, d_idx, d);
                  int O_col = d_idx;
          fourDimWrite(O, b, h, O_row, O_col, H, N, d, Oi_val);

        }

        l[O_row] = lnew[O_row];

      }


    }
  }

    }

  }


  // DO NOT EDIT THIS RETURN STATEMENT //
  // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
  return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}

Output:


Self CPU time total: 589.890ms

REFERENCE - FLASH ATTENTION statistics
cpu time:  588.591ms
mem usage:  524288 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([0.0771, 0.0779, 0.0787, 0.0795, 0.0803, 0.0811, 0.0819, 0.0827, 0.0835,
        0.0843, 0.0851, 0.0859, 0.0867, 0.0875, 0.0883, 0.0891, 0.0899, 0.0907,
        0.0915, 0.0923, 0.0931, 0.0939, 0.0947, 0.0955, 0.0963, 0.0971, 0.0979,
        0.0987, 0.0995, 0.1003, 0.1011, 0.1019])
STAGE:2024-11-20 14:25:23 442505:442505 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
student first row:
 tensor([0.0773, 0.0781, 0.0789, 0.0797, 0.0805, 0.0813, 0.0821, 0.0829, 0.0837,
        0.0845, 0.0853, 0.0861, 0.0869, 0.0877, 0.0885, 0.0893, 0.0901, 0.0909,
        0.0917, 0.0925, 0.0933, 0.0941, 0.0949, 0.0957, 0.0965, 0.0973, 0.0981,
        0.0989, 0.0997, 0.1005, 0.1013, 0.1021])
STAGE:2024-11-20 14:25:23 442505:442505 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-11-20 14:25:23 442505:442505 ActivityProfilerController.cpp:322] Completed Stage: Post Processing
Traceback (most recent call last):
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 329, in <module>
    main()
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 317, in main
    part4Test(N, d, B, H, int(args.bc), int(args.br))
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 257, in part4Test
    testTemplate(attentionModuleReference.myFlashAttention, params, "STUDENT - FLASH ATTENTION")
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 182, in testTemplate
    assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
AssertionError:
-------------------------------------------
 YOUR ATTENTION PRODUCED INCORRECT RESULTS

Asked gpt to find bug in my previous code

Some are corret , some are not.

There are several potential bugs and inefficiencies in the provided implementation of Flash Attention. Below is an analysis of common issues based on the code structure.


1. Bug in Writing lnew to Global Memory

At the end of the block:

l[O_row] = lnew[O_row];

This line is problematic because lnew[O_row] is incorrect. You are trying to write from lnew (a local array for the current query block) to the global l. The index O_row does not map directly to lnew, as lnew uses local indices within a chunk.

Fix:

l[O_row] = lnew[br_idx];

Here, br_idx is the correct index for lnew in the current query block.


2. Mismatch in twoDimRead and twoDimWrite for Sij

In the computation of ( S_{ij} ):

twoDimWrite(Sij, br_idx, bc_idx, d, sum);

You are using d as the stride instead of Bc. Sij is declared as having a shape (Br, Bc), so the stride in the twoDimWrite call should be Bc.

Fix:

twoDimWrite(Sij, br_idx, bc_idx, Bc, sum);

Similarly, update other uses of Sij to use the correct stride.


3. Incorrect Update of liOi

In the line:

float li_Oi_val = li[br_idx] * twoDimRead(Oi, br_idx, d_idx, d);

This reads the old Oi value from Oi, but Oi is already being updated in the loop. This introduces unintended feedback into the computation.

Fix: Make a copy of Oi at the start of each block before updates, and use that copy for the computation:

std::vector<float> Oi_copy = Oi; // Create a copy of Oi
float li_Oi_val = li[br_idx] * twoDimRead(Oi_copy, br_idx, d_idx, d);

4. Improper Memory Initialization

Variables like Kj, Vj, Qi, Oi, Sij, etc., are written to without ensuring their contents are initialized correctly. This might cause unexpected behavior if these tensors are not zeroed out.

Fix: Explicitly zero out these tensors before use:

std::fill(Kj.begin(), Kj.end(), 0.0f);
std::fill(Vj.begin(), Vj.end(), 0.0f);
std::fill(Qi.begin(), Qi.end(), 0.0f);
std::fill(Oi.begin(), Oi.end(), 0.0f);
std::fill(Sij.begin(), Sij.end(), 0.0f);

5. Incorrect Stride in twoDimRead and twoDimWrite for Block Matrices

Some block matrices (e.g., Pij, Kj, Vj, Qi) use strides inconsistently. Double-check the stride parameter in each twoDimRead and twoDimWrite call to match the matrix dimensions.


6. Inefficient Computation of lij and lnew

The computation of lij (row sum of Pij) and lnew involves multiple sequential loops that can be combined for better performance. For example:

for(int br_idx=0; br_idx < Br; br_idx++) {
  float sum = 0.0;
  for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
    sum += twoDimRead(Pij, br_idx, bc_idx, Bc);
  }
  lij[br_idx] = sum;
  lnew[br_idx] = li[br_idx] + sum;
}

This eliminates one loop.


7. Numerical Stability in Softmax

The computation of ( P_{ij} = \exp(S_{ij}) ) can suffer from numerical instability if ( S_{ij} ) contains large values. This can lead to overflow or inaccurate results.

Fix: Subtract the maximum value in each row of ( S_{ij} ) before applying the exponential function:

for(int br_idx=0; br_idx < Br; br_idx++) {
  float max_val = -std::numeric_limits<float>::infinity();
  for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
    float sij_val = twoDimRead(Sij, br_idx, bc_idx, Bc);
    max_val = std::max(max_val, sij_val);
  }
  for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
    float sij_val = twoDimRead(Sij, br_idx, bc_idx, Bc) - max_val;
    twoDimWrite(Sij, br_idx, bc_idx, Bc, sij_val);
  }
}

This ensures the softmax computation is numerically stable.


8. Inefficient Loop Nesting

The loop order might lead to poor memory access patterns, especially for Kj, Vj, and Qi. Consider transposing these matrices in local memory to improve cache efficiency.


9. Potential Misalignment in Block Indexing

When processing blocks of size Br and Bc, ensure that the indexing properly handles edge cases where N is not evenly divisible by Br or Bc. Currently, edge handling appears inconsistent.

Fix: Use bounds checking consistently:

if (q_row_idx >= N || j_row >= N) continue;

10. Output Tensor Not Updated Properly

The output tensor OTensor may not be updated properly due to indexing or memory mismanagement.

Fix: Ensure that updates to O directly correspond to OTensor:

for(int br_idx=0; br_idx < Br; br_idx++) {
  int O_row = q_block_idx * Br + br_idx;
  if(O_row >= N) continue;
  for(int d_idx=0; d_idx < d; d_idx++) {
    float Oi_val = twoDimRead(Oi, br_idx, d_idx, d);
    fourDimWrite(O, b, h, O_row, d_idx, H, N, d, Oi_val);
  }
}

Summary of Key Fixes

  1. Correct indexing when writing lnew to l.
  2. Fix strides in twoDimRead and twoDimWrite for block matrices.
  3. Zero-initialize temporary tensors.
  4. Apply numerical stabilization for the softmax computation.
  5. Combine loops for lij and lnew.
  6. Address potential misalignment when N is not a multiple of block sizes.

These fixes ensure correctness, improve numerical stability, and enhance efficiency.

Add initialization of 0.0 to Kj and Vj, to avoid impact of leftover numnber in last iteration

I realize this problem from python diff check and found that value is not equal at the last row of matrix.

Mismatch at index (0, 3, 1023, 22): QKV=0.13594280183315277, QKS1=0.03398570418357849
Mismatch at index (0, 3, 1023, 23): QKV=0.13674280047416687, QKS1=0.03418571129441261
Mismatch at index (0, 3, 1023, 24): QKV=0.13754287362098694, QKS1=0.03438570350408554
Mismatch at index (0, 3, 1023, 25): QKV=0.13834281265735626, QKS1=0.03458569571375847
Mismatch at index (0, 3, 1023, 26): QKV=0.13914281129837036, QKS1=0.03478570282459259
Mismatch at index (0, 3, 1023, 27): QKV=0.13994282484054565, QKS1=0.034985702484846115
Mismatch at index (0, 3, 1023, 28): QKV=0.14074280858039856, QKS1=0.03518570587038994
Mismatch at index (0, 3, 1023, 29): QKV=0.14154279232025146, QKS1=0.035385698080062866
Mismatch at index (0, 3, 1023, 30): QKV=0.14234280586242676, QKS1=0.03558569401502609
Mismatch at index (0, 3, 1023, 31): QKV=0.14314278960227966, QKS1=0.03578570485115051
Traceback (most recent call last):
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 339, in <module>
    main()
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 327, in main
    part4Test(N, d, B, H, int(args.bc), int(args.br))
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 267, in part4Test
    testTemplate(attentionModuleReference.myFlashAttention, params, "STUDENT - FLASH ATTENTION")
  File "/home/zt/stf-cs149-pp/cs149gpt/gpt149.py", line 190, in testTemplate
    raise AssertionError(correctness_error_message)
AssertionError:
-------------------------------------------
 YOUR ATTENTION PRODUCED INCORRECT RESULTS

Found that result is incorrect starting from second head.

I think it’s becase that some itermediate variable is not cleared before entering next head which leads to incorrect answer.

Mismatch at index (0, 1, 0, 0): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 1): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 2): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 3): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 4): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 5): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 6): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 7): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 8): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 9): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 10): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 11): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 12): QKV=1.0, QKS1=0.5
Mismatch at index (0, 1, 0, 13): QKV=1.0, QKS1=0.5
M

Found the root cause and fix it.

I need to clear l for each head. This is because l is read to li for each Q- block

Code:

torch::Tensor myFlashAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor,
               torch::Tensor QiTensor, torch::Tensor KjTensor, torch::Tensor VjTensor,
               torch::Tensor SijTensor, torch::Tensor PijTensor, torch::Tensor PVTensor,
               torch::Tensor OiTensor, torch::Tensor LTensor,  torch::Tensor LiTensor, 
	       torch::Tensor LijTensor, torch::Tensor LnewTensor, int Bc, int Br,
                int B, int H, int N, int d) {
      
  // Q, K, V are passed in with Shape: (B, H, N, d)
  // Sij, Pij are passed in with Shape: (Br, Bc)
  // Kj, Vj are passed in with Shape: (Bc, d)
  // Qi, Oi, and PV  are passed in with Shape: (Br, d)
  // L in passed in with Shape: (N)
  // Li, Lij, and Lnew are passed in with shape (Br)

  //Make O Tensor with Shape (B, H, N, d)
  at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
 
  //Format All Tensors into Vectors
  std::vector<float> O = formatTensor(OTensor);
  std::vector<float> Q = formatTensor(QTensor);
  std::vector<float> K = formatTensor(KTensor);
  std::vector<float> V = formatTensor(VTensor);
  std::vector<float> Sij = formatTensor(SijTensor); //clear
  std::vector<float> Pij = formatTensor(PijTensor); //clear
  std::vector<float> Kj = formatTensor(KjTensor); // clear
  std::vector<float> Vj = formatTensor(VjTensor); // clear
  std::vector<float> Qi = formatTensor(QiTensor); // clear
  std::vector<float> Oi = formatTensor(OiTensor); //clear
  std::vector<float> l = formatTensor(LTensor); // This should be cleared
  std::vector<float> PV = formatTensor(PVTensor);
  std::vector<float> li = formatTensor(LiTensor);
  std::vector<float> lij = formatTensor(LijTensor);
  std::vector<float> lnew = formatTensor(LnewTensor);

  // std::cout << "br:" << Br << " bc:" << Bc <<std::endl;
  // -------- YOUR CODE HERE  -------- //
  for(int b=0; b < B; b++ ) {
    for(int h=0; h < H; h++) {

    // This line is essential to correctness.
    std::fill(l.begin(), l.end(), 0.0f);
    std::fill(lnew.begin(), lnew.end(), 0.0f);
    std::fill(lij.begin(), lij.end(), 0.0f);
  for(int k_block_idx=0; k_block_idx < (N+Bc-1)/Bc; k_block_idx++) {
    std::fill(Kj.begin(), Kj.end(), 0.0f);
    std::fill(Vj.begin(), Vj.end(), 0.0f);
    // load Kj, Vj into local memory blocks.
    for(int j=0; j < Bc; j++) {
      int j_row = k_block_idx * Bc + j;
      if(j_row >= N) continue;
      for(int d_idx =0; d_idx < d; d_idx++) {
        float k_val = fourDimRead(K, b, h, j_row, d_idx, H, N, d);
        float v_val = fourDimRead(V, b, h, j_row, d_idx, H, N, d);
        twoDimWrite(Kj, j, d_idx, d, k_val);
        twoDimWrite(Vj, j, d_idx, d, v_val);
          // std::cout<< "j:" << j_row << " col:" << d_idx << "kj:" << k_val << " vj:" << v_val << std::endl;
      }
    }

    for(int q_block_idx=0; q_block_idx < (N+Br-1)/Br; q_block_idx++) {
      std::fill(Qi.begin(), Qi.end(), 0.0f);
      std::fill(Oi.begin(), Oi.end(), 0.0f);
      std::fill(Sij.begin(), Sij.end(), 0.0f);
      std::fill(Pij.begin(), Pij.end(), 0.0f);


      // load Qi, Oi, li into local memory blocks
      for(int br_idx=0; br_idx < Br; br_idx++ ) {
        int q_row_idx = q_block_idx * Br + br_idx; 
        if(q_row_idx >= N ) continue;
        for(int d_idx=0; d_idx < d; d_idx++) {
          float q_val = fourDimRead(Q, b, h, q_row_idx, d_idx, H, N, d);
          float o_val = fourDimRead(O, b, h, q_row_idx , d_idx, H, N, d);
          twoDimWrite(Qi, br_idx, d_idx, d, q_val);
          twoDimWrite(Oi, br_idx, d_idx, d, o_val);
            // std::cout << "q_row_idx:" << q_row_idx << " d_idx:" << d_idx << " Qi:" << q_val << " Oi:" << o_val <<std::endl;

        }
        float l_val = l[q_row_idx];
        li[br_idx] = l_val;
            // std::cout << "li:" << l_val << std::endl;

      }

      // compute Sij  = Qi * Kj_T (Br x Bc) 
      for(int br_idx=0; br_idx < Br; br_idx++) {
        for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
          float sum = 0.0;
          for(int d_idx=0; d_idx < d; d_idx++) {
            float q_val = twoDimRead(Qi, br_idx, d_idx, d);
            float k_val = twoDimRead(Kj, bc_idx, d_idx, d);
            sum += q_val * k_val;

          }
          twoDimWrite(Sij, br_idx, bc_idx, Bc, sum);
              // std::cout << "sij, br:" << br_idx << " bc:" << bc_idx << " val:" << sum << std::endl;
        }
      }

      // Compute Pij = exp(Sij) of size (Br x Bc)
      for(int br_idx=0; br_idx < Br; br_idx++) {
        for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
          float exp_val = std::exp(twoDimRead(Sij, br_idx, bc_idx, Bc));
          twoDimWrite(Pij, br_idx, bc_idx, Bc, exp_val);
        }
      }

      // Compute lij = rowsum(Pij) of size (Br)
      for(int br_idx=0; br_idx < Br; br_idx++) {
        float sum = 0.0;
        for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
          sum += twoDimRead(Pij, br_idx, bc_idx, Bc);
        }
        lij[br_idx] = sum;
        // compute lnew = li + lij
        lnew[br_idx] = li[br_idx] + lij[br_idx];

      }



      // Compute Oi = (liOi + PijVj)/ lnew
      for(int br_idx=0; br_idx < Br; br_idx++) {
        for(int d_idx=0; d_idx < d; d_idx++) {
          float pv_sum =0.0;
          for(int bc_idx=0; bc_idx < Bc; bc_idx++) {
            int p_row = br_idx;
            int p_col = bc_idx;
            int v_row = bc_idx;
            int v_col = d_idx;
            pv_sum += twoDimRead(Pij, p_row, p_col, Bc) * twoDimRead(Vj, v_row, v_col, d);

          }
          // twoDimWrite(PV, br_idx, d_idx, d, pv_sum);

          float li_Oi_val = li[br_idx] * twoDimRead(Oi, br_idx, d_idx, d);
          float new_sum = pv_sum + li_Oi_val;
          float new_Oi_val = new_sum / lnew[br_idx];
          twoDimWrite(Oi, br_idx, d_idx, d, new_Oi_val);
        }
      }

      // Write Oi and lnew back to O and l in main memory
      for(int br_idx=0; br_idx < Br; br_idx++) {
        int O_row = q_block_idx * Br + br_idx;
        if(O_row >= N) continue;
        for(int d_idx=0; d_idx < d; d_idx++) {
          float Oi_val = twoDimRead(Oi, br_idx, d_idx, d);
                  int O_col = d_idx;
          fourDimWrite(O, b, h, O_row, O_col, H, N, d, Oi_val);

        }

        l[O_row] = lnew[br_idx];

      }


    }
  }

    }

  }


  // DO NOT EDIT THIS RETURN STATEMENT //
  // It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
  return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}

OUtput:

My implementation is 3x faster than ref implementation. I don’t know why. Is it because that I don’t materialize PV?

               model_inference         0.11%     771.000us       100.00%     721.392ms     721.392ms     512.00 Kb     -10.53 Kb             1
    REFERENCE - FLASH ATTENTION        99.56%     718.217ms        99.83%     720.142ms     720.142ms     512.00 Kb      -8.00 Mb             1
                       aten::to         0.00%       5.000us         0.00%      33.000us       5.500us         520 b           0 b             6
                 aten::_to_copy         0.00%      16.000us         0.00%      28.000us       7.000us         520 b           8 b             4
                      aten::abs         0.00%      30.000us         0.01%      54.000us      13.500us         512 b         256 b             4
            aten::empty_strided         0.00%       3.000us         0.00%       3.000us       0.750us         512 b         512 b             4
                 aten::isfinite         0.00%      12.000us         0.01%      89.000us      89.000us         224 b           0 b             1
            aten::masked_select         0.00%      14.000us         0.01%      43.000us      43.000us         128 b         120 b             1
-------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 721.392ms

REFERENCE - FLASH ATTENTION statistics
cpu time:  720.142ms
mem usage:  524288 bytes
-----RUNNING STUDENT IMPLEMENTATION-----

first row value:
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
student first row:
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
manual attention == pytorch attention True
Manual Execution Time:  0.21289777755737305

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                  aten::empty         0.01%      16.000us         0.01%      16.000us       0.889us       1.00 Mb       1.00 Mb            18
                  aten::clone         0.01%      25.000us         0.16%     332.000us     166.000us       1.00 Mb           0 b             2
                  aten::zeros         0.02%      36.000us         0.04%      83.000us       6.917us     521.59 Kb       5.00 Kb            12
          aten::empty_strided         0.00%      10.000us         0.00%      10.000us       2.000us     512.51 Kb     512.51 Kb             5
              model_inference         0.34%     722.000us       100.00%     212.937ms     212.937ms     512.00 Kb     -15.47 Kb             1
    STUDENT - FLASH ATTENTION        99.25%     211.341ms        99.44%     211.746ms     211.746ms     512.00 Kb      -1.00 Mb             1
                aten::flatten         0.01%      26.000us         0.10%     216.000us      14.400us     512.00 Kb           0 b            15
             aten::empty_like         0.00%       4.000us         0.00%       6.000us       6.000us     512.00 Kb           0 b             1
                  aten::zero_         0.00%       5.000us         0.02%      35.000us       2.917us       5.00 Kb       5.00 Kb            12
                     aten::to         0.01%      16.000us         0.01%      29.000us       4.833us         520 b           4 b             6
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 212.937ms

STUDENT - FLASH ATTENTION statistics
cpu time:  211.746ms
mem usage:  524288 bytes



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Learning-based memory allocation for C++ server workloads summary
  • my question:
  • Binary search algorithm variant
  • Docker Rocksdb build
  • Difference between Dockerfile and Docker Compose