package hpds_net_framework import ( "encoding/binary" "fmt" "git.hpds.cc/Component/logging" "github.com/google/uuid" "go.uber.org/zap" "io" "net" "runtime/debug" "sync" "sync/atomic" "time" ) // TCPConn is warped tcp conn for luck type TCPConn struct { sync.RWMutex uuid string net.Conn // 缓写队列 writeQueue chan []byte // 逻辑消息队列 logicQueue chan []byte // 消息处理器 processor Processor userData interface{} node INode // after close closeCb func() closeFlag int64 logger *logging.Logger } // KCPConn 可靠的UDP,like TCP type KCPConn struct { *TCPConn } // NewKcpConn get new kcp conn func NewKcpConn(conn net.Conn, processor Processor, logger *logging.Logger) *KCPConn { tcpConn := NewTcpConn(conn, processor, logger) if tcpConn != nil { return &KCPConn{tcpConn} } return nil } // NewTcpConn return new tcp conn func NewTcpConn(conn net.Conn, processor Processor, logger *logging.Logger) *TCPConn { if processor == nil || conn == nil { return nil } tc := &TCPConn{ uuid: uuid.New().String(), Conn: conn, writeQueue: make(chan []byte, Cfg.ConnWriteQueueSize), processor: processor, // 单个缓存100个为处理的包 logicQueue: make(chan []byte, Cfg.ConnUndoQueueSize), logger: logger, } // write q go func() { for pkg := range tc.writeQueue { if pkg == nil { break } if Cfg.ConnWriteTimeout > 0 { _ = tc.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(Cfg.ConnWriteTimeout))) } _, err := tc.Write(pkg) if err != nil { logger.Error("tcp write", zap.Error(err)) break } _ = tc.SetWriteDeadline(time.Time{}) } // write over or error _ = tc.Close() logger.Info("Conn Close", zap.String("local address", tc.Conn.LocalAddr().String()), zap.String("remote address", tc.Conn.RemoteAddr().String()), ) }() // logic q go func() { for pkg := range tc.logicQueue { // logic over if pkg == nil { break } // processor handle the package func() { defer func() { if r := recover(); r != nil { logger.Error("processor panic", zap.Any("panic", r), zap.ByteString("stack", debug.Stack()), ) } }() _ = tc.processor.OnReceivedPackage(tc, pkg) }() } }() return tc } // GetUuid get uuid of conn func (tc *TCPConn) GetUuid() string { return tc.uuid } // ReadMsg read | write end -> write | read end -> conn end func (tc *TCPConn) ReadMsg() { defer func() { tc.logicQueue <- nil tc.writeQueue <- nil // force close conn if !tc.IsClosed() { _ = tc.Close() } }() bf := make([]byte, Cfg.MaxDataPackageSize) // 第一个包默认5秒 timeout := time.Second * time.Duration(Cfg.FirstPackageTimeout) for { _ = tc.SetReadDeadline(time.Now().Add(timeout)) // read length _, err := io.ReadAtLeast(tc, bf[:2], 2) if err != nil { tc.logger.Error("TCPConn read message head", zap.Error(err), ) return } var ln uint16 if tc.processor.GetBigEndian() { ln = binary.BigEndian.Uint16(bf[:2]) } else { ln = binary.LittleEndian.Uint16(bf[:2]) } if ln < 1 || int(ln) > Cfg.MaxDataPackageSize { tc.logger.Error("TCPConn message length invalid", zap.Uint16("length", ln), zap.Error(fmt.Errorf("TCPConn message length invalid")), ) return } // read data _, err = io.ReadFull(tc, bf[:ln]) if err != nil { tc.logger.Error("TCPConn read data", zap.Error(err), ) return } // clean _ = tc.SetDeadline(time.Time{}) // write to cache queue select { case tc.logicQueue <- append(make([]byte, 0), bf[:ln]...): default: // ignore overflow package not close conn tc.logger.Error("TCPConn logic queue overflow", zap.String("local address", tc.LocalAddr().String()), zap.String("remote address", tc.RemoteAddr().String()), zap.Int("queue length", len(tc.logicQueue)), zap.Error(fmt.Errorf("TCPConn logic queue overflow")), ) } // after first pack | check heartbeat timeout = time.Second * time.Duration(Cfg.ConnReadTimeout) } } // WriteMsg warp msg base on connection's processor func (tc *TCPConn) WriteMsg(message interface{}) { pkg, err := tc.processor.WrapMsg(message) if err != nil { tc.logger.Error("OnWrapMsg package", zap.Error(err), ) } else { push: select { case tc.writeQueue <- pkg: default: if tc.IsClosed() { return } time.Sleep(time.Millisecond * 50) // re push goto push } } } // Close the connection func (tc *TCPConn) Close() error { tc.Lock() defer func() { tc.Unlock() // add close flag atomic.AddInt64(&tc.closeFlag, 1) if tc.closeCb != nil { tc.closeCb() } // clean write q if not empty for len(tc.writeQueue) > 0 { <-tc.writeQueue } }() return tc.Conn.Close() } // IsClosed return the status of conn func (tc *TCPConn) IsClosed() bool { return atomic.LoadInt64(&tc.closeFlag) != 0 } // AfterClose conn call back func (tc *TCPConn) AfterClose(cb func()) { tc.Lock() defer tc.Unlock() tc.closeCb = cb } // SetData for conn func (tc *TCPConn) SetData(data interface{}) { tc.Lock() defer tc.Unlock() tc.userData = data } // GetData from conn func (tc *TCPConn) GetData() interface{} { tc.RLock() defer tc.RUnlock() return tc.userData } // SetNode for conn func (tc *TCPConn) SetNode(node INode) { tc.Lock() defer tc.Unlock() tc.node = node } // GetNode from conn func (tc *TCPConn) GetNode() INode { tc.RLock() defer tc.RUnlock() return tc.node }