Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/data transfer between nodes #114

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ class LoggingConfig(BaseModel):
"""Path to the directory with different runs."""
run_name: str = Field(default_factory=get_run_name)
"""Name of the run. If None, a random name will be generated"""
dump_modules: bool = False
"""Whether to dump the modules or not"""
clear_ram: bool = False
"""Whether to clear the RAM after dumping the modules"""
keep_in_ram: bool = True
"""Whether to store modules in RAM or dump them into file system."""
report_to: list[str] | None = None
"""List of callbacks to report to. If None, no callbacks will be used"""

Expand Down
4 changes: 2 additions & 2 deletions autointent/context/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def get_dump_dir(self) -> Path | None:

:return: Path to the dump directory or None if dumping is disabled.
"""
if self.logging_config.dump_modules:
if not self.logging_config.keep_in_ram:
return self.logging_config.dump_dir
return None

Expand All @@ -234,7 +234,7 @@ def is_ram_to_clear(self) -> bool:

:return: True if RAM clearing is enabled, False otherwise.
"""
return self.logging_config.clear_ram
return not self.logging_config.keep_in_ram

def has_saved_modules(self) -> bool:
"""
Expand Down
11 changes: 4 additions & 7 deletions autointent/context/optimization_info/_optimization_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def log_module_optimization(
module_params: dict[str, Any],
metric_value: float,
metric_name: str,
artifact: Artifact,
module_dump_dir: str | None,
module: "Module | None" = None,
) -> None:
Expand Down Expand Up @@ -103,13 +102,11 @@ def log_module_optimization(
if module:
self.modules.add_module(node_type, module)

self.artifacts.add_artifact(node_type, artifact)

def _get_metrics_values(self, node_type: str) -> list[float]:
"""Retrieve all metric values for a specific node type."""
return [trial.metric_value for trial in self.trials.get_trials(node_type)]

def _get_best_trial_idx(self, node_type: str) -> int | None:
def get_best_trial_idx(self, node_type: str) -> int | None:
"""
Retrieve the index of the best trial for a node type.

Expand All @@ -133,7 +130,7 @@ def _get_best_artifact(self, node_type: str) -> RetrieverArtifact | ScorerArtifa
:return: The best artifact for the node type.
:raises ValueError: If no best trial exists for the node type.
"""
best_idx = self._get_best_trial_idx(node_type)
best_idx = self.get_best_trial_idx(node_type)
if best_idx is None:
msg = f"No best trial for {node_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -194,7 +191,7 @@ def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNode

:return: List of `InferenceNodeConfig` objects for inference nodes.
"""
trial_ids = [self._get_best_trial_idx(node_type) for node_type in NodeType]
trial_ids = [self.get_best_trial_idx(node_type) for node_type in NodeType]
res = []
for idx, node_type in zip(trial_ids, NodeType, strict=True):
if idx is None:
Expand All @@ -216,7 +213,7 @@ def _get_best_module(self, node_type: str) -> "Module | None":
:param node_type: Type of the node.
:return: The best module, or None if no best trial exists.
"""
idx = self._get_best_trial_idx(node_type)
idx = self.get_best_trial_idx(node_type)
if idx is not None:
return self.modules.get(node_type)[idx]
return None
Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/abc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
"""

@abstractmethod
def get_assets(self) -> Artifact:
def get_artifact(self, context: Context) -> Artifact:
"""Return useful assets that represent intermediate data into context."""

@abstractmethod
Expand Down
9 changes: 5 additions & 4 deletions autointent/modules/abc/_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
:return: Computed metrics value for the test set or error code of metrics
"""
labels, scores = get_decision_evaluation_data(context, split)
self._decisions = self.predict(scores)
decisions = self.predict(scores)
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
return self.score_metrics((labels, self._decisions), chosen_metrics)
return self.score_metrics((labels, decisions), chosen_metrics)

def get_assets(self) -> DecisionArtifact:
def get_artifact(self, context: Context) -> DecisionArtifact:
"""Return useful assets that represent intermediate data into context."""
return DecisionArtifact(labels=self._decisions)
_, scores = get_decision_evaluation_data(context, split="test")
return DecisionArtifact(labels=self.predict(scores))

def clear_cache(self) -> None:
"""Clear cache."""
Expand Down
16 changes: 8 additions & 8 deletions autointent/modules/abc/_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,24 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:

scores = self.predict(utterances)

self._train_scores = self.predict(context.data_handler.train_utterances(1))
self._validation_scores = self.predict(context.data_handler.validation_utterances(1))
self._test_scores = self.predict(context.data_handler.test_utterances())

metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics((labels, scores), chosen_metrics)

def get_assets(self) -> ScorerArtifact:
def get_artifact(self, context: Context) -> ScorerArtifact:
"""
Retrieve assets generated during scoring.

