network/server.go

583 lines
16 KiB
Go

package network
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"sync/atomic"
// authentication implements, Currently, only token authentication is implemented
_ "git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/frame"
"git.hpds.cc/Component/network/hpds_err"
"git.hpds.cc/Component/network/log"
pkgtls "git.hpds.cc/Component/network/tls"
"github.com/lucas-clemente/quic-go"
)
const (
// DefaultListenAddr is the default address to listen.
DefaultListenAddr = "0.0.0.0:9000"
)
// ServerOption is the option for server.
type ServerOption func(*ServerOptions)
// FrameHandler is the handler for frame.
type FrameHandler func(c *Context) error
// Server is the underlining server of Message Queue
type Server struct {
name string
state string
connector Connector
router Router
metadataBuilder MetadataBuilder
counterOfDataFrame int64
downStreams map[string]*Client
mu sync.Mutex
opts ServerOptions
beforeHandlers []FrameHandler
afterHandlers []FrameHandler
}
// NewServer create a Server instance.
func NewServer(name string, opts ...ServerOption) *Server {
s := &Server{
name: name,
connector: newConnector(),
downStreams: make(map[string]*Client),
}
_ = s.Init(opts...)
return s
}
// Init the options.
func (s *Server) Init(opts ...ServerOption) error {
for _, o := range opts {
o(&s.opts)
}
// options defaults
s.initOptions()
return nil
}
// ListenAndServe starts the server.
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
if addr == "" {
addr = DefaultListenAddr
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
return s.Serve(ctx, conn)
}
// Serve the server with a net.PacketConn.
func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error {
if err := s.validateMetadataBuilder(); err != nil {
return err
}
if err := s.validateRouter(); err != nil {
return err
}
// listen the address
listener, err := newListener(conn, s.opts.TLSConfig, s.opts.QuicConfig)
if err != nil {
log.Errorf("%slistener.Listen: err=%v", ServerLogPrefix, err)
return err
}
defer func() {
_ = listener.Close()
}()
log.Printf("%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), mode(), listener.Versions(), s.authNames())
s.state = ConnStateConnected
for {
_ = s.createNewClientConnection(ctx, listener)
}
}
// createNewClientConnection create a new connection when new hpds-client connected
func (s *Server) createNewClientConnection(ctx context.Context, listener Listener) error {
sctx, cancel := context.WithCancel(ctx)
defer cancel()
connect, e := listener.Accept(sctx)
if e != nil {
log.Errorf("%screate connection error: %v", ServerLogPrefix, e)
return e
}
connId := GetConnId(connect)
log.Infof("%s1/ new connection: %s", ServerLogPrefix, connId)
go func(ctx context.Context, qConn quic.Connection) {
for {
err := s.handle(ctx, qConn, connect, connId)
if err != nil {
break
}
}
}(sctx, connect)
return nil
}
func (s *Server) handle(ctx context.Context, qConn quic.Connection, conn quic.Connection, connId string) error {
log.Infof("%s2/ waiting for new stream", ServerLogPrefix)
stream, err := qConn.AcceptStream(ctx)
if err != nil {
name := "--"
if conn := s.connector.Get(connId); conn != nil {
_ = conn.Close()
// connector
s.connector.Remove(connId)
route := s.router.Route(conn.Metadata())
if route != nil {
_ = route.Remove(connId)
}
name = conn.Name()
}
log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connId, err)
return err
}
defer func() {
_ = stream.Close()
}()
log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connId)
// process frames on stream
// c := newContext(connId, stream)
c := newContext(conn, stream)
defer c.Clean()
s.handleConnection(c)
log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID())
return nil
}
// Close will shut down the server.
func (s *Server) Close() error {
if s.router != nil {
s.router.Clean()
}
// connector
if s.connector != nil {
s.connector.Clean()
}
return nil
}
// handle streams on a connection
func (s *Server) handleConnection(c *Context) {
fs := NewFrameStream(c.Stream)
// check update for stream
for {
log.Debugf("%shandleConnection waiting read next...", ServerLogPrefix)
f, err := fs.ReadFrame()
if err != nil {
// if client close connection, will get ApplicationError with code = 0x00
if e, ok := err.(*quic.ApplicationError); ok {
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
// client abort
log.Infof("%sclient close the connection", ServerLogPrefix)
break
} else {
ye := hpds_err.New(hpds_err.Parse(e.ErrorCode), err)
log.Errorf("%s[ERR] %s", ServerLogPrefix, ye)
}
} else if err == io.EOF {
log.Infof("%sthe connection is EOF", ServerLogPrefix)
break
}
if errors.Is(err, net.ErrClosed) {
// if client close the connection, net.ErrClosed will be raised
// by quic-go IdleTimeoutError after connection's KeepAlive config.
log.Warnf("%s[ERR] net.ErrClosed on [handleConnection] %v", ServerLogPrefix, net.ErrClosed)
c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed")
break
}
// any error occurred, we should close the stream
// after this, conn.AcceptStream() will raise the error
c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error())
log.Warnf("%sconnection.Close()", ServerLogPrefix)
break
}
frameType := f.Type()
data := f.Encode()
log.Debugf("%stype=%s, frame[%d]=%# x", ServerLogPrefix, frameType, len(data), frame.Shortly(data))
// add frame to contextFrame
contextFrame := c.WithFrame(f)
// before frame handlers
for _, handler := range s.beforeHandlers {
if e := handler(contextFrame); e != nil {
log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e)
contextFrame.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error())
return
}
}
// main handler
if e := s.mainFrameHandler(contextFrame); e != nil {
log.Errorf("%smainFrameHandler e: %s", ServerLogPrefix, e)
contextFrame.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error())
return
}
// after frame handler
for _, handler := range s.afterHandlers {
if e := handler(contextFrame); e != nil {
log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e)
contextFrame.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error())
return
}
}
}
}
func (s *Server) mainFrameHandler(c *Context) error {
var err error
frameType := c.Frame.Type()
switch frameType {
case frame.TagOfHandshakeFrame:
if err = s.handleHandshakeFrame(c); err != nil {
log.Errorf("%shandleHandshakeFrame err: %s", ServerLogPrefix, err)
// close connections early to avoid resource consumption
if c.Stream != nil {
goawayFrame := frame.NewGoawayFrame(err.Error())
if _, e := c.Stream.Write(goawayFrame.Encode()); e != nil {
log.Errorf("%s write to client[%s] GoawayFrame error:%v", ServerLogPrefix, c.ConnId, e)
return e
}
}
}
// case frame.TagOfPingFrame:
// s.handlePingFrame(mainStream, connection, f.(*frame.PingFrame))
case frame.TagOfDataFrame:
if err = s.handleDataFrame(c); err != nil {
c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err))
} else {
conn := s.connector.Get(c.connId)
if conn != nil && conn.ClientType() == ClientTypeProtocolGateway {
f := c.Frame.(*frame.DataFrame)
f.GetMetaFrame().SetMetadata(conn.Metadata().Encode())
s.dispatchToDownStreams(f)
}
// observe data tags back flow
_ = s.handleBackFlowFrame(c)
}
default:
log.Errorf("%serr=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode()))
}
return nil
}
// handle HandShakeFrame
func (s *Server) handleHandshakeFrame(c *Context) error {
f := c.Frame.(*frame.HandshakeFrame)
log.Debugf("%sGOT HandshakeFrame : %# x", ServerLogPrefix, f)
// basic info
connId := c.ConnId()
clientId := f.ClientId
clientType := ClientType(f.ClientType)
stream := c.Stream
// credential
log.Debugf("%sClientType=%# x is %s, ClientId=%s, Credential=%s", ServerLogPrefix, f.ClientType, ClientType(f.ClientType), clientId, authName(f.AuthName()))
// authenticate
if !s.authenticate(f) {
err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName()))
// return err
log.Debugf("%s <%s> [%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, connId)
rejectedFrame := frame.NewRejectedFrame(err.Error())
if _, err = stream.Write(rejectedFrame.Encode()); err != nil {
log.Debugf("%s write to <%s> [%s](%s) RejectedFrame error:%v", ServerLogPrefix, clientType, f.Name, connId, err)
return err
}
return nil
}
// client type
var conn Connection
switch clientType {
case ClientTypeProtocolGateway, ClientTypeStreamFunction:
// metadata
metadata, err := s.metadataBuilder.Build(f)
if err != nil {
return err
}
conn = newConnection(f.Name, f.ClientId, clientType, metadata, stream, f.ObserveDataTags)
if clientType == ClientTypeStreamFunction {
// route
route := s.router.Route(metadata)
if route == nil {
return errors.New("handleHandshakeFrame route is nil")
}
if e1 := route.Add(connId, f.Name, f.ObserveDataTags); e1 != nil {
// duplicate name
if e2, ok := e1.(hpds_err.DuplicateNameError); ok {
existsConnID := e2.ConnId()
if conn = s.connector.Get(existsConnID); conn != nil {
log.Debugf("%s%s, write to SFN[%s](%s) GoawayFrame", ServerLogPrefix, e2.Error(), f.Name, existsConnID)
goawayFrame := frame.NewGoawayFrame(e2.Error())
if e3 := conn.Write(goawayFrame); e3 != nil {
log.Errorf("%s write to SFN[%s] GoawayFrame error:%v", ServerLogPrefix, f.Name, e3)
return e3
}
}
} else {
return e1
}
}
}
case ClientTypeMessageQueue:
conn = newConnection(f.Name, f.ClientId, clientType, nil, stream, f.ObserveDataTags)
default:
// unknown client type
s.connector.Remove(connId)
err := fmt.Errorf("Illegal ClientType: %#x ", f.ClientType)
c.CloseWithError(hpds_err.ErrorCodeUnknownClient, err.Error())
return err
}
s.connector.Add(connId, conn)
log.Printf("%s <%s> [%s][%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, clientId, connId)
return nil
}
// handle handleGoawayFrame
func (s *Server) handleGoawayFrame(c *Context) error {
f := c.Frame.(*frame.GoawayFrame)
log.Debugf("%s GOT GoawayFrame code=%d, message==%s", ServerLogPrefix, hpds_err.ErrorCodeGoaway, f.Message())
// c.CloseWithError(f.Code(), f.Message())
_, err := c.Stream.Write(f.Encode())
return err
}
// will reuse quic-go's keep-alive feature
// func (s *Server) handlePingFrame(stream quic.Stream, conn quic.Connection, f *frame.PingFrame) error {
// log.Infof("%s------> GOT PingFrame : %# x", ServerLogPrefix, f)
// return nil
// }
func (s *Server) handleDataFrame(c *Context) error {
// counter +1
atomic.AddInt64(&s.counterOfDataFrame, 1)
// currentIssuer := f.GetIssuer()
fromId := c.ConnId()
from := s.connector.Get(fromId)
if from == nil {
log.Warnf("%shandleDataFrame connector cannot find %s", ServerLogPrefix, fromId)
return fmt.Errorf("handleDataFrame connector cannot find %s", fromId)
}
f := c.Frame.(*frame.DataFrame)
var metadata Metadata
if from.ClientType() == ClientTypeMessageQueue {
m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata())
if err != nil {
return err
}
metadata = m
} else {
metadata = from.Metadata()
}
// route
route := s.router.Route(metadata)
if route == nil {
log.Warnf("%shandleDataFrame route is nil", ServerLogPrefix)
return fmt.Errorf("handleDataFrame route is nil")
}
// get stream function connection ids from route
connIds := route.GetForwardRoutes(f.GetDataTag())
for _, toId := range connIds {
conn := s.connector.Get(toId)
if conn == nil {
log.Errorf("%sconn is nil: (%s)", ServerLogPrefix, toId)
continue
}
to := conn.Name()
log.Debugf("%shandleDataFrame tag=%#x tid=%s, counter=%d, from=[%s](%s), to=[%s](%s)", ServerLogPrefix, f.Tag(), f.TransactionId(), s.counterOfDataFrame, from.Name(), fromId, to, toId)
// write data frame to stream
if err := conn.Write(f); err != nil {
log.Warnf("%shandleDataFrame conn.Write tag=%#x tid=%s, from=[%s](%s), to=[%s](%s), %v", ServerLogPrefix, f.Tag(), f.TransactionId(), from.Name(), fromId, to, toId, err)
}
}
return nil
}
func (s *Server) handleBackFlowFrame(c *Context) error {
f := c.Frame.(*frame.DataFrame)
tag := f.GetDataTag()
carriage := f.GetCarriage()
sourceId := f.SourceId()
// write to Protocol Gateway with BackFlowFrame
bf := frame.NewBackFlowFrame(tag, carriage)
sourceConns := s.connector.GetProtocolGatewayConnections(sourceId, tag)
// conn := s.connector.Get(c.connId)
// logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage)
for _, source := range sourceConns {
if source != nil {
log.Debugf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, result=%# x", ServerLogPrefix, tag, sourceId, frame.Shortly(carriage))
if err := source.Write(bf); err != nil {
log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err)
return err
}
}
}
return nil
}
// StatsFunctions returns the sfn stats of server.
func (s *Server) StatsFunctions() map[string]string {
return s.connector.GetSnapshot()
}
// StatsCounter returns how many DataFrames pass through server.
func (s *Server) StatsCounter() int64 {
return s.counterOfDataFrame
}
// DownStreams return all the downstream servers.
func (s *Server) DownStreams() map[string]*Client {
return s.downStreams
}
// ConfigRouter is used to set router by Message Queue
func (s *Server) ConfigRouter(router Router) {
s.mu.Lock()
s.router = router
log.Debugf("%sconfig router is %#v", ServerLogPrefix, router)
s.mu.Unlock()
}
// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue
func (s *Server) ConfigMetadataBuilder(builder MetadataBuilder) {
s.mu.Lock()
s.metadataBuilder = builder
log.Debugf("%sconfig metadataBuilder is %#v", ServerLogPrefix, builder)
s.mu.Unlock()
}
// AddDownstreamServer add a downstream server to this server. all the DataFrames will be
// dispatch to all the downStreams.
func (s *Server) AddDownstreamServer(addr string, c *Client) {
s.mu.Lock()
s.downStreams[addr] = c
s.mu.Unlock()
}
// dispatch every DataFrames to all downStreams
func (s *Server) dispatchToDownStreams(df *frame.DataFrame) {
for addr, ds := range s.downStreams {
log.Debugf("%sdispatching to [%s]: %# x", ServerLogPrefix, addr, df.Tag())
_ = ds.WriteFrame(df)
}
}
// GetConnId get quic connection id
func GetConnId(conn quic.Connection) string {
return conn.RemoteAddr().String()
}
func (s *Server) initOptions() {
// defaults
}
func (s *Server) validateRouter() error {
if s.router == nil {
return errors.New("server's router is nil")
}
return nil
}
func (s *Server) validateMetadataBuilder() error {
if s.metadataBuilder == nil {
return errors.New("server's metadataBuilder is nil")
}
return nil
}
// Options returns the options of server.
func (s *Server) Options() ServerOptions {
return s.opts
}
// Connector returns the connector of server.
func (s *Server) Connector() Connector {
return s.connector
}
// SetBeforeHandlers set the before handlers of server.
func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) {
s.beforeHandlers = append(s.beforeHandlers, handlers...)
}
// SetAfterHandlers set the after handlers of server.
func (s *Server) SetAfterHandlers(handlers ...FrameHandler) {
s.afterHandlers = append(s.afterHandlers, handlers...)
}
func (s *Server) authNames() []string {
if len(s.opts.Auths) == 0 {
return []string{"none"}
}
result := make([]string, 0)
for _, auth := range s.opts.Auths {
result = append(result, auth.Name())
}
return result
}
func (s *Server) authenticate(f *frame.HandshakeFrame) bool {
if len(s.opts.Auths) > 0 {
for _, auth := range s.opts.Auths {
if f.AuthName() == auth.Name() {
isAuthenticated := auth.Authenticate(f.AuthPayload())
if isAuthenticated {
log.Debugf("%sauthenticated==%v", ServerLogPrefix, isAuthenticated)
return isAuthenticated
}
}
}
return false
}
return true
}
func mode() string {
if pkgtls.IsDev() {
return "DEVELOPMENT"
}
return "PRODUCTION"
}
func authName(name string) string {
if name == "" {
return "empty"
}
return name
}