Skip to content

Commit 1ade6d7

Browse files
hannanjgawsaws-patlangekarthick gopalswamyaws-shubhamchandak94aws-amerrez
authored
Sync internal repo to external Dec 12 2024 (#103)
Sync internal repo to external Dec 12 2024 --------- Co-authored-by: Patrick Lange <patlange@amazon.com> Co-authored-by: karthick gopalswamy <kgopalsw@amazon.com> Co-authored-by: Shubham Chandak <chndkv@amazon.com> Co-authored-by: Amer <amerrez@amazon.com> Co-authored-by: Jonathan Lunt <jlunt@amazon.com> Co-authored-by: Prithvijit Chakrabarty <prichakr@amazon.com> Co-authored-by: Harsha Bikki <harbikh@amazon.com> Co-authored-by: Bowen Chen <bowencc@amazon.com> Co-authored-by: Dylan Geva <gevadyla@amazon.com> Co-authored-by: Amulya Ballakur <amulyaab@amazon.com> Co-authored-by: Akhil Raj Azhikodan <aazhiko@amazon.com> Co-authored-by: Yuan Zhou <yazhom@amazon.com> Co-authored-by: Yishan McNabb <yishanm@amazon.com> Co-authored-by: Jiyoung An <jiyoua@amazon.com> Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com> Co-authored-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: Shashwat Srijan <sssrijan@amazon.com> Co-authored-by: yichi <yichi@amazon.com> Co-authored-by: Abhinandan Patni <abhpat@amazon.com> Co-authored-by: Yi-Hsiang (Sean) Lai <yihsian@amazon.com> Co-authored-by: Seung Hun Chung <shchung@amazon.com> Co-authored-by: Udit Deshmukh <desudit@amazon.com> Co-authored-by: Hongbo Shi <hongbshi@amazon.com> Co-authored-by: Lifan Shen <lifans@amazon.com> Co-authored-by: Mike Zhang <zhanyequ@amazon.com> Co-authored-by: Changchang Wang <wchangch@amazon.com> Co-authored-by: jfduan <jfduan@amazon.com>
1 parent c8d6bdc commit 1ade6d7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+7545
-916
lines changed

src/transformers_neuronx/base.py

Lines changed: 280 additions & 54 deletions
Large diffs are not rendered by default.

src/transformers_neuronx/bloom/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ def __init__(
4444
self.batch_size = batch_size
4545
self.amp = amp
4646
self.tp_degree = tp_degree
47+
self.model_type = 'bloom'

src/transformers_neuronx/bloom/hlo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def inputs(self, scribe, dtype, n_active_tokens, batch_size):
3030
)
3131
return tensors, dims
3232

33-
def embedding(self, input_ids, cache_ids, start_ids, last_token_id, slopes, word_embeddings, ln_weight, ln_bias):
33+
def embedding(self, input_ids, cache_ids, start_ids, last_token_id, block_tables, context_lens, slopes, word_embeddings, ln_weight, ln_bias):
3434
dtype = getattr(input_ids.scribe, self.config.amp)
3535
hidden = hlo.embedding(word_embeddings, input_ids, tp_degree=self.config.tp_degree, dtype=dtype)
3636
if self.config.hidden_size % self.config.tp_degree != 0:
@@ -41,9 +41,10 @@ def embedding(self, input_ids, cache_ids, start_ids, last_token_id, slopes, word
4141
return hlo.layer_norm_bsh(hidden, ln_weight, ln_bias) if is_bsh \
4242
else hlo.layer_norm(hidden, ln_weight, ln_bias)
4343

44-
def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *pre_layer_weights):
44+
def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, context_lens, *pre_layer_weights):
4545
slopes, *rest = pre_layer_weights
46-
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions)
46+
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions,
47+
last_token_id=last_token_id, neuron_config=self.neuron_config)
4748
prior_alibi, active_alibi = alibi.alibi(slopes, mask, active_mask)
4849
return hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi, active_alibi
4950

src/transformers_neuronx/bloom/model.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
import torch
16-
import math
17-
import warnings
18-
1915
from transformers_neuronx import decoder
20-
from transformers_neuronx import module
21-
from transformers_neuronx import ops
2216
from transformers_neuronx import sampling
23-
from transformers_neuronx import utils
2417
from transformers_neuronx import bucket
2518
from transformers_neuronx import base
2619
from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB
@@ -67,11 +60,7 @@ def __init__(self, config, *, n_positions=2048, batch_size=1, amp='f32', tp_degr
6760
self.decoder_lm_head = self.decoder_param_set.init_token_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self)
6861

6962
def load_weights(self):
70-
# Materialize the embedding to CPU
71-
self.chkpt_model.transformer.word_embeddings.materialize()
72-
self.chkpt_model.transformer.word_embeddings_layernorm.materialize()
73-
74-
ops.init()
63+
self.materialize_embeddings()
7564

