Skip to content

Commit ae16ed2

Browse files
committed
update readme
1 parent 69bd6ae commit ae16ed2

File tree

6 files changed

+259
-160
lines changed

6 files changed

+259
-160
lines changed

README.md

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,79 @@
1+
<p align="center">
2+
<img src="assets/logo.png" width="100"/>
3+
</p>
4+
15
This repository contains the Cosmo neural network lift and convolution layers. For a usage example and reproduction of the results of the RECOMB 2026 submission "Gaining mechanistic insight from geometric deep learning on molecule structures through equivariant convolution", see https://github.com/BorgwardtLab/RECOMB2026Cosmo.
26

7+
Installation: `pip install cosmic-torch` or `pip install git+https://github.com/BorgwardtLab/Cosmo`
8+
9+
### Cosmo
10+
11+
Cosmo is a neural network architecture based on message passing on geometric graphs of molecule structures. It applies a convolutional filter by translating it to vertices and rotating it towards neighbors. The resulting feature activation (message) is passed to the neighbor that the filter was pointed at. This way, large geometric patterns can be modeled with a template-matching objective by using multiple Cosmo layers. A Cosmo network is equivariant to translation and rotation, and highly interpretable as its weight matrices can be linearly combined and its filter poses can be reconstructed geometrically. For more details, please see the paper.
12+
13+
### Example Usage
14+
15+
Cosmo layers operate on lifted geometric graphs. These are computed from an adjacency matrix of the data, either given by e.g. atomic bond connectivity, or constructed by e.g. k-NN:
16+
17+
`adj = torch_geometric.nn.knn_graph(coords, k, batch_index)`
18+
19+
where `coords` are the input point coordinates of the data, `k` is a hyperparameter, and `batch_index` assigns each node to an instance in the batch (compare the computing principles of [PyG](https://pytorch-geometric.readthedocs.io/en/2.4.0/index.html), which we highly recommend to use).
20+
21+
Given coordinates, node features (e.g. one-hot encoded atom type), and the adjacency we can lift the input graph:
22+
23+
```
24+
L = Lift2D()(features, coords, adj, batch_index) # or Lift3D()
25+
```
26+
27+
The `L` namespace contains everything that we need to compute in subsequent Cosmo layers:
28+
29+
```
30+
features = layer(L.source, L.target, L.features, L.hood_coords)
31+
```
32+
33+
After the Cosmo layers we need to undo the lift operation (lowering) to obtain features on the input graph. This is done by aggregating the edge/triangle features to the nodes, which yields a standard graph object that can be further computed on with PyG layers, for example.
34+
35+
```
36+
node_features = Lower(agg="max")(features, L.lifted2node, num_nodes)
37+
```
38+
39+
Or, if features should be aggregated directly to the instance (graph) level:
40+
41+
```
42+
graph_features = Lower(agg="max")(features, L.lifted2inst, num_instances)
43+
```
44+
45+
An entire Cosmo network for a node classification task could look like this:
46+
47+
```
48+
from cosmic import *
49+
import torch.nn as nn
50+
51+
class CosmoModel(nn.Module):
52+
53+
def __init__(self):
54+
self.lift = Lift3D()
55+
self.lower = Lower()
56+
self.cosmo_layers = nn.ModuleList([
57+
NeuralFieldCosmo(in_channels=5, out_channels=128, dim=3),
58+
NeuralFieldCosmo(in_channels=128, out_channels=128, dim=3),
59+
NeuralFieldCosmo(in_channels=128, out_channels=10, dim=3)
60+
])
61+
62+
def forward(self, node_features, coords, adj, batch_index, num_nodes):
63+
L = self.lift(node_features, coords, adj, batch_index)
64+
features = L.features
65+
for layer in self.cosmo_layers:
66+
features = layer(L.source, L.target, features, L.hood_coords)
67+
node_features = self.lower(features, L.lifted2node, num_nodes)
68+
# there could be some classic GNN-layers here, or an MLP head
69+
return node_features
70+
```
71+
72+
73+
### Citation
74+
75+
TBD
76+
77+
### License
378

4-
Installation: This package depends on torch and torch-scatter. Please install according to your system and their instructions.
79+
TBD

assets/logo.png

517 KB
Loading

cosmic/cosmo.py

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,114 @@
11
import torch
22
from torch import nn
3+
from torch_scatter import scatter_add, scatter_mean, scatter_softmax
34

4-
from .utilities import scatter_add, scatter_mean, scatter_softmax
5+
"""
6+
Cosmo can be implemented with various filter functions. The underlying principle is always to compute the filter under transformation of a local reference frame (hood_coords) which is derived from neighboring input points. The forward signature of the layer is always the same and inputs can be obtained from a Lift2D or Lift3D module.
7+
"""
58

69

710
class KernelPointCosmo(nn.Module):
11+
"""
12+
Implements Kernel Point Convolution (Thomas et al. 2019) in the Cosmo framework.
13+
Note that we implement an optimization trick from KPConvX (Thomas et al. 2024) which uses only the closest kernel point for each input point.
14+
"""
815

9-
def __init__(self, in_channels, out_channels, filter):
16+
def __init__(
17+
self,
18+
in_channels, # Number of input channels
19+
out_channels, # Number of output channels
20+
kernel_points, # Kernel points of shape (k, dim)
21+
):
1022
super().__init__()
1123
self.out_channels = out_channels
12-
mu = filter.unsqueeze(0).float() # out_channels x k x d
24+
mu = kernel_points.unsqueeze(0).float() # out_channels x k x d
1325
self.register_buffer("mu", mu)
14-
self.w = nn.Parameter(
15-
torch.rand(out_channels, mu.shape[1], in_channels)
16-
) # out_channels x k x in_channels
26+
self.w = nn.Parameter(torch.rand(out_channels, mu.shape[1], in_channels))
1727
nn.init.xavier_uniform_(self.w)
1828

19-
def forward(self, ijk, jkl, triangle_features, hood_coords):
29+
def forward(
30+
self,
31+
source, # Source edges ij or triangles ijk (m,)
32+
target, # Target edges jk or triangles jkl (m,)
33+
features, # Edge or triangle features of shape (m, in_channels)
34+
hood_coords, # Locally transformed coordinates of shape (m, dim)
35+
):
36+
m = features.shape[0]
2037
with torch.no_grad():
21-
dist = torch.cdist(hood_coords.unsqueeze(0), self.mu) # n x k
38+
dist = torch.cdist(hood_coords.unsqueeze(0), self.mu) # m x k
2239
nn_idx = dist.argmin(dim=2).squeeze(0)
2340
w = self.w[:, nn_idx] # use closest kernel point
24-
f = triangle_features[ijk]
25-
out_channels = torch.einsum("ni,oni->no", f, w) # n x out
26-
triangle_features = scatter_add(
27-
out_channels, jkl, dim=0, dim_size=triangle_features.shape[0]
28-
)
29-
return triangle_features
41+
f = features[source]
42+
out_channels = torch.einsum("ni,oni->no", f, w) # m x out
43+
features = scatter_add(out_channels, target, dim=0, dim_size=m)
44+
return features # Updated features of shape (m, out_channels)
3045

3146

3247
class NeuralFieldCosmo(nn.Module):
48+
"""
49+
Implements Neural Field Convolution (Proposed with Cosmo in Kucera et al. 2026) in the Cosmo framework. Weight matrices are computed from input coordinates in the local reference frame using a neural field (parameterized by a neural network).
50+
"""
51+
3352
def __init__(
3453
self,
35-
in_channels,
36-
out_channels,
37-
hidden_channels=32,
38-
num_layers=3,
39-
dropout=0.0,
40-
radius=1.0,
41-
dim=3,
42-
field_activation=nn.Tanh,
54+
in_channels, # Number of input channels
55+
out_channels, # Number of output channels
56+
field_channels=32, # Number of channels in the neural field
57+
field_layers=3, # Number of layers in the neural field
58+
field_dropout=0.0, # Dropout rate in the neural field
59+
field_activation=nn.Tanh, # Activation function in the neural field
60+
radius=1.0, # Scale parameter for input coordinates
61+
dim=3, # Dimension of the input data (2 or 3)
4362
):
4463
super().__init__()
4564
self.register_buffer("radius", torch.tensor(radius))
4665
self.register_buffer("in_channels", torch.tensor(in_channels))
4766
self.register_buffer("out_channels", torch.tensor(out_channels))
4867
self.neural_field = nn.Sequential(
49-
nn.Linear(dim, hidden_channels),
50-
nn.LayerNorm(hidden_channels),
68+
nn.Linear(dim, field_channels),
69+
nn.LayerNorm(field_channels),
5170
nn.ReLU(),
52-
nn.Dropout(dropout),
71+
nn.Dropout(field_dropout),
5372
*[
54-
nn.Linear(hidden_channels, hidden_channels),
55-
nn.LayerNorm(hidden_channels),
73+
nn.Linear(field_channels, field_channels),
74+
nn.LayerNorm(field_channels),
5675
nn.ReLU(),
57-
nn.Dropout(dropout),
76+
nn.Dropout(field_dropout),
5877
]
59-
* (num_layers - 2),
60-
nn.Linear(hidden_channels, in_channels * out_channels),
78+
* (field_layers - 2),
79+
nn.Linear(field_channels, in_channels * out_channels),
6180
field_activation(),
6281
)
6382

64-
def forward(self, in_edges, out_edges, edge_features, hood_coords):
83+
def forward(
84+
self,
85+
source, # Source edges ij or triangles ijk (m,)
86+
target, # Target edges jk or triangles jkl (m,)
87+
features, # Edge or triangle features of shape (m, in_channels)
88+
hood_coords, # Locally transformed coordinates of shape (m, dim)
89+
):
90+
m = features.shape[0]
6591
w = self.neural_field(hood_coords / self.radius).view(
6692
-1, self.out_channels, self.in_channels
6793
)
68-
f = edge_features[in_edges]
69-
out_channels = torch.einsum("ni,noi->no", f, w) # n x out
70-
edge_features = scatter_mean(
71-
out_channels, out_edges, dim_size=edge_features.shape[0], dim=0
72-
)
73-
return edge_features
94+
f = features[source]
95+
out_channels = torch.einsum("ni,noi->no", f, w) # m x out
96+
features = scatter_mean(out_channels, target, dim_size=m, dim=0)
97+
return features # Updated features of shape (m, out_channels)
7498

7599

76100
class PointTransformerCosmo(nn.Module):
77-
def __init__(self, in_channels, out_channels, radius, dim=3):
101+
"""
102+
Implements Point Transformer Convolution (Zhao et al. 2020) in the Cosmo framework.
103+
"""
104+
105+
def __init__(
106+
self,
107+
in_channels, # Number of input channels
108+
out_channels, # Number of output channels
109+
radius=1.0, # Scale parameter for input coordinates
110+
dim=3, # Dimension of the input data (2 or 3)
111+
):
78112
super().__init__()
79113
self.register_buffer("radius", torch.tensor(radius))
80114
self.delta = nn.Sequential(
@@ -89,12 +123,18 @@ def __init__(self, in_channels, out_channels, radius, dim=3):
89123
self.w2 = nn.Linear(in_channels, out_channels, bias=False)
90124
self.w3 = nn.Linear(in_channels, out_channels, bias=False)
91125

92-
def forward(self, ijk, jkl, tri_features, hood_coords):
93-
n = tri_features.shape[0]
126+
def forward(
127+
self,
128+
source, # Source edges ij or triangles ijk (m,)
129+
target, # Target edges jk or triangles jkl (m,)
130+
features, # Edge or triangle features of shape (m, in_channels)
131+
hood_coords, # Locally transformed coordinates of shape (m, dim)
132+
):
133+
m = features.shape[0]
94134
d = self.delta(hood_coords / self.radius)
95-
w1 = self.w1(tri_features)
96-
w2 = self.w2(tri_features)
97-
w3 = self.w3(tri_features)
98-
a = scatter_softmax(w1[jkl] - w2[ijk] + d, jkl, dim=0, dim_size=n)
99-
tri_features = scatter_add(a * (w3[ijk] + d), jkl, dim=0, dim_size=n)
100-
return tri_features
135+
w1 = self.w1(features)
136+
w2 = self.w2(features)
137+
w3 = self.w3(features)
138+
a = scatter_softmax(w1[target] - w2[source] + d, target, dim=0, dim_size=m)
139+
features = scatter_add(a * (w3[source] + d), target, dim=0, dim_size=m)
140+
return features # Updated features of shape (m, out_channels)

cosmic/lift.py

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
from types import SimpleNamespace
22

33
import torch
4+
from torch_scatter import scatter_max, scatter_mean, scatter_softmax, scatter_sum
45

56
from .utilities import *
67

78

89
class Lift2D:
10+
"""
11+
Parameter-free module to lift a 2D geometric graph. Given node features, global coordinates, and a graph adjacency matrix it computes the lifted adjacency and coordinates of neighborhoods in the local reference frame of the edges, together with some helper variables. These build the input to a Cosmo layer.
12+
"""
913

1014
@torch.compiler.disable
11-
def __call__(self, node_features, coords, edge_index, node2inst):
15+
def __call__(
16+
self,
17+
node_features, # Node features of shape (n, in_channels)
18+
coords, # Global coordinates of shape (n, 2)
19+
edge_index, # Edge index of shape (2, m)
20+
node2inst, # Node-to-instance mapping of shape (n,)
21+
):
1222
n = coords.shape[0]
1323
adj = (
1424
torch.sparse_coo_tensor(
@@ -32,26 +42,36 @@ def __call__(self, node_features, coords, edge_index, node2inst):
3242
edge2inst = node2inst[edge2node]
3343
edge_features = node_features[edge2node]
3444
return SimpleNamespace(
35-
adj=adj,
36-
ij=ij,
37-
jk=jk,
38-
coords=coords,
39-
hood_coords=hood_coords,
40-
edge_features=edge_features,
41-
bases=bases,
42-
i=i,
43-
j=j,
44-
edges=edges,
45-
node2inst=node2inst,
46-
edge2node=edge2node,
47-
edge2inst=edge2inst,
45+
adj=adj, # Sorted adjacency matrix of shape (n, n)
46+
source=ij, # Lifted source edges (m,)
47+
target=jk, # Lifted target edges (m,)
48+
coords=coords, # Global coordinates of shape (n, 2)
49+
hood_coords=hood_coords, # Local coordinates of shape (m, 2)
50+
features=edge_features, # Edge features of shape (m, in_channels)
51+
bases=bases, # Bases of shape (m, 2, 2)
52+
i=i, # Node indices i (m,)
53+
j=j, # Node indices j (m,)
54+
edges=edges, # Tuples of i,j (m, 2)
55+
node2inst=node2inst, # Node-to-instance mapping of shape (n,)
56+
lifted2node=edge2node, # Edge-to-node mapping of shape (m,)
57+
lifted2inst=edge2inst, # Edge-to-instance mapping of shape (m,)
4858
)
4959

5060

5161
class Lift3D:
62+
"""
63+
Parameter-free module to lift a 3D geometric graph. Given node features, global coordinates, and a graph adjacency matrix it computes the lifted adjacency and coordinates of neighborhoods in the local reference frame of the triangles, together with some helper variables. These build the input to a Cosmo layer.
64+
"""
5265

5366
@torch.compiler.disable
54-
def __call__(self, node_features, coords, edge_index, node2inst, minimum_angle=0.0):
67+
def __call__(
68+
self,
69+
node_features, # Node features of shape (n, in_channels)
70+
coords, # Global coordinates of shape (n, 3)
71+
edge_index, # Edge index of shape (2, m)
72+
node2inst, # Node-to-instance mapping of shape (n,)
73+
minimum_angle=0.0, # Minimum angle to filter nearly colinear triangles (default: 0.0)
74+
):
5575
n = coords.shape[0]
5676
adj = (
5777
torch.sparse_coo_tensor(
@@ -76,18 +96,43 @@ def __call__(self, node_features, coords, edge_index, node2inst, minimum_angle=0
7696
tri2inst = node2inst[tri2node]
7797
tri_features = node_features[tri2node]
7898
return SimpleNamespace(
79-
adj=adj,
80-
ijk=ijk,
81-
jkl=jkl,
82-
coords=coords,
83-
hood_coords=hood_coords,
84-
tri_features=tri_features,
85-
bases=bases,
86-
i=i,
87-
j=j,
88-
triangles=triangles,
89-
node2inst=node2inst,
90-
tri2node=tri2node,
91-
tri2edge=tri2edge,
92-
tri2inst=tri2inst,
99+
adj=adj, # Sorted adjacency matrix of shape (n, n)
100+
source=ijk, # Lifted source triangles (m,)
101+
target=jkl, # Lifted target triangles (m,)
102+
coords=coords, # Global coordinates of shape (n, 3)
103+
hood_coords=hood_coords, # Local coordinates of shape (m, 3)
104+
features=tri_features, # Triangle features of shape (m, in_channels)
105+
bases=bases, # Bases of shape (m, 3, 3)
106+
i=i, # Node indices i (m,)
107+
j=j, # Node indices j (m,)
108+
triangles=triangles, # Tuples of i,j,k (m, 3)
109+
node2inst=node2inst, # Node-to-instance mapping of shape (n,)
110+
lifted2node=tri2node, # Triangle-to-node mapping of shape (m,)
111+
lifted2edge=tri2edge, # Triangle-to-edge mapping of shape (m,)
112+
lifted2inst=tri2inst, # Triangle-to-instance mapping of shape (m,)
93113
)
114+
115+
116+
class Lower:
117+
"""
118+
Parameter-free module to lower a lifted geometric graph back to the input graph. Given edge/triangle features and the corresponding index it aggregates the features to the input graph.
119+
"""
120+
121+
def __init__(self, agg="mean"):
122+
assert agg in ["sum", "mean", "max", "softmax"]
123+
self.agg = agg
124+
125+
def __call__(self, features, index, size, return_index=False):
126+
if self.agg == "sum":
127+
return scatter_sum(features, index, dim_size=size, dim=0)
128+
elif self.agg == "mean":
129+
return scatter_mean(features, index, dim_size=size, dim=0)
130+
elif self.agg == "max":
131+
val, idx = scatter_max(features, index, dim_size=size, dim=0)
132+
if return_index:
133+
return val, idx
134+
else:
135+
return val
136+
elif self.agg == "softmax":
137+
a = scatter_softmax(features, index, dim_size=size, dim=0)
138+
return scatter_sum(a * features, index, dim_size=size, dim=0)

0 commit comments

Comments
 (0)