-
Notifications
You must be signed in to change notification settings - Fork 164
Description
Describe the bug
When using uni-mol2, I observed that the molecular representation results differ when processing molecules in a batch versus processing them individually. The results are inconsistent for the same molecule, except for the molecule with the longest atom count in the batch.
Since batch inference requires padding all inputs to the maximum length, I suspect that the longer molecules force the smaller ones to extend with padding tokens, but these tokens are not being correctly masked during the self-attention mechanism.
Debugging & Analysis:
I debugged the code and traced the issue to unimol_tools.models.transformersv2.py. Specifically, inside the forward method of class TransformerEncoderLayerV2, during the first self-attention operation:
self.self_attn(
x,
x,
x,
pair=pair,
mask=self_attn_mask,
)
I found that the representation x for the same molecule diverges after this step when comparing batch mode vs. single mode. Furthermore, during batch input, even when padding tokens are present, the self_attn_mask remains a matrix of all zeros. This indicates that the padding tokens are not being masked out effectively.
The self_attn_mask is defined in unimol_tools.models.unimolv2.UniMolV2Model.forward as:
attn_mask = attn_bias.clone()
However, attn_bias originates from get_graph_features in unimol_tools.data.conformer.py, where it is defined as:
feat["attn_bias"] = np.zeros((mask.sum() + 1, mask.sum() + 1), dtype=np.float32)
This initializes it as all zeros by default, ignoring the padding requirements for batch processing.
Proposed Solution:
I believe the issue lies in attn_mask = attn_bias.clone(). I modified the code to explicitly handle masking for padding tokens:
attn_mask = torch.cat([torch.ones(atom_mask.size(0), 1).to(atom_mask.device), atom_mask.clone()], dim=1)
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
attn_mask = attn_mask.unsqueeze(1).repeat(1, attn_mask.size(1), 1)
After applying this fix, the inconsistency issue was resolved (remaining minor numerical differences appear related to op_mask and op_norm operation ?).
Thank you for looking into this.
Uni-Mol Version
Uni-Mol2
Expected behavior
The molecular representations obtained during batch inference should be consistent with those obtained during single-sample inference. The presence of padding tokens in a batch should not affect the embeddings of other molecules.
To Reproduce
from unimol_tools import UniMolRepr
clf = UniMolRepr(data_type='molecule',
remove_hs=False,
model_name='unimolv2', # avaliable: unimolv1, unimolv2
model_size='164m', # work when model_name is unimolv2. avaliable: 84m, 164m, 310m, 570m, 1.1B.
)
# batch smiles unimol representation
l = ['C(CC(O)(P(=O)(O)O)P(=O)(O)[O-])CN.O.O.O.[Na+]','c1ccc(cc1)C2=NCC(=O)Nc3c2cc(cc3)[N+](=O)[O]']
r = clf.get_repr(l)
print(r[0][:4])
#single smiles unimol representation
l = ['C(CC(O)(P(=O)(O)O)P(=O)(O)[O-])CN.O.O.O.[Na+]']
r = clf.get_repr(l)
print(r[0][:4])
The output before fixing the mask problem
[ -9.299404 15.081673 -25.31543 -11.924848]
[ -7.4061403 17.232798 -15.8836975 -4.884043 ]
The output after fixing the mask problem
[ -7.503006 17.204712 -15.879755 -4.7829766]
[ -7.4061403 17.232798 -15.883697 -4.884042 ]
fixed content: change "attn_mask = attn_bias.clone()" to:
attn_mask = torch.cat([torch.ones(atom_mask.size(0), 1).to(atom_mask.device), atom_mask.clone()],dim=1)
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
attn_mask = attn_mask.unsqueeze(1).repeat(1, attn_mask.size(1),1)
Environment
Ubuntu 24.04.2 LTS
Additional Context
While this fixes the inference logic, I am unsure if the released pre-trained models were trained with this bug present (i.e., trained in batches without proper masking). If so, this fix might affect compatibility with the pre-trained weights.