Skip to content

Commit b9fc079

Browse files
authored
Merge pull request #52 from gdcc/fix-chunk-read
Fix multipart direct upload buffering chunks in memory
2 parents 0dd0592 + b26a1ce commit b9fc079

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

dvuploader/directupload.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import json
33
import os
44
from io import BytesIO
5-
from typing import AsyncGenerator, Dict, List, Optional, Tuple
5+
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
66
from urllib.parse import urljoin
77

88
import aiofiles
99
import httpx
10+
from aiofiles.threadpool.binary import AsyncBufferedReader
1011
from rich.progress import Progress, TaskID
1112

1213
from dvuploader.file import File
@@ -33,6 +34,8 @@
3334
UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId="
3435
REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId="
3536

37+
DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB
38+
3639
# Initialize logging
3740
init_logging()
3841

@@ -399,34 +402,25 @@ async def _chunked_upload(
399402
)
400403

401404
async with aiofiles.open(file.filepath, "rb") as f:
402-
chunk = await f.read(chunk_size)
403-
e_tags.append(
404-
await _upload_chunk(
405-
session=session,
406-
url=next(urls),
407-
file=BytesIO(chunk),
408-
progress=progress,
409-
pbar=pbar,
410-
hash_func=file.checksum._hash_fun,
405+
filesize = os.path.getsize(file.filepath)
406+
current_position = 0
407+
408+
while current_position < filesize:
409+
current_chunk_size = min(chunk_size, filesize - current_position)
410+
411+
e_tags.append(
412+
await _upload_chunk(
413+
session=session,
414+
url=next(urls),
415+
file=f,
416+
progress=progress,
417+
pbar=pbar,
418+
hash_func=file.checksum._hash_fun,
419+
chunk_size=current_chunk_size,
420+
)
411421
)
412-
)
413-
414-
while chunk:
415-
chunk = await f.read(chunk_size)
416422

417-
if not chunk:
418-
break
419-
else:
420-
e_tags.append(
421-
await _upload_chunk(
422-
session=session,
423-
url=next(urls),
424-
file=BytesIO(chunk),
425-
progress=progress,
426-
pbar=pbar,
427-
hash_func=file.checksum._hash_fun,
428-
)
429-
)
423+
current_position += current_chunk_size
430424

431425
return e_tags
432426

@@ -452,10 +446,11 @@ def _validate_ticket_response(response: Dict) -> None:
452446
async def _upload_chunk(
453447
session: httpx.AsyncClient,
454448
url: str,
455-
file: BytesIO,
449+
file: Union[BytesIO, AsyncBufferedReader],
456450
progress: Progress,
457451
pbar: TaskID,
458452
hash_func,
453+
chunk_size: int,
459454
):
460455
"""
461456
Upload a single chunk of data.
@@ -467,6 +462,7 @@ async def _upload_chunk(
467462
progress (Progress): Progress tracking object.
468463
pbar (TaskID): Progress bar task ID.
469464
hash_func: Hash function for checksum.
465+
chunk_size (int): Size of chunk to upload.
470466
471467
Returns:
472468
str: ETag from server response.
@@ -475,8 +471,13 @@ async def _upload_chunk(
475471
if TESTING:
476472
url = url.replace("localstack", "localhost", 1)
477473

474+
if isinstance(file, BytesIO):
475+
file_size = len(file.getvalue())
476+
else:
477+
file_size = chunk_size
478+
478479
headers = {
479-
"Content-length": str(len(file.getvalue())),
480+
"Content-length": str(file_size),
480481
}
481482

482483
params = {
@@ -487,6 +488,7 @@ async def _upload_chunk(
487488
progress=progress,
488489
pbar=pbar,
489490
hash_func=hash_func,
491+
chunk_size=chunk_size,
490492
),
491493
}
492494

@@ -664,10 +666,11 @@ async def _multipart_json_data_request(
664666

665667

666668
async def upload_bytes(
667-
file: BytesIO,
669+
file: Union[BytesIO, AsyncBufferedReader],
668670
progress: Progress,
669671
pbar: TaskID,
670672
hash_func,
673+
chunk_size: Optional[int] = None,
671674
) -> AsyncGenerator[bytes, None]:
672675
"""
673676
Generate chunks of file data for upload.
@@ -681,12 +684,27 @@ async def upload_bytes(
681684
Yields:
682685
bytes: Next chunk of file data.
683686
"""
687+
read_bytes = 0
684688
while True:
685-
data = file.read(1024 * 1024) # 1MB
689+
if chunk_size is not None and read_bytes >= chunk_size:
690+
break
691+
692+
if isinstance(file, AsyncBufferedReader):
693+
data = await file.read(DEFAULT_CHUNK_SIZE) # 1MB
694+
else:
695+
data = file.read(DEFAULT_CHUNK_SIZE) # 1MB
686696

687697
if not data:
688698
break
689699

700+
if chunk_size is not None:
701+
remaining = chunk_size - read_bytes
702+
if remaining <= 0:
703+
break
704+
if len(data) > remaining:
705+
data = data[:remaining]
706+
read_bytes += len(data)
707+
690708
# Update the hash function with the data
691709
hash_func.update(data)
692710

0 commit comments

Comments
 (0)