Skip to content

Commit 6c2c799

Browse files
committed
Add config source interface and environment variable source
1 parent fa101ae commit 6c2c799

File tree

4 files changed

+164
-0
lines changed

4 files changed

+164
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .sources import EnvironmentSource
2+
3+
__all__ = [
4+
"EnvironmentSource",
5+
]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
from typing import Any
3+
4+
5+
class EnvironmentSource:
6+
"""Configuration from environment variables."""
7+
8+
SOURCE = "environment"
9+
10+
def __init__(self, prefix: str = "AWS_"):
11+
"""Initialize the EnvironmentSource with environment variable prefix.
12+
13+
:param prefix: Prefix for environment variables (default: 'AWS_')
14+
"""
15+
self._prefix = prefix
16+
17+
@property
18+
def name(self) -> str:
19+
"""Returns the source name."""
20+
return self.SOURCE
21+
22+
def get(self, key: str) -> Any | None:
23+
"""Returns a configuration value from environment variables.
24+
25+
:param key: The standard configuration key (e.g., 'region', 'retry_mode').
26+
27+
:returns: The value from the corresponding environment variable, or None if not set or empty.
28+
"""
29+
env_var = f"{self._prefix}{key.upper()}"
30+
config_value = os.environ.get(env_var)
31+
if config_value is None:
32+
return None
33+
stripped = config_value.strip()
34+
return stripped if stripped else None
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
from unittest.mock import patch
3+
4+
from smithy_aws_core.config.sources import EnvironmentSource
5+
from smithy_core.interfaces.source import ConfigSource
6+
7+
8+
class TestEnvironmentSource:
9+
def test_implements_config_source_protocol(self):
10+
source = EnvironmentSource()
11+
assert isinstance(source, ConfigSource)
12+
assert hasattr(source, "name")
13+
assert hasattr(source, "get")
14+
assert callable(source.get)
15+
16+
def test_source_name(self):
17+
source = EnvironmentSource()
18+
assert source.name == "environment"
19+
20+
def test_get_region_from_aws_region(self):
21+
with patch.dict(os.environ, {"AWS_REGION": "us-west-2"}, clear=False):
22+
source = EnvironmentSource()
23+
value = source.get("region")
24+
assert value == "us-west-2"
25+
26+
def test_get_returns_none_when_env_var_not_set(self):
27+
with patch.dict(os.environ, {}, clear=True):
28+
source = EnvironmentSource()
29+
value = source.get("region")
30+
assert value is None
31+
32+
def test_get_returns_none_for_unknown_key(self):
33+
source = EnvironmentSource()
34+
value = source.get("unknown_config_key")
35+
assert value is None
36+
37+
def test_get_handles_empty_string_env_var(self):
38+
with patch.dict(os.environ, {"AWS_REGION": ""}, clear=False):
39+
source = EnvironmentSource()
40+
value = source.get("region")
41+
# Empty string should be treated as None
42+
assert value is None
43+
44+
def test_get_handles_whitespace_env_var(self):
45+
with patch.dict(os.environ, {"AWS_REGION": " us-west-2 "}, clear=False):
46+
source = EnvironmentSource()
47+
value = source.get("region")
48+
# Whitespaces should be stripped
49+
assert value == "us-west-2"
50+
51+
def test_get_handles_whole_whitespace_env_var(self):
52+
with patch.dict(os.environ, {"AWS_REGION": " "}, clear=False):
53+
source = EnvironmentSource()
54+
value = source.get("region")
55+
# Whitespaces should be stripped
56+
assert value is None
57+
58+
def test_multiple_keys_with_different_env_vars(self):
59+
env_vars = {"AWS_REGION": "eu-west-1", "AWS_RETRY_MODE": "standard"}
60+
with patch.dict(os.environ, env_vars, clear=False):
61+
source = EnvironmentSource()
62+
63+
region = source.get("region")
64+
retry_mode = source.get("retry_mode")
65+
66+
assert region == "eu-west-1"
67+
assert retry_mode == "standard"
68+
69+
def test_get_is_idempotent(self):
70+
with patch.dict(os.environ, {"AWS_REGION": "ap-south-1"}, clear=False):
71+
source = EnvironmentSource()
72+
# Calling get on source multiple times should return the same value
73+
value1 = source.get("region")
74+
value2 = source.get("region")
75+
value3 = source.get("region")
76+
77+
assert value1 == value2 == value3 == "ap-south-1"
78+
79+
def test_source_does_not_cache_env_vars(self):
80+
source = EnvironmentSource()
81+
82+
# First read
83+
with patch.dict(os.environ, {"AWS_REGION": "us-east-1"}, clear=False):
84+
value1 = source.get("region")
85+
assert value1 == "us-east-1"
86+
87+
# Environment changes
88+
with patch.dict(os.environ, {"AWS_REGION": "us-west-2"}, clear=False):
89+
value2 = source.get("region")
90+
assert value2 == "us-west-2"
91+
92+
# Source reads from os.environ and not from cache
93+
assert value1 != value2
94+
95+
def test_env_var_names_are_case_sensative(self):
96+
with patch.dict(os.environ, {"aws_region": "us-west-2"}, clear=False):
97+
source = EnvironmentSource()
98+
value = source.get("region")
99+
# Should not find 'aws_region' (lowercase), only 'AWS_REGION'
100+
assert value is None
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Any, Protocol, runtime_checkable
2+
3+
4+
@runtime_checkable
5+
class ConfigSource(Protocol):
6+
"""Protocol for configuration sources that provide values from various locations
7+
like environment variables and configuration files.
8+
"""
9+
10+
@property
11+
def name(self) -> str:
12+
"""Returns a string identifying the source.
13+
14+
:returns: A string identifier for this source.
15+
"""
16+
...
17+
18+
def get(self, key: str) -> Any | None:
19+
"""Returns a configuration value from the source.
20+
21+
:param key: The configuration key to retrieve (e.g., 'region')
22+
23+
:returns: The value associated with the key, or None if not found.
24+
"""
25+
...

0 commit comments

Comments
 (0)