Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#
# Llama 2 is licensed under the LLAMA 2 Community License,
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
import torch

from zeroband.config import Config
from zeroband.models.llama.model import ModelArgs, Transformer
Expand Down Expand Up @@ -95,8 +94,6 @@
def make_model(
config: Config,
vocab_size: int,
dtype: torch.dtype,
device: torch.device
) -> tuple[Transformer, ModelArgs]:
"""
Constructs a model instance according to the supplied configuration and target vocab size
Expand All @@ -114,4 +111,4 @@ def make_model(
model_config.max_seq_len = config.data.seq_length
model_config.attn_fn = config.hardware.attn_fn

return Transformer(model_config, dtype=dtype, device=device), model_config
return Transformer(model_config), model_config
65 changes: 39 additions & 26 deletions src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


import contextlib

Check failure on line 14 in src/zeroband/models/llama/model.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/zeroband/models/llama/model.py:14:8: F401 `contextlib` imported but unused
from dataclasses import dataclass
from typing import Optional, Tuple

Expand All @@ -21,7 +21,7 @@
from zeroband.config import AttnFnType

from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE
from torch.nn.attention import SDPBackend, sdpa_kernel

Check failure on line 24 in src/zeroband/models/llama/model.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/zeroband/models/llama/model.py:24:32: F401 `torch.nn.attention.SDPBackend` imported but unused

Check failure on line 24 in src/zeroband/models/llama/model.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/zeroband/models/llama/model.py:24:44: F401 `torch.nn.attention.sdpa_kernel` imported but unused

from zeroband.utils.mfu_tracker import FlopCounter

Expand Down Expand Up @@ -170,7 +170,7 @@
return torch.stack([torch.repeat_interleave(torch.arange(len(seq), device=seq.device), seq) for seq in seqlens])


def create_block_mask_from_seqlens(seqlens: list[torch.Tensor], dtype: torch.dtype, device: torch.device) -> BlockMask:
def create_block_mask_from_seqlens(seqlens: list[torch.Tensor]) -> BlockMask:
"""Creates a block mask from a list of sequence lengths.

Example:
Expand All @@ -183,7 +183,7 @@
[0 0 1 1 0] # Second token of doc 1 can see both tokens of doc 1
[0 0 0 0 1]] # Token of doc 2 can only see itself
"""
docs = seqlens_to_docs_tensor(seqlens).to(dtype=dtype, device=device)
docs = seqlens_to_docs_tensor(seqlens).to("cuda")
batch_size, max_seq_len = docs.shape

def document_causal_mask(b, h, q_idx, kv_idx):
Expand All @@ -197,7 +197,7 @@
None,
max_seq_len,
max_seq_len,
device=device.type,
device="cuda",
_compile=True,
BLOCK_SIZE=max_seq_len if max_seq_len < _DEFAULT_SPARSE_BLOCK_SIZE else _DEFAULT_SPARSE_BLOCK_SIZE,
)
Expand All @@ -222,17 +222,17 @@

"""

def __init__(self, model_args: ModelArgs, dtype: torch.dtype, device: torch.device):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.n_kv_heads = model_args.n_heads if model_args.n_kv_heads is None else model_args.n_kv_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.head_dim = model_args.dim // model_args.n_heads

self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False, dtype=dtype, device=device)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False, dtype=dtype, device=device)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False, dtype=dtype, device=device)
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False, dtype=dtype, device=device)
self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)

self.attn_fn = model_args.attn_fn

Expand Down Expand Up @@ -342,8 +342,6 @@
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
dtype: torch.dtype,
device: torch.device
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
Expand All @@ -352,9 +350,9 @@
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = nn.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = nn.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = nn.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)

def forward(self, x: torch.Tensor, flop_counter: FlopCounter = FlopCounter()):
flop_counter.track_linear(self.w1, x)
Expand Down Expand Up @@ -400,23 +398,22 @@

"""

def __init__(self, layer_id: int, model_args: ModelArgs, dtype: torch.dtype, device: torch.device):
def __init__(self, layer_id: int, model_args: ModelArgs):
super().__init__()
self.n_heads = model_args.n_heads
self.dim = model_args.dim
self.attention = Attention(model_args, dtype=dtype, device=device)
self.attention = Attention(model_args)
self.feed_forward = FeedForward(
dim=model_args.dim,
hidden_dim=4 * model_args.dim,
multiple_of=model_args.multiple_of,
ffn_dim_multiplier=model_args.ffn_dim_multiplier,
dtype=dtype, device=device
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers

self.attention_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps, dtype=dtype, device=device)
self.ffn_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps, dtype=dtype, device=device)
self.attention_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps)

if model_args.depth_init:
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
Expand Down Expand Up @@ -484,7 +481,7 @@

"""

def __init__(self, model_args: ModelArgs, dtype: torch.dtype, device: torch.device):
def __init__(self, model_args: ModelArgs):
super().__init__()
self.model_args = model_args
self.vocab_size = model_args.vocab_size
Expand All @@ -499,15 +496,15 @@
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
# initialized by the checkpoint, or we need to add a separate initializer for
# just the non-persistent buffers that is called after loading checkpoints.
self.register_buffer("freqs_cis", self._precompute_freqs_cis(dtype=dtype, device=device), persistent=True)
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)

self.layers = torch.nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args, dtype=dtype, device=device)
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps, dtype=dtype, device=device)
self.norm = build_norm(model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False, dtype=dtype, device=device)
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

def init_weights(self):
Expand All @@ -522,6 +519,8 @@
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
with torch.device(self.freqs_cis.device):
self.freqs_cis = self._precompute_freqs_cis()
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
for layer in self.layers.values():
Expand All @@ -540,14 +539,14 @@
b=cutoff_factor * final_out_std,
)

def _precompute_freqs_cis(self, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
self.model_args.rope_theta
).to(dtype=dtype, device=device)
self.model_args.rope_theta,
)

def forward(self, tokens: torch.Tensor, block_mask: BlockMask | None = None, flop_counter: FlopCounter = FlopCounter()):
"""
Expand Down Expand Up @@ -576,6 +575,20 @@

return output

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
"""
Initialize a Transformer model from a ModelArgs object.

Args:
model_args (ModelArgs): Model configuration arguments.

Returns:
Transformer: Transformer model.

"""
return cls(model_args)

def count_parameters(self, exclude_embedding: bool = False) -> int:
"""
Counts the number of parameters.
Expand Down
28 changes: 11 additions & 17 deletions src/zeroband/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@
from torch.distributed.tensor.experimental import local_map


def build_norm(norm_type: str, dim: int, eps: float, dtype: torch.dtype, device: torch.device):
def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Builds the specified normalization layer based on the norm_type.

Args:
norm_type (str): The type of normalization layer to build.
Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
dim (int): The dimension of the normalization layer.
eps (float, optional): The epsilon value for numerical stability.
dtype: The data type to use for the parameter tensor
device: The device to place the layer on
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.

Returns:
The built normalization layer.
Expand All @@ -42,13 +40,13 @@ def build_norm(norm_type: str, dim: int, eps: float, dtype: torch.dtype, device:
norm_type = norm_type.lower() # Normalize to lowercase

if norm_type == "layernorm":
return nn.LayerNorm(dim, eps=eps, bias=False, dtype=dtype, device=device)
return nn.LayerNorm(dim, eps=eps, bias=False)
elif norm_type == "np_layernorm":
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False, dtype=dtype, device=device)
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps, dtype=dtype, device=device)
return RMSNorm(dim, eps=eps)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=eps, dtype=dtype, device=device)
return FusedRMSNorm(dim, eps=eps)
else:
raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")

Expand All @@ -59,13 +57,11 @@ class FusedRMSNorm(nn.Module):
def __init__(
self,
dim: int,
eps: float,
dtype: torch.dtype,
device: torch.device
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
self.weight = nn.Parameter(torch.ones(dim))
self.fused_rms_norm_fn = fused_rms_norm_fn

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -86,20 +82,18 @@ class RMSNorm(nn.Module):

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability.
dtype: The data type to use
device: The torch device to place the layer on
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""

def __init__(self, dim: int, eps: float, dtype: torch.dtype, device: torch.device):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype, device=device))
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
Expand Down
21 changes: 7 additions & 14 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import zlib
from dataclasses import asdict
from logging import Logger
from typing import TYPE_CHECKING, Optional, Iterator, List, Dict, Tuple

Check failure on line 7 in src/zeroband/train.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/zeroband/train.py:7:67: F401 `typing.Tuple` imported but unused

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -124,7 +124,6 @@
num_param_scalars = model.count_parameters()

for _inner_step in range(num_inner_steps):
#torch.cuda.memory._record_memory_history(max_entries=100000)
train_profiler.start_session("inner_step")

flop_counter = FlopCounter()
Expand Down Expand Up @@ -234,9 +233,6 @@
memory_profiler.step()
train_profiler.end_session()

#torch.cuda.memory._dump_snapshot('snapshot.pickle')
#torch.cuda.memory._record_memory_history(enabled=None)


def compute_crc32(tensor: torch.Tensor) -> int:
tensor_cpu = tensor.detach().cpu()
Expand Down Expand Up @@ -522,7 +518,7 @@
return shared_state


def train(logger: Logger, config: Config, mpi_config: Optional[MPIConfig], dtype: torch.dtype, device: torch.device):
def train(logger: Logger, config: Config, mpi_config: Optional[MPIConfig], device: torch.device):
grad_accum_steps = calc_gradient_accumulation_steps(
config.train.batch_size, config.hardware.micro_batch_size, mpi_config
)
Expand All @@ -546,8 +542,6 @@
model, model_config = make_model(
config,
vocab_size=tokenizer_info.vocab_size,
dtype=dtype,
device=device,
)
num_param_scalars = model.count_parameters()
logger.info(f"Number of parameters: {num_param_scalars}")
Expand Down Expand Up @@ -711,16 +705,16 @@
continue

local_world_size = communicator.get_attribute(Attribute.LOCAL_WORLD_SIZE)
#if local_world_size < 2:
# logger.info("Waiting for more workers to join...")
# time.sleep(1)
# continue
if local_world_size < 2:
logger.info("Waiting for more workers to join...")
time.sleep(1)
continue

if topology_updated:
logger.info("Optimizing Topology...")
while True:
try:
# communicator.optimize_topology() # may raise an error if it fails
communicator.optimize_topology() # may raise an error if it fails
break
except PCCLError as e:
print(f"[Peer] OptimizeTopology failed => {e}. Retrying...")
Expand Down Expand Up @@ -835,8 +829,7 @@
device = torch.device(f'cuda:{torch.cuda.current_device()}')
logger.info(f"Using device: {torch.cuda.get_device_name(device)}")

dtype = torch.bfloat16 # TODO: MAKE CONFIGURABLE
train(logger, config, mpi_config, dtype, device)
train(logger, config, mpi_config, device)


if __name__ == "__main__":
Expand Down
Loading