7665
n_head = self.config.n_head
7766
hidden_size = self.config.hidden_size
@@ -142,6 +131,7 @@ def load_weights(self):
142131
ln_f = self.chkpt_model.transformer.ln_f
143132
ln_f.materialize()
144133
self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), ln_f.bias.detach())
134+
ln_f.nullify()
145135

146136
lm_head = self.chkpt_model.lm_head
147137
lm_head.materialize()
@@ -154,7 +144,20 @@ def load_weights(self):
154144
self.decoder_lm_head.add_pre_layer_parameter(self.chkpt_model.transformer.word_embeddings_layernorm.weight)
155145
self.decoder_lm_head.add_pre_layer_parameter(self.chkpt_model.transformer.word_embeddings_layernorm.bias)
156146
self.decoder_lm_head.to_neuron()
147+
self.init_rest_of_model()
148+
self.maybe_nullify_embeddings()
149+
150+
def materialize_embeddings(self):
151+
# Materialize the embedding to CPU
152+
self.chkpt_model.transformer.word_embeddings.materialize()
153+
self.chkpt_model.transformer.word_embeddings_layernorm.materialize()
154+
155+
def maybe_nullify_embeddings(self):
156+
if self.neuron_config.on_device_embedding:
157+
self.chkpt_model.transformer.word_embeddings.nullify()
158+
self.chkpt_model.transformer.word_embeddings_layernorm.nullify()
157159

160+
def init_rest_of_model(self):
158161
if self.context_buckets:
159162
for context_length_estimate in self.context_buckets:
160163
for batch_size in self.batch_sizes:

src/transformers_neuronx/compiler.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414
# ==============================================================================
1515
import os
1616
import shlex
17+
import shutil
1718
import subprocess
1819
import hashlib
1920
import tarfile
20-
import tempfile
21-
from contextlib import contextmanager
22-
import numpy as np
21+
from contextlib import contextmanager, nullcontext
2322
from textwrap import dedent
24-
import torch
2523
import logging
2624
import json
2725
import math
26+
27+
import numpy as np
28+
import torch
29+
2830
from torch_neuronx.pyhlo import xla_data_pb2
2931
from torch_neuronx.pyhlo.scribe import HloScribe
3032
from torch_neuronx.pyhlo.constant.serialize_torch import serialize_torch
@@ -35,6 +37,7 @@
3537
from libneuronxla.neuron_cc_cache import CacheUrl, create_compile_cache
3638
from neuronxcc import __version__ as compiler_version
3739

40+
3841
def get_hash_module(hlo_module, flags):
3942
# Hashing is pretty fast and neglegible compared to compilation time
4043
hash_gen = hashlib.sha256()
@@ -45,8 +48,29 @@ def get_hash_module(hlo_module, flags):
4548
hash = str(hash_gen.hexdigest())[:20]
4649
return hash
4750

51+
52+
@contextmanager
53+
def envvar(key, value):
54+
prior = os.environ.pop(key, None)
55+
if value is not None:
56+
os.environ[key] = value
57+
try:
58+
yield
59+
finally:
60+
os.environ.pop(key, None)
61+
if prior is not None:
62+
os.environ[key] = prior
63+
64+
4865
def compile_py_func(py_func):
49-
return HloScribe(serialize_torch)(py_func).module_proto
66+
67+
# Adds file/scope metadata during debug dump
68+
context = nullcontext()
69+
if "NEURONX_DUMP_TO" in os.environ and 'ENABLE_PYHLO_FILE_METADATA' not in os.environ:
70+
context = envvar('ENABLE_PYHLO_FILE_METADATA', '1')
71+
72+
with context:
73+
return HloScribe(serialize_torch)(py_func).module_proto
5074

5175

5276
def build_kernel(py_func, tp_degree):
@@ -81,9 +105,11 @@ def get_compiler_flags() -> str:
81105
return ' '.join(flags)
82106

83107

84-
def compile_hlo_module(hlo_module, tag=None):
108+
def compile_hlo_module(hlo_module, tag=None, num_exec_repetition=1):
85109

86110
flags = get_compiler_flags()
111+
flags = f'{flags} --execute-repetition={num_exec_repetition}'
112+
87113
module_flag_hash = get_hash_module(hlo_module, flags)
88114
module_hash = get_hash_module(hlo_module, None)
89115

@@ -97,10 +123,8 @@ def compile_hlo_module(hlo_module, tag=None):
97123
hlo_module_name = f'{tag}-{hlo_module.name}.{compiler_version}.{module_flag_hash}'
98124

