Skip to content

Commit 37782e6

Browse files
authored
Merge pull request #136 from stanfordnlp/zen/add_labels
[Minor] Accepting `labels` field for loss calculation
2 parents db7c676 + 9c5a2ff commit 37782e6

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

pyvene/models/intervenable_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,7 @@ def forward(
13171317
unit_locations: Optional[Dict] = None,
13181318
source_representations: Optional[Dict] = None,
13191319
subspaces: Optional[List] = None,
1320+
labels: Optional[torch.LongTensor] = None,
13201321
output_original_output: Optional[bool] = False,
13211322
return_dict: Optional[bool] = None,
13221323
):
@@ -1438,7 +1439,10 @@ def forward(
14381439
)
14391440

14401441
# run intervened forward
1441-
counterfactual_outputs = self.model(**base)
1442+
if labels is not None:
1443+
counterfactual_outputs = self.model(**base, labels=labels)
1444+
else:
1445+
counterfactual_outputs = self.model(**base)
14421446
set_handlers_to_remove.remove()
14431447

14441448
self._output_validation()

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name="pyvene",
13-
version="0.0.8dev",
13+
version="0.0.8",
1414
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
1515
long_description=long_description,
1616
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)