You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
534 lines
15 KiB
534 lines
15 KiB
package network |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"fmt" |
|
"git.hpds.cc/Component/network/auth" |
|
"git.hpds.cc/Component/network/metadata" |
|
"git.hpds.cc/Component/network/router" |
|
"github.com/quic-go/quic-go" |
|
"io" |
|
"net" |
|
"os" |
|
"sync" |
|
"sync/atomic" |
|
|
|
// authentication implements, Currently, only token authentication is implemented |
|
_ "git.hpds.cc/Component/network/auth" |
|
"git.hpds.cc/Component/network/frame" |
|
"git.hpds.cc/Component/network/hpds_err" |
|
"git.hpds.cc/Component/network/log" |
|
) |
|
|
|
// FrameHandler is the handler for frame. |
|
type FrameHandler func(c *Context) error |
|
|
|
// ConnectionHandler is the handler for quic connection |
|
type ConnectionHandler func(conn quic.Connection) |
|
|
|
// Server is the underlining server of Message Queue |
|
type Server struct { |
|
name string |
|
connector *Connector |
|
router router.Router |
|
metadataBuilder metadata.Builder |
|
counterOfDataFrame int64 |
|
downStreams map[string]frame.Writer |
|
mu sync.Mutex |
|
opts *serverOptions |
|
startHandlers []FrameHandler |
|
beforeHandlers []FrameHandler |
|
afterHandlers []FrameHandler |
|
connectionCloseHandlers []ConnectionHandler |
|
listener Listener |
|
logger log.Logger |
|
} |
|
|
|
// NewServer create a Server instance. |
|
func NewServer(name string, opts ...ServerOption) *Server { |
|
options := defaultServerOptions() |
|
|
|
for _, o := range opts { |
|
o(options) |
|
} |
|
s := &Server{ |
|
name: name, |
|
downStreams: make(map[string]frame.Writer), |
|
opts: options, |
|
} |
|
|
|
return s |
|
} |
|
|
|
// ListenAndServe starts the server. |
|
func (s *Server) ListenAndServe(ctx context.Context, addr string) error { |
|
if addr == "" { |
|
addr = DefaultListenAddr |
|
} |
|
udpAddr, err := net.ResolveUDPAddr("udp", addr) |
|
if err != nil { |
|
return err |
|
} |
|
conn, err := net.ListenUDP("udp", udpAddr) |
|
if err != nil { |
|
return err |
|
} |
|
return s.Serve(ctx, conn) |
|
} |
|
|
|
// Serve the server with a net.PacketConn. |
|
func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { |
|
if err := s.validateMetadataBuilder(); err != nil { |
|
return err |
|
} |
|
|
|
if err := s.validateRouter(); err != nil { |
|
return err |
|
} |
|
s.connector = NewConnector(ctx) |
|
// listen the address |
|
listener, err := newListener(conn, s.opts.tlsConfig, s.opts.quicConfig) |
|
if err != nil { |
|
log.Errorf("%s listener.Listen: err=%v", ServerLogPrefix, err) |
|
return err |
|
} |
|
s.listener = listener |
|
|
|
log.Printf("%s [%s][%d] Listening on: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), listener.Versions(), s.authNames()) |
|
|
|
for { |
|
conn, err := s.listener.Accept(ctx) |
|
if err != nil { |
|
log.Errorf("%s listener accept connections error %v", ServerLogPrefix, err) |
|
return err |
|
} |
|
err = s.opts.alpnHandler(conn.ConnectionState().TLS.NegotiatedProtocol) |
|
if err != nil { |
|
_ = conn.CloseWithError(quic.ApplicationErrorCode(hpds_err.ErrorCodeRejected), err.Error()) |
|
continue |
|
} |
|
stream0, err := conn.AcceptStream(ctx) |
|
if err != nil { |
|
continue |
|
} |
|
|
|
controlStream := NewServerControlStream(conn, NewFrameStream(stream0)) |
|
|
|
// Auth accepts a AuthenticationFrame from client. The first frame from client must be |
|
// AuthenticationFrame, It returns true if auth successful otherwise return false. |
|
// It response to client a AuthenticationAckFrame. |
|
err = controlStream.VerifyAuthentication(s.handleAuthenticationFrame) |
|
if err != nil { |
|
log.Warnf("%s Authentication Failed, error: %s", ServerLogPrefix, err) |
|
continue |
|
} |
|
log.Debugf("%s Authentication Success", ServerLogPrefix) |
|
|
|
go func(qConn quic.Connection) { |
|
streamGroup := NewStreamGroup(ctx, controlStream, s.connector) |
|
|
|
defer streamGroup.Wait() |
|
defer s.doConnectionCloseHandlers(qConn) |
|
|
|
select { |
|
case <-ctx.Done(): |
|
return |
|
case err := <-s.runWithStreamGroup(streamGroup): |
|
log.Errorf("%s Client Close, %v", ServerLogPrefix, err) |
|
} |
|
}(conn) |
|
} |
|
} |
|
|
|
func (s *Server) runWithStreamGroup(group *StreamGroup) <-chan error { |
|
errCh := make(chan error) |
|
|
|
go func() { |
|
errCh <- group.Run(s.handleStreamContext) |
|
}() |
|
|
|
return errCh |
|
} |
|
|
|
// Close will shut down the server. |
|
func (s *Server) Close() error { |
|
// connector |
|
if s.connector != nil { |
|
s.connector.Close() |
|
} |
|
// listener |
|
if s.listener != nil { |
|
_ = s.listener.Close() |
|
} |
|
// router |
|
if s.router != nil { |
|
s.router.Clean() |
|
} |
|
return nil |
|
} |
|
|
|
func (s *Server) handleRoute(c *Context) error { |
|
if c.DataStream.StreamType() == StreamTypeStreamFunction { |
|
md, err := s.metadataBuilder.Decode(c.DataStream.Metadata()) |
|
if err != nil { |
|
return err |
|
} |
|
// route |
|
route := s.router.Route(md) |
|
if route == nil { |
|
return errors.New("handleHandshakeFrame route is nil") |
|
} |
|
if err := route.Add(c.StreamId(), c.DataStream.Name(), c.DataStream.ObserveDataTags()); err != nil { |
|
// duplicate name |
|
if e, ok := err.(hpds_err.DuplicateNameError); ok { |
|
existsConnId := e.ConnId() |
|
|
|
log.Debugf("%s StreamFunction Duplicate Name, error: %s; sfn_name: %s, old_stream_id: %s; current_stream_id: %s", |
|
ServerLogPrefix, e.Error(), c.DataStream.Name(), existsConnId, c.StreamId()) |
|
|
|
stream, ok, err := s.connector.Get(existsConnId) |
|
if err != nil { |
|
return err |
|
} |
|
if ok { |
|
_ = stream.CloseWithError(e.Error()) |
|
_ = s.connector.Remove(existsConnId) |
|
} |
|
} else { |
|
return err |
|
} |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
// handleStreamContext handles data streams, |
|
func (s *Server) handleStreamContext(c *Context) { |
|
// handle route. |
|
if err := s.handleRoute(c); err != nil { |
|
c.CloseWithError(hpds_err.ErrorCodeRejected, err.Error()) |
|
return |
|
} |
|
defer s.cleanRoute(c) |
|
|
|
// start frame handlers |
|
for _, handler := range s.startHandlers { |
|
if err := handler(c); err != nil { |
|
log.Errorf("%s startHandlers error: %v", ServerLogPrefix, err) |
|
c.CloseWithError(hpds_err.ErrorCodeStartHandler, err.Error()) |
|
return |
|
} |
|
} |
|
// check update for stream |
|
for { |
|
f, err := c.DataStream.ReadFrame() |
|
if err != nil { |
|
// if client close connection, will get ApplicationError with code = 0x00 |
|
if e, ok := err.(*quic.ApplicationError); ok { |
|
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) { |
|
// client abort |
|
log.Infof("%s client close the connection", ServerLogPrefix) |
|
break |
|
} |
|
he := hpds_err.New(hpds_err.Parse(e.ErrorCode), err) |
|
log.Errorf("%s read frame error: %v", ServerLogPrefix, he) |
|
} else if err == io.EOF { |
|
log.Infof("%s connection EOF", ServerLogPrefix) |
|
break |
|
} |
|
if errors.Is(err, net.ErrClosed) { |
|
// if client close the connection, net.ErrClosed will be raise |
|
// by quic-go IdleTimeoutError after connection's KeepAlive config. |
|
log.Warnf("%s connection error, error: %v", ServerLogPrefix, net.ErrClosed) |
|
c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed") |
|
break |
|
} |
|
// any error occurred, we should close the stream |
|
// after this, conn.AcceptStream() will raise the error |
|
c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error()) |
|
log.Warnf("%s connection close", ServerLogPrefix) |
|
break |
|
} |
|
|
|
// add frame to context |
|
c.WithFrame(f) |
|
|
|
// before frame handlers |
|
for _, handler := range s.beforeHandlers { |
|
if err := handler(c); err != nil { |
|
log.Errorf("%s beforeFrameHandler error: %v", ServerLogPrefix, err) |
|
c.CloseWithError(hpds_err.ErrorCodeBeforeHandler, err.Error()) |
|
return |
|
} |
|
} |
|
// main handler |
|
if err := s.mainFrameHandler(c); err != nil { |
|
log.Errorf("%s mainFrameHandler error: %v", ServerLogPrefix, err) |
|
c.CloseWithError(hpds_err.ErrorCodeMainHandler, err.Error()) |
|
return |
|
} |
|
// after frame handler |
|
for _, handler := range s.afterHandlers { |
|
if err := handler(c); err != nil { |
|
log.Errorf("%s afterFrameHandler error: %v", ServerLogPrefix, err) |
|
c.CloseWithError(hpds_err.ErrorCodeAfterHandler, err.Error()) |
|
return |
|
} |
|
} |
|
} |
|
} |
|
|
|
func (s *Server) mainFrameHandler(c *Context) error { |
|
var err error |
|
frameType := c.Frame.Type() |
|
|
|
switch frameType { |
|
case frame.TagOfDataFrame: |
|
if err = s.handleDataFrame(c); err != nil { |
|
c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err)) |
|
} else { |
|
s.dispatchToDownStreams(c) |
|
// observe data tags back flow |
|
_ = s.handleBackFlowFrame(c) |
|
} |
|
default: |
|
log.Errorf("%s err=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode())) |
|
} |
|
return nil |
|
} |
|
|
|
func (s *Server) handleAuthenticationFrame(f auth.Object) (bool, error) { |
|
ok := auth.Authenticate(s.opts.auths, f) |
|
|
|
if ok { |
|
log.Debugf("%s Successful authentication", ServerLogPrefix) |
|
} else { |
|
log.Warnf("%s Authentication failed, credential: %s", ServerLogPrefix, f.AuthName()) |
|
} |
|
|
|
return ok, nil |
|
} |
|
func (s *Server) handleDataFrame(c *Context) error { |
|
// counter +1 |
|
atomic.AddInt64(&s.counterOfDataFrame, 1) |
|
// currentIssuer := f.GetIssuer() |
|
fromId := c.StreamId() |
|
from, ok, err := s.connector.Get(fromId) |
|
if err != nil { |
|
return err |
|
} |
|
if !ok { |
|
log.Warnf("%s handleDataFrame connector cannot find, from_conn_id: %s", ServerLogPrefix, fromId) |
|
return fmt.Errorf("handleDataFrame connector cannot find %s", fromId) |
|
} |
|
|
|
f := c.Frame.(*frame.DataFrame) |
|
|
|
m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata()) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if m == nil { |
|
m, err = s.metadataBuilder.Decode(from.Metadata()) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
|
|
// route |
|
route := s.router.Route(m) |
|
if route == nil { |
|
log.Warnf("%s handleDataFrame route is nil", ServerLogPrefix) |
|
return fmt.Errorf("handleDataFrame route is nil") |
|
} |
|
|
|
// get stream function connection ids from route |
|
connIDs := route.GetForwardRoutes(f.GetDataTag()) |
|
|
|
log.Debugf("%s Data Routing Status, sfn_stream_ids: %v; connector: %v", ServerLogPrefix, connIDs, s.connector.GetSnapshot()) |
|
|
|
for _, toId := range connIDs { |
|
conn, ok, err := s.connector.Get(toId) |
|
if err != nil { |
|
continue |
|
} |
|
if !ok { |
|
log.Errorf("%s Can't find forward conn, error: conn is nil ;forward_conn_id: ", ServerLogPrefix, toId) |
|
continue |
|
} |
|
|
|
to := conn.Name() |
|
log.Infof("%s handleDataFrame, from_conn_name: %s; from_conn_id: %s; to_conn_name: %s; to_conn_id: %s; data_frame: %s", |
|
ServerLogPrefix, from.Name(), fromId, to, toId, f.String()) |
|
|
|
// write data frame to stream |
|
if err := conn.WriteFrame(f); err != nil { |
|
log.Errorf("%s handleDataFrame conn.Write, %v", ServerLogPrefix, err) |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
func (s *Server) handleBackFlowFrame(c *Context) error { |
|
f := c.Frame.(*frame.DataFrame) |
|
tag := f.GetDataTag() |
|
carriage := f.GetCarriage() |
|
sourceId := f.SourceId() |
|
// write to Protocol Gateway with BackFlowFrame |
|
bf := frame.NewBackFlowFrame(tag, carriage) |
|
sourceConnList, err := s.connector.GetSourceConns(sourceId, tag) |
|
if err != nil { |
|
return err |
|
} |
|
// conn := s.connector.Get(c.connId) |
|
// logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage) |
|
for _, source := range sourceConnList { |
|
if source != nil { |
|
log.Debugf("%s handleBackFlowFrame tag:%#v; source_conn_id: %s, back_flow_frame: %s", ServerLogPrefix, tag, sourceId, f.String()) |
|
if err := source.WriteFrame(bf); err != nil { |
|
log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err) |
|
return err |
|
} |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
// StatsFunctions returns the sfn stats of server. |
|
func (s *Server) StatsFunctions() map[string]string { |
|
return s.connector.GetSnapshot() |
|
} |
|
|
|
// StatsCounter returns how many DataFrames pass through server. |
|
func (s *Server) StatsCounter() int64 { |
|
return s.counterOfDataFrame |
|
} |
|
|
|
// DownStreams return all the downstream servers. |
|
func (s *Server) DownStreams() map[string]frame.Writer { |
|
return s.downStreams |
|
} |
|
|
|
// ConfigRouter is used to set router by Message Queue |
|
func (s *Server) ConfigRouter(router router.Router) { |
|
s.mu.Lock() |
|
s.router = router |
|
log.Debugf("%s config router is %#v", ServerLogPrefix, router) |
|
s.mu.Unlock() |
|
} |
|
|
|
// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue |
|
|
|
func (s *Server) ConfigMetadataBuilder(builder metadata.Builder) { |
|
s.mu.Lock() |
|
s.metadataBuilder = builder |
|
log.Debugf("%s config metadataBuilder is %#v", ServerLogPrefix, builder) |
|
s.mu.Unlock() |
|
} |
|
|
|
// ConfigAlpnHandler is used to set alpnHandler by Emitter |
|
func (s *Server) ConfigAlpnHandler(h func(string) error) { |
|
s.mu.Lock() |
|
s.opts.alpnHandler = h |
|
log.Debugf("%s config alpnHandler", ServerLogPrefix) |
|
s.mu.Unlock() |
|
} |
|
|
|
// AddDownstreamServer add a downstream server to this server. all the DataFrames will be |
|
// dispatch to all the downStreams. |
|
func (s *Server) AddDownstreamServer(addr string, c *Client) { |
|
s.mu.Lock() |
|
s.downStreams[addr] = c |
|
s.mu.Unlock() |
|
} |
|
|
|
// dispatch every DataFrames to all downStreams |
|
func (s *Server) dispatchToDownStreams(c *Context) { |
|
stream, ok, err := s.connector.Get(c.StreamId()) |
|
if err != nil { |
|
log.Errorf("%s Connector Get Error, %v", ServerLogPrefix, err) |
|
return |
|
} |
|
if !ok { |
|
log.Debugf("%s dispatchTo Down Streams failed", ServerLogPrefix) |
|
return |
|
} |
|
|
|
if stream.StreamType() == StreamTypeSource { |
|
f := c.Frame.(*frame.DataFrame) |
|
if f.IsBroadcast() { |
|
if f.GetMetaFrame().Metadata() == nil { |
|
f.GetMetaFrame().SetMetadata(stream.Metadata()) |
|
} |
|
for addr, ds := range s.downStreams { |
|
log.Infof("%s dispatching to, dispatch_addr: %s; tid: %s;", ServerLogPrefix, addr, "", f.TransactionId()) |
|
_ = ds.WriteFrame(f) |
|
} |
|
} |
|
} |
|
} |
|
|
|
// GetConnId get quic connection id |
|
func GetConnId(conn quic.Connection) string { |
|
return conn.RemoteAddr().String() |
|
} |
|
|
|
func (s *Server) validateRouter() error { |
|
if s.router == nil { |
|
return errors.New("server's router is nil") |
|
} |
|
return nil |
|
} |
|
|
|
func (s *Server) validateMetadataBuilder() error { |
|
if s.metadataBuilder == nil { |
|
return errors.New("server's metadataBuilder is nil") |
|
} |
|
return nil |
|
} |
|
|
|
// SetStartHandlers sets a function for operating connection, |
|
// this function executes after handshake successful. |
|
func (s *Server) SetStartHandlers(handlers ...FrameHandler) { |
|
s.startHandlers = append(s.startHandlers, handlers...) |
|
} |
|
|
|
// SetBeforeHandlers set the before handlers of server. |
|
func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) { |
|
s.beforeHandlers = append(s.beforeHandlers, handlers...) |
|
} |
|
|
|
// SetAfterHandlers set the after handlers of server. |
|
func (s *Server) SetAfterHandlers(handlers ...FrameHandler) { |
|
s.afterHandlers = append(s.afterHandlers, handlers...) |
|
} |
|
|
|
// SetConnectionCloseHandlers set the connection close handlers of server. |
|
func (s *Server) SetConnectionCloseHandlers(handlers ...ConnectionHandler) { |
|
s.connectionCloseHandlers = append(s.connectionCloseHandlers, handlers...) |
|
} |
|
|
|
func (s *Server) authNames() []string { |
|
if len(s.opts.auths) == 0 { |
|
return []string{"none"} |
|
} |
|
result := make([]string, 0) |
|
for _, a := range s.opts.auths { |
|
result = append(result, a.Name()) |
|
} |
|
return result |
|
} |
|
|
|
func (s *Server) doConnectionCloseHandlers(qConn quic.Connection) { |
|
log.Debugf("%s QUIC Connection Closed", ServerLogPrefix) |
|
for _, h := range s.connectionCloseHandlers { |
|
h(qConn) |
|
} |
|
} |
|
|
|
func (s *Server) cleanRoute(c *Context) { |
|
md, _ := s.metadataBuilder.Decode(c.DataStream.Metadata()) |
|
_ = s.router.Route(md).Remove(c.StreamId()) |
|
}
|
|
|