nanogpt kv cache first attempt

1. Run basic nano-gpt

git clone https://github.com/karpathy/nanoGPT.git

Install necessary packages

pip install -r requirements.txt

I have these packages in the requirements.txt

blobfile==2.0.1
certifi==2022.12.7
charset-normalizer==3.0.1
filelock==3.9.0
idna==3.4
lxml==4.9.2
numpy==1.24.2
pycryptodomex==3.17
pytz==2022.7.1
regex==2022.10.31
requests==2.28.2
tokenizers==0.13.2
torch==2.0.0
typing_extensions==4.4.0
urllib3==1.26.14
torch==2.0.0
numpy==1.24.2
transformers==4.28.1
datasets==2.11.0
tiktoken==0.3.3
wandb==0.14.2
tqdm==4.65.0

Follow quick start guidance in nanogpt repo do make sure that we can run training and inference successfully.

python data/shakespeare_char/prepare.py
python train.py --compile=False config/train_shakespeare_char.py
python sample.py --out_dir=out-shakespeare-char

My python version is 3.11 which is too high for model compile so I added --compile=False in train command.

With my A800 gpu, I get a loss 0.0449 after 5000 iteration training.

iter 4970: loss 0.0461, time 18.12ms, mfu 20.21%
iter 4980: loss 0.0441, time 18.14ms, mfu 20.24%
iter 4990: loss 0.0464, time 18.13ms, mfu 20.27%
step 5000: train loss 0.0383, val loss 4.7262
iter 5000: loss 0.0449, time 3352.84ms, mfu 18.26%

2. Load GPT-2 models checkpoints and test performance

https://stackoverflow.com/questions/75110981/sslerror-httpsconnectionpoolhost-huggingface-co-port-443-max-retries-exce

proxy error while trying to download gpt2 model from huggingface: https://github.com/huggingface/transformers/issues/17611

First downgrad requests version to 2.27.1

pip install requests==2.27.1

And then adding these two lines of code in train.py and sample.py fix the proxy connection issue for me

os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT']= 'https://hf-mirror.com'

Run sample.py to get a test of gpt2 model with params downloaded from huggingface.

 python sample.py --init_from='gpt2'

I tried to start with “please tell me a joke.” The output is not anything like joke but still very readable.

please tell me a joke

[…]

My name is Zarek, but I am extremely sad for you.

You can't even come to my house anymore

I'm sorry, I know

I have a dream

I don't know how long this thing will last

My name Is Zarek

I'm an adult who believes that

The problem with your friend is that he doesnt know

He doesn't know how to act

running time for 10 times inference:

---------------
Elapsed time: 25.4s

3. Implement KV cache for faster inference

Commit hisotry for kv cache implementation

Please check code above for implementation details.

Issue:

shape of past k proj is  torch.Size([1, 12, 946, 64])
shape of k is  torch.Size([1, 12, 44, 64]) shape of v is  torch.Size([1, 12, 44, 64])
q len is  45
shape of past k proj is  torch.Size([1, 12, 990, 64])
shape of k is  torch.Size([1, 12, 45, 64]) shape of v is  torch.Size([1, 12, 45, 64])
Traceback (most recent call last):
  File "/GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/sample.py", line 93, in <module>
    y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py", line 359, in generate
    logits, _, past_kv_proj = self(idx_cond, past_kv_proj=past_kv_proj,start_pos=start_pos)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py", line 204, in forward
    x, layer_kv_proj = block(x, past_kv_proj=past_kv_proj[i])
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py", line 122, in forward
    attn_res, present_kv_proj = self.attn(self.ln_1(x), past_kv_proj=past_kv_proj)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py", line 78, in forward
    assert KV < self.block_size, f"KV: {KV} >= block_size: {self.block_size}"
           ^^^^^^^^^^^^^^^^^^^^
AssertionError: KV: 1035 >= block_size: 1024
yhrun: error: gpu73: task 0: Exited with exit code 1
(nano-gpt-kv-cache) [nsccgz_qylin_1@ln101 nano-gpt-kv-cache]$

Fix

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            # This is the righ condition
            if idx.size(1)  == T:
                idx_cond = idx
                start_pos = 0
            else:
                idx_cond = idx[:, [-1]]
                start_pos = idx.size(1) - 1

The limitation of this code is that it can only handles condition where max_new_tokens < self.config.block_size

I don’t know why yet.

4. Test KV cache performance

The commit mentions that it only brings performance boost with cpu but not on A100 gpu. Why is that ? Is this because that linear projections can be quickly done with fast gpu matrix multiplication?

This commit and this discussion talks about how to handle long text generation. I have not yet understanded it completely how it deals with long text geneartion.

There is a technique called rotary positional embeddings as mentioned in this commit. But I don’t know how does it works yet. And all I want to do right now is to simply test how kv cache helps with inference speed.

My naive solution right now is to simply cut past_kv_proj to latest self.config.block_size tokens

            if past_kv_proj is not None:
                past_k_proj, past_v_proj = past_kv_proj
                print('shape of past k proj is ', past_k_proj.shape )
                print('shape of k is ', k.shape, 'shape of v is ', v.shape)
                if KV >= self.block_size:
                    past_k_proj = past_k_proj[:, :, -self.block_size:, :]
                    past_v_proj = past_v_proj[:, :, -self.block_size:, :]
                k = torch.cat((past_k_proj, k), dim=2)
                v = torch.cat((past_v_proj, v), dim=2)

