Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions rpcclient/infrastructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ out:
// handleSendPostMessage handles performing the passed HTTP request, reading the
// result, unmarshalling it, and delivering the unmarshalled result to the
// provided response channel.
func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
func (c *Client) handleSendPostMessage(ctx context.Context, jReq *jsonRequest) {
var (
lastErr error
backoff time.Duration
Expand All @@ -782,12 +782,17 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
}

tries := 10
retryloop:
for i := 0; i < tries; i++ {
var httpReq *http.Request

bodyReader := bytes.NewReader(jReq.marshalledJSON)
httpReq, err = http.NewRequest("POST", httpURL, bodyReader)
httpReq, err = http.NewRequestWithContext(ctx, "POST", httpURL, bodyReader)
if err != nil {
// We must observe the contract that shutdown returns ErrClientShutdown.
if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is unneeded since http.NewRequestWithContext will never return a error that's of type context.Canceled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in f87d914

err = ErrClientShutdown
}
jReq.responseChan <- &Response{result: nil, err: err}
return
}
Expand All @@ -812,6 +817,11 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
break
}

// We must observe the contract that shutdown returns ErrClientShutdown.
if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), ErrClientShutdown) {
err = ErrClientShutdown
}

// Save the last error for the case where we backoff further,
// retry and get an invalid response but no error. If this
// happens the saved last error will be used to enrich the error
Expand All @@ -830,8 +840,13 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
select {
case <-time.After(backoff):

case <-c.shutdown:
return
case <-ctx.Done():
err = ctx.Err()
// maintain our contract: shutdown errors are ErrClientShutdown
if errors.Is(context.Cause(ctx), ErrClientShutdown) {
err = ErrClientShutdown
}
break retryloop
}
}
if err != nil {
Expand Down Expand Up @@ -891,30 +906,28 @@ func (c *Client) handleSendPostMessage(jReq *jsonRequest) {
// in HTTP POST mode. It uses a buffered channel to serialize output messages
// while allowing the sender to continue running asynchronously. It must be run
// as a goroutine.
func (c *Client) sendPostHandler() {
func (c *Client) sendPostHandler(ctx context.Context) {
out:
for {
// Send any messages ready for send until the shutdown channel
// is closed.
select {
case jReq := <-c.sendPostChan:
c.handleSendPostMessage(jReq)
c.handleSendPostMessage(ctx, jReq)

case <-c.shutdown:
case <-ctx.Done():
break out
}
}

err := context.Cause(ctx)
// Drain any wait channels before exiting so nothing is left waiting
// around to send.
cleanup:
for {
select {
case jReq := <-c.sendPostChan:
jReq.responseChan <- &Response{
result: nil,
err: ErrClientShutdown,
}
jReq.responseChan <- &Response{result: nil, err: err}

default:
break cleanup
Expand Down Expand Up @@ -1178,8 +1191,13 @@ func (c *Client) start() {
// Start the I/O processing handlers depending on whether the client is
// in HTTP POST mode or the default websocket mode.
if c.config.HTTPPostMode {
ctx, cancel := context.WithCancelCause(context.Background())
c.wg.Add(1)
go c.sendPostHandler()
go c.sendPostHandler(ctx)
go func() {
<-c.shutdown
cancel(ErrClientShutdown)
}()
} else {
c.wg.Add(3)
go func() {
Expand Down
83 changes: 83 additions & 0 deletions rpcclient/infrastructure_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package rpcclient

import (
"io"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -108,3 +112,82 @@ func TestParseAddressString(t *testing.T) {
})
}
}

// TestHTTPPostShutdownInterruptsPendingRequest ensures that a client operating
// in HTTP POST mode can interrupt an in-flight request during shutdown.
func TestHTTPPostShutdownInterruptsPendingRequest(t *testing.T) {
t.Parallel()

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

requestAccepted := make(chan struct{})
serverDone := make(chan struct{})

go func() {
defer close(serverDone)

conn, err := listener.Accept()
if err != nil {
return
}
defer func() {
err := conn.Close()
assert.NoError(t, err)
}()

close(requestAccepted)

_, _ = io.Copy(io.Discard, conn)
}()

t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
<-serverDone
})

connCfg := &ConnConfig{
Host: listener.Addr().String(),
User: "user",
Pass: "pass",
DisableTLS: true,
HTTPPostMode: true,
}

client, err := New(connCfg, nil)
require.NoError(t, err)
t.Cleanup(client.Shutdown)

future := client.GetBlockCountAsync()

select {
case <-requestAccepted:
case <-time.After(2 * time.Second):
t.Fatalf("server did not accept client connection")
}

select {
case <-future:
t.Fatalf("expected request to remain pending until shutdown")
case <-time.After(100 * time.Millisecond):
}

client.Shutdown()

waitDone := make(chan struct{})
go func() {
client.WaitForShutdown()
close(waitDone)
}()

select {
case <-waitDone:
case <-time.After(5 * time.Second):
t.Fatalf("client shutdown did not complete")
}

result, err := future.Receive()
require.Zero(t, result)
require.ErrorContains(t, err, ErrClientShutdown.Error())
}
Loading