22import json
33import os
44from io import BytesIO
5- from typing import AsyncGenerator , Dict , List , Optional , Tuple
5+ from typing import AsyncGenerator , Dict , List , Optional , Tuple , Union
66from urllib .parse import urljoin
77
88import aiofiles
99import httpx
10+ from aiofiles .threadpool .binary import AsyncBufferedReader
1011from rich .progress import Progress , TaskID
1112
1213from dvuploader .file import File
3334UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId="
3435REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId="
3536
37+ DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB
38+
3639# Initialize logging
3740init_logging ()
3841
@@ -399,34 +402,25 @@ async def _chunked_upload(
399402 )
400403
401404 async with aiofiles .open (file .filepath , "rb" ) as f :
402- chunk = await f .read (chunk_size )
403- e_tags .append (
404- await _upload_chunk (
405- session = session ,
406- url = next (urls ),
407- file = BytesIO (chunk ),
408- progress = progress ,
409- pbar = pbar ,
410- hash_func = file .checksum ._hash_fun ,
405+ filesize = os .path .getsize (file .filepath )
406+ current_position = 0
407+
408+ while current_position < filesize :
409+ current_chunk_size = min (chunk_size , filesize - current_position )
410+
411+ e_tags .append (
412+ await _upload_chunk (
413+ session = session ,
414+ url = next (urls ),
415+ file = f ,
416+ progress = progress ,
417+ pbar = pbar ,
418+ hash_func = file .checksum ._hash_fun ,
419+ chunk_size = current_chunk_size ,
420+ )
411421 )
412- )
413-
414- while chunk :
415- chunk = await f .read (chunk_size )
416422
417- if not chunk :
418- break
419- else :
420- e_tags .append (
421- await _upload_chunk (
422- session = session ,
423- url = next (urls ),
424- file = BytesIO (chunk ),
425- progress = progress ,
426- pbar = pbar ,
427- hash_func = file .checksum ._hash_fun ,
428- )
429- )
423+ current_position += current_chunk_size
430424
431425 return e_tags
432426
@@ -452,10 +446,11 @@ def _validate_ticket_response(response: Dict) -> None:
452446async def _upload_chunk (
453447 session : httpx .AsyncClient ,
454448 url : str ,
455- file : BytesIO ,
449+ file : Union [ BytesIO , AsyncBufferedReader ] ,
456450 progress : Progress ,
457451 pbar : TaskID ,
458452 hash_func ,
453+ chunk_size : int ,
459454):
460455 """
461456 Upload a single chunk of data.
@@ -467,6 +462,7 @@ async def _upload_chunk(
467462 progress (Progress): Progress tracking object.
468463 pbar (TaskID): Progress bar task ID.
469464 hash_func: Hash function for checksum.
465+ chunk_size (int): Size of chunk to upload.
470466
471467 Returns:
472468 str: ETag from server response.
@@ -475,8 +471,13 @@ async def _upload_chunk(
475471 if TESTING :
476472 url = url .replace ("localstack" , "localhost" , 1 )
477473
474+ if isinstance (file , BytesIO ):
475+ file_size = len (file .getvalue ())
476+ else :
477+ file_size = chunk_size
478+
478479 headers = {
479- "Content-length" : str (len ( file . getvalue ()) ),
480+ "Content-length" : str (file_size ),
480481 }
481482
482483 params = {
@@ -487,6 +488,7 @@ async def _upload_chunk(
487488 progress = progress ,
488489 pbar = pbar ,
489490 hash_func = hash_func ,
491+ chunk_size = chunk_size ,
490492 ),
491493 }
492494
@@ -664,10 +666,11 @@ async def _multipart_json_data_request(
664666
665667
666668async def upload_bytes (
667- file : BytesIO ,
669+ file : Union [ BytesIO , AsyncBufferedReader ] ,
668670 progress : Progress ,
669671 pbar : TaskID ,
670672 hash_func ,
673+ chunk_size : Optional [int ] = None ,
671674) -> AsyncGenerator [bytes , None ]:
672675 """
673676 Generate chunks of file data for upload.
@@ -681,12 +684,27 @@ async def upload_bytes(
681684 Yields:
682685 bytes: Next chunk of file data.
683686 """
687+ read_bytes = 0
684688 while True :
685- data = file .read (1024 * 1024 ) # 1MB
689+ if chunk_size is not None and read_bytes >= chunk_size :
690+ break
691+
692+ if isinstance (file , AsyncBufferedReader ):
693+ data = await file .read (DEFAULT_CHUNK_SIZE ) # 1MB
694+ else :
695+ data = file .read (DEFAULT_CHUNK_SIZE ) # 1MB
686696
687697 if not data :
688698 break
689699
700+ if chunk_size is not None :
701+ remaining = chunk_size - read_bytes
702+ if remaining <= 0 :
703+ break
704+ if len (data ) > remaining :
705+ data = data [:remaining ]
706+ read_bytes += len (data )
707+
690708 # Update the hash function with the data
691709 hash_func .update (data )
692710
0 commit comments