Skip to content

Explore Graph Neural Networks (GNNs) for Next Iteration of AgentFarm Perception System #315

@csmangum

Description

@csmangum

To enhance the scalability, memory efficiency, and relational reasoning capabilities of the AgentFarm Perception & Observation System, we propose exploring Graph Neural Networks (GNNs) as the next iteration. GNNs are a SOTA approach for multi-agent reinforcement learning (MARL) perception, offering sparse, relational observation encoding, efficient spatial processing, and temporal modeling. This issue outlines the objectives, scope, and tasks to prototype and evaluate GNNs as a potential replacement or complement to the current hybrid sparse-dense architecture, focusing on channels like ALLIES_HP, ENEMIES_HP, and DAMAGE_HEAT.
Motivation
The current perception system achieves [TBD]% memory reduction via sparse storage and lazy dense construction but faces challenges in:

Python dict overhead for sparse channels (e.g., memory fragmentation, hash collisions).
Dense tensor construction costs for large perception radii (O(grid_size)).
Limited relational reasoning for agent interactions (e.g., ALLY_SIGNAL).

GNNs address these by:

Encoding observations as sparse graphs (nodes = agents/resources, edges = proximity/signals), reducing memory for entity channels.
Enabling O(k log n) spatial queries via integration with KD-trees or ANN indices.
Supporting temporal decay through recurrent or attention-based GNNs, streamlining DYNAMIC channels.
Producing NN-compatible embeddings, aligning with PyTorch/TensorFlow requirements.
Scaling linearly to 10k+ agents, as shown in SOTA MARL frameworks like DGN and LEMAE (arXiv:2410.02511, arXiv:2402.10831).

Goals

Evaluate GNN Feasibility: Prototype GNN-based observation generation for 1-2 channels (e.g., ALLIES_HP, DAMAGE_HEAT) and compare memory/performance with current system.
Integrate with Existing System: Ensure GNN outputs are compatible with dense tensor pipelines for neural network processing.
Leverage Spatial Index: Combine GNN edge construction with existing KD-tree queries for O(log n) efficiency.
Test Scalability: Benchmark with 100, 1k, and 10k agents to verify linear scaling.
Explore Temporal Modeling: Implement decaying signals (e.g., DAMAGE_HEAT) using GNN attention or recurrent layers.

Tasks

  1. Research and Setup

Study PyTorch Geometric (https://pytorch-geometric.readthedocs.io/) and DGL (https://github.com/dmlc/dgl) for GNN implementation.
Review SOTA MARL papers (e.g., DGN, LEMAE, BUN) for perception-specific GNN patterns.
Install dependencies (PyTorch Geometric, scipy for KD-tree integration).

  1. Prototype GNN Observation Pipeline

Design graph structure:
Nodes: Agents (SELF_HP, position), resources (RESOURCES), landmarks.
Edges: Proximity-based (from cKDTree.get_nearby()), weighted by distance or signals.
Node features: HP, position, temporal signals (e.g., DAMAGE_HEAT values).

Implement a simple GNN model (e.g., GraphConv or GAT) to aggregate neighbor features into embeddings.
Example:import torch_geometric.nn as gnn
class PerceptionGNN(torch.nn.Module):
def init(self, in_channels=3, out_channels=32): # e.g., HP, x, y
super().init()
self.conv1 = gnn.GraphConv(in_channels, out_channels)
def forward(self, x, edge_index):
return self.conv1(x, edge_index) # Outputs per-agent embeddings

Integrate with existing KD-tree:
Use spatial_index.get_nearby(agent.position, config.fov_radius) to generate edge_index.

Test on a single channel (e.g., ALLIES_HP) with 100 agents.

  1. Hybrid Integration

Combine GNN embeddings with dense channels (e.g., VISIBILITY, RESOURCES) for policy input.
Concatenate GNN output (e.g., 32-dim embedding) with dense grids (e.g., 7x7 VISIBILITY).

Ensure compatibility with _build_dense_tensor for legacy RL pipelines.
Update ChannelRegistry to support GNN-based handlers:class GNNChannelHandler(ChannelHandler):
def process(self, observation, channel_idx, config, agent_world_pos, **kwargs):
x, edge_index = self.build_graph(kwargs["spatial_index"], agent_world_pos)
embedding = self.gnn_model(x, edge_index)
observation[channel_idx] = embedding

  1. Temporal Decay with GNNs

Implement DYNAMIC channel decay (e.g., DAMAGE_HEAT) using attention weights or recurrent GNNs.
Example: Apply decay as edge_weight *= gamma in message-passing.

Compare with current _decay_sparse_channel for memory and speed.

  1. Benchmarking

Measure memory usage and observation generation time for:
Agent counts: 100, 1,000, 10,000.
Radii: 3 (7x7), 6 (13x13), 12 (25x25).

Compare against current sparse-dense system:
Metrics: Memory (KB per agent), generation time (ms), NN processing time.
Tools: PyTorch profiler, memory_profiler.

Test on a sample MARL env (e.g., PettingZoo’s simple_spread).

  1. Evaluate Scalability

Explore distributed GNN training with RLlib (https://docs.ray.io/en/master/rllib/) for 10k+ agents.
Test approximate nearest neighbor indices (e.g., HNSW via nmslib) for faster edge updates.

  1. Documentation and Reporting

Document findings in a report comparing GNN vs. current system (memory, speed, scalability).
Update design doc with GNN architecture if adopted.
Present trade-offs (e.g., GNN complexity vs. memory savings).

Success Criteria

GNN prototype achieves at least [TBD]% memory reduction for entity channels vs. current sparse dicts.
Observation generation remains <100ms for 10k agents.
GNN embeddings maintain or improve RL policy performance in sample env.
Temporal decay matches or outperforms current cleanup loops in efficiency.

Risks and Mitigations

Risk: GNN message-passing slower than dense ops for small graphs.
Mitigation: Use dense grids for small radii; hybridize GNN for entity channels only.

Risk: Edge construction overhead for dynamic graphs.
Mitigation: Cache edges and update incrementally with KD-tree change detection.

Risk: Learning curve for GNN frameworks.
Mitigation: Start with PyTorch Geometric tutorials; leverage MARLlib examples.

Timeline

Week 1-2: Research, setup, prototype GNN for ALLIES_HP.
Week 3: Integrate with dense channels, test temporal decay.
Week 4: Benchmark and evaluate scalability; document results.
Week 5: Iterate based on findings, finalize report.

References

PyTorch Geometric: https://pytorch-geometric.readthedocs.io/
DGL: https://github.com/dmlc/dgl
RLlib: https://docs.ray.io/en/master/rllib/
Papers: DGN (arXiv:1806.08362), LEMAE (arXiv:2410.02511), BUN (arXiv:2402.10831)

Labels

Enhancement
Research
Performance
MARL

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions