Skip to content

Commit 0743890

Browse files
committed
Added support for streaming multipart decoding
1 parent 31e8a16 commit 0743890

File tree

2 files changed

+149
-5
lines changed

2 files changed

+149
-5
lines changed

requests_toolbelt/multipart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from .encoder import MultipartEncoder, MultipartEncoderMonitor
12-
from .decoder import MultipartDecoder
12+
from .decoder import MultipartDecoder, MultipartStreamDecoder
1313
from .decoder import ImproperBodyPartContentException
1414
from .decoder import NonMultipartContentTypeException
1515

requests_toolbelt/multipart/decoder.py

Lines changed: 148 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,12 @@ def __init__(self, content, content_type, encoding='utf-8'):
107107
self.encoding = encoding
108108
#: Parsed parts of the multipart response body
109109
self.parts = tuple()
110-
self._find_boundary()
110+
self.boundary = MultipartDecoder._find_boundary(content_type, encoding)
111111
self._parse_body(content)
112112

113-
def _find_boundary(self):
114-
ct_info = tuple(x.strip() for x in self.content_type.split(';'))
113+
@staticmethod
114+
def _find_boundary(content_type, encoding):
115+
ct_info = tuple(x.strip() for x in content_type.split(';'))
115116
mimetype = ct_info[0]
116117
if mimetype.split('/')[0].lower() != 'multipart':
117118
raise NonMultipartContentTypeException(
@@ -123,7 +124,8 @@ def _find_boundary(self):
123124
'='
124125
)
125126
if attr.lower() == 'boundary':
126-
self.boundary = encode_with(value.strip('"'), self.encoding)
127+
boundary = encode_with(value.strip('"'), encoding)
128+
return boundary
127129

128130
@staticmethod
129131
def _fix_first_part(part, boundary_marker):
@@ -154,3 +156,145 @@ def from_response(cls, response, encoding='utf-8'):
154156
content = response.content
155157
content_type = response.headers.get('content-type', None)
156158
return cls(content, content_type, encoding)
159+
160+
161+
class AlreadyIteratedException(Exception):
162+
pass
163+
164+
165+
# Currently .text is not implemented
166+
class StreamPart(object):
167+
def __init__(self, headers, iterator):
168+
self.headers = headers
169+
self._iterator = iterator
170+
self._started = False
171+
self._consumed = False
172+
173+
def __iter__(self):
174+
if self._started:
175+
raise AlreadyIteratedException()
176+
self._started = True
177+
for typ, data in self._iterator():
178+
# TODO break if data is True as well ?
179+
if typ == 'done' and data == False:
180+
break
181+
elif typ == 'stream':
182+
yield data
183+
else:
184+
raise ImproperBodyPartContentException()
185+
186+
@property
187+
def content(self):
188+
if self._consumed:
189+
return self._content
190+
if self._started:
191+
raise AlreadyIteratedException()
192+
self._content = b''.join(self)
193+
self._consumed = True
194+
return self._content
195+
196+
197+
# On error this will not consume all data, afaik requests will handle this and deplete the stream
198+
# part_test is different then before (it's a stream, can't know what will be)
199+
class MultipartStreamDecoder(object):
200+
@classmethod
201+
def from_response(cls, response, encoding='utf-8', chunk_size=10 * 1024):
202+
content = lambda: response.raw.read(chunk_size)
203+
content_type = response.headers.get('content-type', None)
204+
return cls(content, content_type, encoding)
205+
206+
def __init__(self, stream_read_func, content_type, encoding='utf-8'):
207+
self.content_type = content_type
208+
self.encoding = encoding
209+
self._stream_read_func = stream_read_func
210+
self._boundary = MultipartDecoder._find_boundary(content_type, encoding)
211+
self._splitter = StreamSplitter()
212+
self._boundary = b''.join((b'--', self._boundary))
213+
self._boundary_split = b''.join((b'\r\n', self._boundary))
214+
self._state = 0
215+
self._found = False
216+
self._started = False
217+
218+
def __iter__(self):
219+
if self._started:
220+
raise AlreadyIteratedException()
221+
self._started = True
222+
for typ, data in self._stream():
223+
if typ == 'headers':
224+
yield StreamPart(data, self._stream)
225+
else:
226+
raise ImproperBodyPartContentException()
227+
228+
def _stream(self):
229+
while True:
230+
data = self._stream_read_func()
231+
# This persumes that if data returned empty once it won't return anything again (EOS)
232+
# TODO should we guard against data that returns None instead of '' ?
233+
if not self._found and not data:
234+
break
235+
# Remove the first empty multipart part
236+
if self._state == 0:
237+
# TODO can this be non empty?
238+
_, self._found = self._splitter.stream(data, self._boundary_split)
239+
if self._found:
240+
self._state = 1
241+
continue
242+
# Parse the headers
243+
elif self._state == 1:
244+
headers, self._found = self._splitter.stream(data, b'\r\n\r\n', True)
245+
if headers:
246+
# TODO can this only happen in headers or in body as well if didn't have headers ?
247+
#headers = MultipartDecoder._fix_first_part(headers, boundary)
248+
# TODO should headers be utf8 ? in python3 they are binary
249+
headers = _header_parser(headers.lstrip(), self.encoding)
250+
headers = CaseInsensitiveDict(headers)
251+
self._state = 2
252+
yield 'headers', headers
253+
continue
254+
# No headers found
255+
if self._found:
256+
headers = CaseInsensitiveDict({})
257+
self._state = 2
258+
yield 'headers', headers
259+
continue
260+
# Stream the part
261+
elif self._state == 2:
262+
stream, self._found = self._splitter.stream(data, self._boundary_split)
263+
if stream:
264+
yield 'stream', stream
265+
# boundary_split found, end of part
266+
if self._found:
267+
self._state = 1
268+
yield 'done', False
269+
continue
270+
271+
272+
# TODO this can be implmented with less copying
273+
class StreamSplitter(object):
274+
def __init__(self):
275+
self.leftover = b''
276+
277+
def stream(self, data, split_data, return_only_full=False):
278+
self.leftover += data
279+
index = self.leftover.find(split_data)
280+
if return_only_full:
281+
if index > -1:
282+
ret = self.leftover[:index]
283+
self.leftover = self.leftover[index + len(split_data):]
284+
found = True
285+
else:
286+
ret = b''
287+
found = False
288+
else:
289+
if index > -1:
290+
ret = self.leftover[:index]
291+
self.leftover = self.leftover[index + len(split_data):]
292+
found = True
293+
elif len(self.leftover) >= len(split_data):
294+
ret = self.leftover[:-len(split_data)]
295+
self.leftover = self.leftover[-len(split_data):]
296+
found = False
297+
else:
298+
ret = b''
299+
found = False
300+
return ret, found

0 commit comments

Comments
 (0)