-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathrunner.py
More file actions
192 lines (150 loc) · 6.09 KB
/
runner.py
File metadata and controls
192 lines (150 loc) · 6.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import asyncio
import logging
import re
from pathlib import Path
import logfire
from psycopg import AsyncConnection, AsyncCursor
from semver import Version
from tiger_agent import __version__
from tiger_agent.log_config import setup_logging
logger = logging.getLogger(__name__)
SHARED_LOCK_KEY = 31321898691465844
MAX_LOCK_ATTEMPTS = 10
LOCK_SLEEP_SECONDS = 10
@logfire.instrument("try_migration_lock", extract_args=False)
async def try_migration_lock(cur: AsyncCursor) -> bool:
"""Attempt to acquire transaction-level advisory lock for migration"""
await cur.execute(
"select pg_try_advisory_xact_lock(%s::bigint)", (SHARED_LOCK_KEY,)
)
row = await cur.fetchone()
if not row:
raise Exception(
"attempting to get an advisory lock for migration failed to return a row"
)
return bool(row[0])
@logfire.instrument("run_init", extract_args=False)
async def run_init(cur: AsyncCursor) -> None:
"""Initialize migration infrastructure tables"""
sql = Path(__file__).parent.joinpath("sql", "init.sql").read_text()
await cur.execute(sql)
async def get_db_version(cur: AsyncCursor) -> Version:
"""Get current database version"""
await cur.execute("select version from agent.version")
row = await cur.fetchone()
assert row is not None
ver = Version.parse(str(row[0]))
return ver
@logfire.instrument("is_migration_required", extract_args=["target_version"])
async def is_migration_required(cur: AsyncCursor, target_version: Version) -> bool:
"""Check if migration is required"""
db_version = await get_db_version(cur)
if target_version < db_version:
logger.error(
f"target version ({target_version}) is older than the database ({db_version})! aborting"
)
raise ValueError(
f"Cannot downgrade from version {db_version} to {target_version}"
)
return target_version > db_version
def sql_file_number(path: Path) -> int:
"""Extract number from SQL filename"""
pattern = r"^(\d{3})-[a-z][a-z-]*\.sql$"
match = re.match(pattern, path.name)
if not match:
logger.error(f"{path} file name does not match the pattern {pattern}")
raise ValueError(f"Invalid filename pattern: {path.name}")
return int(match.group(1))
def check_sql_file_order(paths: list[Path]) -> None:
"""Verify SQL files are in sequential order"""
prev = -1
for path in paths:
this = sql_file_number(path)
if this == 999:
break
if this != prev + 1:
logger.error(f"sql files must be strictly ordered: {path.name}")
raise ValueError(f"SQL files not in sequential order at {path.name}")
prev = this
async def run_incremental(cur: AsyncCursor, target_version: Version) -> None:
"""Run incremental migrations"""
migration_template = (
Path(__file__).parent.joinpath("sql", "migration.sql").read_text()
)
incremental = Path(__file__).parent.joinpath("incremental")
paths = [path for path in incremental.glob("*.sql")]
paths.sort()
check_sql_file_order(paths)
for path in paths:
with logfire.span("incremental_sql", script=path.name):
sql = migration_template.format(
migration_name=path.name,
migration_body=path.read_text(),
version=str(target_version),
)
await cur.execute(sql)
async def run_idempotent(cur: AsyncCursor) -> None:
"""Run idempotent SQL that can run multiple times"""
idempotent = Path(__file__).parent.joinpath("idempotent")
paths = [path for path in idempotent.glob("*.sql")]
paths.sort()
check_sql_file_order(paths)
for path in paths:
with logfire.span("idempotent_sql", script=path.name):
sql = path.read_text()
await cur.execute(sql)
@logfire.instrument("set_version", extract_args=["version"])
async def set_version(cur: AsyncCursor, version: Version) -> None:
"""Update database version"""
await cur.execute(
"update agent.version set version = %s, at = clock_timestamp()", (str(version),)
)
@logfire.instrument("migrate_db", extract_args=False)
async def migrate_db(con: AsyncConnection) -> None:
"""Run database migrations"""
target_version = Version.parse(__version__)
async with (
con.cursor() as cur,
con.transaction() as _,
):
# Try to acquire migration lock
for i in range(1, MAX_LOCK_ATTEMPTS + 1):
locked = await try_migration_lock(cur)
if locked:
break
if i == MAX_LOCK_ATTEMPTS:
logger.error(
f"failed to get an advisory lock to check database version after {i} attempts"
)
raise RuntimeError("Could not acquire migration lock")
logger.info(
f"sleeping {LOCK_SLEEP_SECONDS} seconds before another lock attempt"
)
await asyncio.sleep(LOCK_SLEEP_SECONDS)
# Initialize migration infrastructure
await run_init(cur)
# Check if migration is required
if not await is_migration_required(cur, target_version):
logger.info("no migration required. app and db are compatible.")
return
logger.info(f"database migration to version {target_version} required...")
# Run migrations
await run_incremental(cur, target_version)
await run_idempotent(cur)
await set_version(cur, target_version)
logger.info(f"database migration to version {target_version} complete")
async def main():
"""Run database migrations CLI"""
import os
from dotenv import find_dotenv, load_dotenv
# Load environment variables
load_dotenv(dotenv_path=find_dotenv(usecwd=True))
service_name = os.environ.get("SERVICE_NAME")
setup_logging(service_name)
logfire.instrument_psycopg()
logger.info("Starting database migration...")
async with await AsyncConnection.connect() as con:
await migrate_db(con)
logger.info("Database migration completed successfully")
if __name__ == "__main__":
asyncio.run(main())