Skip to content

Commit 68fae1c

Browse files
refactor (embedding utils): refactor add_embeddings_to_source (#36)
* refactor (embedding utils): refactor add_embeddings_to_source * refactor (embedding utils): use list comprehension * refactor (transform): simplify structure of named tuple * refactor (embedding utils): use named tuple instead of tuple * refactor (embedding utils): use list comprehension instead of map * refactor (embedding utils): use list comprehension instead of map * test (embedding utils): add test for preprocess_batch * refactor (embedding utils): simplify code * refactor (embedding utils): split list comprehensions into multiple lines
1 parent a31e0e6 commit 68fae1c

File tree

4 files changed

+174
-77
lines changed

4 files changed

+174
-77
lines changed

src/tasks.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from config.opensearch_config import OpenSearchConfig
1414
from utils.queue_utils import HarvestEventQueue
1515
from utils.embedding_utils import preprocess_batch, add_embeddings_to_source, SourceWithEmbeddingText, \
16-
get_embedding_text_from_fields
16+
get_embedding_text_from_fields, OpenSearchSourceWithEmbedding
1717
from utils import normalize_datacite_json
1818
from typing import Any
1919
from celery.utils.log import get_task_logger
@@ -164,7 +164,6 @@ def transform_batch(self: Any, batch: list[HarvestEventQueue], index_name: str)
164164
validate(instance=normalized_record, schema=self.schema)
165165
normalized.append(SourceWithEmbeddingText(src=normalized_record,
166166
textToEmbed=get_embedding_text_from_fields(normalized_record),
167-
file=Path(''),
168167
event=harvest_event
169168
))
170169

@@ -182,10 +181,10 @@ def transform_batch(self: Any, batch: list[HarvestEventQueue], index_name: str)
182181

183182
try:
184183
logger.info(f'About to Calculate embeddings for {len(normalized)}')
185-
src_with_emb: list[tuple[dict[str, Any], SourceWithEmbeddingText]] = add_embeddings_to_source(normalized,
184+
src_with_emb: list[OpenSearchSourceWithEmbedding] = add_embeddings_to_source(normalized,
186185
self.embedding_transformer)
187186
logger.info(f'Calculated embeddings for {len(src_with_emb)}')
188-
preprocessed = preprocess_batch(list(map(lambda el: el[0], src_with_emb)), index_name)
187+
preprocessed = preprocess_batch([src_with_emb_ele.src for src_with_emb_ele in src_with_emb], index_name)
189188
except Exception as e:
190189
logger.error(f'Could not calculate embeddings: {e}')
191190
raise e
@@ -201,25 +200,20 @@ def transform_batch(self: Any, batch: list[HarvestEventQueue], index_name: str)
201200
for rec in src_with_emb:
202201
# write to records table
203202

204-
if rec[1].event is None:
205-
raise ValueError(f'Original HarvestEvent not found')
206-
207-
#logger.info(rec[1].event.record_identifier)
208-
209-
record_identifier = rec[1].event.record_identifier
210-
datestamp = rec[1].event.datestamp
211-
repository_id = rec[1].event.repository_id
212-
endpoint_id = rec[1].event.endpoint_id
203+
record_identifier = rec.harvest_event.record_identifier
204+
datestamp = rec.harvest_event.datestamp
205+
repository_id = rec.harvest_event.repository_id
206+
endpoint_id = rec.harvest_event.endpoint_id
213207
resource_type = 'Dataset' # TODO: get this information from record
214-
title = rec[0]['titles'][0]['title']
215-
xml = rec[1].event.xml
208+
title = rec.src['titles'][0]['title']
209+
xml = rec.harvest_event.xml
216210
protocol = 'OAI-PMH'
217-
doi = rec[0].get('doi')
218-
url = rec[0].get('url')
219-
embeddings = rec[0]['emb']
220-
datacite_json = json.dumps({**rec[0], 'emb': None})
211+
doi = rec.src.get('doi')
212+
url = rec.src.get('url')
213+
embeddings = rec.src['emb']
214+
datacite_json = json.dumps({**rec.src, 'emb': None})
221215
opensearch_synced = True
222-
additional_metadata = rec[1].event.additional_metadata
216+
additional_metadata = rec.harvest_event.additional_metadata
223217

224218
# https://neon.com/postgresql/postgresql-tutorial/postgresql-upsert
225219
cur.execute("""
@@ -282,7 +276,7 @@ def transform_batch(self: Any, batch: list[HarvestEventQueue], index_name: str)
282276
UPDATE harvest_events
283277
SET error_message = NULL
284278
WHERE id = %s
285-
""", [rec[1].event.id]
279+
""", [rec.harvest_event.id]
286280
)
287281

288282
except BulkIndexError as e:

src/transform.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ class AdditionalMetadataParams(BaseModel):
6262
endpoint: str
6363
protocol: str
6464

65+
6566
class HarvestParams(BaseModel):
6667
metadata_prefix: str
6768
set: Optional[list[str]]
6869
additional_metadata_params: Optional[AdditionalMetadataParams]
6970

71+
7072
class EndpointConfig(BaseModel):
7173
name: str
7274
harvest_url: str
@@ -82,8 +84,8 @@ class Config(BaseModel):
8284
class HarvestEventCreateRequest(BaseModel):
8385
record_identifier: str
8486
datestamp: datetime
85-
raw_metadata: str # XML
86-
additional_metadata: Optional[str] = None # XML or JSON (stringified)
87+
raw_metadata: str # XML
88+
additional_metadata: Optional[str] = None # XML or JSON (stringified)
8789
harvest_url: str
8890
repo_code: str
8991
harvest_run_id: str
@@ -97,6 +99,7 @@ class HarvestEventCreateResponse(BaseModel):
9799
class HarvestRunCreateRequest(BaseModel):
98100
harvest_url: str
99101

102+
100103
class HarvestRunGetResponse(BaseModel):
101104
id: Optional[str] = Field(None, description='ID of the harvest run')
102105
status: Optional[str] = Field(None, description='Status of the harvest run: open|closed|failed')
@@ -121,7 +124,8 @@ class HarvestRunCloseResponse(BaseModel):
121124

122125

123126
def get_latest_harvest_run_in_db(harvest_url: str) -> HarvestRunGetResponse:
124-
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address, password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
127+
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address,
128+
password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
125129

126130
cur = conn.cursor()
127131

@@ -141,15 +145,16 @@ def get_latest_harvest_run_in_db(harvest_url: str) -> HarvestRunGetResponse:
141145
else:
142146
return HarvestRunGetResponse(id=None, status=None)
143147

148+
144149
def create_harvest_run_in_db(harvest_url: str) -> HarvestRunCreateResponse:
145150
"""
146151
Creates a new entry in harvest_runs and returns its id.
147152
148153
:param harvest_url: The new entry to be created.
149154
"""
150155

151-
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address, password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
152-
156+
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address,
157+
password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
153158
cur = conn.cursor()
154159
# TODO: only allow one open harvest run per endpoint
155160
# TODO check (in one transaction):
@@ -192,15 +197,19 @@ def create_harvest_run_in_db(harvest_url: str) -> HarvestRunCreateResponse:
192197
id=str(new_harvest_run['id']),
193198
from_date=new_harvest_run['from_date'],
194199
until_date=new_harvest_run['until_date'],
195-
endpoint_config=EndpointConfig(name=new_harvest_run['name'], harvest_url=new_harvest_run['harvest_url'], code=new_harvest_run['code'], protocol=new_harvest_run['protocol'],
196-
harvest_params=HarvestParams(metadata_prefix=new_harvest_run['harvest_params'].get('metadata_prefix'), set=new_harvest_run['harvest_params'].get('set'), additional_metadata_params=new_harvest_run['harvest_params'].get('additional_metadata_params')))
200+
endpoint_config=EndpointConfig(name=new_harvest_run['name'], harvest_url=new_harvest_run['harvest_url'],
201+
code=new_harvest_run['code'], protocol=new_harvest_run['protocol'],
202+
harvest_params=HarvestParams(
203+
metadata_prefix=new_harvest_run['harvest_params'].get('metadata_prefix'),
204+
set=new_harvest_run['harvest_params'].get('set'),
205+
additional_metadata_params=new_harvest_run['harvest_params'].get(
206+
'additional_metadata_params')))
197207
)
198208

199209

200210
def close_harvest_run_in_db(harvest_run: HarvestRunCloseRequest) -> HarvestRunCloseResponse:
201211
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address,
202212
password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
203-
204213
cur = conn.cursor()
205214

206215
state = 'closed' if harvest_run.success else 'failed'
@@ -224,15 +233,16 @@ def close_harvest_run_in_db(harvest_run: HarvestRunCloseRequest) -> HarvestRunCl
224233

225234
return HarvestRunCloseResponse(id=harvest_run.id)
226235

236+
227237
def create_harvest_event_in_db(harvest_event: HarvestEventCreateRequest) -> HarvestEventCreateResponse:
228238
"""
229239
Creates a record in table harvest_events
230240
231241
:param harvest_event: The new record to be created.
232242
"""
233243

234-
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address, password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
235-
244+
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address,
245+
password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
236246
cur = conn.cursor()
237247

238248
cur.execute("""
@@ -260,8 +270,9 @@ def create_harvest_event_in_db(harvest_event: HarvestEventCreateRequest) -> Harv
260270
(SELECT id FROM harvest_runs WHERE id = %s and status = 'open'),
261271
%s
262272
);
263-
""", (harvest_event.record_identifier, harvest_event.datestamp, harvest_event.raw_metadata, harvest_event.additional_metadata, harvest_event.repo_code, harvest_event.harvest_url, 'OAI-PMH', 'XML', harvest_event.harvest_run_id, harvest_event.is_deleted))
264-
273+
""", (harvest_event.record_identifier, harvest_event.datestamp, harvest_event.raw_metadata,
274+
harvest_event.additional_metadata, harvest_event.repo_code, harvest_event.harvest_url,
275+
'OAI-PMH', 'XML', harvest_event.harvest_run_id, harvest_event.is_deleted))
265276

266277
cur.execute("""
267278
SELECT id
@@ -301,10 +312,14 @@ def get_config_from_db() -> list[EndpointConfig]:
301312
JOIN repositories r ON e.repository_id = r.id
302313
""")
303314
for doc in cur.fetchall():
304-
305315
endpoints.append(
306-
EndpointConfig(name=doc['name'], harvest_url=doc['harvest_url'], code=doc['code'], protocol=doc['protocol'],
307-
harvest_params=HarvestParams(metadata_prefix=doc['harvest_params'].get('metadata_prefix'), set=doc['harvest_params'].get('set'), additional_metadata_params=doc['harvest_params'].get('additional_metadata_params'))))
316+
EndpointConfig(name=doc['name'], harvest_url=doc['harvest_url'], code=doc['code'],
317+
protocol=doc['protocol'],
318+
harvest_params=HarvestParams(
319+
metadata_prefix=doc['harvest_params'].get('metadata_prefix'),
320+
set=doc['harvest_params'].get('set'),
321+
additional_metadata_params=doc['harvest_params'].get(
322+
'additional_metadata_params'))))
308323

