@@ -87,18 +87,27 @@ def quantize(self):
8787
8888 layer_groups = self ._get_model_layer_groups ()
8989 self ._add_mods_to_model_size (self .excluded_mods )
90+
91+ print (f'Calibration set size: { self .n_samples } ' )
9092 calibration_set = self ._get_calibration_set ()
9193
9294 # Run calibration set through model
9395 first_inputs , self .layer_args , self .layer_kwargs = self ._gather_first_inputs (layer_groups , calibration_set )
9496
95- self .model = self .model .to ('cpu' )
97+ del calibration_set
98+ gc .collect ()
99+ torch .cuda .empty_cache ()
100+
101+
102+ self .model .to ('cpu' )
96103 for layer_group , modules in layer_groups .items ():
97104 self .inps = first_inputs [layer_group ]
98105
99106 # quantize layer-by-layer
100107 for i in tqdm (range (len (modules )), desc = f"Quantizing { layer_group } " ):
101108
109+ pass
110+
102111 # move layer inputs to gpu
103112 self .inps = self .inps .to (self .device )
104113
@@ -129,6 +138,9 @@ def quantize(self):
129138 assert torch .all (scale )
130139 self ._apply_scales (scale , group ['prev_op' ], group ['modules' ], layer )
131140
141+ scale = scale .to ('cpu' )
142+ clear_memory (scale )
143+
132144
133145 # solve for and apply clipping
134146 clips = self ._search_best_clip (named_linears , linear_inputs , w_bits_dict )
@@ -149,7 +161,14 @@ def quantize(self):
149161 layer = layer .to ('cpu' )
150162 clear_memory (layer )
151163
152- modules = modules .to ('cpu' )
164+ self .model .to ('cpu' )
165+ gc .collect ()
166+ torch .cuda .empty_cache ()
167+ pass
168+
169+
170+ clear_memory (first_inputs [layer_group ])
171+
153172
154173
155174
@@ -217,6 +236,12 @@ def forward(self, *args, **kwargs):
217236 if type (calibration_set ) == torch .tensor :
218237 calibration_set = calibration_set .cpu ()
219238 clear_memory (calibration_set )
239+ else :
240+ for key in calibration_set .keys ():
241+ calibration_set [key ] = calibration_set [key ].to ('cpu' )
242+ clear_memory (calibration_set [key ])
243+
244+ del calibration_set
220245
221246 for _ , modules in layer_groups .items ():
222247 # restore proper module at beginning of layer group
@@ -247,8 +272,13 @@ def input_hook(module, input, output, module_name, inputs):
247272 )
248273
249274 # compute next set of inputs, grabbing linear inputs through the hooks
250- self .inps = layer (self .inps , * self .layer_args [layer_group ], ** self .layer_kwargs [layer_group ])
251- self .inps = self .inps [0 ].to ('cpu' )
275+ # self.inps = layer(self.inps, *self.layer_args[layer_group], **self.layer_kwargs[layer_group])
276+ out = layer (self .inps , * self .layer_args [layer_group ], ** self .layer_kwargs [layer_group ])[0 ].to ('cpu' )
277+ self .inps = self .inps .to ('cpu' )
278+ clear_memory (self .inps )
279+
280+ # self.inps = self.inps[0].to('cpu')
281+ self .inps = out
252282
253283 # remove hooks from model
254284 for hook in hooks :
@@ -282,7 +312,7 @@ def _compute_scales(self, layer, prev_op, modules, inp, parent_module, layer_kwa
282312 clear_memory (W )
283313
284314 # per channel mean of input (activation)
285- X_mean = inp .abs ().view (- 1 , inp .shape [- 1 ]).mean (0 )
315+ X_mean = inp .cpu (). abs ().view (- 1 , inp .shape [- 1 ]).mean (0 )
286316 X_mean = X_mean .view (- 1 )
287317
288318 kwargs = sanitize_kwargs (layer_kwargs , parent_module )
@@ -354,6 +384,16 @@ def _compute_scales(self, layer, prev_op, modules, inp, parent_module, layer_kwa
354384 assert best_ratio != - 1 , "best scales ratio never set"
355385 assert torch .isnan (best_scales ).sum () == 0 , best_scales
356386
387+ # NOTE: trying this to save memory...
388+ inp = inp .to ('cpu' )
389+ clear_memory (inp )
390+ fp_output = fp_output .to ('cpu' )
391+ clear_memory (fp_output )
392+ q_output = q_output .to ('cpu' )
393+ clear_memory (q_output )
394+ scales_view = scales_view .to ('cpu' )
395+ clear_memory (scales_view )
396+
357397 return best_scales .detach ().cpu ()
358398
359399
0 commit comments