535 lines
15 KiB
Go
535 lines
15 KiB
Go
package network
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"git.hpds.cc/Component/network/auth"
|
|
"git.hpds.cc/Component/network/metadata"
|
|
"git.hpds.cc/Component/network/router"
|
|
"github.com/quic-go/quic-go"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
// authentication implements, Currently, only token authentication is implemented
|
|
_ "git.hpds.cc/Component/network/auth"
|
|
"git.hpds.cc/Component/network/frame"
|
|
"git.hpds.cc/Component/network/hpds_err"
|
|
"git.hpds.cc/Component/network/log"
|
|
)
|
|
|
|
// FrameHandler is the handler for frame.
|
|
type FrameHandler func(c *Context) error
|
|
|
|
// ConnectionHandler is the handler for quic connection
|
|
type ConnectionHandler func(conn quic.Connection)
|
|
|
|
// Server is the underlining server of Message Queue
|
|
type Server struct {
|
|
name string
|
|
connector *Connector
|
|
router router.Router
|
|
metadataBuilder metadata.Builder
|
|
counterOfDataFrame int64
|
|
downStreams map[string]frame.Writer
|
|
mu sync.Mutex
|
|
opts *serverOptions
|
|
startHandlers []FrameHandler
|
|
beforeHandlers []FrameHandler
|
|
afterHandlers []FrameHandler
|
|
connectionCloseHandlers []ConnectionHandler
|
|
listener Listener
|
|
logger log.Logger
|
|
}
|
|
|
|
// NewServer create a Server instance.
|
|
func NewServer(name string, opts ...ServerOption) *Server {
|
|
options := defaultServerOptions()
|
|
|
|
for _, o := range opts {
|
|
o(options)
|
|
}
|
|
s := &Server{
|
|
name: name,
|
|
downStreams: make(map[string]frame.Writer),
|
|
opts: options,
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// ListenAndServe starts the server.
|
|
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
|
|
if addr == "" {
|
|
addr = DefaultListenAddr
|
|
}
|
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn, err := net.ListenUDP("udp", udpAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.Serve(ctx, conn)
|
|
}
|
|
|
|
// Serve the server with a net.PacketConn.
|
|
func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error {
|
|
if err := s.validateMetadataBuilder(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.validateRouter(); err != nil {
|
|
return err
|
|
}
|
|
s.connector = NewConnector(ctx)
|
|
// listen the address
|
|
listener, err := newListener(conn, s.opts.tlsConfig, s.opts.quicConfig)
|
|
if err != nil {
|
|
log.Errorf("%s listener.Listen: err=%v", ServerLogPrefix, err)
|
|
return err
|
|
}
|
|
s.listener = listener
|
|
|
|
log.Printf("%s [%s][%d] Listening on: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), listener.Versions(), s.authNames())
|
|
|
|
for {
|
|
conn, err := s.listener.Accept(ctx)
|
|
if err != nil {
|
|
log.Errorf("%s listener accept connections error %v", ServerLogPrefix, err)
|
|
return err
|
|
}
|
|
err = s.opts.alpnHandler(conn.ConnectionState().TLS.NegotiatedProtocol)
|
|
if err != nil {
|
|
_ = conn.CloseWithError(quic.ApplicationErrorCode(hpds_err.ErrorCodeRejected), err.Error())
|
|
continue
|
|
}
|
|
stream0, err := conn.AcceptStream(ctx)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
controlStream := NewServerControlStream(conn, NewFrameStream(stream0))
|
|
|
|
// Auth accepts a AuthenticationFrame from client. The first frame from client must be
|
|
// AuthenticationFrame, It returns true if auth successful otherwise return false.
|
|
// It response to client a AuthenticationAckFrame.
|
|
err = controlStream.VerifyAuthentication(s.handleAuthenticationFrame)
|
|
if err != nil {
|
|
log.Warnf("%s Authentication Failed, error: %s", ServerLogPrefix, err)
|
|
continue
|
|
}
|
|
log.Debugf("%s Authentication Success", ServerLogPrefix)
|
|
|
|
go func(qConn quic.Connection) {
|
|
streamGroup := NewStreamGroup(ctx, controlStream, s.connector)
|
|
|
|
defer streamGroup.Wait()
|
|
defer s.doConnectionCloseHandlers(qConn)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case err := <-s.runWithStreamGroup(streamGroup):
|
|
log.Errorf("%s Client Close, %v", ServerLogPrefix, err)
|
|
}
|
|
}(conn)
|
|
}
|
|
}
|
|
|
|
func (s *Server) runWithStreamGroup(group *StreamGroup) <-chan error {
|
|
errCh := make(chan error)
|
|
|
|
go func() {
|
|
errCh <- group.Run(s.handleStreamContext)
|
|
}()
|
|
|
|
return errCh
|
|
}
|
|
|
|
// Close will shut down the server.
|
|
func (s *Server) Close() error {
|
|
// connector
|
|
if s.connector != nil {
|
|
s.connector.Close()
|
|
}
|
|
// listener
|
|
if s.listener != nil {
|
|
_ = s.listener.Close()
|
|
}
|
|
// router
|
|
if s.router != nil {
|
|
s.router.Clean()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) handleRoute(c *Context) error {
|
|
if c.DataStream.StreamType() == StreamTypeStreamFunction {
|
|
md, err := s.metadataBuilder.Decode(c.DataStream.Metadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// route
|
|
route := s.router.Route(md)
|
|
if route == nil {
|
|
return errors.New("handleHandshakeFrame route is nil")
|
|
}
|
|
if err := route.Add(c.StreamId(), c.DataStream.Name(), c.DataStream.ObserveDataTags()); err != nil {
|
|
// duplicate name
|
|
if e, ok := err.(hpds_err.DuplicateNameError); ok {
|
|
existsConnId := e.ConnId()
|
|
|
|
log.Debugf("%s StreamFunction Duplicate Name, error: %s; sfn_name: %s, old_stream_id: %s; current_stream_id: %s",
|
|
ServerLogPrefix, e.Error(), c.DataStream.Name(), existsConnId, c.StreamId())
|
|
|
|
stream, ok, err := s.connector.Get(existsConnId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if ok {
|
|
_ = stream.CloseWithError(e.Error())
|
|
_ = s.connector.Remove(existsConnId)
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleStreamContext handles data streams,
|
|
func (s *Server) handleStreamContext(c *Context) {
|
|
// handle route.
|
|
if err := s.handleRoute(c); err != nil {
|
|
c.CloseWithError(hpds_err.ErrorCodeRejected, err.Error())
|
|
return
|
|
}
|
|
defer s.cleanRoute(c)
|
|
|
|
// start frame handlers
|
|
for _, handler := range s.startHandlers {
|
|
if err := handler(c); err != nil {
|
|
log.Errorf("%s startHandlers error: %v", ServerLogPrefix, err)
|
|
c.CloseWithError(hpds_err.ErrorCodeStartHandler, err.Error())
|
|
return
|
|
}
|
|
}
|
|
// check update for stream
|
|
for {
|
|
f, err := c.DataStream.ReadFrame()
|
|
if err != nil {
|
|
// if client close connection, will get ApplicationError with code = 0x00
|
|
if e, ok := err.(*quic.ApplicationError); ok {
|
|
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
|
|
// client abort
|
|
log.Infof("%s client close the connection", ServerLogPrefix)
|
|
break
|
|
}
|
|
he := hpds_err.New(hpds_err.Parse(e.ErrorCode), err)
|
|
log.Errorf("%s read frame error: %v", ServerLogPrefix, he)
|
|
} else if err == io.EOF {
|
|
log.Infof("%s connection EOF", ServerLogPrefix)
|
|
break
|
|
}
|
|
if errors.Is(err, net.ErrClosed) {
|
|
// if client close the connection, net.ErrClosed will be raise
|
|
// by quic-go IdleTimeoutError after connection's KeepAlive config.
|
|
log.Warnf("%s connection error, error: %v", ServerLogPrefix, net.ErrClosed)
|
|
c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed")
|
|
break
|
|
}
|
|
// any error occurred, we should close the stream
|
|
// after this, conn.AcceptStream() will raise the error
|
|
c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error())
|
|
log.Warnf("%s connection close", ServerLogPrefix)
|
|
break
|
|
}
|
|
|
|
// add frame to context
|
|
c.WithFrame(f)
|
|
|
|
// before frame handlers
|
|
for _, handler := range s.beforeHandlers {
|
|
if err := handler(c); err != nil {
|
|
log.Errorf("%s beforeFrameHandler error: %v", ServerLogPrefix, err)
|
|
c.CloseWithError(hpds_err.ErrorCodeBeforeHandler, err.Error())
|
|
return
|
|
}
|
|
}
|
|
// main handler
|
|
if err := s.mainFrameHandler(c); err != nil {
|
|
log.Errorf("%s mainFrameHandler error: %v", ServerLogPrefix, err)
|
|
c.CloseWithError(hpds_err.ErrorCodeMainHandler, err.Error())
|
|
return
|
|
}
|
|
// after frame handler
|
|
for _, handler := range s.afterHandlers {
|
|
if err := handler(c); err != nil {
|
|
log.Errorf("%s afterFrameHandler error: %v", ServerLogPrefix, err)
|
|
c.CloseWithError(hpds_err.ErrorCodeAfterHandler, err.Error())
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) mainFrameHandler(c *Context) error {
|
|
var err error
|
|
frameType := c.Frame.Type()
|
|
|
|
switch frameType {
|
|
case frame.TagOfDataFrame:
|
|
if err = s.handleDataFrame(c); err != nil {
|
|
c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err))
|
|
} else {
|
|
s.dispatchToDownStreams(c)
|
|
// observe data tags back flow
|
|
_ = s.handleBackFlowFrame(c)
|
|
}
|
|
default:
|
|
log.Errorf("%s err=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode()))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) handleAuthenticationFrame(f auth.Object) (bool, error) {
|
|
ok := auth.Authenticate(s.opts.auths, f)
|
|
|
|
if ok {
|
|
log.Debugf("%s Successful authentication", ServerLogPrefix)
|
|
} else {
|
|
log.Warnf("%s Authentication failed, credential: %s", ServerLogPrefix, f.AuthName())
|
|
}
|
|
|
|
return ok, nil
|
|
}
|
|
func (s *Server) handleDataFrame(c *Context) error {
|
|
// counter +1
|
|
atomic.AddInt64(&s.counterOfDataFrame, 1)
|
|
// currentIssuer := f.GetIssuer()
|
|
fromId := c.StreamId()
|
|
from, ok, err := s.connector.Get(fromId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !ok {
|
|
log.Warnf("%s handleDataFrame connector cannot find, from_conn_id: %s", ServerLogPrefix, fromId)
|
|
return fmt.Errorf("handleDataFrame connector cannot find %s", fromId)
|
|
}
|
|
|
|
f := c.Frame.(*frame.DataFrame)
|
|
|
|
m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if m == nil {
|
|
m, err = s.metadataBuilder.Decode(from.Metadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// route
|
|
route := s.router.Route(m)
|
|
if route == nil {
|
|
log.Warnf("%s handleDataFrame route is nil", ServerLogPrefix)
|
|
return fmt.Errorf("handleDataFrame route is nil")
|
|
}
|
|
|
|
// get stream function connection ids from route
|
|
connIDs := route.GetForwardRoutes(f.GetDataTag())
|
|
|
|
log.Debugf("%s Data Routing Status, sfn_stream_ids: %v; connector: %v", ServerLogPrefix, connIDs, s.connector.GetSnapshot())
|
|
|
|
for _, toId := range connIDs {
|
|
conn, ok, err := s.connector.Get(toId)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if !ok {
|
|
log.Errorf("%s Can't find forward conn, error: conn is nil ;forward_conn_id: ", ServerLogPrefix, toId)
|
|
continue
|
|
}
|
|
|
|
to := conn.Name()
|
|
log.Infof("%s handleDataFrame, from_conn_name: %s; from_conn_id: %s; to_conn_name: %s; to_conn_id: %s; data_frame: %s",
|
|
ServerLogPrefix, from.Name(), fromId, to, toId, f.String())
|
|
|
|
// write data frame to stream
|
|
if err := conn.WriteFrame(f); err != nil {
|
|
log.Errorf("%s handleDataFrame conn.Write, %v", ServerLogPrefix, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
func (s *Server) handleBackFlowFrame(c *Context) error {
|
|
f := c.Frame.(*frame.DataFrame)
|
|
tag := f.GetDataTag()
|
|
carriage := f.GetCarriage()
|
|
sourceId := f.SourceId()
|
|
// write to Protocol Gateway with BackFlowFrame
|
|
bf := frame.NewBackFlowFrame(tag, carriage)
|
|
sourceConnList, err := s.connector.GetSourceConns(sourceId, tag)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// conn := s.connector.Get(c.connId)
|
|
// logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage)
|
|
for _, source := range sourceConnList {
|
|
if source != nil {
|
|
log.Debugf("%s handleBackFlowFrame tag:%#v; source_conn_id: %s, back_flow_frame: %s", ServerLogPrefix, tag, sourceId, f.String())
|
|
if err := source.WriteFrame(bf); err != nil {
|
|
log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err)
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// StatsFunctions returns the sfn stats of server.
|
|
func (s *Server) StatsFunctions() map[string]string {
|
|
return s.connector.GetSnapshot()
|
|
}
|
|
|
|
// StatsCounter returns how many DataFrames pass through server.
|
|
func (s *Server) StatsCounter() int64 {
|
|
return s.counterOfDataFrame
|
|
}
|
|
|
|
// DownStreams return all the downstream servers.
|
|
func (s *Server) DownStreams() map[string]frame.Writer {
|
|
return s.downStreams
|
|
}
|
|
|
|
// ConfigRouter is used to set router by Message Queue
|
|
func (s *Server) ConfigRouter(router router.Router) {
|
|
s.mu.Lock()
|
|
s.router = router
|
|
log.Debugf("%s config router is %#v", ServerLogPrefix, router)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue
|
|
|
|
func (s *Server) ConfigMetadataBuilder(builder metadata.Builder) {
|
|
s.mu.Lock()
|
|
s.metadataBuilder = builder
|
|
log.Debugf("%s config metadataBuilder is %#v", ServerLogPrefix, builder)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// ConfigAlpnHandler is used to set alpnHandler by Emitter
|
|
func (s *Server) ConfigAlpnHandler(h func(string) error) {
|
|
s.mu.Lock()
|
|
s.opts.alpnHandler = h
|
|
log.Debugf("%s config alpnHandler", ServerLogPrefix)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// AddDownstreamServer add a downstream server to this server. all the DataFrames will be
|
|
// dispatch to all the downStreams.
|
|
func (s *Server) AddDownstreamServer(addr string, c *Client) {
|
|
s.mu.Lock()
|
|
s.downStreams[addr] = c
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// dispatch every DataFrames to all downStreams
|
|
func (s *Server) dispatchToDownStreams(c *Context) {
|
|
stream, ok, err := s.connector.Get(c.StreamId())
|
|
if err != nil {
|
|
log.Errorf("%s Connector Get Error, %v", ServerLogPrefix, err)
|
|
return
|
|
}
|
|
if !ok {
|
|
log.Debugf("%s dispatchTo Down Streams failed", ServerLogPrefix)
|
|
return
|
|
}
|
|
|
|
if stream.StreamType() == StreamTypeSource {
|
|
f := c.Frame.(*frame.DataFrame)
|
|
if f.IsBroadcast() {
|
|
if f.GetMetaFrame().Metadata() == nil {
|
|
f.GetMetaFrame().SetMetadata(stream.Metadata())
|
|
}
|
|
for addr, ds := range s.downStreams {
|
|
log.Infof("%s dispatching to, dispatch_addr: %s; tid: %s;", ServerLogPrefix, addr, "", f.TransactionId())
|
|
_ = ds.WriteFrame(f)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetConnId get quic connection id
|
|
func GetConnId(conn quic.Connection) string {
|
|
return conn.RemoteAddr().String()
|
|
}
|
|
|
|
func (s *Server) validateRouter() error {
|
|
if s.router == nil {
|
|
return errors.New("server's router is nil")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) validateMetadataBuilder() error {
|
|
if s.metadataBuilder == nil {
|
|
return errors.New("server's metadataBuilder is nil")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetStartHandlers sets a function for operating connection,
|
|
// this function executes after handshake successful.
|
|
func (s *Server) SetStartHandlers(handlers ...FrameHandler) {
|
|
s.startHandlers = append(s.startHandlers, handlers...)
|
|
}
|
|
|
|
// SetBeforeHandlers set the before handlers of server.
|
|
func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) {
|
|
s.beforeHandlers = append(s.beforeHandlers, handlers...)
|
|
}
|
|
|
|
// SetAfterHandlers set the after handlers of server.
|
|
func (s *Server) SetAfterHandlers(handlers ...FrameHandler) {
|
|
s.afterHandlers = append(s.afterHandlers, handlers...)
|
|
}
|
|
|
|
// SetConnectionCloseHandlers set the connection close handlers of server.
|
|
func (s *Server) SetConnectionCloseHandlers(handlers ...ConnectionHandler) {
|
|
s.connectionCloseHandlers = append(s.connectionCloseHandlers, handlers...)
|
|
}
|
|
|
|
func (s *Server) authNames() []string {
|
|
if len(s.opts.auths) == 0 {
|
|
return []string{"none"}
|
|
}
|
|
result := make([]string, 0)
|
|
for _, a := range s.opts.auths {
|
|
result = append(result, a.Name())
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *Server) doConnectionCloseHandlers(qConn quic.Connection) {
|
|
log.Debugf("%s QUIC Connection Closed", ServerLogPrefix)
|
|
for _, h := range s.connectionCloseHandlers {
|
|
h(qConn)
|
|
}
|
|
}
|
|
|
|
func (s *Server) cleanRoute(c *Context) {
|
|
md, _ := s.metadataBuilder.Decode(c.DataStream.Metadata())
|
|
_ = s.router.Route(md).Remove(c.StreamId())
|
|
}
|