Skip to content

Commit 74c5f32

Browse files
typing: adding types to extensions outside of breakeven
1 parent 4b49909 commit 74c5f32

23 files changed

+622
-385
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ python_version = "3.12"
141141
mypy_path = "stubs"
142142

143143
# Exclude specific directories from type checking will try to add them back gradually
144-
exclude = "(?x)(^temoa/extensions/|^temoa/utilities/|^stubs/)"
144+
exclude = "(?x)(^temoa/utilities/|^stubs/|^temoa/extensions/breakeven/)"
145145

146146
# Strict typing for our own code
147147
disallow_untyped_defs = true

temoa/extensions/get_comm_tech.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1+
from __future__ import annotations
2+
13
import getopt
24
import os
35
import re
46
import sqlite3
57
import sys
68
from collections import OrderedDict
9+
from typing import Any
710

811

9-
def get_tperiods(inp_f):
12+
def get_tperiods(inp_f: str) -> dict[str, list[int]]:
1013
file_ty = re.search(r'(\w+)\.(\w+)\b', inp_f) # Extract the input filename and extension
1114

1215
if not file_ty:
13-
raise 'The file type %s is not recognized.' % inp_f
16+
raise Exception(f'The file type {inp_f} is not recognized.')
1417

1518
elif file_ty.group(2) not in ('db', 'sqlite', 'sqlite3', 'sqlitedb'):
1619
raise Exception('Please specify a database for finding scenarios')
1720

18-
periods_list = {}
19-
periods_set = set()
21+
periods_list: dict[str, list[int]] = {}
2022

2123
con = sqlite3.connect(inp_f)
2224
cur = con.cursor() # a database cursor is a control structure that enables traversal over
@@ -30,7 +32,7 @@ def get_tperiods(inp_f):
3032
x.append(row[0])
3133
for y in x:
3234
cur.execute(
33-
"SELECT DISTINCT period FROM output_flow_out WHERE scenario is '" + str(y) + "'"
35+
f"SELECT DISTINCT period FROM output_flow_out WHERE scenario is '{y}'"
3436
)
3537
periods_list[y] = []
3638
for per in cur:
@@ -42,17 +44,16 @@ def get_tperiods(inp_f):
4244
return dict(OrderedDict(sorted(periods_list.items(), key=lambda x: x[1])))
4345

4446

45-
def get_scenario(inp_f):
47+
def get_scenario(inp_f: str) -> dict[str, str]:
4648
file_ty = re.search(r'(\w+)\.(\w+)\b', inp_f) # Extract the input filename and extension
4749

4850
if not file_ty:
49-
raise 'The file type %s is not recognized.' % inp_f
51+
raise Exception(f'The file type {inp_f} is not recognized.')
5052

5153
elif file_ty.group(2) not in ('db', 'sqlite', 'sqlite3', 'sqlitedb'):
5254
raise Exception('Please specify a database for finding scenarios')
5355

54-
scene_list = {}
55-
scene_set = set()
56+
scene_list: dict[str, str] = {}
5657

5758
con = sqlite3.connect(inp_f)
5859
cur = con.cursor() # a database cursor is a control structure that enables traversal over
@@ -70,9 +71,9 @@ def get_scenario(inp_f):
7071
return dict(OrderedDict(sorted(scene_list.items(), key=lambda x: x[1])))
7172

7273

73-
def get_comm(inp_f, db_dat):
74-
comm_list = {}
75-
comm_set = set()
74+
def get_comm(inp_f: str, db_dat: bool) -> OrderedDict[str, str]:
75+
comm_list: dict[str, str] = {}
76+
comm_set: set[str] = set()
7677
is_query_empty = False
7778

7879
if not db_dat:
@@ -138,9 +139,9 @@ def get_comm(inp_f, db_dat):
138139
return OrderedDict(sorted(comm_list.items(), key=lambda x: x[1]))
139140

140141

141-
def get_tech(inp_f, db_dat):
142-
tech_list = {}
143-
tech_set = set()
142+
def get_tech(inp_f: str, db_dat: bool) -> OrderedDict[str, str]:
143+
tech_list: dict[str, str] = {}
144+
tech_set: set[str] = set()
144145
is_query_empty = False
145146

