1313import sys
1414from datetime import datetime
1515from pathlib import Path
16- from typing import Any , Dict , List
16+ from typing import Any , Dict , List , Optional
1717
1818import torch
1919import torch .nn .functional as F
2424from torch import Tensor
2525from torch import nn
2626from torch .utils .tensorboard import SummaryWriter
27- from tqdm import tqdm
28- from tqdm import trange
2927from transformers import AutoModelForCausalLM
3028from transformers import AutoTokenizer
3129from whowhatbench import TextEvaluator
3230
3331import nncf
32+ from nncf .common .logging .track_progress import track
3433from nncf .data .dataset import Dataset
3534from nncf .parameters import CompressionFormat
3635from nncf .parameters import CompressWeightsMode
3736from nncf .parameters import StripFormat
37+ from nncf .quantization .advanced_parameters import AdvancedCompressionParameters
3838from nncf .quantization .quantize_model import compress_weights
3939from nncf .torch .model_creation import load_from_config
4040from nncf .torch .quantization .layers import AsymmetricLoraQuantizer
4141from nncf .torch .quantization .layers import SymmetricLoraQuantizer
4242
4343
44- def get_wikitext2 (nsamples : int , seqlen : int , tokenizer : Any , device : torch .device ) -> List [Tensor ]:
44+ def get_wikitext2 (num_samples : int , seqlen : int , tokenizer : Any , device : torch .device ) -> List [Tensor ]:
4545 """
4646 Loads and processes the Wikitext-2 dataset for training.
4747
48- :param nsamples : Number of samples to generate.
48+ :param num_samples : Number of samples to generate.
4949 :param seqlen: Sequence length for each sample.
5050 :param tokenizer: Tokenizer to encode the text.
5151 :param device: Device to move the tensors to (e.g., 'cpu' or 'cuda').
5252 :return: A list of tensors containing the tokenized text samples.
5353 """
5454 traindata = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "train" )
55- limit = nsamples * seqlen // 4 # ~1k for 128 samples with seqlen=32 to be aligned with optimum
55+ limit = num_samples * seqlen // 4 # ~1k for 128 samples with seqlen=32 to be aligned with optimum
5656 text = "" .join ([" \n " if s == "" else s for s in traindata ["text" ][:limit ]])
5757 trainenc = tokenizer (text , return_tensors = "pt" )
5858 trainloader = []
59- for _ in range (nsamples ):
59+ for _ in range (num_samples ):
6060 # Crop a sequence of tokens of length seqlen starting at a random position
6161 i = torch .randint (0 , trainenc .input_ids .shape [1 ] - seqlen - 1 , (1 ,)).item ()
6262 j = i + seqlen
@@ -66,7 +66,7 @@ def get_wikitext2(nsamples: int, seqlen: int, tokenizer: Any, device: torch.devi
6666
6767
6868@torch .no_grad ()
69- def save_wwb_ref (model : str , tokenizer : Any , wwb_ref_file : Path ) -> None :
69+ def save_wwb_ref (model : str , tokenizer : Any , wwb_ref_file : Path , num_samples : Optional [ int ] = None ) -> None :
7070 """
7171 Save the reference answers for the WWB (WhoWhatBenchmark) evaluation.
7272
@@ -76,12 +76,14 @@ def save_wwb_ref(model: str, tokenizer: Any, wwb_ref_file: Path) -> None:
7676 """
7777 if not wwb_ref_file .exists ():
7878 print ("#" * 50 + " Collect reference answers for WWB " + "#" * 50 )
79- wwb_eval = TextEvaluator (base_model = model , tokenizer = tokenizer , use_chat_template = True )
79+ wwb_eval = TextEvaluator (base_model = model , tokenizer = tokenizer , use_chat_template = True , num_samples = num_samples )
8080 wwb_eval .dump_gt (str (wwb_ref_file ))
8181 torch .cuda .empty_cache ()
8282
8383
84- def measure_similarity (model_for_eval : OVModelForCausalLM , tokenizer : Any , wwb_ref_file : Path ) -> float :
84+ def measure_similarity (
85+ model_for_eval : OVModelForCausalLM , tokenizer : Any , wwb_ref_file : Path , num_samples : Optional [int ] = None
86+ ) -> float :
8587 """
8688 Measures the similarity of a model's output to a reference outputs from a given file using WWB evaluation.
8789
@@ -92,7 +94,11 @@ def measure_similarity(model_for_eval: OVModelForCausalLM, tokenizer: Any, wwb_r
9294 """
9395 print ("#" * 50 + " Evaluate via WWB " + "#" * 50 )
9496 wwb_eval = TextEvaluator (
95- tokenizer = tokenizer , gt_data = wwb_ref_file , test_data = str (wwb_ref_file ), use_chat_template = True
97+ tokenizer = tokenizer ,
98+ gt_data = wwb_ref_file ,
99+ test_data = str (wwb_ref_file ),
100+ use_chat_template = True ,
101+ num_samples = num_samples ,
96102 )
97103 _ , all_metrics = wwb_eval .score (model_for_eval )
98104 return float (all_metrics ["similarity" ].iloc [0 ])
@@ -108,8 +114,8 @@ def calc_hiddens(model: nn.Module, dataloader: List[Tensor]) -> List[Tensor]:
108114 :return: A list of hidden states for each input in the dataloader.
109115 """
110116 orig_hiddens = []
111- for i in trange ( len ( dataloader ), total = len ( dataloader ), desc = "Calculating original hiddens" , leave = False ):
112- model_input = get_model_input (dataloader [ i ] )
117+ for data in track ( dataloader , description = "Calculating original hiddens" ):
118+ model_input = get_model_input (data )
113119 orig_hiddens .append (model .model (** model_input ).last_hidden_state )
114120 torch .cuda .empty_cache ()
115121 return orig_hiddens
@@ -260,10 +266,12 @@ def get_argument_parser() -> argparse.ArgumentParser:
260266 help = "Whether to start from previously saved checkpoint. If not specified or checkpoint does not exist, "
261267 "start from scratch by post-training weight compression initialization." ,
262268 )
269+ parser .add_argument ("--lora_rank" , type = int , default = 256 , help = "Rank of lora adapters" )
263270
264271 # Data params
265- parser .add_argument ("--nsamples " , type = int , default = 1024 , help = "Number of training samples" )
272+ parser .add_argument ("--num_train_samples " , type = int , default = 1024 , help = "Number of training samples" )
266273 parser .add_argument ("--seqlen" , type = int , default = 1024 , help = "Calibration data context length." )
274+ parser .add_argument ("--num_val_samples" , type = int , default = None , help = "Number of validation samples for WWB." )
267275
268276 # Training params
269277 parser .add_argument (
@@ -286,7 +294,7 @@ def get_argument_parser() -> argparse.ArgumentParser:
286294
287295def main (argv ) -> float :
288296 """
289- Fine-tunes the specified model and returns the best validation similarity score .
297+ Fine-tunes the specified model and returns the difference between initial and best validation similarity scores .
290298 """
291299 parser = get_argument_parser ()
292300 args = parser .parse_args (argv )
@@ -295,7 +303,10 @@ def main(argv) -> float:
295303 device = "cuda"
296304 torch_dtype = torch .bfloat16
297305 compression_config = dict (
298- mode = CompressWeightsMode .INT4_ASYM , group_size = 64 , compression_format = CompressionFormat .FQ_LORA
306+ mode = CompressWeightsMode .INT4_ASYM ,
307+ group_size = 64 ,
308+ compression_format = CompressionFormat .FQ_LORA ,
309+ advanced_parameters = AdvancedCompressionParameters (lora_adapter_rank = args .lora_rank ),
299310 )
300311
301312 # Configure output and log files.
@@ -320,11 +331,13 @@ def main(argv) -> float:
320331 # computed by for data generated by two models, original floating-point one and optimized.
321332 # TODO: (nlyalyus) Use original model for collecting reference, once the bug in WWB resolved.
322333 wwb_ref_model = AutoModelForCausalLM .from_pretrained (args .pretrained , torch_dtype = torch_dtype , device_map = "cpu" )
323- save_wwb_ref (wwb_ref_model , tokenizer , wwb_ref_file )
334+ save_wwb_ref (wwb_ref_model , tokenizer , wwb_ref_file , args . num_val_samples )
324335 del wwb_ref_model
325336
326337 # Prepare training data and pre-compute hiddens of teacher model for distillation loss.
327- train_loader = get_wikitext2 (nsamples = args .nsamples , seqlen = args .seqlen , tokenizer = tokenizer , device = device )
338+ train_loader = get_wikitext2 (
339+ num_samples = args .num_train_samples , seqlen = args .seqlen , tokenizer = tokenizer , device = device
340+ )
328341 orig_hiddens = calc_hiddens (model , train_loader )
329342
330343 # Create or load model to tune with Fake Quantizers and absorbable LoRA adapters.
@@ -341,9 +354,11 @@ def main(argv) -> float:
341354
342355 # Convert torch checkpoint to an OpenVINO model and evaluate it via WWB.
343356 model_for_eval = export_to_openvino (args .pretrained , train_loader [0 ], ckpt_file , last_dir )
344- best_similarity = measure_similarity (model_for_eval , tokenizer , wwb_ref_file )
345- tb .add_scalar ("similarity" , best_similarity , 0 )
346- print (f"Initial WWB similarity= { best_similarity :.4f} " )
357+ initial_similarity = best_similarity = measure_similarity (
358+ model_for_eval , tokenizer , wwb_ref_file , args .num_val_samples
359+ )
360+ tb .add_scalar ("similarity" , initial_similarity , 0 )
361+ print (f"Initial WWB similarity= { initial_similarity :.4f} " )
347362
348363 # Run tuning with distillation loss and validation on WWB after each epoch.
349364 grad_accumulation_steps = args .batch_size // args .microbatch_size
@@ -354,7 +369,7 @@ def main(argv) -> float:
354369 loss_numerator = grad_steps = total_microbatches = 0
355370 for epoch in range (args .epochs ):
356371 batch_indices_epoch = torch .randperm (num_samples )[:epoch_samples ].chunk (microbatches_per_epoch )
357- for indices in tqdm (batch_indices_epoch , desc = f"Train epoch { epoch } " , leave = [ False ] ):
372+ for indices in track (batch_indices_epoch , description = f"Train epoch { epoch } " ):
358373 indices = indices .tolist ()
359374 total_microbatches += 1
360375
@@ -393,16 +408,16 @@ def form_batch(inputs: List[Tensor], model_input: bool):
393408 # Save the best checkpoint and OpenVINO IR for the highest similarity score obtained from WWB.
394409 save_checkpoint (model , ckpt_file )
395410 model_for_eval = export_to_openvino (args .pretrained , train_loader [0 ], ckpt_file , last_dir )
396- similarity = measure_similarity (model_for_eval , tokenizer , wwb_ref_file )
397- print (f"[Epoch { epoch } ], WWB similarity = { similarity :.4f} " )
411+ similarity = measure_similarity (model_for_eval , tokenizer , wwb_ref_file , args . num_val_samples )
412+ print (f"[Epoch { epoch + 1 } ], WWB similarity = { similarity :.4f} " )
398413 tb .add_scalar ("similarity" , similarity , total_microbatches )
399414 if similarity > best_similarity :
400415 print (f"New best WWB similarity = { similarity :.4f} " )
401416 best_similarity = similarity
402417 shutil .copytree (last_dir , best_dir , dirs_exist_ok = True )
403418
404- print (f"The finetuned OV model with the best similarity={ best_similarity } saved to: { best_dir } " )
405- return best_similarity
419+ print (f"The finetuned OV model with the best similarity={ best_similarity :.4f } saved to: { best_dir } " )
420+ return best_similarity - initial_similarity
406421
407422
408423if __name__ == "__main__" :
0 commit comments