Skip to content

Commit e637434

Browse files
End-to-end test of the complete AutoShield pipeline
1 parent 0be5355 commit e637434

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

tests/test_end_to_end.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# tests/test_end_to_end.py
2+
"""
3+
End-to-end test of the complete AutoShield pipeline.
4+
"""
5+
import pytest
6+
import asyncio
7+
import json
8+
from datetime import datetime
9+
10+
from autoshield.orchestrator import AutoShieldOrchestrator
11+
from autoshield.models import FlowWindow
12+
13+
@pytest.fixture
14+
def orchestrator():
15+
"""Create orchestrator for testing"""
16+
# Use test mode (no actual Kubernetes actions)
17+
orchestrator = AutoShieldOrchestrator(
18+
model_path="data/models/cnn-lstm/latest/final_model.pth",
19+
policy_file="config/policies/test.yaml",
20+
enable_actuation=False # Test mode - no actual actions
21+
)
22+
return orchestrator
23+
24+
@pytest.fixture
25+
def normal_window():
26+
"""Create normal traffic window"""
27+
return {
28+
"window_id": "test_normal_001",
29+
"src_pod": "frontend",
30+
"dst_pod": "backend",
31+
"flow_count": 15,
32+
"bytes_sent": 25000,
33+
"bytes_received": 15000,
34+
"syn_count": 8,
35+
"ack_count": 14,
36+
"rst_count": 1,
37+
"fin_count": 6,
38+
"total_duration_ms": 1500,
39+
"avg_interarrival_ms": 100,
40+
"std_interarrival_ms": 20,
41+
"failed_conn_ratio": 0.05,
42+
"unique_ports": 2,
43+
"start_time": datetime.now().isoformat(),
44+
"end_time": datetime.now().isoformat()
45+
}
46+
47+
@pytest.fixture
48+
def attack_window():
49+
"""Create attack traffic window (port scan)"""
50+
return {
51+
"window_id": "test_attack_001",
52+
"src_pod": "attacker-pod",
53+
"dst_pod": "target-service",
54+
"flow_count": 35,
55+
"bytes_sent": 5000,
56+
"bytes_received": 1000,
57+
"syn_count": 32,
58+
"ack_count": 5,
59+
"rst_count": 25,
60+
"fin_count": 2,
61+
"total_duration_ms": 2500,
62+
"avg_interarrival_ms": 70,
63+
"std_interarrival_ms": 10,
64+
"failed_conn_ratio": 0.75,
65+
"unique_ports": 22,
66+
"start_time": datetime.now().isoformat(),
67+
"end_time": datetime.now().isoformat()
68+
}
69+
70+
@pytest.mark.asyncio
71+
async def test_normal_traffic(orchestrator, normal_window):
72+
"""Test processing of normal traffic"""
73+
result = await orchestrator.process_window(normal_window)
74+
75+
assert result["window_id"] == normal_window["window_id"]
76+
assert "detection_result" in result
77+
assert "policy_decision" in result
78+
79+
# Normal traffic should not trigger actions
80+
if result["detection_result"]["predicted_class"] == "NORMAL":
81+
assert result["policy_decision"] is None
82+
83+
@pytest.mark.asyncio
84+
async def test_attack_traffic(orchestrator, attack_window):
85+
"""Test processing of attack traffic"""
86+
result = await orchestrator.process_window(attack_window)
87+
88+
assert result["window_id"] == attack_window["window_id"]
89+
assert "detection_result" in result
90+
91+
detection = result["detection_result"]
92+
# Attack should be detected (though actual class depends on model)
93+
assert detection["predicted_class"] != "NORMAL" or detection["confidence"] < 0.5
94+
95+
# Should have enhanced explanation
96+
assert "enhanced_explanation" in result
97+
98+
def test_orchestrator_stats(orchestrator, normal_window, attack_window):
99+
"""Test orchestrator statistics"""
100+
# Process a few windows
101+
asyncio.run(orchestrator.process_window(normal_window))
102+
asyncio.run(orchestrator.process_window(attack_window))
103+
104+
stats = orchestrator.get_stats()
105+
106+
assert stats["orchestrator"]["total_windows_processed"] >= 2
107+
assert "inference" in stats
108+
assert "policy" in stats
109+
110+
def test_safety_features(orchestrator):
111+
"""Test safety controller functionality"""
112+
# Try to process window with protected namespace
113+
protected_window = {
114+
"window_id": "test_protected_001",
115+
"src_pod": "kube-system/coredns",
116+
"dst_pod": "backend",
117+
"flow_count": 100, # Very high to trigger
118+
"bytes_sent": 50000,
119+
"bytes_received": 50000,
120+
"syn_count": 50,
121+
"ack_count": 50,
122+
"rst_count": 0,
123+
"fin_count": 0,
124+
"total_duration_ms": 1000,
125+
"avg_interarrival_ms": 10,
126+
"std_interarrival_ms": 2,
127+
"failed_conn_ratio": 0.0,
128+
"unique_ports": 1,
129+
"start_time": datetime.now().isoformat(),
130+
"end_time": datetime.now().isoformat()
131+
}
132+
133+
result = asyncio.run(orchestrator.process_window(protected_window))
134+
135+
# Even if detected as attack, safety should prevent action on kube-system
136+
# (but actuation is disabled anyway in test mode)
137+
assert "detection_result" in result
138+
139+
@pytest.mark.integration
140+
def test_full_pipeline():
141+
"""Integration test with all components"""
142+
# This would test the complete pipeline with actual services
143+
# Requires all services to be deployed
144+
pass
145+
146+
if __name__ == "__main__":
147+
# Run quick demo
148+
orchestrator = AutoShieldOrchestrator(
149+
model_path="data/models/cnn-lstm/latest/final_model.pth",
150+
enable_actuation=False
151+
)
152+
153+
# Test with sample data
154+
test_data = {
155+
"window_id": "demo_001",
156+
"src_pod": "suspicious-pod",
157+
"dst_pod": "database",
158+
"flow_count": 28,
159+
"bytes_sent": 10000,
160+
"bytes_received": 2000,
161+
"syn_count": 20,
162+
"ack_count": 8,
163+
"rst_count": 15,
164+
"fin_count": 3,
165+
"total_duration_ms": 1800,
166+
"avg_interarrival_ms": 65,
167+
"std_interarrival_ms": 15,
168+
"failed_conn_ratio": 0.65,
169+
"unique_ports": 18,
170+
"start_time": datetime.now().isoformat(),
171+
"end_time": datetime.now().isoformat()
172+
}
173+
174+
result = asyncio.run(orchestrator.process_window(test_data))
175+
print(json.dumps(result, indent=2, default=str))

0 commit comments

Comments
 (0)