1+ import logging
12import os
2- from types import MappingProxyType
3- from typing import Dict
43
5- import toml
4+ from pydantic import field_validator
5+ from pydantic_settings import (
6+ BaseSettings ,
7+ PydanticBaseSettingsSource ,
8+ SettingsConfigDict ,
9+ TomlConfigSettingsSource ,
10+ )
611
712from sketch_map_tool .helpers import get_project_root
813
9- DEFAULT_CONFIG = {
10- "data-dir" : str (get_project_root () / "data" ),
11- "weights-dir" : str (get_project_root () / "weights" ),
12- "user-agent" : "sketch-map-tool" ,
13- "broker-url" : "redis://localhost:6379" ,
14- "result-backend" : "db+postgresql://smt:smt@localhost:5432" ,
15- "cleanup-map-frames-interval" : "12 months" ,
16- "wms-url-osm" : "https://maps.heigit.org/raster/osm-carto/service?SERVICE=WMS&VERSION=1.1.1" ,
17- "wms-layers-osm" : "heigit:osm-carto-proxy" ,
18- "wms-url-esri-world-imagery" : "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1" ,
19- "wms-url-esri-world-imagery-fallback" : "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1" ,
20- "wms-layers-esri-world-imagery" : "world_imagery" ,
21- "wms-layers-esri-world-imagery-fallback" : "world_imagery_fallback" ,
22- "wms-read-timeout" : 600 ,
23- "max-nr-simultaneous-uploads" : 100 ,
24- "yolo_cls" : "SMT-CLS" ,
25- "yolo_osm_obj" : "SMT-OSM" ,
26- "yolo_esri_obj" : "SMT-ESRI" ,
27- "model_type_sam" : "vit_b" ,
28- "esri-api-key" : "" ,
29- "log-level" : "INFO" ,
30- "point-area-threshold" : 0.00047 ,
31- }
32-
3314
3415def get_config_path () -> str :
3516 """Get configuration file path.
@@ -41,26 +22,55 @@ def get_config_path() -> str:
4122 return os .getenv ("SMT_CONFIG" , default = default )
4223
4324
44- def load_config_from_file (path : str ) -> Dict [str , str ]:
45- """Load configuration from file on disk."""
46- if os .path .isfile (path ):
47- with open (path , "r" ) as f :
48- return toml .load (f )
49- else :
50- return {}
25+ class Config (BaseSettings ):
26+ broker_url : str = "redis://localhost:6379"
27+ cleanup_map_frames_interval : str = "12 months"
28+ data_dir : str = str (get_project_root () / "data" ) # TODO: make this a Path
29+ esri_api_key : str = ""
30+ log_level : str = "INFO"
31+ max_nr_simultaneous_uploads : int = 100
32+ model_type_sam : str = "vit_b"
33+ point_area_threshold : float = 0.00047
34+ result_backend : str = "db+postgresql://smt:smt@localhost:5432"
35+ user_agent : str = "sketch-map-tool"
36+ weights_dir : str = str (get_project_root () / "weights" ) # TODO: make this a Path
37+ wms_layers_esri_world_imagery : str = "world_imagery"
38+ wms_layers_esri_world_imagery_fallback : str = "world_imagery_fallback"
39+ wms_layers_osm : str = "heigit:osm-carto-proxy"
40+ wms_read_timeout : int = 600
41+ wms_url_esri_world_imagery : str = "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1"
42+ wms_url_esri_world_imagery_fallback : str = "https://maps.heigit.org/raster/sketch-map-tool/service?SERVICE=WMS&VERSION=1.1.1"
43+ wms_url_osm : str = (
44+ "https://maps.heigit.org/raster/osm-carto/service?SERVICE=WMS&VERSION=1.1.1"
45+ )
46+ yolo_cls : str = "SMT-CLS"
47+ yolo_esri_obj : str = "SMT-ESRI"
48+ yolo_osm_obj : str = "SMT-OSM"
5149
50+ model_config = SettingsConfigDict (
51+ env_prefix = "SMT_" ,
52+ toml_file = get_config_path (),
53+ )
5254
53- def get_config () -> MappingProxyType :
54- """Get configuration variables from environment and file.
55+ @classmethod
56+ def settings_customise_sources (
57+ cls ,
58+ settings_cls : type [BaseSettings ],
59+ init_settings : PydanticBaseSettingsSource ,
60+ env_settings : PydanticBaseSettingsSource ,
61+ dotenv_settings : PydanticBaseSettingsSource ,
62+ file_secret_settings : PydanticBaseSettingsSource ,
63+ ) -> tuple [PydanticBaseSettingsSource , ...]:
64+ # https://docs.pydantic.dev/latest/concepts/pydantic_settings/#other-settings-source
65+ # env takes precedence over file settings
66+ return (env_settings , TomlConfigSettingsSource (settings_cls ))
5567
56- Configuration values from file will be given precedence over default values.
57- """
58- cfg = DEFAULT_CONFIG
59- cfg_file = load_config_from_file ( get_config_path ())
60- cfg . update ( cfg_file )
61- return MappingProxyType ( cfg )
68+ @ field_validator ( "esri_api_key" , mode = "before" )
69+ @ classmethod
70+ def check_esri_api_key ( cls , value : str ) -> str :
71+ if not value :
72+ logging . warning ( "No ESRI API Key found." )
73+ return value
6274
6375
64- def get_config_value (key : str ):
65- config = get_config ()
66- return config [key ]
76+ CONFIG = Config ()
0 commit comments