Skip to content

Commit 75aa605

Browse files
committed
More type hinting
1 parent 7f7a118 commit 75aa605

25 files changed

+217
-124
lines changed

.github/workflows/pythontest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
with:
1818
python-version: "3.13"
1919
- name: Install dependencies
20-
run: uv sync --extra dev
20+
run: uv sync --dev
2121
- name: lint with ruff
2222
run: |
2323
uv run ruff format tdclient --diff --exit-non-zero-on-fix

pyproject.toml

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ dependencies = [
3434
]
3535

3636
[project.optional-dependencies]
37-
dev = [
38-
"ruff",
39-
"pyright",
40-
"tox>=4",
41-
]
4237
docs = [
4338
"sphinx",
4439
"sphinx_rtd_theme",
@@ -78,15 +73,23 @@ known-third-party = ["dateutil","msgpack","pkg_resources","pytest","setuptools",
7873
[tool.pyright]
7974
include = ["tdclient"]
8075
exclude = ["**/__pycache__", "tdclient/test", "docs"]
81-
typeCheckingMode = "basic"
76+
typeCheckingMode = "strict"
8277
pythonVersion = "3.10"
8378
pythonPlatform = "All"
84-
reportMissingTypeStubs = false
85-
reportUnknownMemberType = false
86-
reportUnknownArgumentType = false
87-
reportUnknownVariableType = false
88-
reportMissingImports = "warning"
79+
reportMissingTypeStubs = "warning"
80+
reportUnknownMemberType = "error"
81+
reportUnknownArgumentType = "error"
82+
reportMissingImports = "error"
8983

9084
# Pre-commit venv configuration
9185
venvPath = "."
9286
venv = ".venv"
87+
88+
[dependency-groups]
89+
dev = [
90+
"ruff",
91+
"pyright",
92+
"tox>=4",
93+
"msgpack-types>=0.5.0",
94+
"types-certifi",
95+
]

tdclient/api.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
import time
1515
import urllib.parse as urlparse
1616
from array import array
17+
from collections.abc import Iterator
1718
from typing import IO, Any, cast
1819

1920
import msgpack
2021
import urllib3
21-
import urllib3.util
2222

2323
from tdclient import errors, version
2424
from tdclient.bulk_import_api import BulkImportAPI
@@ -31,7 +31,7 @@
3131
from tdclient.schedule_api import ScheduleAPI
3232
from tdclient.server_status_api import ServerStatusAPI
3333
from tdclient.table_api import TableAPI
34-
from tdclient.types import BytesOrStream, StreamBody
34+
from tdclient.types import BytesOrStream, DataFormat, FileLike, StreamBody
3535
from tdclient.user_api import UserAPI
3636
from tdclient.util import (
3737
csv_dict_record_reader,
@@ -42,7 +42,7 @@
4242
)
4343

4444
try:
45-
import certifi
45+
import certifi # type: ignore[reportMissingImports]
4646
except ImportError:
4747
certifi = None
4848

@@ -576,7 +576,9 @@ def close(self) -> None:
576576
# all connections in pool will be closed eventually during gc.
577577
self.http.clear()
578578

579-
def _prepare_file(self, file_like, fmt, **kwargs):
579+
def _prepare_file(
580+
self, file_like: FileLike, fmt: DataFormat, **kwargs: Any
581+
) -> IO[bytes]:
580582
fp = tempfile.TemporaryFile()
581583
with contextlib.closing(gzip.GzipFile(mode="wb", fileobj=fp)) as gz:
582584
packer = msgpack.Packer()
@@ -591,34 +593,41 @@ def _prepare_file(self, file_like, fmt, **kwargs):
591593
fp.seek(0)
592594
return fp
593595

594-
def _read_file(self, file_like, fmt, **kwargs):
596+
def _read_file(self, file_like: FileLike, fmt: DataFormat, **kwargs: Any) -> Any:
595597
compressed = fmt.endswith(".gz")
598+
fmt_str = str(fmt)
596599
if compressed:
597-
fmt = fmt[0 : len(fmt) - len(".gz")]
598-
reader_name = f"_read_{fmt}_file"
600+
fmt_str = fmt_str[0 : len(fmt_str) - len(".gz")]
601+
reader_name = f"_read_{fmt_str}_file"
599602
if hasattr(self, reader_name):
600603
reader = getattr(self, reader_name)
601604
else:
602605
raise TypeError(f"unknown format: {fmt}")
603606
if hasattr(file_like, "read"):
604607
if compressed:
605-
file_like = gzip.GzipFile(fileobj=file_like)
608+
file_like = gzip.GzipFile(fileobj=file_like) # type: ignore[arg-type]
606609
return reader(file_like, **kwargs)
607610
else:
611+
# At this point, file_like must be str or bytes (not IO[bytes])
612+
file_path = cast("str | bytes", file_like)
608613
if compressed:
609-
file_like = gzip.GzipFile(fileobj=open(file_like, "rb"))
614+
file_like = gzip.GzipFile(fileobj=open(file_path, "rb")) # type: ignore[arg-type]
610615
else:
611-
file_like = open(file_like, "rb")
616+
file_like = open(file_path, "rb")
612617
return reader(file_like, **kwargs)
613618

614-
def _read_msgpack_file(self, file_like, **kwargs):
619+
def _read_msgpack_file(
620+
self, file_like: IO[bytes], **kwargs: Any
621+
) -> Iterator[dict[str, Any]]:
615622
# current impl doesn't tolerate any unpack error
616-
unpacker = msgpack.Unpacker(file_like, raw=False)
623+
unpacker = msgpack.Unpacker(file_like, raw=False) # type: ignore[arg-type]
617624
for record in unpacker:
618625
validate_record(record)
619626
yield record
620627

621-
def _read_json_file(self, file_like, **kwargs):
628+
def _read_json_file(
629+
self, file_like: IO[bytes], **kwargs: Any
630+
) -> Iterator[dict[str, Any]]:
622631
# current impl doesn't tolerate any JSON parse error
623632
for s in file_like:
624633
record = json.loads(s.decode("utf-8"))
@@ -627,20 +636,22 @@ def _read_json_file(self, file_like, **kwargs):
627636

628637
def _read_csv_file(
629638
self,
630-
file_like,
631-
dialect=csv.excel,
632-
columns=None,
633-
encoding="utf-8",
634-
dtypes=None,
635-
converters=None,
636-
**kwargs,
637-
):
639+
file_like: IO[bytes],
640+
dialect: type[csv.Dialect] = csv.excel,
641+
columns: list[str] | None = None,
642+
encoding: str = "utf-8",
643+
dtypes: dict[str, Any] | None = None,
644+
converters: dict[str, Any] | None = None,
645+
**kwargs: Any,
646+
) -> Iterator[dict[str, Any]]:
638647
if columns is None:
639-
reader = csv_dict_record_reader(file_like, encoding, dialect)
648+
reader = csv_dict_record_reader(file_like, encoding, dialect) # type: ignore[arg-type]
640649
else:
641-
reader = csv_text_record_reader(file_like, encoding, dialect, columns)
650+
reader = csv_text_record_reader(file_like, encoding, dialect, columns) # type: ignore[arg-type]
642651

643652
return read_csv_records(reader, dtypes, converters, **kwargs)
644653

645-
def _read_tsv_file(self, file_like, **kwargs):
654+
def _read_tsv_file(
655+
self, file_like: IO[bytes], **kwargs: Any
656+
) -> Iterator[dict[str, Any]]:
646657
return self._read_csv_file(file_like, dialect=csv.excel_tab, **kwargs)

tdclient/bulk_import_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class BulkImportAPI:
2626
def get(
2727
self,
2828
path: str,
29-
params: dict[str, Any] | bytes | None = None,
29+
params: dict[str, Any] | None = None,
3030
headers: dict[str, str] | None = None,
3131
**kwargs: Any,
3232
) -> AbstractContextManager[urllib3.BaseHTTPResponse]: ...
@@ -50,7 +50,7 @@ def raise_error(
5050
) -> None: ...
5151
def checked_json(self, body: bytes, required: list[str]) -> dict[str, Any]: ...
5252
def _prepare_file(
53-
self, file_like: FileLike, fmt: str, **kwargs: Any
53+
self, file_like: FileLike, fmt: DataFormat, **kwargs: Any
5454
) -> IO[bytes]: ...
5555

5656
def create_bulk_import(
@@ -165,7 +165,7 @@ def validate_part_name(part_name: str) -> None:
165165
part_name (str): The part name the user is trying to use
166166
"""
167167
# Check for duplicate periods
168-
d = collections.defaultdict(int)
168+
d: collections.defaultdict[str, int] = collections.defaultdict(int)
169169
for char in part_name:
170170
d[char] += 1
171171

@@ -378,5 +378,5 @@ def bulk_import_error_records(
378378
body = io.BytesIO(res.read())
379379
decompressor = gzip.GzipFile(fileobj=body)
380380

381-
unpacker = msgpack.Unpacker(decompressor, raw=False)
381+
unpacker = msgpack.Unpacker(decompressor, raw=False) # type: ignore[arg-type]
382382
yield from unpacker

tdclient/bulk_import_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING, Any
66

77
from tdclient.model import Model
8-
from tdclient.types import FileLike
8+
from tdclient.types import BytesOrStream, DataFormat, FileLike
99

1010
if TYPE_CHECKING:
1111
from tdclient.client import Client
@@ -112,7 +112,7 @@ def perform(
112112
self,
113113
wait: bool = False,
114114
wait_interval: int = 5,
115-
wait_callback: Callable[[], None] | None = None,
115+
wait_callback: Callable[["Job"], None] | None = None,
116116
timeout: float | None = None,
117117
) -> "Job":
118118
"""Perform bulk import
@@ -162,7 +162,9 @@ def error_record_items(self) -> Iterator[dict[str, Any]]:
162162
"""
163163
yield from self._client.bulk_import_error_records(self.name)
164164

165-
def upload_part(self, part_name: str, bytes_or_stream: FileLike, size: int) -> bool:
165+
def upload_part(
166+
self, part_name: str, bytes_or_stream: BytesOrStream, size: int
167+
) -> None:
166168
"""Upload a part to bulk import session
167169
168170
Args:
@@ -177,8 +179,8 @@ def upload_part(self, part_name: str, bytes_or_stream: FileLike, size: int) -> b
177179
return response
178180

179181
def upload_file(
180-
self, part_name: str, fmt: str, file_like: FileLike, **kwargs: Any
181-
) -> float:
182+
self, part_name: str, fmt: DataFormat, file_like: FileLike, **kwargs: Any
183+
) -> None:
182184
"""Upload a part to Bulk Import session, from an existing file on filesystem.
183185
184186
Args:

tdclient/client.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -790,12 +790,12 @@ def history(
790790
"""
791791
result = self.api.history(name, _from or 0, to)
792792

793-
def scheduled_job(m):
793+
def scheduled_job(m: tuple[Any, ...]) -> models.ScheduledJob:
794794
(
795795
scheduled_at,
796796
job_id,
797797
type,
798-
status,
798+
_status,
799799
query,
800800
start_at,
801801
end_at,
@@ -837,8 +837,11 @@ def run_schedule(self, name: str, time: int, num: int) -> list[models.ScheduledJ
837837
"""
838838
results = self.api.run_schedule(name, time, num)
839839

840-
def scheduled_job(m):
840+
def scheduled_job(
841+
m: tuple[Any, str, datetime.datetime | None],
842+
) -> models.ScheduledJob:
841843
job_id, type, scheduled_at = m
844+
assert scheduled_at is not None
842845
return models.ScheduledJob(self, scheduled_at, job_id, type, None)
843846

844847
return [scheduled_job(m) for m in results]
@@ -904,7 +907,7 @@ def results(self) -> list[models.Result]:
904907
"""
905908
results = self.api.list_result()
906909

907-
def result(m):
910+
def result(m: tuple[str, str, None]) -> models.Result:
908911
name, url, organizations = m
909912
return models.Result(self, name, url, organizations)
910913

@@ -943,7 +946,7 @@ def users(self):
943946
"""
944947
results = self.api.list_users()
945948

946-
def user(m):
949+
def user(m: tuple[str, None, None, str]) -> models.User:
947950
name, org, roles, email = m
948951
return models.User(self, name, org, roles, email)
949952

@@ -1011,7 +1014,7 @@ def close(self) -> None:
10111014

10121015

10131016
def job_from_dict(client: Client, dd: dict[str, Any], **values: Any) -> models.Job:
1014-
d = dict()
1017+
d: dict[str, Any] = dict()
10151018
d.update(dd)
10161019
d.update(values)
10171020
return models.Job(

tdclient/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
wait_callback: Callable[["Cursor"], None] | None = None,
2424
**kwargs: Any,
2525
) -> None:
26-
cursor_kwargs = dict()
26+
cursor_kwargs: dict[str, Any] = dict()
2727
if type is not None:
2828
cursor_kwargs["type"] = type
2929
if db is not None:

0 commit comments

Comments
 (0)