Skip to content

Commit d4ca094

Browse files
authored
Merge pull request #185 from stanfordnlp/zen/dependency_clean
[Minor] Update dependency
2 parents 4f70e10 + 4b14b6e commit d4ca094

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

pyvene/models/intervenable_base.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json, logging, torch, types
2-
import nnsight
32
import numpy as np
43
from collections import OrderedDict
54
from typing import List, Optional, Tuple, Union, Dict, Any
@@ -27,6 +26,12 @@
2726
from transformers.utils import ModelOutput
2827
from tqdm import tqdm, trange
2928

29+
try:
30+
import nnsight
31+
except:
32+
print("nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.")
33+
34+
3035
@dataclass
3136
class IntervenableModelOutput(ModelOutput):
3237
"""
@@ -226,7 +231,7 @@ def __init__(self, config, model, backend, **kwargs):
226231
# cached swapped activations (hot)
227232
self.hot_activations = {}
228233

229-
self.aux_loss = []
234+
self.full_intervention_outputs = []
230235

231236
# temp fields should not be accessed outside
232237
self._batched_setter_activation_select = {}
@@ -1558,16 +1563,17 @@ def hook_callback(model, args, kwargs, output=None):
15581563
else:
15591564
if not isinstance(self.interventions[key][0], types.FunctionType):
15601565
if intervention.is_source_constant:
1561-
intervened_representation = do_intervention(
1566+
raw_intervened_representation = do_intervention(
15621567
selected_output,
15631568
None,
15641569
intervention,
15651570
subspaces[key_i] if subspaces is not None else None,
15661571
)
1567-
if isinstance(intervened_representation, InterventionOutput):
1568-
if intervened_representation.loss is not None:
1569-
self.aux_loss.append(intervened_representation.loss)
1570-
intervened_representation = intervened_representation.output
1572+
if isinstance(raw_intervened_representation, InterventionOutput):
1573+
self.full_intervention_outputs.append(raw_intervened_representation)
1574+
intervened_representation = raw_intervened_representation.output
1575+
else:
1576+
intervened_representation = raw_intervened_representation
15711577
else:
15721578
intervened_representation = do_intervention(
15731579
selected_output,
@@ -1866,7 +1872,7 @@ def forward(
18661872
if sources is not None and not isinstance(sources, list):
18671873
sources = [sources]
18681874

1869-
self.aux_loss.clear()
1875+
self.full_intervention_outputs.clear()
18701876

18711877
self._cleanup_states()
18721878

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,3 @@ numpy>=1.23.5
1010
fsspec>=2023.6.0
1111
accelerate>=0.29.1
1212
sentencepiece>=0.1.96
13-
nnsight>=0.1.0

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name="pyvene",
13-
version="0.1.4",
13+
version="0.1.5",
1414
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
1515
long_description=long_description,
1616
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)