From 26f3a3d0ac75018045591989e1e1b335ddb3b4fe Mon Sep 17 00:00:00 2001 From: przemyslaw wierzbicki Date: Wed, 10 Jul 2019 00:37:00 -0400 Subject: [PATCH] #109 detect closed client connection --- .../nio/channels/AsynchronousChannel.scala | 37 +++-- src/test/scala/zio/nio/ChannelSuite.scala | 127 ++++++++++++------ 2 files changed, 111 insertions(+), 53 deletions(-) diff --git a/src/main/scala/zio/nio/channels/AsynchronousChannel.scala b/src/main/scala/zio/nio/channels/AsynchronousChannel.scala index 566b785..82425b8 100644 --- a/src/main/scala/zio/nio/channels/AsynchronousChannel.scala +++ b/src/main/scala/zio/nio/channels/AsynchronousChannel.scala @@ -1,5 +1,6 @@ package zio.nio.channels +import java.io.IOException import java.lang.{ Integer => JInteger, Long => JLong, Void => JVoid } import java.nio.{ ByteBuffer => JByteBuffer } import java.nio.channels.{ @@ -27,9 +28,13 @@ class AsynchronousByteChannel(private val channel: JAsynchronousByteChannel) { final def read(capacity: Int): IO[Exception, Chunk[Byte]] = for { b <- Buffer.byte(capacity) - _ <- readBuffer(b) + l <- readBuffer(b) a <- b.array - r = Chunk.fromArray(a) + r <- if (l == -1) { + ZIO.fail(new IOException("Connection reset by peer")) + } else { + ZIO.succeed(Chunk.fromArray(a).take(l)) + } } yield r /** @@ -43,9 +48,13 @@ class AsynchronousByteChannel(private val channel: JAsynchronousByteChannel) { final def read[A](capacity: Int, attachment: A): IO[Exception, Chunk[Byte]] = for { b <- Buffer.byte(capacity) - _ <- readBuffer(b, attachment) + l <- readBuffer(b, attachment) a <- b.array - r = Chunk.fromArray(a) + r <- if (l == -1) { + ZIO.fail(new IOException("Connection reset by peer")) + } else { + ZIO.succeed(Chunk.fromArray(a).take(l)) + } } yield r /** @@ -216,9 +225,13 @@ class AsynchronousSocketChannel(private val channel: JAsynchronousSocketChannel) final def read[A](capacity: Int, timeout: Duration, attachment: A): IO[Exception, Chunk[Byte]] = for { b <- Buffer.byte(capacity) - _ <- readBuffer(b, timeout, attachment) + l <- readBuffer(b, timeout, attachment) a <- b.array - r = Chunk.fromArray(a) + r <- if (l == -1) { + ZIO.fail(new IOException("Connection reset by peer")) + } else { + ZIO.succeed(Chunk.fromArray(a).take(l)) + } } yield r final private[nio] def readBuffer[A]( @@ -249,11 +262,15 @@ class AsynchronousSocketChannel(private val channel: JAsynchronousSocketChannel) attachment: A ): IO[Exception, List[Chunk[Byte]]] = for { - bs <- IO.collectAll(capacities.map(Buffer.byte(_))) - _ <- readBuffer(bs, offset, length, timeout, attachment) + bs <- IO.collectAll(capacities.map(Buffer.byte)) + l <- readBuffer(bs, offset, length, timeout, attachment) as <- IO.collectAll(bs.map(_.array)) - ds = as.map(Chunk.fromArray(_)) - } yield ds + r <- if (l == -1) { + ZIO.fail(new IOException("Connection reset by peer")) + } else { + ZIO.succeed(as.map(Chunk.fromArray)) + } + } yield r } diff --git a/src/test/scala/zio/nio/ChannelSuite.scala b/src/test/scala/zio/nio/ChannelSuite.scala index a86c300..f45e2dc 100644 --- a/src/test/scala/zio/nio/ChannelSuite.scala +++ b/src/test/scala/zio/nio/ChannelSuite.scala @@ -1,54 +1,95 @@ package zio.nio import zio.nio.channels.{ AsynchronousServerSocketChannel, AsynchronousSocketChannel } -import zio.{ DefaultRuntime, IO } +import zio.{ Chunk, DefaultRuntime, IO, ZIO } import testz.{ Harness, assert } object ChannelSuite extends DefaultRuntime { def tests[T](harness: Harness[T]): T = { import harness._ - section(test("read/write") { () => - val inetAddress = InetAddress.localHost - .flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 13370)) - - def echoServer: IO[Exception, Unit] = - for { - address <- inetAddress - sink <- Buffer.byte(3) - server <- AsynchronousServerSocketChannel() - _ <- server.bind(address) - worker <- server.accept - _ <- worker.readBuffer(sink) - _ <- sink.flip - _ <- worker.writeBuffer(sink) - _ <- worker.close - _ <- server.close - } yield () - - def echoClient: IO[Exception, Boolean] = - for { - address <- inetAddress - src <- Buffer.byte(3) - client <- AsynchronousSocketChannel() - _ <- client.connect(address) - sent <- src.array - _ = sent.update(0, 1) - _ <- client.writeBuffer(src) - _ <- src.flip - _ <- client.readBuffer(src) - received <- src.array - _ <- client.close - } yield sent.sameElements(received) - - val testProgram: IO[Exception, Boolean] = for { - serverFiber <- echoServer.fork - clientFiber <- echoClient.fork - _ <- serverFiber.join - same <- clientFiber.join - } yield same - - assert(unsafeRun(testProgram)) - }) + section( + test("read/write") { () => + val inetAddress = InetAddress.localHost + .flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 13370)) + + def echoServer: IO[Exception, Unit] = + for { + address <- inetAddress + sink <- Buffer.byte(3) + server <- AsynchronousServerSocketChannel() + _ <- server.bind(address) + worker <- server.accept + _ <- worker.readBuffer(sink) + _ <- sink.flip + _ <- worker.writeBuffer(sink) + _ <- worker.close + _ <- server.close + } yield () + + def echoClient: IO[Exception, Boolean] = + for { + address <- inetAddress + src <- Buffer.byte(3) + client <- AsynchronousSocketChannel() + _ <- client.connect(address) + sent <- src.array + _ = sent.update(0, 1) + _ <- client.writeBuffer(src) + _ <- src.flip + _ <- client.readBuffer(src) + received <- src.array + _ <- client.close + } yield sent.sameElements(received) + + val testProgram: IO[Exception, Boolean] = for { + serverFiber <- echoServer.fork + clientFiber <- echoClient.fork + _ <- serverFiber.join + same <- clientFiber.join + } yield same + + assert(unsafeRun(testProgram)) + }, + test("read should fail when connection close") { () => + val inetAddress = InetAddress.localHost + .flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 13370)) + + def server: IO[Exception, Boolean] = { + for { + address <- inetAddress + server <- AsynchronousServerSocketChannel() + _ <- server.bind(address) + worker <- server.accept + _ <- worker.read(3) + _ <- worker.read(3) + _ <- worker.close + _ <- server.close + } yield false + }.catchSome { + case ex: java.io.IOException if ex.getMessage == "Connection reset by peer" => + ZIO.succeed(true) + } + + def client: IO[Exception, Unit] = + for { + address <- inetAddress + client <- AsynchronousSocketChannel() + _ <- client.connect(address) + _ = client.write(Chunk.fromArray(Array[Byte](1, 1, 1))) + _ <- client.close + } yield () + + val testProgram: IO[Exception, Boolean] = for { + serverFiber <- server.fork + clientFiber <- client.fork + same <- serverFiber.join + _ <- clientFiber.join + } yield same + + assert(unsafeRun(testProgram)) + } + ) + } }