package network import ( "context" "errors" "fmt" "io" "net" "os" "sync" "sync/atomic" // authentication implements, Currently, only token authentication is implemented _ "git.hpds.cc/Component/network/auth" "git.hpds.cc/Component/network/frame" "git.hpds.cc/Component/network/hpds_err" "git.hpds.cc/Component/network/log" pkgtls "git.hpds.cc/Component/network/tls" "github.com/lucas-clemente/quic-go" ) const ( // DefaultListenAddr is the default address to listen. DefaultListenAddr = "0.0.0.0:9000" ) // ServerOption is the option for server. type ServerOption func(*ServerOptions) // FrameHandler is the handler for frame. type FrameHandler func(c *Context) error // Server is the underlining server of Message Queue type Server struct { name string state string connector Connector router Router metadataBuilder MetadataBuilder counterOfDataFrame int64 downStreams map[string]*Client mu sync.Mutex opts ServerOptions beforeHandlers []FrameHandler afterHandlers []FrameHandler } // NewServer create a Server instance. func NewServer(name string, opts ...ServerOption) *Server { s := &Server{ name: name, connector: newConnector(), downStreams: make(map[string]*Client), } _ = s.Init(opts...) return s } // Init the options. func (s *Server) Init(opts ...ServerOption) error { for _, o := range opts { o(&s.opts) } // options defaults s.initOptions() return nil } // ListenAndServe starts the server. func (s *Server) ListenAndServe(ctx context.Context, addr string) error { if addr == "" { addr = DefaultListenAddr } udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return err } conn, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } return s.Serve(ctx, conn) } // Serve the server with a net.PacketConn. func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { if err := s.validateMetadataBuilder(); err != nil { return err } if err := s.validateRouter(); err != nil { return err } // listen the address listener, err := newListener(conn, s.opts.TLSConfig, s.opts.QuicConfig) if err != nil { log.Errorf("%slistener.Listen: err=%v", ServerLogPrefix, err) return err } defer func() { _ = listener.Close() }() log.Printf("%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), mode(), listener.Versions(), s.authNames()) s.state = ConnStateConnected for { _ = s.createNewClientConnection(ctx, listener) } } // createNewClientConnection create a new connection when new hpds-client connected func (s *Server) createNewClientConnection(ctx context.Context, listener Listener) error { sctx, cancel := context.WithCancel(ctx) defer cancel() connect, e := listener.Accept(sctx) if e != nil { log.Errorf("%screate connection error: %v", ServerLogPrefix, e) return e } connId := GetConnId(connect) log.Infof("%s1/ new connection: %s", ServerLogPrefix, connId) go func(ctx context.Context, qConn quic.Connection) { for { err := s.handle(ctx, qConn, connect, connId) if err != nil { break } } }(sctx, connect) return nil } func (s *Server) handle(ctx context.Context, qConn quic.Connection, conn quic.Connection, connId string) error { log.Infof("%s2/ waiting for new stream", ServerLogPrefix) stream, err := qConn.AcceptStream(ctx) if err != nil { name := "--" if conn := s.connector.Get(connId); conn != nil { _ = conn.Close() // connector s.connector.Remove(connId) route := s.router.Route(conn.Metadata()) if route != nil { _ = route.Remove(connId) } name = conn.Name() } log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connId, err) return err } defer func() { _ = stream.Close() }() log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connId) // process frames on stream // c := newContext(connId, stream) c := newContext(conn, stream) defer c.Clean() s.handleConnection(c) log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID()) return nil } // Close will shut down the server. func (s *Server) Close() error { if s.router != nil { s.router.Clean() } // connector if s.connector != nil { s.connector.Clean() } return nil } // handle streams on a connection func (s *Server) handleConnection(c *Context) { fs := NewFrameStream(c.Stream) // check update for stream for { log.Debugf("%shandleConnection waiting read next...", ServerLogPrefix) f, err := fs.ReadFrame() if err != nil { // if client close connection, will get ApplicationError with code = 0x00 if e, ok := err.(*quic.ApplicationError); ok { if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) { // client abort log.Infof("%sclient close the connection", ServerLogPrefix) break } else { ye := hpds_err.New(hpds_err.Parse(e.ErrorCode), err) log.Errorf("%s[ERR] %s", ServerLogPrefix, ye) } } else if err == io.EOF { log.Infof("%sthe connection is EOF", ServerLogPrefix) break } if errors.Is(err, net.ErrClosed) { // if client close the connection, net.ErrClosed will be raised // by quic-go IdleTimeoutError after connection's KeepAlive config. log.Warnf("%s[ERR] net.ErrClosed on [handleConnection] %v", ServerLogPrefix, net.ErrClosed) c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed") break } // any error occurred, we should close the stream // after this, conn.AcceptStream() will raise the error c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error()) log.Warnf("%sconnection.Close()", ServerLogPrefix) break } frameType := f.Type() data := f.Encode() log.Debugf("%stype=%s, frame[%d]=%# x", ServerLogPrefix, frameType, len(data), frame.Shortly(data)) // add frame to contextFrame contextFrame := c.WithFrame(f) // before frame handlers for _, handler := range s.beforeHandlers { if e := handler(contextFrame); e != nil { log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) contextFrame.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error()) return } } // main handler if e := s.mainFrameHandler(contextFrame); e != nil { log.Errorf("%smainFrameHandler e: %s", ServerLogPrefix, e) contextFrame.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error()) return } // after frame handler for _, handler := range s.afterHandlers { if e := handler(contextFrame); e != nil { log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) contextFrame.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error()) return } } } } func (s *Server) mainFrameHandler(c *Context) error { var err error frameType := c.Frame.Type() switch frameType { case frame.TagOfHandshakeFrame: if err = s.handleHandshakeFrame(c); err != nil { log.Errorf("%shandleHandshakeFrame err: %s", ServerLogPrefix, err) // close connections early to avoid resource consumption if c.Stream != nil { goawayFrame := frame.NewGoawayFrame(err.Error()) if _, e := c.Stream.Write(goawayFrame.Encode()); e != nil { log.Errorf("%s write to client[%s] GoawayFrame error:%v", ServerLogPrefix, c.ConnId, e) return e } } } // case frame.TagOfPingFrame: // s.handlePingFrame(mainStream, connection, f.(*frame.PingFrame)) case frame.TagOfDataFrame: if err = s.handleDataFrame(c); err != nil { c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err)) } else { conn := s.connector.Get(c.connId) if conn != nil && conn.ClientType() == ClientTypeProtocolGateway { f := c.Frame.(*frame.DataFrame) f.GetMetaFrame().SetMetadata(conn.Metadata().Encode()) s.dispatchToDownStreams(f) } // observe data tags back flow _ = s.handleBackFlowFrame(c) } default: log.Errorf("%serr=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode())) } return nil } // handle HandShakeFrame func (s *Server) handleHandshakeFrame(c *Context) error { f := c.Frame.(*frame.HandshakeFrame) log.Debugf("%sGOT HandshakeFrame : %# x", ServerLogPrefix, f) // basic info connId := c.ConnId() clientId := f.ClientId clientType := ClientType(f.ClientType) stream := c.Stream // credential log.Debugf("%sClientType=%# x is %s, ClientId=%s, Credential=%s", ServerLogPrefix, f.ClientType, ClientType(f.ClientType), clientId, authName(f.AuthName())) // authenticate if !s.authenticate(f) { err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName())) // return err log.Debugf("%s <%s> [%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, connId) rejectedFrame := frame.NewRejectedFrame(err.Error()) if _, err = stream.Write(rejectedFrame.Encode()); err != nil { log.Debugf("%s write to <%s> [%s](%s) RejectedFrame error:%v", ServerLogPrefix, clientType, f.Name, connId, err) return err } return nil } // client type var conn Connection switch clientType { case ClientTypeProtocolGateway, ClientTypeStreamFunction: // metadata metadata, err := s.metadataBuilder.Build(f) if err != nil { return err } conn = newConnection(f.Name, f.ClientId, clientType, metadata, stream, f.ObserveDataTags) if clientType == ClientTypeStreamFunction { // route route := s.router.Route(metadata) if route == nil { return errors.New("handleHandshakeFrame route is nil") } if e1 := route.Add(connId, f.Name, f.ObserveDataTags); e1 != nil { // duplicate name if e2, ok := e1.(hpds_err.DuplicateNameError); ok { existsConnID := e2.ConnId() if conn = s.connector.Get(existsConnID); conn != nil { log.Debugf("%s%s, write to SFN[%s](%s) GoawayFrame", ServerLogPrefix, e2.Error(), f.Name, existsConnID) goawayFrame := frame.NewGoawayFrame(e2.Error()) if e3 := conn.Write(goawayFrame); e3 != nil { log.Errorf("%s write to SFN[%s] GoawayFrame error:%v", ServerLogPrefix, f.Name, e3) return e3 } } } else { return e1 } } } case ClientTypeMessageQueue: conn = newConnection(f.Name, f.ClientId, clientType, nil, stream, f.ObserveDataTags) default: // unknown client type s.connector.Remove(connId) err := fmt.Errorf("Illegal ClientType: %#x ", f.ClientType) c.CloseWithError(hpds_err.ErrorCodeUnknownClient, err.Error()) return err } s.connector.Add(connId, conn) log.Printf("%s <%s> [%s][%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, clientId, connId) return nil } // handle handleGoawayFrame func (s *Server) handleGoawayFrame(c *Context) error { f := c.Frame.(*frame.GoawayFrame) log.Debugf("%s GOT GoawayFrame code=%d, message==%s", ServerLogPrefix, hpds_err.ErrorCodeGoaway, f.Message()) // c.CloseWithError(f.Code(), f.Message()) _, err := c.Stream.Write(f.Encode()) return err } // will reuse quic-go's keep-alive feature // func (s *Server) handlePingFrame(stream quic.Stream, conn quic.Connection, f *frame.PingFrame) error { // log.Infof("%s------> GOT PingFrame : %# x", ServerLogPrefix, f) // return nil // } func (s *Server) handleDataFrame(c *Context) error { // counter +1 atomic.AddInt64(&s.counterOfDataFrame, 1) // currentIssuer := f.GetIssuer() fromId := c.ConnId() from := s.connector.Get(fromId) if from == nil { log.Warnf("%shandleDataFrame connector cannot find %s", ServerLogPrefix, fromId) return fmt.Errorf("handleDataFrame connector cannot find %s", fromId) } f := c.Frame.(*frame.DataFrame) var metadata Metadata if from.ClientType() == ClientTypeMessageQueue { m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata()) if err != nil { return err } metadata = m } else { metadata = from.Metadata() } // route route := s.router.Route(metadata) if route == nil { log.Warnf("%shandleDataFrame route is nil", ServerLogPrefix) return fmt.Errorf("handleDataFrame route is nil") } // get stream function connection ids from route connIds := route.GetForwardRoutes(f.GetDataTag()) for _, toId := range connIds { conn := s.connector.Get(toId) if conn == nil { log.Errorf("%sconn is nil: (%s)", ServerLogPrefix, toId) continue } to := conn.Name() log.Debugf("%shandleDataFrame tag=%#x tid=%s, counter=%d, from=[%s](%s), to=[%s](%s)", ServerLogPrefix, f.Tag(), f.TransactionId(), s.counterOfDataFrame, from.Name(), fromId, to, toId) // write data frame to stream if err := conn.Write(f); err != nil { log.Warnf("%shandleDataFrame conn.Write tag=%#x tid=%s, from=[%s](%s), to=[%s](%s), %v", ServerLogPrefix, f.Tag(), f.TransactionId(), from.Name(), fromId, to, toId, err) } } return nil } func (s *Server) handleBackFlowFrame(c *Context) error { f := c.Frame.(*frame.DataFrame) tag := f.GetDataTag() carriage := f.GetCarriage() sourceId := f.SourceId() // write to Protocol Gateway with BackFlowFrame bf := frame.NewBackFlowFrame(tag, carriage) sourceConns := s.connector.GetProtocolGatewayConnections(sourceId, tag) // conn := s.connector.Get(c.connId) // logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage) for _, source := range sourceConns { if source != nil { log.Debugf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, result=%# x", ServerLogPrefix, tag, sourceId, frame.Shortly(carriage)) if err := source.Write(bf); err != nil { log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err) return err } } } return nil } // StatsFunctions returns the sfn stats of server. func (s *Server) StatsFunctions() map[string]string { return s.connector.GetSnapshot() } // StatsCounter returns how many DataFrames pass through server. func (s *Server) StatsCounter() int64 { return s.counterOfDataFrame } // DownStreams return all the downstream servers. func (s *Server) DownStreams() map[string]*Client { return s.downStreams } // ConfigRouter is used to set router by Message Queue func (s *Server) ConfigRouter(router Router) { s.mu.Lock() s.router = router log.Debugf("%sconfig router is %#v", ServerLogPrefix, router) s.mu.Unlock() } // ConfigMetadataBuilder is used to set metadataBuilder by Message Queue func (s *Server) ConfigMetadataBuilder(builder MetadataBuilder) { s.mu.Lock() s.metadataBuilder = builder log.Debugf("%sconfig metadataBuilder is %#v", ServerLogPrefix, builder) s.mu.Unlock() } // AddDownstreamServer add a downstream server to this server. all the DataFrames will be // dispatch to all the downStreams. func (s *Server) AddDownstreamServer(addr string, c *Client) { s.mu.Lock() s.downStreams[addr] = c s.mu.Unlock() } // dispatch every DataFrames to all downStreams func (s *Server) dispatchToDownStreams(df *frame.DataFrame) { for addr, ds := range s.downStreams { log.Debugf("%sdispatching to [%s]: %# x", ServerLogPrefix, addr, df.Tag()) _ = ds.WriteFrame(df) } } // GetConnId get quic connection id func GetConnId(conn quic.Connection) string { return conn.RemoteAddr().String() } func (s *Server) initOptions() { // defaults } func (s *Server) validateRouter() error { if s.router == nil { return errors.New("server's router is nil") } return nil } func (s *Server) validateMetadataBuilder() error { if s.metadataBuilder == nil { return errors.New("server's metadataBuilder is nil") } return nil } // Options returns the options of server. func (s *Server) Options() ServerOptions { return s.opts } // Connector returns the connector of server. func (s *Server) Connector() Connector { return s.connector } // SetBeforeHandlers set the before handlers of server. func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) { s.beforeHandlers = append(s.beforeHandlers, handlers...) } // SetAfterHandlers set the after handlers of server. func (s *Server) SetAfterHandlers(handlers ...FrameHandler) { s.afterHandlers = append(s.afterHandlers, handlers...) } func (s *Server) authNames() []string { if len(s.opts.Auths) == 0 { return []string{"none"} } result := make([]string, 0) for _, auth := range s.opts.Auths { result = append(result, auth.Name()) } return result } func (s *Server) authenticate(f *frame.HandshakeFrame) bool { if len(s.opts.Auths) > 0 { for _, auth := range s.opts.Auths { if f.AuthName() == auth.Name() { isAuthenticated := auth.Authenticate(f.AuthPayload()) if isAuthenticated { log.Debugf("%sauthenticated==%v", ServerLogPrefix, isAuthenticated) return isAuthenticated } } } return false } return true } func mode() string { if pkgtls.IsDev() { return "DEVELOPMENT" } return "PRODUCTION" } func authName(name string) string { if name == "" { return "empty" } return name }