From bf804645600f3cfef657904ab0c27c9c52271798 Mon Sep 17 00:00:00 2001 From: Bernard M Beckerman Date: Tue, 8 Dec 2020 06:54:45 -0800 Subject: [PATCH] Make BotAx registries string-based by class name Summary: Make `ACQUISITION_REGISTRY` and other modular [BotAx registries](https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/ax/ax/storage/botorch_modular_registry.py?lines=73-100%2C57-72) string-based by class name Reviewed By: lena-kashtelyan Differential Revision: D25376150 fbshipit-source-id: 1bed59bf8fd2c6fa7d7e4e10312f4ee7bb31380d --- ax/storage/botorch_modular_registry.py | 75 +++++++++++++------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/ax/storage/botorch_modular_registry.py b/ax/storage/botorch_modular_registry.py index a9dfc92e651..0a5e9f2bae6 100644 --- a/ax/storage/botorch_modular_registry.py +++ b/ax/storage/botorch_modular_registry.py @@ -52,57 +52,57 @@ # to `CLASS_TO_REGISTRY` and `CLASS_TO_REVERSE_REGISTRY` in this file. """ -Mapping of modular Ax `Acquisition` classes to ints. +Mapping of modular Ax `Acquisition` classes to class name strings. """ -ACQUISITION_REGISTRY: Dict[Type[Acquisition], int] = { - Acquisition: 0, - KnowledgeGradient: 1, - MultiFidelityKnowledgeGradient: 2, - MaxValueEntropySearch: 3, - MultiFidelityMaxValueEntropySearch: 4, +ACQUISITION_REGISTRY: Dict[Type[Acquisition], str] = { + Acquisition: "Acquisition", + KnowledgeGradient: "KnowledgeGradient", + MaxValueEntropySearch: "MaxValueEntropySearch", + MultiFidelityKnowledgeGradient: "MultiFidelityKnowledgeGradient", + MultiFidelityMaxValueEntropySearch: "MultiFidelityMaxValueEntropySearch", } """ -Mapping of BoTorch `Model` classes to ints. +Mapping of BoTorch `Model` classes to class name strings. """ -MODEL_REGISTRY: Dict[Type[Model], int] = { - FixedNoiseGP: 0, - SingleTaskGP: 1, - FixedNoiseMultiFidelityGP: 2, - SingleTaskMultiFidelityGP: 3, - ModelListGP: 4, - FixedNoiseMultiTaskGP: 5, - MultiTaskGP: 6, +MODEL_REGISTRY: Dict[Type[Model], str] = { + FixedNoiseGP: "FixedNoiseGP", + FixedNoiseMultiFidelityGP: "FixedNoiseMultiFidelityGP", + FixedNoiseMultiTaskGP: "FixedNoiseMultiTaskGP", + ModelListGP: "ModelListGP", + MultiTaskGP: "MultiTaskGP", + SingleTaskGP: "SingleTaskGP", + SingleTaskMultiFidelityGP: "SingleTaskMultiFidelityGP", } """ -Mapping of Botorch `AcquisitionFunction` classes to ints. +Mapping of Botorch `AcquisitionFunction` classes to class name strings. """ -ACQUISITION_FUNCTION_REGISTRY: Dict[Type[AcquisitionFunction], int] = { - qExpectedImprovement: 0, - qNoisyExpectedImprovement: 1, - qKnowledgeGradient: 2, - qMultiFidelityKnowledgeGradient: 3, - qMaxValueEntropy: 4, - qMultiFidelityMaxValueEntropy: 5, +ACQUISITION_FUNCTION_REGISTRY: Dict[Type[AcquisitionFunction], str] = { + qExpectedImprovement: "qExpectedImprovement", + qKnowledgeGradient: "qKnowledgeGradient", + qMaxValueEntropy: "qMaxValueEntropy", + qMultiFidelityKnowledgeGradient: "qMultiFidelityKnowledgeGradient", + qMultiFidelityMaxValueEntropy: "qMultiFidelityMaxValueEntropy", + qNoisyExpectedImprovement: "qNoisyExpectedImprovement", } """ -Mapping of BoTorch `MarginalLogLikelihood` classes to ints. +Mapping of BoTorch `MarginalLogLikelihood` classes to class name strings. """ -MLL_REGISTRY: Dict[Type[MarginalLogLikelihood], int] = { - ExactMarginalLogLikelihood: 0, - SumMarginalLogLikelihood: 1, +MLL_REGISTRY: Dict[Type[MarginalLogLikelihood], str] = { + ExactMarginalLogLikelihood: "ExactMarginalLogLikelihood", + SumMarginalLogLikelihood: "SumMarginalLogLikelihood", } """ Overarching mapping from encoded classes to registry map. """ -CLASS_TO_REGISTRY: Dict[Any, Dict[Type[Any], int]] = { +CLASS_TO_REGISTRY: Dict[Any, Dict[Type[Any], str]] = { Acquisition: ACQUISITION_REGISTRY, AcquisitionFunction: ACQUISITION_FUNCTION_REGISTRY, MarginalLogLikelihood: MLL_REGISTRY, @@ -113,22 +113,22 @@ """ Reverse registries for decoding. """ -REVERSE_ACQUISITION_REGISTRY: Dict[int, Type[Acquisition]] = { +REVERSE_ACQUISITION_REGISTRY: Dict[str, Type[Acquisition]] = { v: k for k, v in ACQUISITION_REGISTRY.items() } -REVERSE_MODEL_REGISTRY: Dict[int, Type[Model]] = { +REVERSE_MODEL_REGISTRY: Dict[str, Type[Model]] = { v: k for k, v in MODEL_REGISTRY.items() } -REVERSE_ACQUISITION_FUNCTION_REGISTRY: Dict[int, Type[AcquisitionFunction]] = { +REVERSE_ACQUISITION_FUNCTION_REGISTRY: Dict[str, Type[AcquisitionFunction]] = { v: k for k, v in ACQUISITION_FUNCTION_REGISTRY.items() } -REVERSE_MLL_REGISTRY: Dict[int, Type[MarginalLogLikelihood]] = { +REVERSE_MLL_REGISTRY: Dict[str, Type[MarginalLogLikelihood]] = { v: k for k, v in MLL_REGISTRY.items() } @@ -136,7 +136,7 @@ """ Overarching mapping from encoded classes to reverse registry map. """ -CLASS_TO_REVERSE_REGISTRY: Dict[Any, Dict[int, Type[Any]]] = { +CLASS_TO_REVERSE_REGISTRY: Dict[Any, Dict[str, Type[Any]]] = { Acquisition: REVERSE_ACQUISITION_REGISTRY, AcquisitionFunction: REVERSE_ACQUISITION_FUNCTION_REGISTRY, MarginalLogLikelihood: REVERSE_MLL_REGISTRY, @@ -146,7 +146,6 @@ def register_acquisition(acq_class: Type[Acquisition]) -> None: """Add a custom acquisition class to the SQA and JSON registries.""" - ACQUISITION_REGISTRY = CLASS_TO_REGISTRY[Acquisition] - index = len(ACQUISITION_REGISTRY) - CLASS_TO_REGISTRY[Acquisition].update({acq_class: index}) - CLASS_TO_REVERSE_REGISTRY[Acquisition].update({index: acq_class}) + class_name = acq_class.__name__ + CLASS_TO_REGISTRY[Acquisition].update({acq_class: class_name}) + CLASS_TO_REVERSE_REGISTRY[Acquisition].update({class_name: acq_class})