Skip to content

Commit 516ac3e

Browse files
mamtsingquic-mamta
authored andcommitted
Merge branch 'main' into use_logger
Signed-off-by: Mamta Singh <[email protected]>
2 parents 20000b6 + 3aaa2d8 commit 516ac3e

File tree

7 files changed

+108
-69
lines changed

7 files changed

+108
-69
lines changed

QEfficient/finetune/utils/helper.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
#
66
# -----------------------------------------------------------------------------
77
import os
8+
from contextlib import nullcontext
9+
10+
import torch
11+
12+
try:
13+
import torch_qaic.debug as qaic_debug # noqa: F401
14+
except ImportError as e:
15+
print(f"Warning: {e}. Moving ahead without these qaic modules.")
16+
817

918
TASK_TYPE = ["generation", "seq_classification"]
1019
PEFT_METHOD = ["lora"]
@@ -18,3 +27,34 @@ def is_rank_zero():
1827

1928
def get_num_ddp_devices():
2029
return int(os.getenv("WORLD_SIZE", 1))
30+
31+
32+
def get_autocast_ctx(use_autocast, device_type, dtype=torch.float16):
33+
return torch.autocast(device_type=device_type, dtype=dtype) if use_autocast else nullcontext()
34+
35+
36+
def get_op_verifier_ctx(
37+
use_op_by_op_verifier,
38+
train_device,
39+
dump_dir,
40+
step,
41+
ref_device="cpu",
42+
ref_dtype=torch.float32,
43+
atol=1e-1,
44+
rtol=1e-5,
45+
use_ref_output_on_mismatch=True,
46+
):
47+
if not use_op_by_op_verifier:
48+
return nullcontext()
49+
50+
filter_config = qaic_debug.DispatchFilterConfig.default(train_device)
51+
dump_dir = dump_dir + "/mismatches/step_" + str(step)
52+
return qaic_debug.OpByOpVerifierMode(
53+
ref_device=ref_device,
54+
ref_dtype=ref_dtype,
55+
atol=atol,
56+
rtol=rtol,
57+
use_ref_output_on_mismatch=use_ref_output_on_mismatch,
58+
filter_config=filter_config,
59+
dump_root_dir=dump_dir,
60+
)

QEfficient/finetune/utils/logging_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ def raise_error(message, errortype=RuntimeError):
2525
raise errortype(message)
2626

2727
def log_rank_zero(msg: str, level: int = logging.INFO):
28-
if not is_rank_zero:
29-
return
30-
self.logger.log(level, msg, stacklevel=2)
28+
if is_rank_zero():
29+
self.logger.log(level, msg, stacklevel=2)
3130

3231
def prepare_for_logs(output_path, dump_logs=False, level=logging.INFO):
3332
self.logger.setLevel(level)

QEfficient/finetune/utils/plot_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def plot_metrics(file_path):
7575
with open(file_path, "r") as f:
7676
try:
7777
data = json.load(f)
78-
except json.JSONDecodeError:
79-
logger.raise_error("Invalid JSON file.", json.JSONDecodeError)
78+
except json.JSONDecodeError as e:
79+
logger.raise_error("Invalid JSON file.", e)
8080
return
8181

8282
directory = os.path.dirname(file_path)

QEfficient/finetune/utils/train_utils.py

Lines changed: 39 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import json
99
import os
1010
import time
11-
from contextlib import nullcontext
1211
from datetime import datetime
12+
from functools import partial
1313
from typing import Dict, List, Tuple
1414

1515
import torch
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
from QEfficient.finetune.configs.training import TrainConfig
22-
from QEfficient.finetune.utils.helper import is_rank_zero
22+
from QEfficient.finetune.utils.helper import get_autocast_ctx, get_op_verifier_ctx, is_rank_zero
2323
from QEfficient.finetune.utils.logging_utils import logger
2424

