Skip to content

Commit

Permalink
Add localized expressions in AST (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
mandel authored Nov 22, 2024
1 parent 53004df commit 4500f2f
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 57 deletions.
35 changes: 4 additions & 31 deletions pdl-live/src/pdl_ast.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3235,7 +3235,7 @@ export interface CallBlock {
result?: unknown;
location?: LocationType | null;
kind?: Kind18;
call: Call;
call: unknown;
args?: Args;
trace?: Trace6;
}
Expand Down Expand Up @@ -3654,7 +3654,7 @@ export interface DataBlock {
result?: unknown;
location?: LocationType | null;
kind?: Kind13;
data: Data;
data: unknown;
raw?: Raw;
}
/**
Expand Down Expand Up @@ -3736,7 +3736,7 @@ export interface IfBlock {
result?: unknown;
location?: LocationType | null;
kind?: Kind12;
if: If;
if: unknown;
then: Then;
else?: Else;
if_result?: IfResult;
Expand Down Expand Up @@ -3905,7 +3905,7 @@ export interface RepeatUntilBlock {
location?: LocationType | null;
kind?: Kind10;
repeat: Repeat1;
until: Until;
until: unknown;
join?: Join1;
trace?: Trace2;
}
Expand Down Expand Up @@ -4822,26 +4822,6 @@ export interface JoinArray {
export interface JoinLastOf {
as: As2;
}
/**
* Condition of the loop.
*
*/
export interface Until {
[k: string]: unknown;
}
/**
* Condition.
*
*/
export interface If {
[k: string]: unknown;
}
/**
* Value defined.
*/
export interface Data {
[k: string]: unknown;
}
export interface BamTextGenerationParameters {
beam_width?: BeamWidth;
decoding_method?: DecodingMethod | null;
Expand Down Expand Up @@ -4954,13 +4934,6 @@ export interface LitellmParameters {
max_retries?: MaxRetries;
[k: string]: unknown;
}
/**
* Function to call.
*
*/
export interface Call {
[k: string]: unknown;
}
/**
* Arguments of the function with their values.
*
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"jsonschema~=4.0",
"litellm~=1.49",
"termcolor~=2.0",
"ipython~=8.0"
"ipython~=8.0",
]
authors = [
{ name="Mandana Vaziri", email="[email protected]" },
Expand Down
76 changes: 72 additions & 4 deletions src/pdl/pdl-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,10 @@
{
"type": "string"
},
{}
{},
{
"$ref": "#/$defs/LocalizedExpression"
}
],
"title": "Model"
},
Expand Down Expand Up @@ -1359,8 +1362,9 @@
{
"$ref": "#/$defs/BamTextGenerationParameters"
},
{},
{
"type": "object"
"$ref": "#/$defs/LocalizedExpression"
},
{
"type": "null"
Expand Down Expand Up @@ -2081,6 +2085,12 @@
"type": "string"
},
"call": {
"anyOf": [
{},
{
"$ref": "#/$defs/LocalizedExpression"
}
],
"description": "Function to call.\n ",
"title": "Call"
},
Expand Down Expand Up @@ -3370,6 +3380,12 @@
"type": "string"
},
"data": {
"anyOf": [
{},
{
"$ref": "#/$defs/LocalizedExpression"
}
],
"description": "Value defined.",
"title": "Data"
},
Expand Down Expand Up @@ -4928,6 +4944,14 @@
"type": "string"
},
"for": {
"additionalProperties": {
"anyOf": [
{},
{
"$ref": "#/$defs/LocalizedExpression"
}
]
},
"description": "Arrays to iterate over.\n ",
"title": "For",
"type": "object"
Expand Down Expand Up @@ -6859,6 +6883,12 @@
"type": "string"
},
"if": {
"anyOf": [
{},
{
"$ref": "#/$defs/LocalizedExpression"
}
],
"description": "Condition.\n ",
"title": "If"
},
Expand Down Expand Up @@ -8999,7 +9029,10 @@
{
"type": "string"
},
{}
{},
{
"$ref": "#/$defs/LocalizedExpression"
}
],
"title": "Model"
},
Expand Down Expand Up @@ -9270,8 +9303,9 @@
{
"$ref": "#/$defs/LitellmParameters"
},
{},
{
"type": "object"
"$ref": "#/$defs/LocalizedExpression"
},
{
"type": "null"
Expand Down Expand Up @@ -9685,6 +9719,31 @@
"title": "LitellmParameters",
"type": "object"
},
"LocalizedExpression": {
"additionalProperties": false,
"description": "Expression with location information",
"properties": {
"expr": {
"title": "Expr"
},
"location": {
"anyOf": [
{
"$ref": "#/$defs/LocationType"
},
{
"type": "null"
}
],
"default": null
}
},
"required": [
"expr"
],
"title": "LocalizedExpression",
"type": "object"
},
"LocationType": {
"additionalProperties": false,
"properties": {
Expand Down Expand Up @@ -12398,6 +12457,9 @@
"read": {
"anyOf": [
{},
{
"$ref": "#/$defs/LocalizedExpression"
},
{
"type": "null"
}
Expand Down Expand Up @@ -13917,6 +13979,12 @@
"title": "Repeat"
},
"until": {
"anyOf": [
{},
{
"$ref": "#/$defs/LocalizedExpression"
}
],
"description": "Condition of the loop.\n ",
"title": "Until"
},
Expand Down
43 changes: 29 additions & 14 deletions src/pdl/pdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,6 @@

ScopeType: TypeAlias = dict[str, Any]

ExpressionType: TypeAlias = Any
# (
# str
# | int
# | float
# | bool
# | None
# | list["ExpressionType"]
# | dict[str, "ExpressionType"]
# )


Message: TypeAlias = dict[str, Any]
Messages: TypeAlias = list[Message]
Expand Down Expand Up @@ -64,6 +53,28 @@ class LocationType(BaseModel):
empty_block_location = LocationType(file="", path=[], table={})


class LocalizedExpression(BaseModel):
"""Expression with location information"""

model_config = ConfigDict(
extra="forbid", use_attribute_docstrings=True, arbitrary_types_allowed=True
)
expr: Any
location: Optional[LocationType] = None


ExpressionType: TypeAlias = Any | LocalizedExpression
# (
# str
# | int
# | float
# | bool
# | None
# | list["ExpressionType"]
# | dict[str, "ExpressionType"]
# )


class Parser(BaseModel):
model_config = ConfigDict(extra="forbid")
description: Optional[str] = None
Expand Down Expand Up @@ -96,7 +107,11 @@ class ContributeValue(BaseModel):
class Block(BaseModel):
"""Common fields for all PDL blocks."""

model_config = ConfigDict(extra="forbid", use_attribute_docstrings=True)
model_config = ConfigDict(
extra="forbid",
use_attribute_docstrings=True,
arbitrary_types_allowed=True,
)

description: Optional[str] = None
"""Documentation associated to the block.
Expand Down Expand Up @@ -265,7 +280,7 @@ class ModelBlock(Block):
class BamModelBlock(ModelBlock):
platform: Literal[ModelPlatform.BAM]
prompt_id: Optional[str] = None
parameters: Optional[BamTextGenerationParameters | dict] = None
parameters: Optional[BamTextGenerationParameters | ExpressionType] = None
moderations: Optional[ModerationParameters] = None
data: Optional[PromptTemplateData] = None
constraints: Any = None # TODO
Expand All @@ -275,7 +290,7 @@ class LitellmModelBlock(ModelBlock):
"""Call a LLM through the LiteLLM API: https://docs.litellm.ai/."""

platform: Literal[ModelPlatform.LITELLM] = ModelPlatform.LITELLM
parameters: Optional[LitellmParameters | dict] = None
parameters: Optional[LitellmParameters | ExpressionType] = None


class CodeBlock(Block):
Expand Down
5 changes: 2 additions & 3 deletions src/pdl/pdl_ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def iter_block_children(f: Callable[[BlocksType], None], block: BlockType) -> No
f(blocks)
match block:
case FunctionBlock():
if block.returns is not None:
f(block.returns)
f(block.returns)
case CallBlock():
if block.trace is not None:
f(block.trace)
Expand Down Expand Up @@ -208,7 +207,7 @@ def map_block_children(f: MappedFunctions, block: BlockType) -> BlockType:

def map_blocks(f: MappedFunctions, blocks: BlocksType) -> BlocksType:
if not isinstance(blocks, str) and isinstance(blocks, Sequence):
# is a list of blocks
# Is a list of blocks
blocks = [f.f_block(block) for block in blocks]
else:
blocks = f.f_block(blocks)
Expand Down
13 changes: 10 additions & 3 deletions src/pdl/pdl_compilers/to_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
IncludeBlock,
LitellmModelBlock,
LitellmParameters,
LocalizedExpression,
ModelBlock,
ReadBlock,
RepeatBlock,
Expand Down Expand Up @@ -273,10 +274,14 @@ def compile_block(
"include_stop_sequence", False
)
else:
stop_sequences = block.parameters.stop_sequences or []
if isinstance(block.parameters, LocalizedExpression):
parameters = block.parameters.expr
else:
parameters = block.parameters
stop_sequences = parameters.stop_sequences or []
include_stop_sequence = (
block.parameters.include_stop_sequence is None
or block.parameters.include_stop_sequence
parameters.include_stop_sequence is None
or parameters.include_stop_sequence
)
case LitellmModelBlock():
if block.parameters is None:
Expand All @@ -285,6 +290,8 @@ def compile_block(
else:
if isinstance(block.parameters, LitellmParameters):
parameters = block.parameters.model_dump()
elif isinstance(block.parameters, LocalizedExpression):
parameters = block.parameters.expr
else:
parameters = block.parameters
stop_sequences = parameters.get("stop", [])
Expand Down
1 change: 1 addition & 0 deletions src/pdl/pdl_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def blocks_to_dict(
) -> DumpedBlockType | list[DumpedBlockType]:
result: DumpedBlockType | list[DumpedBlockType]
if not isinstance(blocks, str) and isinstance(blocks, Sequence):
# Is a list of blocks
result = [block_to_dict(block, json_compatible) for block in blocks]
else:
result = block_to_dict(blocks, json_compatible)
Expand Down
Loading

0 comments on commit 4500f2f

Please sign in to comment.