Skip to content

Commit 3ddcc64

Browse files
authored
Add config source interface and environment variable source for config resolution (#640)
* Add config source interface and environment variable source
1 parent 027ebb5 commit 3ddcc64

File tree

4 files changed

+155
-0
lines changed

4 files changed

+155
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from .sources import EnvironmentSource
4+
5+
__all__ = [
6+
"EnvironmentSource",
7+
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import os
4+
5+
6+
class EnvironmentSource:
7+
"""Configuration from environment variables."""
8+
9+
def __init__(self, prefix: str = "AWS_"):
10+
"""Initialize the EnvironmentSource with environment variable prefix.
11+
12+
:param prefix: Prefix for environment variables (default: 'AWS_')
13+
"""
14+
self._prefix = prefix
15+
16+
@property
17+
def name(self) -> str:
18+
"""Returns the source name."""
19+
return "environment"
20+
21+
def get(self, key: str) -> str | None:
22+
"""Returns a configuration value from environment variables.
23+
24+
The key is transformed to uppercase and prefixed (e.g., 'region' → 'AWS_REGION').
25+
26+
:param key: The standard configuration key (e.g., 'region', 'retry_mode').
27+
28+
:returns: The value from the environment variable (or empty string if set to empty),
29+
or None if the variable is not set.
30+
"""
31+
env_var = f"{self._prefix}{key.upper()}"
32+
config_value = os.environ.get(env_var)
33+
if config_value is None:
34+
return None
35+
return config_value.strip()
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import os
4+
from unittest.mock import patch
5+
6+
from smithy_aws_core.config.sources import EnvironmentSource
7+
8+
9+
class TestEnvironmentSource:
10+
def test_source_name(self):
11+
source = EnvironmentSource()
12+
assert source.name == "environment"
13+
14+
def test_get_region_from_aws_region(self):
15+
with patch.dict(os.environ, {"AWS_REGION": "us-west-2"}, clear=True):
16+
source = EnvironmentSource()
17+
value = source.get("region")
18+
assert value == "us-west-2"
19+
20+
def test_get_returns_none_when_env_var_not_set(self):
21+
with patch.dict(os.environ, {}, clear=True):
22+
source = EnvironmentSource()
23+
value = source.get("region")
24+
assert value is None
25+
26+
def test_get_returns_none_for_unknown_key(self):
27+
source = EnvironmentSource()
28+
value = source.get("unknown_config_key")
29+
assert value is None
30+
31+
def test_get_handles_empty_string_env_var(self):
32+
with patch.dict(os.environ, {"AWS_REGION": ""}, clear=True):
33+
source = EnvironmentSource()
34+
value = source.get("region")
35+
assert value == ""
36+
37+
def test_get_handles_whitespace_env_var(self):
38+
with patch.dict(os.environ, {"AWS_REGION": " us-west-2 "}, clear=True):
39+
source = EnvironmentSource()
40+
value = source.get("region")
41+
# Whitespaces should be stripped
42+
assert value == "us-west-2"
43+
44+
def test_get_handles_whole_whitespace_env_var(self):
45+
with patch.dict(os.environ, {"AWS_REGION": " "}, clear=True):
46+
source = EnvironmentSource()
47+
value = source.get("region")
48+
# Whitespaces should be stripped
49+
assert value == ""
50+
51+
def test_multiple_keys_with_different_env_vars(self):
52+
env_vars = {"AWS_REGION": "eu-west-1", "AWS_RETRY_MODE": "standard"}
53+
with patch.dict(os.environ, env_vars, clear=True):
54+
source = EnvironmentSource()
55+
56+
region = source.get("region")
57+
retry_mode = source.get("retry_mode")
58+
59+
assert region == "eu-west-1"
60+
assert retry_mode == "standard"
61+
62+
def test_get_is_idempotent(self):
63+
with patch.dict(os.environ, {"AWS_REGION": "ap-south-1"}, clear=True):
64+
source = EnvironmentSource()
65+
# Calling get on source multiple times should return the same value
66+
value1 = source.get("region")
67+
value2 = source.get("region")
68+
value3 = source.get("region")
69+
70+
assert value1 == value2 == value3 == "ap-south-1"
71+
72+
def test_source_does_not_cache_env_vars(self):
73+
source = EnvironmentSource()
74+
75+
# First read
76+
with patch.dict(os.environ, {"AWS_REGION": "us-east-1"}, clear=True):
77+
value1 = source.get("region")
78+
assert value1 == "us-east-1"
79+
80+
# Environment changes
81+
with patch.dict(os.environ, {"AWS_REGION": "us-west-2"}, clear=False):
82+
value2 = source.get("region")
83+
assert value2 == "us-west-2"
84+
85+
# Source reads from os.environ and not from cache
86+
assert value1 != value2
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from typing import Any, Protocol, runtime_checkable
4+
5+
6+
@runtime_checkable
7+
class ConfigSource(Protocol):
8+
"""Protocol for configuration sources that provide values from various locations
9+
like environment variables and configuration files.
10+
"""
11+
12+
@property
13+
def name(self) -> str:
14+
"""Returns a string identifying the source.
15+
16+
:returns: A string identifier for this source.
17+
"""
18+
...
19+
20+
def get(self, key: str) -> Any | None:
21+
"""Returns a configuration value from the source.
22+
23+
:param key: The configuration key to retrieve (e.g., 'region')
24+
25+
:returns: The value associated with the key, or None if not found.
26+
"""
27+
...

0 commit comments

Comments
 (0)