Skip to content

Commit 591b115

Browse files
committed
refactor LimitedStream.exhaust
ensure exhaust accounts for returned size, not requested size add comments about exhaust/disconnect logic clean up test
1 parent 971a964 commit 591b115

File tree

3 files changed

+77
-60
lines changed

3 files changed

+77
-60
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Unreleased
1717
client. :issue:`2549`
1818
- Fix handling of header extended parameters such that they are no longer quoted.
1919
:issue:`2529`
20+
- ``LimitedStream.read`` works correctly when wrapping a stream that may not return
21+
the requested size in one ``read`` call. :issue:`2558`
2022

2123

2224
Version 2.2.2

src/werkzeug/wsgi.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -928,56 +928,77 @@ def on_disconnect(self) -> bytes:
928928

929929
raise ClientDisconnected()
930930

931-
def exhaust(self, chunk_size: int = 1024 * 64) -> None:
932-
"""Exhaust the stream. This consumes all the data left until the
933-
limit is reached.
931+
def _exhaust_chunks(self, chunk_size: int = 1024 * 64) -> t.Iterator[bytes]:
932+
"""Exhaust the stream by reading until the limit is reached or the client
933+
disconnects, yielding each chunk.
934+
935+
:param chunk_size: How many bytes to read at a time.
936+
937+
:meta private:
934938
935-
:param chunk_size: the size for a chunk. It will read the chunk
936-
until the stream is exhausted and throw away
937-
the results.
939+
.. versionadded:: 2.2.3
938940
"""
939941
to_read = self.limit - self._pos
940-
chunk = chunk_size
942+
941943
while to_read > 0:
942-
chunk = min(to_read, chunk)
943-
self.read(chunk)
944-
to_read -= chunk
945-
946-
def exhaust_into(self, buf: bytearray, chunk_size: int = 1024 * 64) -> None:
947-
"""Exhaust the stream. This consumes all the data left until the
948-
limit is reached, and writes the result into the given buffer.
949-
950-
:param buf: the buffer to read the result into.
951-
:param chunk_size: the size for a chunk. It will read the chunk
952-
until the stream is exhausted and write it into
953-
the buffer.
944+
chunk = self.read(min(to_read, chunk_size))
945+
yield chunk
946+
to_read -= len(chunk)
947+
948+
def exhaust(self, chunk_size: int = 1024 * 64) -> None:
949+
"""Exhaust the stream by reading until the limit is reached or the client
950+
disconnects, discarding the data.
951+
952+
:param chunk_size: How many bytes to read at a time.
953+
954+
.. versionchanged:: 2.2.3
955+
Handle case where wrapped stream returns fewer bytes than requested.
954956
"""
955-
to_read = self.limit - self._pos
956-
chunk = chunk_size
957-
while to_read > 0:
958-
chunk = min(to_read, chunk)
959-
data = self.read(chunk)
960-
buf.extend(data)
961-
to_read -= len(data)
957+
for _ in self._exhaust_chunks(chunk_size):
958+
pass
962959

