diff --git a/docs/source/conf.py b/docs/source/conf.py index fe45af9..4d6618e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,10 +6,10 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'Uni-Mol tools' +project = 'Uni-Mol' copyright = '2023, cuiyaning' author = 'cuiyaning' -release = '0.1.0' +release = '0.1.1' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/data.rst b/docs/source/data.rst index 72aa58a..9754f62 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -1,12 +1,12 @@ Data ==== -`unimol_tools.data `_ contains functions and classes for loading, containing, and scaler data, feature. +`unimol_tools.data `_ contains functions and classes for loading, containing, and scaler data, feature. DataHub ------- -Classes and functions from `unimol_tools.data.datahub.py `_. +Classes and functions from `unimol_tools.data.datahub.py `_. .. automodule:: unimol_tools.data.datahub :members: @@ -14,7 +14,7 @@ Classes and functions from `unimol_tools.data.datahub.py `_. +Classes and functions from `unimol_tools.data.datareader.py `_. .. automodule:: unimol_tools.data.datareader :members: @@ -22,7 +22,7 @@ Classes and functions from `unimol_tools.data.datareader.py `_. +Classes and functions from `unimol_tools.data.datascaler.py `_. .. automodule:: unimol_tools.data.datascaler :members: @@ -30,7 +30,7 @@ Classes and functions from `unimol_tools.data.datascaler.py `_. +Classes and functions from `unimol_tools.data.conformer.py `_. .. automodule:: unimol_tools.data.conformer :members: \ No newline at end of file diff --git a/docs/source/examples.md b/docs/source/examples.md new file mode 100644 index 0000000..69aba4f --- /dev/null +++ b/docs/source/examples.md @@ -0,0 +1,22 @@ +# Examples + +Welcome to the examples section! On our platform Bohrium, we offer a variety of notebook cases for studying Uni-Mol. These notebooks provide practical examples and applications of Uni-Mol in different scientific fields. You can explore these notebooks to gain hands-on experience and deepen your understanding of Uni-Mol. + +## Uni-Mol Notebooks on Bohrium +Explore our collection of Uni-Mol notebooks on Bohrium: [Uni-Mol Notebooks](https://bohrium.dp.tech/search?searchKey=UniMol&%3BactiveTab=notebook&activeTab=notebook) + +### Uni-Mol for QSAR (Quantitative Structure-Activity Relationship) +Uni-Mol can be used to predict the biological activity of compounds based on their chemical structure. These notebooks demonstrate how to apply Uni-Mol for QSAR tasks: +- [QSAR Example 1](https://bohrium.dp.tech/notebooks/7141701322) +- [QSAR Example 2](https://bohrium.dp.tech/notebooks/9919429887) + +### Uni-Mol for OLED Properties Predictions +Organic Light Emitting Diodes (OLEDs) are used in various display technologies. Uni-Mol can predict the properties of OLED molecules, aiding in the design of more efficient materials. Check out these notebooks for detailed examples: +- [OLED Properties Prediction Example 1](https://bohrium.dp.tech/notebooks/2412844127) +- [OLED Properties Prediction Example 2](https://bohrium.dp.tech/notebooks/7637046852) + +### Uni-Mol Predicts Liquid Flow Battery Solubility +Liquid flow batteries are a promising technology for energy storage. Uni-Mol can predict the solubility of compounds used in these batteries, helping to optimize their performance. Explore this notebook to see how Uni-Mol is applied in this context: +- [Liquid Flow Battery Solubility Prediction](https://bohrium.dp.tech/notebooks/7941779831) + +These examples provide a glimpse into the powerful capabilities of Uni-Mol in various scientific applications. We encourage you to explore these notebooks and experiment with Uni-Mol to discover its full potential. \ No newline at end of file diff --git a/docs/source/features.md b/docs/source/features.md index 6a25603..fc3ec56 100644 --- a/docs/source/features.md +++ b/docs/source/features.md @@ -1,5 +1,8 @@ # New Features +## 2024-11-22 +Unimol V2 has been added to Unimol_tools! + ## 2024-06-25 Unimol_tools has been publish to pypi! Huggingface has been used to manage the pretrain models. diff --git a/docs/source/index.rst b/docs/source/index.rst index 9d0818b..2580812 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,9 +3,24 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to Uni-Mol tools' documentation! +Welcome to Uni-Mol' documentation! ========================================== +Uni-Mol is the first universal large-scale three-dimensional Molecular Representation Learning (MRL) framework developed by the DP Technology. It expands the application scope and representation capabilities of MRL. + +This framework consists of two models, one trained on billions of molecular three-dimensional conformations and the other on millions of protein pocket data. + +It has shown excellent performance in various molecular property prediction tasks, especially in 3D-related tasks, where it demonstrates significant performance. In addition to drug design, Uni-Mol can also predict the properties of materials, such as the gas adsorption performance of MOF materials and the optical properties of OLED molecules. + +.. Important:: + + The project Uni-Mol is licensed under `MIT LICENSE `_. + If you use Uni-Mol in your research, please kindly cite the following works: + + - Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. "Uni-Mol: A Universal 3D Molecular Representation Learning Framework." The Eleventh International Conference on Learning Representations, 2023. `https://openreview.net/forum?id=6K2RM6wVqKu `_. + - Shuqi Lu, Zhifeng Gao, Di He, Linfeng Zhang, Guolin Ke. "Data-driven quantum chemical property prediction leveraging 3D conformations with Uni-Mol+." Nature Communications, 2024. `https://www.nature.com/articles/s41467-024-51321-w `_. + + Uni-Mol tools is a easy-use wrappers for property prediction,representation and downstreams with Uni-Mol. It includes the following tools: * molecular property prediction with Uni-Mol. @@ -14,11 +29,23 @@ Uni-Mol tools is a easy-use wrappers for property prediction,representation and .. toctree:: :maxdepth: 2 - :caption: Contents: + :caption: Getting Started: requirements installation - tutorial + +.. toctree:: + :maxdepth: 2 + :caption: Tutorials: + + quickstart + school + examples + +.. toctree:: + :maxdepth: 2 + :caption: Uni-Mol tools: + train data models @@ -27,6 +54,7 @@ Uni-Mol tools is a easy-use wrappers for property prediction,representation and weight features + Indices and tables ================== diff --git a/docs/source/installation.md b/docs/source/installation.md index 862d5d7..fbeba04 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -16,7 +16,7 @@ We recommend installing ```huggingface_hub``` so that the required unimol models pip install huggingface_hub ``` -`huggingface_hub` allows you to easily download and manage models from the Hugging Face Hub, which is key for using UniMol models. +`huggingface_hub` allows you to easily download and manage models from the Hugging Face Hub, which is key for using Uni-Mol models. ### Option 2: Installing from source @@ -25,7 +25,7 @@ pip install huggingface_hub pip install -r requirements.txt ## Clone repository -git clone https://github.com/dptech-corp/Uni-Mol.git +git clone https://github.com/deepmodeling/Uni-Mol.git cd Uni-Mol/unimol_tools ## Install @@ -34,7 +34,7 @@ python setup.py install ### Models in Huggingface -The UniMol pretrained models can be found at [dptech/Uni-Mol-Models](https://huggingface.co/dptech/Uni-Mol-Models/tree/main). +The Uni-Mol pretrained models can be found at [dptech/Uni-Mol-Models](https://huggingface.co/dptech/Uni-Mol-Models/tree/main). If the download is slow, you can use other mirrors, such as: diff --git a/docs/source/models.rst b/docs/source/models.rst index 63dcba9..ab85cd6 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -3,13 +3,13 @@ Models ====== -`unimol_tools.models `_ contains the models of Uni-Mol. +`unimol_tools.models `_ contains the models of Uni-Mol. Uni-Mol ------- -`unimol_tools.models.unimol.py `_ contains the :class:`~unimol_tools.models.UniMolModel`, which is the backbone of Uni-Mol model. +`unimol_tools.models.unimol.py `_ contains the :class:`~unimol_tools.models.UniMolModel`, which is the backbone of Uni-Mol model. .. automodule:: unimol_tools.models.unimol :members: @@ -17,7 +17,7 @@ Uni-Mol Model ----- -`unimol_tools.models.nnmodel.py `_ contains the :class:`~unimol_tools.models.NNModel`, which is responsible for initializing the model. +`unimol_tools.models.nnmodel.py `_ contains the :class:`~unimol_tools.models.NNModel`, which is responsible for initializing the model. .. automodule:: unimol_tools.models.nnmodel :members: @@ -25,7 +25,7 @@ Model Loss ----- -`unimol_tools.models.loss.py `_ contains different loss functions. +`unimol_tools.models.loss.py `_ contains different loss functions. .. automodule:: unimol_tools.models.loss :members: @@ -33,7 +33,7 @@ Loss Transformers ------------ -`unimol_tools.models.transformers.py `_ contains a custom Transformer Encoder module that extends `PyTorch's nn.Module `_. +`unimol_tools.models.transformers.py `_ contains a custom Transformer Encoder module that extends `PyTorch's nn.Module `_. .. automodule:: unimol_tools.models.transformers :members: \ No newline at end of file diff --git a/docs/source/tutorial.md b/docs/source/quickstart.md similarity index 59% rename from docs/source/tutorial.md rename to docs/source/quickstart.md index ac0192f..a7544a2 100644 --- a/docs/source/tutorial.md +++ b/docs/source/quickstart.md @@ -1,4 +1,6 @@ -# Tutorial +# Quick start + +Quick start for UniMol Tools. ## Molecule property prediction @@ -20,8 +22,7 @@ custom dict can also as the input. The dict format should be like ```python {'atoms':[['C','C'],['C','H','O']], 'coordinates':[coordinates_1,coordinates_2]} ``` -Here is an example to train a model and make a prediction. - +Here is an example to train a model and make a prediction. When using Unimol V2, set `model_name='unimolv2'`. ```python from unimol_tools import MolTrain, MolPredict clf = MolTrain(task='classification', @@ -29,6 +30,8 @@ clf = MolTrain(task='classification', epochs=10, batch_size=16, metrics='auc', + model_name='unimolv1', # avaliable: unimolv1, unimolv2 + model_size='84m', # work when model_name is unimolv2. avaliable: 84m, 164m, 310m, 570m, 1.1B. ) pred = clf.fit(data = train_data) # currently support data with smiles based csv/txt file @@ -41,13 +44,45 @@ res = clf.predict(data = test_data) Uni-Mol representation can easily be achieved as follow. ```python +import numpy as np from unimol_tools import UniMolRepr # single smiles unimol representation -clf = UniMolRepr(data_type='molecule', remove_hs=False) +clf = UniMolRepr(data_type='molecule', + remove_hs=False, + model_name='unimolv1', # avaliable: unimolv1, unimolv2 + model_size='84m', # work when model_name is unimolv2. avaliable: 84m, 164m, 310m, 570m, 1.1B. + ) smiles = 'c1ccc(cc1)C2=NCC(=O)Nc3c2cc(cc3)[N+](=O)[O]' smiles_list = [smiles] unimol_repr = clf.get_repr(smiles_list, return_atomic_reprs=True) # CLS token repr print(np.array(unimol_repr['cls_repr']).shape) # atomic level repr, align with rdkit mol.GetAtoms() -print(np.array(unimol_repr['atomic_reprs']).shape) \ No newline at end of file +print(np.array(unimol_repr['atomic_reprs']).shape) +``` +## Continue training (Re-train) + +```python +clf = MolTrain(task='regression', + data_type='molecule', + epochs=10, + batch_size=16, + save_path='./model_dir', + remove_hs=False, + target_cols='TARGET', + ) +pred = clf.fit(data = train_data) +# After train a model, set load_model_dir='./model_dir' to continue training + +clf2 = MolTrain(task='regression', + data_type='molecule', + epochs=10, + batch_size=16, + save_path='./retrain_model_dir', + remove_hs=False, + target_cols='TARGET', + load_model_dir='./model_dir', + ) + +pred2 = clf.fit(data = retrain_data) +``` \ No newline at end of file diff --git a/docs/source/school.md b/docs/source/school.md new file mode 100644 index 0000000..7119577 --- /dev/null +++ b/docs/source/school.md @@ -0,0 +1,26 @@ +# Uni-Mol School + +Welcome to Uni-Mol School! This course is designed to provide comprehensive training on Uni-Mol, a powerful tool for molecular modeling and simulations. + +## Course Introduction +The properties of drugs are determined by their three-dimensional structures, which are crucial for their efficacy and absorption. Drug design requires consideration of molecular diversity. Current Molecular Representation Learning (MRL) models mainly utilize one-dimensional or two-dimensional data, with limited capability to integrate 3D information. + +Uni-Mol, developed by the DP Technology team, is the first general large-scale 3D MRL framework in the field of drug design, expanding the application scope and representation capabilities of MRL. This framework consists of two models trained on billions of molecular 3D conformations and millions of protein pocket data, respectively. It has shown excellent performance in various molecular property prediction tasks, especially in 3D-related tasks. Besides drug design, Uni-Mol can also predict the properties of materials, such as gas adsorption performance of MOF materials and optical properties of OLED molecules. + +## Course Content +| Topic | Course Content | Instructor | +|-------|----------------|------------| +| Introduction to Uni-Mol | Uni-Mol molecular 3D representation learning framework and pre-trained models | Chen Letian | +| Uni-Mol for Materials Science | Case study of Uni-Mol in predicting the properties of battery materials | Chen Letian | +| | 3D Representation Learning Framework and Pre-trained Models for Nanoporous Materials | Chen Letian | +| | Efficient Screening of Ir(III) Complex Emitters: A Study Combining Machine Learning and Computational Analysis | Chen Letian | +| | Application of 3D Molecular Pre-trained Model Uni-Mol in Flow Batteries | Xie Qiming | +| | Materials Science Uni-Mol Notebook Case Study | | +| Uni-Mol for Biomedical Science | Application of Uni-Mol in Molecular Docking | Zhou Gengmo | +| | Application of Uni-Mol in Molecular Generation | Song Ke | +| | Biomedical Science Uni-Mol Notebook Case Study | | + +## How to Enroll +Enroll now and start your journey with Uni-Mol! [Click here to enroll](https://bohrium.dp.tech/courses/6134196349?tab=courses) + +Don't miss this opportunity to advance your knowledge and skills in molecular modeling with Uni-Mol! \ No newline at end of file diff --git a/docs/source/task.rst b/docs/source/task.rst index 2274a08..472c6ea 100644 --- a/docs/source/task.rst +++ b/docs/source/task.rst @@ -3,13 +3,13 @@ Task ====== -`unimol_tools.tasks `_ oversees the tasks related to the model, such as training and prediction. +`unimol_tools.tasks `_ oversees the tasks related to the model, such as training and prediction. Trainer ------- -`unimol_tools.tasks.trainer.py `_ contains the :class:`~unimol_tools.unimol_tools.models.tasks.Trainer`, managing the training, validation, and testing phases. +`unimol_tools.tasks.trainer.py `_ contains the :class:`~unimol_tools.unimol_tools.models.tasks.Trainer`, managing the training, validation, and testing phases. .. automodule:: unimol_tools.tasks.trainer :members: @@ -17,7 +17,7 @@ Trainer Split ------- -`unimol_tools.tasks.split.py `_ manages the split methods in the dataset. +`unimol_tools.tasks.split.py `_ manages the split methods in the dataset. .. automodule:: unimol_tools.tasks.split :members: \ No newline at end of file diff --git a/docs/source/train.rst b/docs/source/train.rst index da8ba7c..f01d7d8 100644 --- a/docs/source/train.rst +++ b/docs/source/train.rst @@ -7,7 +7,7 @@ Interface Train ----- -`unimol_tools.train.py `_ trains a Uni-Mol model. +`unimol_tools.train.py `_ trains a Uni-Mol model. .. automodule:: unimol_tools.train :members: @@ -16,7 +16,7 @@ Train Predict ------------ -`unimol_tools.predictor.py `_ predict through a Uni-Mol model. +`unimol_tools.predictor.py `_ predict through a Uni-Mol model. .. automodule:: unimol_tools.predict :members: @@ -25,7 +25,7 @@ Predict Uni-Mol representation ------------------------ -`unimol_tools.predictor.py `_ get the Uni-Mol representation. +`unimol_tools.predictor.py `_ get the Uni-Mol representation. .. automodule:: unimol_tools.predictor :members: diff --git a/docs/source/utils.rst b/docs/source/utils.rst index cc4cc4f..ed5f9ca 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -3,13 +3,13 @@ Utils ======= -`unimol_tools.utils `_ contains the utils related to the model, such as metrics and logger. +`unimol_tools.utils `_ contains the utils related to the model, such as metrics and logger. Metrics ------- -`unimol_tools.utils.metrics `_ contains the metrics included in the model. +`unimol_tools.utils.metrics `_ contains the metrics included in the model. .. automodule:: unimol_tools.utils.metrics :members: @@ -17,7 +17,7 @@ Metrics Logger ------- -`unimol_tools.utils.base_logger.py `_ control the logger. +`unimol_tools.utils.base_logger.py `_ control the logger. .. automodule:: unimol_tools.utils.base_logger :members: @@ -25,7 +25,7 @@ Logger Config ------- -`unimol_tools.utils.config_handler.py `_ manages the config input file. +`unimol_tools.utils.config_handler.py `_ manages the config input file. .. automodule:: unimol_tools.utils.config_handler :members: @@ -33,7 +33,7 @@ Config Padding ------- -`unimol_tools.utils.util.py `_ contain some padding methods. +`unimol_tools.utils.util.py `_ contain some padding methods. .. automodule:: unimol_tools.utils.util :members: \ No newline at end of file diff --git a/docs/source/weight.rst b/docs/source/weight.rst index a5e891e..218dd98 100644 --- a/docs/source/weight.rst +++ b/docs/source/weight.rst @@ -3,18 +3,18 @@ Weights ======= -We recommend installing ``huggingface_hub`` so that the required UniMol models can be automatically downloaded at runtime! It can be installed by: +We recommend installing ``huggingface_hub`` so that the required Uni-Mol models can be automatically downloaded at runtime! It can be installed by: .. code-block:: bash pip install huggingface_hub -``huggingface_hub`` allows you to easily download and manage models from the Hugging Face Hub, which is key for using UniMol models. +``huggingface_hub`` allows you to easily download and manage models from the Hugging Face Hub, which is key for using Uni-Mol models. Models in Huggingface --------------------- -The UniMol pretrained models can be found at `dptech/Uni-Mol-Models `_. +The Uni-Mol pretrained models can be found at `dptech/Uni-Mol-Models `_. If the download is slow, you can use other mirrors, such as: @@ -24,7 +24,7 @@ If the download is slow, you can use other mirrors, such as: Setting the ``HF_ENDPOINT`` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models. -`unimol_tools.weights.weight_hub.py `_ control the logger. +`unimol_tools.weights.weight_hub.py `_ control the logger. .. automodule:: unimol_tools.weights.weighthub :members: \ No newline at end of file diff --git a/unimol_tools/setup.py b/unimol_tools/setup.py index 29e316a..1d1ddfa 100644 --- a/unimol_tools/setup.py +++ b/unimol_tools/setup.py @@ -5,7 +5,7 @@ setup( name="unimol_tools", - version="0.1.1", + version="0.1.1.post1", description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."), long_description=open('README.md').read(), long_description_content_type='text/markdown', diff --git a/unimol_tools/unimol_tools/data/datareader.py b/unimol_tools/unimol_tools/data/datareader.py index 0745b2c..81e7fe2 100644 --- a/unimol_tools/unimol_tools/data/datareader.py +++ b/unimol_tools/unimol_tools/data/datareader.py @@ -78,6 +78,10 @@ def read_data(self, data=None, is_train=True, **params): else: if target_cols is None: target_cols = [item for item in data.columns if item.startswith(target_col_prefix)] + elif isinstance(target_cols, str): + target_cols = target_cols.split(',') + elif isinstance(target_cols, list): + pass else: for col in target_cols: if col not in data.columns: diff --git a/unimol_tools/unimol_tools/models/nnmodel.py b/unimol_tools/unimol_tools/models/nnmodel.py index d52b959..613bdaf 100644 --- a/unimol_tools/unimol_tools/models/nnmodel.py +++ b/unimol_tools/unimol_tools/models/nnmodel.py @@ -158,6 +158,19 @@ def run(self): if fold > 0: # need to initalize model for next fold training self.model = self._init_model(**self.model_params) + if self.model_params.get('load_model_dir', None) is not None: + load_model_path = os.path.join(self.model_params['load_model_dir'], f'model_{fold}.pth') + model_dict = torch.load(load_model_path, map_location=self.model_params['device'])["model_state_dict"] + if model_dict['classification_head.out_proj.weight'].shape[0] != self.model.output_dim: + current_model_dict = self.model.state_dict() + model_dict = {k: v for k, v in model_dict.items() if k in current_model_dict and 'classification_head.out_proj' not in k} + current_model_dict.update(model_dict) + logger.info("The output_dim of the model is different from the loaded model, only load the common part of the model") + self.model.load_state_dict(model_dict, strict=False) + else: + self.model.load_state_dict(model_dict) + + logger.info("load model success from {}".format(load_model_path)) _y_pred = self.trainer.fit_predict( self.model, traindataset, validdataset, self.loss_func, self.activation_fn, self.save_path, fold, self.target_scaler) y_pred[te_idx] = _y_pred diff --git a/unimol_tools/unimol_tools/models/transformers.py b/unimol_tools/unimol_tools/models/transformers.py index 51975d2..19c9309 100644 --- a/unimol_tools/unimol_tools/models/transformers.py +++ b/unimol_tools/unimol_tools/models/transformers.py @@ -10,6 +10,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True): """softmax dropout, and mask, bias are optional. + Args: input (torch.Tensor): input tensor dropout_prob (float): dropout probability diff --git a/unimol_tools/unimol_tools/train.py b/unimol_tools/unimol_tools/train.py index 2bf1a4c..f1a3abc 100644 --- a/unimol_tools/unimol_tools/train.py +++ b/unimol_tools/unimol_tools/train.py @@ -34,6 +34,7 @@ def __init__(self, save_path='./exp', remove_hs=False, smiles_col='SMILES', + target_cols=None, target_col_prefix='TARGET', target_anomaly_check="filter", smiles_check="filter", @@ -43,6 +44,7 @@ def __init__(self, use_amp=True, freeze_layers=None, freeze_layers_reversed=False, + load_model_dir=None, # load model for transfer learning model_name='unimolv1', model_size='84m', **params, @@ -76,6 +78,7 @@ def __init__(self, :param save_path: str, default='./exp', path to save training results. :param remove_hs: bool, default=False, whether to remove hydrogens from molecules. :param smiles_col: str, default='SMILES', column name of SMILES. + :param target_cols: list or str, default=None, column names of target values. :param target_col_prefix: str, default='TARGET', prefix of target column name. :param target_anomaly_check: str, default='filter', how to deal with anomaly target values. currently support: filter, none. :param smiles_check: str, default='filter', how to deal with invalid SMILES. currently support: filter, none. @@ -87,11 +90,16 @@ def __init__(self, :param freeze_layers: str or list, frozen layers by startwith name list. ['encoder', 'gbf'] will freeze all the layers whose name start with 'encoder' or 'gbf'. :param freeze_layers_reversed: bool, default=False, inverse selection of frozen layers :param params: dict, default=None, other parameters. + :param load_model_dir: str, default=None, path to load model for transfer learning. :param model_name: str, default='unimolv1', currently support unimolv1, unimolv2. :param model_size: str, default='84m', model size. work when model_name is unimolv2. avaliable: 84m, 164m, 310m, 570m, 1.1B. """ - config_path = os.path.join(os.path.dirname(__file__), 'config/default.yaml') + if load_model_dir is not None: + config_path = os.path.join(load_model_dir, 'config.yaml') + logger.info('Load config file from {}'.format(config_path)) + else: + config_path = os.path.join(os.path.dirname(__file__), 'config/default.yaml') self.yamlhandler = YamlHandler(config_path) config = self.yamlhandler.read_yaml() config.task = task @@ -106,6 +114,7 @@ def __init__(self, config.kfold = kfold config.remove_hs = remove_hs config.smiles_col = smiles_col + config.target_cols = target_cols config.target_col_prefix = target_col_prefix config.anomaly_clean = target_anomaly_check in ['filter'] config.smi_strict = smiles_check in ['filter'] @@ -115,6 +124,7 @@ def __init__(self, config.use_amp = use_amp config.freeze_layers = freeze_layers config.freeze_layers_reversed = freeze_layers_reversed + config.load_model_dir = load_model_dir config.model_name = model_name config.model_size = model_size self.save_path = save_path