package network import ( "context" "errors" "fmt" "git.hpds.cc/Component/network/auth" "git.hpds.cc/Component/network/metadata" "git.hpds.cc/Component/network/router" "github.com/quic-go/quic-go" "io" "net" "os" "sync" "sync/atomic" // authentication implements, Currently, only token authentication is implemented _ "git.hpds.cc/Component/network/auth" "git.hpds.cc/Component/network/frame" "git.hpds.cc/Component/network/hpds_err" "git.hpds.cc/Component/network/log" ) // FrameHandler is the handler for frame. type FrameHandler func(c *Context) error // ConnectionHandler is the handler for quic connection type ConnectionHandler func(conn quic.Connection) // Server is the underlining server of Message Queue type Server struct { name string connector *Connector router router.Router metadataBuilder metadata.Builder counterOfDataFrame int64 downStreams map[string]frame.Writer mu sync.Mutex opts *serverOptions startHandlers []FrameHandler beforeHandlers []FrameHandler afterHandlers []FrameHandler connectionCloseHandlers []ConnectionHandler listener Listener logger log.Logger } // NewServer create a Server instance. func NewServer(name string, opts ...ServerOption) *Server { options := defaultServerOptions() for _, o := range opts { o(options) } s := &Server{ name: name, downStreams: make(map[string]frame.Writer), opts: options, } return s } // ListenAndServe starts the server. func (s *Server) ListenAndServe(ctx context.Context, addr string) error { if addr == "" { addr = DefaultListenAddr } udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return err } conn, err := net.ListenUDP("udp", udpAddr) if err != nil { return err } return s.Serve(ctx, conn) } // Serve the server with a net.PacketConn. func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { if err := s.validateMetadataBuilder(); err != nil { return err } if err := s.validateRouter(); err != nil { return err } s.connector = NewConnector(ctx) // listen the address listener, err := newListener(conn, s.opts.tlsConfig, s.opts.quicConfig) if err != nil { log.Errorf("%s listener.Listen: err=%v", ServerLogPrefix, err) return err } s.listener = listener log.Printf("%s [%s][%d] Listening on: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), listener.Versions(), s.authNames()) for { conn, err := s.listener.Accept(ctx) if err != nil { log.Errorf("%s listener accept connections error %v", ServerLogPrefix, err) return err } err = s.opts.alpnHandler(conn.ConnectionState().TLS.NegotiatedProtocol) if err != nil { _ = conn.CloseWithError(quic.ApplicationErrorCode(hpds_err.ErrorCodeRejected), err.Error()) continue } stream0, err := conn.AcceptStream(ctx) if err != nil { continue } controlStream := NewServerControlStream(conn, NewFrameStream(stream0)) // Auth accepts a AuthenticationFrame from client. The first frame from client must be // AuthenticationFrame, It returns true if auth successful otherwise return false. // It response to client a AuthenticationAckFrame. err = controlStream.VerifyAuthentication(s.handleAuthenticationFrame) if err != nil { log.Warnf("%s Authentication Failed, error: %s", ServerLogPrefix, err) continue } log.Debugf("%s Authentication Success", ServerLogPrefix) go func(qConn quic.Connection) { streamGroup := NewStreamGroup(ctx, controlStream, s.connector) defer streamGroup.Wait() defer s.doConnectionCloseHandlers(qConn) select { case <-ctx.Done(): return case err := <-s.runWithStreamGroup(streamGroup): log.Errorf("%s Client Close, %v", ServerLogPrefix, err) } }(conn) } } func (s *Server) runWithStreamGroup(group *StreamGroup) <-chan error { errCh := make(chan error) go func() { errCh <- group.Run(s.handleStreamContext) }() return errCh } // Close will shut down the server. func (s *Server) Close() error { // connector if s.connector != nil { s.connector.Close() } // listener if s.listener != nil { _ = s.listener.Close() } // router if s.router != nil { s.router.Clean() } return nil } func (s *Server) handleRoute(c *Context) error { if c.DataStream.StreamType() == StreamTypeStreamFunction { md, err := s.metadataBuilder.Decode(c.DataStream.Metadata()) if err != nil { return err } // route route := s.router.Route(md) if route == nil { return errors.New("handleHandshakeFrame route is nil") } if err := route.Add(c.StreamId(), c.DataStream.Name(), c.DataStream.ObserveDataTags()); err != nil { // duplicate name if e, ok := err.(hpds_err.DuplicateNameError); ok { existsConnId := e.ConnId() log.Debugf("%s StreamFunction Duplicate Name, error: %s; sfn_name: %s, old_stream_id: %s; current_stream_id: %s", ServerLogPrefix, e.Error(), c.DataStream.Name(), existsConnId, c.StreamId()) stream, ok, err := s.connector.Get(existsConnId) if err != nil { return err } if ok { _ = stream.CloseWithError(e.Error()) _ = s.connector.Remove(existsConnId) } } else { return err } } } return nil } // handleStreamContext handles data streams, func (s *Server) handleStreamContext(c *Context) { // handle route. if err := s.handleRoute(c); err != nil { c.CloseWithError(hpds_err.ErrorCodeRejected, err.Error()) return } defer s.cleanRoute(c) // start frame handlers for _, handler := range s.startHandlers { if err := handler(c); err != nil { log.Errorf("%s startHandlers error: %v", ServerLogPrefix, err) c.CloseWithError(hpds_err.ErrorCodeStartHandler, err.Error()) return } } // check update for stream for { f, err := c.DataStream.ReadFrame() if err != nil { // if client close connection, will get ApplicationError with code = 0x00 if e, ok := err.(*quic.ApplicationError); ok { if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) { // client abort log.Infof("%s client close the connection", ServerLogPrefix) break } he := hpds_err.New(hpds_err.Parse(e.ErrorCode), err) log.Errorf("%s read frame error: %v", ServerLogPrefix, he) } else if err == io.EOF { log.Infof("%s connection EOF", ServerLogPrefix) break } if errors.Is(err, net.ErrClosed) { // if client close the connection, net.ErrClosed will be raise // by quic-go IdleTimeoutError after connection's KeepAlive config. log.Warnf("%s connection error, error: %v", ServerLogPrefix, net.ErrClosed) c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed") break } // any error occurred, we should close the stream // after this, conn.AcceptStream() will raise the error c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error()) log.Warnf("%s connection close", ServerLogPrefix) break } // add frame to context c.WithFrame(f) // before frame handlers for _, handler := range s.beforeHandlers { if err := handler(c); err != nil { log.Errorf("%s beforeFrameHandler error: %v", ServerLogPrefix, err) c.CloseWithError(hpds_err.ErrorCodeBeforeHandler, err.Error()) return } } // main handler if err := s.mainFrameHandler(c); err != nil { log.Errorf("%s mainFrameHandler error: %v", ServerLogPrefix, err) c.CloseWithError(hpds_err.ErrorCodeMainHandler, err.Error()) return } // after frame handler for _, handler := range s.afterHandlers { if err := handler(c); err != nil { log.Errorf("%s afterFrameHandler error: %v", ServerLogPrefix, err) c.CloseWithError(hpds_err.ErrorCodeAfterHandler, err.Error()) return } } } } func (s *Server) mainFrameHandler(c *Context) error { var err error frameType := c.Frame.Type() switch frameType { case frame.TagOfDataFrame: if err = s.handleDataFrame(c); err != nil { c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err)) } else { s.dispatchToDownStreams(c) // observe data tags back flow _ = s.handleBackFlowFrame(c) } default: log.Errorf("%s err=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode())) } return nil } func (s *Server) handleAuthenticationFrame(f auth.Object) (bool, error) { ok := auth.Authenticate(s.opts.auths, f) if ok { log.Debugf("%s Successful authentication", ServerLogPrefix) } else { log.Warnf("%s Authentication failed, credential: %s", ServerLogPrefix, f.AuthName()) } return ok, nil } func (s *Server) handleDataFrame(c *Context) error { // counter +1 atomic.AddInt64(&s.counterOfDataFrame, 1) // currentIssuer := f.GetIssuer() fromId := c.StreamId() from, ok, err := s.connector.Get(fromId) if err != nil { return err } if !ok { log.Warnf("%s handleDataFrame connector cannot find, from_conn_id: %s", ServerLogPrefix, fromId) return fmt.Errorf("handleDataFrame connector cannot find %s", fromId) } f := c.Frame.(*frame.DataFrame) m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata()) if err != nil { return err } if m == nil { m, err = s.metadataBuilder.Decode(from.Metadata()) if err != nil { return err } } // route route := s.router.Route(m) if route == nil { log.Warnf("%s handleDataFrame route is nil", ServerLogPrefix) return fmt.Errorf("handleDataFrame route is nil") } // get stream function connection ids from route connIDs := route.GetForwardRoutes(f.GetDataTag()) log.Debugf("%s Data Routing Status, sfn_stream_ids: %v; connector: %v", ServerLogPrefix, connIDs, s.connector.GetSnapshot()) for _, toId := range connIDs { conn, ok, err := s.connector.Get(toId) if err != nil { continue } if !ok { log.Errorf("%s Can't find forward conn, error: conn is nil ;forward_conn_id: ", ServerLogPrefix, toId) continue } to := conn.Name() log.Infof("%s handleDataFrame, from_conn_name: %s; from_conn_id: %s; to_conn_name: %s; to_conn_id: %s; data_frame: %s", ServerLogPrefix, from.Name(), fromId, to, toId, f.String()) // write data frame to stream if err := conn.WriteFrame(f); err != nil { log.Errorf("%s handleDataFrame conn.Write, %v", ServerLogPrefix, err) } } return nil } func (s *Server) handleBackFlowFrame(c *Context) error { f := c.Frame.(*frame.DataFrame) tag := f.GetDataTag() carriage := f.GetCarriage() sourceId := f.SourceId() // write to Protocol Gateway with BackFlowFrame bf := frame.NewBackFlowFrame(tag, carriage) sourceConnList, err := s.connector.GetSourceConns(sourceId, tag) if err != nil { return err } // conn := s.connector.Get(c.connId) // logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage) for _, source := range sourceConnList { if source != nil { log.Debugf("%s handleBackFlowFrame tag:%#v; source_conn_id: %s, back_flow_frame: %s", ServerLogPrefix, tag, sourceId, f.String()) if err := source.WriteFrame(bf); err != nil { log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err) return err } } } return nil } // StatsFunctions returns the sfn stats of server. func (s *Server) StatsFunctions() map[string]string { return s.connector.GetSnapshot() } // StatsCounter returns how many DataFrames pass through server. func (s *Server) StatsCounter() int64 { return s.counterOfDataFrame } // DownStreams return all the downstream servers. func (s *Server) DownStreams() map[string]frame.Writer { return s.downStreams } // ConfigRouter is used to set router by Message Queue func (s *Server) ConfigRouter(router router.Router) { s.mu.Lock() s.router = router log.Debugf("%s config router is %#v", ServerLogPrefix, router) s.mu.Unlock() } // ConfigMetadataBuilder is used to set metadataBuilder by Message Queue func (s *Server) ConfigMetadataBuilder(builder metadata.Builder) { s.mu.Lock() s.metadataBuilder = builder log.Debugf("%s config metadataBuilder is %#v", ServerLogPrefix, builder) s.mu.Unlock() } // ConfigAlpnHandler is used to set alpnHandler by Emitter func (s *Server) ConfigAlpnHandler(h func(string) error) { s.mu.Lock() s.opts.alpnHandler = h log.Debugf("%s config alpnHandler", ServerLogPrefix) s.mu.Unlock() } // AddDownstreamServer add a downstream server to this server. all the DataFrames will be // dispatch to all the downStreams. func (s *Server) AddDownstreamServer(addr string, c *Client) { s.mu.Lock() s.downStreams[addr] = c s.mu.Unlock() } // dispatch every DataFrames to all downStreams func (s *Server) dispatchToDownStreams(c *Context) { stream, ok, err := s.connector.Get(c.StreamId()) if err != nil { log.Errorf("%s Connector Get Error, %v", ServerLogPrefix, err) return } if !ok { log.Debugf("%s dispatchTo Down Streams failed", ServerLogPrefix) return } if stream.StreamType() == StreamTypeSource { f := c.Frame.(*frame.DataFrame) if f.IsBroadcast() { if f.GetMetaFrame().Metadata() == nil { f.GetMetaFrame().SetMetadata(stream.Metadata()) } for addr, ds := range s.downStreams { log.Infof("%s dispatching to, dispatch_addr: %s; tid: %s;", ServerLogPrefix, addr, "", f.TransactionId()) _ = ds.WriteFrame(f) } } } } // GetConnId get quic connection id func GetConnId(conn quic.Connection) string { return conn.RemoteAddr().String() } func (s *Server) validateRouter() error { if s.router == nil { return errors.New("server's router is nil") } return nil } func (s *Server) validateMetadataBuilder() error { if s.metadataBuilder == nil { return errors.New("server's metadataBuilder is nil") } return nil } // SetStartHandlers sets a function for operating connection, // this function executes after handshake successful. func (s *Server) SetStartHandlers(handlers ...FrameHandler) { s.startHandlers = append(s.startHandlers, handlers...) } // SetBeforeHandlers set the before handlers of server. func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) { s.beforeHandlers = append(s.beforeHandlers, handlers...) } // SetAfterHandlers set the after handlers of server. func (s *Server) SetAfterHandlers(handlers ...FrameHandler) { s.afterHandlers = append(s.afterHandlers, handlers...) } // SetConnectionCloseHandlers set the connection close handlers of server. func (s *Server) SetConnectionCloseHandlers(handlers ...ConnectionHandler) { s.connectionCloseHandlers = append(s.connectionCloseHandlers, handlers...) } func (s *Server) authNames() []string { if len(s.opts.auths) == 0 { return []string{"none"} } result := make([]string, 0) for _, a := range s.opts.auths { result = append(result, a.Name()) } return result } func (s *Server) doConnectionCloseHandlers(qConn quic.Connection) { log.Debugf("%s QUIC Connection Closed", ServerLogPrefix) for _, h := range s.connectionCloseHandlers { h(qConn) } } func (s *Server) cleanRoute(c *Context) { md, _ := s.metadataBuilder.Decode(c.DataStream.Metadata()) _ = s.router.Route(md).Remove(c.StreamId()) }