File tree 2 files changed +6
-2
lines changed
2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -1317,6 +1317,7 @@ def forward(
1317
1317
unit_locations : Optional [Dict ] = None ,
1318
1318
source_representations : Optional [Dict ] = None ,
1319
1319
subspaces : Optional [List ] = None ,
1320
+ labels : Optional [torch .LongTensor ] = None ,
1320
1321
output_original_output : Optional [bool ] = False ,
1321
1322
return_dict : Optional [bool ] = None ,
1322
1323
):
@@ -1438,7 +1439,10 @@ def forward(
1438
1439
)
1439
1440
1440
1441
# run intervened forward
1441
- counterfactual_outputs = self .model (** base )
1442
+ if labels is not None :
1443
+ counterfactual_outputs = self .model (** base , labels = labels )
1444
+ else :
1445
+ counterfactual_outputs = self .model (** base )
1442
1446
set_handlers_to_remove .remove ()
1443
1447
1444
1448
self ._output_validation ()
Original file line number Diff line number Diff line change 10
10
11
11
setup (
12
12
name = "pyvene" ,
13
- version = "0.0.8dev " ,
13
+ version = "0.0.8 " ,
14
14
description = "Use Activation Intervention to Interpret Causal Mechanism of Model" ,
15
15
long_description = long_description ,
16
16
long_description_content_type = 'text/markdown' ,
You can’t perform that action at this time.
0 commit comments