Skip to content

Commit

Permalink
Fix some lint
Browse files Browse the repository at this point in the history
  • Loading branch information
claudiosv committed Sep 6, 2024
1 parent cbb4851 commit 1deacb2
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 149 deletions.
3 changes: 2 additions & 1 deletion pdl/optimize/bam_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def get_seq_logprobs(

return None


def process_logprobs():
gsm8k = load_from_disk("var/gsm8k_logprobs_agg")

Expand All @@ -183,7 +184,7 @@ def mapper(row):
seq,
max_new_tokens=None,
)
answer = (lp.generated_text)
answer = lp.generated_text
return {
"generated_probs": lp.generated_probs,
"generated_text": lp.generated_text,
Expand Down
16 changes: 12 additions & 4 deletions pdl/optimize/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

class OptimizationConfig(BaseModel):
benchmark: Literal[
"gsm8k", "gsm8k-baseline", "gsm8k-bench", "fever", "evalplus",
"gsm8k",
"gsm8k-baseline",
"gsm8k-bench",
"fever",
"evalplus",
] = Field()
num_candidates: int = Field(default=30)
num_demonstrations: int = Field(default=5)
Expand Down Expand Up @@ -51,7 +55,11 @@ def get_variable_names(self) -> list[str]:
print(config)
print(config.get_variable_names())
Path("opticonfig1.yml").write_text(
yaml.dump(config.model_dump(
exclude_defaults=False, exclude_none=False, exclude_unset=False,
)),
yaml.dump(
config.model_dump(
exclude_defaults=False,
exclude_none=False,
exclude_unset=False,
),
),
)
4 changes: 1 addition & 3 deletions pdl/optimize/fever_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from pdl.optimize.util import PDLThread
from pdl.pdl_ast import ScopeType
from pdl.pdl_interpreter import (
empty_scope,
)
from pdl.pdl_interpreter import empty_scope


class FEVERTrialThread(PDLThread):
Expand Down
4 changes: 1 addition & 3 deletions pdl/optimize/gsm8k_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from pdl.optimize.parse_number import extract_math_answer
from pdl.optimize.util import PDLThread
from pdl.pdl_ast import ScopeType
from pdl.pdl_interpreter import (
empty_scope,
)
from pdl.pdl_interpreter import empty_scope


class Gsm8kTrialThread(PDLThread):
Expand Down
8 changes: 2 additions & 6 deletions pdl/optimize/mbpp_thread.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import ast
from typing import Any

from evalplus.evaluate import (
check_correctness,
)
from evalplus.evaluate import check_correctness

from pdl.optimize.util import PDLThread
from pdl.pdl_ast import ScopeType
Expand Down Expand Up @@ -77,9 +75,7 @@ def answer_correct(self, document: str, answer: Any, truth: Any) -> bool:

task_id = self.example["task_id"]

solution = (
self.example["prompt"] + answer
)
solution = self.example["prompt"] + answer

result = check_correctness(
dataset="mbpp",
Expand Down
14 changes: 4 additions & 10 deletions pdl/optimize/optimize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import argparse
from copy import deepcopy
from enum import Enum
from pathlib import Path

import yaml
from datasets import load_dataset, load_from_disk
from evalplus.evaluate import (
MBPP_OUTPUT_NOT_NONE_TASKS,
get_groundtruth,
get_mbpp_plus_hash,
)
from datasets import concatenate_datasets, load_dataset, load_from_disk
from evalplus.data import get_mbpp_plus, get_mbpp_plus_hash
from evalplus.evaluate import MBPP_OUTPUT_NOT_NONE_TASKS, get_groundtruth

from pdl.optimize.config_parser import OptimizationConfig
from pdl.optimize.mbpp_thread import MBPPTrialThread
Expand Down Expand Up @@ -90,10 +88,6 @@ class SamplingMethods(Enum):
config=config,
).run()
elif config.benchmark == "evalplus":
from copy import deepcopy

from datasets import concatenate_datasets
from evalplus.data import get_mbpp_plus, get_mbpp_plus_hash

class SelectableList(list):
def select(self, iterable):
Expand Down
4 changes: 1 addition & 3 deletions pdl/optimize/pdl_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def parse_budget(self):
self.time_budget = duration

def load_pdl(self, path: Path) -> Program:
with (
path.open(encoding="utf-8") as pdl,
):
with (path.open(encoding="utf-8") as pdl,):
return Program.model_validate(yaml.safe_load(pdl))

def parse_signature(self):
Expand Down
9 changes: 7 additions & 2 deletions pdl/optimize/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from pdl.optimize.bam_logprobs import ModelResponse, get_seq_logprobs
from pdl.optimize.config_parser import OptimizationConfig
from pdl.pdl_ast import Program, ScopeType
from pdl.pdl_interpreter import InterpreterState, contains_error, messages_to_str, process_prog
from pdl.pdl_interpreter import (
InterpreterState,
contains_error,
messages_to_str,
process_prog,
)

console = Console()

Expand Down Expand Up @@ -104,7 +109,7 @@ def run(
else:
if self.index == 0 and self.return_logprobs:
model_input = get_seq_logprobs(
self.model,
self.scope["model"],
scope[self.config.demonstrations_variable_name],
)
answer = self.extract_answer(document)
Expand Down
Loading

0 comments on commit 1deacb2

Please sign in to comment.