Skip to content

Commit

Permalink
Rename yield_output into yield_background
Browse files Browse the repository at this point in the history
  • Loading branch information
mandel committed Sep 5, 2024
1 parent ea00db2 commit 64afc6c
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pdl/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def process_sample(self, q, data, index, json_str):
answer = 0.0 # pylint: disable=invalid-name
exception = None
try:
state = InterpreterState(yield_output=False)
state = InterpreterState(yield_background=False)
scope = empty_scope
scope["question"] = question
document, _, _, _ = process_prog(state, scope, data)
Expand Down
2 changes: 1 addition & 1 deletion pdl/bugfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def process():
correct = True
wrongs = 0
try:
state = InterpreterState(yield_output=True)
state = InterpreterState(yield_background=True)
scope = empty_scope
scope["question"] = question
scope["code"] = code
Expand Down
7 changes: 5 additions & 2 deletions pdl/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
class InterpreterConfig(TypedDict, total=False):
"""Configuration parameters of the PDL interpreter."""

yield_output: bool
"""Print the program messages during the execution.
yield_result: bool
"""Print incrementally result of the execution.
"""
yield_background: bool
"""Print the program background messages during the execution.
"""
batch: int
"""Execution type:
Expand Down
48 changes: 28 additions & 20 deletions pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,19 @@ class PDLRuntimeParserError(PDLException):


class InterpreterState(BaseModel):
yield_output: bool = True
yield_result: bool = False
yield_background: bool = False
log: list[str] = []
batch: int = 0
# batch=0: streaming
# batch=1: call to generate with `input`
role: RoleType = None

def with_yield_output(self: "InterpreterState", b: bool) -> "InterpreterState":
return self.model_copy(update={"yield_output": b})
def with_yield_result(self: "InterpreterState", b: bool) -> "InterpreterState":
return self.model_copy(update={"yield_result": b})

def with_yield_background(self: "InterpreterState", b: bool) -> "InterpreterState":
return self.model_copy(update={"yield_background": b})

def with_role(self: "InterpreterState", role: RoleType) -> "InterpreterState":
return self.model_copy(update={"role": role})
Expand All @@ -106,8 +110,12 @@ def generate(
log_file = "log.txt"
try:
prog, loc = parse_file(pdl_file)
state = InterpreterState(yield_output=True)
_, _, _, trace = process_prog(state, initial_scope, prog, loc)
state = InterpreterState()
result, _, _, trace = process_prog(state, initial_scope, prog, loc)
if not state.yield_result:
if state.yield_background:
print("----------------")
print(stringify(result))
with open(log_file, "w", encoding="utf-8") as log_fp:
for line in state.log:
log_fp.write(line)
Expand Down Expand Up @@ -172,7 +180,7 @@ def process_prog(
# incremental_document += output
# print()
# result, document, scope, trace = doc_generator.value
# assert document == incremental_document or not state.yield_output
# assert document == incremental_document or not state.yield_background
# return result, document, scope, trace


Expand Down Expand Up @@ -210,7 +218,7 @@ def step_block(
else:
background = [{"role": state.role, "content": stringify(result)}]
trace = result
if state.yield_output:
if state.yield_background:
yield OutputMessage(background)
append_log(state, "Document", background)
else:
Expand All @@ -233,7 +241,7 @@ def step_advanced_block(
scope, defs_trace = yield from step_defs(state, scope, block.defs, loc)
else:
defs_trace = block.defs
state = state.with_yield_output(state.yield_output and block.show_result)
state = state.with_yield_background(state.yield_background and block.show_result)
result, background, scope, trace = yield from step_block_body(
state, scope, block, loc
)
Expand Down Expand Up @@ -278,7 +286,7 @@ def step_block_body(
result, background, scope, trace = yield from step_call_code(
state, scope, block, loc
)
if state.yield_output:
if state.yield_background:
yield OutputMessage(background)
case GetBlock(get=var):
result = get_var(var, scope)
Expand All @@ -294,7 +302,7 @@ def step_block_body(
else:
background = [{"role": state.role, "content": stringify(result)}]
trace = block.model_copy()
if state.yield_output:
if state.yield_background:
yield OutputMessage(background)
case DataBlock(data=v):
block.location = append(loc, "data")
Expand All @@ -308,13 +316,13 @@ def step_block_body(
else:
background = [{"role": state.role, "content": stringify(result)}]
trace = block.model_copy()
if state.yield_output:
if state.yield_background:
yield OutputMessage(background)
case ApiBlock():
result, background, scope, trace = yield from step_call_api(
state, scope, block, loc
)
if state.yield_output:
if state.yield_background:
yield OutputMessage(background)
case DocumentBlock():
result, background, scope, document = yield from step_blocks(
Expand Down Expand Up @@ -498,7 +506,7 @@ def step_block_body(
trace = block.model_copy(update={"trace": iterations_trace})
case ReadBlock():
result, background, scope, trace = process_input(state, scope, block, loc)
if state.yield_output:
if state.yield_background:
yield OutputMessage(background)

case IncludeBlock():
Expand Down Expand Up @@ -569,7 +577,7 @@ def step_defs(
defloc = append(loc, "defs")
for x, blocks in defs.items():
newloc = append(defloc, x)
state = state.with_yield_output(False)
state = state.with_yield_background(False)
result, _, _, blocks_trace = yield from step_blocks(
IterationType.SEQUENCE, state, scope, blocks, newloc
)
Expand Down Expand Up @@ -711,7 +719,7 @@ def step_call_model(
if block.input is not None: # If not implicit, then input must be a block
model_input_result, _, _, input_trace = yield from step_blocks(
IterationType.SEQUENCE,
state.with_yield_output(False),
state.with_yield_background(False),
scope,
block.input,
append(loc, "input"),
Expand Down Expand Up @@ -846,7 +854,7 @@ def generate_client_response_streaming(
complete_msg: Optional[Message] = None
role = None
for chunk in msg_stream:
if state.yield_output:
if state.yield_background:
yield OutputMessage([chunk])
if complete_msg is None:
complete_msg = chunk
Expand Down Expand Up @@ -899,7 +907,7 @@ def generate_client_response_single(
msg = LitellmModel.generate_text(
model_id=block.model, messages=model_input, parameters=parameters
)
if state.yield_output:
if state.yield_background:
yield OutputMessage([msg])
return msg

Expand All @@ -921,7 +929,7 @@ def generate_client_response_batching( # pylint: disable=too-many-arguments
moderations=block.moderations,
data=block.data,
)
if state.yield_output:
if state.yield_background:
yield OutputMessage(msg)
case WatsonxModelBlock():
assert False # XXX TODO
Expand All @@ -940,7 +948,7 @@ def step_call_api(
background: Messages
input_value, _, _, input_trace = yield from step_blocks(
IterationType.SEQUENCE,
state.with_yield_output(False),
state.with_yield_background(False),
scope,
block.input,
append(loc, "input"),
Expand Down Expand Up @@ -974,7 +982,7 @@ def step_call_code(
background: Messages
code_s, _, _, code_trace = yield from step_blocks(
IterationType.SEQUENCE,
state.with_yield_output(False),
state.with_yield_background(False),
scope,
block.code,
append(loc, "code"),
Expand Down
2 changes: 1 addition & 1 deletion pdl/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def process(file, mode):
result = ""
answer = 0.0 # pylint: disable=invalid-name
try:
state = InterpreterState(yield_output=True)
state = InterpreterState(yield_background=True)
scope = empty_scope
scope["question"] = question
result, _, _, _ = process_prog(state, scope, data)
Expand Down

0 comments on commit 64afc6c

Please sign in to comment.