-
Notifications
You must be signed in to change notification settings - Fork 1
Description
To enhance the scalability, memory efficiency, and relational reasoning capabilities of the AgentFarm Perception & Observation System, we propose exploring Graph Neural Networks (GNNs) as the next iteration. GNNs are a SOTA approach for multi-agent reinforcement learning (MARL) perception, offering sparse, relational observation encoding, efficient spatial processing, and temporal modeling. This issue outlines the objectives, scope, and tasks to prototype and evaluate GNNs as a potential replacement or complement to the current hybrid sparse-dense architecture, focusing on channels like ALLIES_HP, ENEMIES_HP, and DAMAGE_HEAT.
Motivation
The current perception system achieves [TBD]% memory reduction via sparse storage and lazy dense construction but faces challenges in:
Python dict overhead for sparse channels (e.g., memory fragmentation, hash collisions).
Dense tensor construction costs for large perception radii (O(grid_size)).
Limited relational reasoning for agent interactions (e.g., ALLY_SIGNAL).
GNNs address these by:
Encoding observations as sparse graphs (nodes = agents/resources, edges = proximity/signals), reducing memory for entity channels.
Enabling O(k log n) spatial queries via integration with KD-trees or ANN indices.
Supporting temporal decay through recurrent or attention-based GNNs, streamlining DYNAMIC channels.
Producing NN-compatible embeddings, aligning with PyTorch/TensorFlow requirements.
Scaling linearly to 10k+ agents, as shown in SOTA MARL frameworks like DGN and LEMAE (arXiv:2410.02511, arXiv:2402.10831).
Goals
Evaluate GNN Feasibility: Prototype GNN-based observation generation for 1-2 channels (e.g., ALLIES_HP, DAMAGE_HEAT) and compare memory/performance with current system.
Integrate with Existing System: Ensure GNN outputs are compatible with dense tensor pipelines for neural network processing.
Leverage Spatial Index: Combine GNN edge construction with existing KD-tree queries for O(log n) efficiency.
Test Scalability: Benchmark with 100, 1k, and 10k agents to verify linear scaling.
Explore Temporal Modeling: Implement decaying signals (e.g., DAMAGE_HEAT) using GNN attention or recurrent layers.
Tasks
- Research and Setup
Study PyTorch Geometric (https://pytorch-geometric.readthedocs.io/) and DGL (https://github.com/dmlc/dgl) for GNN implementation.
Review SOTA MARL papers (e.g., DGN, LEMAE, BUN) for perception-specific GNN patterns.
Install dependencies (PyTorch Geometric, scipy for KD-tree integration).
- Prototype GNN Observation Pipeline
Design graph structure:
Nodes: Agents (SELF_HP, position), resources (RESOURCES), landmarks.
Edges: Proximity-based (from cKDTree.get_nearby()), weighted by distance or signals.
Node features: HP, position, temporal signals (e.g., DAMAGE_HEAT values).
Implement a simple GNN model (e.g., GraphConv or GAT) to aggregate neighbor features into embeddings.
Example:import torch_geometric.nn as gnn
class PerceptionGNN(torch.nn.Module):
def init(self, in_channels=3, out_channels=32): # e.g., HP, x, y
super().init()
self.conv1 = gnn.GraphConv(in_channels, out_channels)
def forward(self, x, edge_index):
return self.conv1(x, edge_index) # Outputs per-agent embeddings
Integrate with existing KD-tree:
Use spatial_index.get_nearby(agent.position, config.fov_radius) to generate edge_index.
Test on a single channel (e.g., ALLIES_HP) with 100 agents.
- Hybrid Integration
Combine GNN embeddings with dense channels (e.g., VISIBILITY, RESOURCES) for policy input.
Concatenate GNN output (e.g., 32-dim embedding) with dense grids (e.g., 7x7 VISIBILITY).
Ensure compatibility with _build_dense_tensor for legacy RL pipelines.
Update ChannelRegistry to support GNN-based handlers:class GNNChannelHandler(ChannelHandler):
def process(self, observation, channel_idx, config, agent_world_pos, **kwargs):
x, edge_index = self.build_graph(kwargs["spatial_index"], agent_world_pos)
embedding = self.gnn_model(x, edge_index)
observation[channel_idx] = embedding
- Temporal Decay with GNNs
Implement DYNAMIC channel decay (e.g., DAMAGE_HEAT) using attention weights or recurrent GNNs.
Example: Apply decay as edge_weight *= gamma in message-passing.
Compare with current _decay_sparse_channel for memory and speed.
- Benchmarking
Measure memory usage and observation generation time for:
Agent counts: 100, 1,000, 10,000.
Radii: 3 (7x7), 6 (13x13), 12 (25x25).
Compare against current sparse-dense system:
Metrics: Memory (KB per agent), generation time (ms), NN processing time.
Tools: PyTorch profiler, memory_profiler.
Test on a sample MARL env (e.g., PettingZoo’s simple_spread).
- Evaluate Scalability
Explore distributed GNN training with RLlib (https://docs.ray.io/en/master/rllib/) for 10k+ agents.
Test approximate nearest neighbor indices (e.g., HNSW via nmslib) for faster edge updates.
- Documentation and Reporting
Document findings in a report comparing GNN vs. current system (memory, speed, scalability).
Update design doc with GNN architecture if adopted.
Present trade-offs (e.g., GNN complexity vs. memory savings).
Success Criteria
GNN prototype achieves at least [TBD]% memory reduction for entity channels vs. current sparse dicts.
Observation generation remains <100ms for 10k agents.
GNN embeddings maintain or improve RL policy performance in sample env.
Temporal decay matches or outperforms current cleanup loops in efficiency.
Risks and Mitigations
Risk: GNN message-passing slower than dense ops for small graphs.
Mitigation: Use dense grids for small radii; hybridize GNN for entity channels only.
Risk: Edge construction overhead for dynamic graphs.
Mitigation: Cache edges and update incrementally with KD-tree change detection.
Risk: Learning curve for GNN frameworks.
Mitigation: Start with PyTorch Geometric tutorials; leverage MARLlib examples.
Timeline
Week 1-2: Research, setup, prototype GNN for ALLIES_HP.
Week 3: Integrate with dense channels, test temporal decay.
Week 4: Benchmark and evaluate scalability; document results.
Week 5: Iterate based on findings, finalize report.
References
PyTorch Geometric: https://pytorch-geometric.readthedocs.io/
DGL: https://github.com/dmlc/dgl
RLlib: https://docs.ray.io/en/master/rllib/
Papers: DGN (arXiv:1806.08362), LEMAE (arXiv:2410.02511), BUN (arXiv:2402.10831)
Labels
Enhancement
Research
Performance
MARL