2022-10-11 17:36:09 +08:00
package network
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"sync/atomic"
// authentication implements, Currently, only token authentication is implemented
_ "git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/frame"
"git.hpds.cc/Component/network/hpds_err"
"git.hpds.cc/Component/network/log"
pkgtls "git.hpds.cc/Component/network/tls"
"github.com/lucas-clemente/quic-go"
)
const (
// DefaultListenAddr is the default address to listen.
DefaultListenAddr = "0.0.0.0:9000"
)
// ServerOption is the option for server.
type ServerOption func ( * ServerOptions )
// FrameHandler is the handler for frame.
type FrameHandler func ( c * Context ) error
// Server is the underlining server of Message Queue
type Server struct {
name string
state string
connector Connector
router Router
metadataBuilder MetadataBuilder
counterOfDataFrame int64
downStreams map [ string ] * Client
mu sync . Mutex
opts ServerOptions
beforeHandlers [ ] FrameHandler
afterHandlers [ ] FrameHandler
}
// NewServer create a Server instance.
func NewServer ( name string , opts ... ServerOption ) * Server {
s := & Server {
name : name ,
connector : newConnector ( ) ,
downStreams : make ( map [ string ] * Client ) ,
}
2023-03-10 23:49:52 +08:00
_ = s . Init ( opts ... )
2022-10-11 17:36:09 +08:00
return s
}
// Init the options.
func ( s * Server ) Init ( opts ... ServerOption ) error {
for _ , o := range opts {
o ( & s . opts )
}
// options defaults
s . initOptions ( )
return nil
}
// ListenAndServe starts the server.
func ( s * Server ) ListenAndServe ( ctx context . Context , addr string ) error {
if addr == "" {
addr = DefaultListenAddr
}
udpAddr , err := net . ResolveUDPAddr ( "udp" , addr )
if err != nil {
return err
}
conn , err := net . ListenUDP ( "udp" , udpAddr )
if err != nil {
return err
}
return s . Serve ( ctx , conn )
}
// Serve the server with a net.PacketConn.
func ( s * Server ) Serve ( ctx context . Context , conn net . PacketConn ) error {
if err := s . validateMetadataBuilder ( ) ; err != nil {
return err
}
if err := s . validateRouter ( ) ; err != nil {
return err
}
// listen the address
listener , err := newListener ( conn , s . opts . TLSConfig , s . opts . QuicConfig )
if err != nil {
log . Errorf ( "%slistener.Listen: err=%v" , ServerLogPrefix , err )
return err
}
2023-03-10 23:49:52 +08:00
defer func ( ) {
_ = listener . Close ( )
} ( )
2022-10-11 17:36:09 +08:00
log . Printf ( "%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s" , ServerLogPrefix , s . name , os . Getpid ( ) , listener . Addr ( ) , mode ( ) , listener . Versions ( ) , s . authNames ( ) )
s . state = ConnStateConnected
for {
2023-03-10 23:49:52 +08:00
_ = s . createNewClientConnection ( ctx , listener )
}
}
// createNewClientConnection create a new connection when new hpds-client connected
func ( s * Server ) createNewClientConnection ( ctx context . Context , listener Listener ) error {
sctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
connect , e := listener . Accept ( sctx )
if e != nil {
log . Errorf ( "%screate connection error: %v" , ServerLogPrefix , e )
return e
}
connId := GetConnId ( connect )
log . Infof ( "%s1/ new connection: %s" , ServerLogPrefix , connId )
go func ( ctx context . Context , qConn quic . Connection ) {
for {
err := s . handle ( ctx , qConn , connect , connId )
if err != nil {
2023-04-02 23:23:27 +08:00
continue
2023-03-10 23:49:52 +08:00
}
2022-10-11 17:36:09 +08:00
}
2023-03-10 23:49:52 +08:00
} ( sctx , connect )
return nil
}
2022-10-11 17:36:09 +08:00
2023-03-10 23:49:52 +08:00
func ( s * Server ) handle ( ctx context . Context , qConn quic . Connection , conn quic . Connection , connId string ) error {
log . Infof ( "%s2/ waiting for new stream" , ServerLogPrefix )
stream , err := qConn . AcceptStream ( ctx )
if err != nil {
name := "--"
if conn := s . connector . Get ( connId ) ; conn != nil {
_ = conn . Close ( )
// connector
s . connector . Remove ( connId )
route := s . router . Route ( conn . Metadata ( ) )
if route != nil {
_ = route . Remove ( connId )
2022-10-11 17:36:09 +08:00
}
2023-03-10 23:49:52 +08:00
name = conn . Name ( )
}
log . Printf ( "%s [%s](%s) close the connection: %v" , ServerLogPrefix , name , connId , err )
return err
2022-10-11 17:36:09 +08:00
}
2023-03-10 23:49:52 +08:00
defer func ( ) {
_ = stream . Close ( )
} ( )
log . Infof ( "%s3/ [stream:%d] created, connId=%s" , ServerLogPrefix , stream . StreamID ( ) , connId )
// process frames on stream
// c := newContext(connId, stream)
c := newContext ( conn , stream )
defer c . Clean ( )
s . handleConnection ( c )
log . Infof ( "%s4/ [stream:%d] handleConnection DONE" , ServerLogPrefix , stream . StreamID ( ) )
return nil
2022-10-11 17:36:09 +08:00
}
// Close will shut down the server.
func ( s * Server ) Close ( ) error {
if s . router != nil {
s . router . Clean ( )
}
// connector
if s . connector != nil {
s . connector . Clean ( )
}
return nil
}
// handle streams on a connection
func ( s * Server ) handleConnection ( c * Context ) {
fs := NewFrameStream ( c . Stream )
// check update for stream
for {
log . Debugf ( "%shandleConnection waiting read next..." , ServerLogPrefix )
f , err := fs . ReadFrame ( )
if err != nil {
// if client close connection, will get ApplicationError with code = 0x00
if e , ok := err . ( * quic . ApplicationError ) ; ok {
if hpds_err . Is ( e . ErrorCode , hpds_err . ErrorCodeClientAbort ) {
// client abort
log . Infof ( "%sclient close the connection" , ServerLogPrefix )
break
} else {
ye := hpds_err . New ( hpds_err . Parse ( e . ErrorCode ) , err )
log . Errorf ( "%s[ERR] %s" , ServerLogPrefix , ye )
}
} else if err == io . EOF {
log . Infof ( "%sthe connection is EOF" , ServerLogPrefix )
break
}
if errors . Is ( err , net . ErrClosed ) {
// if client close the connection, net.ErrClosed will be raised
// by quic-go IdleTimeoutError after connection's KeepAlive config.
log . Warnf ( "%s[ERR] net.ErrClosed on [handleConnection] %v" , ServerLogPrefix , net . ErrClosed )
c . CloseWithError ( hpds_err . ErrorCodeClosed , "net.ErrClosed" )
break
}
// any error occurred, we should close the stream
// after this, conn.AcceptStream() will raise the error
c . CloseWithError ( hpds_err . ErrorCodeUnknown , err . Error ( ) )
log . Warnf ( "%sconnection.Close()" , ServerLogPrefix )
break
}
frameType := f . Type ( )
data := f . Encode ( )
log . Debugf ( "%stype=%s, frame[%d]=%# x" , ServerLogPrefix , frameType , len ( data ) , frame . Shortly ( data ) )
2023-03-10 23:49:52 +08:00
// add frame to contextFrame
contextFrame := c . WithFrame ( f )
2022-10-11 17:36:09 +08:00
// before frame handlers
for _ , handler := range s . beforeHandlers {
2023-03-10 23:49:52 +08:00
if e := handler ( contextFrame ) ; e != nil {
2022-10-11 17:36:09 +08:00
log . Errorf ( "%safterFrameHandler e: %s" , ServerLogPrefix , e )
2023-03-10 23:49:52 +08:00
contextFrame . CloseWithError ( hpds_err . ErrorCodeBeforeHandler , e . Error ( ) )
2022-10-11 17:36:09 +08:00
return
}
}
// main handler
2023-03-10 23:49:52 +08:00
if e := s . mainFrameHandler ( contextFrame ) ; e != nil {
2022-10-11 17:36:09 +08:00
log . Errorf ( "%smainFrameHandler e: %s" , ServerLogPrefix , e )
2023-03-10 23:49:52 +08:00
contextFrame . CloseWithError ( hpds_err . ErrorCodeMainHandler , e . Error ( ) )
2022-10-11 17:36:09 +08:00
return
}
// after frame handler
for _ , handler := range s . afterHandlers {
2023-03-10 23:49:52 +08:00
if e := handler ( contextFrame ) ; e != nil {
2022-10-11 17:36:09 +08:00
log . Errorf ( "%safterFrameHandler e: %s" , ServerLogPrefix , e )
2023-03-10 23:49:52 +08:00
contextFrame . CloseWithError ( hpds_err . ErrorCodeAfterHandler , e . Error ( ) )
2022-10-11 17:36:09 +08:00
return
}
}
}
}
func ( s * Server ) mainFrameHandler ( c * Context ) error {
var err error
frameType := c . Frame . Type ( )
switch frameType {
case frame . TagOfHandshakeFrame :
if err = s . handleHandshakeFrame ( c ) ; err != nil {
log . Errorf ( "%shandleHandshakeFrame err: %s" , ServerLogPrefix , err )
// close connections early to avoid resource consumption
if c . Stream != nil {
goawayFrame := frame . NewGoawayFrame ( err . Error ( ) )
if _ , e := c . Stream . Write ( goawayFrame . Encode ( ) ) ; e != nil {
log . Errorf ( "%s write to client[%s] GoawayFrame error:%v" , ServerLogPrefix , c . ConnId , e )
return e
}
}
}
// case frame.TagOfPingFrame:
// s.handlePingFrame(mainStream, connection, f.(*frame.PingFrame))
case frame . TagOfDataFrame :
if err = s . handleDataFrame ( c ) ; err != nil {
c . CloseWithError ( hpds_err . ErrorCodeData , fmt . Sprintf ( "handleDataFrame err: %v" , err ) )
} else {
conn := s . connector . Get ( c . connId )
if conn != nil && conn . ClientType ( ) == ClientTypeProtocolGateway {
f := c . Frame . ( * frame . DataFrame )
f . GetMetaFrame ( ) . SetMetadata ( conn . Metadata ( ) . Encode ( ) )
s . dispatchToDownStreams ( f )
}
// observe data tags back flow
2023-03-10 23:49:52 +08:00
_ = s . handleBackFlowFrame ( c )
2022-10-11 17:36:09 +08:00
}
default :
log . Errorf ( "%serr=%v, frame=%v" , ServerLogPrefix , err , frame . Shortly ( c . Frame . Encode ( ) ) )
}
return nil
}
// handle HandShakeFrame
func ( s * Server ) handleHandshakeFrame ( c * Context ) error {
f := c . Frame . ( * frame . HandshakeFrame )
log . Debugf ( "%sGOT HandshakeFrame : %# x" , ServerLogPrefix , f )
// basic info
connId := c . ConnId ( )
clientId := f . ClientId
clientType := ClientType ( f . ClientType )
stream := c . Stream
// credential
log . Debugf ( "%sClientType=%# x is %s, ClientId=%s, Credential=%s" , ServerLogPrefix , f . ClientType , ClientType ( f . ClientType ) , clientId , authName ( f . AuthName ( ) ) )
// authenticate
if ! s . authenticate ( f ) {
err := fmt . Errorf ( "handshake authentication fails, client credential name is %s" , authName ( f . AuthName ( ) ) )
// return err
log . Debugf ( "%s <%s> [%s](%s) is connected!" , ServerLogPrefix , clientType , f . Name , connId )
rejectedFrame := frame . NewRejectedFrame ( err . Error ( ) )
if _ , err = stream . Write ( rejectedFrame . Encode ( ) ) ; err != nil {
log . Debugf ( "%s write to <%s> [%s](%s) RejectedFrame error:%v" , ServerLogPrefix , clientType , f . Name , connId , err )
return err
}
return nil
}
// client type
var conn Connection
switch clientType {
case ClientTypeProtocolGateway , ClientTypeStreamFunction :
// metadata
metadata , err := s . metadataBuilder . Build ( f )
if err != nil {
return err
}
conn = newConnection ( f . Name , f . ClientId , clientType , metadata , stream , f . ObserveDataTags )
if clientType == ClientTypeStreamFunction {
// route
route := s . router . Route ( metadata )
if route == nil {
return errors . New ( "handleHandshakeFrame route is nil" )
}
if e1 := route . Add ( connId , f . Name , f . ObserveDataTags ) ; e1 != nil {
// duplicate name
if e2 , ok := e1 . ( hpds_err . DuplicateNameError ) ; ok {
existsConnID := e2 . ConnId ( )
if conn = s . connector . Get ( existsConnID ) ; conn != nil {
log . Debugf ( "%s%s, write to SFN[%s](%s) GoawayFrame" , ServerLogPrefix , e2 . Error ( ) , f . Name , existsConnID )
goawayFrame := frame . NewGoawayFrame ( e2 . Error ( ) )
if e3 := conn . Write ( goawayFrame ) ; e3 != nil {
log . Errorf ( "%s write to SFN[%s] GoawayFrame error:%v" , ServerLogPrefix , f . Name , e3 )
return e3
}
}
} else {
return e1
}
}
}
case ClientTypeMessageQueue :
conn = newConnection ( f . Name , f . ClientId , clientType , nil , stream , f . ObserveDataTags )
default :
// unknown client type
s . connector . Remove ( connId )
2023-03-10 23:49:52 +08:00
err := fmt . Errorf ( "Illegal ClientType: %#x " , f . ClientType )
2022-10-11 17:36:09 +08:00
c . CloseWithError ( hpds_err . ErrorCodeUnknownClient , err . Error ( ) )
return err
}
s . connector . Add ( connId , conn )
log . Printf ( "%s <%s> [%s][%s](%s) is connected!" , ServerLogPrefix , clientType , f . Name , clientId , connId )
return nil
}
// handle handleGoawayFrame
func ( s * Server ) handleGoawayFrame ( c * Context ) error {
f := c . Frame . ( * frame . GoawayFrame )
log . Debugf ( "%s GOT GoawayFrame code=%d, message==%s" , ServerLogPrefix , hpds_err . ErrorCodeGoaway , f . Message ( ) )
// c.CloseWithError(f.Code(), f.Message())
_ , err := c . Stream . Write ( f . Encode ( ) )
return err
}
// will reuse quic-go's keep-alive feature
// func (s *Server) handlePingFrame(stream quic.Stream, conn quic.Connection, f *frame.PingFrame) error {
// log.Infof("%s------> GOT PingFrame : %# x", ServerLogPrefix, f)
// return nil
// }
func ( s * Server ) handleDataFrame ( c * Context ) error {
// counter +1
atomic . AddInt64 ( & s . counterOfDataFrame , 1 )
// currentIssuer := f.GetIssuer()
fromId := c . ConnId ( )
from := s . connector . Get ( fromId )
if from == nil {
log . Warnf ( "%shandleDataFrame connector cannot find %s" , ServerLogPrefix , fromId )
return fmt . Errorf ( "handleDataFrame connector cannot find %s" , fromId )
}
f := c . Frame . ( * frame . DataFrame )
var metadata Metadata
if from . ClientType ( ) == ClientTypeMessageQueue {
m , err := s . metadataBuilder . Decode ( f . GetMetaFrame ( ) . Metadata ( ) )
if err != nil {
return err
}
metadata = m
} else {
metadata = from . Metadata ( )
}
// route
route := s . router . Route ( metadata )
if route == nil {
log . Warnf ( "%shandleDataFrame route is nil" , ServerLogPrefix )
return fmt . Errorf ( "handleDataFrame route is nil" )
}
// get stream function connection ids from route
connIds := route . GetForwardRoutes ( f . GetDataTag ( ) )
for _ , toId := range connIds {
conn := s . connector . Get ( toId )
if conn == nil {
log . Errorf ( "%sconn is nil: (%s)" , ServerLogPrefix , toId )
continue
}
to := conn . Name ( )
log . Debugf ( "%shandleDataFrame tag=%#x tid=%s, counter=%d, from=[%s](%s), to=[%s](%s)" , ServerLogPrefix , f . Tag ( ) , f . TransactionId ( ) , s . counterOfDataFrame , from . Name ( ) , fromId , to , toId )
// write data frame to stream
if err := conn . Write ( f ) ; err != nil {
log . Warnf ( "%shandleDataFrame conn.Write tag=%#x tid=%s, from=[%s](%s), to=[%s](%s), %v" , ServerLogPrefix , f . Tag ( ) , f . TransactionId ( ) , from . Name ( ) , fromId , to , toId , err )
}
}
return nil
}
func ( s * Server ) handleBackFlowFrame ( c * Context ) error {
f := c . Frame . ( * frame . DataFrame )
tag := f . GetDataTag ( )
carriage := f . GetCarriage ( )
sourceId := f . SourceId ( )
// write to Protocol Gateway with BackFlowFrame
bf := frame . NewBackFlowFrame ( tag , carriage )
sourceConns := s . connector . GetProtocolGatewayConnections ( sourceId , tag )
// conn := s.connector.Get(c.connId)
// logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage)
for _ , source := range sourceConns {
if source != nil {
log . Debugf ( "%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, result=%# x" , ServerLogPrefix , tag , sourceId , frame . Shortly ( carriage ) )
if err := source . Write ( bf ) ; err != nil {
log . Errorf ( "%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v" , ServerLogPrefix , tag , sourceId , err )
return err
}
}
}
return nil
}
// StatsFunctions returns the sfn stats of server.
func ( s * Server ) StatsFunctions ( ) map [ string ] string {
return s . connector . GetSnapshot ( )
}
// StatsCounter returns how many DataFrames pass through server.
func ( s * Server ) StatsCounter ( ) int64 {
return s . counterOfDataFrame
}
// DownStreams return all the downstream servers.
func ( s * Server ) DownStreams ( ) map [ string ] * Client {
return s . downStreams
}
// ConfigRouter is used to set router by Message Queue
func ( s * Server ) ConfigRouter ( router Router ) {
s . mu . Lock ( )
s . router = router
log . Debugf ( "%sconfig router is %#v" , ServerLogPrefix , router )
s . mu . Unlock ( )
}
// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue
func ( s * Server ) ConfigMetadataBuilder ( builder MetadataBuilder ) {
s . mu . Lock ( )
s . metadataBuilder = builder
log . Debugf ( "%sconfig metadataBuilder is %#v" , ServerLogPrefix , builder )
s . mu . Unlock ( )
}
// AddDownstreamServer add a downstream server to this server. all the DataFrames will be
// dispatch to all the downStreams.
func ( s * Server ) AddDownstreamServer ( addr string , c * Client ) {
s . mu . Lock ( )
s . downStreams [ addr ] = c
s . mu . Unlock ( )
}
// dispatch every DataFrames to all downStreams
func ( s * Server ) dispatchToDownStreams ( df * frame . DataFrame ) {
for addr , ds := range s . downStreams {
log . Debugf ( "%sdispatching to [%s]: %# x" , ServerLogPrefix , addr , df . Tag ( ) )
2023-03-10 23:49:52 +08:00
_ = ds . WriteFrame ( df )
2022-10-11 17:36:09 +08:00
}
}
// GetConnId get quic connection id
func GetConnId ( conn quic . Connection ) string {
return conn . RemoteAddr ( ) . String ( )
}
func ( s * Server ) initOptions ( ) {
// defaults
}
func ( s * Server ) validateRouter ( ) error {
if s . router == nil {
return errors . New ( "server's router is nil" )
}
return nil
}
func ( s * Server ) validateMetadataBuilder ( ) error {
if s . metadataBuilder == nil {
return errors . New ( "server's metadataBuilder is nil" )
}
return nil
}
// Options returns the options of server.
func ( s * Server ) Options ( ) ServerOptions {
return s . opts
}
// Connector returns the connector of server.
func ( s * Server ) Connector ( ) Connector {
return s . connector
}
// SetBeforeHandlers set the before handlers of server.
func ( s * Server ) SetBeforeHandlers ( handlers ... FrameHandler ) {
s . beforeHandlers = append ( s . beforeHandlers , handlers ... )
}
// SetAfterHandlers set the after handlers of server.
func ( s * Server ) SetAfterHandlers ( handlers ... FrameHandler ) {
s . afterHandlers = append ( s . afterHandlers , handlers ... )
}
func ( s * Server ) authNames ( ) [ ] string {
if len ( s . opts . Auths ) == 0 {
return [ ] string { "none" }
}
result := make ( [ ] string , 0 )
for _ , auth := range s . opts . Auths {
result = append ( result , auth . Name ( ) )
}
return result
}
func ( s * Server ) authenticate ( f * frame . HandshakeFrame ) bool {
if len ( s . opts . Auths ) > 0 {
for _ , auth := range s . opts . Auths {
if f . AuthName ( ) == auth . Name ( ) {
isAuthenticated := auth . Authenticate ( f . AuthPayload ( ) )
if isAuthenticated {
log . Debugf ( "%sauthenticated==%v" , ServerLogPrefix , isAuthenticated )
return isAuthenticated
}
}
}
return false
}
return true
}
func mode ( ) string {
if pkgtls . IsDev ( ) {
return "DEVELOPMENT"
}
return "PRODUCTION"
}
func authName ( name string ) string {
if name == "" {
return "empty"
}
return name
}