Bohrium
robot
新建

空间站广场

论文
Notebooks
比赛
课程
Apps
我的主页
我的Notebooks
我的论文库
我的足迹

我的工作空间

任务
节点
文件
数据集
镜像
项目
数据库
公开
Converting a From-Scratch GPT Architecture to Llama 2
Machine Learning
Machine Learning
xuxh@dp.tech
更新于 2024-11-12
推荐镜像 :Basic Image:bohrium-notebook:2023-04-07
推荐机型 :c2_m4_cpu
Converting a From-Scratch GPT Architecture to Llama 2
1. Convert the GPT model implementation step by step
1.1 Replace LayerNorm with RMSNorm layer
1.2 Replace GELU with SiLU activation
1.3 Update the FeedForward module
1.4 Implement RoPE
1.5 Add RoPE to MultiHeadAttention module
1.6 Update the TransformerBlock module
1.7 Update the model class
2. Initialize model
3. Load tokenizer
4. Load pretrained weights
5. Using the instruction-finetuned model
What's next?
代码
文本

Converting a From-Scratch GPT Architecture to Llama 2

代码
文本
  • In this notebook, we convert the original GPT architecture into a Llama 2 model step by step (note the GPT and GPT-2 share the same architecture)
  • Why not Llama 1 or Llama 3?
    • The Llama 1 architecture is similar to Llama 2, except that Llama 2 has a larger context window (which is nice); the Llama 1 weights are not readily available and have more usage restrictions, so it makes more sense to focus on Llama 2
    • Regarding Llama 3, I will share a separate notebook to convert Llama 2 to Llama 3 (there are only a few small additional changes)
  • The explanations are purposefully kept minimal in this notebook not to bloat it unnecessarily and focus on the main code
  • For more information, please see the Llama 2 paper: Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)
代码
文本
代码
文本
  • Packages that are being used in this notebook:
代码
文本
[1]
from importlib.metadata import version

pkgs = [
"huggingface_hub", # to download pretrained weights
"sentencepiece", # to implement the tokenizer
"torch", # to implement the model
]
for p in pkgs:
print(f"{p} version: {version(p)}")
huggingface_hub version: 0.24.7
sentencepiece version: 0.2.0
torch version: 2.4.1+cu121
代码
文本

 

1. Convert the GPT model implementation step by step

代码
文本
  • In this section, we go through the GPT model code from chapter 4 and modify it step by step to implement the Llama 2 architecture
  • Later, we load the original Llama 2 weights shared by Meta AI
代码
文本

 

1.1 Replace LayerNorm with RMSNorm layer

代码
文本
  • First, we replace LayerNorm by Root Mean Square Layer Normalization (RMSNorm)
  • LayerNorm normalizes inputs using mean and variance, while RMSNorm uses only the root mean square, which improves computational efficiency
  • The RMSNorm operation is as follows, where is the input is a trainable parameter (vector), and is a small constant to avoid zero-division errors:

代码
文本
[2]
import torch
import torch.nn as nn


#####################################
# Chapter 4
#####################################

# class LayerNorm(nn.Module):
# def __init__(self, emb_dim):
# super().__init__()
# self.eps = 1e-5
# self.scale = nn.Parameter(torch.ones(emb_dim))
# self.shift = nn.Parameter(torch.zeros(emb_dim))

# def forward(self, x):
# mean = x.mean(dim=-1, keepdim=True)
# var = x.var(dim=-1, keepdim=True, unbiased=False)
# norm_x = (x - mean) / torch.sqrt(var + self.eps)
# return self.scale * norm_x + self.shift


class RMSNorm(nn.Module):
def __init__(self, emb_dim, eps=1e-5):
super().__init__()
self.eps = eps
self.emb_dim = emb_dim
self.weight = nn.Parameter(torch.ones(emb_dim)).float()

def forward(self, x):
means = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(means + self.eps)
return (x_normed * self.weight).to(dtype=x.dtype)
代码
文本
  • The following code cell checks that this implementation works the same as PyTorch's built-in implementation:
代码
文本
[3]
torch.manual_seed(123)

example_batch = torch.randn(2, 3, 4)

rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))
代码
文本

 

1.2 Replace GELU with SiLU activation

代码
文本
  • Llama uses the SiLU activation function (instead of GELU), which is also known as the Swish function:

代码
文本
[4]
#####################################
# Chapter 4
#####################################

# class GELU(nn.Module):
# def __init__(self):
# super().__init__()

# def forward(self, x):
# return 0.5 * x * (1 + torch.tanh(
# torch.sqrt(torch.tensor(2.0 / torch.pi)) *
# (x + 0.044715 * torch.pow(x, 3))
# ))


class SiLU(nn.Module):
def __init__(self):
super(SiLU, self).__init__()

def forward(self, x):
return x * torch.sigmoid(x)
代码
文本
[5]
silu = SiLU()

assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))
代码
文本

 

1.3 Update the FeedForward module

代码
文本
  • In fact, Llama uses a "Gates Linear Unit" (GLU) variant of SiLU called SwiGLU, which essentially results in a slightly differently structured FeedForward module
  • SwiGLU uses a gating mechanism in the feedforward layer, with the formula:

  • Here, and are two linear layers, and denotes element-wise multiplication

  • The third linear layer, , is applied after this gated activation

  • For more information, see SwiGLU paper: GLU Variants Improve Transformer (2020)

代码
文本
[6]
#####################################
# Chapter 4
#####################################
# class FeedForward(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# self.layers = nn.Sequential(
# nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
# GELU(),
# nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
# )

# def forward(self, x):
# return self.layers(x)
代码
文本
[7]
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
self.silu = SiLU()

def forward(self, x):
x_fc1 = self.fc1(x)
x_fc2 = self.fc2(x)
x = self.silu(x_fc1) * x_fc2
return self.fc3(x)
代码
文本
  • Note that we also added a dtype=cfg["dtype"] setting above, which will allow us to load the model directly in lower precision formats later to save memory (versus instantiating it in the original 32-bit precision format and then converting it)
  • We also set bias=False since Llama doesn't use any bias units
代码
文本

 

1.4 Implement RoPE

代码
文本
  • In the GPT model, the positional embeddings are implemented as follows:
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
代码
文本
[8]
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):
assert head_dim % 2 == 0, "Embedding dimension must be even"

# Compute the inverse frequencies
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))

# Generate position indices
positions = torch.arange(context_length)

# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)

# Expand angles to match the head_dim
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)

# Precompute sine and cosine
cos = torch.cos(angles)
sin = torch.sin(angles)

return cos, sin

def compute_rope(x, cos, sin):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"

# Split x into first half and second half
x1 = x[..., : head_dim // 2] # First half
x2 = x[..., head_dim // 2 :] # Second half

# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)

return x_rotated.to(dtype=x.dtype)
代码
文本
  • The following is an example of applying RoPE to the q and k tensors:
代码
文本
[9]
# Settings
batch_size = 2
context_len = 5
num_heads = 4
head_dim = 16

# Instantiate RoPE parameters
cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, context_len, num_heads, head_dim)
keys = torch.randn(batch_size, context_len, num_heads, head_dim)

# Apply rotary position embeddings
queries_rot = compute_rope(queries, cos, sin)
keys_rot = compute_rope(keys, cos, sin)
代码
文本

 

1.5 Add RoPE to MultiHeadAttention module

代码
文本
  • It's important to note that GPT applies the positional embeddings to the inputs, whereas Llama applies rotations to the query and key vectors in the self-attention mechanism itself
  • Here, we modify the MultiHeadAttention class with the appropriate RoPE code
  • In addition, we remove the qkv_bias option and hardcode the bias=False setting
  • Also, we add a dtype setting to be able to instantiate the model with a lower precision later
  • Tip: since the TransformerBlocks (in the next section) are repeated exactly, we could simplify the code and only initialize the buffers once instead for each MultiHeadAttention module; however, we add the precomputed RoPE parameters to the MultiHeadAttention class so that it can function as a standalone module
代码
文本
[10]
#####################################
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, num_heads, dtype=None): # ,dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

################################### NEW ###################################
# Set bias=False and dtype=dtype for all linear layers below
###########################################################################
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) # Linear layer to combine head outputs
# self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

################################### NEW ###################################
cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)
self.register_buffer("cos", cos)
self.register_buffer("sin", sin)
###########################################################################


def forward(self, x):

b, num_tokens, d_in = x.shape

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

