From 6b168c4684800975ba15ec8966e838c3534c1b35 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 2 Feb 2025 08:12:54 +0300 Subject: [PATCH 1/4] merge `clear_ram` and `dump_modules` into `keep_in_ram` --- autointent/configs/_optimization.py | 6 ++---- autointent/context/_context.py | 4 ++-- tests/callback/test_callback.py | 2 +- tests/nodes/conftest.py | 2 +- tests/pipeline/test_inference.py | 4 ++-- tests/pipeline/test_optimization.py | 6 +++--- user_guides/advanced/04_reporting.py | 2 +- user_guides/basic_usage/03_automl.py | 2 +- user_guides/basic_usage/04_inference.py | 4 ++-- 9 files changed, 15 insertions(+), 17 deletions(-) diff --git a/autointent/configs/_optimization.py b/autointent/configs/_optimization.py index 610eece5..f2296284 100644 --- a/autointent/configs/_optimization.py +++ b/autointent/configs/_optimization.py @@ -28,10 +28,8 @@ class LoggingConfig(BaseModel): """Path to the directory with different runs.""" run_name: str = Field(default_factory=get_run_name) """Name of the run. If None, a random name will be generated""" - dump_modules: bool = False - """Whether to dump the modules or not""" - clear_ram: bool = False - """Whether to clear the RAM after dumping the modules""" + keep_in_ram: bool = True + """Whether to store modules in RAM or dump them into file system.""" report_to: list[str] | None = None """List of callbacks to report to. If None, no callbacks will be used""" diff --git a/autointent/context/_context.py b/autointent/context/_context.py index 1081ea3a..c8d63c17 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -208,7 +208,7 @@ def get_dump_dir(self) -> Path | None: :return: Path to the dump directory or None if dumping is disabled. """ - if self.logging_config.dump_modules: + if not self.logging_config.keep_in_ram: return self.logging_config.dump_dir return None @@ -234,7 +234,7 @@ def is_ram_to_clear(self) -> bool: :return: True if RAM clearing is enabled, False otherwise. """ - return self.logging_config.clear_ram + return not self.logging_config.keep_in_ram def has_saved_modules(self) -> bool: """ diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index f09d479a..2f36aae1 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -83,7 +83,7 @@ def test_pipeline_callbacks(dataset): pipeline_optimizer = Pipeline.from_search_space(search_space) context = Context() context.configure_vector_index(VectorIndexConfig(save_db=True)) - context.configure_logging(LoggingConfig(run_name="dummy_run_name", project_dir=project_dir, dump_modules=False)) + context.configure_logging(LoggingConfig(run_name="dummy_run_name", project_dir=project_dir, keep_in_ram=True)) context.callback_handler = CallbackHandler([DummyCallback]) context.set_dataset(dataset) diff --git a/tests/nodes/conftest.py b/tests/nodes/conftest.py index bb0117e3..e5095249 100644 --- a/tests/nodes/conftest.py +++ b/tests/nodes/conftest.py @@ -82,7 +82,7 @@ def get_context(multilabel): if multilabel: dataset = dataset.to_multilabel() res.set_dataset(dataset) - res.configure_logging(LoggingConfig(project_dir=project_dir, dump_modules=True)) + res.configure_logging(LoggingConfig(project_dir=project_dir, keep_in_ram=False)) res.configure_vector_index(VectorIndexConfig(), EmbedderConfig(device="cpu")) res.configure_cross_encoder(CrossEncoderConfig()) return res diff --git a/tests/pipeline/test_inference.py b/tests/pipeline/test_inference.py index 456d54e4..f389f5b7 100644 --- a/tests/pipeline/test_inference.py +++ b/tests/pipeline/test_inference.py @@ -15,7 +15,7 @@ def test_inference_config(dataset, task_type): pipeline_optimizer = Pipeline.from_search_space(search_space) - pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)) + pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=False)) pipeline_optimizer.set_config(VectorIndexConfig(save_db=True)) pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu")) pipeline_optimizer.set_config(CrossEncoderConfig()) @@ -47,7 +47,7 @@ def test_inference_context(dataset, task_type): pipeline = Pipeline.from_search_space(search_space) - pipeline.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False, clear_ram=False)) + pipeline.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=True)) pipeline.set_config(VectorIndexConfig(save_db=True)) pipeline.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu")) diff --git a/tests/pipeline/test_optimization.py b/tests/pipeline/test_optimization.py index 050eca74..0f4a4f2b 100644 --- a/tests/pipeline/test_optimization.py +++ b/tests/pipeline/test_optimization.py @@ -24,7 +24,7 @@ def test_no_context_optimization(dataset, task_type): pipeline_optimizer = Pipeline.from_search_space(search_space) - pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False)) + pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=True)) pipeline_optimizer.set_config(VectorIndexConfig()) pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu")) @@ -45,7 +45,7 @@ def test_save_db(dataset, task_type): pipeline_optimizer = Pipeline.from_search_space(search_space) - pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=False)) + pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=True)) pipeline_optimizer.set_config(VectorIndexConfig(save_db=True)) pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu")) @@ -66,7 +66,7 @@ def test_dump_modules(dataset, task_type): pipeline_optimizer = Pipeline.from_search_space(search_space) - pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True)) + pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, keep_in_ram=False)) pipeline_optimizer.set_config(VectorIndexConfig()) pipeline_optimizer.set_config(EmbedderConfig(batch_size=16, max_length=32, device="cpu")) diff --git a/user_guides/advanced/04_reporting.py b/user_guides/advanced/04_reporting.py index 868153d5..417d3ff6 100644 --- a/user_guides/advanced/04_reporting.py +++ b/user_guides/advanced/04_reporting.py @@ -74,7 +74,7 @@ from pathlib import Path log_config = LoggingConfig( - run_name="test_tensorboard", report_to=["tensorboard"], dirpath=Path("test_tensorboard"), dump_modules=False + run_name="test_tensorboard", report_to=["tensorboard"], dirpath=Path("test_tensorboard"), keep_in_ram=True ) pipeline_optimizer.set_config(log_config) diff --git a/user_guides/basic_usage/03_automl.py b/user_guides/basic_usage/03_automl.py index 97a93d64..51d8f739 100644 --- a/user_guides/basic_usage/03_automl.py +++ b/user_guides/basic_usage/03_automl.py @@ -124,7 +124,7 @@ from pathlib import Path from autointent.configs import LoggingConfig -logging_config = LoggingConfig(project_dir=Path.cwd() / "runs", dump_modules=False, clear_ram=False) +logging_config = LoggingConfig(project_dir=Path.cwd() / "runs", keep_in_ram=True) custom_pipeline.set_config(logging_config) # %% [markdown] diff --git a/user_guides/basic_usage/04_inference.py b/user_guides/basic_usage/04_inference.py index 0de393b4..4863c110 100644 --- a/user_guides/basic_usage/04_inference.py +++ b/user_guides/basic_usage/04_inference.py @@ -67,7 +67,7 @@ # %% from autointent.configs import LoggingConfig -logging_config = LoggingConfig(dump_modules=True, clear_ram=True) +logging_config = LoggingConfig(keep_in_ram=False) # %% [markdown] """ @@ -82,7 +82,7 @@ dataset = Dataset.from_hub("AutoIntent/clinc150_subset") pipeline = Pipeline.from_search_space(search_space) -pipeline.set_config(LoggingConfig(dump_modules=True, clear_ram=True)) +pipeline.set_config(LoggingConfig(keep_in_ram=False)) pipeline.set_config(VectorIndexConfig(save_db=True)) # %% [markdown] From 9a337f99f9b659891788dbf2804ba412317d0483 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 2 Feb 2025 09:02:18 +0300 Subject: [PATCH 2/4] refactor: artifacts are added only for best module --- .../optimization_info/_optimization_info.py | 11 ++++------- autointent/modules/abc/_base.py | 2 +- autointent/modules/abc/_decision.py | 9 +++++---- autointent/modules/abc/_scoring.py | 16 ++++++++-------- autointent/modules/embedding/_logreg.py | 2 +- autointent/modules/embedding/_retrieval.py | 2 +- autointent/modules/regexp/_regexp.py | 2 +- .../nodes/_optimization/_node_optimizer.py | 17 +++++++++++++---- tests/modules/embedding/test_logreg.py | 4 ++-- tests/modules/embedding/test_retrieval.py | 4 ++-- 10 files changed, 38 insertions(+), 31 deletions(-) diff --git a/autointent/context/optimization_info/_optimization_info.py b/autointent/context/optimization_info/_optimization_info.py index 7ec1e5a0..07cfd725 100644 --- a/autointent/context/optimization_info/_optimization_info.py +++ b/autointent/context/optimization_info/_optimization_info.py @@ -74,7 +74,6 @@ def log_module_optimization( module_params: dict[str, Any], metric_value: float, metric_name: str, - artifact: Artifact, module_dump_dir: str | None, module: "Module | None" = None, ) -> None: @@ -103,13 +102,11 @@ def log_module_optimization( if module: self.modules.add_module(node_type, module) - self.artifacts.add_artifact(node_type, artifact) - def _get_metrics_values(self, node_type: str) -> list[float]: """Retrieve all metric values for a specific node type.""" return [trial.metric_value for trial in self.trials.get_trials(node_type)] - def _get_best_trial_idx(self, node_type: str) -> int | None: + def get_best_trial_idx(self, node_type: str) -> int | None: """ Retrieve the index of the best trial for a node type. @@ -133,7 +130,7 @@ def _get_best_artifact(self, node_type: str) -> RetrieverArtifact | ScorerArtifa :return: The best artifact for the node type. :raises ValueError: If no best trial exists for the node type. """ - best_idx = self._get_best_trial_idx(node_type) + best_idx = self.get_best_trial_idx(node_type) if best_idx is None: msg = f"No best trial for {node_type}" raise ValueError(msg) @@ -194,7 +191,7 @@ def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNode :return: List of `InferenceNodeConfig` objects for inference nodes. """ - trial_ids = [self._get_best_trial_idx(node_type) for node_type in NodeType] + trial_ids = [self.get_best_trial_idx(node_type) for node_type in NodeType] res = [] for idx, node_type in zip(trial_ids, NodeType, strict=True): if idx is None: @@ -216,7 +213,7 @@ def _get_best_module(self, node_type: str) -> "Module | None": :param node_type: Type of the node. :return: The best module, or None if no best trial exists. """ - idx = self._get_best_trial_idx(node_type) + idx = self.get_best_trial_idx(node_type) if idx is not None: return self.modules.get(node_type)[idx] return None diff --git a/autointent/modules/abc/_base.py b/autointent/modules/abc/_base.py index 287b5126..14f8ac26 100644 --- a/autointent/modules/abc/_base.py +++ b/autointent/modules/abc/_base.py @@ -44,7 +44,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics: """ @abstractmethod - def get_assets(self) -> Artifact: + def get_artifact(self, context: Context) -> Artifact: """Return useful assets that represent intermediate data into context.""" @abstractmethod diff --git a/autointent/modules/abc/_decision.py b/autointent/modules/abc/_decision.py index 750ee05f..94d65607 100644 --- a/autointent/modules/abc/_decision.py +++ b/autointent/modules/abc/_decision.py @@ -49,13 +49,14 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics: :return: Computed metrics value for the test set or error code of metrics """ labels, scores = get_decision_evaluation_data(context, split) - self._decisions = self.predict(scores) + decisions = self.predict(scores) chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics} - return self.score_metrics((labels, self._decisions), chosen_metrics) + return self.score_metrics((labels, decisions), chosen_metrics) - def get_assets(self) -> DecisionArtifact: + def get_artifact(self, context: Context) -> DecisionArtifact: """Return useful assets that represent intermediate data into context.""" - return DecisionArtifact(labels=self._decisions) + _, scores = get_decision_evaluation_data(context, split="test") + return DecisionArtifact(labels=self.predict(scores)) def clear_cache(self) -> None: """Clear cache.""" diff --git a/autointent/modules/abc/_scoring.py b/autointent/modules/abc/_scoring.py index f275de9d..c19b36a7 100644 --- a/autointent/modules/abc/_scoring.py +++ b/autointent/modules/abc/_scoring.py @@ -41,24 +41,24 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics: scores = self.predict(utterances) - self._train_scores = self.predict(context.data_handler.train_utterances(1)) - self._validation_scores = self.predict(context.data_handler.validation_utterances(1)) - self._test_scores = self.predict(context.data_handler.test_utterances()) - metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics} return self.score_metrics((labels, scores), chosen_metrics) - def get_assets(self) -> ScorerArtifact: + def get_artifact(self, context: Context) -> ScorerArtifact: """ Retrieve assets generated during scoring. :return: ScorerArtifact containing test, validation and test scores. """ + train_scores = self.predict(context.data_handler.train_utterances(1)) + validation_scores = self.predict(context.data_handler.validation_utterances(1)) + test_scores = self.predict(context.data_handler.test_utterances()) + return ScorerArtifact( - train_scores=self._train_scores, - validation_scores=self._validation_scores, - test_scores=self._test_scores, + train_scores=train_scores, + validation_scores=validation_scores, + test_scores=test_scores, ) @abstractmethod diff --git a/autointent/modules/embedding/_logreg.py b/autointent/modules/embedding/_logreg.py index d729b874..25f3ead0 100644 --- a/autointent/modules/embedding/_logreg.py +++ b/autointent/modules/embedding/_logreg.py @@ -152,7 +152,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics: chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics} return self.score_metrics((labels, probas), chosen_metrics) - def get_assets(self) -> RetrieverArtifact: + def get_artifact(self, context: Context) -> RetrieverArtifact: """ Get the classifier artifacts for this module. diff --git a/autointent/modules/embedding/_retrieval.py b/autointent/modules/embedding/_retrieval.py index b7d919ad..1d29f037 100644 --- a/autointent/modules/embedding/_retrieval.py +++ b/autointent/modules/embedding/_retrieval.py @@ -132,7 +132,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics: chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics} return self.score_metrics((labels, predictions), chosen_metrics) - def get_assets(self) -> RetrieverArtifact: + def get_artifact(self, context: Context) -> RetrieverArtifact: """ Get the retriever artifacts for this module. diff --git a/autointent/modules/regexp/_regexp.py b/autointent/modules/regexp/_regexp.py index 05d0278e..45b681ef 100644 --- a/autointent/modules/regexp/_regexp.py +++ b/autointent/modules/regexp/_regexp.py @@ -133,7 +133,7 @@ def clear_cache(self) -> None: """Clear cache.""" del self.regexp_patterns - def get_assets(self) -> Artifact: + def get_artifact(self) -> Artifact: """Get assets.""" return Artifact() diff --git a/autointent/nodes/_optimization/_node_optimizer.py b/autointent/nodes/_optimization/_node_optimizer.py index 868403f9..65e264d8 100644 --- a/autointent/nodes/_optimization/_node_optimizer.py +++ b/autointent/nodes/_optimization/_node_optimizer.py @@ -51,7 +51,7 @@ def fit(self, context: Context) -> None: :param context: Context """ self._logger.info("starting %s node optimization...", self.node_info.node_type) - + scored_modules = [] for search_space in deepcopy(self.modules_search_spaces): module_name = search_space.pop("module_name") @@ -62,7 +62,10 @@ def fit(self, context: Context) -> None: context.callback_handler.start_module( module_name=module_name, num=j_combination, module_kwargs=module_kwargs ) - module = self.node_info.modules_available[module_name].from_context(context, **module_kwargs) + module_type = self.node_info.modules_available[module_name] + module = module_type.from_context(context, **module_kwargs) + + scored_modules.append((module_type, module_kwargs)) embedder_name = module.get_embedder_name() if embedder_name is not None: @@ -92,7 +95,6 @@ def fit(self, context: Context) -> None: module_kwargs, metric_value, self.target_metric, - module.get_assets(), # retriever name / scores / predictions module_dump_dir, module=module if not context.is_ram_to_clear() else None, ) @@ -102,7 +104,14 @@ def fit(self, context: Context) -> None: gc.collect() torch.cuda.empty_cache() - self._logger.info("%s node optimization is finished!", self.node_info.node_type) + self._logger.info("%s node optimization is finished! saving best assets", self.node_info.node_type) + # TODO refactor the following code (via implementing `autointent.load_module(path)` utility) + trial_idx = context.optimization_info.get_best_trial_idx(self.node_type) + trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) + module_type, module_kwargs = scored_modules[trial_idx] + best_module: Module = module_type(**module_kwargs) + best_module.load(trial.module_dump_dir) + context.optimization_info.artifacts.add_artifact(self.node_type, best_module.get_artifact(context)) def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str: """ diff --git a/tests/modules/embedding/test_logreg.py b/tests/modules/embedding/test_logreg.py index aba9071d..11550a70 100644 --- a/tests/modules/embedding/test_logreg.py +++ b/tests/modules/embedding/test_logreg.py @@ -1,9 +1,9 @@ from autointent.modules.embedding import LogregAimedEmbedding -def test_get_assets_returns_correct_artifact_for_logreg(): +def test_get_artifact_returns_correct_artifact_for_logreg(): module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo") - artifact = module.get_assets() + artifact = module.get_artifact() assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo" diff --git a/tests/modules/embedding/test_retrieval.py b/tests/modules/embedding/test_retrieval.py index 7b4618a0..bcce00f1 100644 --- a/tests/modules/embedding/test_retrieval.py +++ b/tests/modules/embedding/test_retrieval.py @@ -4,9 +4,9 @@ from tests.conftest import setup_environment -def test_get_assets_returns_correct_artifact(): +def test_get_artifact_returns_correct_artifact(): module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo") - artifact = module.get_assets() + artifact = module.get_artifact() assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo" From 2924d092c9bb685bc2e7048d81ffcc25aaa0942e Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 2 Feb 2025 09:03:46 +0300 Subject: [PATCH 3/4] fix typing --- autointent/nodes/_optimization/_node_optimizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/autointent/nodes/_optimization/_node_optimizer.py b/autointent/nodes/_optimization/_node_optimizer.py index 65e264d8..945f1bbd 100644 --- a/autointent/nodes/_optimization/_node_optimizer.py +++ b/autointent/nodes/_optimization/_node_optimizer.py @@ -107,10 +107,10 @@ def fit(self, context: Context) -> None: self._logger.info("%s node optimization is finished! saving best assets", self.node_info.node_type) # TODO refactor the following code (via implementing `autointent.load_module(path)` utility) trial_idx = context.optimization_info.get_best_trial_idx(self.node_type) - trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) - module_type, module_kwargs = scored_modules[trial_idx] + trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) # type: ignore[arg-type] + module_type, module_kwargs = scored_modules[trial_idx] # type: ignore[index] best_module: Module = module_type(**module_kwargs) - best_module.load(trial.module_dump_dir) + best_module.load(trial.module_dump_dir) # type: ignore[arg-type] context.optimization_info.artifacts.add_artifact(self.node_type, best_module.get_artifact(context)) def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str: From a602cd2d5bd62e309cd22cef7e0cc0270e96d209 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 2 Feb 2025 09:19:20 +0300 Subject: [PATCH 4/4] stage progress --- autointent/nodes/_optimization/_node_optimizer.py | 14 +++++++++----- tests/modules/embedding/test_logreg.py | 6 ------ tests/modules/embedding/test_retrieval.py | 6 ------ tests/nodes/conftest.py | 2 +- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/autointent/nodes/_optimization/_node_optimizer.py b/autointent/nodes/_optimization/_node_optimizer.py index 945f1bbd..0d856460 100644 --- a/autointent/nodes/_optimization/_node_optimizer.py +++ b/autointent/nodes/_optimization/_node_optimizer.py @@ -107,11 +107,15 @@ def fit(self, context: Context) -> None: self._logger.info("%s node optimization is finished! saving best assets", self.node_info.node_type) # TODO refactor the following code (via implementing `autointent.load_module(path)` utility) trial_idx = context.optimization_info.get_best_trial_idx(self.node_type) - trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) # type: ignore[arg-type] - module_type, module_kwargs = scored_modules[trial_idx] # type: ignore[index] - best_module: Module = module_type(**module_kwargs) - best_module.load(trial.module_dump_dir) # type: ignore[arg-type] - context.optimization_info.artifacts.add_artifact(self.node_type, best_module.get_artifact(context)) + if context.is_ram_to_clear(): + trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) # type: ignore[arg-type] + module_type, module_kwargs = scored_modules[trial_idx] # type: ignore[index] + best_module: Module = module_type(**module_kwargs) + best_module.load(trial.module_dump_dir) # type: ignore[arg-type] + else: + best_module = context.optimization_info.modules.get(self.node_type)[trial_idx] + artifact = best_module.get_artifact(context) + context.optimization_info.artifacts.add_artifact(self.node_type, artifact) def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str: """ diff --git a/tests/modules/embedding/test_logreg.py b/tests/modules/embedding/test_logreg.py index 11550a70..952d0044 100644 --- a/tests/modules/embedding/test_logreg.py +++ b/tests/modules/embedding/test_logreg.py @@ -1,12 +1,6 @@ from autointent.modules.embedding import LogregAimedEmbedding -def test_get_artifact_returns_correct_artifact_for_logreg(): - module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo") - artifact = module.get_artifact() - assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo" - - def test_fit_trains_model(): module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo") diff --git a/tests/modules/embedding/test_retrieval.py b/tests/modules/embedding/test_retrieval.py index bcce00f1..df4b0ee7 100644 --- a/tests/modules/embedding/test_retrieval.py +++ b/tests/modules/embedding/test_retrieval.py @@ -4,12 +4,6 @@ from tests.conftest import setup_environment -def test_get_artifact_returns_correct_artifact(): - module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo") - artifact = module.get_artifact() - assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo" - - def test_dump_and_load_preserves_model_state(): project_dir = setup_environment() module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo") diff --git a/tests/nodes/conftest.py b/tests/nodes/conftest.py index e5095249..6a09829d 100644 --- a/tests/nodes/conftest.py +++ b/tests/nodes/conftest.py @@ -82,7 +82,7 @@ def get_context(multilabel): if multilabel: dataset = dataset.to_multilabel() res.set_dataset(dataset) - res.configure_logging(LoggingConfig(project_dir=project_dir, keep_in_ram=False)) + res.configure_logging(LoggingConfig(project_dir=project_dir, keep_in_ram=True)) res.configure_vector_index(VectorIndexConfig(), EmbedderConfig(device="cpu")) res.configure_cross_encoder(CrossEncoderConfig()) return res