Skip to content

Commit f311c89

Browse files
committed
Add type hint for API classes
1 parent ca3f12d commit f311c89

16 files changed

+448
-115
lines changed

tdclient/bulk_import_api.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
#!/usr/bin/env python
22

3+
from __future__ import annotations
4+
35
import collections
46
import contextlib
57
import gzip
68
import io
79
import os
10+
from collections.abc import Iterator
11+
from typing import TYPE_CHECKING, Any
812

913
import msgpack
1014

11-
from .util import create_url
15+
if TYPE_CHECKING:
16+
from contextlib import AbstractContextManager
17+
from typing import IO
18+
19+
import urllib3
20+
21+
from tdclient.types import BulkImportParams, BytesOrStream, DataFormat, FileLike
22+
from tdclient.util import create_url
1223

1324

1425
class BulkImportAPI:
@@ -17,7 +28,27 @@ class BulkImportAPI:
1728
This class is inherited by :class:`tdclient.api.API`.
1829
"""
1930

20-
def create_bulk_import(self, name, db, table, params=None):
31+
# Methods from API class
32+
def get(
33+
self, url: str, params: dict[str, Any] | None = None
34+
) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ...
35+
def post(
36+
self, url: str, params: dict[str, Any] | None = None
37+
) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ...
38+
def put(
39+
self, url: str, stream: BytesOrStream, size: int
40+
) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ...
41+
def raise_error(
42+
self, msg: str, res: urllib3.BaseHTTPResponse, body: bytes | str
43+
) -> None: ...
44+
def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ...
45+
def _prepare_file(
46+
self, file: FileLike, format: str, **kwargs: Any
47+
) -> IO[bytes]: ...
48+
49+
def create_bulk_import(
50+
self, name: str, db: str, table: str, params: BulkImportParams | None = None
51+
) -> bool:
2152
"""Enable bulk importing of data to the targeted database and table and stores
2253
it in the default resource pool. Default expiration for bulk import is 30days.
2354
@@ -30,22 +61,24 @@ def create_bulk_import(self, name, db, table, params=None):
3061
Returns:
3162
True if succeeded
3263
"""
33-
params = {} if params is None else params
64+
post_params = {} if params is None else dict(params)
3465
with self.post(
3566
create_url(
3667
"/v3/bulk_import/create/{name}/{db}/{table}",
3768
name=name,
3869
db=db,
3970
table=table,
4071
),
41-
params,
72+
post_params,
4273
) as res:
4374
code, body = res.status, res.read()
4475
if code != 200:
4576
self.raise_error("Create bulk import failed", res, body)
4677
return True
4778

48-
def delete_bulk_import(self, name, params=None):
79+
def delete_bulk_import(
80+
self, name: str, params: dict[str, Any] | None = None
81+
) -> bool:
4982
"""Delete the imported information with the specified name
5083
5184
Args:
@@ -63,7 +96,7 @@ def delete_bulk_import(self, name, params=None):
6396
self.raise_error("Delete bulk import failed", res, body)
6497
return True
6598

66-
def show_bulk_import(self, name):
99+
def show_bulk_import(self, name: str) -> dict[str, Any]:
67100
"""Show the details of the bulk import with the specified name
68101
69102
Args:
@@ -78,7 +111,9 @@ def show_bulk_import(self, name):
78111
js = self.checked_json(body, ["status"])
79112
return js
80113

81-
def list_bulk_imports(self, params=None):
114+
def list_bulk_imports(
115+
self, params: dict[str, Any] | None = None
116+
) -> list[dict[str, Any]]:
82117
"""Return the list of available bulk imports
83118
Args:
84119
params (dict, optional): Extra parameters.
@@ -93,7 +128,9 @@ def list_bulk_imports(self, params=None):
93128
js = self.checked_json(body, ["bulk_imports"])
94129
return js["bulk_imports"]
95130

96-
def list_bulk_import_parts(self, name, params=None):
131+
def list_bulk_import_parts(
132+
self, name: str, params: dict[str, Any] | None = None
133+
) -> list[str]:
97134
"""Return the list of available parts uploaded through
98135
:func:`~BulkImportAPI.bulk_import_upload_part`.
99136
@@ -114,7 +151,7 @@ def list_bulk_import_parts(self, name, params=None):
114151
return js["parts"]
115152

116153
@staticmethod
117-
def validate_part_name(part_name):
154+
def validate_part_name(part_name: str) -> None:
118155
"""Make sure the part_name is valid
119156
120157
Args:
@@ -133,7 +170,9 @@ def validate_part_name(part_name):
133170
if 0 < part_name.find("/"):
134171
raise ValueError("part name must not contain '/': %s" % (repr(part_name)))
135172

