Skip to content

Commit dc75288

Browse files
committed
[tests] Add WebSocket consumer tests #677
Implement tests for RadiusBatchConsumer to verify authentication, authorization, and group messaging logic. Increases coverage for consumers.py to 84%. Fixes #677
1 parent 33bbd55 commit dc75288

File tree

8 files changed

+599
-8
lines changed

8 files changed

+599
-8
lines changed

openwisp_radius/consumers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from asgiref.sync import sync_to_async
22
from channels.generic.websocket import AsyncJsonWebsocketConsumer
3-
from django.core.exceptions import ObjectDoesNotExist
43

54
from .utils import load_model
65

@@ -12,13 +11,9 @@ def _user_can_access_batch(self, user, batch_id):
1211
if user.is_superuser:
1312
return RadiusBatch.objects.filter(pk=batch_id).exists()
1413
# For non-superusers, check their managed organizations
15-
try:
16-
RadiusBatch.objects.filter(
17-
pk=batch_id, organization__in=user.organizations_managed
18-
).exists()
19-
return True
20-
except ObjectDoesNotExist:
21-
return False
14+
return RadiusBatch.objects.filter(
15+
pk=batch_id, organization__in=user.organizations_managed
16+
).exists()
2217

2318
async def connect(self):
2419
self.batch_id = self.scope["url_route"]["kwargs"]["batch_id"]

