Skip to content
This repository was archived by the owner on Nov 28, 2019. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions src/main/scala/zio/nio/channels/AsynchronousChannel.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package zio.nio.channels

import java.io.IOException
import java.lang.{ Integer => JInteger, Long => JLong, Void => JVoid }
import java.nio.{ ByteBuffer => JByteBuffer }
import java.nio.channels.{
Expand Down Expand Up @@ -27,9 +28,13 @@ class AsynchronousByteChannel(private val channel: JAsynchronousByteChannel) {
final def read(capacity: Int): IO[Exception, Chunk[Byte]] =
for {
b <- Buffer.byte(capacity)
_ <- readBuffer(b)
l <- readBuffer(b)
a <- b.array
r = Chunk.fromArray(a)
r <- if (l == -1) {
ZIO.fail(new IOException("Connection reset by peer"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

} else {
ZIO.succeed(Chunk.fromArray(a).take(l))
}
} yield r

/**
Expand All @@ -43,9 +48,13 @@ class AsynchronousByteChannel(private val channel: JAsynchronousByteChannel) {
final def read[A](capacity: Int, attachment: A): IO[Exception, Chunk[Byte]] =
for {
b <- Buffer.byte(capacity)
_ <- readBuffer(b, attachment)
l <- readBuffer(b, attachment)
a <- b.array
r = Chunk.fromArray(a)
r <- if (l == -1) {
ZIO.fail(new IOException("Connection reset by peer"))
} else {
ZIO.succeed(Chunk.fromArray(a).take(l))
}
} yield r

/**
Expand Down Expand Up @@ -216,9 +225,13 @@ class AsynchronousSocketChannel(private val channel: JAsynchronousSocketChannel)
final def read[A](capacity: Int, timeout: Duration, attachment: A): IO[Exception, Chunk[Byte]] =
for {
b <- Buffer.byte(capacity)
_ <- readBuffer(b, timeout, attachment)
l <- readBuffer(b, timeout, attachment)
a <- b.array
r = Chunk.fromArray(a)
r <- if (l == -1) {
ZIO.fail(new IOException("Connection reset by peer"))
} else {
ZIO.succeed(Chunk.fromArray(a).take(l))
}
} yield r

final private[nio] def readBuffer[A](
Expand Down Expand Up @@ -249,11 +262,15 @@ class AsynchronousSocketChannel(private val channel: JAsynchronousSocketChannel)
attachment: A
): IO[Exception, List[Chunk[Byte]]] =
for {
bs <- IO.collectAll(capacities.map(Buffer.byte(_)))
_ <- readBuffer(bs, offset, length, timeout, attachment)
bs <- IO.collectAll(capacities.map(Buffer.byte))
l <- readBuffer(bs, offset, length, timeout, attachment)
as <- IO.collectAll(bs.map(_.array))
ds = as.map(Chunk.fromArray(_))
} yield ds
r <- if (l == -1) {
ZIO.fail(new IOException("Connection reset by peer"))
} else {
ZIO.succeed(as.map(Chunk.fromArray))
}
} yield r

}

Expand Down
127 changes: 84 additions & 43 deletions src/test/scala/zio/nio/ChannelSuite.scala
Original file line number Diff line number Diff line change
@@ -1,54 +1,95 @@
package zio.nio

import zio.nio.channels.{ AsynchronousServerSocketChannel, AsynchronousSocketChannel }
import zio.{ DefaultRuntime, IO }
import zio.{ Chunk, DefaultRuntime, IO, ZIO }
import testz.{ Harness, assert }

object ChannelSuite extends DefaultRuntime {

def tests[T](harness: Harness[T]): T = {
import harness._
section(test("read/write") { () =>
val inetAddress = InetAddress.localHost
.flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 13370))

def echoServer: IO[Exception, Unit] =
for {
address <- inetAddress
sink <- Buffer.byte(3)
server <- AsynchronousServerSocketChannel()
_ <- server.bind(address)
worker <- server.accept
_ <- worker.readBuffer(sink)
_ <- sink.flip
_ <- worker.writeBuffer(sink)
_ <- worker.close
_ <- server.close
} yield ()

def echoClient: IO[Exception, Boolean] =
for {
address <- inetAddress
src <- Buffer.byte(3)
client <- AsynchronousSocketChannel()
_ <- client.connect(address)
sent <- src.array
_ = sent.update(0, 1)
_ <- client.writeBuffer(src)
_ <- src.flip
_ <- client.readBuffer(src)
received <- src.array
_ <- client.close
} yield sent.sameElements(received)

val testProgram: IO[Exception, Boolean] = for {
serverFiber <- echoServer.fork
clientFiber <- echoClient.fork
_ <- serverFiber.join
same <- clientFiber.join
} yield same

assert(unsafeRun(testProgram))
})
section(
test("read/write") { () =>
val inetAddress = InetAddress.localHost
.flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 13370))

def echoServer: IO[Exception, Unit] =
for {
address <- inetAddress
sink <- Buffer.byte(3)
server <- AsynchronousServerSocketChannel()
_ <- server.bind(address)
worker <- server.accept
_ <- worker.readBuffer(sink)
_ <- sink.flip
_ <- worker.writeBuffer(sink)
_ <- worker.close
_ <- server.close
} yield ()

def echoClient: IO[Exception, Boolean] =
for {
address <- inetAddress
src <- Buffer.byte(3)
client <- AsynchronousSocketChannel()
_ <- client.connect(address)
sent <- src.array
_ = sent.update(0, 1)
_ <- client.writeBuffer(src)
_ <- src.flip
_ <- client.readBuffer(src)
received <- src.array
_ <- client.close
} yield sent.sameElements(received)

val testProgram: IO[Exception, Boolean] = for {
serverFiber <- echoServer.fork
clientFiber <- echoClient.fork
_ <- serverFiber.join
same <- clientFiber.join
} yield same

assert(unsafeRun(testProgram))
},
test("read should fail when connection close") { () =>
val inetAddress = InetAddress.localHost
.flatMap(iAddr => SocketAddress.inetSocketAddress(iAddr, 13370))

def server: IO[Exception, Boolean] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave this branch a run and noticed it doesn't close down the listening socket. This means if you execute test:run a second time in the same sbt session, you get an address already bound error. I think you need to use bracketing to make sure the cleanup happens. Very roughly:

        def server: IO[Exception, Boolean] = {
          for {
            address <- inetAddress
            _  <- AsynchronousServerSocketChannel().bracket(_.close.ignore) { server =>
              server.bind(address) *> server.accept.bracket(_.close.ignore) { worker =>
                worker.read(3) *> worker.read(3)
              }
            }
          } yield false
        }.catchSome {
          case ex: java.io.IOException if ex.getMessage == "Connection reset by peer" =>
            ZIO.succeed(true)
        }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch I will update this.

for {
address <- inetAddress
server <- AsynchronousServerSocketChannel()
_ <- server.bind(address)
worker <- server.accept
_ <- worker.read(3)
_ <- worker.read(3)
_ <- worker.close
_ <- server.close
} yield false
}.catchSome {
case ex: java.io.IOException if ex.getMessage == "Connection reset by peer" =>
ZIO.succeed(true)
}

def client: IO[Exception, Unit] =
for {
address <- inetAddress
client <- AsynchronousSocketChannel()
_ <- client.connect(address)
_ = client.write(Chunk.fromArray(Array[Byte](1, 1, 1)))
_ <- client.close
} yield ()

val testProgram: IO[Exception, Boolean] = for {
serverFiber <- server.fork
clientFiber <- client.fork
same <- serverFiber.join
_ <- clientFiber.join
} yield same

assert(unsafeRun(testProgram))
}
)

}
}