################################### NEW ###################################
keys = compute_rope(keys, self.cos, self.sin)
queries = compute_rope(queries, self.cos, self.sin)
###########################################################################

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
# attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec
代码
文本
  • Below is an example using the MultiHeadAttention module on an example input:
代码
文本
[11]
# Settings
batch_size = 1
context_len = 100
max_context_len = 4096
embed_dim = 128
num_heads = 4


example_batch = torch.randn((batch_size, context_len, embed_dim))

mha = MultiHeadAttention(
d_in=embed_dim,
d_out=embed_dim,
context_length=max_context_len,
num_heads=num_heads
)

mha(example_batch)

del mha # delete to safe memory
代码
文本

 

1.6 Update the TransformerBlock module

代码
文本
  • At this stage, most of the hard work is already done; we can now update the TransformerBlock to use the code we implemented above
  • This means we
  • replace LayerNorm with RMSNorm
  • remove dropout
  • remove the qkv_bias setting
  • add the dtype setting
代码
文本
[12]
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dtype=cfg["dtype"] # NEW
# dropout=cfg["drop_rate"],
# qkv_bias=cfg["qkv_bias"]
)
self.ff = FeedForward(cfg)

################################### NEW ###################################
# self.norm1 = LayerNorm(cfg["emb_dim"])
# self.norm2 = LayerNorm(cfg["emb_dim"])
self.norm1 = RMSNorm(cfg["emb_dim"])
self.norm2 = RMSNorm(cfg["emb_dim"])
###########################################################################

# self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
# x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back

# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
# x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back

return x
代码
文本

 

1.7 Update the model class

代码
文本
  • As you may recall from chapter 5, the TransformerBlock is a repeated block within the main model
  • Our Llama model is almost complete; we just have to update the model code surrounding the TransformerBlock
  • This means we
    • remove absolute positional embeddings since we have RoPE embeddings now
    • replace LayerNorm with RMSNorm
    • remove dropout
    • add the dtype setting
