Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HF_MODEL to load models directly from huggingface #17801

Merged
merged 5 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions models/demos/llama3/PERF.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ This configuration uses bfp4 MLP FF1+FF3 for all models.
| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|----------------|--------|-----------|-----------|---------------|
| Llama3.2-1B | N150 | 89 | 98 | 86.9 |
| Llama3.2-1B | N300 | 91 | 98 | 104.3 |
| Llama3.2-1B | T3K | 91 | 98 | 118.5 |
| Llama3.2-1B | N300 | 90 | 98 | 104.3 |
| Llama3.2-1B | T3K | 87 | 98 | 118.5 |
| Llama3.2-1B | TG | | | 72.3 |
| Llama3.2-3B | N150 | 92 | 96 | 53.3 |
| Llama3.2-3B | N150 | 91 | 96 | 53.3 |
| Llama3.2-3B | N300 | 91 | 96 | 66.1 |
| Llama3.2-3B | T3K | 91 | 96 | 66.9 |
| Llama3.2-3B | TG | | | 48.5 |
| Llama3.1-8B | N150 | 87 | 99 | 27.9 |
| Llama3.1-8B | N300 | 88 | 99 | 43.7 |
| Llama3.1-8B | T3K | 88 | 100 | 64.2 |
| Llama3.1-8B | T3K | 88 | 99 | 64.2 |
| Llama3.1-8B | TG | | | 41.0 |
| Llama3.2-11B | N300 | 89 | 99 | 43.5 |
| Llama3.2-11B | T3K | 88 | 99 | 63.4 |
Expand All @@ -37,12 +37,12 @@ This configuration uses bfp4 MLP FF1+FF3 only for the Llama-3.1-70B model and th
| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|----------------|--------|-----------|-----------|---------------|
| Llama3.2-1B | N150 | 88 | 98 | 86.8 |
| Llama3.2-1B | N300 | 90 | 98 | 98.1 |
| Llama3.2-1B | T3K | 90 | 98 | 97.5 |
| Llama3.2-1B | N300 | 88 | 98 | 98.1 |
| Llama3.2-1B | T3K | 89 | 99 | 97.5 |
| Llama3.2-1B | TG | 87 | 98 | 51.3 |
| Llama3.2-3B | N150 | 93 | 99 | 44.2 |
| Llama3.2-3B | N150 | 92 | 99 | 44.2 |
| Llama3.2-3B | N300 | 92 | 98 | 54.2 |
| Llama3.2-3B | T3K | 93 | 98 | 55.6 |
| Llama3.2-3B | T3K | 91 | 100 | 55.6 |
| Llama3.2-3B | TG | 91 | 98 | 33.6 |
| Llama3.1-8B | N150 | 93 | 100 | 23.6 |
| Llama3.1-8B | N300 | 93 | 100 | 34.5 |
Expand Down
29 changes: 15 additions & 14 deletions models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_tt_model_acc(
text = f.read()

# Encode text to tokens
encoded_tokens = tokenizer.encode(text, bos=True, eos=False)
encoded_tokens = model_args.encode_prompt(text, system_prompt_text=None, instruct=False)
total_length = prefill_len + decode_len + 1
reference_tokens = torch.tensor(encoded_tokens[:total_length]).unsqueeze(0)
top5_tokens = None # Will be computed during inference
Expand Down Expand Up @@ -439,17 +439,18 @@ def test_tt_model_acc(
true_word = sanitize(tokenizer.decode([true_token]))
logger.info(f"{error['position']}: {context}[{incorrect}] != [{expected}], true: [{true_word}]")

# Get accuracy thresholds from PERF.md
min_top1_acc, min_top5_acc = get_accuracy_thresholds(
model_args.base_model_name,
model_args.device_name,
optimizations,
)
if use_reference_file:
# Get accuracy thresholds from PERF.md
min_top1_acc, min_top5_acc = get_accuracy_thresholds(
model_args.base_model_name,
model_args.device_name,
optimizations,
)

logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%")
assert (
total_top1_acc >= min_top1_acc
), f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >={min_top1_acc}%)"
assert (
total_top5_acc >= min_top5_acc
), f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >={min_top5_acc}%)"
logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%")
assert (
total_top1_acc >= min_top1_acc
), f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >={min_top1_acc}%)"
assert (
total_top5_acc >= min_top5_acc
), f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >={min_top5_acc}%)"
6 changes: 3 additions & 3 deletions models/demos/llama3/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import ttnn
from models.common.lightweightmodule import LightweightModule
from models.demos.llama3.tt.llama_ccl import tt_all_reduce, tt_all_gather
from models.demos.llama3.tt.llama_common import first_five
from models.demos.llama3.tt.load_checkpoints import permute


class TtLlamaAttention(LightweightModule):
Expand Down Expand Up @@ -138,7 +136,9 @@ def __init__(
)
# as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size
self.wqkv_bias_prefill = ttnn.reshape(
self.wqkv_bias_prefill, ttnn.Shape([1, 1, 1, self.wqkv_bias_prefill.shape[-1]])
self.wqkv_bias_prefill,
(1, 1, 1, self.wqkv_bias_prefill.shape[-1]),
(1, 1, self.wqkv_bias_prefill.shape[-2], self.wqkv_bias_prefill.shape[-1]),
)

# Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size
Expand Down
11 changes: 7 additions & 4 deletions models/demos/llama3/tt/load_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ def load_hf_state_dict(ckpt_dir):
raise FileNotFoundError(f"Neither model.safetensors.index.json nor model.safetensors found in {ckpt_dir}")
loaded_weights = safetensors_load_file(safetensor_path)

if not "lm_head.weight" in loaded_weights:
# Assume tied to the embeddings if not present
loaded_weights["lm_head.weight"] = loaded_weights["model.embed_tokens.weight"]

return loaded_weights


def standardize_hf_keys(state_dict):
if not "lm_head.weight" in state_dict:
# Assume tied to the embeddings if not present
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
return state_dict


def convert_hf_to_meta(state_dict, head_dim):
state_dict = convert_hf_qkv_to_meta_format(state_dict, head_dim)
state_dict = map_hf_to_meta_keys(state_dict)
Expand Down
42 changes: 34 additions & 8 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
convert_hf_to_meta,
convert_meta_to_hf,
reverse_permute,
standardize_hf_keys,
)


Expand Down Expand Up @@ -114,8 +115,10 @@ def __init__(
self.max_batch_size = max_batch_size
self.tile_size = 32
self.is_70b = False
self.from_hf_url = False # updated below if true

LLAMA_DIR = os.getenv("LLAMA_DIR")
HF_MODEL = os.getenv("HF_MODEL")
if LLAMA_DIR:
if any([os.getenv("LLAMA_CKPT_DIR"), os.getenv("LLAMA_TOKENIZER_PATH"), os.getenv("LLAMA_CACHE_PATH")]):
logger.warning(
Expand All @@ -125,10 +128,18 @@ def __init__(
self.DEFAULT_TOKENIZER_PATH = LLAMA_DIR
self.DEFAULT_CACHE_PATH = os.path.join(LLAMA_DIR, self.device_name)
self.model_name = os.path.basename(LLAMA_DIR) # May be overridden by config
elif HF_MODEL:
self.DEFAULT_CKPT_DIR = HF_MODEL
self.DEFAULT_TOKENIZER_PATH = HF_MODEL
self.DEFAULT_CACHE_PATH = os.getenv("LLAMA_CACHE_PATH")
if not self.DEFAULT_CACHE_PATH:
self.DEFAULT_CACHE_PATH = os.path.join("model_cache", HF_MODEL, self.device_name)
self.model_name = HF_MODEL # May be overridden by config
self.from_hf_url = True
else:
assert "Please set $LLAMA_DIR to a valid checkpoint directory"

if not dummy_weights:
if not dummy_weights and not HF_MODEL:
# Assert if all folders and files exist
assert os.path.exists(
self.DEFAULT_CKPT_DIR
Expand Down Expand Up @@ -157,7 +168,10 @@ def __init__(
self.instruct = True

# Load model params
if not dummy_weights:
if HF_MODEL:
self.checkpoint_type = CheckpointType.HuggingFace
self._set_hf_params(self.DEFAULT_CKPT_DIR)
elif not dummy_weights:
self.checkpoint_type = self.detect_checkpoint_type()
self._set_model_params(self.DEFAULT_CKPT_DIR)
else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders.
Expand Down Expand Up @@ -1107,10 +1121,15 @@ def _set_llama_params(self, checkpoint_dir):
self.orig_context_len = 8192

def _set_hf_params(self, checkpoint_dir):
config_file = os.path.join(checkpoint_dir, "config.json")
assert os.path.exists(config_file), f"config.json file not found at {config_file}"
with open(config_file, "r") as f:
config = json.load(f)
if self.from_hf_url:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(self.model_name).to_dict()
else:
config_file = os.path.join(checkpoint_dir, "config.json")
assert os.path.exists(config_file), f"config.json file not found at {config_file}"
with open(config_file, "r") as f:
config = json.load(f)
self._set_params_from_dict(config)

def __repr__(self):
Expand Down Expand Up @@ -1172,7 +1191,14 @@ def load_state_dict(self):
state_dict = load_meta_state_dict(self.DEFAULT_CKPT_DIR, self.n_layers)
else:
assert self.checkpoint_type == CheckpointType.HuggingFace
state_dict = load_hf_state_dict(self.DEFAULT_CKPT_DIR)
if self.from_hf_url:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(self.DEFAULT_CKPT_DIR)
state_dict = model.state_dict()
else:
state_dict = load_hf_state_dict(self.DEFAULT_CKPT_DIR)
state_dict = standardize_hf_keys(state_dict)
state_dict = convert_hf_to_meta(state_dict, self.head_dim)
keys_dict = list(state_dict.keys())[:]
remv = [f"layers.{i}." for i in list(range(self.n_layers, self.full_model_n_layers))]
Expand Down Expand Up @@ -1210,7 +1236,7 @@ def matmul_config(
) # TODO: Needed for TG hang workaround

if in0_block_w is None:
in0_block_w = min(4, max(1, k // (self.tile_size * grid_size[0])))
in0_block_w = self.find_largest_divisor(k // (self.tile_size * grid_size[1]))

return ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
Expand Down
Loading