-
Notifications
You must be signed in to change notification settings - Fork 131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
用unimol+模型 inference出的pos_pred取同一个id的均值吗? #269
Comments
您好,请问第一个问题的id是什么,我可以先复现一下。然后也确认下这个id对应的构象是8个吗,因为不是所有的id都有8个构象;第二个问题确实是的,是把mol对象的坐标换成预测的坐标来生成的 |
1:不确定这里是怎么输出的每个id对应的构象坐标shape,如果可以的话可以提供一下相应的代码。不过不是所有id对应的构象都是8个,如果按照给出的图里全都是按照8个来切分,可能混淆了不同id的分子,所以原子数不一样。 |
1.发现相同id对应的预测构象坐标shape不同后,我检查了模型预测后得到的直接输出test-dev_0.pkl,以防是我输出shape的代码有误,check预测的坐标值(以id3379091为例子)结果如下:
我想代码中是以 df_grouped = df.groupby("id")按id分组的,并不是以8个为一组,如果是切分有误,那么问题可能发生在我的模型inference过程中,但是inference.py文件我并无改动,想知道您的test-dev测试集中该id的pos_pred结果原子数量有异吗? |
我重新检查了unimol+模型的inference过程。在unimol_plus文件夹下的pcq.py中的load_dataset函数部分,涉及到pcq_dataset.py文件中的PCQDataset函数,其中以下代码:
max_node_num 是每个batch里最大分子原子数,然后将其调整到最接近的 4 的倍数减去 1。atom_mask对分子的真实原子位赋1,新添加的虚拟原子位赋0,包括后面涉及到的attn_mask对新添的虚拟原子位赋-inf。经过这种处理,新添的虚拟原子坐标值一开始都是0,但是在经过inference后,虚拟原子(mask标记为0)的坐标也有了预测值,导致我上面出现的同一个id(相同分子)对应的原子数不相同的情况(原子个数为每个batch里的max_node_num)。 |
您好,请问复现有结果了吗?不知道我上述猜想是否正确? @ShuqiLu |
不好意思没有看到您的回复,你理解的其实也差不多,这里把所有分子的原子数都置为max_node_num是为了能用pytorch并行处理一个batch内的所有分子,需要所有tensor的shape一致,所以这里用padding操作,把batch的的分子的原子数补充成相同的数目;为了使得padding的内容不实际影响模型的运算结果,所以使用atom_mask和attn_mask让padding的内容不参与实际运算;因为这篇工作我们只预测分子的能量并不取出分子坐标独立研究,所以没有对返回的分子坐标处理padding的部分,所以看起来同一个分子生成的坐标在不同batch内shape不一样。实际上如果需要取出真实原子的坐标,去掉padding的部分,可以利用atom_mask,把每个分子的atom_mask=1的位置对应的坐标取出,就是所有真实原子的坐标; 或者假设真实原子数为k,可以取出前k个坐标,利用数据中的smiles生成rdkit初始构象,再将前k个坐标填入即得到预测的3d构象(需要原始数据中的smiles不然可能没法对应)。 至于show case中选择的标准是什么,这里其实我们就选择了几个能量预测误差相对小的case展示了一下,没有过多的特殊筛选。 |
在make_pcq_test_dev_submission.py文件中gap_pred是取同一个id对应值的平均,大多数id是8个值的平均。然而在预测出的pos_pred(原子坐标信息中),为什么相同id得到的pos_pred形状会不一样?也就是说同一个分子的原子数量对不上?如下是某个id预测出的pos_pred
形状为Shape mismatch for id group: [(19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (19, 3), (15, 3)],第8个预测出来的pos_pred只有15个原子对应的坐标信息,而其它为19个原子。
想请问这是什么情况?以及想得到某个id的最终pos_pred值,是采用什么方法?除去有异常原子数量的预测值,然后其余取均值吗?
The text was updated successfully, but these errors were encountered: