From 0575fb0c61ca86fd61a7e7e8dbbd0a3b8b0939ae Mon Sep 17 00:00:00 2001 From: liuzan <62529552+liuzan-info@users.noreply.github.com> Date: Wed, 10 Jan 2024 11:21:03 +0800 Subject: [PATCH] fix multilabel datareader bug (#205) --- unimol_tools/unimol_tools/data/datareader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unimol_tools/unimol_tools/data/datareader.py b/unimol_tools/unimol_tools/data/datareader.py index 93a7c29..b1ded29 100644 --- a/unimol_tools/unimol_tools/data/datareader.py +++ b/unimol_tools/unimol_tools/data/datareader.py @@ -65,9 +65,10 @@ def read_data(self, data=None, is_train=True, **params): else: for i in range(label.shape[1]): data[target_col_prefix + str(i)] = label[:,i] + + _ = data.pop('target') data = pd.DataFrame(data).rename(columns={smiles_col: 'SMILES'}) - if 'target' in data: - data = data.drop(columns=['target']) + elif isinstance(data, list): # load from smiles list data = pd.DataFrame(data, columns=['SMILES'])