From a750964220b67fafec8b35563a00ab0f7cf79f50 Mon Sep 17 00:00:00 2001 From: wangjian Date: Fri, 10 Mar 2023 23:49:52 +0800 Subject: [PATCH] fix bug --- client.go | 15 +++--- context.go | 4 +- frame/data_frame.go | 6 +-- frame/goaway_frame.go | 6 +-- frame/handshake_frame.go | 26 +++++----- frame/meta_frame.go | 14 ++--- frame/payload_frame.go | 2 +- frame/rejected_frame.go | 6 +-- log/zap.go | 14 ++--- server.go | 131 ++++++++++++++++++++++++++--------------------- tls/tls.go | 7 ++- 11 files changed, 122 insertions(+), 109 deletions(-) diff --git a/client.go b/client.go index 993e7a9..d4af3f3 100644 --- a/client.go +++ b/client.go @@ -56,7 +56,7 @@ func NewClient(appName string, connType ClientType, opts ...ClientOption) *Clien errChan: make(chan error), closeChan: make(chan bool), } - c.Init(opts...) + _ = c.Init(opts...) once.Do(func() { c.init() }) @@ -145,8 +145,9 @@ func (c *Client) handleFrame() { // this will block until a frame is received f, err := fs.ReadFrame() if err != nil { - defer c.stream.Close() - // defer c.conn.CloseWithError(0xD0, err.Error()) + defer func() { + _ = c.stream.Close() + }() c.logger.Debugf("%shandleFrame(): %T | %v", ClientLogPrefix, err, err) if e, ok := err.(*quic.IdleTimeoutError); ok { @@ -185,7 +186,7 @@ func (c *Client) handleFrame() { // any error occurred, we should close the stream // after this, conn.AcceptStream() will raise the error c.setState(ConnStateClosed) - c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeUnknown), err.Error()) + _ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeUnknown), err.Error()) c.logger.Errorf("%sunknown error occurred, err=%v, state=%v", ClientLogPrefix, err, c.getState()) break } @@ -210,7 +211,7 @@ func (c *Client) handleFrame() { c.setState(ConnStateRejected) if v, ok := f.(*frame.RejectedFrame); ok { c.logger.Errorf("%s receive RejectedFrame, message=%s", ClientLogPrefix, v.Message()) - c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeRejected), v.Message()) + _ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeRejected), v.Message()) c.errChan <- errors.New(v.Message()) break } @@ -218,7 +219,7 @@ func (c *Client) handleFrame() { c.setState(ConnStateGoaway) if v, ok := f.(*frame.GoawayFrame); ok { c.logger.Errorf("%s️ receive GoawayFrame, message=%s", ClientLogPrefix, v.Message()) - c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeGoaway), v.Message()) + _ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeGoaway), v.Message()) c.errChan <- errors.New(v.Message()) break } @@ -229,8 +230,6 @@ func (c *Client) handleFrame() { if c.processor == nil { c.logger.Warnf("%sprocessor is nil", ClientLogPrefix) } else { - // TODO: should c.processor accept a DataFrame as parameter? - // c.processor(v.GetDataTagID(), v.GetCarriage(), v.GetMetaFrame()) c.processor(v) } } diff --git a/context.go b/context.go index 2aceb6b..6d0d2cd 100644 --- a/context.go +++ b/context.go @@ -54,10 +54,10 @@ func (c *Context) Clean() { func (c *Context) CloseWithError(code hpds_err.ErrorCode, msg string) { log.Debugf("%sconn[%s] context close, errCode=%#x, msg=%s", ServerLogPrefix, c.connId, code, msg) if c.Stream != nil { - c.Stream.Close() + _ = c.Stream.Close() } if c.Conn != nil { - c.Conn.CloseWithError(quic.ApplicationErrorCode(code), msg) + _ = c.Conn.CloseWithError(quic.ApplicationErrorCode(code), msg) } c.Clean() } diff --git a/frame/data_frame.go b/frame/data_frame.go index 217d943..ec0b547 100644 --- a/frame/data_frame.go +++ b/frame/data_frame.go @@ -71,7 +71,7 @@ func (d *DataFrame) SourceId() string { // Encode return coder encoded bytes of `DataFrame` func (d *DataFrame) Encode() []byte { - data := coder.NewNodePacketEncoder(byte(d.Type())) + data := coder.NewNodePacketEncoder(d.Type()) // MetaFrame data.AddBytes(d.metaFrame.Encode()) // PayloadFrame @@ -90,7 +90,7 @@ func DecodeToDataFrame(buf []byte) (*DataFrame, error) { data := &DataFrame{} - if metaBlock, ok := packet.NodePackets[byte(TagOfMetaFrame)]; ok { + if metaBlock, ok := packet.NodePackets[TagOfMetaFrame]; ok { meta, err := DecodeToMetaFrame(metaBlock.GetRawBytes()) if err != nil { return nil, err @@ -98,7 +98,7 @@ func DecodeToDataFrame(buf []byte) (*DataFrame, error) { data.metaFrame = meta } - if payloadBlock, ok := packet.NodePackets[byte(TagOfPayloadFrame)]; ok { + if payloadBlock, ok := packet.NodePackets[TagOfPayloadFrame]; ok { payload, err := DecodeToPayloadFrame(payloadBlock.GetRawBytes()) if err != nil { return nil, err diff --git a/frame/goaway_frame.go b/frame/goaway_frame.go index dd82a2c..33c0a8f 100644 --- a/frame/goaway_frame.go +++ b/frame/goaway_frame.go @@ -21,9 +21,9 @@ func (f *GoawayFrame) Type() Type { // Encode to coder encoded bytes func (f *GoawayFrame) Encode() []byte { - goaway := coder.NewNodePacketEncoder(byte(f.Type())) + goaway := coder.NewNodePacketEncoder(f.Type()) // message - msgBlock := coder.NewPrimitivePacketEncoder(byte(TagOfGoawayMessage)) + msgBlock := coder.NewPrimitivePacketEncoder(TagOfGoawayMessage) msgBlock.SetStringValue(f.message) goaway.AddPrimitivePacket(msgBlock) @@ -46,7 +46,7 @@ func DecodeToGoawayFrame(buf []byte) (*GoawayFrame, error) { goaway := &GoawayFrame{} // message - if msgBlock, ok := node.PrimitivePackets[byte(TagOfGoawayMessage)]; ok { + if msgBlock, ok := node.PrimitivePackets[TagOfGoawayMessage]; ok { msg, err := msgBlock.ToUTF8String() if err != nil { return nil, err diff --git a/frame/handshake_frame.go b/frame/handshake_frame.go index 4b4cda6..e00672e 100644 --- a/frame/handshake_frame.go +++ b/frame/handshake_frame.go @@ -39,24 +39,24 @@ func (h *HandshakeFrame) Type() Type { // Encode to coder encoding. func (h *HandshakeFrame) Encode() []byte { // name - nameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeName)) + nameBlock := coder.NewPrimitivePacketEncoder(TagOfHandshakeName) nameBlock.SetStringValue(h.Name) // client ID - idBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeId)) + idBlock := coder.NewPrimitivePacketEncoder(TagOfHandshakeId) idBlock.SetStringValue(h.ClientId) // client type - typeBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeType)) + typeBlock := coder.NewPrimitivePacketEncoder(TagOfHandshakeType) typeBlock.SetBytesValue([]byte{h.ClientType}) // observe data tags - observeDataTagsBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeObserveDataTags)) + observeDataTagsBlock := coder.NewPrimitivePacketEncoder(TagOfHandshakeObserveDataTags) observeDataTagsBlock.SetBytesValue(h.ObserveDataTags) // auth - authNameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthName)) + authNameBlock := coder.NewPrimitivePacketEncoder(TagOfHandshakeAuthName) authNameBlock.SetStringValue(h.authName) - authPayloadBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthPayload)) + authPayloadBlock := coder.NewPrimitivePacketEncoder(TagOfHandshakeAuthPayload) authPayloadBlock.SetStringValue(h.authPayload) // handshake frame - handshake := coder.NewNodePacketEncoder(byte(h.Type())) + handshake := coder.NewNodePacketEncoder(h.Type()) handshake.AddPrimitivePacket(nameBlock) handshake.AddPrimitivePacket(idBlock) handshake.AddPrimitivePacket(typeBlock) @@ -77,7 +77,7 @@ func DecodeToHandshakeFrame(buf []byte) (*HandshakeFrame, error) { handshake := &HandshakeFrame{} // name - if nameBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeName)]; ok { + if nameBlock, ok := node.PrimitivePackets[TagOfHandshakeName]; ok { name, err := nameBlock.ToUTF8String() if err != nil { return nil, err @@ -85,7 +85,7 @@ func DecodeToHandshakeFrame(buf []byte) (*HandshakeFrame, error) { handshake.Name = name } // client id - if idBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeId)]; ok { + if idBlock, ok := node.PrimitivePackets[TagOfHandshakeId]; ok { id, err := idBlock.ToUTF8String() if err != nil { return nil, err @@ -93,23 +93,23 @@ func DecodeToHandshakeFrame(buf []byte) (*HandshakeFrame, error) { handshake.ClientId = id } // client type - if typeBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeType)]; ok { + if typeBlock, ok := node.PrimitivePackets[TagOfHandshakeType]; ok { clientType := typeBlock.ToBytes() handshake.ClientType = clientType[0] } // observe data tag list - if observeDataTagsBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeObserveDataTags)]; ok { + if observeDataTagsBlock, ok := node.PrimitivePackets[TagOfHandshakeObserveDataTags]; ok { handshake.ObserveDataTags = observeDataTagsBlock.ToBytes() } // auth - if authNameBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthName)]; ok { + if authNameBlock, ok := node.PrimitivePackets[TagOfHandshakeAuthName]; ok { authName, err := authNameBlock.ToUTF8String() if err != nil { return nil, err } handshake.authName = authName } - if authPayloadBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthPayload)]; ok { + if authPayloadBlock, ok := node.PrimitivePackets[TagOfHandshakeAuthPayload]; ok { authPayload, err := authPayloadBlock.ToUTF8String() if err != nil { return nil, err diff --git a/frame/meta_frame.go b/frame/meta_frame.go index e9beb32..8f95d2a 100644 --- a/frame/meta_frame.go +++ b/frame/meta_frame.go @@ -57,20 +57,20 @@ func (m *MetaFrame) SourceId() string { // Encode implements Frame.Encode method. func (m *MetaFrame) Encode() []byte { - meta := coder.NewNodePacketEncoder(byte(TagOfMetaFrame)) + meta := coder.NewNodePacketEncoder(TagOfMetaFrame) // transaction ID - transactionId := coder.NewPrimitivePacketEncoder(byte(TagOfTransactionId)) + transactionId := coder.NewPrimitivePacketEncoder(TagOfTransactionId) transactionId.SetStringValue(m.tid) meta.AddPrimitivePacket(transactionId) // source ID - sourceId := coder.NewPrimitivePacketEncoder(byte(TagOfSourceId)) + sourceId := coder.NewPrimitivePacketEncoder(TagOfSourceId) sourceId.SetStringValue(m.sourceId) meta.AddPrimitivePacket(sourceId) // metadata if m.metadata != nil { - metadata := coder.NewPrimitivePacketEncoder(byte(TagOfMetadata)) + metadata := coder.NewPrimitivePacketEncoder(TagOfMetadata) metadata.SetBytesValue(m.metadata) meta.AddPrimitivePacket(metadata) } @@ -89,17 +89,17 @@ func DecodeToMetaFrame(buf []byte) (*MetaFrame, error) { meta := &MetaFrame{} for k, v := range nodeBlock.PrimitivePackets { switch k { - case byte(TagOfTransactionId): + case TagOfTransactionId: val, err := v.ToUTF8String() if err != nil { return nil, err } meta.tid = val break - case byte(TagOfMetadata): + case TagOfMetadata: meta.metadata = v.ToBytes() break - case byte(TagOfSourceId): + case TagOfSourceId: sourceId, err := v.ToUTF8String() if err != nil { return nil, err diff --git a/frame/payload_frame.go b/frame/payload_frame.go index 2f1f43b..2282682 100644 --- a/frame/payload_frame.go +++ b/frame/payload_frame.go @@ -30,7 +30,7 @@ func (m *PayloadFrame) Encode() []byte { carriage := coder.NewPrimitivePacketEncoder(m.Tag) carriage.SetBytesValue(m.Carriage) - payload := coder.NewNodePacketEncoder(byte(TagOfPayloadFrame)) + payload := coder.NewNodePacketEncoder(TagOfPayloadFrame) payload.AddPrimitivePacket(carriage) return payload.Encode() diff --git a/frame/rejected_frame.go b/frame/rejected_frame.go index 9cb5db4..bc27bfe 100644 --- a/frame/rejected_frame.go +++ b/frame/rejected_frame.go @@ -21,9 +21,9 @@ func (f *RejectedFrame) Type() Type { // Encode to coder encoded bytes func (f *RejectedFrame) Encode() []byte { - rejected := coder.NewNodePacketEncoder(byte(f.Type())) + rejected := coder.NewNodePacketEncoder(f.Type()) // message - msgBlock := coder.NewPrimitivePacketEncoder(byte(TagOfRejectedMessage)) + msgBlock := coder.NewPrimitivePacketEncoder(TagOfRejectedMessage) msgBlock.SetStringValue(f.message) rejected.AddPrimitivePacket(msgBlock) @@ -45,7 +45,7 @@ func DecodeToRejectedFrame(buf []byte) (*RejectedFrame, error) { } rejected := &RejectedFrame{} // message - if msgBlock, ok := node.PrimitivePackets[byte(TagOfRejectedMessage)]; ok { + if msgBlock, ok := node.PrimitivePackets[TagOfRejectedMessage]; ok { msg, e := msgBlock.ToUTF8String() if e != nil { return nil, e diff --git a/log/zap.go b/log/zap.go index c81043c..6e94f01 100644 --- a/log/zap.go +++ b/log/zap.go @@ -130,12 +130,12 @@ func (z *zapLogger) Infof(template string, args ...interface{}) { } // Warnf logs a message at WarnLevel -func (z zapLogger) Warnf(template string, args ...interface{}) { +func (z *zapLogger) Warnf(template string, args ...interface{}) { z.Instance().Warnf(template, args...) } // Errorf logs a message at ErrorLevel -func (z zapLogger) Errorf(template string, args ...interface{}) { +func (z *zapLogger) Errorf(template string, args ...interface{}) { z.Instance().Errorf(template, args...) } @@ -190,7 +190,7 @@ func (z *zapLogger) Instance() *zap.SugaredLogger { if err != nil { return err } - rotatedLogger.Write(msg.Bytes()) + _, _ = rotatedLogger.Write(msg.Bytes()) } return nil }) @@ -204,11 +204,11 @@ func (z *zapLogger) Instance() *zap.SugaredLogger { return z.instance } -func errorRotatedLogger(file string, maxSize, maxBacukups, maxAge int) *lumberjack.Logger { +func errorRotatedLogger(file string, maxSize, maxBackups, maxAge int) *lumberjack.Logger { return &lumberjack.Logger{ Filename: file, MaxSize: maxSize, - MaxBackups: maxBacukups, + MaxBackups: maxBackups, MaxAge: maxAge, Compress: false, } @@ -221,7 +221,7 @@ func timeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { type logWriter struct{} func (l logWriter) Write(bytes []byte) (int, error) { - os.Stderr.WriteString(time.Now().Format(timeFormat)) - os.Stderr.Write([]byte("\t")) + _, _ = os.Stderr.WriteString(time.Now().Format(timeFormat)) + _, _ = os.Stderr.Write([]byte("\t")) return os.Stderr.Write(bytes) } diff --git a/server.go b/server.go index 87362ba..b0b9a56 100644 --- a/server.go +++ b/server.go @@ -52,7 +52,7 @@ func NewServer(name string, opts ...ServerOption) *Server { connector: newConnector(), downStreams: make(map[string]*Client), } - s.Init(opts...) + _ = s.Init(opts...) return s } @@ -100,59 +100,74 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { log.Errorf("%slistener.Listen: err=%v", ServerLogPrefix, err) return err } - defer listener.Close() + defer func() { + _ = listener.Close() + }() log.Printf("%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), mode(), listener.Versions(), s.authNames()) s.state = ConnStateConnected for { - // create a new connection when new hpds-client connected - sctx, cancel := context.WithCancel(ctx) - defer cancel() - - connect, e := listener.Accept(sctx) - if e != nil { - log.Errorf("%screate connection error: %v", ServerLogPrefix, e) - return e - } - - connID := GetConnId(connect) - log.Infof("%s1/ new connection: %s", ServerLogPrefix, connID) - - go func(ctx context.Context, qconn quic.Connection) { - for { - log.Infof("%s2/ waiting for new stream", ServerLogPrefix) - stream, err := qconn.AcceptStream(ctx) - if err != nil { - // if client close the connection, then we should close the connection - // @CC: when Source close the connection, it won't affect connectors - name := "--" - if conn := s.connector.Get(connID); conn != nil { - conn.Close() - // connector - s.connector.Remove(connID) - route := s.router.Route(conn.Metadata()) - if route != nil { - route.Remove(connID) - } - name = conn.Name() - } - log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connID, err) - break - } - defer stream.Close() - - log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connID) - // process frames on stream - // c := newContext(connId, stream) - c := newContext(connect, stream) - defer c.Clean() - s.handleConnection(c) - log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID()) - } - }(sctx, connect) + _ = s.createNewClientConnection(ctx, listener) } } +// createNewClientConnection create a new connection when new hpds-client connected +func (s *Server) createNewClientConnection(ctx context.Context, listener Listener) error { + sctx, cancel := context.WithCancel(ctx) + defer cancel() + + connect, e := listener.Accept(sctx) + if e != nil { + log.Errorf("%screate connection error: %v", ServerLogPrefix, e) + return e + } + + connId := GetConnId(connect) + log.Infof("%s1/ new connection: %s", ServerLogPrefix, connId) + + go func(ctx context.Context, qConn quic.Connection) { + for { + err := s.handle(ctx, qConn, connect, connId) + if err != nil { + break + } + } + }(sctx, connect) + return nil +} + +func (s *Server) handle(ctx context.Context, qConn quic.Connection, conn quic.Connection, connId string) error { + log.Infof("%s2/ waiting for new stream", ServerLogPrefix) + stream, err := qConn.AcceptStream(ctx) + if err != nil { + name := "--" + if conn := s.connector.Get(connId); conn != nil { + _ = conn.Close() + // connector + s.connector.Remove(connId) + route := s.router.Route(conn.Metadata()) + if route != nil { + _ = route.Remove(connId) + } + name = conn.Name() + } + log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connId, err) + return err + } + defer func() { + _ = stream.Close() + }() + + log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connId) + // process frames on stream + // c := newContext(connId, stream) + c := newContext(conn, stream) + defer c.Clean() + s.handleConnection(c) + log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID()) + return nil +} + // Close will shut down the server. func (s *Server) Close() error { if s.router != nil { @@ -204,28 +219,28 @@ func (s *Server) handleConnection(c *Context) { frameType := f.Type() data := f.Encode() log.Debugf("%stype=%s, frame[%d]=%# x", ServerLogPrefix, frameType, len(data), frame.Shortly(data)) - // add frame to context - context := c.WithFrame(f) + // add frame to contextFrame + contextFrame := c.WithFrame(f) // before frame handlers for _, handler := range s.beforeHandlers { - if e := handler(context); e != nil { + if e := handler(contextFrame); e != nil { log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) - context.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error()) + contextFrame.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error()) return } } // main handler - if e := s.mainFrameHandler(context); e != nil { + if e := s.mainFrameHandler(contextFrame); e != nil { log.Errorf("%smainFrameHandler e: %s", ServerLogPrefix, e) - context.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error()) + contextFrame.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error()) return } // after frame handler for _, handler := range s.afterHandlers { - if e := handler(context); e != nil { + if e := handler(contextFrame); e != nil { log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) - context.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error()) + contextFrame.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error()) return } } @@ -262,7 +277,7 @@ func (s *Server) mainFrameHandler(c *Context) error { s.dispatchToDownStreams(f) } // observe data tags back flow - s.handleBackFlowFrame(c) + _ = s.handleBackFlowFrame(c) } default: log.Errorf("%serr=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode())) @@ -334,7 +349,7 @@ func (s *Server) handleHandshakeFrame(c *Context) error { default: // unknown client type s.connector.Remove(connId) - err := fmt.Errorf("Illegal ClientType: %#x", f.ClientType) + err := fmt.Errorf("Illegal ClientType: %#x ", f.ClientType) c.CloseWithError(hpds_err.ErrorCodeUnknownClient, err.Error()) return err } @@ -477,7 +492,7 @@ func (s *Server) AddDownstreamServer(addr string, c *Client) { func (s *Server) dispatchToDownStreams(df *frame.DataFrame) { for addr, ds := range s.downStreams { log.Debugf("%sdispatching to [%s]: %# x", ServerLogPrefix, addr, df.Tag()) - ds.WriteFrame(df) + _ = ds.WriteFrame(df) } } diff --git a/tls/tls.go b/tls/tls.go index 30de2c2..ff444a7 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -10,7 +10,6 @@ import ( "crypto/x509/pkix" "encoding/pem" "errors" - "io/ioutil" "math/big" "net" "os" @@ -88,7 +87,7 @@ func getCACertPool() (*x509.CertPool, error) { return nil, errors.New("tls: must provide CA certificate on production mode, you can configure this via environment variables: `HPDS_TLS_CACERT_FILE`") } - caCert, err = ioutil.ReadFile(caCertPath) + caCert, err = os.ReadFile(caCertPath) if err != nil { return nil, err } @@ -116,12 +115,12 @@ func getCertAndKey() (*tls.Certificate, error) { } // certificate - cert, err = ioutil.ReadFile(certPath) + cert, err = os.ReadFile(certPath) if err != nil { return nil, err } // private key - key, err = ioutil.ReadFile(keyPath) + key, err = os.ReadFile(keyPath) if err != nil { return nil, err }