diff --git a/tests/test_token_plugins.py b/tests/test_token_plugins.py index e1b967b7..1e8e7e5b 100644 --- a/tests/test_token_plugins.py +++ b/tests/test_token_plugins.py @@ -60,6 +60,37 @@ def test_simple(self): self.assertEqual(result[0], "remote_host") self.assertEqual(result[1], "remote_port") + def test_dir_with_files(self): + mock_file1 = MagicMock() + mock_file1.is_file.return_value = True + mock_file1.is_dir.return_value = False + mock_file1.open.return_value.__enter__.return_value.readlines.return_value = ["testhost1: remote_host1:remote_port1"] + + mock_file2 = MagicMock() + mock_file2.is_file.return_value = True + mock_file2.is_dir.return_value = False + mock_file2.open.return_value.__enter__.return_value.readlines.return_value = ["testhost2: remote_host2:remote_port2"] + + mock_dir = MagicMock() + mock_dir.is_dir.return_value = True + mock_dir.is_file.return_value = False + + mock_source_dir = MagicMock() + mock_source_dir.is_dir.return_value = True + mock_source_dir.iterdir.return_value = [mock_file1, mock_file2, mock_dir] + + with patch("websockify.token_plugins.Path") as mock_path: + mock_path.return_value = mock_source_dir + plugin = ReadOnlyTokenFile('configdir') + result1 = plugin.lookup('testhost1') + result2 = plugin.lookup('testhost2') + + mock_path.assert_called_once_with('configdir') + self.assertIsNotNone(result1) + self.assertIsNotNone(result2) + self.assertEqual(result1, ["remote_host1", "remote_port1"]) + self.assertEqual(result2, ["remote_host2", "remote_port2"]) + def test_tabs(self): mock_source_file = MagicMock() mock_source_file.is_dir.return_value = False diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index d582032f..a0c17012 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -45,7 +45,7 @@ def __init__(self, *args, **kwargs): def _load_targets(self): source = Path(self.source) if source.is_dir(): - cfg_files = [file for file in source if file.is_file()] + cfg_files = [file for file in source.iterdir() if file.is_file()] else: cfg_files = [source]