Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tests/test_rubric_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,126 @@ def reward_func(completion, parser, answer, **_):

assert score.reward == 1.0
assert recorded_parsers == [xml_parser]

def test_rubric_group_weights_validation(self):
"""Test that mismatched weight length raises ValueError."""

def func1(completion, **kwargs):
return 1.0

rubric1 = Rubric(funcs=[func1], weights=[1.0])
rubric2 = Rubric(funcs=[func1], weights=[1.0])

# Should raise ValueError when weights length doesn't match rubrics length
with pytest.raises(
ValueError,
match="Number of weights must match number of rubrics: got 3 weights for 2 rubrics",
):
RubricGroup(rubrics=[rubric1, rubric2], weights=[1.0, 0.5, 0.3])

@pytest.mark.asyncio
async def test_rubric_group_with_custom_weights_score_rollout(self):
"""Test non-uniform weights applied in score_rollout."""

def func1(completion, **kwargs):
return 1.0

def func2(completion, **kwargs):
return 0.5

rubric1 = Rubric(funcs=[func1], weights=[1.0])
rubric2 = Rubric(funcs=[func2], weights=[1.0])

# Create group with custom weights [0.7, 0.3]
group = RubricGroup(rubrics=[rubric1, rubric2], weights=[0.7, 0.3])

prompt = [{"role": "user", "content": "Test prompt"}]
completion = [{"role": "assistant", "content": "Test completion"}]
state = {"timing": {"generation_ms": 0.0, "scoring_ms": 0.0, "total_ms": 0.0}}

score = await group.score_rollout(prompt, completion, "answer", state)

# Reward should be (1.0 * 0.7) + (0.5 * 0.3) = 0.7 + 0.15 = 0.85
assert score.reward == 0.85
# Metrics should also be weighted
assert score.metrics["func1"] == 0.7 # 1.0 * 0.7
assert score.metrics["func2"] == 0.15 # 0.5 * 0.3

@pytest.mark.asyncio
async def test_rubric_group_with_custom_weights_score_rollouts(self):
"""Test batch scoring with custom weights."""

def func1(completion, **kwargs):
return 1.0

def func2(completion, **kwargs):
return 0.5

rubric1 = Rubric(funcs=[func1], weights=[1.0])
rubric2 = Rubric(funcs=[func2], weights=[1.0])

# Create group with custom weights [0.6, 0.4]
group = RubricGroup(rubrics=[rubric1, rubric2], weights=[0.6, 0.4])

prompts = [
[{"role": "user", "content": "Test 1"}],
[{"role": "user", "content": "Test 2"}],
]
completions = [
[{"role": "assistant", "content": "Response 1"}],
[{"role": "assistant", "content": "Response 2"}],
]
answers = ["answer1", "answer2"]
tasks = ["default", "default"]
infos = [{}, {}]
states = [
{"timing": {"generation_ms": 0.0, "scoring_ms": 0.0, "total_ms": 0.0}},
{"timing": {"generation_ms": 0.0, "scoring_ms": 0.0, "total_ms": 0.0}},
]

scores = await group.score_rollouts(
prompts=prompts,
completions=completions,
answers=answers,
states=states,
tasks=tasks,
infos=infos,
)

# All rewards should be (1.0 * 0.6) + (0.5 * 0.4) = 0.6 + 0.2 = 0.8
assert len(scores.reward) == 2
assert scores.reward[0] == 0.8
assert scores.reward[1] == 0.8
# All metrics should be weighted
assert scores.metrics["func1"][0] == 0.6 # 1.0 * 0.6
assert scores.metrics["func1"][1] == 0.6
assert scores.metrics["func2"][0] == 0.2 # 0.5 * 0.4
assert scores.metrics["func2"][1] == 0.2

@pytest.mark.asyncio
async def test_rubric_group_default_weights_behavior(self):
"""Test that default weights maintain backward compatibility (sum behavior)."""

def func1(completion, **kwargs):
return 1.0

def func2(completion, **kwargs):
return 0.5

rubric1 = Rubric(funcs=[func1], weights=[1.0])
rubric2 = Rubric(funcs=[func2], weights=[1.0])

