|
1 | 1 | import json, logging, torch, types
|
2 |
| -import nnsight |
3 | 2 | import numpy as np
|
4 | 3 | from collections import OrderedDict
|
5 | 4 | from typing import List, Optional, Tuple, Union, Dict, Any
|
|
27 | 26 | from transformers.utils import ModelOutput
|
28 | 27 | from tqdm import tqdm, trange
|
29 | 28 |
|
| 29 | +try: |
| 30 | + import nnsight |
| 31 | +except: |
| 32 | + print("nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.") |
| 33 | + |
| 34 | + |
30 | 35 | @dataclass
|
31 | 36 | class IntervenableModelOutput(ModelOutput):
|
32 | 37 | """
|
@@ -226,7 +231,7 @@ def __init__(self, config, model, backend, **kwargs):
|
226 | 231 | # cached swapped activations (hot)
|
227 | 232 | self.hot_activations = {}
|
228 | 233 |
|
229 |
| - self.aux_loss = [] |
| 234 | + self.full_intervention_outputs = [] |
230 | 235 |
|
231 | 236 | # temp fields should not be accessed outside
|
232 | 237 | self._batched_setter_activation_select = {}
|
@@ -1558,16 +1563,17 @@ def hook_callback(model, args, kwargs, output=None):
|
1558 | 1563 | else:
|
1559 | 1564 | if not isinstance(self.interventions[key][0], types.FunctionType):
|
1560 | 1565 | if intervention.is_source_constant:
|
1561 |
| - intervened_representation = do_intervention( |
| 1566 | + raw_intervened_representation = do_intervention( |
1562 | 1567 | selected_output,
|
1563 | 1568 | None,
|
1564 | 1569 | intervention,
|
1565 | 1570 | subspaces[key_i] if subspaces is not None else None,
|
1566 | 1571 | )
|
1567 |
| - if isinstance(intervened_representation, InterventionOutput): |
1568 |
| - if intervened_representation.loss is not None: |
1569 |
| - self.aux_loss.append(intervened_representation.loss) |
1570 |
| - intervened_representation = intervened_representation.output |
| 1572 | + if isinstance(raw_intervened_representation, InterventionOutput): |
| 1573 | + self.full_intervention_outputs.append(raw_intervened_representation) |
| 1574 | + intervened_representation = raw_intervened_representation.output |
| 1575 | + else: |
| 1576 | + intervened_representation = raw_intervened_representation |
1571 | 1577 | else:
|
1572 | 1578 | intervened_representation = do_intervention(
|
1573 | 1579 | selected_output,
|
@@ -1866,7 +1872,7 @@ def forward(
|
1866 | 1872 | if sources is not None and not isinstance(sources, list):
|
1867 | 1873 | sources = [sources]
|
1868 | 1874 |
|
1869 |
| - self.aux_loss.clear() |
| 1875 | + self.full_intervention_outputs.clear() |
1870 | 1876 |
|
1871 | 1877 | self._cleanup_states()
|
1872 | 1878 |
|
|
0 commit comments