583 lines
16 KiB
Go
583 lines
16 KiB
Go
package network
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
// authentication implements, Currently, only token authentication is implemented
|
|
_ "git.hpds.cc/Component/network/auth"
|
|
"git.hpds.cc/Component/network/frame"
|
|
"git.hpds.cc/Component/network/hpds_err"
|
|
"git.hpds.cc/Component/network/log"
|
|
pkgtls "git.hpds.cc/Component/network/tls"
|
|
"github.com/lucas-clemente/quic-go"
|
|
)
|
|
|
|
const (
|
|
// DefaultListenAddr is the default address to listen.
|
|
DefaultListenAddr = "0.0.0.0:9000"
|
|
)
|
|
|
|
// ServerOption is the option for server.
|
|
type ServerOption func(*ServerOptions)
|
|
|
|
// FrameHandler is the handler for frame.
|
|
type FrameHandler func(c *Context) error
|
|
|
|
// Server is the underlining server of Message Queue
|
|
type Server struct {
|
|
name string
|
|
state string
|
|
connector Connector
|
|
router Router
|
|
metadataBuilder MetadataBuilder
|
|
counterOfDataFrame int64
|
|
downStreams map[string]*Client
|
|
mu sync.Mutex
|
|
opts ServerOptions
|
|
beforeHandlers []FrameHandler
|
|
afterHandlers []FrameHandler
|
|
}
|
|
|
|
// NewServer create a Server instance.
|
|
func NewServer(name string, opts ...ServerOption) *Server {
|
|
s := &Server{
|
|
name: name,
|
|
connector: newConnector(),
|
|
downStreams: make(map[string]*Client),
|
|
}
|
|
_ = s.Init(opts...)
|
|
|
|
return s
|
|
}
|
|
|
|
// Init the options.
|
|
func (s *Server) Init(opts ...ServerOption) error {
|
|
for _, o := range opts {
|
|
o(&s.opts)
|
|
}
|
|
// options defaults
|
|
s.initOptions()
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListenAndServe starts the server.
|
|
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
|
|
if addr == "" {
|
|
addr = DefaultListenAddr
|
|
}
|
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn, err := net.ListenUDP("udp", udpAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.Serve(ctx, conn)
|
|
}
|
|
|
|
// Serve the server with a net.PacketConn.
|
|
func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error {
|
|
if err := s.validateMetadataBuilder(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.validateRouter(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// listen the address
|
|
listener, err := newListener(conn, s.opts.TLSConfig, s.opts.QuicConfig)
|
|
if err != nil {
|
|
log.Errorf("%slistener.Listen: err=%v", ServerLogPrefix, err)
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = listener.Close()
|
|
}()
|
|
log.Printf("%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), mode(), listener.Versions(), s.authNames())
|
|
|
|
s.state = ConnStateConnected
|
|
for {
|
|
_ = s.createNewClientConnection(ctx, listener)
|
|
}
|
|
}
|
|
|
|
// createNewClientConnection create a new connection when new hpds-client connected
|
|
func (s *Server) createNewClientConnection(ctx context.Context, listener Listener) error {
|
|
sctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
connect, e := listener.Accept(sctx)
|
|
if e != nil {
|
|
log.Errorf("%screate connection error: %v", ServerLogPrefix, e)
|
|
return e
|
|
}
|
|
|
|
connId := GetConnId(connect)
|
|
log.Infof("%s1/ new connection: %s", ServerLogPrefix, connId)
|
|
|
|
go func(ctx context.Context, qConn quic.Connection) {
|
|
for {
|
|
err := s.handle(ctx, qConn, connect, connId)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
}(sctx, connect)
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) handle(ctx context.Context, qConn quic.Connection, conn quic.Connection, connId string) error {
|
|
log.Infof("%s2/ waiting for new stream", ServerLogPrefix)
|
|
stream, err := qConn.AcceptStream(ctx)
|
|
if err != nil {
|
|
name := "--"
|
|
if conn := s.connector.Get(connId); conn != nil {
|
|
_ = conn.Close()
|
|
// connector
|
|
s.connector.Remove(connId)
|
|
route := s.router.Route(conn.Metadata())
|
|
if route != nil {
|
|
_ = route.Remove(connId)
|
|
}
|
|
name = conn.Name()
|
|
}
|
|
log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connId, err)
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = stream.Close()
|
|
}()
|
|
|
|
log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connId)
|
|
// process frames on stream
|
|
// c := newContext(connId, stream)
|
|
c := newContext(conn, stream)
|
|
defer c.Clean()
|
|
s.handleConnection(c)
|
|
log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID())
|
|
return nil
|
|
}
|
|
|
|
// Close will shut down the server.
|
|
func (s *Server) Close() error {
|
|
if s.router != nil {
|
|
s.router.Clean()
|
|
}
|
|
// connector
|
|
if s.connector != nil {
|
|
s.connector.Clean()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handle streams on a connection
|
|
func (s *Server) handleConnection(c *Context) {
|
|
fs := NewFrameStream(c.Stream)
|
|
// check update for stream
|
|
for {
|
|
log.Debugf("%shandleConnection waiting read next...", ServerLogPrefix)
|
|
f, err := fs.ReadFrame()
|
|
if err != nil {
|
|
// if client close connection, will get ApplicationError with code = 0x00
|
|
if e, ok := err.(*quic.ApplicationError); ok {
|
|
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
|
|
// client abort
|
|
log.Infof("%sclient close the connection", ServerLogPrefix)
|
|
break
|
|
} else {
|
|
ye := hpds_err.New(hpds_err.Parse(e.ErrorCode), err)
|
|
log.Errorf("%s[ERR] %s", ServerLogPrefix, ye)
|
|
}
|
|
} else if err == io.EOF {
|
|
log.Infof("%sthe connection is EOF", ServerLogPrefix)
|
|
break
|
|
}
|
|
if errors.Is(err, net.ErrClosed) {
|
|
// if client close the connection, net.ErrClosed will be raised
|
|
// by quic-go IdleTimeoutError after connection's KeepAlive config.
|
|
log.Warnf("%s[ERR] net.ErrClosed on [handleConnection] %v", ServerLogPrefix, net.ErrClosed)
|
|
c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed")
|
|
break
|
|
}
|
|
// any error occurred, we should close the stream
|
|
// after this, conn.AcceptStream() will raise the error
|
|
c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error())
|
|
log.Warnf("%sconnection.Close()", ServerLogPrefix)
|
|
break
|
|
}
|
|
|
|
frameType := f.Type()
|
|
data := f.Encode()
|
|
log.Debugf("%stype=%s, frame[%d]=%# x", ServerLogPrefix, frameType, len(data), frame.Shortly(data))
|
|
// add frame to contextFrame
|
|
contextFrame := c.WithFrame(f)
|
|
|
|
// before frame handlers
|
|
for _, handler := range s.beforeHandlers {
|
|
if e := handler(contextFrame); e != nil {
|
|
log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e)
|
|
contextFrame.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error())
|
|
return
|
|
}
|
|
}
|
|
// main handler
|
|
if e := s.mainFrameHandler(contextFrame); e != nil {
|
|
log.Errorf("%smainFrameHandler e: %s", ServerLogPrefix, e)
|
|
contextFrame.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error())
|
|
return
|
|
}
|
|
// after frame handler
|
|
for _, handler := range s.afterHandlers {
|
|
if e := handler(contextFrame); e != nil {
|
|
log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e)
|
|
contextFrame.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error())
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) mainFrameHandler(c *Context) error {
|
|
var err error
|
|
frameType := c.Frame.Type()
|
|
|
|
switch frameType {
|
|
case frame.TagOfHandshakeFrame:
|
|
if err = s.handleHandshakeFrame(c); err != nil {
|
|
log.Errorf("%shandleHandshakeFrame err: %s", ServerLogPrefix, err)
|
|
// close connections early to avoid resource consumption
|
|
if c.Stream != nil {
|
|
goawayFrame := frame.NewGoawayFrame(err.Error())
|
|
if _, e := c.Stream.Write(goawayFrame.Encode()); e != nil {
|
|
log.Errorf("%s write to client[%s] GoawayFrame error:%v", ServerLogPrefix, c.ConnId, e)
|
|
return e
|
|
}
|
|
}
|
|
}
|
|
// case frame.TagOfPingFrame:
|
|
// s.handlePingFrame(mainStream, connection, f.(*frame.PingFrame))
|
|
case frame.TagOfDataFrame:
|
|
if err = s.handleDataFrame(c); err != nil {
|
|
c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err))
|
|
} else {
|
|
conn := s.connector.Get(c.connId)
|
|
if conn != nil && conn.ClientType() == ClientTypeProtocolGateway {
|
|
f := c.Frame.(*frame.DataFrame)
|
|
f.GetMetaFrame().SetMetadata(conn.Metadata().Encode())
|
|
s.dispatchToDownStreams(f)
|
|
}
|
|
// observe data tags back flow
|
|
_ = s.handleBackFlowFrame(c)
|
|
}
|
|
default:
|
|
log.Errorf("%serr=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode()))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handle HandShakeFrame
|
|
func (s *Server) handleHandshakeFrame(c *Context) error {
|
|
f := c.Frame.(*frame.HandshakeFrame)
|
|
|
|
log.Debugf("%sGOT HandshakeFrame : %# x", ServerLogPrefix, f)
|
|
// basic info
|
|
connId := c.ConnId()
|
|
clientId := f.ClientId
|
|
clientType := ClientType(f.ClientType)
|
|
stream := c.Stream
|
|
// credential
|
|
log.Debugf("%sClientType=%# x is %s, ClientId=%s, Credential=%s", ServerLogPrefix, f.ClientType, ClientType(f.ClientType), clientId, authName(f.AuthName()))
|
|
// authenticate
|
|
if !s.authenticate(f) {
|
|
err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName()))
|
|
// return err
|
|
log.Debugf("%s <%s> [%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, connId)
|
|
rejectedFrame := frame.NewRejectedFrame(err.Error())
|
|
if _, err = stream.Write(rejectedFrame.Encode()); err != nil {
|
|
log.Debugf("%s write to <%s> [%s](%s) RejectedFrame error:%v", ServerLogPrefix, clientType, f.Name, connId, err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// client type
|
|
var conn Connection
|
|
switch clientType {
|
|
case ClientTypeProtocolGateway, ClientTypeStreamFunction:
|
|
// metadata
|
|
metadata, err := s.metadataBuilder.Build(f)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn = newConnection(f.Name, f.ClientId, clientType, metadata, stream, f.ObserveDataTags)
|
|
|
|
if clientType == ClientTypeStreamFunction {
|
|
// route
|
|
route := s.router.Route(metadata)
|
|
if route == nil {
|
|
return errors.New("handleHandshakeFrame route is nil")
|
|
}
|
|
if e1 := route.Add(connId, f.Name, f.ObserveDataTags); e1 != nil {
|
|
// duplicate name
|
|
if e2, ok := e1.(hpds_err.DuplicateNameError); ok {
|
|
existsConnID := e2.ConnId()
|
|
if conn = s.connector.Get(existsConnID); conn != nil {
|
|
log.Debugf("%s%s, write to SFN[%s](%s) GoawayFrame", ServerLogPrefix, e2.Error(), f.Name, existsConnID)
|
|
goawayFrame := frame.NewGoawayFrame(e2.Error())
|
|
if e3 := conn.Write(goawayFrame); e3 != nil {
|
|
log.Errorf("%s write to SFN[%s] GoawayFrame error:%v", ServerLogPrefix, f.Name, e3)
|
|
return e3
|
|
}
|
|
}
|
|
} else {
|
|
return e1
|
|
}
|
|
}
|
|
}
|
|
case ClientTypeMessageQueue:
|
|
conn = newConnection(f.Name, f.ClientId, clientType, nil, stream, f.ObserveDataTags)
|
|
default:
|
|
// unknown client type
|
|
s.connector.Remove(connId)
|
|
err := fmt.Errorf("Illegal ClientType: %#x ", f.ClientType)
|
|
c.CloseWithError(hpds_err.ErrorCodeUnknownClient, err.Error())
|
|
return err
|
|
}
|
|
|
|
s.connector.Add(connId, conn)
|
|
log.Printf("%s <%s> [%s][%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, clientId, connId)
|
|
return nil
|
|
}
|
|
|
|
// handle handleGoawayFrame
|
|
func (s *Server) handleGoawayFrame(c *Context) error {
|
|
f := c.Frame.(*frame.GoawayFrame)
|
|
|
|
log.Debugf("%s GOT GoawayFrame code=%d, message==%s", ServerLogPrefix, hpds_err.ErrorCodeGoaway, f.Message())
|
|
// c.CloseWithError(f.Code(), f.Message())
|
|
_, err := c.Stream.Write(f.Encode())
|
|
return err
|
|
}
|
|
|
|
// will reuse quic-go's keep-alive feature
|
|
// func (s *Server) handlePingFrame(stream quic.Stream, conn quic.Connection, f *frame.PingFrame) error {
|
|
// log.Infof("%s------> GOT PingFrame : %# x", ServerLogPrefix, f)
|
|
// return nil
|
|
// }
|
|
|
|
func (s *Server) handleDataFrame(c *Context) error {
|
|
// counter +1
|
|
atomic.AddInt64(&s.counterOfDataFrame, 1)
|
|
// currentIssuer := f.GetIssuer()
|
|
fromId := c.ConnId()
|
|
from := s.connector.Get(fromId)
|
|
if from == nil {
|
|
log.Warnf("%shandleDataFrame connector cannot find %s", ServerLogPrefix, fromId)
|
|
return fmt.Errorf("handleDataFrame connector cannot find %s", fromId)
|
|
}
|
|
|
|
f := c.Frame.(*frame.DataFrame)
|
|
|
|
var metadata Metadata
|
|
if from.ClientType() == ClientTypeMessageQueue {
|
|
m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
metadata = m
|
|
} else {
|
|
metadata = from.Metadata()
|
|
}
|
|
|
|
// route
|
|
route := s.router.Route(metadata)
|
|
if route == nil {
|
|
log.Warnf("%shandleDataFrame route is nil", ServerLogPrefix)
|
|
return fmt.Errorf("handleDataFrame route is nil")
|
|
}
|
|
|
|
// get stream function connection ids from route
|
|
connIds := route.GetForwardRoutes(f.GetDataTag())
|
|
for _, toId := range connIds {
|
|
conn := s.connector.Get(toId)
|
|
if conn == nil {
|
|
log.Errorf("%sconn is nil: (%s)", ServerLogPrefix, toId)
|
|
continue
|
|
}
|
|
|
|
to := conn.Name()
|
|
log.Debugf("%shandleDataFrame tag=%#x tid=%s, counter=%d, from=[%s](%s), to=[%s](%s)", ServerLogPrefix, f.Tag(), f.TransactionId(), s.counterOfDataFrame, from.Name(), fromId, to, toId)
|
|
|
|
// write data frame to stream
|
|
if err := conn.Write(f); err != nil {
|
|
log.Warnf("%shandleDataFrame conn.Write tag=%#x tid=%s, from=[%s](%s), to=[%s](%s), %v", ServerLogPrefix, f.Tag(), f.TransactionId(), from.Name(), fromId, to, toId, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) handleBackFlowFrame(c *Context) error {
|
|
f := c.Frame.(*frame.DataFrame)
|
|
tag := f.GetDataTag()
|
|
carriage := f.GetCarriage()
|
|
sourceId := f.SourceId()
|
|
// write to Protocol Gateway with BackFlowFrame
|
|
bf := frame.NewBackFlowFrame(tag, carriage)
|
|
sourceConns := s.connector.GetProtocolGatewayConnections(sourceId, tag)
|
|
// conn := s.connector.Get(c.connId)
|
|
// logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage)
|
|
for _, source := range sourceConns {
|
|
if source != nil {
|
|
log.Debugf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, result=%# x", ServerLogPrefix, tag, sourceId, frame.Shortly(carriage))
|
|
if err := source.Write(bf); err != nil {
|
|
log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err)
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// StatsFunctions returns the sfn stats of server.
|
|
func (s *Server) StatsFunctions() map[string]string {
|
|
return s.connector.GetSnapshot()
|
|
}
|
|
|
|
// StatsCounter returns how many DataFrames pass through server.
|
|
func (s *Server) StatsCounter() int64 {
|
|
return s.counterOfDataFrame
|
|
}
|
|
|
|
// DownStreams return all the downstream servers.
|
|
func (s *Server) DownStreams() map[string]*Client {
|
|
return s.downStreams
|
|
}
|
|
|
|
// ConfigRouter is used to set router by Message Queue
|
|
func (s *Server) ConfigRouter(router Router) {
|
|
s.mu.Lock()
|
|
s.router = router
|
|
log.Debugf("%sconfig router is %#v", ServerLogPrefix, router)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue
|
|
func (s *Server) ConfigMetadataBuilder(builder MetadataBuilder) {
|
|
s.mu.Lock()
|
|
s.metadataBuilder = builder
|
|
log.Debugf("%sconfig metadataBuilder is %#v", ServerLogPrefix, builder)
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// AddDownstreamServer add a downstream server to this server. all the DataFrames will be
|
|
// dispatch to all the downStreams.
|
|
func (s *Server) AddDownstreamServer(addr string, c *Client) {
|
|
s.mu.Lock()
|
|
s.downStreams[addr] = c
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// dispatch every DataFrames to all downStreams
|
|
func (s *Server) dispatchToDownStreams(df *frame.DataFrame) {
|
|
for addr, ds := range s.downStreams {
|
|
log.Debugf("%sdispatching to [%s]: %# x", ServerLogPrefix, addr, df.Tag())
|
|
_ = ds.WriteFrame(df)
|
|
}
|
|
}
|
|
|
|
// GetConnId get quic connection id
|
|
func GetConnId(conn quic.Connection) string {
|
|
return conn.RemoteAddr().String()
|
|
}
|
|
|
|
func (s *Server) initOptions() {
|
|
// defaults
|
|
}
|
|
|
|
func (s *Server) validateRouter() error {
|
|
if s.router == nil {
|
|
return errors.New("server's router is nil")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) validateMetadataBuilder() error {
|
|
if s.metadataBuilder == nil {
|
|
return errors.New("server's metadataBuilder is nil")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Options returns the options of server.
|
|
func (s *Server) Options() ServerOptions {
|
|
return s.opts
|
|
}
|
|
|
|
// Connector returns the connector of server.
|
|
func (s *Server) Connector() Connector {
|
|
return s.connector
|
|
}
|
|
|
|
// SetBeforeHandlers set the before handlers of server.
|
|
func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) {
|
|
s.beforeHandlers = append(s.beforeHandlers, handlers...)
|
|
}
|
|
|
|
// SetAfterHandlers set the after handlers of server.
|
|
func (s *Server) SetAfterHandlers(handlers ...FrameHandler) {
|
|
s.afterHandlers = append(s.afterHandlers, handlers...)
|
|
}
|
|
|
|
func (s *Server) authNames() []string {
|
|
if len(s.opts.Auths) == 0 {
|
|
return []string{"none"}
|
|
}
|
|
result := make([]string, 0)
|
|
for _, auth := range s.opts.Auths {
|
|
result = append(result, auth.Name())
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *Server) authenticate(f *frame.HandshakeFrame) bool {
|
|
if len(s.opts.Auths) > 0 {
|
|
for _, auth := range s.opts.Auths {
|
|
if f.AuthName() == auth.Name() {
|
|
isAuthenticated := auth.Authenticate(f.AuthPayload())
|
|
if isAuthenticated {
|
|
log.Debugf("%sauthenticated==%v", ServerLogPrefix, isAuthenticated)
|
|
return isAuthenticated
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func mode() string {
|
|
if pkgtls.IsDev() {
|
|
return "DEVELOPMENT"
|
|
}
|
|
return "PRODUCTION"
|
|
}
|
|
|
|
func authName(name string) string {
|
|
if name == "" {
|
|
return "empty"
|
|
}
|
|
|
|
return name
|
|
}
|