Skip to content

Commit 4c363ca

Browse files
authored
Merge branch 'main' into use_logger
Signed-off-by: Mamta Singh <[email protected]>
2 parents 3d8a53e + db38927 commit 4c363ca

30 files changed

+414
-156
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import hashlib
99
import inspect
10-
import json
1110
import logging
1211
import shutil
1312
import subprocess
@@ -23,7 +22,7 @@
2322
from QEfficient.base.pytorch_transforms import PytorchTransform
2423
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2524
from QEfficient.generation.cloud_infer import QAICInferenceSession
26-
from QEfficient.utils import constants, dump_qconfig
25+
from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
2726
from QEfficient.utils.cache import QEFF_HOME, to_hashable
2827

2928
logger = logging.getLogger(__name__)
@@ -269,17 +268,17 @@ def _compile(
269268
specializations=specializations,
270269
custom_io=custom_io,
271270
device_group=list(range(mdp_ts_num_devices)),
272-
num_cores=compiler_options.get("aic_num_cores", 16),
273-
mxfp6=compiler_options.get("mxfp6_matmul", False),
271+
num_cores=compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES),
272+
mxfp6=compiler_options.get("mxfp6_matmul", constants.DEFAULT_AIC_MXPF6_MATMUL),
274273
mxint8=mxint8_kv_cache,
275274
qnn_config=qnn_config,
276275
)
277276

278277
return self.qpc_path
279278

280279
command = constants.COMPILER + [f"-m={onnx_path}"]
281-
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
282-
mdp_ts_num_devices = None
280+
281+
if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
283282
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
284283

285284
for key, value in compiler_options.items():
@@ -289,6 +288,17 @@ def _compile(
289288
command.append(option)
290289
continue
291290
command.append(f"{option}={value}")
291+
292+
# Create a dummy mdp_ts_json if mdp-load-partition-config not provided and num_devices > 1
293+
if mdp_ts_json_path is not None:
294+
mdp_ts_json = load_json(str(mdp_ts_json_path))
295+
elif mdp_ts_num_devices > 1:
296+
mdp_ts_json = generate_mdp_partition_config(
297+
mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES)
298+
)
299+
else:
300+
mdp_ts_json = None
301+
292302
compile_hash = hashlib.sha256(to_hashable(command))
293303

294304
if specializations is not None:
@@ -299,30 +309,37 @@ def _compile(
299309

300310
if num_speculative_tokens:
301311
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
302-
# Hash num_devices too, since default value would always be 1.
303-
compile_hash.update(to_hashable(mdp_ts_num_devices))
312+
313+
# Hash the MDP partition config and the number of devices.
314+
compile_hash.update(to_hashable(mdp_ts_json))
315+
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))
304316

305317
# Check if already compiled
306318
compile_hash = compile_hash.hexdigest()[:16]
307319
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
308320
qpc_path = compile_dir / "qpc"
309321
qpc_path.mkdir(parents=True, exist_ok=True)
322+
310323
if qpc_path.is_dir():
311324
if (qpc_path / "programqpc.bin").is_file():
312325
self.qpc_path = qpc_path
313326
return qpc_path
314327
# Probably compilation failure last time, delete directory to start over
315328
shutil.rmtree(qpc_path)
316329