gpu v100

with kv cache, no flash attention

 yhrun -p gpu_v100  python   sample.py --init_from='gpt2'  --use_kv_cache=True --dtype=float32  --num_samples=10 --max_
new_tokens=1000

time:

---------------
Elapsed time: 102.6s

memory: almost the same for peak memory usage ?

without kv cache, no flash attention

python   sample.py --init_from='gpt2'  --use_kv_cache=False --dtype=float32  --num_samples=10 --max_new_tokens=1000

time:

Elapsed time: 151.8s

memory:

Saves 30% time. Not bad.

Function profile: I run 10 times, each with 1000 generation sequence length.

with kv cache:

      15990616 function calls (14380614 primitive calls) in 110.086 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   120000   21.130    0.000   56.527    0.000 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:58(forward)
   490000   14.545    0.000   14.545    0.000 {built-in method torch._C._nn.linear}
1620000/10000   13.119    0.000  103.105    0.010 /GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1494(_call_impl)
  3420000    9.789    0.000    9.789    0.000 /GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1601(__getattr__)
   120000    5.108    0.000   97.206    0.001 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:135(forward)
   250000    4.809    0.000    4.809    0.000 {built-in method torch.layer_norm}
   120000    3.840    0.000    3.840    0.000 {method 'masked_fill' of 'torch._C._TensorBase' objects}
       10    3.737    0.374  109.808   10.981 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:350(generate)
   250000    3.694    0.000    3.694    0.000 {built-in method torch.cat}
   120000    2.772    0.000   20.640    0.000 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:113(forward)
   370000    2.268    0.000    3.704    0.000 /GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py:1235(dropout)

    10000    2.075    0.000  102.922    0.010 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:206(forward)

without kv cache

       15380126 function calls (13770124 primitive calls) in 154.658 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)

    10000   45.431    0.005  148.326    0.015 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:206(forward)

   120000   22.252    0.000   58.175    0.000 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:58(forward)
   490000   17.891    0.000   17.891    0.000 {built-in method torch._C._nn.linear}
1620000/10000   13.188    0.000  148.485    0.015 /GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1494(_call_impl)
  3420000    9.488    0.000    9.488    0.000 /GPUFS/nsccgz_qylin_1/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py:1601(__getattr__)
   120000    5.289    0.000   99.160    0.001 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:135(forward)
   250000    5.037    0.000    5.037    0.000 {built-in method torch.layer_norm}
   120000    4.087    0.000    4.087    0.000 {method 'masked_fill' of 'torch._C._TensorBase' objects}
       10    2.952    0.295  154.334   15.433 /GPUFS/nsccgz_qylin_1/zt/nano-gpt-kv-cache/model.py:350(generate)

The main time difference comes from the function call to self attention block. Per call time for without kv cache gpt is 0.015 and it’s 0.010 for with kv cache. I guess this explains why benefit of kv cache for short sequence geneartion on A100 is negligible because it takes very short amount of time to generate key and value embedding with more advanced gpu.

500 tokens, cpu

with kv cache,

The law gives the government access to consumer information only if the government's purpose is to provide health care to the general public. If those
---------------
Elapsed time: 218.9s


The law gives the government access to consumer information only if the government's purpose is to provide health care to the general public. If those
---------------
Elapsed time: 251.4s

without kv cache

The law gives the government access to consumer information only if the government's purpose is to provide health care to the general public. If those
---------------
Elapsed time: 1191.4s

5 times inference time saving. Not bad.

The peak memory usage between with kv cache and without kv cache is nearly the same. This is because that sequence length is the same with or without kv cache. However, kv cache do bring some advantages. Here’s the answer from gpt.

Actually, there is a difference in memory usage when using KV cache for LLM inference. While it’s true that the maximum memory usage might be similar, the way memory is utilized and managed can vary significantly.

  1. Memory Allocation: With KV cache, memory is allocated for storing key-value pairs from previous computations. This can lead to more efficient memory usage as the model doesn’t need to recompute values, reducing the overall memory footprint during inference.
  2. Memory Management: KV cache helps in better memory management by reusing previously computed values. This can lead to more stable memory usage patterns, avoiding spikes in memory consumption that might occur without caching.
  3. Performance Optimization: By reducing redundant computations, KV cache can lead to faster inference times, which indirectly affects memory usage. Faster computations mean less time spent holding intermediate values in memory, leading to more efficient memory utilization.

References

youtube video llm kv cache explanation

requirements.txt to run nano-gpt

nano-gpt kv cache pr example

huggingface transformers kv cache source code on github

https://zhuanlan.zhihu.com/p/646577898

https://zhuanlan.zhihu.com/p/624740065

huggingface transformers API documentation




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Learning-based memory allocation for C++ server workloads summary
  • my question:
  • Binary search algorithm variant
  • Docker Rocksdb build
  • Difference between Dockerfile and Docker Compose