Skip to content

Commit 0aeaf07

Browse files
refactor(conf): support env alongside file configuration
1 parent 9dd0591 commit 0aeaf07

File tree

21 files changed

+147
-199
lines changed

21 files changed

+147
-199
lines changed

config/test.config.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
# Empty config file for tests
2+
user_agent = "foo" # Will be overwritten by pytest.env

sketch_map_tool/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,23 @@
99
# https://github.com/GIScience/sketch-map-tool/issues/503
1010
from osgeo import gdal, osr # noqa: F401
1111

12-
from sketch_map_tool.config import get_config_value
12+
from sketch_map_tool.config import CONFIG
1313
from sketch_map_tool.database import client_flask as db_client
1414
from sketch_map_tool.definitions import LANGUAGES
1515

1616
__version__ = "2025.11.12"
1717

1818
# Setup logging
19-
LEVEL = getattr(logging, get_config_value("log-level").upper())
19+
LEVEL = getattr(logging, CONFIG.log_level.upper())
2020
FORMAT = "%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(message)s"
2121
logging.basicConfig(
2222
level=LEVEL,
2323
format=FORMAT,
2424
)
2525

2626
CELERY_CONFIG = {
27-
"broker_url": get_config_value("broker-url"),
28-
"result_backend": get_config_value("result-backend"),
27+
"broker_url": CONFIG.broker_url,
28+
"result_backend": CONFIG.result_backend,
2929
"task_serializer": "pickle",
3030
"task_track_started": True, # report ‘started’ status worker executes task
3131
"task_send_sent_event": True,

sketch_map_tool/config.py

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,16 @@
1+
import logging
12
import os
2-
from types import MappingProxyType
3-
from typing import Dict
43

5-
import toml
4+
from pydantic import field_validator
5+
from pydantic_settings import (
6+
BaseSettings,
7+
PydanticBaseSettingsSource,
8+
SettingsConfigDict,
9+
TomlConfigSettingsSource,
10+
)
611

712
from sketch_map_tool.helpers import get_project_root
813

9-
DEFAULT_CONFIG = {
10-
"data-dir": str(get_project_root() / "data"),
11-
"weights-dir": str(get_project_root() / "weights"),
12-
"user-agent": "sketch-map-tool",
13-
"broker-url": "redis://localhost:6379",
14-
"result-backend": "db+postgresql://smt:smt@localhost:5432",
15-
"cleanup-map-frames-interval": "12 months",
16-
"wms-url-osm": "https://maps.heigit.org/raster/osm-carto/service?SERVICE=WMS&VERSION=1.1.1",
17-
"wms-layers-osm": "heigit:osm-carto-proxy",
18-
"wms-url-esri-world-imagery": "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1",
19-
"wms-url-esri-world-imagery-fallback": "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1",
20-
"wms-layers-esri-world-imagery": "world_imagery",
21-
"wms-layers-esri-world-imagery-fallback": "world_imagery_fallback",
22-
"wms-read-timeout": 600,
23-
"max-nr-simultaneous-uploads": 100,
24-
"yolo_cls": "SMT-CLS",
25-
"yolo_osm_obj": "SMT-OSM",
26-
"yolo_esri_obj": "SMT-ESRI",
27-
"model_type_sam": "vit_b",
28-
"esri-api-key": "",
29-
"log-level": "INFO",
30-
"point-area-threshold": 0.00047,
31-
}
32-
3314

3415
def get_config_path() -> str:
3516
"""Get configuration file path.
@@ -41,26 +22,55 @@ def get_config_path() -> str:
4122
return os.getenv("SMT_CONFIG", default=default)
4223

4324

44-
def load_config_from_file(path: str) -> Dict[str, str]:
45-
"""Load configuration from file on disk."""
46-
if os.path.isfile(path):
47-
with open(path, "r") as f:
48-
return toml.load(f)
49-
else:
50-
return {}
25+
class Config(BaseSettings):
26+
broker_url: str = "redis://localhost:6379"
27+
cleanup_map_frames_interval: str = "12 months"
28+
data_dir: str = str(get_project_root() / "data") # TODO: make this a Path
29+
esri_api_key: str = ""
30+
log_level: str = "INFO"
31+
max_nr_simultaneous_uploads: int = 100
32+
model_type_sam: str = "vit_b"
33+
point_area_threshold: float = 0.00047
34+
result_backend: str = "db+postgresql://smt:smt@localhost:5432"
35+
user_agent: str = "sketch-map-tool"
36+
weights_dir: str = str(get_project_root() / "weights") # TODO: make this a Path
37+
wms_layers_esri_world_imagery: str = "world_imagery"
38+
wms_layers_esri_world_imagery_fallback: str = "world_imagery_fallback"
39+
wms_layers_osm: str = "heigit:osm-carto-proxy"
40+
wms_read_timeout: int = 600
41+
wms_url_esri_world_imagery: str = "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1"
42+
wms_url_esri_world_imagery_fallback: str = "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1"
43+
wms_url_osm: str = (
44+
"https://maps.heigit.org/raster/osm-carto/service?SERVICE=WMS&VERSION=1.1.1"
45+
)
46+
yolo_cls: str = "SMT-CLS"
47+
yolo_esri_obj: str = "SMT-ESRI"
48+
yolo_osm_obj: str = "SMT-OSM"
5149

50+
model_config = SettingsConfigDict(
51+
env_prefix="SMT_",
52+
toml_file=get_config_path(),
53+
)
5254

53-
def get_config() -> MappingProxyType:
54-
"""Get configuration variables from environment and file.
55+
@classmethod
56+
def settings_customise_sources(
57+
cls,
58+
settings_cls: type[BaseSettings],
59+
init_settings: PydanticBaseSettingsSource,
60+
env_settings: PydanticBaseSettingsSource,
61+
dotenv_settings: PydanticBaseSettingsSource,
62+
file_secret_settings: PydanticBaseSettingsSource,
63+
) -> tuple[PydanticBaseSettingsSource, ...]:
64+
# https://docs.pydantic.dev/latest/concepts/pydantic_settings/#other-settings-source
65+
# env takes precedence over file settings
66+
return (env_settings, TomlConfigSettingsSource(settings_cls))
5567

56-
Configuration values from file will be given precedence over default values.
57-
"""
58-
cfg = DEFAULT_CONFIG
59-
cfg_file = load_config_from_file(get_config_path())
60-
cfg.update(cfg_file)
61-
return MappingProxyType(cfg)
68+
@field_validator("esri_api_key", mode="before")
69+
@classmethod
70+
def check_esri_api_key(cls, value: str) -> str:
71+
if not value:
72+
logging.warning("No ESRI API Key found.")
73+
return value
6274

6375

64-
def get_config_value(key: str):
65-
config = get_config()
66-
return config[key]
76+
CONFIG = Config()

sketch_map_tool/database/client_celery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from psycopg2.extensions import connection
88

99
from sketch_map_tool import __version__
10-
from sketch_map_tool.config import get_config_value
10+
from sketch_map_tool.config import CONFIG
1111
from sketch_map_tool.exceptions import (
1212
CustomFileDoesNotExistAnymoreError,
1313
CustomFileNotFoundError,
@@ -20,7 +20,7 @@
2020

2121
def open_connection():
2222
global db_conn
23-
raw = get_config_value("result-backend")
23+
raw = CONFIG.result_backend
2424
dns = raw[3:]
2525
db_conn = psycopg2.connect(dns)
2626
db_conn.autocommit = True
@@ -129,7 +129,7 @@ def cleanup_map_frames():
129129
"""
130130
with db_conn.cursor() as curs:
131131
try:
132-
curs.execute(query, [get_config_value("cleanup-map-frames-interval")])
132+
curs.execute(query, [CONFIG.cleanup_map_frames_interval])
133133
except UndefinedTable:
134134
logging.info("Table `map_frame` does not exist yet. Nothing todo.")
135135

sketch_map_tool/database/client_flask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from psycopg2.extensions import connection
77
from werkzeug.utils import secure_filename
88

9-
from sketch_map_tool.config import get_config_value
9+
from sketch_map_tool.config import CONFIG
1010
from sketch_map_tool.exceptions import (
1111
CustomFileDoesNotExistAnymoreError,
1212
CustomFileNotFoundError,
@@ -18,7 +18,7 @@
1818

1919
def open_connection():
2020
if "db_conn" not in g:
21-
raw = get_config_value("result-backend")
21+
raw = CONFIG.result_backend
2222
dns = raw[3:]
2323
g.db_conn = psycopg2.connect(dns)
2424
g.db_conn.autocommit = True

sketch_map_tool/definitions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import requests
88
from werkzeug.utils import secure_filename
99

10-
from sketch_map_tool.config import get_config_value
10+
from sketch_map_tool.config import CONFIG
1111
from sketch_map_tool.models import LiteratureReference, PaperFormat
1212
from sketch_map_tool.openaerialmap import client as oam_client
1313

@@ -37,7 +37,7 @@ def get_attribution(layer: str) -> str:
3737
url = (
3838
"https://basemaps-api.arcgis.com/arcgis/rest/services/styles/ArcGIS:Imagery"
3939
)
40-
token = get_config_value("esri-api-key")
40+
token = CONFIG.esri_api_key
4141
if token == "":
4242
sources = "Esri, Maxar, Earthstar Geographics, and the GIS User Community"
4343
logging.warning(
@@ -69,7 +69,7 @@ def get_literature_references() -> list[LiteratureReference]:
6969
For image source either a web URL or a filename of a file in the publications folder
7070
is expected.
7171
"""
72-
p = Path(get_config_value("data-dir")) / "literature.json"
72+
p = Path(CONFIG.data_dir) / "literature.json"
7373
with open(p, "r") as f:
7474
raw = json.load(f)
7575

sketch_map_tool/routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def create(lang="en") -> str:
8181
return render_template(
8282
"create.html.jinja",
8383
lang=lang,
84-
esri_api_key=config.get_config_value("esri-api-key"),
84+
esri_api_key=config.CONFIG.esri_api_key,
8585
)
8686

8787

sketch_map_tool/tasks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from ultralytics import YOLO
1111
from ultralytics_MB import YOLO as YOLO_MB
1212

13+
from sketch_map_tool import CONFIG, map_generation
1314
from sketch_map_tool import celery_app as celery
14-
from sketch_map_tool import get_config_value, map_generation
1515
from sketch_map_tool.database import client_celery as db_client_celery
1616
from sketch_map_tool.definitions import get_attribution
1717
from sketch_map_tool.exceptions import MarkingDetectionError
@@ -64,9 +64,9 @@ def init_worker_ml_models(**_):
6464
)
6565
sam_predictor = SAM2ImagePredictor(sam2_model)
6666

67-
yolo_obj_osm = YOLO_MB(init_model(get_config_value("yolo_osm_obj")))
68-
yolo_obj_esri = YOLO_MB(init_model(get_config_value("yolo_esri_obj")))
69-
yolo_cls = YOLO(init_model(get_config_value("yolo_cls")))
67+
yolo_obj_osm = YOLO_MB(init_model(CONFIG.yolo_osm_obj))
68+
yolo_obj_esri = YOLO_MB(init_model(CONFIG.yolo_esri_obj))
69+
yolo_cls = YOLO(init_model(CONFIG.yolo_cls))
7070

7171

7272
@worker_process_shutdown.connect
@@ -78,7 +78,7 @@ def shutdown_worker(**_):
7878

7979
@setup_logging.connect
8080
def on_setup_logging(**_):
81-
level = getattr(logging, get_config_value("log-level").upper())
81+
level = getattr(logging, CONFIG.log_level.upper())
8282
format = "%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(message)s"
8383
logging.basicConfig(
8484
level=level,

sketch_map_tool/upload_processing/ml_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@
55
import torch
66
from torch._prims_common import DeviceLikeType
77

8-
from sketch_map_tool.config import get_config_value
8+
from sketch_map_tool.config import CONFIG
99

1010

1111
def init_model(id: str) -> Path:
1212
"""Initialize model. Raise error if not found."""
13-
raw = Path(get_config_value("weights-dir")) / id
13+
raw = Path(CONFIG.weights_dir) / id
1414
path = raw.with_suffix(".pt")
1515
if not path.is_file():
1616
raise FileNotFoundError("Model not found at " + str(path))
1717
return path
1818

1919

2020
def init_sam2(id: str = "sam2_hiera_base_plus") -> Path:
21-
raw = Path(get_config_value("weights-dir")) / id
21+
raw = Path(CONFIG.weights_dir) / id
2222
path = raw.with_suffix(".pt")
2323
base_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
2424
url = base_url + id + ".pt"

sketch_map_tool/upload_processing/post_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from shapely.ops import transform, unary_union
99
from shapelysmooth import chaikin_smooth
1010

11-
from sketch_map_tool.config import get_config_value
11+
from sketch_map_tool.config import CONFIG
1212
from sketch_map_tool.definitions import COLORS
1313
from sketch_map_tool.models import Bbox
1414

@@ -157,7 +157,7 @@ def smooth(fc: FeatureCollection) -> FeatureCollection:
157157

158158
def classify_points(fc: FeatureCollection, bbox: Bbox) -> FeatureCollection:
159159
"""Classify each feature as point or polygon based area."""
160-
point_area_threshold = get_config_value("point-area-threshold")
160+
point_area_threshold = CONFIG.point_area_threshold
161161
fc_ = FeatureCollection(features=[])
162162

163163
for feature in fc["features"]:

0 commit comments

Comments
 (0)