136-
def bulk_import_upload_part(self, name, part_name, stream, size):
173+
def bulk_import_upload_part(
174+
self, name: str, part_name: str, stream: BytesOrStream, size: int
175+
) -> None:
137176
"""Upload bulk import having the specified name and part in the path.
138177
139178
Args:
@@ -156,7 +195,14 @@ def bulk_import_upload_part(self, name, part_name, stream, size):
156195
if code / 100 != 2:
157196
self.raise_error("Upload a part failed", res, body)
158197

159-
def bulk_import_upload_file(self, name, part_name, format, file, **kwargs):
198+
def bulk_import_upload_file(
199+
self,
200+
name: str,
201+
part_name: str,
202+
format: DataFormat,
203+
file: FileLike,
204+
**kwargs: Any,
205+
) -> None:
160206
"""Upload a file with bulk import having the specified name.
161207
162208
Args:
@@ -193,7 +239,9 @@ def bulk_import_upload_file(self, name, part_name, format, file, **kwargs):
193239
size = os.fstat(fp.fileno()).st_size
194240
return self.bulk_import_upload_part(name, part_name, fp, size)
195241

196-
def bulk_import_delete_part(self, name, part_name, params=None):
242+
def bulk_import_delete_part(
243+
self, name: str, part_name: str, params: dict[str, Any] | None = None
244+
) -> bool:
197245
"""Delete the imported information with the specified name.
198246
199247
Args:
@@ -218,7 +266,9 @@ def bulk_import_delete_part(self, name, part_name, params=None):
218266
self.raise_error("Delete a part failed", res, body)
219267
return True
220268

221-
def freeze_bulk_import(self, name, params=None):
269+
def freeze_bulk_import(
270+
self, name: str, params: dict[str, Any] | None = None
271+
) -> bool:
222272
"""Freeze the bulk import with the specified name.
223273
224274
Args:
@@ -236,7 +286,9 @@ def freeze_bulk_import(self, name, params=None):
236286
self.raise_error("Freeze bulk import failed", res, body)
237287
return True
238288

239-
def unfreeze_bulk_import(self, name, params=None):
289+
def unfreeze_bulk_import(
290+
self, name: str, params: dict[str, Any] | None = None
291+
) -> bool:
240292
"""Unfreeze bulk_import with the specified name.
241293
242294
Args:
@@ -254,7 +306,9 @@ def unfreeze_bulk_import(self, name, params=None):
254306
self.raise_error("Unfreeze bulk import failed", res, body)
255307
return True
256308

257-
def perform_bulk_import(self, name, params=None):
309+
def perform_bulk_import(
310+
self, name: str, params: dict[str, Any] | None = None
311+
) -> str:
258312
"""Execute a job to perform bulk import with the indicated priority using the
259313
resource pool if indicated, else it will use the account's default.
260314
@@ -274,7 +328,9 @@ def perform_bulk_import(self, name, params=None):
274328
js = self.checked_json(body, ["job_id"])
275329
return str(js["job_id"])
276330

277-
def commit_bulk_import(self, name, params=None):
331+
def commit_bulk_import(
332+
self, name: str, params: dict[str, Any] | None = None
333+
) -> bool:
278334
"""Commit the bulk import information having the specified name.
279335
280336
Args:
@@ -292,7 +348,9 @@ def commit_bulk_import(self, name, params=None):
292348
self.raise_error("Commit bulk import failed", res, body)
293349
return True
294350

295-
def bulk_import_error_records(self, name, params=None):
351+
def bulk_import_error_records(
352+
self, name: str, params: dict[str, Any] | None = None
353+
) -> Iterator[dict[str, Any]]:
296354
"""List the records that have errors under the specified bulk import name.
297355
298356
Args:

tdclient/client.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
import datetime
66
import json
77
from collections.abc import Iterator
8-
from typing import Any
8+
from typing import Any, cast, Literal
9+
910

