Skip to content

Commit

Permalink
v0.0.64
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Sep 2, 2024
1 parent 022989c commit 087e656
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 23 deletions.
66 changes: 44 additions & 22 deletions agixtsdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,7 @@ def plan_task(
def _generate_detailed_schema(self, model: Type[BaseModel], depth: int = 0) -> str:
"""
Recursively generates a detailed schema representation of a Pydantic model,
including nested models.
including nested models and complex types.
"""
fields = model.__annotations__
field_descriptions = []
Expand All @@ -1687,29 +1687,55 @@ def _generate_detailed_schema(self, model: Type[BaseModel], depth: int = 0) -> s
for field, field_type in fields.items():
description = f"{indent}{field}: "

if get_origin(field_type) == Union:
field_type = get_args(field_type)[0]
origin_type = get_origin(field_type)
if origin_type is None:
origin_type = field_type

if isinstance(field_type, type) and issubclass(field_type, BaseModel):
description += f"Nested Model:\n{self._generate_detailed_schema(field_type, depth + 1)}"
elif get_origin(field_type) == List:
if issubclass(origin_type, BaseModel):
description += f"Nested Model:\n{self._generate_detailed_schema(origin_type, depth + 1)}"
elif origin_type == List:
list_type = get_args(field_type)[0]
if isinstance(list_type, type) and issubclass(list_type, BaseModel):
description += f"List of Nested Model:\n{self._generate_detailed_schema(list_type, depth + 1)}"
elif get_origin(list_type) == Union:
union_types = get_args(list_type)
description += f"List of Union:\n"
for union_type in union_types:
if issubclass(union_type, BaseModel):
description += f"{indent} - Nested Model:\n{self._generate_detailed_schema(union_type, depth + 2)}"
else:
description += f"{indent} - {union_type.__name__}\n"
else:
description += f"List[{list_type.__name__}]"
elif get_origin(field_type) == Dict:
description += f"List[{self._get_type_name(list_type)}]"
elif origin_type == Dict:
key_type, value_type = get_args(field_type)
description += f"Dict[{key_type.__name__}, {value_type.__name__}]"
elif isinstance(field_type, type) and issubclass(field_type, Enum):
enum_values = ", ".join([f"{e.name} = {e.value}" for e in field_type])
description += f"{field_type.__name__} (Enum values: {enum_values})"
description += f"Dict[{self._get_type_name(key_type)}, {self._get_type_name(value_type)}]"
elif origin_type == Union:
union_types = get_args(field_type)
description += "Union of:\n"
for union_type in union_types:
if issubclass(union_type, BaseModel):
description += f"{indent} - Nested Model:\n{self._generate_detailed_schema(union_type, depth + 2)}"
else:
description += (
f"{indent} - {self._get_type_name(union_type)}\n"
)
elif issubclass(origin_type, Enum):
enum_values = ", ".join([f"{e.name} = {e.value}" for e in origin_type])
description += f"{origin_type.__name__} (Enum values: {enum_values})"
else:
description += f"{field_type.__name__}"
description += self._get_type_name(origin_type)

field_descriptions.append(description)

return "\n".join(field_descriptions)

def _get_type_name(self, type_):
"""Helper method to get the name of a type, handling some special cases."""
if hasattr(type_, "__name__"):
return type_.__name__
return str(type_).replace("typing.", "")

def convert_to_model(
self,
input_string: str,
Expand All @@ -1732,12 +1758,12 @@ def convert_to_model(
"""
input_string = str(input_string)
schema = self._generate_detailed_schema(model)

if "user_input" in kwargs:
del kwargs["user_input"]
if "schema" in kwargs:
del kwargs["schema"]

response = self.prompt_agent(
agent_name=agent_name,
prompt_name="Convert to Model",
Expand All @@ -1747,12 +1773,12 @@ def convert_to_model(
**kwargs,
},
)

if "```json" in response:
response = response.split("```json")[1].split("```")[0].strip()
elif "```" in response:
response = response.split("```")[1].strip()

try:
response = json.loads(response)
if response_type == "json":
Expand All @@ -1766,11 +1792,7 @@ def convert_to_model(
f"Error: {e} . Failed to convert the response to the model after {max_failures} attempts. Response: {response}"
)
self.failures = 0
return (
response
if response
else "Failed to convert the response to the model."
)
return response if response else "Failed to convert the response to the model."
else:
self.failures = 1
print(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

setup(
name="agixtsdk",
version="0.0.63",
version="0.0.64",
description="The AGiXT SDK for Python.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 087e656

Please sign in to comment.