hpds_net_framework/tcpConn.go

258 lines
5.4 KiB
Go
Raw Normal View History

2022-08-03 16:55:40 +08:00
package hpds_net_framework
import (
"encoding/binary"
"fmt"
"git.hpds.cc/Component/logging"
"github.com/google/uuid"
"go.uber.org/zap"
"io"
"net"
"runtime/debug"
"sync"
"sync/atomic"
"time"
)
// TCPConn is warped tcp conn for luck
type TCPConn struct {
sync.RWMutex
uuid string
net.Conn
// 缓写队列
writeQueue chan []byte
// 逻辑消息队列
logicQueue chan []byte
// 消息处理器
processor Processor
userData interface{}
node INode
// after close
closeCb func()
closeFlag int64
logger *logging.Logger
}
// KCPConn 可靠的UDPlike TCP
type KCPConn struct {
*TCPConn
}
// NewKcpConn get new kcp conn
func NewKcpConn(conn net.Conn, processor Processor, logger *logging.Logger) *KCPConn {
tcpConn := NewTcpConn(conn, processor, logger)
if tcpConn != nil {
return &KCPConn{tcpConn}
}
return nil
}
// NewTcpConn return new tcp conn
func NewTcpConn(conn net.Conn, processor Processor, logger *logging.Logger) *TCPConn {
if processor == nil || conn == nil {
return nil
}
tc := &TCPConn{
uuid: uuid.New().String(),
Conn: conn,
writeQueue: make(chan []byte, Cfg.ConnWriteQueueSize),
processor: processor,
// 单个缓存100个为处理的包
logicQueue: make(chan []byte, Cfg.ConnUndoQueueSize),
logger: logger,
}
// write q
go func() {
for pkg := range tc.writeQueue {
if pkg == nil {
break
}
if Cfg.ConnWriteTimeout > 0 {
_ = tc.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(Cfg.ConnWriteTimeout)))
}
_, err := tc.Write(pkg)
if err != nil {
logger.Error("tcp write", zap.Error(err))
break
}
_ = tc.SetWriteDeadline(time.Time{})
}
// write over or error
_ = tc.Close()
logger.Info("Conn Close",
zap.String("local address", tc.Conn.LocalAddr().String()),
zap.String("remote address", tc.Conn.RemoteAddr().String()),
)
}()
// logic q
go func() {
for pkg := range tc.logicQueue {
// logic over
if pkg == nil {
break
}
// processor handle the package
func() {
defer func() {
if r := recover(); r != nil {
logger.Error("processor panic",
zap.Any("panic", r),
zap.ByteString("stack", debug.Stack()),
)
}
}()
_ = tc.processor.OnReceivedPackage(tc, pkg)
}()
}
}()
return tc
}
// GetUuid get uuid of conn
func (tc *TCPConn) GetUuid() string {
return tc.uuid
}
// ReadMsg read | write end -> write | read end -> conn end
func (tc *TCPConn) ReadMsg() {
defer func() {
tc.logicQueue <- nil
tc.writeQueue <- nil
// force close conn
if !tc.IsClosed() {
_ = tc.Close()
}
}()
bf := make([]byte, Cfg.MaxDataPackageSize)
// 第一个包默认5秒
timeout := time.Second * time.Duration(Cfg.FirstPackageTimeout)
for {
_ = tc.SetReadDeadline(time.Now().Add(timeout))
// read length
_, err := io.ReadAtLeast(tc, bf[:2], 2)
if err != nil {
tc.logger.Error("TCPConn read message head",
zap.Error(err),
)
return
}
var ln uint16
if tc.processor.GetBigEndian() {
ln = binary.BigEndian.Uint16(bf[:2])
} else {
ln = binary.LittleEndian.Uint16(bf[:2])
}
if ln < 1 || int(ln) > Cfg.MaxDataPackageSize {
tc.logger.Error("TCPConn message length invalid",
zap.Uint16("length", ln),
zap.Error(fmt.Errorf("TCPConn message length invalid")),
)
return
}
// read data
_, err = io.ReadFull(tc, bf[:ln])
if err != nil {
tc.logger.Error("TCPConn read data",
zap.Error(err),
)
return
}
// clean
_ = tc.SetDeadline(time.Time{})
// write to cache queue
select {
case tc.logicQueue <- append(make([]byte, 0), bf[:ln]...):
default:
// ignore overflow package not close conn
tc.logger.Error("TCPConn logic queue overflow",
zap.String("local address", tc.LocalAddr().String()),
zap.String("remote address", tc.RemoteAddr().String()),
zap.Int("queue length", len(tc.logicQueue)),
zap.Error(fmt.Errorf("TCPConn logic queue overflow")),
)
}
// after first pack | check heartbeat
timeout = time.Second * time.Duration(Cfg.ConnReadTimeout)
}
}
// WriteMsg warp msg base on connection's processor
func (tc *TCPConn) WriteMsg(message interface{}) {
pkg, err := tc.processor.WrapMsg(message)
if err != nil {
tc.logger.Error("OnWrapMsg package",
zap.Error(err),
)
} else {
push:
select {
case tc.writeQueue <- pkg:
default:
if tc.IsClosed() {
return
}
time.Sleep(time.Millisecond * 50)
// re push
goto push
}
}
}
// Close the connection
func (tc *TCPConn) Close() error {
tc.Lock()
defer func() {
tc.Unlock()
// add close flag
atomic.AddInt64(&tc.closeFlag, 1)
if tc.closeCb != nil {
tc.closeCb()
}
// clean write q if not empty
for len(tc.writeQueue) > 0 {
<-tc.writeQueue
}
}()
return tc.Conn.Close()
}
// IsClosed return the status of conn
func (tc *TCPConn) IsClosed() bool {
return atomic.LoadInt64(&tc.closeFlag) != 0
}
// AfterClose conn call back
func (tc *TCPConn) AfterClose(cb func()) {
tc.Lock()
defer tc.Unlock()
tc.closeCb = cb
}
// SetData for conn
func (tc *TCPConn) SetData(data interface{}) {
tc.Lock()
defer tc.Unlock()
tc.userData = data
}
// GetData from conn
func (tc *TCPConn) GetData() interface{} {
tc.RLock()
defer tc.RUnlock()
return tc.userData
}
// SetNode for conn
func (tc *TCPConn) SetNode(node INode) {
tc.Lock()
defer tc.Unlock()
tc.node = node
}
// GetNode from conn
func (tc *TCPConn) GetNode() INode {
tc.RLock()
defer tc.RUnlock()
return tc.node
}