# Create group without specifying weights (should default to [1.0, 1.0])
group = RubricGroup(rubrics=[rubric1, rubric2])

prompt = [{"role": "user", "content": "Test prompt"}]
completion = [{"role": "assistant", "content": "Test completion"}]
state = {"timing": {"generation_ms": 0.0, "scoring_ms": 0.0, "total_ms": 0.0}}

score = await group.score_rollout(prompt, completion, "answer", state)

# Reward should be (1.0 * 1.0) + (0.5 * 1.0) = 1.5 (sum behavior)
assert score.reward == 1.5
# Metrics should also be summed
assert score.metrics["func1"] == 1.0 # 1.0 * 1.0
assert score.metrics["func2"] == 0.5 # 0.5 * 1.0
36 changes: 25 additions & 11 deletions verifiers/rubrics/rubric_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@ class RubricGroup(Rubric):
Class for aggregating multiple rubrics.
"""

def __init__(self, rubrics: list[Rubric], **kwargs):
def __init__(
self, rubrics: list[Rubric], weights: list[float] | None = None, **kwargs
):
if not rubrics:
raise ValueError("RubricGroup must have at least one rubric")
if weights is not None and len(weights) != len(rubrics):
raise ValueError(
f"Number of weights must match number of rubrics: "
f"got {len(weights)} weights for {len(rubrics)} rubrics"
)
super().__init__(**kwargs)
self.rubrics = rubrics
self.rubric_weights = weights if weights is not None else [1.0] * len(rubrics)
self.logger.info(f"Initialized RubricGroup with {len(rubrics)} rubrics")

def get_reward_func_names(self) -> list[str]:
Expand All @@ -39,6 +47,9 @@ def get_reward_weights(self) -> list[float]:
weights.extend(rubric.get_reward_weights())
return weights

def get_rubric_weights(self) -> list[float]:
return self.rubric_weights

def add_reward_func(self, func: RewardFunc, weight: float = 1.0):
assert len(self.rubrics) > 0, "RubricGroup must have at least one rubric"
self.logger.warning("Adding reward function to the first rubric in the group.")
Expand All @@ -56,7 +67,7 @@ async def score_rollout(
) -> RolloutScore:
total_reward = 0.0
aggregated_metrics: dict[str, float] = {}
for rubric in self.rubrics:
for rubric, weight in zip(self.rubrics, self.rubric_weights):
score = await rubric.score_rollout(
prompt,
completion,
Expand All @@ -66,9 +77,11 @@ async def score_rollout(
info,
**kwargs,
)
total_reward += score.reward
total_reward += score.reward * weight
for key, value in score.metrics.items():
aggregated_metrics[key] = aggregated_metrics.get(key, 0.0) + value
aggregated_metrics[key] = (
aggregated_metrics.get(key, 0.0) + value * weight
)
return RolloutScore(reward=total_reward, metrics=aggregated_metrics)

async def score_rollouts(
Expand All @@ -91,7 +104,7 @@ async def score_rollouts(
reward=[],
metrics={},
)
for rubric in self.rubrics:
for rubric, weight in zip(self.rubrics, self.rubric_weights):
rubric_scores = await rubric.score_rollouts(
prompts,
completions,
Expand All @@ -102,19 +115,20 @@ async def score_rollouts(
max_concurrent,
**kwargs,
)
# aggregate reward (element-wise sum across rubrics)
# aggregate reward (element-wise weighted sum across rubrics)
if not all_scores.reward:
all_scores.reward = rubric_scores.reward
all_scores.reward = [r * weight for r in rubric_scores.reward]
else:
all_scores.reward = [
a + b for a, b in zip(all_scores.reward, rubric_scores.reward)
a + b * weight
for a, b in zip(all_scores.reward, rubric_scores.reward)
]
for key, value in rubric_scores.metrics.items():
if key in all_scores.metrics:
# element-wise sum
# element-wise weighted sum
all_scores.metrics[key] = [
a + b for a, b in zip(all_scores.metrics[key], value)
a + b * weight for a, b in zip(all_scores.metrics[key], value)
]
else:
all_scores.metrics[key] = value
all_scores.metrics[key] = [v * weight for v in value]
return all_scores