330+
# write the MDP partition config file if not provided
331+
if mdp_ts_json is not None:
332+
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
333+
create_json(str(mdp_ts_json_path), mdp_ts_json)
334+
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
335+
317336
# Write specializations.json file
318337
if specializations is not None:
319338
specializations_json = compile_dir / "specializations.json"
320-
with open(specializations_json, "w") as fp:
321-
json.dump(
322-
{"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]},
323-
fp,
324-
indent=4,
325-
)
339+
specializations_data = {
340+
"specializations": [{k: str(v) for k, v in spec.items()} for spec in specializations]
341+
}
342+
create_json(str(specializations_json), specializations_data)
326343
command.append(f"-network-specialization-config={specializations_json}")
327344

328345
# Write custom_io.yaml file
@@ -333,26 +350,6 @@ def _compile(
333350
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")
334351
command.append(f"-custom-IO-list-file={custom_io_yaml}")
335352

336-
# Write mdp_config.json file
337-
if not mdp_ts_json_path and mdp_ts_num_devices > 1:
338-
num_cores = compiler_options.get("aic_num_cores", 16)
339-
mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
340-
with open(mdp_ts_json, "w") as fp:
341-
json.dump(
342-
{
343-
"connections": [{"devices": list(range(mdp_ts_num_devices)), "type": "p2p"}],
344-
"partitions": [
345-
{
346-
"name": "Partition0",
347-
"devices": [{"deviceId": d, "numCores": num_cores} for d in range(mdp_ts_num_devices)],
348-
}
349-
],
350-
},
351-
fp,
352-
indent=4,
353-
)
354-
command.append(f"-mdp-load-partition-config={mdp_ts_json}")
355-
356353
command.append(f"-aic-binary-dir={qpc_path}")
357354
logger.info(f"Running compiler: {' '.join(command)}")
358355
try:

QEfficient/exporter/export_hf_to_cloud_ai_100.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def export_bertstyle_model_to_onnx(model_name, model, tokenizer, onnx_dir_path,
129129
)
130130

131131
# Generate inputFiles
132-
# todo(ochougul):rename to bert_style_input_list.txt
133132
input_list_file = os.path.join(onnx_dir_path, "input_list.txt")
134133
generate_input_files(
135134
input_files_path=os.path.join(onnx_dir_path, "inputFiles"),

QEfficient/exporter/export_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,6 @@ def fix_onnx_fp16(
218218
:str: Updated base name of exported ONNX model.
219219
"""
220220
model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
221-
# TODO: Remove this `fix_onnx_fp16` function and replace with this transform
222-
# as we're not utilizing the validations done in this function
223221
model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path)
224222

225223
if fp16_fix:
@@ -256,8 +254,6 @@ def fix_onnx_fp16(
256254
if ort_outputs is not None:
257255
for oname, orto, ortof in zip(output_names, ort_outputs, ort_outputs_fixed):
258256
fix_diff = np.abs(orto.astype(np.float32) - ortof.astype(np.float32)).max()
259-
# TODO: need to the debug this
260-
# info(oname, fix_diff)
261257
close_outputs.append(fix_diff < 1e-5)
262258
else:
263259
info("No constants out of FP16 range")

QEfficient/finetune/data/sampler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7-
87
import random
98
from itertools import islice
109

11-
import numpy as np
1210
import torch
1311

1412

@@ -22,14 +20,14 @@ def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool
2220
self.batch_size = batch_size
2321
self.drop_last = drop_last
2422
self.shuffle = shuffle
23+
self.data_source = data_source
2524

2625
def __iter__(self):
27-
ids = np.argsort(self.lengths, kind="mergesort")
26+
ids = list(range(len(self.data_source)))
2827
if self.drop_last:
2928
ids = ids[: len(ids) // self.batch_size * self.batch_size]
3029

3130
batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)]
32-
3331
if self.shuffle:
3432
random.shuffle(batches)
3533

@@ -45,11 +43,17 @@ def __len__(self):
4543

4644
class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
4745
def __init__(
48-
self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0
46+
self,
47+
data_source,
48+
batch_size: int,
49+
num_replicas: int,
50+
rank: int,
51+
shuffle: bool = True,
52+
seed: int = 0,
4953
) -> None:
5054
random.seed(seed)
5155
self.batch_sampler = LengthBasedBatchSampler(
52-
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
56+
data_source, batch_size=batch_size, drop_last=False, shuffle=shuffle
5357
)
5458
self.num_replicas = num_replicas
5559
self.rank = rank

QEfficient/finetune/dataset/samsum_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
12-
dataset = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
12+
dataset = datasets.load_dataset("knkarthick/samsum", split=split, trust_remote_code=True)
1313

1414
prompt = "Summarize this dialog:\n{dialog}\n---\nSummary:\n"
1515

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7-
7+
import datasets
88
import torch
99
import torch.distributed as dist
1010
from transformers.data import DataCollatorForSeq2Seq
1111

1212
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
1313
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
1414
from QEfficient.finetune.utils.logging_utils import logger
15+
from QEfficient.finetune.utils.helper import get_num_ddp_devices
1516

1617

1718
def get_preprocessed_dataset(
@@ -56,20 +57,51 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
5657
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False
5758
)
5859
kwargs["batch_size"] = batch_size
59-
kwargs["drop_last"] = True
60+
kwargs["drop_last"] = False
6061
else:
6162
kwargs["batch_size"] = batch_size
62-
kwargs["drop_last"] = True
63+
kwargs["drop_last"] = False
6364
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
6465
return kwargs
6566

6667

68+
def padding_dataset(train_config, dataset, batch_size):
69+
if train_config.enable_ddp and train_config.enable_sorting_for_ddp:
70+
if isinstance(dataset, datasets.Dataset):
71+
# Hugging Face Dataset transformation
72+
dataset = dataset.map(lambda x: {"input_length": len(x["input_ids"])})
73+
dataset = dataset.sort("input_length")
74+
75+
else:
76+
dataset = sorted(dataset, key=lambda x: len(x["input_ids"]))
77+
78+
dummy_row = next(iter(dataset))
79+
dummy_row["labels"] = torch.tensor([-100] * len(dummy_row["labels"]))
80+
padding_size = 0
81+
num_replicas = get_num_ddp_devices()
82+
remainder = len(dataset) % (num_replicas * batch_size)
83+
padding_size = (num_replicas * batch_size) - remainder
84+
85+
dummy_data = [dummy_row.copy() for _ in range(padding_size)]
86+
dummy_dataset = datasets.Dataset.from_list(dummy_data)
87+
if isinstance(dataset, datasets.Dataset):
88+
combined_dataset = datasets.concatenate_datasets([dataset, dummy_dataset])
89+
else:
90+
combined_dataset = dataset + list(dummy_dataset)
91+
return combined_dataset
92+
93+
6794
def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
6895
dataset = get_preprocessed_dataset(tokenizer, dataset_config, split, context_length=train_config.context_length)
96+
97+
batch_size = train_config.train_batch_size if split == "train" else train_config.val_batch_size
98+
dataset = padding_dataset(train_config, dataset, batch_size)
99+
69100
dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
70101

71102
# FIXME (Meet): Add custom data collator registration from the outside by the user.
72103
custom_data_collator = get_custom_data_collator(tokenizer, dataset_config)
104+
73105
if custom_data_collator:
74106
print("custom_data_collator is used")
75107
dl_kwargs["collate_fn"] = custom_data_collator

QEfficient/finetune/utils/helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
import os
78

89
import os
910

@@ -15,3 +16,6 @@
1516

1617
def is_rank_zero():
1718
return int(os.getenv("LOCAL_RANK", 0)) == 0
19+
20+
def get_num_ddp_devices():
21+
return int(os.getenv("WORLD_SIZE", 1))

0 commit comments

Comments
 (0)