Skip to content

Commit

Permalink
Merge branch 'master' into yaml_config_space
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Nov 25, 2023
2 parents 65e340f + 521f021 commit 155ed07
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- uses: actions/setup-python@v4
with:
Expand Down
51 changes: 29 additions & 22 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Natural Language :: English",
"License :: OSI Approved :: Apache Software License",
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
Expand All @@ -46,40 +47,46 @@ packages = [

[tool.poetry.dependencies]
python = ">=3.8,<3.12"
ConfigSpace = "^0.4.19"
grakel = "^0.1.9"
numpy = "^1.23.0"
pandas = "^1.3.1"
ConfigSpace = "^0.6"
grakel = "^0.1"
numpy = "^1"
pandas = "^2"
networkx = "^2.6.3"
nltk = "^3.6.4"
path = "^16.2.0"
termcolor = "^1.1.0"
scipy = "^1.8"
#path = "^16.2.0"
#termcolor = "^1.1.0"
scipy = "^1"
torch = ">=1.7.0,<=2.1, !=2.0.1, !=2.1.0" # fix from: https://stackoverflow.com/a/76647180
# torch = [
# {version = ">=1.7.0,<=2.1", markers = "sys_platform == 'darwin'"}, # Segfaults for macOS on github actions
# {version = ">=1.7.0,<=2.1", markers = "sys_platform != 'darwin'"},
# ]
matplotlib = "^3.4"
statsmodels = "^0.13.2"
more-itertools = "^9.0.0"
portalocker = "^2.6.0"
seaborn = "^0.12.1"
pyyaml = "^6.0"
tensorboard = "^2.13"
cython = "^3.0.4"
torchvision = "<0.16.0"
matplotlib = "^3"
# statsmodels = "^0.13.2"
more-itertools = "^10"
portalocker = "^2"
seaborn = "^0.13"
pyyaml = "^6"
tensorboard = "^2"
# cython = "^3.0.4"

[tool.poetry.group.dev.dependencies]
pre-commit = "^2.10"
mypy = "^0.930"
pytest = "^6.2.5"
types-PyYAML = "^6.0.12"
typing-extensions = "^4.0.1"
types-termcolor = "^1.1.2"
pre-commit = "^3"
mypy = "^1"
pytest = "^7"
types-PyYAML = "^6"
typing-extensions = "^4"
#types-termcolor = "^1.1.2"
# jahs-bench = {git = "https://github.com/automl/jahs_bench_201.git", rev = "v1.0.2"}
mkdocs-material = "^8.1.3"
mike = "^1.1.2"
torchvision = "<0.16.0" # Used in examples


[tool.poetry.group.experimental]
optional = true

[tool.poetry.group.experimental.dependencies]
gpytorch = "1.8.0"

[build-system]
Expand Down
20 changes: 19 additions & 1 deletion src/metahyper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ def get_data_representation(data: Any):
return data


class MissingDependencyError(Exception):
def __init__(self, dep: str, cause: Exception, *args: Any):
super().__init__(dep, cause, *args)
self.dep = dep
self.__cause__ = cause # This is what `raise a from b` does

def __str__(self) -> str:
return (
f"Some required dependency-({self.dep}) to use this optional feature is "
f"missing. Please, include neps[experimental] dependency group in your "
f"installation of neps to be able to use all the optional features."
f" Otherwise, just install ({self.dep})"
)


class YamlSerializer:
SUFFIX = ".yaml"
PRE_SERIALIZE = True
Expand Down Expand Up @@ -97,7 +112,7 @@ def instance_from_map(
name: str = "mapping",
allow_any: bool = True,
as_class: bool = False,
kwargs: dict = None,
kwargs: dict | None = None,
):
"""Get an instance of an class from a mapping.
Expand Down Expand Up @@ -140,6 +155,9 @@ def instance_from_map(
else:
raise ValueError(f"Object {request} invalid key for {name}")

if isinstance(instance, MissingDependencyError):
raise instance

# Check if the request is a class if it is mandatory
if (args_dict or as_class) and not is_partial_class(instance):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions src/neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import warnings
from pathlib import Path
from typing import Callable, Literal, List
from typing import Callable, List, Literal

import ConfigSpace as CS

Expand Down Expand Up @@ -297,7 +297,7 @@ def run(
pre_load_hooks=pre_load_hooks,
)

if post_run_csv:
if post_run_summary:
post_run_csv(root_directory, logger)


Expand Down
19 changes: 12 additions & 7 deletions src/neps/optimizers/bayesian_optimization/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from .deepGP import DeepGP
from metahyper.utils import MissingDependencyError

from .gp import ComprehensiveGP
from .gp_hierarchy import ComprehensiveGPHierarchy

try:
from .deepGP import DeepGP
except ImportError as e:
DeepGP = MissingDependencyError("gpytorch", e)

try:
from .pfn import PFN_SURROGATE # only if available locally
except Exception as e:
PFN_SURROGATE = MissingDependencyError("pfn", e)

SurrogateModelMapping = {
"deep_gp": DeepGP,
"gp": ComprehensiveGP,
"gp_hierarchy": ComprehensiveGPHierarchy,
"pfn": PFN_SURROGATE,
}

try:
from .pfn import PFN_SURROGATE # only if available locally
SurrogateModelMapping.update({"pfn": PFN_SURROGATE})
except:
pass
2 changes: 1 addition & 1 deletion src/neps/plot/tensorboard_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _write_image_config(
if resize_images is None:
resize_images = [32, 32]

if tblogger.current_epoch and tblogger.current_epoch % counter == 0:
if tblogger.current_epoch >= 0 and tblogger.current_epoch % counter == 0:
# Log every multiple of "counter"

if num_images > len(image):
Expand Down
2 changes: 1 addition & 1 deletion src/neps/search_spaces/architecture/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import networkx as nx
import torch
from networkx.algorithms.dag import lexicographical_topological_sort
from path import Path
from pathlib import Path
from torch import nn

from ...utils.common import AttrDict
Expand Down
3 changes: 3 additions & 0 deletions src/neps/status/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,6 @@ def post_run_csv(root_directory: str | Path, logger=None) -> None:
df_config_data,
df_run_data,
)

def get_run_summary_csv(root_directory: str | Path):
post_run_csv(root_directory=root_directory)

0 comments on commit 155ed07

Please sign in to comment.