Skip to content

Commit

Permalink
Make BotAx registries string-based by class name
Browse files Browse the repository at this point in the history
Summary: Make `ACQUISITION_REGISTRY` and other modular [BotAx registries](https://www.internalfb.com/intern/diffusion/FBS/browsefile/master/fbcode/ax/ax/storage/botorch_modular_registry.py?lines=73-100%2C57-72) string-based by class name

Reviewed By: lena-kashtelyan

Differential Revision: D25376150

fbshipit-source-id: 1bed59bf8fd2c6fa7d7e4e10312f4ee7bb31380d
  • Loading branch information
bernardbeckerman authored and facebook-github-bot committed Dec 8, 2020
1 parent 5546981 commit bf80464
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions ax/storage/botorch_modular_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,57 +52,57 @@
# to `CLASS_TO_REGISTRY` and `CLASS_TO_REVERSE_REGISTRY` in this file.

"""
Mapping of modular Ax `Acquisition` classes to ints.
Mapping of modular Ax `Acquisition` classes to class name strings.
"""
ACQUISITION_REGISTRY: Dict[Type[Acquisition], int] = {
Acquisition: 0,
KnowledgeGradient: 1,
MultiFidelityKnowledgeGradient: 2,
MaxValueEntropySearch: 3,
MultiFidelityMaxValueEntropySearch: 4,
ACQUISITION_REGISTRY: Dict[Type[Acquisition], str] = {
Acquisition: "Acquisition",
KnowledgeGradient: "KnowledgeGradient",
MaxValueEntropySearch: "MaxValueEntropySearch",
MultiFidelityKnowledgeGradient: "MultiFidelityKnowledgeGradient",
MultiFidelityMaxValueEntropySearch: "MultiFidelityMaxValueEntropySearch",
}


"""
Mapping of BoTorch `Model` classes to ints.
Mapping of BoTorch `Model` classes to class name strings.
"""
MODEL_REGISTRY: Dict[Type[Model], int] = {
FixedNoiseGP: 0,
SingleTaskGP: 1,
FixedNoiseMultiFidelityGP: 2,
SingleTaskMultiFidelityGP: 3,
ModelListGP: 4,
FixedNoiseMultiTaskGP: 5,
MultiTaskGP: 6,
MODEL_REGISTRY: Dict[Type[Model], str] = {
FixedNoiseGP: "FixedNoiseGP",
FixedNoiseMultiFidelityGP: "FixedNoiseMultiFidelityGP",
FixedNoiseMultiTaskGP: "FixedNoiseMultiTaskGP",
ModelListGP: "ModelListGP",
MultiTaskGP: "MultiTaskGP",
SingleTaskGP: "SingleTaskGP",
SingleTaskMultiFidelityGP: "SingleTaskMultiFidelityGP",
}


"""
Mapping of Botorch `AcquisitionFunction` classes to ints.
Mapping of Botorch `AcquisitionFunction` classes to class name strings.
"""
ACQUISITION_FUNCTION_REGISTRY: Dict[Type[AcquisitionFunction], int] = {
qExpectedImprovement: 0,
qNoisyExpectedImprovement: 1,
qKnowledgeGradient: 2,
qMultiFidelityKnowledgeGradient: 3,
qMaxValueEntropy: 4,
qMultiFidelityMaxValueEntropy: 5,
ACQUISITION_FUNCTION_REGISTRY: Dict[Type[AcquisitionFunction], str] = {
qExpectedImprovement: "qExpectedImprovement",
qKnowledgeGradient: "qKnowledgeGradient",
qMaxValueEntropy: "qMaxValueEntropy",
qMultiFidelityKnowledgeGradient: "qMultiFidelityKnowledgeGradient",
qMultiFidelityMaxValueEntropy: "qMultiFidelityMaxValueEntropy",
qNoisyExpectedImprovement: "qNoisyExpectedImprovement",
}


"""
Mapping of BoTorch `MarginalLogLikelihood` classes to ints.
Mapping of BoTorch `MarginalLogLikelihood` classes to class name strings.
"""
MLL_REGISTRY: Dict[Type[MarginalLogLikelihood], int] = {
ExactMarginalLogLikelihood: 0,
SumMarginalLogLikelihood: 1,
MLL_REGISTRY: Dict[Type[MarginalLogLikelihood], str] = {
ExactMarginalLogLikelihood: "ExactMarginalLogLikelihood",
SumMarginalLogLikelihood: "SumMarginalLogLikelihood",
}


"""
Overarching mapping from encoded classes to registry map.
"""
CLASS_TO_REGISTRY: Dict[Any, Dict[Type[Any], int]] = {
CLASS_TO_REGISTRY: Dict[Any, Dict[Type[Any], str]] = {
Acquisition: ACQUISITION_REGISTRY,
AcquisitionFunction: ACQUISITION_FUNCTION_REGISTRY,
MarginalLogLikelihood: MLL_REGISTRY,
Expand All @@ -113,30 +113,30 @@
"""
Reverse registries for decoding.
"""
REVERSE_ACQUISITION_REGISTRY: Dict[int, Type[Acquisition]] = {
REVERSE_ACQUISITION_REGISTRY: Dict[str, Type[Acquisition]] = {
v: k for k, v in ACQUISITION_REGISTRY.items()
}


REVERSE_MODEL_REGISTRY: Dict[int, Type[Model]] = {
REVERSE_MODEL_REGISTRY: Dict[str, Type[Model]] = {
v: k for k, v in MODEL_REGISTRY.items()
}


REVERSE_ACQUISITION_FUNCTION_REGISTRY: Dict[int, Type[AcquisitionFunction]] = {
REVERSE_ACQUISITION_FUNCTION_REGISTRY: Dict[str, Type[AcquisitionFunction]] = {
v: k for k, v in ACQUISITION_FUNCTION_REGISTRY.items()
}


REVERSE_MLL_REGISTRY: Dict[int, Type[MarginalLogLikelihood]] = {
REVERSE_MLL_REGISTRY: Dict[str, Type[MarginalLogLikelihood]] = {
v: k for k, v in MLL_REGISTRY.items()
}


"""
Overarching mapping from encoded classes to reverse registry map.
"""
CLASS_TO_REVERSE_REGISTRY: Dict[Any, Dict[int, Type[Any]]] = {
CLASS_TO_REVERSE_REGISTRY: Dict[Any, Dict[str, Type[Any]]] = {
Acquisition: REVERSE_ACQUISITION_REGISTRY,
AcquisitionFunction: REVERSE_ACQUISITION_FUNCTION_REGISTRY,
MarginalLogLikelihood: REVERSE_MLL_REGISTRY,
Expand All @@ -146,7 +146,6 @@

def register_acquisition(acq_class: Type[Acquisition]) -> None:
"""Add a custom acquisition class to the SQA and JSON registries."""
ACQUISITION_REGISTRY = CLASS_TO_REGISTRY[Acquisition]
index = len(ACQUISITION_REGISTRY)
CLASS_TO_REGISTRY[Acquisition].update({acq_class: index})
CLASS_TO_REVERSE_REGISTRY[Acquisition].update({index: acq_class})
class_name = acq_class.__name__
CLASS_TO_REGISTRY[Acquisition].update({acq_class: class_name})
CLASS_TO_REVERSE_REGISTRY[Acquisition].update({class_name: acq_class})

0 comments on commit bf80464

Please sign in to comment.