146147
if not db_dat:
@@ -199,13 +200,13 @@ def get_tech(inp_f, db_dat):
199200
return OrderedDict(sorted(tech_list.items(), key=lambda x: x[1]))
200201

201202

202-
def is_db_overwritten(db_file, inp_dat_file):
203+
def is_db_overwritten(db_file: str, inp_dat_file: str) -> bool:
203204
if os.path.basename(db_file) == '0':
204205
return False
205206

206207
try:
207208
con = sqlite3.connect(db_file)
208-
except:
209+
except Exception:
209210
return False
210211
cur = con.cursor() # A database cursor enables traversal over DB records
211212
con.text_factory = str # This ensures data is explored with UTF-8 encoding
@@ -214,15 +215,15 @@ def is_db_overwritten(db_file, inp_dat_file):
214215
# IF output file is empty database.
215216
cur.execute('SELECT * FROM Technology')
216217
is_db_empty = False # False for empty db file
217-
for elem in cur:
218+
for _ in cur:
218219
is_db_empty = True # True for non-empty db file
219220
break
220221
# This file could be schema with populated results from previous run. Or it could be a normal
221222
# db file.
222223
if is_db_empty:
223224
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='input_file';")
224225
does_input_file_table_exist = False
225-
for i in cur: # This means that the 'input_file' table exists in db.
226+
for _ in cur: # This means that the 'input_file' table exists in db.
226227
does_input_file_table_exist = True
227228
if does_input_file_table_exist: # This block distinguishes normal database from schema.
228229
# This is schema file.
@@ -247,7 +248,7 @@ def is_db_overwritten(db_file, inp_dat_file):
247248
return False
248249

249250

