Skip to content

Commit

Permalink
small refactor of code
Browse files Browse the repository at this point in the history
  • Loading branch information
y1xiaoc committed May 13, 2021
1 parent 615ce5b commit b998f08
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _check_strict(self, value: dict, path=None):
if name not in allowed_keys:
raise ArgumentKeyError(path,
f"undefined key `{name}` is "
"not allowed in strict mode")
"not allowed in strict mode")

# above are type checking part
# below are normalizing part
Expand All @@ -323,9 +323,9 @@ def normalize(self, argdict: dict, inplace: bool = False,
self.traverse(argdict,
key_hook=Argument._assign_default)
if trim_pattern is not None:
self._trim_unrequired(argdict, trim_pattern, reserved=[self.name])
trim_by_pattern(argdict, trim_pattern, reserved=[self.name])
self.traverse(argdict, sub_hook=lambda a, d, p:
Argument._trim_unrequired(d, trim_pattern, a.flatten_sub(d, p).keys()))
trim_by_pattern(d, trim_pattern, a.flatten_sub(d, p).keys()))
return argdict

def normalize_value(self, value: Any, inplace: bool = False,
Expand All @@ -342,7 +342,7 @@ def normalize_value(self, value: Any, inplace: bool = False,
key_hook=Argument._assign_default)
if trim_pattern is not None:
self.traverse_value(value, sub_hook=lambda a, d, p:
Argument._trim_unrequired(d, trim_pattern, a.flatten_sub(d, p).keys()))
trim_by_pattern(d, trim_pattern, a.flatten_sub(d, p).keys()))
return value

def _assign_default(self, argdict: dict, path=None):
Expand All @@ -358,21 +358,6 @@ def _convert_alias(self, argdict: dict, path=None):
argdict[self.name] = argdict.pop(alias)
return

@staticmethod
def _trim_unrequired(argdict: dict, pattern: str,
reserved: Optional[List[str]] = None,
use_regex: bool = False):
rep = fnmatch.translate(pattern) if not use_regex else pattern
rem = re.compile(rep)
if reserved:
conflict = list(filter(rem.match, reserved))
if conflict:
raise ValueError(f"pattern `{pattern}` conflicts with the "
f"following reserved names: {', '.join(conflict)}")
unrequired = list(filter(rem.match, argdict.keys()))
for key in unrequired:
argdict.pop(key)

# above are normalizing part
# below are doc generation part

Expand Down Expand Up @@ -632,4 +617,19 @@ def update_nodup(this : dict,
raise ValueError(f"duplicate key `{k}` when updating dict"
+("" if err_msg is None else f"in {err_msg}"))
this[k] = v
return this
return this


def trim_by_pattern(argdict: dict, pattern: str,
reserved: Optional[List[str]] = None,
use_regex: bool = False):
rep = fnmatch.translate(pattern) if not use_regex else pattern
rem = re.compile(rep)
if reserved:
conflict = list(filter(rem.match, reserved))
if conflict:
raise ValueError(f"pattern `{pattern}` conflicts with the "
f"following reserved names: {', '.join(conflict)}")
unrequired = list(filter(rem.match, argdict.keys()))
for key in unrequired:
argdict.pop(key)

0 comments on commit b998f08

Please sign in to comment.