Skip to content

Commit 2f54ec9

Browse files
committed
Optimizing code
1 parent 9d9789f commit 2f54ec9

File tree

10 files changed

+406
-151
lines changed

10 files changed

+406
-151
lines changed

network/kcp/client_conn.go

Lines changed: 99 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,73 @@ import (
66
"sync/atomic"
77
"time"
88

9+
"github.com/dobyte/due/v2/core/buffer"
910
"github.com/dobyte/due/v2/errors"
1011
"github.com/dobyte/due/v2/log"
1112
"github.com/dobyte/due/v2/network"
1213
"github.com/dobyte/due/v2/packet"
1314
"github.com/dobyte/due/v2/utils/xcall"
1415
"github.com/dobyte/due/v2/utils/xnet"
1516
"github.com/dobyte/due/v2/utils/xtime"
17+
"github.com/xtaci/kcp-go/v5"
1618
)
1719

1820
type clientConn struct {
1921
rw sync.RWMutex
20-
id int64 // 连接ID
21-
uid int64 // 用户ID
22-
attr *attr // 连接属性
23-
conn net.Conn // TCP源连接
24-
state int32 // 连接状态
25-
client *client // 客户端
26-
chWrite chan chWrite // 写入队列
27-
done chan struct{} // 写入完成信号
28-
close chan struct{} // 关闭信号
29-
lastHeartbeatTime int64 // 上次心跳时间
22+
id int64 // 连接ID
23+
uid int64 // 用户ID
24+
attr *attr // 连接属性
25+
conn *kcp.UDPSession // UDP源连接
26+
state atomic.Int32 // 连接状态
27+
client *client // 客户端
28+
chWrite chan chWrite // 写入队列
29+
done chan struct{} // 写入完成信号
30+
close chan struct{} // 关闭信号
31+
lastHeartbeatTime atomic.Int64 // 上次心跳时间
3032
}
3133

3234
var _ network.Conn = &clientConn{}
3335

