hpds_net_framework/tcpConn.go

258 lines
5.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package hpds_net_framework
import (
"encoding/binary"
"fmt"
"git.hpds.cc/Component/logging"
"github.com/google/uuid"
"go.uber.org/zap"
"io"
"net"
"runtime/debug"
"sync"
"sync/atomic"
"time"
)
// TCPConn is warped tcp conn for luck
type TCPConn struct {
sync.RWMutex
uuid string
net.Conn
// 缓写队列
writeQueue chan []byte
// 逻辑消息队列
logicQueue chan []byte
// 消息处理器
processor Processor
userData interface{}
node INode
// after close
closeCb func()
closeFlag int64
logger *logging.Logger
}
// KCPConn 可靠的UDPlike TCP
type KCPConn struct {
*TCPConn
}
// NewKcpConn get new kcp conn
func NewKcpConn(conn net.Conn, processor Processor, logger *logging.Logger) *KCPConn {
tcpConn := NewTcpConn(conn, processor, logger)
if tcpConn != nil {
return &KCPConn{tcpConn}
}
return nil
}
// NewTcpConn return new tcp conn
func NewTcpConn(conn net.Conn, processor Processor, logger *logging.Logger) *TCPConn {
if processor == nil || conn == nil {
return nil
}
tc := &TCPConn{
uuid: uuid.New().String(),
Conn: conn,
writeQueue: make(chan []byte, Cfg.ConnWriteQueueSize),
processor: processor,
// 单个缓存100个为处理的包
logicQueue: make(chan []byte, Cfg.ConnUndoQueueSize),
logger: logger,
}
// write q
go func() {
for pkg := range tc.writeQueue {
if pkg == nil {
break
}
if Cfg.ConnWriteTimeout > 0 {
_ = tc.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(Cfg.ConnWriteTimeout)))
}
_, err := tc.Write(pkg)
if err != nil {
logger.Error("tcp write", zap.Error(err))
break
}
_ = tc.SetWriteDeadline(time.Time{})
}
// write over or error
_ = tc.Close()
logger.Info("Conn Close",
zap.String("local address", tc.Conn.LocalAddr().String()),
zap.String("remote address", tc.Conn.RemoteAddr().String()),
)
}()
// logic q
go func() {
for pkg := range tc.logicQueue {
// logic over
if pkg == nil {
break
}
// processor handle the package
func() {
defer func() {
if r := recover(); r != nil {
logger.Error("processor panic",
zap.Any("panic", r),
zap.ByteString("stack", debug.Stack()),
)
}
}()
_ = tc.processor.OnReceivedPackage(tc, pkg)
}()
}
}()
return tc
}
// GetUuid get uuid of conn
func (tc *TCPConn) GetUuid() string {
return tc.uuid
}
// ReadMsg read | write end -> write | read end -> conn end
func (tc *TCPConn) ReadMsg() {
defer func() {
tc.logicQueue <- nil
tc.writeQueue <- nil
// force close conn
if !tc.IsClosed() {
_ = tc.Close()
}
}()
bf := make([]byte, Cfg.MaxDataPackageSize)
// 第一个包默认5秒
timeout := time.Second * time.Duration(Cfg.FirstPackageTimeout)
for {
_ = tc.SetReadDeadline(time.Now().Add(timeout))
// read length
_, err := io.ReadAtLeast(tc, bf[:2], 2)
if err != nil {
tc.logger.Error("TCPConn read message head",
zap.Error(err),
)
return
}
var ln uint16
if tc.processor.GetBigEndian() {
ln = binary.BigEndian.Uint16(bf[:2])
} else {
ln = binary.LittleEndian.Uint16(bf[:2])
}
if ln < 1 || int(ln) > Cfg.MaxDataPackageSize {
tc.logger.Error("TCPConn message length invalid",
zap.Uint16("length", ln),
zap.Error(fmt.Errorf("TCPConn message length invalid")),
)
return
}
// read data
_, err = io.ReadFull(tc, bf[:ln])
if err != nil {
tc.logger.Error("TCPConn read data",
zap.Error(err),
)
return
}
// clean
_ = tc.SetDeadline(time.Time{})
// write to cache queue
select {
case tc.logicQueue <- append(make([]byte, 0), bf[:ln]...):
default:
// ignore overflow package not close conn
tc.logger.Error("TCPConn logic queue overflow",
zap.String("local address", tc.LocalAddr().String()),
zap.String("remote address", tc.RemoteAddr().String()),
zap.Int("queue length", len(tc.logicQueue)),
zap.Error(fmt.Errorf("TCPConn logic queue overflow")),
)
}
// after first pack | check heartbeat
timeout = time.Second * time.Duration(Cfg.ConnReadTimeout)
}
}
// WriteMsg warp msg base on connection's processor
func (tc *TCPConn) WriteMsg(message interface{}) {
pkg, err := tc.processor.WrapMsg(message)
if err != nil {
tc.logger.Error("OnWrapMsg package",
zap.Error(err),
)
} else {
push:
select {
case tc.writeQueue <- pkg:
default:
if tc.IsClosed() {
return
}
time.Sleep(time.Millisecond * 50)
// re push
goto push
}
}
}
// Close the connection
func (tc *TCPConn) Close() error {
tc.Lock()
defer func() {
tc.Unlock()
// add close flag
atomic.AddInt64(&tc.closeFlag, 1)
if tc.closeCb != nil {
tc.closeCb()
}
// clean write q if not empty
for len(tc.writeQueue) > 0 {
<-tc.writeQueue
}
}()
return tc.Conn.Close()
}
// IsClosed return the status of conn
func (tc *TCPConn) IsClosed() bool {
return atomic.LoadInt64(&tc.closeFlag) != 0
}
// AfterClose conn call back
func (tc *TCPConn) AfterClose(cb func()) {
tc.Lock()
defer tc.Unlock()
tc.closeCb = cb
}
// SetData for conn
func (tc *TCPConn) SetData(data interface{}) {
tc.Lock()
defer tc.Unlock()
tc.userData = data
}
// GetData from conn
func (tc *TCPConn) GetData() interface{} {
tc.RLock()
defer tc.RUnlock()
return tc.userData
}
// SetNode for conn
func (tc *TCPConn) SetNode(node INode) {
tc.Lock()
defer tc.Unlock()
tc.node = node
}
// GetNode from conn
func (tc *TCPConn) GetNode() INode {
tc.RLock()
defer tc.RUnlock()
return tc.node
}