2525
try:
@@ -85,8 +85,8 @@ def train(
8585
max_steps_reached = False # Flag to indicate max training steps reached
8686

8787
tensorboard_updates = None
88-
tensorboard_log_dir = train_config.output_dir + "/runs/" + f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
8988
if is_rank_zero():
89+
tensorboard_log_dir = train_config.output_dir + "/runs/" + f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
9090
tensorboard_updates = SummaryWriter(log_dir=tensorboard_log_dir)
9191

9292
device_type = torch.device(device).type
@@ -110,6 +110,9 @@ def train(
110110
num_classes = model.classifier.out_features
111111
acc_helper = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes).to(device)
112112

113+
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
114+
op_verifier_ctx = partial(get_op_verifier_ctx, train_config.opByOpVerifier, device, train_config.output_dir)
115+
113116
# Start the training loop
114117
for epoch in range(train_config.num_epochs):
115118
if loss_0_counter.item() == train_config.convergence_counter:
@@ -168,60 +171,38 @@ def train(
168171
break
169172
batch = {k: v.to(device) for k, v in batch.items()} # move the batch elements to qaic device
170173

171-
with (
172-
torch.autocast(device_type=device_type, dtype=torch.float16)
173-
if train_config.use_autocast
174-
else nullcontext()
175-
):
176-
# an additional condition can be put here to avoid opByOpVerifier getting triggered for each step
177-
if train_config.opByOpVerifier:
178-
with qaic_debug.OpByOpVerifierMode(
179-
ref_device="cpu",
180-
ref_dtype=torch.float32,
181-
# adjust atol & rtol this as required
182-
atol=1e-1,
183-
use_ref_output_on_mismatch=True,
184-
filter_config=qaic_debug.DispatchFilterConfig.default(device),
185-
dump_root_dir=train_config.output_dir + "/mismatches/step_" + str(step),
186-
) as verifier:
187-
model_outputs = model(**batch)
188-
loss = model_outputs.loss # Forward call
189-
if (batch["labels"] != -100).sum() == 0:
190-
loss = loss.nan_to_num(nan=0.0)
191-
num_dummy_samples += train_config.train_batch_size
192-
else:
193-
num_dummy_samples_per_batch = (
194-
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
195-
)
196-
if num_dummy_samples_per_batch > 0:
197-
num_dummy_samples += num_dummy_samples_per_batch
198-
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
199-
200-
if train_config.task_type == "seq_classification":
201-
logits = model_outputs.logits
202-
labels = batch["labels"][:, 0]
203-
preds = torch.nn.functional.softmax(logits, dim=-1)
204-
acc_helper.forward(preds, labels)
205-
logger.info("Mismatches detected:", verifier.get_perop_mismatch_count())
174+
is_optimizer_step = (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(
175+
train_dataloader
176+
) - 1
177+
if train_config.enable_ddp:
178+
# Below block derived from : https://github.com/karpathy/nanoGPT/blob/93a43d9a5c22450bbf06e78da2cb6eeef084b717/train.py#L293
179+
# in DDP training we only need to sync gradients at the last micro step.
180+
# the official way to do this is with model.no_sync() context manager, but
181+
# using too many context managers may bloat the code and forces us to repeat code
182+
# looking at the source of that context manager, it just toggles this variable
183+
model.require_backward_grad_sync = is_optimizer_step
184+
185+
with autocast_ctx, op_verifier_ctx(step) as verifier:
186+
model_outputs = model(**batch)
187+
loss = model_outputs.loss # Forward call
188+
if (batch["labels"] != -100).sum() == 0:
189+
loss = loss.nan_to_num(nan=0.0)
190+
num_dummy_samples += train_config.train_batch_size
206191
else:
207-
model_outputs = model(**batch)
208-
loss = model_outputs.loss # Forward call
209-
if (batch["labels"] != -100).sum() == 0:
210-
loss = loss.nan_to_num(nan=0.0)
211-
num_dummy_samples += train_config.train_batch_size
212-
else:
213-
num_dummy_samples_per_batch = (
214-
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
215-
)
216-
if num_dummy_samples_per_batch > 0:
217-
num_dummy_samples += num_dummy_samples_per_batch
218-
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
192+
num_dummy_samples_per_batch = (
193+
(torch.sum(batch["labels"] == -100, dim=1) == batch["labels"].shape[1]).sum().item()
194+
)
195+
if num_dummy_samples_per_batch > 0:
196+
num_dummy_samples += num_dummy_samples_per_batch
197+
loss = loss * train_config.train_batch_size / num_dummy_samples_per_batch
219198

220-
if train_config.task_type == "seq_classification":
221-
logits = model_outputs.logits
222-
labels = batch["labels"][:, 0]
223-
preds = torch.nn.functional.softmax(logits, dim=-1)
224-
acc_helper.forward(preds, labels)
199+
if train_config.task_type == "seq_classification":
200+
logits = model_outputs.logits
201+
labels = batch["labels"][:, 0]
202+
preds = torch.nn.functional.softmax(logits, dim=-1)
203+
acc_helper.forward(preds, labels)
204+
if train_config.opByOpVerifier:
205+
logger.info("Mismatches detected:", verifier.get_perop_mismatch_count())
225206

226207
total_loss += loss.detach().float()
227208
if is_rank_zero():
@@ -258,7 +239,7 @@ def train(
258239
else:
259240
loss.backward() # backward pass
260241

261-
if (step + 1) % train_config.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
242+
if is_optimizer_step:
262243
if train_config.grad_scaler:
263244
scaler.step(optimizer)
264245
scaler.update()
@@ -440,6 +421,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
440421
device_type = torch.device(device).type
441422

442423
num_dummy_samples = 0
424+
autocast_ctx = get_autocast_ctx(train_config.use_autocast, device_type, dtype=torch.float16)
443425
for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
444426
# stop when the maximum number of eval steps is reached
445427
if train_config.max_eval_step > 0 and step > train_config.max_eval_step:
@@ -450,11 +432,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
450432
# Ensure no gradients are computed for this scope to save memory
451433
with torch.no_grad():
452434
# Forward pass and compute loss
453-
with (
454-
torch.autocast(device_type=device_type, dtype=torch.float16)
455-
if train_config.use_autocast
456-
else nullcontext()
457-
):
435+
with autocast_ctx:
458436
outputs = model(**batch)
459437
loss = outputs.loss
460438

QEfficient/generation/text_generation_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __repr__(self):
6060
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)} sec\
6161
\nDecode is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)} tokens/sec\
6262
\nTotal is= {round(self.perf_metrics.total_perf * self.batch_size, 2)} tokens/sec\
63-
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} tokens/sec"
63+
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)} sec"
6464

