Skip to content

Commit

Permalink
Implemented consensus sampling free function, with test
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Dec 18, 2024
1 parent 7e83c8e commit 2625679
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 4 deletions.
3 changes: 2 additions & 1 deletion ldp/alg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .algorithms import to_network
from .algorithms import evaluate_consensus, to_network
from .beam_search import Beam, BeamSearchRollout
from .callbacks import (
Callback,
Expand Down Expand Up @@ -45,5 +45,6 @@
"TrajectoryMetricsCallback",
"TreeSearchRollout",
"WandBLoggingCallback",
"evaluate_consensus",
"to_network",
]
76 changes: 73 additions & 3 deletions ldp/alg/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import collections
import itertools
from collections.abc import Sequence
from typing import Any
import random
from collections.abc import Awaitable, Callable, Hashable, Iterable, Sequence
from typing import Any, Literal, TypeVar, cast

import networkx as nx
from aviary.core import Message, Tool, ToolRequestMessage, join
from aviary.core import Message, Tool, ToolRequestMessage, is_coroutine_callable, join

from ldp.graph import OpResult
from ldp.graph.ops import GradOutType
Expand Down Expand Up @@ -120,3 +123,70 @@ def gvizify(x: Any) -> str:
G.add_edge(op_call_str, arg_str, label=gvizify(grad), style="dotted")

return G


TData = TypeVar("TData")
TGroupKey = TypeVar("TGroupKey", bound=Hashable)
TAnswer = TypeVar("TAnswer")
NO_IDEAL_ANSWER_FN: Literal["NO_IDEAL_ANSWER_FN"] = "NO_IDEAL_ANSWER_FN" # Sentinel


async def evaluate_consensus(
data: Iterable[TData],
grouping_fn: Callable[[TData], TGroupKey],
extract_answer_fn: Callable[[TData], TAnswer | Awaitable[TAnswer]],
ideal_answer_fn: (
Callable[[TData], TAnswer] | Literal["NO_IDEAL_ANSWER_FN"]
) = NO_IDEAL_ANSWER_FN,
num_samples: int = 1,
seed: int | None = None,
) -> tuple[dict[TGroupKey, list[tuple[TAnswer, int]]], float]:
"""
Create consensus groups and evaluate the consensus accuracy for each one.
Args:
data: Data to evaluate consensus upon, length is called N.
grouping_fn: Function to extract the group key from a datum.
extract_answer_fn: Function to extract the actual answer from a datum. It can
be async so this can be done using a LLM call.
ideal_answer_fn: Optional function to extract the ideal answer from a datum to
compute accuracy with, or pass NO_IDEAL_ANSWER to skip this calculation.
num_samples: Number of samples to choose from the N total.
seed: Optional seed for sampling.
Returns:
Two-tuple of consensus list generated by collections.Counter.most_common and
the proportion of groups for which the consensus matches the ideal.
"""
groups = collections.defaultdict(list)
for x in data:
groups[grouping_fn(x)].append(x)

ideal_count = 0
grouped_consensus: dict[TGroupKey, list[tuple[TAnswer, int]]] = {}
rand = random.Random(seed) if seed is not None else random
for group_key, group in groups.items():
if len(group) < num_samples: # Too few items, sample with replacement
sampled = [rand.choice(group) for _ in range(num_samples)]
else: # Sample without replacement
sampled = rand.sample(group, num_samples)

# Get answers for the sampled data
if is_coroutine_callable(extract_answer_fn):
extract_answer_fn = cast(
Callable[[TData], Awaitable[TAnswer]], extract_answer_fn
)
answers = await asyncio.gather(*(extract_answer_fn(x) for x in sampled))
else:
extract_answer_fn = cast(Callable[[TData], TAnswer], extract_answer_fn)
answers = [extract_answer_fn(x) for x in sampled]

# Compute consensus: mode of the sampled answers
grouped_consensus[group_key] = collections.Counter(answers).most_common()
# NOTE: If there are multiple modes, just use the first one
consensus: TAnswer = grouped_consensus[group_key][0][0]
if ideal_answer_fn != NO_IDEAL_ANSWER_FN:
# Assume all items in the group have the same ideal answer
ideal_count += consensus == ideal_answer_fn(group[0])

return grouped_consensus, ideal_count / len(groups) if groups else 0.0
75 changes: 75 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import operator

import pytest
from aviary.core import DummyEnv
from aviary.utils import MultipleChoiceQuestion

from ldp.agent import SimpleAgent
from ldp.alg import evaluate_consensus
from ldp.utils import discounted_returns


Expand Down Expand Up @@ -32,3 +36,74 @@ async def test_rollout_and_discounting(dummy_env: DummyEnv) -> None:
print(terms)
d_returns = discounted_returns(rewards, terms, 0.5)
print(d_returns)


@pytest.mark.asyncio
async def test_evaluate_consensus() -> None:
# We have two questions, so let's group based on question
question_1 = MultipleChoiceQuestion(
question="What is the meaning of life?",
options=["-84", "11", "cheesecake"],
ideal_answer="42",
)
question_2 = MultipleChoiceQuestion(
question="What is a healthy fruit?",
options=["brownie", "chocolate bar", "french fry"],
ideal_answer="apple",
)
question_3 = MultipleChoiceQuestion(
question="What is the highest number?",
options=["1", "2", "4"],
ideal_answer="8",
)
data_with_several_groups: list[tuple[MultipleChoiceQuestion, str]] = [
# Correct consensus
(question_1, "-84"),
(question_1, "11"),
(question_1, "11"),
(question_1, "cheesecake"),
(question_1, "42"),
(question_1, "42"),
(question_1, "42"),
(question_1, "42"),
(question_1, "42"),
(question_1, "42"),
# Correct consensus
(question_2, "brownie"),
(question_2, "brownie"),
(question_2, "apple"),
(question_2, "apple"),
(question_2, "apple"),
(question_2, "apple"),
(question_2, "apple"),
(question_2, "apple"),
# Incorrect consensus
(question_3, "1"),
(question_3, "2"),
(question_3, "1"),
(question_3, "2"),
]
# NOTE: this consensus is sensitive to seed
expected_consensus = {
question_1.question: [("42", 3), ("11", 1), ("-84", 1)],
question_2.question: [("apple", 4), ("brownie", 1)],
question_3.question: [("1", 3), ("2", 2)],
}

# Check accuracy is 0% without an ideal answer
assert await evaluate_consensus(
data_with_several_groups,
grouping_fn=lambda x: x[0].question,
extract_answer_fn=operator.itemgetter(1),
num_samples=5,
seed=42,
) == (expected_consensus, 0.0)
# Check accuracy is present when we can get an ideal answer
assert await evaluate_consensus(
data_with_several_groups,
grouping_fn=lambda x: x[0].question,
extract_answer_fn=operator.itemgetter(1),
ideal_answer_fn=lambda x: x[0].ideal_answer,
num_samples=5,
seed=42,
) == (expected_consensus, 2 / 3)

0 comments on commit 2625679

Please sign in to comment.