Skip to content

Commit

Permalink
fix multilabel datareader bug (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzan-info authored Jan 10, 2024
1 parent 6066381 commit 0575fb0
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions unimol_tools/unimol_tools/data/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def read_data(self, data=None, is_train=True, **params):
else:
for i in range(label.shape[1]):
data[target_col_prefix + str(i)] = label[:,i]

_ = data.pop('target')
data = pd.DataFrame(data).rename(columns={smiles_col: 'SMILES'})
if 'target' in data:
data = data.drop(columns=['target'])

elif isinstance(data, list):
# load from smiles list
data = pd.DataFrame(data, columns=['SMILES'])
Expand Down

0 comments on commit 0575fb0

Please sign in to comment.