Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse json data from token if object is dict-like. #285

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion cyclopts/argument.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import itertools
import json
from contextlib import suppress
from functools import partial
from typing import Any, Callable, Optional, Union, get_args, get_origin
Expand Down Expand Up @@ -55,6 +56,7 @@
negative=None,
accepts_keys=None,
consume_multiple=None,
env_var=None,
)

_SHOW_DEFAULT_BLOCKLIST = (
Expand Down Expand Up @@ -972,10 +974,28 @@
converter = partial(
convert, converter=_identity_converter, name_transform=self.parameter.name_transform
)

if self.tokens:
# Dictionary-like structures may have incoming json data from an environment variable.
for token in self.tokens:
try:
data.update(json.loads(token.value))
except json.JSONDecodeError as e:
raise CoercionError(token=token, target_type=self.hint, msg=e.msg) from None

Check warning on line 984 in cyclopts/argument.py

View check run for this annotation

Codecov / codecov/patch

cyclopts/argument.py#L983-L984

Added lines #L983 - L984 were not covered by tests

for child in self.children:
assert len(child.keys) == (len(self.keys) + 1)
if child.has_tokens:
if child.has_tokens: # Either the child directly has tokens, or a nested child has tokens.
data[child.keys[-1]] = child.convert_and_validate(converter=converter)
elif child.required:
# Check if the required fields are already populated.
obj = data
for k in child.keys:
try:
obj = obj[k]
except Exception:
raise MissingArgumentError(argument=child) from None
child._marked = True

if self._missing_keys_checker and (self.required or data):
if missing_keys := self._missing_keys_checker(self, data):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_bind_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import sys
from dataclasses import dataclass, field
from textwrap import dedent
Expand Down Expand Up @@ -42,6 +43,29 @@ def foo(some_number: int, user: User):
)


def test_bind_dataclass_from_json(app, assert_parse_args, monkeypatch):
@app.command
def foo(some_number: int, user: Annotated[User, Parameter(env_var="USER")]):
pass

external_data = {
"id": 123,
# "name" is purposely missing.
"tastes": {
"wine": 9,
"cheese": 7,
"cabbage": 1,
},
}
monkeypatch.setenv("USER", json.dumps(external_data))
assert_parse_args(
foo,
"foo 100",
100,
User(**external_data),
)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="field(kw_only=True) doesn't exist.")
def test_bind_dataclass_recursive(app, assert_parse_args, console):
@dataclass
Expand Down
35 changes: 33 additions & 2 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
from datetime import datetime
from textwrap import dedent
from typing import Dict, Optional, Union
from typing import Annotated, Dict, Optional, Union

import pytest
from pydantic import BaseModel, Field, PositiveInt, validate_call
from pydantic import ValidationError as PydanticValidationError

from cyclopts import MissingArgumentError
from cyclopts import MissingArgumentError, Parameter


def test_pydantic_error_msg(app, console):
Expand Down Expand Up @@ -74,6 +75,36 @@ def foo(user: User):
"has_socks": True,
},
}

assert_parse_args(
foo,
'foo --user.id=123 --user.signup-ts="2019-06-01 12:22" --user.tastes.wine=9 --user.tastes.cheese=7 --user.tastes.cabbage=1 --user.outfit.body=t-shirt --user.outfit.head=baseball-cap --user.outfit.has-socks',
User(**external_data),
)


def test_bind_pydantic_basemodel_from_json(app, assert_parse_args, monkeypatch):
@app.command
def foo(user: Annotated[User, Parameter(env_var="USER")]):
pass

external_data = {
"id": 123,
"signup_ts": "2019-06-01 12:22",
"tastes": {
"wine": 9,
"cheese": 7,
"cabbage": "1",
},
"outfit": {
"body": "t-shirt",
"head": "baseball-cap",
"has_socks": True,
},
}

monkeypatch.setenv("USER", json.dumps(external_data))

assert_parse_args(
foo,
'foo --user.id=123 --user.signup-ts="2019-06-01 12:22" --user.tastes.wine=9 --user.tastes.cheese=7 --user.tastes.cabbage=1 --user.outfit.body=t-shirt --user.outfit.head=baseball-cap --user.outfit.has-socks',
Expand Down
Loading