Skip to content

Commit

Permalink
Added 'pass @ k' to alg (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Jan 7, 2025
1 parent a236544 commit c932768
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
3 changes: 2 additions & 1 deletion ldp/alg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .algorithms import evaluate_consensus, to_network
from .algorithms import compute_pass_at_k, evaluate_consensus, to_network
from .beam_search import Beam, BeamSearchRollout
from .callbacks import (
Callback,
Expand Down Expand Up @@ -45,6 +45,7 @@
"TrajectoryMetricsCallback",
"TreeSearchRollout",
"WandBLoggingCallback",
"compute_pass_at_k",
"evaluate_consensus",
"to_network",
]
23 changes: 23 additions & 0 deletions ldp/alg/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Literal, TypeVar, cast

import networkx as nx
import numpy as np
from aviary.core import Message, Tool, ToolRequestMessage, is_coroutine_callable, join

from ldp.graph import OpResult
Expand Down Expand Up @@ -190,3 +191,25 @@ async def evaluate_consensus(
ideal_count += consensus == ideal_answer_fn(group[0])

return grouped_consensus, ideal_count / len(groups) if groups else 0.0


def compute_pass_at_k(n: int, c: int, k: int) -> float:
"""Compute an unbiased estimation for 'pass @ k'.
Source: https://doi.org/10.48550/arXiv.2107.03374
If there's multiple tasks, an aggregation used in https://doi.org/10.48550/arXiv.2407.21787
is averaging pass @ k across the tasks.
Args:
n: Total number of samples.
c: Number of correct (pass a verifier) samples.
k: k term (number of attempts) used in pass @ k.
Returns:
Unbiased estimation for pass @ k, the probability of getting at least one
successful outcome in k attempts.
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
19 changes: 18 additions & 1 deletion tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aviary.utils import MultipleChoiceQuestion

from ldp.agent import SimpleAgent
from ldp.alg import evaluate_consensus
from ldp.alg import compute_pass_at_k, evaluate_consensus
from ldp.utils import discounted_returns


Expand Down Expand Up @@ -107,3 +107,20 @@ async def test_evaluate_consensus() -> None:
num_samples=5,
seed=42,
) == (expected_consensus, 2 / 3)


@pytest.mark.parametrize(
("n", "c", "k", "expected"),
[
pytest.param(10, 10, 3, 1.0, id="all-correct-k>1"),
pytest.param(10, 10, 1, 1.0, id="all-correct-k=1"),
pytest.param(10, 0, 3, 0.0, id="all-incorrect-k>1"),
pytest.param(10, 0, 1, 0.0, id="all-incorrect-k=1"),
pytest.param(3, 1, 3, 1.0, id="n-c<k"),
(10, 5, 3, 1 - 1 / 12),
(2, 1, 1, 1 / 2), # Match https://ai.stackexchange.com/a/40396
# SEE: https://github.com/parker-research/pass-at-k/blob/037c5d477486f9e95e1c21fc349576447cd6ce8b/tests/test_pass_at_k.py
],
)
def test_compute_pass_at_k(n: int, c: int, k: int, expected: float) -> None:
assert compute_pass_at_k(n, c, k) == pytest.approx(expected)

0 comments on commit c932768

Please sign in to comment.