1011
from tdclient import api, models
1112
from tdclient.types import (
1213
BulkImportParams,
14+
BytesOrStream,
1315
DataFormat,
1416
ExportParams,
1517
FileLike,
18+
Priority,
1619
ResultFormat,
1720
ResultParams,
1821
ScheduleParams,
@@ -254,7 +257,7 @@ def query(
254257
db_name: str,
255258
q: str,
256259
result_url: str | None = None,
257-
priority: int | str | None = None,
260+
priority: Priority | None = None,
258261
retry_limit: int | None = None,
259262
type: str = "hive",
260263
**kwargs: Any,
@@ -279,9 +282,11 @@ def query(
279282
# for compatibility, assume type is hive unless specifically specified
280283
if type not in ["hive", "pig", "impala", "presto", "trino"]:
281284
raise ValueError("The specified query type is not supported: %s" % (type))
285+
# Cast type to expected literal since we've validated it
286+
query_type = cast(Literal["hive", "presto", "trino", "bulkload"], type)
282287
job_id = self.api.query(
283288
q,
284-
type=type,
289+
type=query_type,
285290
db=db_name,
286291
result_url=result_url,
287292
priority=priority,
@@ -295,7 +300,7 @@ def jobs(
295300
_from: int | None = None,
296301
to: int | None = None,
297302
status: str | None = None,
298-
conditions: str | None = None,
303+
conditions: dict[str, Any] | None = None,
299304
) -> list[models.Job]:
300305
"""List jobs
301306
@@ -304,7 +309,7 @@ def jobs(
304309
to (int, optional): Gets the Job up to the nth index in the list.
305310
By default, the first 20 jobs in the list are displayed
306311
status (str, optional): Filter by given status. {"queued", "running", "success", "error"}
307-
conditions (str, optional): Condition for ``TIMESTAMPDIFF()`` to search for slow queries.
312+
conditions (dict[str, Any], optional): Condition for ``TIMESTAMPDIFF()`` to search for slow queries.
308313
Avoid using this parameter as it can be dangerous.
309314
310315
Returns:
@@ -334,7 +339,7 @@ def job_status(self, job_id: str | int) -> str:
334339
Returns:
335340
a string represents the status of the job ("success", "error", "killed", "queued", "running")
336341
"""
337-
return self.api.job_status(job_id)
342+
return self.api.job_status(str(job_id))
338343

339344
def job_result(self, job_id: str | int) -> list[Any]:
340345
"""
@@ -344,7 +349,7 @@ def job_result(self, job_id: str | int) -> list[Any]:
344349
Returns:
345350
a list of each rows in result set
346351
"""
347-
return self.api.job_result(job_id)
352+
return self.api.job_result(str(job_id))
348353

349354
def job_result_each(self, job_id: str | int) -> Iterator[dict[str, Any]]:
350355
"""
@@ -354,7 +359,7 @@ def job_result_each(self, job_id: str | int) -> Iterator[dict[str, Any]]:
354359
Returns:
355360
an iterator of result set
356361
"""
357-
for row in self.api.job_result_each(job_id):
362+
for row in self.api.job_result_each(str(job_id)):
358363
yield row
359364

360365
def job_result_format(
@@ -368,7 +373,7 @@ def job_result_format(
368373
Returns:
369374
a list of each rows in result set
370375
"""
371-
return self.api.job_result_format(job_id, format, header=header)
376+
return self.api.job_result_format(str(job_id), format, header=header)
372377

373378
def job_result_format_each(
374379
self,
@@ -393,7 +398,7 @@ def job_result_format_each(
393398
an iterator of rows in result set
394399
"""
395400
for row in self.api.job_result_format_each(
396-
job_id,
401+
str(job_id),
397402
format,
398403
header=header,
399404
store_tmpfile=store_tmpfile,
@@ -414,17 +419,17 @@ def download_job_result(
414419
Returns:
415420
`True` if success
416421
"""
417-
return self.api.download_job_result(job_id, path, num_threads=num_threads)
422+
return self.api.download_job_result(str(job_id), path, num_threads=num_threads)
418423

419-
def kill(self, job_id: str | int) -> str:
424+
def kill(self, job_id: str | int) -> str | None:
420425
"""
421426
Args:
422427
job_id (str): job id
423428
424429
Returns:
425430
a string represents the status of killed job ("queued", "running")
426431
"""
427-
return self.api.kill(job_id)
432+
return self.api.kill(str(job_id))
428433

429434
def export_data(
430435
self,
@@ -582,7 +587,7 @@ def bulk_imports(self) -> list[models.BulkImport]:
582587
]
583588

584589
def bulk_import_upload_part(
585-
self, name: str, part_name: str, bytes_or_stream: FileLike, size: int
590+
self, name: str, part_name: str, bytes_or_stream: BytesOrStream, size: int
586591
) -> None:
587592
"""Upload a part to a bulk import session
588593
@@ -849,7 +854,7 @@ def import_data(
849854
db_name: str,
850855
table_name: str,
851856
format: DataFormat,
852-
bytes_or_stream: FileLike,
857+
bytes_or_stream: BytesOrStream,
853858
size: int,
854859
unique_id: str | None = None,
855860
) -> float:

0 commit comments

Comments
 (0)