openwisp_radius/tests/test_commands.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,3 +564,27 @@ def test_convert_called_station_id_command_with_slug(self, *args):
564564
call_command("convert_called_station_id")
565565
radius_acc.refresh_from_db()
566566
self.assertEqual(radius_acc.called_station_id, "CC-CC-CC-CC-CC-0C")
567+
568+
def test_convert_called_station_id_command_wrapper(self):
569+
from ..management.commands.convert_called_station_id import Command
570+
571+
command = Command()
572+
self.assertIsNotNone(command)
573+
from ..management.commands.base.convert_called_station_id import (
574+
BaseConvertCalledStationIdCommand,
575+
)
576+
577+
self.assertIsInstance(command, BaseConvertCalledStationIdCommand)
578+
579+
def test_prefix_add_users_command_wrapper(self):
580+
from ..management.commands.prefix_add_users import Command
581+
582+
command = Command()
583+
self.assertIsNotNone(command)
584+
from ..management.commands.base import BatchAddMixin
585+
from ..management.commands.base.prefix_add_users import (
586+
BasePrefixAddUsersCommand,
587+
)
588+
589+
self.assertIsInstance(command, BatchAddMixin)
590+
self.assertIsInstance(command, BasePrefixAddUsersCommand)
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from asgiref.sync import async_to_sync
2+
from channels.routing import URLRouter
3+
from channels.testing import WebsocketCommunicator
4+
from django.contrib.auth import get_user_model
5+
from django.test import TransactionTestCase
6+
from django.urls import re_path
7+
8+
from openwisp_users.tests.utils import TestOrganizationMixin
9+
10+
from ..consumers import RadiusBatchConsumer
11+
from ..utils import load_model
12+
from . import CreateRadiusObjectsMixin
13+
14+
User = get_user_model()
15+
RadiusBatch = load_model("RadiusBatch")
16+
17+
application = URLRouter(
18+
[
19+
re_path(
20+
r"^ws/radius/batch/(?P<batch_id>[^/]+)/$",
21+
RadiusBatchConsumer.as_asgi(),
22+
),
23+
]
24+
)
25+
26+
27+
class TestRadiusBatchConsumer(
28+
CreateRadiusObjectsMixin, TestOrganizationMixin, TransactionTestCase
29+
):
30+
31+
TEST_PASSWORD = "test_password" # noqa: S105
32+
33+
def _create_test_data(self):
34+
org = self._create_org()
35+
user = self._create_admin(password=self.TEST_PASSWORD)
36+
batch = self._create_radius_batch(
37+
name="test-batch",
38+
strategy="prefix",
39+
prefix="test-",
40+
organization=org,
41+
)
42+
return org, user, batch
43+
44+
def test_websocket_connect_superuser(self):
45+
_, user, batch = self._create_test_data()
46+
47+
async def test():
48+
communicator = WebsocketCommunicator(
49+
application,
50+
f"/ws/radius/batch/{batch.pk}/",
51+
)
52+
communicator.scope["user"] = user
53+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
54+
55+
connected, _ = await communicator.connect()
56+
assert connected is True
57+
await communicator.disconnect()
58+
59+
async_to_sync(test)()
60+
61+
def test_websocket_connect_staff_with_permission(self):
62+
org, _, batch = self._create_test_data()
63+
staff_user = self._create_administrator(
64+
organizations=[org], password=self.TEST_PASSWORD
65+
)
66+
67+
async def test():
68+
communicator = WebsocketCommunicator(
69+
application,
70+
f"/ws/radius/batch/{batch.pk}/",
71+
)
72+
communicator.scope["user"] = staff_user
73+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
74+
75+
connected, _ = await communicator.connect()
76+
assert connected is True
77+
await communicator.disconnect()
78+
79+
async_to_sync(test)()
80+
81+
def test_websocket_reject_unauthenticated(self):
82+
_, _, batch = self._create_test_data()
83+
84+
async def test():
85+
communicator = WebsocketCommunicator(
86+
application,
87+
f"/ws/radius/batch/{batch.pk}/",
88+
)
89+
from django.contrib.auth.models import AnonymousUser
90+
91+
communicator.scope["user"] = AnonymousUser()
92+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
93+
94+
connected, _ = await communicator.connect()
95+
assert connected is False
96+
97+
async_to_sync(test)()
98+
99+
def test_websocket_reject_non_staff(self):
100+
_, _, batch = self._create_test_data()
101+
regular_user = self._create_user(is_staff=False, password=self.TEST_PASSWORD)
102+
103+
async def test():
104+
communicator = WebsocketCommunicator(
105+
application,
106+
f"/ws/radius/batch/{batch.pk}/",
107+
)
108+
communicator.scope["user"] = regular_user
109+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
110+
111+
connected, _ = await communicator.connect()
112+
assert connected is False
113+
114+
async_to_sync(test)()
115+
116+
def test_websocket_reject_no_permission(self):
117+
_, _, batch = self._create_test_data()
118+
119+
staff_user = self._create_user(is_staff=True, password=self.TEST_PASSWORD)
120+
121+
async def test():
122+
communicator = WebsocketCommunicator(
123+
application,
124+
f"/ws/radius/batch/{batch.pk}/",
125+
)
126+
communicator.scope["user"] = staff_user
127+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
128+
129+
connected, _ = await communicator.connect()
130+
assert connected is False
131+
132+
async_to_sync(test)()
133+
134+
def test_websocket_group_connection(self):
135+
_, user, batch = self._create_test_data()
136+
137+
async def test():
138+
communicator = WebsocketCommunicator(
139+
application,
140+
f"/ws/radius/batch/{batch.pk}/",
141+
)
142+
communicator.scope["user"] = user
143+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
144+
145+
connected, _ = await communicator.connect()
146+
assert connected is True
147+
await communicator.disconnect()
148+
149+
async_to_sync(test)()
150+
151+
def test_batch_status_update(self):
152+
_, user, batch = self._create_test_data()
153+
154+
async def test():
155+
communicator = WebsocketCommunicator(
156+
application,
157+
f"/ws/radius/batch/{batch.pk}/",
158+
)
159+
communicator.scope["user"] = user
160+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
161+
162+
connected, _ = await communicator.connect()
163+
assert connected is True
164+
await communicator.send_json_to(
165+
{"type": "batch.status.update", "status": "completed"}
166+
)
167+
168+
from channels.layers import get_channel_layer
169+
170+
channel_layer = get_channel_layer()
171+
172+
await channel_layer.group_send(
173+
f"radius_batch_{batch.pk}",
174+
{"type": "batch_status_update", "status": "processing"},
175+
)
176+
177+
response = await communicator.receive_json_from()
178+
assert response == {"status": "processing"}
179+
180+
await communicator.disconnect()
181+
182+
async_to_sync(test)()
183+
184+
def test_disconnect_cleanup(self):
185+
_, user, batch = self._create_test_data()
186+
187+
async def test():
188+
communicator = WebsocketCommunicator(
189+
application,
190+
f"/ws/radius/batch/{batch.pk}/",
191+
)
192+
communicator.scope["user"] = user
193+
communicator.scope["url_route"] = {"kwargs": {"batch_id": str(batch.pk)}}
194+
195+
connected, _ = await communicator.connect()
196+
assert connected is True
197+
198+
await communicator.disconnect()
199+
200+
from channels.layers import get_channel_layer
201+
202+
channel_layer = get_channel_layer()
203+
204+
await channel_layer.group_send(
205+
f"radius_batch_{batch.pk}",
206+
{"type": "batch_status_update", "status": "completed"},
207+
)
208+
209+
async_to_sync(test)()
210+
211+
def test_user_can_access_batch_method(self):
212+
_, user, batch = self._create_test_data()
213+
consumer = RadiusBatchConsumer()
214+
215+
self.assertTrue(consumer._user_can_access_batch(user, batch.pk))
216+
217+
org = self._create_org(name="test-org-2", slug="test-org-2")
218+
staff_user = self._create_administrator(
219+
organizations=[org],
220+
password=self.TEST_PASSWORD,
221+
username="staff_user_2",
222+
email="staff2@example.com",
223+
)
224+
batch2 = self._create_radius_batch(
225+
name="test2",
226+
organization=org,
227+
strategy="prefix",
228+
prefix="test-prefix-2",
229+
)
230+
self.assertTrue(consumer._user_can_access_batch(staff_user, batch2.pk))
231+
232+
other_org = self._create_org(name="other", slug="other")
233+
other_user = self._create_administrator(
234+
organizations=[other_org],
235+
password=self.TEST_PASSWORD,
236+
username="other_user",
237+
email="other@example.com",
238+
)
239+
self.assertFalse(consumer._user_can_access_batch(other_user, batch2.pk))
240+
241+
def test_invalid_batch_id(self):
242+
_, user, _ = self._create_test_data()
243+
244+
async def test():
245+
invalid_batch_id = "00000000-0000-0000-0000-000000000000"
246+
communicator = WebsocketCommunicator(
247+
application,
248+
f"/ws/radius/batch/{invalid_batch_id}/",
249+
)
250+
communicator.scope["user"] = user
251+
communicator.scope["url_route"] = {"kwargs": {"batch_id": invalid_batch_id}}
252+
253+
connected, _ = await communicator.connect()
254+
assert connected is False
255+
256+
async_to_sync(test)()
257+
258+
def test_user_can_access_batch_with_invalid_uuid(self):
259+
_, user, _ = self._create_test_data()
260+
consumer = RadiusBatchConsumer()
261+
262+
result = consumer._user_can_access_batch(user, "00000000-0000-0000-0000-000000000000")
263+
self.assertFalse(result)
264+