6565

6666
@dataclass

QEfficient/utils/generate_qnn_network_specialization_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def generate_data_format_config(
166166
for output in onnx_model.graph.output:
167167
if "past_key" in output.name or "past_value" in output.name:
168168
kv_nodes.append(output.name)
169-
kv_overrides = {}
170169

170+
kv_overrides = {}
171171
kv_overrides["graphs"] = [
172172
{
173173
"graph_name": model_dlc_name + "_configuration_1",

docs/source/quick_start.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ python -m QEfficient.cloud.execute --model_name gpt2 --qpc_path qeff_models/gpt2
9494
You can run the finetune with set of predefined existing datasets on QAIC using the eager pipeline
9595

9696
```bash
97-
python -m QEfficient.cloud.finetune --device qaic:0 --use-peft --output_dir ./meta-sam --num_epochs 2 --context_length 256
97+
python -m QEfficient.cloud.finetune --device qaic:0 --use-peft --output_dir ./meta-sam --num_epochs 2 --context_length 256
9898
```
9999
For more details on finetune, checkout the subsection.
100100

@@ -138,6 +138,28 @@ Users can compile a model with QNN SDK by following the steps below:
138138
* Enabled QNN by passing enable_qnn flag, add --enable_qnn in the cli command.
139139
* An optional config file can be passed to override the default parameters.
140140

141+
**Default Parameters**
142+
143+
QNN Converter Stage:
144+
145+
"--float_bias_bitwidth 32 --float_bitwidth 16 --preserve_io_datatype --onnx_skip_simplification --target_backend AIC"
146+
147+
QNN Context Binary Stage:
148+
149+
LOG_LEVEL = "error"
150+
COMPILER_COMPILATION_TARGET = "hardware"
151+
COMPILER_CONVERT_TO_FP16 = True
152+
COMPILER_DO_DDR_TO_MULTICAST = True
153+
COMPILER_HARDWARE_VERSION = "2.0"
154+
COMPILER_PERF_WARNINGS = False
155+
COMPILER_PRINT_DDR_STATS = False
156+
COMPILER_PRINT_PERF_METRICS = False
157+
COMPILER_RETAINED_STATE = True
158+
COMPILER_STAT_LEVEL = 10
159+
COMPILER_STATS_BATCH_SIZE = 1
160+
COMPILER_TIME_PASSES = False
161+
162+
141163
**CLI Inference Command**
142164

143165
Without QNN Config

0 commit comments

Comments
 (0)