Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tornado/test/util_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gzip
import re
import sys
import datetime
Expand All @@ -14,6 +15,7 @@
timedelta_to_seconds,
import_object,
re_unescape,
GzipDecompressor,
)

from typing import cast, Dict, Any
Expand Down Expand Up @@ -371,3 +373,57 @@ def test_version_info_compatible(self):

def test_current_version(self):
self.assert_version_info_compatible(tornado.version, tornado.version_info)


class GzipDecompressorTest(unittest.TestCase):
def test_concatenated_gzip_members(self):
"""Test that concatenated gzip members are fully decompressed."""
data1 = b"First gzip member content."
data2 = b"Second gzip member content."

member1 = gzip.compress(data1)
member2 = gzip.compress(data2)

concatenated = member1 + member2
decompressor = GzipDecompressor()
result = decompressor.decompress(concatenated)

expected = data1 + data2
self.assertEqual(
result, expected, "Concatenated gzip members should be fully decompressed"
)

def test_single_gzip_member(self):
"""Test that single gzip member is decompressed correctly."""
data = b"This is some example data that will be compressed using gzip."
compressed = gzip.compress(data)

decompressor = GzipDecompressor()
result = decompressor.decompress(compressed)

self.assertEqual(result, data)

def test_multiple_concatenated_members(self):
"""Test that three or more concatenated gzip members are fully decompressed."""
data1 = b"First member."
data2 = b"Second member."
data3 = b"Third member."

concatenated = gzip.compress(data1) + gzip.compress(data2) + gzip.compress(data3)
decompressor = GzipDecompressor()
result = decompressor.decompress(concatenated)

expected = data1 + data2 + data3
self.assertEqual(result, expected)

def test_decompress_after_flush_raises(self):
"""Test that decompress() raises RuntimeError after flush()."""
data = b"Test data"
compressed = gzip.compress(data)

decompressor = GzipDecompressor()
decompressor.decompress(compressed)
decompressor.flush()

with self.assertRaises(RuntimeError):
decompressor.decompress(gzip.compress(b"More data"))
36 changes: 34 additions & 2 deletions tornado/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self) -> None:
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
# This works on cpython and pypy, but not jython.
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
self._flushed = False

def decompress(self, value: bytes, max_length: int = 0) -> bytes:
"""Decompress a chunk, returning newly-available data.
Expand All @@ -89,7 +90,36 @@ def decompress(self, value: bytes, max_length: int = 0) -> bytes:
in ``unconsumed_tail``; you must retrieve this value and pass
it back to a future call to `decompress` if it is not empty.
"""
return self.decompressobj.decompress(value, max_length)
if self._flushed:
raise RuntimeError("Cannot call decompress() after flush()")

data = value
out = bytearray()
remaining = max_length

while True:
if remaining:
chunk = self.decompressobj.decompress(data, remaining)
else:
chunk = self.decompressobj.decompress(data)

out.extend(chunk)

if remaining:
remaining = max(0, max_length - len(out))
if remaining == 0:
break

# Handle concatenated gzip members
unused = getattr(self.decompressobj, "unused_data", b"")
if unused:
data = unused
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
continue

break

return bytes(out)

@property
def unconsumed_tail(self) -> bytes:
Expand All @@ -102,7 +132,9 @@ def flush(self) -> bytes:
Also checks for errors such as truncated input.
No other methods may be called on this object after `flush`.
"""
return self.decompressobj.flush()
result = self.decompressobj.flush()
self._flushed = True
return result


def import_object(name: str) -> Any:
Expand Down