250-
def help_user():
251+
def help_user() -> None:
251252
print(
252253
"""Use as:
253254
python get_comm_tech.py -i (or --input) <input filename>
@@ -259,8 +260,8 @@ def help_user():
259260
)
260261

261262

262-
def get_info(inputs):
263-
inp_file = None
263+
def get_info(inputs: dict[str, str]) -> Any:
264+
inp_file: str | None = None
264265
tech_flag = False
265266
comm_flag = False
266267
scene = False
@@ -317,8 +318,8 @@ def get_info(inputs):
317318

318319
else:
319320
print(
320-
'The input file type %s is not recognized. Please specify a database or a text file.'
321-
% inp_file
321+
f'The input file type {inp_file} is not recognized. Please specify a database '
322+
'or a text file.'
322323
)
323324
sys.exit(2)
324325

temoa/extensions/method_of_morris/morris.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
# from __future__ import division
2-
import time
1+
from __future__ import annotations
2+
3+
import csv
4+
import sqlite3
35
from importlib import resources
46
from pathlib import Path
7+
from typing import Any
58

9+
from joblib import Parallel, delayed # type: ignore[import-untyped]
10+
from numpy import array
611
from pyomo.dataportal import DataPortal
12+
from SALib.analyze import morris # type: ignore[import-untyped]
13+
from SALib.sample.morris import sample # type: ignore[import-untyped]
14+
from SALib.util import compute_groups_matrix, read_param_file # type: ignore[import-untyped]
715

816
from temoa._internal import run_actions
917
from temoa._internal.table_writer import TableWriter
1018
from temoa.core.config import TemoaConfig
1119
from temoa.data_io.hybrid_loader import HybridLoader
1220

13-
start_time = time.time()
14-
import sqlite3
15-
16-
from joblib import Parallel, delayed
17-
from numpy import array
18-
from SALib.analyze import morris
19-
from SALib.sample.morris import sample
20-
from SALib.util import compute_groups_matrix, read_param_file
21-
2221
seed = 42
2322

2423

25-
def evaluate(param_names, param_values, data: dict, k):
24+
def evaluate(param_names: dict[int, list[Any]], param_values: Any,
25+
data: dict[str, Any], k: int) -> list[Any]:
2626
m = len(param_values)
2727
for j in range(0, m):
2828
names = param_names[j]
@@ -51,16 +51,16 @@ def evaluate(param_names, param_values, data: dict, k):
5151
cur.execute('SELECT * FROM output_objective')
5252
output_query = cur.fetchall()
5353
for row in output_query:
54-
Y_OF = row[-1]
54+
y_of = row[-1]
5555
cur.execute("SELECT emis_comm, SUM(emission) FROM output_emissionn WHERE emis_comm='co2'")
5656
output_query = cur.fetchall()
5757
for row in output_query:
58-
Y_CumulativeCO2 = row[-1]
59-
Morris_Objectives = []
60-
Morris_Objectives.append(Y_OF)
61-
Morris_Objectives.append(Y_CumulativeCO2)
58+
y_cumulative_co2 = row[-1]
59+
morris_objectives = []
60+
morris_objectives.append(y_of)
61+
morris_objectives.append(y_cumulative_co2)
6262
con.close()
63-
return Morris_Objectives
63+
return morris_objectives
6464

6565

6666
morris_root = Path(__file__).parent
@@ -137,7 +137,7 @@ def evaluate(param_names, param_values, data: dict, k):
137137
file.write('\n')
138138

139139
# load a data portal, retrieve the data dict for the problem
140-
config = TemoaConfig.build_config(config_file=config_path, output_path='.')
140+
config = TemoaConfig.build_config(config_file=config_path, output_path=Path('.'))
141141
loader = HybridLoader(db_connection=con, config=config)
142142
loader.load_data_portal()
143143
data = loader.data
@@ -157,7 +157,7 @@ def evaluate(param_names, param_values, data: dict, k):
157157
)
158158
Morris_Objectives = array(Morris_Objectives)
159159
print(Morris_Objectives)
160-
Si_OF = morris.analyze(
160+
si_of = morris.analyze(
161161
problem,
162162
param_values,
163163
Morris_Objectives[:, 0],
@@ -168,7 +168,7 @@ def evaluate(param_names, param_values, data: dict, k):
168168
seed=seed + 1,
169169
)
170170

171-
Si_CumulativeCO2 = morris.analyze(
171+
si_cumulative_co2 = morris.analyze(
172172
problem,
173173
param_values,
174174
Morris_Objectives[:, 1],
@@ -189,19 +189,18 @@ def evaluate(param_names, param_values, data: dict, k):
189189
for j in list(range(number_of_groups)):
190190
print(
191191
'{:30} {:10.3f} {:10.3f} {:15.3f} {:10.3f}'.format(
192-
Si_OF['names'][j],
193-
Si_OF['mu_star'][j],
194-
Si_OF['mu'][j],
195-
Si_OF['mu_star_conf'][j],
196-
Si_OF['sigma'][j],
192+
si_of['names'][j],
193+
si_of['mu_star'][j],
194+
si_of['mu'][j],
195+
si_of['mu_star_conf'][j],
196+
si_of['sigma'][j],
197197
)
198198
)
199-
import csv
200199

201-
line1 = Si_OF['mu_star']
202-
line2 = Si_OF['mu_star_conf']
203-
line3 = Si_CumulativeCO2['mu_star']
204-
line4 = Si_CumulativeCO2['mu_star_conf']
200+
line1 = si_of['mu_star']
201+
line2 = si_of['mu_star_conf']
202+
line3 = si_cumulative_co2['mu_star']
203+
line4 = si_cumulative_co2['mu_star_conf']
205204
with open('MMResults.csv', 'w') as f:
206205
writer = csv.writer(f, delimiter=',')
207206
writer.writerow(unique_group_names)
@@ -211,6 +210,3 @@ def evaluate(param_names, param_values, data: dict, k):
211210
writer.writerow('Cumulative CO2 Emissions')
212211
writer.writerow(line3)
213212
writer.writerow(line4)
214-
215-
f.close
216-
print('--- %s seconds ---' % (time.time() - start_time))

temoa/extensions/method_of_morris/morris_evaluate.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
This module contains the core "evaluation" function for Method Of Morris. It needs to be isolated
33
(outside of class) to enable parallelization.
44
"""
5+
from __future__ import annotations
56

67
import logging
78
import sqlite3
89
import sys
910
from logging.handlers import QueueHandler
11+
from typing import TYPE_CHECKING, Any
1012

1113
from pyomo.dataportal import DataPortal
1214

1315
from temoa._internal import run_actions
1416
from temoa._internal.table_writer import TableWriter
15-
from temoa.core.config import TemoaConfig
1617

18+
if TYPE_CHECKING:
19+
from temoa.core.config import TemoaConfig
1720

18-
def configure_worker_logger(log_queue, log_level):
21+
22+
23+
def configure_worker_logger(log_queue: Any, log_level: int) -> logging.Logger:
1924
"""configure the logger"""
2025
worker_logger = logging.getLogger('MM evaluate')
2126
if not worker_logger.hasHandlers():
@@ -30,7 +35,8 @@ def configure_worker_logger(log_queue, log_level):
3035
return worker_logger
3136

3237

33-
def evaluate(param_info, mm_sample, data, i, config: TemoaConfig, log_queue, log_level):
38+
def evaluate(param_info: dict[int, list[Any]], mm_sample: Any, data: dict[str, Any],
39+
i: int, config: TemoaConfig, log_queue: Any, log_level: int) -> list[float]:
3440
"""
3541
Run model for params provided and return objective value and emission value
3642
Note: This function needs to be a static instance to enable the parallel
@@ -49,19 +55,14 @@ def evaluate(param_info, mm_sample, data, i, config: TemoaConfig, log_queue, log
4955
log_entry = ['']
5056
for j in range(0, len(mm_sample)):
5157
param_name, *set_idx, _ = param_info[j]
52-
set_idx = tuple(set_idx)
58+
set_idx_tuple = tuple(set_idx)
5359
# tweak the parameter
5460
if data.get(param_name) is None:
5561
raise ValueError(f'Unrecognized parameter: {param_name}')
56-
if data[param_name].get(set_idx) is None:
62+
if data[param_name].get(set_idx_tuple) is None:
5763
raise ValueError('index mismatch from data read-in')
58-
data[param_name][set_idx] = mm_sample[j]
59-
setting_entry = 'run # %d: Setting param %s[%s] to value: %f' % (
60-
i + 1,
61-
param_name,
62-
set_idx,
63-
mm_sample[j],
64-
)
64+
data[param_name][set_idx_tuple] = mm_sample[j]
65+
setting_entry = 'run # %d: Setting param %s[%s] to value: %f'
6566
log_entry.append(setting_entry)
6667
logger.debug('\n '.join(log_entry))
6768

@@ -87,23 +88,23 @@ def evaluate(param_info, mm_sample, data, i, config: TemoaConfig, log_queue, log
8788
'Multiple outputs found in Objective table matching scenario name. Coding error.'
8889
)
8990
else:
90-
Y_OF = output_query[0][0]
91+
y_of = output_query[0][0]
9192
cur.execute(
9293
"SELECT SUM(emission) FROM output_emission WHERE emis_comm='co2' AND scenario=?",
9394
(scenario_name,),
9495
)
9596
output_query = cur.fetchall()
9697
if len(output_query) == 0:
97-
Y_CumulativeCO2 = 0.0
98+
y_cumulative_co2 = 0.0
9899
elif len(output_query) > 1:
99100
raise RuntimeError(
100101
'Multiple outputs found in output_emissions table matching scenario name. Coding '
101102
'error.'
102103
)
103104
else:
104-
Y_CumulativeCO2 = output_query[0][0]
105-
morris_objectives = [float(Y_OF), float(Y_CumulativeCO2)]
106-
logger.info('Finished MM evaluation # %d with OBJ value: %0.2f ', i + 1, Y_OF)
105+
y_cumulative_co2 = output_query[0][0]
106+
morris_objectives = [float(y_of), float(y_cumulative_co2)]
107+
logger.info('Finished MM evaluation # %d with OBJ value: %0.2f ', i + 1, y_of)
107108
if not config.silent:
108109
sys.stdout.write(f'Completed MM run {i + 1}\n')
109110
sys.stdout.flush()

0 commit comments

Comments
 (0)