openwisp_radius/tests/test_counters/test_base_counter.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def test_resets(self):
114114
self.assertEqual(str(datetime.fromtimestamp(start)), "2021-10-22 00:00:00")
115115
self.assertEqual(str(datetime.fromtimestamp(end)), "2021-11-22 00:00:00")
116116

117+
with self.subTest("monthly_subscription future start date logic"):
118+
user.date_joined = datetime.fromisoformat("2021-07-04 12:34:58")
119+
user.save(update_fields=["date_joined"])
120+
start, end = resets["monthly_subscription"](user)
121+
self.assertEqual(str(datetime.fromtimestamp(start)), "2021-10-04 00:00:00")
122+
self.assertEqual(str(datetime.fromtimestamp(end)), "2021-11-04 00:00:00")
123+
117124
with self.subTest("never"):
118125
start, end = resets["never"]()
119126
self.assertEqual(start, 0)
@@ -131,5 +138,33 @@ class MaxInputOctetsCounter(BaseDailyCounter):
131138
self.assertEqual(BaseMontlhyTrafficCounter.get_attribute_type(), "bytes")
132139
self.assertEqual(MaxInputOctetsCounter.get_attribute_type(), "bytes")
133140

141+
def test_base_exception_logging(self):
142+
from unittest.mock import MagicMock
143+
144+
from ...counters.exceptions import BaseException
145+
146+
logger = MagicMock()
147+
BaseException("message", "error", logger)
148+
logger.error.assert_called_with("message")
149+
with self.assertRaises(AssertionError):
150+
BaseException("message", "invalid_level", logger)
151+
152+
def test_consumed_method(self):
153+
opts = self._get_kwargs("Max-Daily-Session")
154+
from ...counters.sqlite.daily_counter import DailyCounter
155+
156+
counter = DailyCounter(**opts)
157+
consumed = counter.consumed()
158+
self.assertEqual(consumed, 0)
159+
self.assertIsInstance(consumed, int)
160+
161+
from .utils import _acct_data
162+
163+
self._create_radius_accounting(**_acct_data)
164+
consumed = counter.consumed()
165+
self.assertEqual(consumed, int(_acct_data["session_time"]))
166+
self.assertIsInstance(consumed, int)
167+
168+
134169

135170
del BaseTransactionTestCase

0 commit comments

Comments
 (0)