diff --git a/flaml/tune/searcher/search_thread.py b/flaml/tune/searcher/search_thread.py index 5ab7846aa7..56fd0e10ef 100644 --- a/flaml/tune/searcher/search_thread.py +++ b/flaml/tune/searcher/search_thread.py @@ -25,6 +25,19 @@ logger = logging.getLogger(__name__) +def recursive_update(d:dict, u:dict): + """ + Args: + d (dict): The target dictionary to be updated. + u (dict): A dictionary containing values to be merged into `d`. + """ + for k, v in u.items(): + if isinstance(v, dict) and k in d and isinstance(d[k], dict): + recursive_update(d[k], v) + else: + d[k] = v + + class SearchThread: """Class of global or local search thread.""" @@ -65,7 +78,7 @@ def suggest(self, trial_id: str) -> Optional[Dict]: try: config = self._search_alg.suggest(trial_id) if isinstance(self._search_alg._space, dict): - config.update(self._const) + recursive_update(config, self._const) else: # define by run config, self.space = unflatten_hierarchical(config, self._space)