Skip to content

Commit

Permalink
Fix Accidental Override of Boolean Value (#66)
Browse files Browse the repository at this point in the history
* fix default value override

* incrementing dolma version in rust
  • Loading branch information
soldni authored Oct 26, 2023
1 parent 490280a commit 2ee1ae2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dolma"
version = "0.9.0"
version = "0.9.1"
edition = "2021"
license = "Apache-2.0"

Expand Down
28 changes: 17 additions & 11 deletions python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@

def _field_nargs(default: Any) -> Union[Literal["?"], Literal["*"]]:
# return '+' if _default is iterable but not string/bytes, else 1
if isinstance(default, str) or isinstance(default, bytes):
if isinstance(default, (str, bytes)):
return "?"
elif isinstance(default, Iterable):

if isinstance(default, Iterable):
return "*"
else:
return "?"

return "?"


def field(default: T = MISSING, help: Optional[str] = None, **extra: Any) -> T:
Expand All @@ -68,9 +69,9 @@ class DataClass(Protocol):


def make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None) -> A:
for field_name, field in config.__dataclass_fields__.items():
for field_name, dt_field in config.__dataclass_fields__.items():
# get type from annotations or metadata
typ_ = config.__annotations__.get(field_name, field.metadata.get("type", MISSING))
typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING))

if typ_ is MISSING:
warn(f"No type annotation for field {field_name} in {config.__name__}")
Expand Down Expand Up @@ -101,22 +102,24 @@ def make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None
# for boolean values, we add two arguments: --field_name and --no-field_name
parser.add_argument(
f"--{field_name}",
help=field.metadata.get("help"),
help=dt_field.metadata.get("help"),
dest=field_name,
action="store_true",
default=MISSING,
)
parser.add_argument(
f"--no-{field_name}",
help=f"Disable {field_name}",
dest=field_name,
action="store_false",
default=MISSING,
)
else:
# else it's just a normal argument
parser.add_argument(
f"--{field_name}",
help=field.metadata.get("help"),
nargs=field.metadata.get("nargs", "?"),
help=dt_field.metadata.get("help"),
nargs=dt_field.metadata.get("nargs", "?"),
default=MISSING,
)

Expand Down Expand Up @@ -154,9 +157,12 @@ def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config:
def print_config(config: Any, console: Optional[Console] = None) -> None:
if not isinstance(config, (DictConfig, ListConfig)):
config = om.create(config)

# print the config as yaml using a rich syntax highlighter
console = console or Console()
syntax = Syntax(code=om.to_yaml(config, sort_keys=True).strip(), lexer="yaml", theme="ansi_dark")
console.print(syntax)
yaml_config = om.to_yaml(config, sort_keys=True).strip()
highlighted = Syntax(code=yaml_config, lexer="yaml", theme="ansi_dark")
console.print(highlighted)


class BaseCli(Generic[D]):
Expand Down

0 comments on commit 2ee1ae2

Please sign in to comment.