963960
def read(self, size: t.Optional[int] = None) -> bytes:
964-
"""Read `size` bytes or if size is not provided everything is read.
961+
"""Read up to ``size`` bytes from the underlying stream. If size is not
962+
provided, read until the limit.
965963
966-
:param size: the number of bytes read.
964+
If the limit is reached, :meth:`on_exhausted` is called, which returns empty
965+
bytes.
966+
967+
If no bytes are read and the limit is not reached, or if an error occurs during
968+
the read, :meth:`on_disconnect` is called, which raises
969+
:exc:`.ClientDisconnected`.
970+
971+
:param size: The number of bytes to read. ``None``, default, reads until the
972+
limit is reached.
973+
974+
.. versionchanged:: 2.2.3
975+
Handle case where wrapped stream returns fewer bytes than requested.
967976
"""
968977
if self._pos >= self.limit:
969978
return self.on_exhausted()
970-
if size is None or size == -1: # -1 is for consistence with file
979+
980+
if size is None or size == -1: # -1 is for consistency with file
981+
# Keep reading from the wrapped stream until the limit is reached. Can't
982+
# rely on stream.read(size) because it's not guaranteed to return size.
971983
buf = bytearray()
972-
self.exhaust_into(buf)
984+
985+
for chunk in self._exhaust_chunks():
986+
buf.extend(chunk)
987+
973988
return bytes(buf)
989+
974990
to_read = min(self.limit - self._pos, size)
991+
975992
try:
976993
read = self._read(to_read)
977994
except (OSError, ValueError):
978995
return self.on_disconnect()
996+
979997
if to_read and not len(read):
998+
# If no data was read, treat it as a disconnect. As long as some data was
999+
# read, a subsequent call can still return more before reaching the limit.
9801000
return self.on_disconnect()
1001+
9811002
self._pos += len(read)
9821003
return read
9831004

tests/test_wsgi.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from __future__ import annotations
2+
13
import io
24
import json
35
import os
6+
import typing as t
47

58
import pytest
69

@@ -165,69 +168,60 @@ def test_limited_stream_json_load():
165168

166169

167170
def test_limited_stream_disconnection():
168-
io_ = io.BytesIO(b"A bit of content")
169-
170-
# disconnect detection on out of bytes
171-
stream = wsgi.LimitedStream(io_, 255)
171+
# disconnect because stream returns zero bytes
172+
stream = wsgi.LimitedStream(io.BytesIO(), 255)
172173
with pytest.raises(ClientDisconnected):
173174
stream.read()
174175

175-
# disconnect detection because file close
176-
io_ = io.BytesIO(b"x" * 255)
177-
io_.close()
178-
stream = wsgi.LimitedStream(io_, 255)
176+
# disconnect because stream is closed
177+
data = io.BytesIO(b"x" * 255)
178+
data.close()
179+
stream = wsgi.LimitedStream(data, 255)
180+
179181
with pytest.raises(ClientDisconnected):
180182
stream.read()
181183

182184

183185
def test_limited_stream_read_with_raw_io():
184-
class FakeRawIOStream:
185-
"""
186-
Fakes Raw IO behavior where fewer bytes can be returned by ``read`` than what
187-
are asked for through `size`.
188-
"""
189-
190-
buf: bytes
191-
192-
def __init__(self, buf: bytes):
186+
class OneByteStream(t.BinaryIO):
187+
def __init__(self, buf: bytes) -> None:
193188
self.buf = buf
194189
self.pos = 0
195190

196-
def read(self, size: int) -> bytes:
191+
def read(self, size: int | None = None) -> bytes:
192+
"""Return one byte at a time regardless of requested size."""
193+
197194
if size is None or size == -1:
198195
raise ValueError("expected read to be called with specific limit")
199-
if size == 0:
200-
return b""
201196

202-
if len(self.buf) < self.pos:
197+
if size == 0 or len(self.buf) < self.pos:
203198
return b""
204199

205200
b = self.buf[self.pos : self.pos + 1]
206201
self.pos += 1
207202
return b
208203

209-
def readline(self):
210-
raise NotImplementedError
211-
212-
data = b"foo"
213-
stream = wsgi.LimitedStream(FakeRawIOStream(data), 4) # noqa
204+
stream = wsgi.LimitedStream(OneByteStream(b"foo"), 4)
214205
assert stream.read(5) == b"f"
215206
assert stream.read(5) == b"o"
216207
assert stream.read(5) == b"o"
217-
# the underlying stream has fewer bytes than the expected limit
208+
209+
# The stream has fewer bytes (3) than the limit (4), therefore the read returns 0
210+
# bytes before the limit is reached.
218211
with pytest.raises(ClientDisconnected):
219212
stream.read(5)
220213

221-
stream = wsgi.LimitedStream(FakeRawIOStream(data), 3) # noqa
214+
stream = wsgi.LimitedStream(OneByteStream(b"foo123"), 3)
222215
assert stream.read(5) == b"f"
223216
assert stream.read(5) == b"o"
224217
assert stream.read(5) == b"o"
218+
# The limit was reached, therefore the wrapper is exhausted, not disconnected.
225219
assert stream.read(5) == b""
226220

227-
stream = wsgi.LimitedStream(FakeRawIOStream(data), 3) # noqa
221+
stream = wsgi.LimitedStream(OneByteStream(b"foo"), 3)
228222
assert stream.read() == b"foo"
229223

230-
stream = wsgi.LimitedStream(FakeRawIOStream(data), 2) # noqa
224+
stream = wsgi.LimitedStream(OneByteStream(b"foo"), 2)
231225
assert stream.read() == b"fo"
232226

233227

0 commit comments

Comments
 (0)