:return: ScorerArtifact containing test, validation and test scores.
"""
train_scores = self.predict(context.data_handler.train_utterances(1))
validation_scores = self.predict(context.data_handler.validation_utterances(1))
test_scores = self.predict(context.data_handler.test_utterances())

return ScorerArtifact(
train_scores=self._train_scores,
validation_scores=self._validation_scores,
test_scores=self._test_scores,
train_scores=train_scores,
validation_scores=validation_scores,
test_scores=test_scores,
)

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/embedding/_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics((labels, probas), chosen_metrics)

def get_assets(self) -> RetrieverArtifact:
def get_artifact(self, context: Context) -> RetrieverArtifact:
"""
Get the classifier artifacts for this module.

Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/embedding/_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics((labels, predictions), chosen_metrics)

def get_assets(self) -> RetrieverArtifact:
def get_artifact(self, context: Context) -> RetrieverArtifact:
"""
Get the retriever artifacts for this module.

Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/regexp/_regexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def clear_cache(self) -> None:
"""Clear cache."""
del self.regexp_patterns

def get_assets(self) -> Artifact:
def get_artifact(self) -> Artifact:
"""Get assets."""
return Artifact()

Expand Down
21 changes: 17 additions & 4 deletions autointent/nodes/_optimization/_node_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def fit(self, context: Context) -> None:
:param context: Context
"""
self._logger.info("starting %s node optimization...", self.node_info.node_type)

scored_modules = []
for search_space in deepcopy(self.modules_search_spaces):
module_name = search_space.pop("module_name")

Expand All @@ -62,7 +62,10 @@ def fit(self, context: Context) -> None:
context.callback_handler.start_module(
module_name=module_name, num=j_combination, module_kwargs=module_kwargs
)
module = self.node_info.modules_available[module_name].from_context(context, **module_kwargs)
module_type = self.node_info.modules_available[module_name]
module = module_type.from_context(context, **module_kwargs)

scored_modules.append((module_type, module_kwargs))

embedder_name = module.get_embedder_name()
if embedder_name is not None:
Expand Down Expand Up @@ -92,7 +95,6 @@ def fit(self, context: Context) -> None:
module_kwargs,
metric_value,
self.target_metric,
module.get_assets(), # retriever name / scores / predictions
module_dump_dir,
module=module if not context.is_ram_to_clear() else None,
)
Expand All @@ -102,7 +104,18 @@ def fit(self, context: Context) -> None:
gc.collect()
torch.cuda.empty_cache()

self._logger.info("%s node optimization is finished!", self.node_info.node_type)
self._logger.info("%s node optimization is finished! saving best assets", self.node_info.node_type)
# TODO refactor the following code (via implementing `autointent.load_module(path)` utility)
trial_idx = context.optimization_info.get_best_trial_idx(self.node_type)
if context.is_ram_to_clear():
trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) # type: ignore[arg-type]
module_type, module_kwargs = scored_modules[trial_idx] # type: ignore[index]
best_module: Module = module_type(**module_kwargs)
best_module.load(trial.module_dump_dir) # type: ignore[arg-type]
else:
best_module = context.optimization_info.modules.get(self.node_type)[trial_idx]
artifact = best_module.get_artifact(context)
context.optimization_info.artifacts.add_artifact(self.node_type, artifact)

def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_pipeline_callbacks(dataset):
pipeline_optimizer = Pipeline.from_search_space(search_space)
context = Context()
context.configure_vector_index(VectorIndexConfig(save_db=True))
context.configure_logging(LoggingConfig(run_name="dummy_run_name", project_dir=project_dir, dump_modules=False))
context.configure_logging(LoggingConfig(run_name="dummy_run_name", project_dir=project_dir, keep_in_ram=True))
context.callback_handler = CallbackHandler([DummyCallback])
context.set_dataset(dataset)