34-
func newClientConn(client *client, id int64, conn net.Conn) network.Conn {
36+
func newClientConn(client *client, id int64, conn *kcp.UDPSession) network.Conn {
3537
c := &clientConn{
36-
id: id,
37-
attr: &attr{},
38-
conn: conn,
39-
state: int32(network.ConnOpened),
40-
client: client,
41-
chWrite: make(chan chWrite, 4096),
42-
done: make(chan struct{}),
43-
close: make(chan struct{}),
44-
lastHeartbeatTime: xtime.Now().UnixNano(),
38+
id: id,
39+
attr: &attr{},
40+
conn: conn,
41+
client: client,
42+
chWrite: make(chan chWrite, 4096),
43+
done: make(chan struct{}),
44+
close: make(chan struct{}),
45+
}
46+
47+
c.state.Store(int32(network.ConnOpened))
48+
c.lastHeartbeatTime.Store(xtime.Now().UnixNano())
49+
50+
if c.client.opts.mtu > 0 {
51+
conn.SetMtu(c.client.opts.mtu)
52+
}
53+
54+
if len(c.client.opts.noDelay) == 4 {
55+
conn.SetNoDelay(c.client.opts.noDelay[0], c.client.opts.noDelay[1], c.client.opts.noDelay[2], c.client.opts.noDelay[3])
56+
}
57+
58+
if c.client.opts.ackNoDelay {
59+
conn.SetACKNoDelay(c.client.opts.ackNoDelay)
60+
}
61+
62+
if c.client.opts.writeDelay {
63+
conn.SetWriteDelay(c.client.opts.writeDelay)
64+
}
65+
66+
if len(c.client.opts.windowSize) == 2 {
67+
conn.SetWindowSize(c.client.opts.windowSize[0], c.client.opts.windowSize[1])
68+
}
69+
70+
if c.client.opts.readBuffer > 0 {
71+
conn.SetReadBuffer(c.client.opts.readBuffer)
72+
}
73+
74+
if c.client.opts.writeBuffer > 0 {
75+
conn.SetWriteBuffer(c.client.opts.writeBuffer)
4576
}
4677

4778
xcall.Go(c.read)
@@ -99,21 +130,26 @@ func (c *clientConn) Send(msg []byte) error {
99130
}
100131

101132
// Push 发送消息(异步)
102-
func (c *clientConn) Push(msg []byte) (err error) {
103-
if err = c.checkState(); err != nil {
104-
return
133+
func (c *clientConn) Push(msg []byte) error {
134+
if err := c.checkState(); err != nil {
135+
return err
105136
}
106137

107138
c.rw.RLock()
139+
defer c.rw.RUnlock()
140+
141+
if c.conn == nil {
142+
return errors.ErrConnectionClosed
143+
}
144+
108145
c.chWrite <- chWrite{typ: dataPacket, msg: msg}
109-
c.rw.RUnlock()
110146

111-
return
147+
return nil
112148
}
113149

114150
// State 获取连接状态
115151
func (c *clientConn) State() network.ConnState {
116-
return network.ConnState(atomic.LoadInt32(&c.state))
152+
return network.ConnState(c.state.Load())
117153
}
118154

119155
// Close 关闭连接
@@ -181,7 +217,7 @@ func (c *clientConn) RemoteAddr() (net.Addr, error) {
181217

182218
// 检测连接状态
183219
func (c *clientConn) checkState() error {
184-
switch network.ConnState(atomic.LoadInt32(&c.state)) {
220+
switch c.State() {
185221
case network.ConnHanged:
186222
return errors.ErrConnectionHanged
187223
case network.ConnClosed:
@@ -193,46 +229,47 @@ func (c *clientConn) checkState() error {
193229

194230
// 优雅关闭
195231
func (c *clientConn) graceClose() error {
196-
if !atomic.CompareAndSwapInt32(&c.state, int32(network.ConnOpened), int32(network.ConnHanged)) {
232+
if !c.state.CompareAndSwap(int32(network.ConnOpened), int32(network.ConnHanged)) {
197233
return errors.ErrConnectionNotOpened
198234
}
199235

200236
c.rw.RLock()
237+
if c.conn == nil {
238+
c.rw.RUnlock()
239+
return errors.ErrConnectionClosed
240+
}
201241
c.chWrite <- chWrite{typ: closeSig}
202242
c.rw.RUnlock()
203243

204244
<-c.done
205245

206-
if !atomic.CompareAndSwapInt32(&c.state, int32(network.ConnHanged), int32(network.ConnClosed)) {
246+
if !c.state.CompareAndSwap(int32(network.ConnHanged), int32(network.ConnClosed)) {
207247
return errors.ErrConnectionNotHanged
208248
}
209249

210-
c.rw.Lock()
211-
close(c.chWrite)
212-
close(c.close)
213-
close(c.done)
214-
conn := c.conn
215-
c.conn = nil
216-
c.rw.Unlock()
217-
218-
err := conn.Close()
219-
220-
if c.client.disconnectHandler != nil {
221-
c.client.disconnectHandler(c)
222-
}
223-
224-
return err
250+
return c.doClose()
225251
}
226252

227253
// 强制关闭
228254
func (c *clientConn) forceClose() error {
229-
if !atomic.CompareAndSwapInt32(&c.state, int32(network.ConnOpened), int32(network.ConnClosed)) {
230-
if !atomic.CompareAndSwapInt32(&c.state, int32(network.ConnHanged), int32(network.ConnClosed)) {
255+
if !c.state.CompareAndSwap(int32(network.ConnOpened), int32(network.ConnClosed)) {
256+
if !c.state.CompareAndSwap(int32(network.ConnHanged), int32(network.ConnClosed)) {
231257
return errors.ErrConnectionClosed
232258
}
233259
}
234260

261+
return c.doClose()
262+
}
263+
264+
// 执行关闭操作
265+
func (c *clientConn) doClose() error {
235266
c.rw.Lock()
267+
268+
if c.conn == nil {
269+
c.rw.Unlock()
270+
return errors.ErrConnectionClosed
271+
}
272+
236273
close(c.chWrite)
237274
close(c.close)
238275
close(c.done)
@@ -265,7 +302,7 @@ func (c *clientConn) read() {
265302
}
266303

267304
if c.client.opts.heartbeatInterval > 0 {
268-
atomic.StoreInt64(&c.lastHeartbeatTime, xtime.Now().UnixNano())
305+
c.lastHeartbeatTime.Store(xtime.Now().UnixNano())
269306
}
270307

271308
switch c.State() {
@@ -277,6 +314,11 @@ func (c *clientConn) read() {
277314
// ignore
278315
}
279316

317+
// ignore empty packet
318+
if len(msg) == 0 {
319+
continue
320+
}
321+
280322
isHeartbeat, err := packet.CheckHeartbeat(msg)
281323
if err != nil {
282324
log.Errorf("check heartbeat message error: %v", err)
@@ -288,13 +330,8 @@ func (c *clientConn) read() {
288330
continue
289331
}
290332

291-
// ignore empty packet
292-
if len(msg) == 0 {
293-
continue
294-
}
295-
296333
if c.client.receiveHandler != nil {
297-
c.client.receiveHandler(c, msg)
334+
c.client.receiveHandler(c, buffer.NewBytes(msg))
298335
}
299336
}
300337
}
@@ -335,9 +372,14 @@ func (c *clientConn) write() {
335372
if _, err := conn.Write(r.msg); err != nil {
336373
log.Errorf("write data message error: %v", err)
337374
}
338-
case <-ticker.C:
339-
deadline := xtime.Now().Add(-2 * c.client.opts.heartbeatInterval).UnixNano()
340-
if atomic.LoadInt64(&c.lastHeartbeatTime) < deadline {
375+
case t, ok := <-ticker.C:
376+
if !ok {
377+
return
378+
}
379+
380+
deadline := t.Add(-2 * c.client.opts.heartbeatInterval).UnixNano()
381+
382+
if c.lastHeartbeatTime.Load() < deadline {
341383
log.Debugf("connection heartbeat timeout")
342384
_ = c.forceClose()
343385
return
@@ -361,5 +403,5 @@ func (c *clientConn) write() {
361403

362404
// 是否已关闭
363405
func (c *clientConn) isClosed() bool {
364-
return network.ConnState(atomic.LoadInt32(&c.state)) == network.ConnClosed
406+
return network.ConnState(c.state.Load()) == network.ConnClosed
365407
}

network/kcp/client_options.go

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package kcp
22

33
import (
4-
"github.com/dobyte/due/v2/etc"
54
"time"
5+
6+
"github.com/dobyte/due/v2/etc"
67
)
78

89
const (
@@ -12,9 +13,16 @@ const (
1213
)
1314

1415
const (
15-
defaultClientDialAddrKey = "etc.network.tcp.client.addr"
16-
defaultClientDialTimeoutKey = "etc.network.tcp.client.timeout"
17-
defaultClientHeartbeatIntervalKey = "etc.network.tcp.client.heartbeatInterval"
16+
defaultClientDialAddrKey = "etc.network.kcp.client.addr"
17+
defaultClientDialTimeoutKey = "etc.network.kcp.client.timeout"
18+
defaultClientHeartbeatIntervalKey = "etc.network.kcp.client.heartbeatInterval"
19+
defaultClientMtuKey = "etc.network.kcp.client.mtu"
20+
defaultClientNoDelayKey = "etc.network.kcp.client.noDelay"
21+
defaultClientAckNoDelayKey = "etc.network.kcp.client.ackNoDelay"
22+
defaultClientWriteDelayKey = "etc.network.kcp.client.writeDelay"
23+
defaultClientWindowSizeKey = "etc.network.kcp.client.windowSize"
24+
defaultClientReadBufferKey = "etc.network.kcp.client.readBuffer"
25+
defaultClientWriteBufferKey = "etc.network.kcp.client.writeBuffer"
1826
)
1927

2028
type ClientOption func(o *clientOptions)
@@ -23,13 +31,27 @@ type clientOptions struct {
2331
addr string // 地址
2432
timeout time.Duration // 拨号超时时间,默认5s
2533
heartbeatInterval time.Duration // 心跳间隔时间,默认10s
34+
mtu int // 最大传输单元,默认不设置
35+
noDelay []int // 是否开启无延迟模式,默认不设置
36+
ackNoDelay bool // 是否开启ACK延迟确认,默认不设置
37+
writeDelay bool // 是否开启写延迟,默认不设置
38+
windowSize []int // 窗口大小,默认不设置
39+
readBuffer int // 读取缓冲区大小,默认不设置
40+
writeBuffer int // 写入缓冲区大小,默认不设置
2641
}
2742

2843
func defaultClientOptions() *clientOptions {
2944
return &clientOptions{
3045
addr: etc.Get(defaultClientDialAddrKey, defaultClientDialAddr).String(),
3146
timeout: etc.Get(defaultClientDialTimeoutKey, defaultClientDialTimeout).Duration(),
3247
heartbeatInterval: etc.Get(defaultClientHeartbeatIntervalKey, defaultClientHeartbeatInterval).Duration(),
48+
mtu: etc.Get(defaultClientMtuKey).Int(),
49+
noDelay: etc.Get(defaultClientNoDelayKey).Ints(),
50+
ackNoDelay: etc.Get(defaultClientAckNoDelayKey).Bool(),
51+
writeDelay: etc.Get(defaultClientWriteDelayKey).Bool(),
52+
windowSize: etc.Get(defaultClientWindowSizeKey).Ints(),
53+
readBuffer: int(etc.Get(defaultClientReadBufferKey).B()),
54+
writeBuffer: int(etc.Get(defaultClientWriteBufferKey).B()),
3355
}
3456
}
3557

@@ -47,3 +69,38 @@ func WithClientDialTimeout(timeout time.Duration) ClientOption {
4769
func WithClientHeartbeatInterval(heartbeatInterval time.Duration) ClientOption {
4870
return func(o *clientOptions) { o.heartbeatInterval = heartbeatInterval }
4971
}
72+
73+
// WithClientMtu 设置最大传输单元
74+
func WithClientMtu(mtu int) ClientOption {
75+
return func(o *clientOptions) { o.mtu = mtu }
76+
}
77+
78+
// WithClientNoDelay 设置是否开启无延迟模式
79+
func WithClientNoDelay(noDelay int) ClientOption {
80+
return func(o *clientOptions) { o.noDelay = append(o.noDelay, noDelay) }
81+
}
82+
83+
// WithClientAckNoDelay 设置是否开启ACK延迟确认
84+
func WithClientAckNoDelay(ackNoDelay bool) ClientOption {
85+
return func(o *clientOptions) { o.ackNoDelay = ackNoDelay }
86+
}
87+
88+
// WithClientWriteDelay 设置是否开启写延迟
89+
func WithClientWriteDelay(writeDelay bool) ClientOption {
90+
return func(o *clientOptions) { o.writeDelay = writeDelay }
91+
}
92+
93+
// WithClientWindowSize 设置窗口大小
94+
func WithClientWindowSize(windowSize int) ClientOption {
95+
return func(o *clientOptions) { o.windowSize = append(o.windowSize, windowSize) }
96+
}
97+
98+
// WithClientReadBuffer 设置读取缓冲区大小
99+
func WithClientReadBuffer(readBuffer int) ClientOption {
100+
return func(o *clientOptions) { o.readBuffer = readBuffer }
101+
}
102+
103+
// WithClientWriteBuffer 设置写入缓冲区大小
104+
func WithClientWriteBuffer(writeBuffer int) ClientOption {
105+
return func(o *clientOptions) { o.writeBuffer = writeBuffer }
106+
}

network/kcp/client_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"time"
1010

1111
"github.com/dobyte/due/network/kcp/v2"
12+
"github.com/dobyte/due/v2/core/buffer"
1213
"github.com/dobyte/due/v2/log"
1314
"github.com/dobyte/due/v2/network"
1415
"github.com/dobyte/due/v2/packet"
@@ -26,8 +27,10 @@ func TestClient_Simple(t *testing.T) {
2627
log.Info("connection is closed")
2728
})
2829

29-
client.OnReceive(func(conn network.Conn, msg []byte) {
30-
message, err := packet.UnpackMessage(msg)
30+
client.OnReceive(func(conn network.Conn, buf buffer.Buffer) {
31+
defer buf.Release()
32+
33+
message, err := packet.UnpackMessage(buf.Bytes())
3134
if err != nil {
3235
log.Errorf("unpack message failed: %v", err)
3336
return
@@ -145,7 +148,9 @@ func doPressureTest(c int, n int, size int) {
145148

146149
client := kcp.NewClient(kcp.WithClientHeartbeatInterval(0))
147150

148-
client.OnReceive(func(conn network.Conn, msg []byte) {
151+
client.OnReceive(func(conn network.Conn, buf buffer.Buffer) {
152+
defer buf.Release()
153+
149154
atomic.AddInt64(&totalRecv, 1)
150155

151156
wg.Done()

0 commit comments

Comments
 (0)