You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This repository contains the Cosmo neural network lift and convolution layers. For a usage example and reproduction of the results of the RECOMB 2026 submission "Gaining mechanistic insight from geometric deep learning on molecule structures through equivariant convolution", see https://github.com/BorgwardtLab/RECOMB2026Cosmo.
2
6
7
+
Installation: `pip install cosmic-torch` or `pip install git+https://github.com/BorgwardtLab/Cosmo`
8
+
9
+
### Cosmo
10
+
11
+
Cosmo is a neural network architecture based on message passing on geometric graphs of molecule structures. It applies a convolutional filter by translating it to vertices and rotating it towards neighbors. The resulting feature activation (message) is passed to the neighbor that the filter was pointed at. This way, large geometric patterns can be modeled with a template-matching objective by using multiple Cosmo layers. A Cosmo network is equivariant to translation and rotation, and highly interpretable as its weight matrices can be linearly combined and its filter poses can be reconstructed geometrically. For more details, please see the paper.
12
+
13
+
### Example Usage
14
+
15
+
Cosmo layers operate on lifted geometric graphs. These are computed from an adjacency matrix of the data, either given by e.g. atomic bond connectivity, or constructed by e.g. k-NN:
where `coords` are the input point coordinates of the data, `k` is a hyperparameter, and `batch_index` assigns each node to an instance in the batch (compare the computing principles of [PyG](https://pytorch-geometric.readthedocs.io/en/2.4.0/index.html), which we highly recommend to use).
20
+
21
+
Given coordinates, node features (e.g. one-hot encoded atom type), and the adjacency we can lift the input graph:
22
+
23
+
```
24
+
L = Lift2D()(features, coords, adj, batch_index) # or Lift3D()
25
+
```
26
+
27
+
The `L` namespace contains everything that we need to compute in subsequent Cosmo layers:
28
+
29
+
```
30
+
features = layer(L.source, L.target, L.features, L.hood_coords)
31
+
```
32
+
33
+
After the Cosmo layers we need to undo the lift operation (lowering) to obtain features on the input graph. This is done by aggregating the edge/triangle features to the nodes, which yields a standard graph object that can be further computed on with PyG layers, for example.
from .utilitiesimportscatter_add, scatter_mean, scatter_softmax
5
+
"""
6
+
Cosmo can be implemented with various filter functions. The underlying principle is always to compute the filter under transformation of a local reference frame (hood_coords) which is derived from neighboring input points. The forward signature of the layer is always the same and inputs can be obtained from a Lift2D or Lift3D module.
7
+
"""
5
8
6
9
7
10
classKernelPointCosmo(nn.Module):
11
+
"""
12
+
Implements Kernel Point Convolution (Thomas et al. 2019) in the Cosmo framework.
13
+
Note that we implement an optimization trick from KPConvX (Thomas et al. 2024) which uses only the closest kernel point for each input point.
returnfeatures# Updated features of shape (m, out_channels)
30
45
31
46
32
47
classNeuralFieldCosmo(nn.Module):
48
+
"""
49
+
Implements Neural Field Convolution (Proposed with Cosmo in Kucera et al. 2026) in the Cosmo framework. Weight matrices are computed from input coordinates in the local reference frame using a neural field (parameterized by a neural network).
50
+
"""
51
+
33
52
def__init__(
34
53
self,
35
-
in_channels,
36
-
out_channels,
37
-
hidden_channels=32,
38
-
num_layers=3,
39
-
dropout=0.0,
40
-
radius=1.0,
41
-
dim=3,
42
-
field_activation=nn.Tanh,
54
+
in_channels,# Number of input channels
55
+
out_channels,# Number of output channels
56
+
field_channels=32,# Number of channels in the neural field
57
+
field_layers=3,# Number of layers in the neural field
58
+
field_dropout=0.0,# Dropout rate in the neural field
59
+
field_activation=nn.Tanh, # Activation function in the neural field
60
+
radius=1.0, # Scale parameter for input coordinates
Parameter-free module to lift a 2D geometric graph. Given node features, global coordinates, and a graph adjacency matrix it computes the lifted adjacency and coordinates of neighborhoods in the local reference frame of the edges, together with some helper variables. These build the input to a Cosmo layer.
coords=coords,# Global coordinates of shape (n, 2)
49
+
hood_coords=hood_coords,# Local coordinates of shape (m, 2)
50
+
features=edge_features,# Edge features of shape (m, in_channels)
51
+
bases=bases,# Bases of shape (m, 2, 2)
52
+
i=i,# Node indices i (m,)
53
+
j=j,# Node indices j (m,)
54
+
edges=edges,# Tuples of i,j (m, 2)
55
+
node2inst=node2inst,# Node-to-instance mapping of shape (n,)
56
+
lifted2node=edge2node,# Edge-to-node mapping of shape (m,)
57
+
lifted2inst=edge2inst,# Edge-to-instance mapping of shape (m,)
48
58
)
49
59
50
60
51
61
classLift3D:
62
+
"""
63
+
Parameter-free module to lift a 3D geometric graph. Given node features, global coordinates, and a graph adjacency matrix it computes the lifted adjacency and coordinates of neighborhoods in the local reference frame of the triangles, together with some helper variables. These build the input to a Cosmo layer.
coords=coords,# Global coordinates of shape (n, 3)
103
+
hood_coords=hood_coords,# Local coordinates of shape (m, 3)
104
+
features=tri_features,# Triangle features of shape (m, in_channels)
105
+
bases=bases,# Bases of shape (m, 3, 3)
106
+
i=i,# Node indices i (m,)
107
+
j=j,# Node indices j (m,)
108
+
triangles=triangles,# Tuples of i,j,k (m, 3)
109
+
node2inst=node2inst,# Node-to-instance mapping of shape (n,)
110
+
lifted2node=tri2node,# Triangle-to-node mapping of shape (m,)
111
+
lifted2edge=tri2edge,# Triangle-to-edge mapping of shape (m,)
112
+
lifted2inst=tri2inst,# Triangle-to-instance mapping of shape (m,)
93
113
)
114
+
115
+
116
+
classLower:
117
+
"""
118
+
Parameter-free module to lower a lifted geometric graph back to the input graph. Given edge/triangle features and the corresponding index it aggregates the features to the input graph.
0 commit comments