Expand Down
6 changes: 0 additions & 6 deletions tests/modules/embedding/test_logreg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from autointent.modules.embedding import LogregAimedEmbedding


def test_get_assets_returns_correct_artifact_for_logreg():
module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo")
artifact = module.get_assets()
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"


def test_fit_trains_model():
module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo")

Expand Down
6 changes: 0 additions & 6 deletions tests/modules/embedding/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
from tests.conftest import setup_environment


def test_get_assets_returns_correct_artifact():
module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
artifact = module.get_assets()
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"


def test_dump_and_load_preserves_model_state():
project_dir = setup_environment()
module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
Expand Down
2 changes: 1 addition & 1 deletion tests/nodes/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_context(multilabel):
if multilabel:
dataset = dataset.to_multilabel()
res.set_dataset(dataset)
res.configure_logging(LoggingConfig(project_dir=project_dir, dump_modules=True))
res.configure_logging(LoggingConfig(project_dir=project_dir, keep_in_ram=True))
res.configure_vector_index(VectorIndexConfig(), EmbedderConfig(device="cpu"))
res.configure_cross_encoder(CrossEncoderConfig())
return res
4 changes: 2 additions & 2 deletions tests/pipeline/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_inference_config(dataset, task_type):

pipeline_optimizer = Pipeline.from_search_space(search_space)

pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True))
pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=False))
pipeline_optimizer.set_config(VectorIndexConfig(save_db=True))
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))
pipeline_optimizer.set_config(CrossEncoderConfig())
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_inference_context(dataset, task_type):

pipeline = Pipeline.from_search_space(search_space)

pipeline.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=False))
pipeline.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=True))
pipeline.set_config(VectorIndexConfig(save_db=True))
pipeline.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))

Expand Down
6 changes: 3 additions & 3 deletions tests/pipeline/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_no_context_optimization(dataset, task_type):

pipeline_optimizer = Pipeline.from_search_space(search_space)

pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False))
pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=True))
pipeline_optimizer.set_config(VectorIndexConfig())
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))

Expand All @@ -45,7 +45,7 @@ def test_save_db(dataset, task_type):

pipeline_optimizer = Pipeline.from_search_space(search_space)

pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False))
pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=True))
pipeline_optimizer.set_config(VectorIndexConfig(save_db=True))
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))

Expand All @@ -66,7 +66,7 @@ def test_dump_modules(dataset, task_type):

pipeline_optimizer = Pipeline.from_search_space(search_space)

pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True))
pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=False))
pipeline_optimizer.set_config(VectorIndexConfig())
pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu"))

Expand Down
2 changes: 1 addition & 1 deletion user_guides/advanced/04_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from pathlib import Path

log_config = LoggingConfig(
run_name="test_tensorboard", report_to=["tensorboard"], dirpath=Path("test_tensorboard"), dump_modules=False
run_name="test_tensorboard", report_to=["tensorboard"], dirpath=Path("test_tensorboard"), keep_in_ram=True
)

pipeline_optimizer.set_config(log_config)
Expand Down
2 changes: 1 addition & 1 deletion user_guides/basic_usage/03_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
from pathlib import Path
from autointent.configs import LoggingConfig

logging_config = LoggingConfig(project_dir=Path.cwd() / "runs", dump_modules=False, clear_ram=False)
logging_config = LoggingConfig(project_dir=Path.cwd() / "runs", keep_in_ram=True)
custom_pipeline.set_config(logging_config)

# %% [markdown]
Expand Down
4 changes: 2 additions & 2 deletions user_guides/basic_usage/04_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
# %%
from autointent.configs import LoggingConfig

logging_config = LoggingConfig(dump_modules=True, clear_ram=True)
logging_config = LoggingConfig(keep_in_ram=False)

# %% [markdown]
"""
Expand All @@ -82,7 +82,7 @@

dataset = Dataset.from_hub("AutoIntent/clinc150_subset")
pipeline = Pipeline.from_search_space(search_space)
pipeline.set_config(LoggingConfig(dump_modules=True, clear_ram=True))
pipeline.set_config(LoggingConfig(keep_in_ram=False))
pipeline.set_config(VectorIndexConfig(save_db=True))

# %% [markdown]
Expand Down
Loading