99125
if dump:
100-
101-
102-
dump_to = os.environ.get('NEURONX_DUMP_TO', '/tmp')
103-
dump_to = os.path.join(dump_to, hlo_module_name)
126+
dump_to_parent = os.environ.get('NEURONX_DUMP_TO', '/tmp')
127+
dump_to = os.path.join(dump_to_parent, hlo_module_name)
104128
os.makedirs(dump_to, exist_ok=True)
105129
hlo_module_path = os.path.join(dump_to, f'{hlo_module_name}.pb')
106130
hlo_module_path = os.path.realpath(hlo_module_path)
@@ -115,6 +139,10 @@ def compile_hlo_module(hlo_module, tag=None):
115139
subprocess.check_call(command_line, cwd=dump_to)
116140
with open(neff_path, 'rb') as f:
117141
neff_bytes = f.read()
142+
try:
143+
shutil.copyfile(os.path.join(dump_to_parent, 'neuron_model_config.json'), os.path.join(dump_to, 'neuron_model_config.json'))
144+
except FileNotFoundError:
145+
pass
118146
else:
119147
module_bytes = hlo_module.SerializeToString()
120148
try:
@@ -201,7 +229,10 @@ def __init__(self):
201229
F32 FLOAT float32
202230
F64 DOUBLE float64
203231
BF16 BFLOAT16 bfloat16
232+
F8E4M3FN INT8 float8_e4m3fn
204233
'''
234+
# Note that for FP8 we map metaneff datatype to int8, since from the runtime perspective these datatypes are functionally equivalent (for fp8 storage only)
235+
# Within Tnx, we no longer use the metaneff flow, so this would not matter anyway.
205236
name_mapping = dedent(name_mapping)
206237
name_mapping = name_mapping.lstrip().strip()
207238
self.hlo2metaneff_mapping = {}
@@ -211,6 +242,8 @@ def __init__(self):
211242
for line in name_mapping.split('\n'):
212243
line = line.lstrip().strip()
213244
pname, dname, tname = line.split()
245+
if not hasattr(torch, tname):
246+
continue
214247
primitive_type = getattr(xla_data_pb2.PrimitiveType, pname)
215248
metaneff_dtype = getattr(metaneff_pb2.MetaTensor.DataType, dname)
216249
torch_dtype = getattr(torch, tname)
@@ -355,10 +388,10 @@ def __call__(self, inputs, return_ranks: int = -1):
355388
result: The output tensors from each rank concatenated along dim 0.
356389
"""
357390
casted = []
358-
for cpu, buf in zip(inputs, self.inputs):
391+
for i, (cpu, buf) in enumerate(zip(inputs, self.inputs)):
359392
if cpu.shape != buf.shape:
360393
raise AssertionError(
361-
f"Input shape mismatch. Expected {buf.shape}, but got {cpu.shape}"
394+
f"{i+1}th input shape mismatch. Expected {buf.shape}, but got {cpu.shape}"
362395
)
363396
if cpu.dtype != buf.dtype:
364397
cpu = cpu.to(buf.dtype)
@@ -444,7 +477,7 @@ def io_ring_cache_context(size):
444477

445478
class ParallelKernel:
446479
hlo_snapshot_iter = 0
447-
def __init__(self, hlo_module, tp_degree, g_start_device_id=0, g_device_count=None, tag=None):
480+
def __init__(self, hlo_module, tp_degree, g_start_device_id=0, g_device_count=None, tag=None, num_exec_repetition=1):
448481
self.hlo_module = hlo_module
449482
self.tp_degree = tp_degree
450483
self.neff_bytes = None
@@ -459,6 +492,7 @@ def __init__(self, hlo_module, tp_degree, g_start_device_id=0, g_device_count=No
459492
self.tag = tag
460493
self.g_device_count = g_device_count
461494
self.memories = []
495+
self.num_exec_repetition = num_exec_repetition
462496
self.total_input_tensors_size = get_total_input_tensors_size(self.hlo_module)
463497
logging.debug(f"Total input tensor size of the module (per rank): {self.total_input_tensors_size / (10**9)} G, whole (all ranks): {self.total_input_tensors_size * tp_degree / (10**9)} G")
464498

@@ -467,15 +501,15 @@ def build_memory(self):
467501
self.memories.append(memory)
468502
return memory
469503

470-
def compile(self):
471-
self.build()
504+
def compile(self, num_exec_repetition=1):
505+
self.build(num_exec_repetition)
472506
return self.neff_bytes
473507

474-
def build(self):
508+
def build(self, num_exec_repetition=1):
475509
# Avoid rebuilding NEFF. This path occurs during deserialization
476510
if self.neff_bytes is not None:
477511
return
478-
self.neff_bytes = compile_hlo_module(self.hlo_module, self.tag)
512+
self.neff_bytes = compile_hlo_module(self.hlo_module, self.tag, num_exec_repetition)
479513

480514
def load(self, io_ring_cache_size=1):
481515
assert self.neff_bytes is not None, f"Try to load with neff bytes as None, might due to compilation failure"
@@ -700,3 +734,4 @@ def setup(self, nc_input_buffers, nc_output_buffers, output_count=None):
700734

701735
def run(self):
702736
self.kernel(self.memories)
737+

0 commit comments

Comments
 (0)