|
16 | 16 | CPUOffloadEnabled = False |
17 | 17 |
|
18 | 18 |
|
19 | | -def set_offloading_param(tensor, param_name, value): |
| 19 | +def mark_activation_offload(*tensors): |
20 | 20 | """Set the type of the offloading needed for a tensor.""" |
21 | | - assert param_name in ["weight_offloading", "activation_offloading"] |
22 | | - if tensor is None: |
23 | | - return |
24 | | - if type(tensor) in [torch.Tensor, torch.nn.Parameter]: |
25 | | - setattr(tensor, param_name, value) |
26 | | - else: |
27 | | - data_tensors = tensor.get_data_tensors() |
28 | | - for tensor in data_tensors: |
29 | | - if tensor is not None: |
30 | | - setattr(tensor, param_name, value) |
| 21 | + for tensor in tensors: |
| 22 | + if tensor is None: |
| 23 | + continue |
| 24 | + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: |
| 25 | + tensor.activation_offloading = True |
| 26 | + else: |
| 27 | + data_tensors = tensor.get_data_tensors() |
| 28 | + for tensor in data_tensors: |
| 29 | + if tensor is not None: |
| 30 | + tensor.activation_offloading = True |
| 31 | + # This is a hack to force clear the tensor after it is offloaded. |
| 32 | + # It is needed, because .*TensorBase classes are saved in the ctx, |
| 33 | + # and they contain the reference to their data tensors. |
| 34 | + tensor.needs_force_clear = True |
31 | 35 |
|
32 | 36 |
|
33 | 37 | def is_cpu_offload_enabled() -> bool: |
@@ -459,8 +463,15 @@ def synchronize_on_group_commit_forward(self, current_group): |
459 | 463 | torch.cuda.current_stream().wait_stream(self.d2h_stream) |
460 | 464 |
|
461 | 465 | # Time to free the activation memory after usage |
462 | | - for tensor_tag, _ in self.tensor_tag_to_buf.items(): |
| 466 | + for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): |
463 | 467 | if tensor_tag[0] == self.offloaded_group_count: |
| 468 | + if hasattr(tensor_buf, "needs_force_clear"): |
| 469 | + # Need to clear activation tensor - sometimes references persist in the code. |
| 470 | + # This is the case for example with the Float8TensorBase class, |
| 471 | + # which is saved directly inside the ctx while its internal tensors are |
| 472 | + # saved inside save_for_backward. |
| 473 | + tensor_buf.data = torch.Tensor() |
| 474 | + # Release the pointer to the tensor |
464 | 475 | self.tensor_tag_to_buf[tensor_tag] = None |
465 | 476 |
|
466 | 477 | # Time to offload the next group |
@@ -538,7 +549,7 @@ def get_cpu_offload_context( |
538 | 549 | num_layers: int = 1, |
539 | 550 | model_layers: int = 1, |
540 | 551 | offload_activations: bool = True, |
541 | | - offload_weights: bool = True, |
| 552 | + offload_weights: bool = False, |
542 | 553 | ): |
543 | 554 | """ |
544 | 555 | This function returns the CPU Offload context and the synchronizer function that needs to be |
@@ -570,28 +581,30 @@ def get_cpu_offload_context( |
570 | 581 |
|
571 | 582 | """ |
572 | 583 |
|
573 | | - def tensor_need_offloading_checker_activations(tensor): |
574 | | - return hasattr(tensor, "activation_offloading") |
575 | | - |
576 | | - # This includes the Gradient Accumulation Buffer |
577 | | - def tensor_need_offloading_checker_weights(tensor): |
578 | | - return hasattr(tensor, "weight_offloading") |
579 | | - |
580 | | - def tensor_need_offloading_checker_all(tensor): |
581 | | - return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading") |
582 | | - |
583 | | - if offload_activations and offload_weights: |
584 | | - tensor_need_offloading_checker = tensor_need_offloading_checker_all |
585 | | - elif offload_activations: |
586 | | - tensor_need_offloading_checker = tensor_need_offloading_checker_activations |
587 | | - elif offload_weights: |
588 | | - tensor_need_offloading_checker = tensor_need_offloading_checker_weights |
589 | | - else: |
| 584 | + if not offload_weights and not offload_activations: |
590 | 585 | raise ValueError( |
591 | 586 | "CPU Offloading is enabled while it is not " |
592 | 587 | "mentioned what to offload (weights/activations)" |
593 | 588 | ) |
594 | 589 |
|
| 590 | + if offload_weights: |
| 591 | + import warnings |
| 592 | + |
| 593 | + warnings.warn( |
| 594 | + "Offloading weights is deprecated. Using offload_weights=True does not have any" |
| 595 | + " effect.", |
| 596 | + DeprecationWarning, |
| 597 | + ) |
| 598 | + |
| 599 | + # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. |
| 600 | + if not offload_activations: |
| 601 | + return nullcontext(), lambda x: x |
| 602 | + |
| 603 | + def tensor_need_offloading_checker_activations(tensor): |
| 604 | + return hasattr(tensor, "activation_offloading") |
| 605 | + |
| 606 | + tensor_need_offloading_checker = tensor_need_offloading_checker_activations |
| 607 | + |
595 | 608 | cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( |
596 | 609 | num_offload_group=num_layers, |
597 | 610 | num_model_group=model_layers, |
|
0 commit comments