309324
return endpoints
310325
except JSONDecodeError as e:
@@ -331,7 +346,8 @@ def create_jobs_in_queue(harvest_run_id: str) -> int:
331346

332347
logger.info(f'Preparing jobs')
333348

334-
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address, password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
349+
with psycopg.connect(dbname=postgres_config.user, user=postgres_config.user, host=postgres_config.address,
350+
password=postgres_config.password, port=postgres_config.port, row_factory=dict_row) as conn:
335351

336352
cur = conn.cursor()
337353

@@ -363,13 +379,15 @@ def create_jobs_in_queue(harvest_run_id: str) -> int:
363379
""", (harvest_run_id, limit, offset))
364380

365381
for doc in cur.fetchall():
366-
367382
# https://www.psycopg.org/psycopg3/docs/basic/adapt.html#uuid-adaptation
368383
# https://docs.python.org/3/library/uuid.html#uuid.UUID
369384
# str(uuid) returns a string in the form 12345678-1234-5678-1234-567812345678 where the 32 hexadecimal digits represent the UUID.
370385
batch.append(
371386
HarvestEventQueue(id=str(doc['id']), xml=doc['record'], repository_id=str(doc['repository_id']),
372-
endpoint_id=str(doc['endpoint_id']), record_identifier=doc['record_identifier'], code=doc['code'], harvest_url=doc['harvest_url'], additional_metadata=doc['additional_metadata'], is_deleted=doc['is_deleted'], datestamp=doc['datestamp'].strftime('%Y-%m-%d %H:%M:%S.%f%z'))
387+
endpoint_id=str(doc['endpoint_id']), record_identifier=doc['record_identifier'],
388+
code=doc['code'], harvest_url=doc['harvest_url'],
389+
additional_metadata=doc['additional_metadata'], is_deleted=doc['is_deleted'],
390+
datestamp=doc['datestamp'].strftime('%Y-%m-%d %H:%M:%S.%f%z'))
373391
)
374392

375393
if len(batch) == 0:
@@ -385,13 +403,15 @@ def create_jobs_in_queue(harvest_run_id: str) -> int:
385403
offset += limit
386404
# will be false if query returned fewer results than limit
387405
fetch = len(batch) == limit
388-
#fetch = False
406+
# fetch = False
389407
batch = []
390408

391409
return tasks
392410

411+
393412
@app.get('/index', tags=['index'])
394-
def init_index(harvest_run_id: str = Query(default=None, description='Id of the harvest run to be indexed')) -> IndexGetResponse:
413+
def init_index(
414+
harvest_run_id: str = Query(default=None, description='Id of the harvest run to be indexed')) -> IndexGetResponse:
395415
# this long-running method is synchronous and runs in an external threadpool, see https://fastapi.tiangolo.com/async/#path-operation-functions
396416
# this way, it does not block the server
397417
try:
@@ -422,24 +442,31 @@ def get_config() -> Config:
422442
@app.post('/harvest_event', tags=['harvest_event'], summary='Register a new harvest event')
423443
def create_harvest_event(harvest_event: HarvestEventCreateRequest) -> HarvestEventCreateResponse:
424444
try:
425-
#logger.debug(harvest_event)
445+
# logger.debug(harvest_event)
426446
return create_harvest_event_in_db(harvest_event)
427447
except psycopg_errors.UniqueViolation as e:
428448
logger.exception(f'Harvest event could not be created for given harvest run')
429-
raise HTTPException(status_code=400, detail='Harvest event could not be created for the given harvest run because the record identifier already exists.')
449+
raise HTTPException(status_code=400,
450+
detail='Harvest event could not be created for the given harvest run because the record identifier already exists.')
430451
except Exception as e:
431452
logger.exception(f'An error occurred when creating harvest event: {e}')
432453
raise HTTPException(status_code=500, detail=str(e))
433454

434-
@app.get('/harvest_run', tags=['harvest_run'], summary='Get id and status of the latest harvest run for a given endpoint.', description='If no harvest run exists for the given endpoint, id and status will be null in the response.')
435-
def get_harvest_run(harvest_url: str = Query(default=None, description='harvest url of the endpoint')) -> HarvestRunGetResponse:
455+
456+
@app.get('/harvest_run', tags=['harvest_run'],
457+
summary='Get id and status of the latest harvest run for a given endpoint.',
458+
description='If no harvest run exists for the given endpoint, id and status will be null in the response.')
459+
def get_harvest_run(
460+
harvest_url: str = Query(default=None, description='harvest url of the endpoint')) -> HarvestRunGetResponse:
436461
try:
437462
return get_latest_harvest_run_in_db(harvest_url)
438463
except Exception as e:
439464
logger.exception(f'An error occurred when getting harvest run: {e}')
440465
raise HTTPException(status_code=500, detail=str(e))
441466

442-
@app.post('/harvest_run', tags=['harvest_run'], summary='Create a new havest run for a given endpoint.', description='A new harvest run can only be created if no other open harvest run exists for the same endpoint.')
467+
468+
@app.post('/harvest_run', tags=['harvest_run'], summary='Create a new havest run for a given endpoint.',
469+
description='A new harvest run can only be created if no other open harvest run exists for the same endpoint.')
443470
def create_harvest_run(harvest_run: HarvestRunCreateRequest) -> HarvestRunCreateResponse:
444471
try:
445472
logger.debug(harvest_run)
@@ -451,11 +478,11 @@ def create_harvest_run(harvest_run: HarvestRunCreateRequest) -> HarvestRunCreate
451478
logger.exception(f'An error occurred when creating harvest event: {e}')
452479
raise HTTPException(status_code=500, detail=str(e))
453480

481+
454482
@app.put('/harvest_run', tags=['harvest_run'], summary='Close an open harvest run for a given endpoint.')
455483
def close_harvest_run(harvest_run: HarvestRunCloseRequest) -> HarvestRunCloseResponse:
456484
try:
457485
return close_harvest_run_in_db(harvest_run)
458486
except Exception as e:
459487
logger.exception(f'An error occurred when closing harvest event: {e}')
460488
raise HTTPException(status_code=500, detail=str(e))
461-

0 commit comments

Comments
 (0)