From 406ddce390780a81771bc01a8018ff8b68a81175 Mon Sep 17 00:00:00 2001 From: Novak Boskov Date: Tue, 9 Sep 2025 11:30:31 +0200 Subject: [PATCH] relative_hidden_states in RepReadingPipeline.get_directions should involve T_f^+ and T_f^- --- repe/rep_reading_pipeline.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/repe/rep_reading_pipeline.py b/repe/rep_reading_pipeline.py index 16d5f47..2451d90 100644 --- a/repe/rep_reading_pipeline.py +++ b/repe/rep_reading_pipeline.py @@ -136,12 +136,19 @@ def get_directions( if direction_finder.needs_hiddens: # get raw hidden states for the train inputs hidden_states = self._batched_string_to_hiddens(train_inputs, rep_token, hidden_layers, batch_size, which_hidden_states, **tokenizer_args) - + # get differences between pairs relative_hidden_states = {k: np.copy(v) for k, v in hidden_states.items()} for layer in hidden_layers: for _ in range(n_difference): - relative_hidden_states[layer] = relative_hidden_states[layer][::2] - relative_hidden_states[layer][1::2] + positive_template_hidden, negative_template_hidden = [], [] + for hidden, train_label in zip(relative_hidden_states[layer], np.concatenate(train_labels)): + if train_label: + positive_template_hidden.append(hidden) + else: + negative_template_hidden.append(hidden) + + relative_hidden_states[layer] = np.array(positive_template_hidden) - np.array(negative_template_hidden) # get the directions direction_finder.directions = direction_finder.get_rep_directions(