network/server.go

535 lines
15 KiB
Go

package network
import (
"context"
"errors"
"fmt"
"git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/metadata"
"git.hpds.cc/Component/network/router"
"github.com/quic-go/quic-go"
"io"
"net"
"os"
"sync"
"sync/atomic"
// authentication implements, Currently, only token authentication is implemented
_ "git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/frame"
"git.hpds.cc/Component/network/hpds_err"
"git.hpds.cc/Component/network/log"
)
// FrameHandler is the handler for frame.
type FrameHandler func(c *Context) error
// ConnectionHandler is the handler for quic connection
type ConnectionHandler func(conn quic.Connection)
// Server is the underlining server of Message Queue
type Server struct {
name string
connector *Connector
router router.Router
metadataBuilder metadata.Builder
counterOfDataFrame int64
downStreams map[string]frame.Writer
mu sync.Mutex
opts *serverOptions
startHandlers []FrameHandler
beforeHandlers []FrameHandler
afterHandlers []FrameHandler
connectionCloseHandlers []ConnectionHandler
listener Listener
logger log.Logger
}
// NewServer create a Server instance.
func NewServer(name string, opts ...ServerOption) *Server {
options := defaultServerOptions()
for _, o := range opts {
o(options)
}
s := &Server{
name: name,
downStreams: make(map[string]frame.Writer),
opts: options,
}
return s
}
// ListenAndServe starts the server.
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
if addr == "" {
addr = DefaultListenAddr
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
return s.Serve(ctx, conn)
}
// Serve the server with a net.PacketConn.
func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error {
if err := s.validateMetadataBuilder(); err != nil {
return err
}
if err := s.validateRouter(); err != nil {
return err
}
s.connector = NewConnector(ctx)
// listen the address
listener, err := newListener(conn, s.opts.tlsConfig, s.opts.quicConfig)
if err != nil {
log.Errorf("%s listener.Listen: err=%v", ServerLogPrefix, err)
return err
}
s.listener = listener
log.Printf("%s [%s][%d] Listening on: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), listener.Versions(), s.authNames())
for {
conn, err := s.listener.Accept(ctx)
if err != nil {
log.Errorf("%s listener accept connections error %v", ServerLogPrefix, err)
return err
}
err = s.opts.alpnHandler(conn.ConnectionState().TLS.NegotiatedProtocol)
if err != nil {
_ = conn.CloseWithError(quic.ApplicationErrorCode(hpds_err.ErrorCodeRejected), err.Error())
continue
}
stream0, err := conn.AcceptStream(ctx)
if err != nil {
continue
}
controlStream := NewServerControlStream(conn, NewFrameStream(stream0))
// Auth accepts a AuthenticationFrame from client. The first frame from client must be
// AuthenticationFrame, It returns true if auth successful otherwise return false.
// It response to client a AuthenticationAckFrame.
err = controlStream.VerifyAuthentication(s.handleAuthenticationFrame)
if err != nil {
log.Warnf("%s Authentication Failed, error: %s", ServerLogPrefix, err)
continue
}
log.Debugf("%s Authentication Success", ServerLogPrefix)
go func(qConn quic.Connection) {
streamGroup := NewStreamGroup(ctx, controlStream, s.connector)
defer streamGroup.Wait()
defer s.doConnectionCloseHandlers(qConn)
select {
case <-ctx.Done():
return
case err := <-s.runWithStreamGroup(streamGroup):
log.Errorf("%s Client Close, %v", ServerLogPrefix, err)
}
}(conn)
}
}
func (s *Server) runWithStreamGroup(group *StreamGroup) <-chan error {
errCh := make(chan error)
go func() {
errCh <- group.Run(s.handleStreamContext)
}()
return errCh
}
// Close will shut down the server.
func (s *Server) Close() error {
// connector
if s.connector != nil {
s.connector.Close()
}
// listener
if s.listener != nil {
_ = s.listener.Close()
}
// router
if s.router != nil {
s.router.Clean()
}
return nil
}
func (s *Server) handleRoute(c *Context) error {
if c.DataStream.StreamType() == StreamTypeStreamFunction {
md, err := s.metadataBuilder.Decode(c.DataStream.Metadata())
if err != nil {
return err
}
// route
route := s.router.Route(md)
if route == nil {
return errors.New("handleHandshakeFrame route is nil")
}
if err := route.Add(c.StreamId(), c.DataStream.Name(), c.DataStream.ObserveDataTags()); err != nil {
// duplicate name
if e, ok := err.(hpds_err.DuplicateNameError); ok {
existsConnId := e.ConnId()
log.Debugf("%s StreamFunction Duplicate Name, error: %s; sfn_name: %s, old_stream_id: %s; current_stream_id: %s",
ServerLogPrefix, e.Error(), c.DataStream.Name(), existsConnId, c.StreamId())
stream, ok, err := s.connector.Get(existsConnId)
if err != nil {
return err
}
if ok {
_ = stream.CloseWithError(e.Error())
_ = s.connector.Remove(existsConnId)
}
} else {
return err
}
}
}
return nil
}
// handleStreamContext handles data streams,
func (s *Server) handleStreamContext(c *Context) {
// handle route.
if err := s.handleRoute(c); err != nil {
c.CloseWithError(hpds_err.ErrorCodeRejected, err.Error())
return
}
defer s.cleanRoute(c)
// start frame handlers
for _, handler := range s.startHandlers {
if err := handler(c); err != nil {
log.Errorf("%s startHandlers error: %v", ServerLogPrefix, err)
c.CloseWithError(hpds_err.ErrorCodeStartHandler, err.Error())
return
}
}
// check update for stream
for {
f, err := c.DataStream.ReadFrame()
if err != nil {
// if client close connection, will get ApplicationError with code = 0x00
if e, ok := err.(*quic.ApplicationError); ok {
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
// client abort
log.Infof("%s client close the connection", ServerLogPrefix)
break
}
he := hpds_err.New(hpds_err.Parse(e.ErrorCode), err)
log.Errorf("%s read frame error: %v", ServerLogPrefix, he)
} else if err == io.EOF {
log.Infof("%s connection EOF", ServerLogPrefix)
break
}
if errors.Is(err, net.ErrClosed) {
// if client close the connection, net.ErrClosed will be raise
// by quic-go IdleTimeoutError after connection's KeepAlive config.
log.Warnf("%s connection error, error: %v", ServerLogPrefix, net.ErrClosed)
c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed")
break
}
// any error occurred, we should close the stream
// after this, conn.AcceptStream() will raise the error
c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error())
log.Warnf("%s connection close", ServerLogPrefix)
break
}
// add frame to context
c.WithFrame(f)
// before frame handlers
for _, handler := range s.beforeHandlers {
if err := handler(c); err != nil {
log.Errorf("%s beforeFrameHandler error: %v", ServerLogPrefix, err)
c.CloseWithError(hpds_err.ErrorCodeBeforeHandler, err.Error())
return
}
}
// main handler
if err := s.mainFrameHandler(c); err != nil {
log.Errorf("%s mainFrameHandler error: %v", ServerLogPrefix, err)
c.CloseWithError(hpds_err.ErrorCodeMainHandler, err.Error())
return
}
// after frame handler
for _, handler := range s.afterHandlers {
if err := handler(c); err != nil {
log.Errorf("%s afterFrameHandler error: %v", ServerLogPrefix, err)
c.CloseWithError(hpds_err.ErrorCodeAfterHandler, err.Error())
return
}
}
}
}
func (s *Server) mainFrameHandler(c *Context) error {
var err error
frameType := c.Frame.Type()
switch frameType {
case frame.TagOfDataFrame:
if err = s.handleDataFrame(c); err != nil {
c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err))
} else {
s.dispatchToDownStreams(c)
// observe data tags back flow
_ = s.handleBackFlowFrame(c)
}
default:
log.Errorf("%s err=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode()))
}
return nil
}
func (s *Server) handleAuthenticationFrame(f auth.Object) (bool, error) {
ok := auth.Authenticate(s.opts.auths, f)
if ok {
log.Debugf("%s Successful authentication", ServerLogPrefix)
} else {
log.Warnf("%s Authentication failed, credential: %s", ServerLogPrefix, f.AuthName())
}
return ok, nil
}
func (s *Server) handleDataFrame(c *Context) error {
// counter +1
atomic.AddInt64(&s.counterOfDataFrame, 1)
// currentIssuer := f.GetIssuer()
fromId := c.StreamId()
from, ok, err := s.connector.Get(fromId)
if err != nil {
return err
}
if !ok {
log.Warnf("%s handleDataFrame connector cannot find, from_conn_id: %s", ServerLogPrefix, fromId)
return fmt.Errorf("handleDataFrame connector cannot find %s", fromId)
}
f := c.Frame.(*frame.DataFrame)
m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata())
if err != nil {
return err
}
if m == nil {
m, err = s.metadataBuilder.Decode(from.Metadata())
if err != nil {
return err
}
}
// route
route := s.router.Route(m)
if route == nil {
log.Warnf("%s handleDataFrame route is nil", ServerLogPrefix)
return fmt.Errorf("handleDataFrame route is nil")
}
// get stream function connection ids from route
connIDs := route.GetForwardRoutes(f.GetDataTag())
log.Debugf("%s Data Routing Status, sfn_stream_ids: %v; connector: %v", ServerLogPrefix, connIDs, s.connector.GetSnapshot())
for _, toId := range connIDs {
conn, ok, err := s.connector.Get(toId)
if err != nil {
continue
}
if !ok {
log.Errorf("%s Can't find forward conn, error: conn is nil ;forward_conn_id: ", ServerLogPrefix, toId)
continue
}
to := conn.Name()
log.Infof("%s handleDataFrame, from_conn_name: %s; from_conn_id: %s; to_conn_name: %s; to_conn_id: %s; data_frame: %s",
ServerLogPrefix, from.Name(), fromId, to, toId, f.String())
// write data frame to stream
if err := conn.WriteFrame(f); err != nil {
log.Errorf("%s handleDataFrame conn.Write, %v", ServerLogPrefix, err)
}
}
return nil
}
func (s *Server) handleBackFlowFrame(c *Context) error {
f := c.Frame.(*frame.DataFrame)
tag := f.GetDataTag()
carriage := f.GetCarriage()
sourceId := f.SourceId()
// write to Protocol Gateway with BackFlowFrame
bf := frame.NewBackFlowFrame(tag, carriage)
sourceConnList, err := s.connector.GetSourceConns(sourceId, tag)
if err != nil {
return err
}
// conn := s.connector.Get(c.connId)
// logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage)
for _, source := range sourceConnList {
if source != nil {
log.Debugf("%s handleBackFlowFrame tag:%#v; source_conn_id: %s, back_flow_frame: %s", ServerLogPrefix, tag, sourceId, f.String())
if err := source.WriteFrame(bf); err != nil {
log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err)
return err
}
}
}
return nil
}
// StatsFunctions returns the sfn stats of server.
func (s *Server) StatsFunctions() map[string]string {
return s.connector.GetSnapshot()
}
// StatsCounter returns how many DataFrames pass through server.
func (s *Server) StatsCounter() int64 {
return s.counterOfDataFrame
}
// DownStreams return all the downstream servers.
func (s *Server) DownStreams() map[string]frame.Writer {
return s.downStreams
}
// ConfigRouter is used to set router by Message Queue
func (s *Server) ConfigRouter(router router.Router) {
s.mu.Lock()
s.router = router
log.Debugf("%s config router is %#v", ServerLogPrefix, router)
s.mu.Unlock()
}
// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue
func (s *Server) ConfigMetadataBuilder(builder metadata.Builder) {
s.mu.Lock()
s.metadataBuilder = builder
log.Debugf("%s config metadataBuilder is %#v", ServerLogPrefix, builder)
s.mu.Unlock()
}
// ConfigAlpnHandler is used to set alpnHandler by Emitter
func (s *Server) ConfigAlpnHandler(h func(string) error) {
s.mu.Lock()
s.opts.alpnHandler = h
log.Debugf("%s config alpnHandler", ServerLogPrefix)
s.mu.Unlock()
}
// AddDownstreamServer add a downstream server to this server. all the DataFrames will be
// dispatch to all the downStreams.
func (s *Server) AddDownstreamServer(addr string, c *Client) {
s.mu.Lock()
s.downStreams[addr] = c
s.mu.Unlock()
}
// dispatch every DataFrames to all downStreams
func (s *Server) dispatchToDownStreams(c *Context) {
stream, ok, err := s.connector.Get(c.StreamId())
if err != nil {
log.Errorf("%s Connector Get Error, %v", ServerLogPrefix, err)
return
}
if !ok {
log.Debugf("%s dispatchTo Down Streams failed", ServerLogPrefix)
return
}
if stream.StreamType() == StreamTypeSource {
f := c.Frame.(*frame.DataFrame)
if f.IsBroadcast() {
if f.GetMetaFrame().Metadata() == nil {
f.GetMetaFrame().SetMetadata(stream.Metadata())
}
for addr, ds := range s.downStreams {
log.Infof("%s dispatching to, dispatch_addr: %s; tid: %s;", ServerLogPrefix, addr, "", f.TransactionId())
_ = ds.WriteFrame(f)
}
}
}
}
// GetConnId get quic connection id
func GetConnId(conn quic.Connection) string {
return conn.RemoteAddr().String()
}
func (s *Server) validateRouter() error {
if s.router == nil {
return errors.New("server's router is nil")
}
return nil
}
func (s *Server) validateMetadataBuilder() error {
if s.metadataBuilder == nil {
return errors.New("server's metadataBuilder is nil")
}
return nil
}
// SetStartHandlers sets a function for operating connection,
// this function executes after handshake successful.
func (s *Server) SetStartHandlers(handlers ...FrameHandler) {
s.startHandlers = append(s.startHandlers, handlers...)
}
// SetBeforeHandlers set the before handlers of server.
func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) {
s.beforeHandlers = append(s.beforeHandlers, handlers...)
}
// SetAfterHandlers set the after handlers of server.
func (s *Server) SetAfterHandlers(handlers ...FrameHandler) {
s.afterHandlers = append(s.afterHandlers, handlers...)
}
// SetConnectionCloseHandlers set the connection close handlers of server.
func (s *Server) SetConnectionCloseHandlers(handlers ...ConnectionHandler) {
s.connectionCloseHandlers = append(s.connectionCloseHandlers, handlers...)
}
func (s *Server) authNames() []string {
if len(s.opts.auths) == 0 {
return []string{"none"}
}
result := make([]string, 0)
for _, a := range s.opts.auths {
result = append(result, a.Name())
}
return result
}
func (s *Server) doConnectionCloseHandlers(qConn quic.Connection) {
log.Debugf("%s QUIC Connection Closed", ServerLogPrefix)
for _, h := range s.connectionCloseHandlers {
h(qConn)
}
}
func (s *Server) cleanRoute(c *Context) {
md, _ := s.metadataBuilder.Decode(c.DataStream.Metadata())
_ = s.router.Route(md).Remove(c.StreamId())
}