代码
文本
[13]
# class GPTModel(nn.Module):
class Llama2Model(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
# self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
# self.drop_emb = nn.Dropout(cfg["drop_rate"])

self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

################################### NEW ###################################
# self.final_norm = LayerNorm(cfg["emb_dim"])
self.final_norm = RMSNorm(cfg["emb_dim"])
###########################################################################
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds # + pos_embeds # Shape [batch_size, num_tokens, emb_size]
# x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
代码
文本

 

2. Initialize model

代码
文本
  • The model code is now complete, and we are ready to initialize it
  • In chapter 5, we used the following config file to specify the 124M-parameter GPT model:
代码
文本
[14]
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
代码
文本
  • For reference, the 1.5B parameter GPT model config is shown below as well:
代码
文本
[15]
GPT_CONFIG_1558M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 1600, # Embedding dimension
"n_heads": 25, # Number of attention heads
"n_layers": 48, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
代码
文本
  • Similarly, we can define a Llama 2 config file for the 7B model (we ignore the other larger models for simplicity here):
代码
文本
[16]
LLAMA2_CONFIG_7B = {
"vocab_size": 32000, # Vocabulary size
"context_length": 4096, # Context length
"emb_dim": 4096, # Embedding dimension
"n_heads": 32, # Number of attention heads
"n_layers": 32, # Number of layers
"hidden_dim": 11008, # NEW: Size of the intermediate dimension in FeedForward
"dtype": torch.bfloat16 # NEW: Lower-precision dtype to save memory
}
代码
文本
  • Using these settings, we can now initialize a Llama 2 7B model (note that this requires ~26 GB of memory)
代码
文本
[17]
model = Llama2Model(LLAMA2_CONFIG_7B)
代码
文本
[18]
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")
Total number of parameters: 6,738,415,616
代码
文本
  • As shown above, the model contains 6.7 billion parameters (commonly rounded and referred to as a 7B model)
  • Additionally, we can calculate the memory requirements for this model using the code below:
代码
文本
[19]
def model_memory_size(model, input_dtype=torch.float32):
total_params = 0
total_grads = 0
for param in model.parameters():
# Calculate total number of elements per parameter
param_size = param.numel()
total_params += param_size
# Check if gradients are stored for this parameter
if param.requires_grad:
total_grads += param_size

# Calculate buffer size (non-parameters that require memory)
total_buffers = sum(buf.numel() for buf in model.buffers())

# Size in bytes = (Number of elements) * (Size of each element in bytes)
# We assume parameters and gradients are stored in the same type as input dtype
element_size = torch.tensor(0, dtype=input_dtype).element_size()
total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

# Convert bytes to gigabytes
total_memory_gb = total_memory_bytes / (1024**3)

return total_memory_gb

print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")
float32 (PyTorch default): 52.33 GB
bfloat16: 26.17 GB
代码
文本
  • Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:
代码
文本
[20]
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")

model.to(device);
代码
文本

 

3. Load tokenizer

代码
文本
  • In this section, we are going to load the tokenizer for the model
  • Llama 2 uses Google's SentencePiece tokenizer instead of OpenAI's Tiktoken (but Llama 3 uses Tiktoken)
  • Meta AI shared the original Llama 2 model weights and tokenizer vocabulary on the Hugging Face Hub
  • We will download the tokenizer vocabulary from the Hub and load it into SentencePiece
  • Uncomment and run the following code to install the required libraries:
代码
文本
[21]
# !pip install huggingface_hub sentencepiece
代码
文本
  • Please note that Meta AI requires that you accept the Llama 2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the meta-llama/Llama-2-7b repository to accept the terms
  • Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on "Settings"
  • Then, create and copy the access token so you can copy & paste it into the next code cell
代码
文本
[22]
from huggingface_hub import login
import json

with open("config.json", "r") as config_file:
config = json.load(config_file)
access_token = config["HF_ACCESS_TOKEN"]

login(token=access_token)
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful
代码
文本
  • After login via the access token, which is necessary to verify that we accepted the Llama 2 licensing terms, we can now download the tokenizer vocabulary:
代码
文本
[23]
from huggingface_hub import hf_hub_download

tokenizer_file = hf_hub_download(
repo_id="meta-llama/Llama-2-7b",
filename="tokenizer.model",
local_dir="Llama-2-7B"
)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
代码
文本
  • To provide a more familiar interface for the tokenizer, we define a small LlamaTokenizer wrapper class:
代码
文本
[24]
import sentencepiece as spm


class LlamaTokenizer:
def __init__(self, filepath):
sp = spm.SentencePieceProcessor()
sp.load(tokenizer_file)
self.tokenizer = sp

def encode(self, text):
return self.tokenizer.encode_as_ids(text)

def decode(self, ids):
return self.tokenizer.decode_pieces(ids)


tokenizer = LlamaTokenizer(tokenizer_file)
代码
文本
  • We can now use the generate function to have the Llama 2 model generate new text:
代码
文本
[25]
from previous_chapters import generate, text_to_token_ids, token_ids_to_text


torch.manual_seed(123)

token_ids = generate(
model=model,
idx=text_to_token_ids("Every effort moves", tokenizer).to(device),
max_new_tokens=30,
context_size=LLAMA2_CONFIG_7B["context_length"],
top_k=1,
temperature=0.
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
 Every effort movesαllRadius deletingpretcc否']; future eer napulate lackус während inter DES издаSchéon로жа Bass differencespadxsnu ;; ctx始
代码
文本
  • Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 2 model yet
  • In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI
代码
文本

 

4. Load pretrained weights

代码
文本
  • We are loading the "meta-llama/Llama-2-7b" base model below, which is a simple text completion model before finetuning
  • Alternatively, you can load the instruction-finetuned and aligned "meta-llama/Llama-2-7b-chat" model by modifying the string in the next code cell accordingly
代码
文本
[26]
weights_file = hf_hub_download(
repo_id="meta-llama/Llama-2-7b",
filename="consolidated.00.pth",
local_dir="Llama-2-7b"
)
代码
文本
[27]
weights = torch.load(weights_file, weights_only=True)
代码
文本
  • The weights contains the following tensors (only the first 15 are shown for simplicity):
代码
文本
[28]
list(weights.keys())[:15]
['tok_embeddings.weight',
, 'norm.weight',
, 'output.weight',
, 'layers.0.attention.wq.weight',
, 'layers.0.attention.wk.weight',
, 'layers.0.attention.wv.weight',
, 'layers.0.attention.wo.weight',
, 'layers.0.feed_forward.w1.weight',
, 'layers.0.feed_forward.w2.weight',
, 'layers.0.feed_forward.w3.weight',
, 'layers.0.attention_norm.weight',
, 'layers.0.ffn_norm.weight',
, 'layers.1.attention.wq.weight',
, 'layers.1.attention.wk.weight',
, 'layers.1.attention.wv.weight']
代码
文本
  • The following function, modeled after the load_weights_into_gpt function in chapter 5, loads the pretrained weights into our Llama 2 model:
代码
文本
[29]
def assign(left, right):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")

if isinstance(right, torch.Tensor):
return torch.nn.Parameter(right.clone().detach())
else:
return torch.nn.Parameter(torch.tensor(right))


def load_weights_into_llama(model, param_config, params):
model.tok_emb.weight = assign(model.tok_emb.weight, params["tok_embeddings.weight"])

for l in range(param_config["n_layers"]):

# Load attention weights
model.trf_blocks[l].att.W_query.weight = assign(
model.trf_blocks[l].att.W_query.weight,
params[f"layers.{l}.attention.wq.weight"]
)
model.trf_blocks[l].att.W_key.weight = assign(
model.trf_blocks[l].att.W_key.weight,
params[f"layers.{l}.attention.wk.weight"]
)
model.trf_blocks[l].att.W_value.weight = assign(
model.trf_blocks[l].att.W_value.weight,
params[f"layers.{l}.attention.wv.weight"]
)
model.trf_blocks[l].att.out_proj.weight = assign(
model.trf_blocks[l].att.out_proj.weight,
params[f"layers.{l}.attention.wo.weight"]
)
model.trf_blocks[l].norm1.weight = assign(
model.trf_blocks[l].norm1.weight,
params[f"layers.{l}.attention_norm.weight"]
)

# Load FeedForward weights
model.trf_blocks[l].ff.fc1.weight = assign(
model.trf_blocks[l].ff.fc1.weight,
params[f"layers.{l}.feed_forward.w1.weight"]
)
# For some reason w2 and w3 are provided in the wrong order in the weights file
model.trf_blocks[l].ff.fc2.weight = assign(
model.trf_blocks[l].ff.fc2.weight,
params[f"layers.{l}.feed_forward.w3.weight"]
)
model.trf_blocks[l].ff.fc3.weight = assign(
model.trf_blocks[l].ff.fc3.weight,
params[f"layers.{l}.feed_forward.w2.weight"]
)
model.trf_blocks[l].norm2.weight = assign(
model.trf_blocks[l].norm2.weight,
params[f"layers.{l}.ffn_norm.weight"]
)

# Load output layer weights
model.final_norm.weight = assign(model.final_norm.weight, params["norm.weight"])
model.out_head.weight = assign(model.out_head.weight, params["output.weight"])


load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)
model.to(device);
代码
文本
  • Next, we are ready to use the model for text generation
代码
文本
[30]
torch.manual_seed(123)

token_ids = generate(
model=model,
idx=text_to_token_ids("Every effort", tokenizer).to(device),
max_new_tokens=25,
context_size=LLAMA2_CONFIG_7B["context_length"],
top_k=1,
temperature=0.
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
 Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication
代码
文本

 

5. Using the instruction-finetuned model

代码
文本
  • As mentioned earlier, above we used the pretrained base model; if you want to use a model capable of following instructions, use the "meta-llama/Llama-2-7b-chat" model instead, as shown below
代码
文本
[34]
del model # to free up memory

weights_file = hf_hub_download(
repo_id="meta-llama/Llama-2-7b-chat",
filename="consolidated.00.pth",
local_dir="Llama-2-7b-chat"
)

model = Llama2Model(LLAMA2_CONFIG_7B)
load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)
model.to(device);

torch.manual_seed(123)

token_ids = generate(
model=model,
idx=text_to_token_ids("What do llamas eat?", tokenizer).to(device),
max_new_tokens=25,
context_size=LLAMA2_CONFIG_7B["context_length"],
top_k=1,
temperature=0.
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
Output text:
 What do llamas eat?
Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass
代码
文本

 

What's next?

代码
文本
  • This notebook converted the original GPT-2 architecture into a Llama 2 model
  • If you are interested in how to convert Llama 2 into Llama 3, Llama 3.1, and Llama 3.2, check out the converting-llama2-to-llama3.ipynb notebook
代码
文本
Machine Learning
Machine Learning
点个赞吧
{/**/}