Skip to content

[BUG] Wrong calculation on feat["pair_type"] in Uni-Mol2 #368

@xuanliugit

Description

@xuanliugit

Describe the bug

The calculation of feat["pair_type"] in Uni-Mol2 at line 64 unimol2/unimol2/data/unimol2_dataset.py should calculate the atom-atom pair. Instead, it calculates the bond-bond pair. Is this the expected behavior?

#unimol2/unimol2/data/unimol2_dataset.py

def get_graph_features(...):
    atom_feat = convert_to_single_emb(x[:, 1:], atom_feat_sizes) 
    # The atom number has been removed in the above step
    
    ...
    feat = {}
    feat["atom_feat"] = torch.from_numpy(atom_feat).long()
    ...
    atoms = feat["atom_feat"][..., 0]
    # As a result, the above code does not retrieve the atom number.
    
    pair_type = torch.cat(
            [
                atoms.view(-1, 1, 1).expand(-1, N, -1),
                atoms.view(1, -1, 1).expand(N, -1, -1),
            ],
            dim=-1,
        )

Uni-Mol Version

Uni-Mol2

Expected behavior

atoms should get directly from x[:, 0] instead of atom_feat.

To Reproduce

No response

Environment

No response

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions