diff --git a/auth/auth.go b/auth/auth.go index d43b728..0fdc0f4 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -57,3 +57,33 @@ func (c *Credential) Payload() string { func (c *Credential) Name() string { return c.name } + +// Object is the object to be authenticated, +// The Object usually be pass to `Authenticate` function to be authed. +type Object interface { + // AuthName returns the auth name, the name will be used to find the auth way. + AuthName() string + + // AuthPayload returns the auth payload be passed to `auth.Authenticate`. + AuthPayload() string +} + +// Authenticate finds an authentication way in `auths` and authenticates the Object. +// +// If `auths` is nil or empty, It returns true, It think that authentication is not required. +func Authenticate(auths map[string]Authentication, obj Object) bool { + if auths == nil || len(auths) <= 0 { + return true + } + + if obj == nil { + return false + } + + auth, ok := auths[obj.AuthName()] + if !ok { + return false + } + + return auth.Authenticate(obj.AuthPayload()) +} diff --git a/client.go b/client.go index d4af3f3..712dc40 100644 --- a/client.go +++ b/client.go @@ -3,459 +3,312 @@ package network import ( "context" "errors" - "fmt" - "git.hpds.cc/Component/network/hpds_err" "git.hpds.cc/Component/network/id" - pkgtls "git.hpds.cc/Component/network/tls" - "net" - "sync" "time" - "github.com/lucas-clemente/quic-go" - - "git.hpds.cc/Component/network/auth" "git.hpds.cc/Component/network/frame" "git.hpds.cc/Component/network/log" ) // ClientOption client options -type ClientOption func(*ClientOptions) - -// ConnState describes the state of the connection. -type ConnState = string +type ClientOption func(*clientOptions) // Client is the abstraction of a HPDS-Client. a HPDS-Client can be // Protocol Gateway, Message Queue or StreamFunction. type Client struct { name string // name of the client clientId string // id of the client - clientType ClientType // type of the connection - conn quic.Connection // quic connection - stream quic.Stream // quic stream - state ConnState // state of the connection - processor func(*frame.DataFrame) // functions to invoke when data arrived - receiver func(*frame.BackFlowFrame) // functions to invoke when data is processed - addr string // the address of server connected to - mu sync.Mutex - opts ClientOptions - localAddr string // client local addr, it will be changed on reconnect - logger log.Logger - errChan chan error - closeChan chan bool - closed bool + streamType StreamType // type of the dataStream + processor func(*frame.DataFrame) // function to invoke when data arrived + receiver func(*frame.BackFlowFrame) // function to invoke when data is processed + errorfn func(error) // function to invoke when error occured + opts *clientOptions + + // ctx and ctxCancel manage the lifecycle of client. + ctx context.Context + ctxCancel context.CancelFunc + + writeFrameChan chan frame.Frame + shutdownChan chan error } // NewClient creates a new HPDS-Client. func NewClient(appName string, connType ClientType, opts ...ClientOption) *Client { - c := &Client{ - name: appName, - clientId: id.New(), - clientType: connType, - state: ConnStateReady, - opts: ClientOptions{}, - errChan: make(chan error), - closeChan: make(chan bool), - } - _ = c.Init(opts...) - once.Do(func() { - c.init() - }) + option := defaultClientOption() - return c -} - -// Init the options. -func (c *Client) Init(opts ...ClientOption) error { for _, o := range opts { - o(&c.opts) + o(option) + } + clientId := id.New() + + if option.credential != nil { + log.Infof("use credential, credential_name: %s;", option.credential.Name()) + } + + ctx, ctxCancel := context.WithCancel(context.Background()) + + return &Client{ + name: appName, + clientId: clientId, + streamType: connType, + opts: option, + errorfn: func(err error) { log.Errorf("client err, %s", err) }, + writeFrameChan: make(chan frame.Frame), + shutdownChan: make(chan error, 1), + ctx: ctx, + ctxCancel: ctxCancel, } - return c.initOptions() } // Connect connects to HPDS-MessageQueue. func (c *Client) Connect(ctx context.Context, addr string) error { - - // TODO: refactor this later as a Connection Manager - // reconnect - // for download mq - // If you do not check for errors, the connection will be automatically reconnected - go c.reconnect(ctx, addr) - - // connect - if err := c.connect(ctx, addr); err != nil { + controlStream, dataStream, err := c.openStream(ctx, addr) + if err != nil { + log.Errorf("connect error, %s", err) return err } + go c.runBackground(ctx, addr, controlStream, dataStream) + return nil } -func (c *Client) connect(ctx context.Context, addr string) error { - c.addr = addr - c.state = ConnStateConnecting +func (c *Client) runBackground(ctx context.Context, addr string, controlStream ClientControlStream, dataStream DataStream) { + reconnection := make(chan struct{}) - // create quic connection - conn, err := quic.DialAddrContext(ctx, addr, c.opts.TLSConfig, c.opts.QuicConfig) - if err != nil { - c.state = ConnStateDisconnected - return err - } + go c.processStream(controlStream, dataStream, reconnection) - // quic stream - stream, err := conn.OpenStreamSync(ctx) - if err != nil { - c.state = ConnStateDisconnected - return err - } - - c.stream = stream - c.conn = conn - - c.state = ConnStateAuthenticating - // send handshake - handshake := frame.NewHandshakeFrame( - c.name, - c.clientId, - byte(c.clientType), - c.opts.ObserveDataTags, - c.opts.Credential.Name(), - c.opts.Credential.Payload(), - ) - err = c.WriteFrame(handshake) - if err != nil { - c.state = ConnStateRejected - return err - } - c.state = ConnStateConnected - c.localAddr = c.conn.LocalAddr().String() - - c.logger.Printf("%s [%s][%s](%s) is connected to HPDS-MQ %s", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr) - - // receiving frames - go c.handleFrame() - - return nil -} - -// handleFrame handles the logic when receiving frame from server. -func (c *Client) handleFrame() { - // transform raw QUIC stream to wire format - fs := NewFrameStream(c.stream) for { - c.logger.Debugf("%shandleFrame connection state=%v", ClientLogPrefix, c.state) - // this will block until a frame is received - f, err := fs.ReadFrame() - if err != nil { - defer func() { - _ = c.stream.Close() - }() - - c.logger.Debugf("%shandleFrame(): %T | %v", ClientLogPrefix, err, err) - if e, ok := err.(*quic.IdleTimeoutError); ok { - c.logger.Errorf("%sconnection timeout, err=%v, mq addr=%s", ClientLogPrefix, e, c.addr) - c.setState(ConnStateDisconnected) - } else if e, ok := err.(*quic.ApplicationError); ok { - c.logger.Infof("%sapplication error, err=%v, errcode=%v", ClientLogPrefix, e, e.ErrorCode) - if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeRejected) { - // if connection is rejected(eg: authenticate fails) from server - c.logger.Errorf("%sIllegal client, server rejected.", ClientLogPrefix) - c.setState(ConnStateRejected) - break - } else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) { - // client abort - c.logger.Infof("%sclient close the connection", ClientLogPrefix) - c.setState(ConnStateAborted) - break - } else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeGoaway) { - // server goaway - c.logger.Infof("%sserver goaway the connection", ClientLogPrefix) - c.setState(ConnStateGoaway) - break - } else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeHandshake) { - // handshake - c.logger.Errorf("%shandshake fails", ClientLogPrefix) - c.setState(ConnStateRejected) - break - } - } else if errors.Is(err, net.ErrClosed) { - // if client close the connection, net.ErrClosed will be raised - c.logger.Errorf("%sconnection is closed, err=%v", ClientLogPrefix, err) - c.setState(ConnStateDisconnected) - // by quic-go IdleTimeoutError after connection's KeepAlive config. - break - } else { - // any error occurred, we should close the stream - // after this, conn.AcceptStream() will raise the error - c.setState(ConnStateClosed) - _ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeUnknown), err.Error()) - c.logger.Errorf("%sunknown error occurred, err=%v, state=%v", ClientLogPrefix, err, c.getState()) - break + select { + case <-c.ctx.Done(): + c.cleanStream(controlStream, nil) + return + case <-ctx.Done(): + c.cleanStream(controlStream, ctx.Err()) + return + case <-reconnection: + RECONNECT: + var err error + controlStream, dataStream, err = c.openStream(ctx, addr) + if err != nil { + log.Errorf("client reconnect error, %s", err) + time.Sleep(time.Second) + goto RECONNECT } - } - if f == nil { - break - } - // read frame - // first, get frame type - frameType := f.Type() - c.logger.Debugf("%stype=%s, frame=%# x", ClientLogPrefix, frameType, frame.Shortly(f.Encode())) - switch frameType { - case frame.TagOfHandshakeFrame: - if v, ok := f.(*frame.HandshakeFrame); ok { - c.logger.Debugf("%sreceive HandshakeFrame, name=%v", ClientLogPrefix, v.Name) - } - case frame.TagOfPongFrame: - c.setState(ConnStatePong) - case frame.TagOfAcceptedFrame: - c.setState(ConnStateAccepted) - case frame.TagOfRejectedFrame: - c.setState(ConnStateRejected) - if v, ok := f.(*frame.RejectedFrame); ok { - c.logger.Errorf("%s receive RejectedFrame, message=%s", ClientLogPrefix, v.Message()) - _ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeRejected), v.Message()) - c.errChan <- errors.New(v.Message()) - break - } - case frame.TagOfGoawayFrame: - c.setState(ConnStateGoaway) - if v, ok := f.(*frame.GoawayFrame); ok { - c.logger.Errorf("%s️ receive GoawayFrame, message=%s", ClientLogPrefix, v.Message()) - _ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeGoaway), v.Message()) - c.errChan <- errors.New(v.Message()) - break - } - case frame.TagOfDataFrame: // DataFrame carries user's data - if v, ok := f.(*frame.DataFrame); ok { - c.setState(ConnStateTransportData) - c.logger.Debugf("%sreceive DataFrame, tag=%#x, tid=%s, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.TransactionId(), v.GetCarriage()) - if c.processor == nil { - c.logger.Warnf("%sprocessor is nil", ClientLogPrefix) - } else { - c.processor(v) - } - } - case frame.TagOfBackFlowFrame: - if v, ok := f.(*frame.BackFlowFrame); ok { - c.logger.Debugf("%sreceive BackFlowFrame, tag=%#x, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.GetCarriage()) - if c.receiver == nil { - c.logger.Warnf("%sreceiver is nil", ClientLogPrefix) - } else { - c.setState(ConnStateBackFlow) - c.receiver(v) - } - } - default: - c.logger.Errorf("%sunknown signal", ClientLogPrefix) + go c.processStream(controlStream, dataStream, reconnection) } } } +// WriteFrame write frame to client. +func (c *Client) WriteFrame(f frame.Frame) error { + c.writeFrameChan <- f + return nil +} + +func (c *Client) cleanStream(controlStream ClientControlStream, err error) { + errString := "" + if err != nil { + errString = err.Error() + log.Errorf("client cancel with error, %s", err) + } + + // controlStream is nil represents that client is not connected. + if controlStream == nil { + return + } + + _ = controlStream.CloseWithError(0, errString) +} + // Close the client. -func (c *Client) Close() (err error) { - if c.conn != nil { - c.logger.Printf("%sclose the connection, name:%s, id:%s, addr:%s", ClientLogPrefix, c.name, c.clientId, c.conn.RemoteAddr().String()) - } - if c.stream != nil { - err = c.stream.Close() - if err != nil { - c.logger.Errorf("%s stream.Close(): %v", ClientLogPrefix, err) - } - } - if c.conn != nil { - err = c.conn.CloseWithError(0, "client-ask-to-close-this-connection") - if err != nil { - c.logger.Errorf("%s connection.Close(): %v", ClientLogPrefix, err) - } - } - // close channel - c.mu.Lock() - if !c.closed { - close(c.errChan) - close(c.closeChan) - c.closed = true - } - c.mu.Unlock() +func (c *Client) Close() error { + // break runBackgroud() for-loop. + c.ctxCancel() - return err + // non-blocking to return Wait(). + select { + case c.shutdownChan <- nil: + default: + } + + return nil } -// WriteFrame writes a frame to the connection, gurantee threadsafe. -func (c *Client) WriteFrame(frm frame.Frame) error { - // write on QUIC stream - if c.stream == nil { - return errors.New("stream is nil") - } - if c.state == ConnStateDisconnected || c.state == ConnStateRejected { - return fmt.Errorf("client connection state is %s", c.state) - } - c.logger.Debugf("%s[%s](%s)@%s WriteFrame() will write frame: %s", ClientLogPrefix, c.name, c.localAddr, c.state, frm.Type()) - - data := frm.Encode() - // emit raw bytes of Frame - c.mu.Lock() - n, err := c.stream.Write(data) - c.mu.Unlock() - c.logger.Debugf("%sWriteFrame() wrote n=%d, data=%# x", ClientLogPrefix, n, frame.Shortly(data)) +func (c *Client) openControlStream(ctx context.Context, addr string) (ClientControlStream, error) { + controlStream, err := OpenClientControlStream(ctx, addr, c.opts.tlsConfig, c.opts.quicConfig) if err != nil { - c.setState(ConnStateDisconnected) - // c.state = ConnStateDisconnected - if e, ok := err.(*quic.IdleTimeoutError); ok { - c.logger.Errorf("%sWriteFrame() connection timeout, err=%v", ClientLogPrefix, e) - } else { - c.logger.Errorf("%sWriteFrame() wrote error=%v", ClientLogPrefix, err) - return err + return nil, err + } + + if err := controlStream.Authenticate(c.opts.credential); err != nil { + return nil, err + } + + return controlStream, nil +} + +func (c *Client) openStream(ctx context.Context, addr string) (ClientControlStream, DataStream, error) { + controlStream, err := c.openControlStream(ctx, addr) + if err != nil { + return nil, nil, err + } + dataStream, err := c.openDataStream(ctx, controlStream) + if err != nil { + return nil, nil, err + } + + return controlStream, dataStream, nil +} + +func (c *Client) openDataStream(ctx context.Context, controlStream ClientControlStream) (DataStream, error) { + handshakeFrame := frame.NewHandshakeFrame( + c.name, + c.clientId, + byte(c.streamType), + c.opts.observeDataTags, + []byte{}, // The stream does not require metadata currently. + ) + dataStream, err := controlStream.OpenStream(ctx, handshakeFrame) + if err != nil { + return nil, err + } + + return dataStream, nil +} + +func (c *Client) processStream(controlStream ClientControlStream, dataStream DataStream, reconnection chan<- struct{}) { + defer func() { + _ = dataStream.Close() + }() + + var ( + controlStreamErrChan = c.receivingStreamClose(controlStream, dataStream) + readFrameChan = c.readFrame(dataStream) + ) + for { + select { + case err := <-controlStreamErrChan: + c.shutdownWithError(err) + case result := <-readFrameChan: + if err := result.err; err != nil { + c.errorfn(err) + reconnection <- struct{}{} + return + } + c.handleFrame(result.frame) + case f := <-c.writeFrameChan: + err := dataStream.WriteFrame(f) + // restore DataFrame. + if d, ok := f.(*frame.DataFrame); ok { + d.Clean() + } + if err != nil { + c.errorfn(err) + reconnection <- struct{}{} + return + } } } - if n != len(data) { - err := errors.New("[client] hpds Client .Write() wrote error") - c.logger.Errorf("%s error:%v", ClientLogPrefix, err) - return err - } +} + +// Wait waits client error returning. +func (c *Client) Wait() error { + err := <-c.shutdownChan return err } -// update connection state -func (c *Client) setState(state ConnState) { - c.logger.Debugf("setState to:%s", state) - c.mu.Lock() - c.state = state - c.mu.Unlock() +func (c *Client) shutdownWithError(err error) { + // non-blocking shutdown client. + select { + case c.shutdownChan <- err: + default: + } } -// getState get connection state -func (c *Client) getState() ConnState { - c.mu.Lock() - defer c.mu.Unlock() - return c.state +type readResult struct { + frame frame.Frame + err error } -// update connection local addr -func (c *Client) setLocalAddr(addr string) { - c.mu.Lock() - c.localAddr = addr - c.mu.Unlock() +func (c *Client) readFrame(dataStream DataStream) chan readResult { + readChan := make(chan readResult) + go func() { + for { + f, err := dataStream.ReadFrame() + readChan <- readResult{f, err} + if err != nil { + return + } + } + }() + + return readChan +} + +func (c *Client) handleFrame(f frame.Frame) { + switch ff := f.(type) { + case *frame.DataFrame: + if c.processor == nil { + log.Warnf("client processor has not been set") + } else { + c.processor(ff) + } + case *frame.BackFlowFrame: + if c.receiver == nil { + log.Warnf("client receiver has not been set") + } else { + c.receiver(ff) + } + default: + log.Warnf("client data stream receive unexcepted frame, frame_type: %v", f) + } +} + +func (c *Client) receivingStreamClose(controlStream ControlStream, dataStream DataStream) chan error { + closeStreamChan := make(chan error) + + go func() { + for { + streamID, reason, err := controlStream.ReceiveStreamClose() + if err != nil { + closeStreamChan <- err + return + } + if streamID == c.clientId { + c.ctxCancel() + _ = dataStream.Close() + closeStreamChan <- errors.New(reason) + _ = controlStream.CloseWithError(0, reason) + return + } + } + }() + + return closeStreamChan } // SetDataFrameObserver sets the data frame handler. func (c *Client) SetDataFrameObserver(fn func(*frame.DataFrame)) { c.processor = fn - c.logger.Debugf("%sSetDataFrameObserver(%v)", ClientLogPrefix, c.processor) + log.Debugf("SetDataFrameObserver") } // SetBackFlowFrameObserver sets the backflow frame handler. func (c *Client) SetBackFlowFrameObserver(fn func(*frame.BackFlowFrame)) { c.receiver = fn - c.logger.Debugf("%sSetBackFlowFrameObserver(%v)", ClientLogPrefix, c.receiver) -} - -// reconnect the connection between client and server. -func (c *Client) reconnect(ctx context.Context, addr string) { - t := time.NewTicker(1 * time.Second) - defer t.Stop() - - for { - select { - case <-ctx.Done(): - c.logger.Debugf("%s[%s](%s) context.Done()", ClientLogPrefix, c.name, c.localAddr) - return - case <-c.closeChan: - c.logger.Debugf("%s[%s](%s) close channel", ClientLogPrefix, c.name, c.localAddr) - return - case <-t.C: - if c.getState() == ConnStateDisconnected { - c.logger.Printf("%s[%s][%s](%s) is reconnecting to HPDS-MQ %s...", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr) - err := c.connect(ctx, addr) - if err != nil { - c.logger.Errorf("%s[%s][%s](%s) reconnect error:%v", ClientLogPrefix, c.name, c.clientId, c.localAddr, err) - } - } - } - } -} - -func (c *Client) init() { - // // tracing - // _, _, err := tracing.NewTracerProvider(c.name) - // if err != nil { - // logger.Errorf("tracing: %v", err) - // } -} - -// ServerAddr returns the address of the server. -func (c *Client) ServerAddr() string { - return c.addr -} - -// initOptions init options defaults -func (c *Client) initOptions() error { - // logger - if c.logger == nil { - if c.opts.Logger != nil { - c.logger = c.opts.Logger - } else { - c.logger = log.Default() - } - } - // observe tag list - if c.opts.ObserveDataTags == nil { - c.opts.ObserveDataTags = make([]byte, 0) - } - // credential - if c.opts.Credential == nil { - c.opts.Credential = auth.NewCredential("") - } - // tls config - if c.opts.TLSConfig == nil { - tc, err := pkgtls.CreateClientTLSConfig() - if err != nil { - c.logger.Errorf("%sCreateClientTLSConfig: %v", ClientLogPrefix, err) - return err - } - c.opts.TLSConfig = tc - } - // quic config - if c.opts.QuicConfig == nil { - c.opts.QuicConfig = &quic.Config{ - Versions: []quic.VersionNumber{quic.Version1, quic.VersionDraft29}, - MaxIdleTimeout: time.Second * 40, - KeepAlivePeriod: time.Second * 20, - MaxIncomingStreams: 1000, - MaxIncomingUniStreams: 1000, - HandshakeIdleTimeout: time.Second * 3, - InitialStreamReceiveWindow: 1024 * 1024 * 2, - InitialConnectionReceiveWindow: 1024 * 1024 * 2, - TokenStore: quic.NewLRUTokenStore(10, 5), - DisablePathMTUDiscovery: true, - } - } - // credential - if c.opts.Credential != nil { - c.logger.Printf("%suse credential: [%s]", ClientLogPrefix, c.opts.Credential.Name()) - } - - return nil + log.Debugf("SetBackFlowFrameObserver") } // SetObserveDataTags set the data tag list that will be observed. -// Deprecated: use hpds.WithObserveDataTags instead -func (c *Client) SetObserveDataTags(tag ...byte) { - c.opts.ObserveDataTags = append(c.opts.ObserveDataTags, tag...) -} - -// Logger get client's logger instance, you can customize this using `hpds.WithLogger` -func (c *Client) Logger() log.Logger { - return c.logger +// Deprecated: use yomo.WithObserveDataTags instead +func (c *Client) SetObserveDataTags(tag ...frame.Tag) { + c.opts.observeDataTags = append(c.opts.observeDataTags, tag...) } // SetErrorHandler set error handler func (c *Client) SetErrorHandler(fn func(err error)) { - if fn != nil { - go func() { - err := <-c.errChan - if err != nil { - fn(err) - } - }() - } + c.errorfn = fn } // ClientId return the client ID diff --git a/client_options.go b/client_options.go index 00079f9..7dfb67a 100644 --- a/client_options.go +++ b/client_options.go @@ -2,52 +2,85 @@ package network import ( "crypto/tls" - "github.com/lucas-clemente/quic-go" + "git.hpds.cc/Component/network/frame" + "github.com/quic-go/quic-go" + "time" "git.hpds.cc/Component/network/auth" - "git.hpds.cc/Component/network/log" + pkgtls "git.hpds.cc/Component/network/tls" ) -// ClientOptions are the options for HPDS client. -type ClientOptions struct { - ObserveDataTags []byte - QuicConfig *quic.Config - TLSConfig *tls.Config - Credential *auth.Credential - Logger log.Logger +// clientOptions are the options for YoMo client. +type clientOptions struct { + observeDataTags []frame.Tag + quicConfig *quic.Config + tlsConfig *tls.Config + credential *auth.Credential +} + +func defaultClientOption() *clientOptions { + defaultQuicConfig := &quic.Config{ + Versions: []quic.VersionNumber{quic.VersionDraft29, quic.Version1, quic.Version2}, + MaxIdleTimeout: time.Second * 40, + KeepAlivePeriod: time.Second * 20, + MaxIncomingStreams: 1000, + MaxIncomingUniStreams: 1000, + HandshakeIdleTimeout: time.Second * 3, + InitialStreamReceiveWindow: 1024 * 1024 * 2, + InitialConnectionReceiveWindow: 1024 * 1024 * 2, + TokenStore: quic.NewLRUTokenStore(10, 5), + } + + opts := &clientOptions{ + observeDataTags: make([]frame.Tag, 0), + quicConfig: defaultQuicConfig, + tlsConfig: pkgtls.MustCreateClientTLSConfig(), + credential: auth.NewCredential(""), + } + + return opts } // WithObserveDataTags sets data tag list for the client. -func WithObserveDataTags(tags ...byte) ClientOption { - return func(o *ClientOptions) { - o.ObserveDataTags = tags +func WithObserveDataTags(tags ...frame.Tag) ClientOption { + return func(o *clientOptions) { + o.observeDataTags = tags } } // WithCredential sets the client credential method (used by client). func WithCredential(payload string) ClientOption { - return func(o *ClientOptions) { - o.Credential = auth.NewCredential(payload) + return func(o *clientOptions) { + o.credential = auth.NewCredential(payload) } } // WithClientTLSConfig sets tls config for the client. func WithClientTLSConfig(tc *tls.Config) ClientOption { - return func(o *ClientOptions) { - o.TLSConfig = tc + return func(o *clientOptions) { + if tc != nil { + o.tlsConfig = tc + } } } // WithClientQuicConfig sets quic config for the client. func WithClientQuicConfig(qc *quic.Config) ClientOption { - return func(o *ClientOptions) { - o.QuicConfig = qc + return func(o *clientOptions) { + o.quicConfig = qc } } -// WithLogger sets logger for the client. -func WithLogger(logger log.Logger) ClientOption { - return func(o *ClientOptions) { - o.Logger = logger - } -} +// ClientType is equal to StreamType. +type ClientType = StreamType + +const ( + // ClientTypeSource is equal to StreamTypeSource. + ClientTypeSource ClientType = StreamTypeSource + + // ClientTypeUpstreamEmitter is equal to StreamTypeUpstreamEmitter. + ClientTypeUpstreamEmitter ClientType = StreamTypeUpstreamEmitter + + // ClientTypeStreamFunction is equal to StreamTypeStreamFunction. + ClientTypeStreamFunction ClientType = StreamTypeStreamFunction +) diff --git a/client_type.go b/client_type.go deleted file mode 100644 index 6c2baef..0000000 --- a/client_type.go +++ /dev/null @@ -1,28 +0,0 @@ -package network - -const ( - // ClientTypeNone is connection type "None". - ClientTypeNone ClientType = 0xFF - // ClientTypeProtocolGateway is connection type "Protocol Gateway". - ClientTypeProtocolGateway ClientType = 0x5F - // ClientTypeMessageQueue is connection type "Message Queue". - ClientTypeMessageQueue ClientType = 0x5E - // ClientTypeStreamFunction is connection type "Stream Function". - ClientTypeStreamFunction ClientType = 0x5D -) - -// ClientType represents the connection type. -type ClientType byte - -func (c ClientType) String() string { - switch c { - case ClientTypeProtocolGateway: - return "Source" - case ClientTypeMessageQueue: - return "Message Queue" - case ClientTypeStreamFunction: - return "Stream Function" - default: - return "None" - } -} diff --git a/connection.go b/connection.go index 1c51705..0b8e104 100644 --- a/connection.go +++ b/connection.go @@ -3,6 +3,7 @@ package network import ( "git.hpds.cc/Component/network/frame" "git.hpds.cc/Component/network/log" + "git.hpds.cc/Component/network/metadata" "io" "sync" ) @@ -18,7 +19,7 @@ type Connection interface { // ClientType returns the type of the client (Protocol Gateway | Message Queue | Stream Function) ClientType() ClientType // Metadata returns the extra info of the application - Metadata() Metadata + Metadata() metadata.Metadata // Write should goroutine-safely send coder frames to peer side Write(f frame.Frame) error // ObserveDataTags observed data tags @@ -28,7 +29,7 @@ type Connection interface { type connection struct { name string clientType ClientType - metadata Metadata + metadata metadata.Metadata stream io.ReadWriteCloser clientId string observed []byte // observed data tags @@ -36,7 +37,7 @@ type connection struct { closed bool } -func newConnection(name string, clientId string, clientType ClientType, metadata Metadata, +func newConnection(name string, clientId string, clientType ClientType, metadata metadata.Metadata, stream io.ReadWriteCloser, observed []byte) Connection { return &connection{ name: name, @@ -68,7 +69,7 @@ func (c *connection) ClientType() ClientType { } // Metadata returns the extra info of the application -func (c *connection) Metadata() Metadata { +func (c *connection) Metadata() metadata.Metadata { return c.metadata } @@ -77,7 +78,7 @@ func (c *connection) Write(f frame.Frame) error { c.mu.Lock() defer c.mu.Unlock() if c.closed { - log.Warnf("%sclient stream is closed: %s", ServerLogPrefix, c.clientId) + log.Warnf("client stream is closed: %s", c.clientId) return nil } _, err := c.stream.Write(f.Encode()) diff --git a/connector.go b/connector.go index 756035d..0834c08 100644 --- a/connector.go +++ b/connector.go @@ -1,87 +1,135 @@ package network import ( + "context" + "errors" + "git.hpds.cc/Component/network/frame" "git.hpds.cc/Component/network/log" "sync" ) -var _ Connector = &connector{} +// ErrConnectorClosed will be returned if the connector has been closed. +var ErrConnectorClosed = errors.New("hpdsMq: connector closed") -// Connector is an interface to manage the connections and applications. -type Connector interface { - // Add a connection. - Add(connId string, conn Connection) - // Remove a connection. - Remove(connId string) - // Get a connection by connection id. - Get(connId string) Connection - // GetSnapshot gets the snapshot of all connections. - GetSnapshot() map[string]string - // GetProtocolGatewayConnections gets the connections by Protocol Gateway observe tags. - GetProtocolGatewayConnections(sourceId string, tags byte) []Connection - // Clean the connector. - Clean() +// The Connector class manages data streams and provides a centralized way to get and set streams. +type Connector struct { + // ctx and ctxCancel manage the lifescyle of Connector. + ctx context.Context + ctxCancel context.CancelFunc + + streams sync.Map } -type connector struct { - conns sync.Map -} +// NewConnector returns an initial Connector. +func NewConnector(ctx context.Context) *Connector { + ctx, ctxCancel := context.WithCancel(ctx) -func newConnector() Connector { - return &connector{conns: sync.Map{}} -} - -// Add a connection. -func (c *connector) Add(connID string, conn Connection) { - log.Debugf("%sconnector add: connId=%s", ServerLogPrefix, connID) - c.conns.Store(connID, conn) -} - -// Remove a connection. -func (c *connector) Remove(connID string) { - log.Debugf("%sconnector remove: connId=%s", ServerLogPrefix, connID) - c.conns.Delete(connID) -} - -// Get a connection by connection id. -func (c *connector) Get(connID string) Connection { - log.Debugf("%sconnector get connection: connId=%s", ServerLogPrefix, connID) - if conn, ok := c.conns.Load(connID); ok { - return conn.(Connection) + return &Connector{ + ctx: ctx, + ctxCancel: ctxCancel, } +} + +// Add adds DataStream to Connector, +// If the streamID is the same twice, the new stream will replace the old stream. +func (c *Connector) Add(streamId string, stream DataStream) error { + select { + case <-c.ctx.Done(): + return ErrConnectorClosed + default: + } + + c.streams.Store(streamId, stream) + + log.Debugf("Connector add stream, stream_id: %s", streamId) return nil } -// GetProtocolGatewayConnections gets the Protocol Gateway connection by tag. -func (c *connector) GetProtocolGatewayConnections(sourceId string, tag byte) []Connection { - conns := make([]Connection, 0) +// Remove removes the DataStream with the specified streamID. +// If the Connector does not have a stream with the given streamID, no action is taken. +func (c *Connector) Remove(streamId string) error { + select { + case <-c.ctx.Done(): + return ErrConnectorClosed + default: + } - c.conns.Range(func(key interface{}, val interface{}) bool { - conn := val.(Connection) - for _, v := range conn.ObserveDataTags() { - if v == tag && conn.ClientType() == ClientTypeProtocolGateway && conn.ClientId() == sourceId { - conns = append(conns, conn) + c.streams.Delete(streamId) + log.Debugf("Connector remove stream, stream_id: %s", streamId) + + return nil +} + +// Get retrieves the DataStream with the specified streamID. +// If the Connector does not have a stream with the given streamID, return nil and false. +func (c *Connector) Get(streamId string) (DataStream, bool, error) { + select { + case <-c.ctx.Done(): + return nil, false, ErrConnectorClosed + default: + } + + v, ok := c.streams.Load(streamId) + if !ok { + return nil, false, nil + } + + stream := v.(DataStream) + + return stream, true, nil +} + +// GetSourceConns gets the streams with the specified source observe tag. +func (c *Connector) GetSourceConns(sourceId string, tag frame.Tag) ([]DataStream, error) { + select { + case <-c.ctx.Done(): + return []DataStream{}, ErrConnectorClosed + default: + } + + streams := make([]DataStream, 0) + + c.streams.Range(func(key interface{}, val interface{}) bool { + stream := val.(DataStream) + + for _, v := range stream.ObserveDataTags() { + if v == tag && + stream.StreamType() == StreamTypeSource && + stream.ID() == sourceId { + streams = append(streams, stream) } } return true }) - return conns + return streams, nil } -// GetSnapshot gets the snapshot of all connections. -func (c *connector) GetSnapshot() map[string]string { +// GetSnapshot returnsa snapshot of all streams. +// The resulting map uses streamID as the key and stream name as the value. +// This function is typically used to monitor the status of the Connector. +func (c *Connector) GetSnapshot() map[string]string { result := make(map[string]string) - c.conns.Range(func(key interface{}, val interface{}) bool { - connID := key.(string) - conn := val.(Connection) - result[connID] = conn.Name() + + c.streams.Range(func(key interface{}, val interface{}) bool { + var ( + streamID = key.(string) + stream = val.(DataStream) + ) + result[streamID] = stream.Name() return true }) + return result } -// Clean the connector. -func (c *connector) Clean() { - c.conns = sync.Map{} +// Close cleans all stream of Connector and reset Connector to closed status. +// The Connector can't be use after close. +func (c *Connector) Close() { + c.ctxCancel() + + c.streams.Range(func(key, value any) bool { + c.streams.Delete(key) + return true + }) } diff --git a/constant.go b/constant.go index 2cddae8..57bff5f 100644 --- a/constant.go +++ b/constant.go @@ -2,37 +2,19 @@ package network import ( "math/rand" - "sync" "time" ) -var ( - once sync.Once -) +// ConnState represents the state of the connection. +type ConnState = string // ConnState represents the state of a connection. const ( - ConnStateReady ConnState = "Ready" - ConnStateDisconnected ConnState = "Disconnected" - ConnStateConnecting ConnState = "Connecting" - ConnStateConnected ConnState = "Connected" - ConnStateAuthenticating ConnState = "Authenticating" - ConnStateAccepted ConnState = "Accepted" - ConnStateRejected ConnState = "Rejected" - ConnStatePing ConnState = "Ping" - ConnStatePong ConnState = "Pong" - ConnStateTransportData ConnState = "TransportData" - ConnStateAborted ConnState = "Aborted" - ConnStateClosed ConnState = "Closed" // close connection by server - ConnStateGoaway ConnState = "Goaway" - ConnStateBackFlow ConnState = "BackFlow" -) - -// Prefix is the prefix for logger. -const ( - ClientLogPrefix = "\033[36m[network:client]\033[0m " - ServerLogPrefix = "\033[32m[network:server]\033[0m " - ParseFrameLogPrefix = "\033[36m[network:stream_parser]\033[0m " + ConnStateReady ConnState = "Ready" + ConnStateDisconnected ConnState = "Disconnected" + ConnStateConnecting ConnState = "Connecting" + ConnStateConnected ConnState = "Connected" + ConnStateClosed ConnState = "Closed" ) func init() { diff --git a/context.go b/context.go index 6d0d2cd..f4b87ba 100644 --- a/context.go +++ b/context.go @@ -1,191 +1,137 @@ package network import ( + "context" "git.hpds.cc/Component/network/hpds_err" "git.hpds.cc/Component/network/log" - "io" "sync" "time" "git.hpds.cc/Component/network/frame" - "github.com/lucas-clemente/quic-go" ) +var ctxPool sync.Pool + // Context for Network Server. type Context struct { - // Conn is the connection of client. - Conn quic.Connection - connId string - // Stream is the long-lived connection between client and server. - Stream io.ReadWriteCloser + // DataStream is the stream used for reading and writing frames. + DataStream DataStream + // Frame receives from client. Frame frame.Frame - // Keys store the key/value pairs in context. - Keys map[string]interface{} + // mu is used to protect Keys from concurrent read and write operations. mu sync.RWMutex + // Keys stores the key/value pairs in context, It is Lazy initialized. + Keys map[string]any } -func newContext(conn quic.Connection, stream quic.Stream) *Context { - return &Context{ - Conn: conn, - connId: conn.RemoteAddr().String(), - Stream: stream, - // keys: make(map[string]interface{}), - } -} - -// WithFrame sets a frame to context. -func (c *Context) WithFrame(f frame.Frame) *Context { - c.Frame = f - return c -} - -// Clean the context. -func (c *Context) Clean() { - log.Debugf("%sconn[%s] context clean", ServerLogPrefix, c.connId) - c.Stream = nil - c.Frame = nil - c.Keys = nil - c.Conn = nil -} - -// CloseWithError closes the stream and cleans the context. -func (c *Context) CloseWithError(code hpds_err.ErrorCode, msg string) { - log.Debugf("%sconn[%s] context close, errCode=%#x, msg=%s", ServerLogPrefix, c.connId, code, msg) - if c.Stream != nil { - _ = c.Stream.Close() - } - if c.Conn != nil { - _ = c.Conn.CloseWithError(quic.ApplicationErrorCode(code), msg) - } - c.Clean() -} - -// ConnId get quic connection id -func (c *Context) ConnId() string { - return c.connId -} - -// Set a key/value pair to context. -func (c *Context) Set(key string, value interface{}) { +// Set is used to store a new key/value pair exclusively for this context. +// It also lazy initializes c.Keys if it was not used previously. +func (c *Context) Set(key string, value any) { c.mu.Lock() + defer c.mu.Unlock() + if c.Keys == nil { - c.Keys = make(map[string]interface{}) + c.Keys = make(map[string]any) } c.Keys[key] = value - c.mu.Unlock() } -// Get the value by a specified key. -func (c *Context) Get(key string) (value interface{}, exists bool) { +// Get returns the value for the given key, ie: (value, true). +// If the value does not exist it returns (nil, false) +func (c *Context) Get(key string) (any, bool) { c.mu.RLock() - value, exists = c.Keys[key] - c.mu.RUnlock() + defer c.mu.RUnlock() + + value, ok := c.Keys[key] + return value, ok +} + +var _ context.Context = &Context{} + +// Done returns nil (chan which will wait forever) when c.Stream.Context() has no Context. +func (c *Context) Done() <-chan struct{} { return c.DataStream.Context().Done() } + +// Deadline returns that there is no deadline (ok==false) when c.Stream has no Context. +func (c *Context) Deadline() (deadline time.Time, ok bool) { return c.DataStream.Context().Deadline() } + +// Err returns nil when c.Request has no Context. +func (c *Context) Err() error { return c.DataStream.Context().Err() } + +// Value returns the value associated with this context for key, or nil +// if no value is associated with key. Successive calls to Value with +// the same key returns the same result. +func (c *Context) Value(key any) any { + if keyAsString, ok := key.(string); ok { + if val, exists := c.Keys[keyAsString]; exists { + return val + } + } + // There always returns nil, because quic.Stream.Context is not be allowed modify. + return c.DataStream.Context().Value(key) +} + +// newContext returns a yomo context, +// The context implements standard library `context.Context` interface, +// The lifecycle of Context is equal to stream's that be passed in. +func newContext(dataStream DataStream) (c *Context) { + v := ctxPool.Get() + if v == nil { + c = new(Context) + } else { + c = v.(*Context) + } + + log.Infof("stream_id: %s; stream_name: %s; stream_type: %s;", dataStream.ID(), + dataStream.Name(), dataStream.StreamType().String(), + ) + + c.DataStream = dataStream return } -// GetString gets a string value by a specified key. -func (c *Context) GetString(key string) (s string) { - if val, ok := c.Get(key); ok && val != nil { - s, _ = val.(string) - } - return +// WithFrame sets a frame to context. +// +// TODO: delete frame from context due to different lifecycle between stream and stream. +func (c *Context) WithFrame(f frame.Frame) { + c.Frame = f } -// GetBool gets a bool value by a specified key. -func (c *Context) GetBool(key string) (b bool) { - if val, ok := c.Get(key); ok && val != nil { - b, _ = val.(bool) +// CloseWithError close dataStream in se error, +// It tells controlStream which dataStream should be closed and close dataStream with +// returning error message to client side stream. +// +// TODO: ycode is not be transmitted. +func (c *Context) CloseWithError(hCode hpds_err.ErrorCode, errString string) { + log.Warnf("Stream Close With error", "err_code", hCode.String(), "error", errString) + + err := c.DataStream.CloseWithError(errString) + if err == nil { + return } - return + log.Errorf("Close DataStream error", err) } -// GetInt gets an int value by a specified key. -func (c *Context) GetInt(key string) (i int) { - if val, ok := c.Get(key); ok && val != nil { - i, _ = val.(int) - } - return +// Clean cleans the Context, +// Context is not available after called Clean, +// +// Warining: do not use any Context api after Clean, It maybe cause an error. +func (c *Context) Clean() { + c.reset() + ctxPool.Put(c) } -// GetInt64 gets an int64 value by a specified key. -func (c *Context) GetInt64(key string) (i64 int64) { - if val, ok := c.Get(key); ok && val != nil { - i64, _ = val.(int64) +func (c *Context) reset() { + c.DataStream = nil + c.Frame = nil + for k := range c.Keys { + delete(c.Keys, k) } - return } -// GetUint gets an uint value by a specified key. -func (c *Context) GetUint(key string) (ui uint) { - if val, ok := c.Get(key); ok && val != nil { - ui, _ = val.(uint) - } - return -} - -// GetUint64 gets an uint64 value by a specified key. -func (c *Context) GetUint64(key string) (ui64 uint64) { - if val, ok := c.Get(key); ok && val != nil { - ui64, _ = val.(uint64) - } - return -} - -// GetFloat64 gets a float64 value by a specified key. -func (c *Context) GetFloat64(key string) (f64 float64) { - if val, ok := c.Get(key); ok && val != nil { - f64, _ = val.(float64) - } - return -} - -// GetTime gets a time.Time value by a specified key. -func (c *Context) GetTime(key string) (t time.Time) { - if val, ok := c.Get(key); ok && val != nil { - t, _ = val.(time.Time) - } - return -} - -// GetDuration gets a time.Duration value by a specified key. -func (c *Context) GetDuration(key string) (d time.Duration) { - if val, ok := c.Get(key); ok && val != nil { - d, _ = val.(time.Duration) - } - return -} - -// GetStringSlice gets a []string value by a specified key. -func (c *Context) GetStringSlice(key string) (ss []string) { - if val, ok := c.Get(key); ok && val != nil { - ss, _ = val.([]string) - } - return -} - -// GetStringMap gets a map[string]interface{} value by a specified key. -func (c *Context) GetStringMap(key string) (sm map[string]interface{}) { - if val, ok := c.Get(key); ok && val != nil { - sm, _ = val.(map[string]interface{}) - } - return -} - -// GetStringMapString gets a map[string]string value by a specified key. -func (c *Context) GetStringMapString(key string) (sms map[string]string) { - if val, ok := c.Get(key); ok && val != nil { - sms, _ = val.(map[string]string) - } - return -} - -// GetStringMapStringSlice gets a map[string][]string value by a specified key. -func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { - if val, ok := c.Get(key); ok && val != nil { - smss, _ = val.(map[string][]string) - } - return +// StreamId gets dataStream ID. +func (c *Context) StreamId() string { + return c.DataStream.ID() } diff --git a/control_stream.go b/control_stream.go new file mode 100644 index 0000000..b42d2fa --- /dev/null +++ b/control_stream.go @@ -0,0 +1,267 @@ +package network + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "git.hpds.cc/Component/network/auth" + "git.hpds.cc/Component/network/frame" + "github.com/quic-go/quic-go" +) + +// ControlStream defines the interface for controlling a stream. +type ControlStream interface { + // CloseStream notifies the peer's control stream to close the data stream with the given streamID and error message. + CloseStream(streamId string, errString string) error + // ReceiveStreamClose is received from the peer's control stream to close the data stream according to streamID and error message. + ReceiveStreamClose() (streamId string, errString string, err error) + // CloseWithError closes the control stream. + CloseWithError(code uint64, errString string) error +} + +// ServerControlStream defines the interface of server side control stream. +type ServerControlStream interface { + ControlStream + + // VerifyAuthentication verify the Authentication from client side. + VerifyAuthentication(verifyFunc func(auth.Object) (bool, error)) error + // AcceptStream accepts data stream from the request of client. + AcceptStream(context.Context) (DataStream, error) +} + +// ClientControlStream defines the interface of client side control stream. +type ClientControlStream interface { + ControlStream + + // Authenticate with credential, the credential will be sent to ServerControlStream to authenticate the client. + Authenticate(*auth.Credential) error + // OpenStream request a ServerControlStream to create a new data stream. + OpenStream(context.Context, *frame.HandshakeFrame) (DataStream, error) +} + +var _ ServerControlStream = &serverControlStream{} + +type serverControlStream struct { + conn quic.Connection + stream frame.ReadWriter +} + +// NewServerControlStream returns ServerControlStream from quic Connection and the first stream of this Connection. +func NewServerControlStream(qConn quic.Connection, stream frame.ReadWriter) ServerControlStream { + return &serverControlStream{ + conn: qConn, + stream: stream, + } +} + +func (ss *serverControlStream) ReceiveStreamClose() (streamId string, errReason string, err error) { + return receiveStreamClose(ss.stream) +} + +func (ss *serverControlStream) CloseStream(streamId string, errString string) error { + return closeStream(ss.stream, streamId, errString) +} + +func (ss *serverControlStream) AcceptStream(context.Context) (DataStream, error) { + f, err := ss.stream.ReadFrame() + if err != nil { + return nil, err + } + + switch ff := f.(type) { + case *frame.HandshakeFrame: + stream, err := ss.conn.OpenStreamSync(context.Background()) + if err != nil { + return nil, err + } + _, err = stream.Write(frame.NewHandshakeAckFrame(ff.ID()).Encode()) + if err != nil { + return nil, err + } + dataStream := newDataStream( + ff.Name(), + ff.ID(), + StreamType(ff.StreamType()), + ff.Metadata(), + stream, + ff.ObserveDataTags(), + ss, + ) + return dataStream, nil + default: + return nil, fmt.Errorf("yomo: control stream read unexpected frame %s", f.Type()) + } +} + +func (ss *serverControlStream) CloseWithError(code uint64, errString string) error { + return closeWithError(ss.conn, code, errString) +} + +func (ss *serverControlStream) VerifyAuthentication(verifyFunc func(auth.Object) (bool, error)) error { + first, err := ss.stream.ReadFrame() + if err != nil { + return err + } + received, ok := first.(*frame.AuthenticationFrame) + if !ok { + return fmt.Errorf("yomo: read unexpected frame while waiting for authentication, frame read: %s", received.Type().String()) + } + ok, err = verifyFunc(received) + if err != nil { + return err + } + if !ok { + return ss.stream.WriteFrame( + frame.NewAuthenticationRespFrame( + false, + fmt.Sprintf("yomo: authentication failed, client credential name is %s", received.AuthName()), + ), + ) + } + return ss.stream.WriteFrame(frame.NewAuthenticationRespFrame(true, "")) +} + +var _ ClientControlStream = &clientControlStream{} + +type clientControlStream struct { + conn quic.Connection + stream frame.ReadWriter +} + +// OpenClientControlStream opens ClientControlStream from addr. +func OpenClientControlStream( + ctx context.Context, addr string, + tlsConfig *tls.Config, quicConfig *quic.Config, +) (ClientControlStream, error) { + conn, err := quic.DialAddrContext(ctx, addr, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + stream, err := conn.OpenStream() + if err != nil { + return nil, err + } + + return NewClientControlStream(conn, NewFrameStream(stream)), nil +} + +// NewClientControlStream returns ClientControlStream from quic Connection and the first stream form the Connection. +func NewClientControlStream(qConn quic.Connection, stream frame.ReadWriter) ClientControlStream { + return &clientControlStream{ + conn: qConn, + stream: stream, + } +} + +func (cs *clientControlStream) ReceiveStreamClose() (streamId string, errReason string, err error) { + return receiveStreamClose(cs.stream) +} + +func (cs *clientControlStream) CloseStream(streamId string, errString string) error { + return closeStream(cs.stream, streamId, errString) +} + +func (cs *clientControlStream) Authenticate(cred *auth.Credential) error { + if err := cs.stream.WriteFrame( + frame.NewAuthenticationFrame(cred.Name(), cred.Payload())); err != nil { + return err + } + received, err := cs.stream.ReadFrame() + if err != nil { + return err + } + resp, ok := received.(*frame.AuthenticationRespFrame) + if !ok { + return fmt.Errorf( + "yomo: read unexcept frame during waiting authentication resp, frame readed: %s", + received.Type().String(), + ) + } + if !resp.OK() { + return errors.New(resp.Reason()) + } + return nil +} + +// dataStreamAcked drain HandshakeAckFrame from stream. +func dataStreamAcked(stream DataStream) error { + first, err := stream.ReadFrame() + if err != nil { + return err + } + + f, ok := first.(*frame.HandshakeAckFrame) + if !ok { + return fmt.Errorf("yomo: data stream read first frame should be HandshakeAckFrame, but got %s", first.Type().String()) + } + + if f.StreamId() != stream.ID() { + return fmt.Errorf("yomo: data stream ack exception, stream id did not match") + } + + return nil +} + +func (cs *clientControlStream) OpenStream(ctx context.Context, hf *frame.HandshakeFrame) (DataStream, error) { + err := cs.stream.WriteFrame(frame.NewHandshakeFrame( + hf.Name(), + hf.ID(), + hf.StreamType(), + hf.ObserveDataTags(), + hf.Metadata(), + )) + + if err != nil { + return nil, err + } + + quicStream, err := cs.conn.AcceptStream(ctx) + if err != nil { + return nil, err + } + + dataStream := newDataStream( + hf.Name(), + hf.ID(), + StreamType(hf.StreamType()), + hf.Metadata(), + quicStream, + hf.ObserveDataTags(), + cs, + ) + + if err := dataStreamAcked(dataStream); err != nil { + return nil, err + } + + return dataStream, nil +} + +func (cs *clientControlStream) CloseWithError(code uint64, errString string) error { + return closeWithError(cs.conn, code, errString) +} + +func closeStream(controlStream frame.Writer, streamID string, errString string) error { + f := frame.NewCloseStreamFrame(streamID, errString) + return controlStream.WriteFrame(f) +} + +func receiveStreamClose(controlStream frame.Reader) (streamID string, errString string, err error) { + f, err := controlStream.ReadFrame() + if err != nil { + return "", "", err + } + ff, ok := f.(*frame.CloseStreamFrame) + if !ok { + return "", "", errors.New("yomo: control stream only transmit close stream frame") + } + return ff.StreamID(), ff.Reason(), nil +} + +func closeWithError(qConn quic.Connection, code uint64, errString string) error { + return qConn.CloseWithError( + quic.ApplicationErrorCode(code), + errString, + ) +} diff --git a/data_stream.go b/data_stream.go new file mode 100644 index 0000000..8bb5e30 --- /dev/null +++ b/data_stream.go @@ -0,0 +1,163 @@ +package network + +import ( + "context" + "github.com/quic-go/quic-go" + "io" + "sync" + "sync/atomic" + + "git.hpds.cc/Component/network/frame" +) + +// DataStream wraps the specific io streams (typically quic.Stream) to transfer frames. +// DataStream be used to read and write frames, and be managed by Connector. +type DataStream interface { + // Context returns context.Context to manages DataStream lifecycle. + Context() context.Context + // Name returns the name of the stream, which is set by clients. + Name() string + // ID represents the dataStream ID, the ID is an unique string. + ID() string + // StreamType represents dataStream type (Source | SFN | UpstreamEmitter). + StreamType() StreamType + // Metadata returns the extra info of the application + Metadata() []byte + // Close real close DataStream, + // The controlStream calls this function, If you want close a dataStream, to use + // the CloseWithError api. + io.Closer + // CloseWithError close DataStream with an error string, + // This function do not real close the underlying stream, It notices controlStream to + // close itself, The controlStream must close underlying stream after receive CloseStreamFrame. + CloseWithError(string) error + // ReadWriter writes or reads frame to underlying stream. + // Writing and Reading are both goroutine-safely handle frames to peer side. + // ReadWriter returns stream closed error if stream is closed. + frame.ReadWriter + // ObserveDataTags observed data tags. + // TODO: There maybe a sorted list, we can find tag quickly. + ObserveDataTags() []frame.Tag +} + +// TODO: dataStream sync.Pool wrap. +type dataStream struct { + name string + id string + streamType StreamType + metadata []byte + observed []frame.Tag + + closed atomic.Bool + // mu protected stream write and close + // because of quic stream write and close is not goroutinue-safely. + mu sync.Mutex + stream quic.Stream + controlStream ControlStream +} + +// newDataStream constructures dataStream. +func newDataStream( + name string, + id string, + streamType StreamType, + metadata []byte, + stream quic.Stream, + observed []frame.Tag, + controlStream ControlStream, +) DataStream { + return &dataStream{ + name: name, + id: id, + streamType: streamType, + metadata: metadata, + stream: stream, + observed: observed, + controlStream: controlStream, + } +} + +// DataStream implements. +func (s *dataStream) Context() context.Context { return s.stream.Context() } +func (s *dataStream) ID() string { return s.id } +func (s *dataStream) Name() string { return s.name } +func (s *dataStream) Metadata() []byte { return s.metadata } +func (s *dataStream) StreamType() StreamType { return s.streamType } +func (s *dataStream) ObserveDataTags() []frame.Tag { return s.observed } + +func (s *dataStream) WriteFrame(frm frame.Frame) error { + if s.closed.Load() { + return io.EOF + } + + s.mu.Lock() + defer s.mu.Unlock() + _, err := s.stream.Write(frm.Encode()) + return err +} + +func (s *dataStream) ReadFrame() (frame.Frame, error) { + if s.closed.Load() { + return nil, io.EOF + } + return ParseFrame(s.stream) +} + +func (s *dataStream) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Close the stream truly, + // This function should be called after controlStream receive a closeStreamFrame. + return s.stream.Close() +} + +func (s *dataStream) CloseWithError(errString string) error { + if s.closed.Load() { + return nil + } + s.closed.Store(true) + + s.mu.Lock() + defer s.mu.Unlock() + + // Only notice client-side controlStream the stream has been closed. + // The controlStream reads closeStreamFrame and to close dataStream. + return s.controlStream.CloseStream(s.id, errString) +} + +const ( + // StreamTypeNone is stream type "None". + // "None" stream is not supposed to be in the yomo system. + StreamTypeNone StreamType = 0xFF + + // StreamTypeSource is stream type "Source". + // "Source" type stream sends data to "Stream Function" stream generally. + StreamTypeSource StreamType = 0x5F + + // StreamTypeUpstreamEmitter is connection type "Upstream Emitter". + // "Upstream Emitter" type stream sends data from "Source" to other Emitter node. + // With "Upstream Emitter", the yomo can run in mesh mode. + StreamTypeUpstreamEmitter StreamType = 0x5E + + // StreamTypeStreamFunction is stream type "Stream Function". + // "Stream Function" handles data from source. + StreamTypeStreamFunction StreamType = 0x5D +) + +// StreamType represents the stream type. +type StreamType byte + +// String returns string for StreamType. +func (c StreamType) String() string { + switch c { + case StreamTypeSource: + return "Source" + case StreamTypeUpstreamEmitter: + return "Upstream Emitter" + case StreamTypeStreamFunction: + return "Stream Function" + default: + return "None" + } +} diff --git a/frame/authentication_frame.go b/frame/authentication_frame.go new file mode 100644 index 0000000..7eb971b --- /dev/null +++ b/frame/authentication_frame.go @@ -0,0 +1,84 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// AuthenticationFrame is used to authenticate the client, +// Once the connection is established, the client immediately, sends information +// to the server, server gets the way to authenticate according to authName and +// use authPayload to do a authentication. +// +// AuthenticationFrame is a coder encoded. +type AuthenticationFrame struct { + authName string + authPayload string +} + +// NewAuthenticationFrame creates a new AuthenticationFrame. +func NewAuthenticationFrame(authName string, authPayload string) *AuthenticationFrame { + return &AuthenticationFrame{ + authName: authName, + authPayload: authPayload, + } +} + +// Type returns the type of AuthenticationFrame. +func (h *AuthenticationFrame) Type() Type { + return TagOfAuthenticationFrame +} + +// Encode encodes AuthenticationFrame to bytes in coder codec. +func (h *AuthenticationFrame) Encode() []byte { + // auth + authNameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfAuthenticationName)) + authNameBlock.SetStringValue(h.authName) + authPayloadBlock := coder.NewPrimitivePacketEncoder(byte(TagOfAuthenticationPayload)) + authPayloadBlock.SetStringValue(h.authPayload) + // authentication frame + authentication := coder.NewNodePacketEncoder(byte(h.Type())) + authentication.AddPrimitivePacket(authNameBlock) + authentication.AddPrimitivePacket(authPayloadBlock) + + return authentication.Encode() +} + +// DecodeToAuthenticationFrame decodes coder encoded bytes to AuthenticationFrame. +func DecodeToAuthenticationFrame(buf []byte) (*AuthenticationFrame, error) { + node := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &node) + if err != nil { + return nil, err + } + + authentication := &AuthenticationFrame{} + + // auth + if authNameBlock, ok := node.PrimitivePackets[byte(TagOfAuthenticationName)]; ok { + authName, err := authNameBlock.ToUTF8String() + if err != nil { + return nil, err + } + authentication.authName = authName + } + if authPayloadBlock, ok := node.PrimitivePackets[byte(TagOfAuthenticationPayload)]; ok { + authPayload, err := authPayloadBlock.ToUTF8String() + if err != nil { + return nil, err + } + authentication.authPayload = authPayload + } + + return authentication, nil +} + +// AuthPayload returns authentication payload. +func (h *AuthenticationFrame) AuthPayload() string { + return h.authPayload +} + +// AuthName returns authentication name, +// server finds the mode of authentication in AuthName. +func (h *AuthenticationFrame) AuthName() string { + return h.authName +} diff --git a/frame/authentication_frame_test.go b/frame/authentication_frame_test.go new file mode 100644 index 0000000..08cf7de --- /dev/null +++ b/frame/authentication_frame_test.go @@ -0,0 +1,23 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAuthenticationFrame(t *testing.T) { + m := NewAuthenticationFrame("token", "a") + assert.Equal(t, []byte{ + 0x80 | byte(TagOfAuthenticationFrame), 0xa, + byte(TagOfAuthenticationName), 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + byte(TagOfAuthenticationPayload), 0x01, 0x61, + }, + m.Encode(), + ) + + authenticate, err := DecodeToAuthenticationFrame(m.Encode()) + assert.NoError(t, err) + assert.EqualValues(t, "token", authenticate.AuthName()) + assert.EqualValues(t, "a", authenticate.AuthPayload()) +} diff --git a/frame/authentication_resp.go b/frame/authentication_resp.go new file mode 100644 index 0000000..cbebc2f --- /dev/null +++ b/frame/authentication_resp.go @@ -0,0 +1,77 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// AuthenticationRespFrame is the response of Authentication. +// AuthenticationRespFrame is a coder encoded bytes. +type AuthenticationRespFrame struct { + ok bool + reason string +} + +// OK returns if Authentication is success. +func (f *AuthenticationRespFrame) OK() bool { return f.ok } + +// Reason returns the failed reason of Authentication. +func (f *AuthenticationRespFrame) Reason() string { return f.reason } + +// NewAuthenticationRespFrame returns a AuthenticationRespFrame. +func NewAuthenticationRespFrame(ok bool, reason string) *AuthenticationRespFrame { + return &AuthenticationRespFrame{ + ok: ok, + reason: reason, + } +} + +// Type gets the type of the AuthenticationRespFrame. +func (f *AuthenticationRespFrame) Type() Type { + return TagOfAuthenticationAckFrame +} + +// Encode encodes AuthenticationRespFrame to coder encoded bytes. +func (f *AuthenticationRespFrame) Encode() []byte { + // ok + okBlock := coder.NewPrimitivePacketEncoder(byte(TagOfAuthenticationAckOk)) + okBlock.SetBoolValue(f.ok) + // reason + reasonBlock := coder.NewPrimitivePacketEncoder(byte(TagOfAuthenticationAckReason)) + reasonBlock.SetStringValue(f.reason) + // frame + ack := coder.NewNodePacketEncoder(byte(f.Type())) + ack.AddPrimitivePacket(okBlock) + ack.AddPrimitivePacket(reasonBlock) + + return ack.Encode() +} + +// DecodeToAuthenticationRespFrame decodes coder encoded bytes to AuthenticationRespFrame. +func DecodeToAuthenticationRespFrame(buf []byte) (*AuthenticationRespFrame, error) { + node := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &node) + if err != nil { + return nil, err + } + + f := &AuthenticationRespFrame{} + + // ok + if okBlock, ok := node.PrimitivePackets[byte(TagOfAuthenticationAckOk)]; ok { + ok, err := okBlock.ToBool() + if err != nil { + return nil, err + } + f.ok = ok + } + // reason + if reasonBlock, ok := node.PrimitivePackets[byte(TagOfAuthenticationAckReason)]; ok { + reason, err := reasonBlock.ToUTF8String() + if err != nil { + return nil, err + } + f.reason = reason + } + + return f, nil +} diff --git a/frame/authentication_resp_test.go b/frame/authentication_resp_test.go new file mode 100644 index 0000000..8e1aa6a --- /dev/null +++ b/frame/authentication_resp_test.go @@ -0,0 +1,20 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAuthenticationAckFrame(t *testing.T) { + f := NewAuthenticationRespFrame(false, "aabbcc") + + bytes := f.Encode() + assert.Equal(t, []byte{0x91, 0xb, 0x12, 0x1, 0x0, 0x13, 0x6, 0x61, 0x61, 0x62, 0x62, 0x63, 0x63}, bytes) + + got, err := DecodeToAuthenticationRespFrame(bytes) + assert.Equal(t, f, got) + assert.NoError(t, err) + assert.EqualValues(t, false, f.OK()) + assert.EqualValues(t, "aabbcc", f.Reason()) +} diff --git a/frame/backflow_frame.go b/frame/backflow_frame.go index 298422c..9bf2d8f 100644 --- a/frame/backflow_frame.go +++ b/frame/backflow_frame.go @@ -7,12 +7,12 @@ import ( // BackFlowFrame is a coder encoded bytes // It's used to receive stream function processed result type BackFlowFrame struct { - Tag byte + Tag Tag Carriage []byte } // NewBackFlowFrame creates a new BackFlowFrame with a given tag and carriage -func NewBackFlowFrame(tag byte, carriage []byte) *BackFlowFrame { +func NewBackFlowFrame(tag Tag, carriage []byte) *BackFlowFrame { return &BackFlowFrame{ Tag: tag, Carriage: carriage, @@ -32,17 +32,20 @@ func (f *BackFlowFrame) SetCarriage(buf []byte) *BackFlowFrame { // Encode to coder encoded bytes func (f *BackFlowFrame) Encode() []byte { - carriage := coder.NewPrimitivePacketEncoder(f.Tag) + tag := coder.NewPrimitivePacketEncoder(byte(TagOfBackFlowDataTag)) + tag.SetUInt32Value(uint32(f.Tag)) + carriage := coder.NewPrimitivePacketEncoder(byte(TagOfBackFlowCarriage)) carriage.SetBytesValue(f.Carriage) node := coder.NewNodePacketEncoder(byte(TagOfBackFlowFrame)) + node.AddPrimitivePacket(tag) node.AddPrimitivePacket(carriage) return node.Encode() } // GetDataTag return the Tag of user's data -func (f *BackFlowFrame) GetDataTag() byte { +func (f *BackFlowFrame) GetDataTag() Tag { return f.Tag } @@ -60,11 +63,15 @@ func DecodeToBackFlowFrame(buf []byte) (*BackFlowFrame, error) { } payload := &BackFlowFrame{} - for _, v := range nodeBlock.PrimitivePackets { - payload.Tag = v.SeqId() - payload.Carriage = v.GetValBuf() - break + if p, ok := nodeBlock.PrimitivePackets[byte(TagOfBackFlowDataTag)]; ok { + tag, err := p.ToUInt32() + if err != nil { + return nil, err + } + payload.Tag = Tag(tag) + } + if p, ok := nodeBlock.PrimitivePackets[byte(TagOfBackFlowCarriage)]; ok { + payload.Carriage = p.GetValBuf() } - return payload, nil } diff --git a/frame/backflow_frame_test.go b/frame/backflow_frame_test.go new file mode 100644 index 0000000..2343ab1 --- /dev/null +++ b/frame/backflow_frame_test.go @@ -0,0 +1,33 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBackFlowFrameEncode(t *testing.T) { + var ( + tag = Tag(22) + carriage = []byte("hello backflow") + ) + f := NewBackFlowFrame(tag, []byte{}) + + f.SetCarriage(carriage) + + assert.Equal(t, TagOfBackFlowFrame, f.Type()) + assert.Equal(t, f.GetCarriage(), carriage) + assert.Equal(t, f.GetDataTag(), tag) + assert.Equal(t, []byte{0xad, 0x13, 0x1, 0x1, 0x16, 0x2, 0xe, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x62, 0x61, 0x63, 0x6b, 0x66, 0x6c, 0x6f, 0x77}, f.Encode()) +} + +func TestBackflowFrameDecode(t *testing.T) { + f := NewBackFlowFrame(Tag(22), []byte("hello backflow")) + + buf := f.Encode() + + df, err := DecodeToBackFlowFrame(buf) + + assert.NoError(t, err) + assert.Equal(t, df, f) +} diff --git a/frame/close_stream_frame.go b/frame/close_stream_frame.go new file mode 100644 index 0000000..d85589e --- /dev/null +++ b/frame/close_stream_frame.go @@ -0,0 +1,78 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// CloseStreamFrame is used to close a dataStream, controlStream +// receives CloseStreamFrame and closes dataStream according to the Frame. +// CloseStreamFrame is a coder encoded bytes. +type CloseStreamFrame struct { + streamID string + reason string +} + +// StreamID returns the ID of the stream to be closed. +func (f *CloseStreamFrame) StreamID() string { return f.streamID } + +// Reason returns the close reason. +func (f *CloseStreamFrame) Reason() string { return f.reason } + +// NewCloseStreamFrame returns a CloseStreamFrame. +func NewCloseStreamFrame(streamID, reason string) *CloseStreamFrame { + return &CloseStreamFrame{ + streamID: streamID, + reason: reason, + } +} + +// Type gets the type of the CloseStreamFrame. +func (f *CloseStreamFrame) Type() Type { + return TagOfCloseStreamFrame +} + +// Encode encodes CloseStreamFrame to coder encoded bytes. +func (f *CloseStreamFrame) Encode() []byte { + // id + idBlock := coder.NewPrimitivePacketEncoder(byte(TagOfCloseStreamID)) + idBlock.SetStringValue(f.streamID) + // reason + reasonBlock := coder.NewPrimitivePacketEncoder(byte(TagOfCloseStreamReason)) + reasonBlock.SetStringValue(f.reason) + // frame + ack := coder.NewNodePacketEncoder(byte(f.Type())) + ack.AddPrimitivePacket(idBlock) + ack.AddPrimitivePacket(reasonBlock) + + return ack.Encode() +} + +// DecodeToCloseStreamFrame decodes coder encoded bytes to CloseStreamFrame. +func DecodeToCloseStreamFrame(buf []byte) (*CloseStreamFrame, error) { + node := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &node) + if err != nil { + return nil, err + } + + f := &CloseStreamFrame{} + + // id + if idBlock, ok := node.PrimitivePackets[byte(TagOfCloseStreamID)]; ok { + id, err := idBlock.ToUTF8String() + if err != nil { + return nil, err + } + f.streamID = id + } + // reason + if reasonBlock, ok := node.PrimitivePackets[byte(TagOfCloseStreamReason)]; ok { + reason, err := reasonBlock.ToUTF8String() + if err != nil { + return nil, err + } + f.reason = reason + } + + return f, nil +} diff --git a/frame/close_stream_frame_test.go b/frame/close_stream_frame_test.go new file mode 100644 index 0000000..9139b83 --- /dev/null +++ b/frame/close_stream_frame_test.go @@ -0,0 +1,20 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCloseStreamFrame(t *testing.T) { + f := NewCloseStreamFrame("eeffgg", "aabbcc") + + bytes := f.Encode() + assert.Equal(t, []byte{0x94, 0x10, 0x15, 0x6, 0x65, 0x65, 0x66, 0x66, 0x67, 0x67, 0x16, 0x6, 0x61, 0x61, 0x62, 0x62, 0x63, 0x63}, bytes) + + got, err := DecodeToCloseStreamFrame(bytes) + assert.Equal(t, f, got) + assert.NoError(t, err) + assert.EqualValues(t, "eeffgg", f.StreamID()) + assert.EqualValues(t, "aabbcc", f.Reason()) +} diff --git a/frame/data_frame.go b/frame/data_frame.go index 217d943..51ff7d2 100644 --- a/frame/data_frame.go +++ b/frame/data_frame.go @@ -1,37 +1,83 @@ package frame import ( + "fmt" coder "git.hpds.cc/Component/mq_coder" + "sync" ) +var dataFramePool sync.Pool + // DataFrame defines the data structure carried with user's data type DataFrame struct { metaFrame *MetaFrame payloadFrame *PayloadFrame } +func (d *DataFrame) String() string { + data := d.GetCarriage() + length := len(data) + if length > debugFrameSize { + data = data[:debugFrameSize] + } + return fmt.Sprintf("tid=%s | tag=%#x | source=%s | data[%d]=%# x", d.metaFrame.tid, d.Tag(), d.SourceId(), length, data) +} + // NewDataFrame create `DataFrame` with a transactionId string, // consider change transactionID to UUID type later func NewDataFrame() *DataFrame { - data := &DataFrame{ - metaFrame: NewMetaFrame(), - } + data := newDataFrame() + data.metaFrame.tid = randString() return data } +func newDataFrame() (data *DataFrame) { + v := dataFramePool.Get() + if v == nil { + data = new(DataFrame) + data.metaFrame = new(MetaFrame) + data.payloadFrame = new(PayloadFrame) + } else { + data = v.(*DataFrame) + } + + return +} + +// Clean cleans DataFrame. +// Note that: +// 1/ if the client is calling WriteFrame(), it will automatically invoke Clean(), so there is no need to call Clean() separately. +// 2/ The DataFrame will be unavailable after cleaned, do not access DataFrame after Clean() called. +func (d *DataFrame) Clean() { + // reset metadataFrame + d.metaFrame.tid = "" + d.metaFrame.metadata = d.metaFrame.metadata[:0] + d.metaFrame.sourceId = "" + d.metaFrame.broadcast = false + + // reset payloadFrame + d.payloadFrame.Tag = Tag(0) + d.payloadFrame.Carriage = d.payloadFrame.Carriage[:0] + + dataFramePool.Put(d) +} + // Type gets the type of Frame. func (d *DataFrame) Type() Type { return TagOfDataFrame } // Tag return the tag of carriage data. -func (d *DataFrame) Tag() byte { +func (d *DataFrame) Tag() Tag { return d.payloadFrame.Tag } // SetCarriage set user's raw data in `DataFrame` -func (d *DataFrame) SetCarriage(tag byte, carriage []byte) { - d.payloadFrame = NewPayloadFrame(tag).SetCarriage(carriage) +func (d *DataFrame) SetCarriage(tag Tag, carriage []byte) { + d.payloadFrame = &PayloadFrame{ + Tag: tag, + Carriage: carriage, + } } // GetCarriage return user's raw data in `DataFrame` @@ -45,8 +91,8 @@ func (d *DataFrame) TransactionId() string { } // SetTransactionId set transactionId string -func (d *DataFrame) SetTransactionId(transactionID string) { - d.metaFrame.SetTransactionId(transactionID) +func (d *DataFrame) SetTransactionId(transactionId string) { + d.metaFrame.SetTransactionId(transactionId) } // GetMetaFrame return MetaFrame. @@ -55,13 +101,13 @@ func (d *DataFrame) GetMetaFrame() *MetaFrame { } // GetDataTag return the Tag of user's data -func (d *DataFrame) GetDataTag() byte { +func (d *DataFrame) GetDataTag() Tag { return d.payloadFrame.Tag } // SetSourceId set the source id. -func (d *DataFrame) SetSourceId(sourceID string) { - d.metaFrame.SetSourceId(sourceID) +func (d *DataFrame) SetSourceId(sourceId string) { + d.metaFrame.SetSourceId(sourceId) } // SourceId returns source id @@ -69,6 +115,16 @@ func (d *DataFrame) SourceId() string { return d.metaFrame.SourceId() } +// SetBroadcast set broadcast mode +func (d *DataFrame) SetBroadcast(enabled bool) { + d.metaFrame.SetBroadcast(enabled) +} + +// IsBroadcast returns the broadcast mode is enabled +func (d *DataFrame) IsBroadcast() bool { + return d.metaFrame.IsBroadcast() +} + // Encode return coder encoded bytes of `DataFrame` func (d *DataFrame) Encode() []byte { data := coder.NewNodePacketEncoder(byte(d.Type())) @@ -88,22 +144,22 @@ func DecodeToDataFrame(buf []byte) (*DataFrame, error) { return nil, err } - data := &DataFrame{} + data := new(DataFrame) + data.metaFrame = new(MetaFrame) + data.payloadFrame = new(PayloadFrame) if metaBlock, ok := packet.NodePackets[byte(TagOfMetaFrame)]; ok { - meta, err := DecodeToMetaFrame(metaBlock.GetRawBytes()) + err := DecodeToMetaFrame(metaBlock.GetRawBytes(), data.metaFrame) if err != nil { return nil, err } - data.metaFrame = meta } if payloadBlock, ok := packet.NodePackets[byte(TagOfPayloadFrame)]; ok { - payload, err := DecodeToPayloadFrame(payloadBlock.GetRawBytes()) + err := DecodeToPayloadFrame(payloadBlock.GetRawBytes(), data.payloadFrame) if err != nil { return nil, err } - data.payloadFrame = payload } return data, nil diff --git a/frame/data_frame_test.go b/frame/data_frame_test.go index 3ed1cda..18b0045 100644 --- a/frame/data_frame_test.go +++ b/frame/data_frame_test.go @@ -7,33 +7,59 @@ import ( ) func TestDataFrameEncode(t *testing.T) { - var userDataTag byte = 0x15 + var userDataTag Tag = 0x15 d := NewDataFrame() d.SetCarriage(userDataTag, []byte("hpds")) + d.SetBroadcast(true) + + assert.EqualValues(t, "", d.SourceId()) tidBuf := []byte(d.TransactionId()) result := []byte{ - 0x80 | byte(TagOfDataFrame), byte(len(tidBuf) + 4 + 8 + 2), - 0x80 | byte(TagOfMetaFrame), byte(len(tidBuf) + 2 + 2), + 0x80 | byte(TagOfDataFrame), byte(len(tidBuf) + 4 + 8 + 5 + 3), + 0x80 | byte(TagOfMetaFrame), byte(len(tidBuf) + 2 + 2 + 3), byte(TagOfTransactionId), byte(len(tidBuf))} result = append(result, tidBuf...) result = append(result, byte(TagOfSourceId), 0x0) - result = append(result, 0x80|byte(TagOfPayloadFrame), 0x06, - userDataTag, 0x04, 0x68, 0x70, 0x64, 0x73) + result = append(result, byte(TagOfBroadcast), 0x1, 0x1) + result = append(result, 0x80|byte(TagOfPayloadFrame), 0x09, + 0x01, 0x1, 0x15, 0x02, 0x04, 0x68, 0x70, 0x64, 0x73) assert.Equal(t, result, d.Encode()) } func TestDataFrameDecode(t *testing.T) { - var userDataTag byte = 0x15 + var userDataTag Tag = 0x15 buf := []byte{ - 0x80 | byte(TagOfDataFrame), 0x10, - 0x80 | byte(TagOfMetaFrame), 0x06, + 0x80 | byte(TagOfDataFrame), 0x10 + 3, + 0x80 | byte(TagOfMetaFrame), 0x06 + 3, byte(TagOfTransactionId), 0x04, 0x31, 0x32, 0x33, 0x34, - 0x80 | byte(TagOfPayloadFrame), 0x06, - userDataTag, 0x04, 0x68, 0x70, 0x64, 0x73} + byte(TagOfBroadcast), 0x01, 0x01, + 0x80 | byte(TagOfPayloadFrame), 0x09, + 0x01, 0x1, 0x15, 0x02, 0x04, 0x68, 0x70, 0x64, 0x73} data, err := DecodeToDataFrame(buf) + defer data.Clean() assert.NoError(t, err) + + assert.EqualValues(t, 0x15, data.Tag()) assert.EqualValues(t, "1234", data.TransactionId()) assert.EqualValues(t, userDataTag, data.GetDataTag()) assert.EqualValues(t, []byte("hpds"), data.GetCarriage()) + assert.EqualValues(t, true, data.IsBroadcast()) +} + +func BenchmarkDataFramePool(b *testing.B) { + var ( + tag = Tag(0x15) + payload = []byte("hpds") + ) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + prev := NewDataFrame() + prev.SetCarriage(tag, payload) + prev.SetBroadcast(true) + + prev.Clean() + } + }) } diff --git a/frame/frame.go b/frame/frame.go index 4ac67ee..e4a106a 100644 --- a/frame/frame.go +++ b/frame/frame.go @@ -5,6 +5,25 @@ import ( "strconv" ) +// ReadWriter is the interface that groups the ReadFrame and WriteFrame methods. +type ReadWriter interface { + Reader + Writer +} + +// Reader reads frame from underlying stream. +type Reader interface { + // ReadFrame reads frame, if error, the error returned is not empty + // and frame returned is nil. + ReadFrame() (Frame, error) +} + +// Writer is the interface that wraps the WriteFrame method, It writes +// frm to the underlying data stream. +type Writer interface { + WriteFrame(frm Frame) error +} + // debugFrameSize print frame data size on debug mode var debugFrameSize = 16 @@ -17,19 +36,45 @@ const ( TagOfMetadata Type = 0x03 TagOfTransactionId Type = 0x01 TagOfSourceId Type = 0x02 + TagOfBroadcast Type = 0x04 // PayloadFrame of DataFrame - TagOfPayloadFrame Type = 0x2E - TagOfBackFlowFrame Type = 0x2D + TagOfPayloadFrame Type = 0x2E + TagOfPayloadDataTag Type = 0x01 + TagOfPayloadCarriage Type = 0x02 + TagOfBackFlowFrame Type = 0x2D + TagOfBackFlowDataTag Type = 0x01 + TagOfBackFlowCarriage Type = 0x02 TagOfTokenFrame Type = 0x3E + + // AuthenticationFrame + TagOfAuthenticationFrame Type = 0x03 + TagOfAuthenticationName Type = 0x04 + TagOfAuthenticationPayload Type = 0x05 + + // AuthenticationAckFrame + TagOfAuthenticationAckFrame Type = 0x11 + TagOfAuthenticationAckOk Type = 0x12 + TagOfAuthenticationAckReason Type = 0x13 + + // CloseStreamFrame + TagOfCloseStreamFrame Type = 0x14 + TagOfCloseStreamID Type = 0x15 + TagOfCloseStreamReason Type = 0x16 + // HandshakeFrame TagOfHandshakeFrame Type = 0x3D TagOfHandshakeName Type = 0x01 - TagOfHandshakeType Type = 0x02 + TagOfHandshakeStreamType Type = 0x02 TagOfHandshakeId Type = 0x03 TagOfHandshakeAuthName Type = 0x04 TagOfHandshakeAuthPayload Type = 0x05 TagOfHandshakeObserveDataTags Type = 0x06 + TagOfHandshakeMetadata Type = 0x07 + + // TagOfHandshakeAckFrame + TagOfHandshakeAckFrame Type = 0x29 + TagOfHandshakeAckStreamId Type = 0x28 TagOfPingFrame Type = 0x3C TagOfPongFrame Type = 0x3B @@ -60,6 +105,14 @@ func (f Type) String() string { return "DataFrame" case TagOfTokenFrame: return "TokenFrame" + case TagOfAuthenticationFrame: + return "AuthenticationFrame" + case TagOfAuthenticationAckFrame: + return "AuthenticationAckFrame" + case TagOfHandshakeAckFrame: + return "HandshakeAckFrame" + case TagOfCloseStreamFrame: + return "CloseStreamFrame" case TagOfHandshakeFrame: return "HandshakeFrame" case TagOfPingFrame: @@ -78,11 +131,9 @@ func (f Type) String() string { return "MetaFrame" case TagOfPayloadFrame: return "PayloadFrame" - // case TagOfTransactionId: - // return "TransactionId" case TagOfHandshakeName: return "HandshakeName" - case TagOfHandshakeType: + case TagOfHandshakeStreamType: return "HandshakeType" default: return "UnknownFrame" diff --git a/frame/goaway_frame_test.go b/frame/goaway_frame_test.go new file mode 100644 index 0000000..c2081cd --- /dev/null +++ b/frame/goaway_frame_test.go @@ -0,0 +1,20 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGoawayFrameEncode(t *testing.T) { + f := NewGoawayFrame("goaway") + assert.Equal(t, "goaway", f.Message()) + assert.Equal(t, []byte{0x80 | byte(TagOfGoawayFrame), 0x8, 0x2, 0x6, 0x67, 0x6f, 0x61, 0x77, 0x61, 0x79}, f.Encode()) +} + +func TestGoawayFrameDecode(t *testing.T) { + buf := []byte{0x80 | byte(TagOfGoawayFrame), 0x8, 0x2, 0x6, 0x67, 0x6f, 0x61, 0x77, 0x61, 0x79} + f, err := DecodeToGoawayFrame(buf) + assert.NoError(t, err) + assert.Equal(t, []byte{0x80 | byte(TagOfGoawayFrame), 0x8, 0x2, 0x6, 0x67, 0x6f, 0x61, 0x77, 0x61, 0x79}, f.Encode()) +} diff --git a/frame/handshake_ack_frame.go b/frame/handshake_ack_frame.go new file mode 100644 index 0000000..8a24b6f --- /dev/null +++ b/frame/handshake_ack_frame.go @@ -0,0 +1,59 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// HandshakeAckFrame is used to ack handshake, It is always that the first frame +// is HandshakeAckFrame after client acquire a new stream. +// HandshakeAckFrame is a coder encoded bytes. +type HandshakeAckFrame struct { + streamId string +} + +// NewHandshakeAckFrame returns a HandshakeAckFrame. +func NewHandshakeAckFrame(streamId string) *HandshakeAckFrame { + return &HandshakeAckFrame{streamId} +} + +// Type gets the type of the HandshakeAckFrame. +func (f *HandshakeAckFrame) Type() Type { + return TagOfHandshakeAckFrame +} + +// StreamId returns the id of stream be acked. +func (f *HandshakeAckFrame) StreamId() string { + return f.streamId +} + +// Encode encodes HandshakeAckFrame to coder encoded bytes. +func (f *HandshakeAckFrame) Encode() []byte { + ack := coder.NewNodePacketEncoder(byte(f.Type())) + // streamId + streamIDBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAckStreamId)) + streamIDBlock.SetStringValue(f.streamId) + + ack.AddPrimitivePacket(streamIDBlock) + + return ack.Encode() +} + +// DecodeToHandshakeAckFrame decodes coder encoded bytes to HandshakeAckFrame +func DecodeToHandshakeAckFrame(buf []byte) (*HandshakeAckFrame, error) { + node := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &node) + if err != nil { + return nil, err + } + + ack := &HandshakeAckFrame{} + // streamID + if streamIDBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAckStreamId)]; ok { + streamId, err := streamIDBlock.ToUTF8String() + if err != nil { + return nil, err + } + ack.streamId = streamId + } + return ack, nil +} diff --git a/frame/handshake_ack_frame_test.go b/frame/handshake_ack_frame_test.go new file mode 100644 index 0000000..1761d61 --- /dev/null +++ b/frame/handshake_ack_frame_test.go @@ -0,0 +1,25 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var handShakeAckTestBuf = []byte{0x80 | byte(TagOfHandshakeAckFrame), 0x8, 0x28, 0x6, 0x74, 0x68, 0x65, 0x2d, 0x69, 0x64} + +var testStreamID = "the-id" + +func TestHandshakeAckFrameEncode(t *testing.T) { + f := NewHandshakeAckFrame(testStreamID) + assert.Equal(t, TagOfHandshakeAckFrame, f.Type()) + assert.Equal(t, handShakeAckTestBuf, f.Encode()) +} + +func TestHandshakeAckFrameDecode(t *testing.T) { + f, err := DecodeToHandshakeAckFrame(handShakeAckTestBuf) + assert.NoError(t, err) + assert.Equal(t, TagOfHandshakeAckFrame, f.Type()) + assert.Equal(t, testStreamID, f.StreamId()) + assert.Equal(t, handShakeAckTestBuf, f.Encode()) +} diff --git a/frame/handshake_frame.go b/frame/handshake_frame.go index 4b4cda6..544f59d 100644 --- a/frame/handshake_frame.go +++ b/frame/handshake_frame.go @@ -1,68 +1,82 @@ package frame import ( + "encoding/binary" coder "git.hpds.cc/Component/mq_coder" ) // HandshakeFrame is a coder encoded. type HandshakeFrame struct { - // Name is client name - Name string + // name is client name + name string // ClientId represents client id - ClientId string + id string // ClientType represents client type (Protocol Gateway | Stream Function) - ClientType byte + streamType byte // ObserveDataTags are the client data tag list. - ObserveDataTags []byte - // auth - authName string - authPayload string + observeDataTags []Tag + metadata []byte } // NewHandshakeFrame creates a new HandshakeFrame. -func NewHandshakeFrame(name string, clientId string, clientType byte, observeDataTags []byte, authName string, authPayload string) *HandshakeFrame { +func NewHandshakeFrame(name string, id string, stream byte, observeDataTags []Tag, metadata []byte) *HandshakeFrame { return &HandshakeFrame{ - Name: name, - ClientId: clientId, - ClientType: clientType, - ObserveDataTags: observeDataTags, - authName: authName, - authPayload: authPayload, + name: name, + id: id, + streamType: stream, + observeDataTags: observeDataTags, + metadata: metadata, } } -// Type gets the type of Frame. -func (h *HandshakeFrame) Type() Type { - return TagOfHandshakeFrame -} +// Name is the name of dataStream. +func (h *HandshakeFrame) Name() string { return h.name } + +// ID represents the dataStream ID, the ID must be a unique string. +func (h *HandshakeFrame) ID() string { return h.id } + +// StreamType represents dataStream type (Source | SFN | UpstreamEmitter). +// different StreamType has different behaviors in server side. +func (h *HandshakeFrame) StreamType() byte { return h.streamType } + +// ObserveDataTags are the stream data tag list. +func (h *HandshakeFrame) ObserveDataTags() []Tag { return h.observeDataTags } + +// Metadata holds stream metadata, +// metadata stores information for route the data. +func (h *HandshakeFrame) Metadata() []byte { return h.metadata } + +// Type returns the type of HandshakeFrame. +func (h *HandshakeFrame) Type() Type { return TagOfHandshakeFrame } // Encode to coder encoding. func (h *HandshakeFrame) Encode() []byte { // name nameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeName)) - nameBlock.SetStringValue(h.Name) - // client ID + nameBlock.SetStringValue(h.name) + // ID idBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeId)) - idBlock.SetStringValue(h.ClientId) - // client type - typeBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeType)) - typeBlock.SetBytesValue([]byte{h.ClientType}) + idBlock.SetStringValue(h.id) + // stream type + typeBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeStreamType)) + typeBlock.SetBytesValue([]byte{h.streamType}) // observe data tags observeDataTagsBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeObserveDataTags)) - observeDataTagsBlock.SetBytesValue(h.ObserveDataTags) - // auth - authNameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthName)) - authNameBlock.SetStringValue(h.authName) - authPayloadBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthPayload)) - authPayloadBlock.SetStringValue(h.authPayload) + buf := make([]byte, 4) + for _, v := range h.observeDataTags { + binary.LittleEndian.PutUint32(buf, uint32(v)) + observeDataTagsBlock.AddBytes(buf) + } + // metadata + metadataBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeMetadata)) + metadataBlock.SetBytesValue(h.metadata) // handshake frame handshake := coder.NewNodePacketEncoder(byte(h.Type())) handshake.AddPrimitivePacket(nameBlock) handshake.AddPrimitivePacket(idBlock) handshake.AddPrimitivePacket(typeBlock) handshake.AddPrimitivePacket(observeDataTagsBlock) - handshake.AddPrimitivePacket(authNameBlock) - handshake.AddPrimitivePacket(authPayloadBlock) + handshake.AddPrimitivePacket(metadataBlock) return handshake.Encode() } @@ -82,7 +96,7 @@ func DecodeToHandshakeFrame(buf []byte) (*HandshakeFrame, error) { if err != nil { return nil, err } - handshake.Name = name + handshake.name = name } // client id if idBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeId)]; ok { @@ -90,42 +104,27 @@ func DecodeToHandshakeFrame(buf []byte) (*HandshakeFrame, error) { if err != nil { return nil, err } - handshake.ClientId = id + handshake.id = id } // client type - if typeBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeType)]; ok { - clientType := typeBlock.ToBytes() - handshake.ClientType = clientType[0] + if typeBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeStreamType)]; ok { + streamType := typeBlock.ToBytes() + handshake.streamType = streamType[0] } // observe data tag list if observeDataTagsBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeObserveDataTags)]; ok { - handshake.ObserveDataTags = observeDataTagsBlock.ToBytes() - } - // auth - if authNameBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthName)]; ok { - authName, err := authNameBlock.ToUTF8String() - if err != nil { - return nil, err + buf := observeDataTagsBlock.GetValBuf() + length := len(buf) / 4 + for i := 0; i < length; i++ { + pos := i * 4 + handshake.observeDataTags = append(handshake.observeDataTags, Tag(binary.LittleEndian.Uint32(buf[pos:pos+4]))) } - handshake.authName = authName } - if authPayloadBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthPayload)]; ok { - authPayload, err := authPayloadBlock.ToUTF8String() - if err != nil { - return nil, err - } - handshake.authPayload = authPayload + // metadata + if typeBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeMetadata)]; ok { + metadata := typeBlock.ToBytes() + handshake.metadata = metadata } return handshake, nil } - -// AuthPayload authentication payload -func (h *HandshakeFrame) AuthPayload() string { - return h.authPayload -} - -// AuthName authentication name -func (h *HandshakeFrame) AuthName() string { - return h.authName -} diff --git a/frame/handshake_frame_test.go b/frame/handshake_frame_test.go index b8c5d49..8818217 100644 --- a/frame/handshake_frame_test.go +++ b/frame/handshake_frame_test.go @@ -6,25 +6,25 @@ import ( "github.com/stretchr/testify/assert" ) -func TestHandshakeFrameEncode(t *testing.T) { - expectedName := "1234" - var expectedType byte = 0xD3 - m := NewHandshakeFrame(expectedName, "", expectedType, []byte{0x01, 0x02}, "token", "a") - assert.Equal(t, []byte{ - 0x80 | byte(TagOfHandshakeFrame), 0x19, - byte(TagOfHandshakeName), 0x04, 0x31, 0x32, 0x33, 0x34, - byte(TagOfHandshakeId), 0x0, - byte(TagOfHandshakeType), 0x01, 0xD3, - byte(TagOfHandshakeObserveDataTags), 0x02, 0x01, 0x02, - // byte(TagOfHandshakeAppID), 0x0, - byte(TagOfHandshakeAuthName), 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, - byte(TagOfHandshakeAuthPayload), 0x01, 0x61, - }, - m.Encode(), +func TestHandshakeFrame(t *testing.T) { + var ( + name = "hpds" + id = "sdfsdfsd" + streamType = byte(0x5F) + observeDataTags = []Tag{'a', 'b', 'c'} + metadata = []byte{'d', 'e', 'f'} ) - Handshake, err := DecodeToHandshakeFrame(m.Encode()) + f := NewHandshakeFrame(name, id, streamType, observeDataTags, metadata) + + buf := f.Encode() + got, err := DecodeToHandshakeFrame(buf) + assert.NoError(t, err) - assert.EqualValues(t, expectedName, Handshake.Name) - assert.EqualValues(t, expectedType, Handshake.ClientType) + + assert.Equal(t, name, got.Name()) + assert.Equal(t, id, got.ID()) + assert.Equal(t, streamType, got.StreamType()) + assert.Equal(t, observeDataTags, got.ObserveDataTags()) + assert.Equal(t, metadata, got.Metadata()) } diff --git a/frame/meta_frame.go b/frame/meta_frame.go index e9beb32..8c9bee5 100644 --- a/frame/meta_frame.go +++ b/frame/meta_frame.go @@ -11,9 +11,19 @@ import ( // MetaFrame is a coder encoded bytes, SeqId is a fixed value of TYPE_ID_TRANSACTION. // used for describes metadata for a DataFrame. type MetaFrame struct { - tid string - metadata []byte - sourceId string + tid string + metadata []byte + sourceId string + broadcast bool +} + +// randString genetates a random string. +func randString() string { + tid, err := gonanoid.New() + if err != nil { + tid = strconv.FormatInt(time.Now().UnixMicro(), 10) + } + return tid } // NewMetaFrame creates a new MetaFrame instance. @@ -55,15 +65,25 @@ func (m *MetaFrame) SourceId() string { return m.sourceId } +// SetBroadcast set broadcast mode +func (m *MetaFrame) SetBroadcast(enabled bool) { + m.broadcast = enabled +} + +// IsBroadcast returns the broadcast mode is enabled +func (m *MetaFrame) IsBroadcast() bool { + return m.broadcast +} + // Encode implements Frame.Encode method. func (m *MetaFrame) Encode() []byte { meta := coder.NewNodePacketEncoder(byte(TagOfMetaFrame)) - // transaction ID + // transaction Id transactionId := coder.NewPrimitivePacketEncoder(byte(TagOfTransactionId)) transactionId.SetStringValue(m.tid) meta.AddPrimitivePacket(transactionId) - // source ID + // source Id sourceId := coder.NewPrimitivePacketEncoder(byte(TagOfSourceId)) sourceId.SetStringValue(m.sourceId) meta.AddPrimitivePacket(sourceId) @@ -74,40 +94,47 @@ func (m *MetaFrame) Encode() []byte { metadata.SetBytesValue(m.metadata) meta.AddPrimitivePacket(metadata) } + // broadcast mode + broadcast := coder.NewPrimitivePacketEncoder(byte(TagOfBroadcast)) + broadcast.SetBoolValue(m.broadcast) + meta.AddPrimitivePacket(broadcast) return meta.Encode() } // DecodeToMetaFrame decode a MetaFrame instance from given buffer. -func DecodeToMetaFrame(buf []byte) (*MetaFrame, error) { +func DecodeToMetaFrame(buf []byte, meta *MetaFrame) error { nodeBlock := coder.NodePacket{} _, err := coder.DecodeToNodePacket(buf, &nodeBlock) if err != nil { - return nil, err + return err } - meta := &MetaFrame{} + //meta := &MetaFrame{} for k, v := range nodeBlock.PrimitivePackets { switch k { case byte(TagOfTransactionId): val, err := v.ToUTF8String() if err != nil { - return nil, err + return err } meta.tid = val - break case byte(TagOfMetadata): meta.metadata = v.ToBytes() - break case byte(TagOfSourceId): sourceId, err := v.ToUTF8String() if err != nil { - return nil, err + return err } meta.sourceId = sourceId - break + case byte(TagOfBroadcast): + broadcast, err := v.ToBool() + if err != nil { + return err + } + meta.broadcast = broadcast } } - return meta, nil + return nil } diff --git a/frame/meta_frame_test.go b/frame/meta_frame_test.go index 293c632..498186d 100644 --- a/frame/meta_frame_test.go +++ b/frame/meta_frame_test.go @@ -7,19 +7,22 @@ import ( ) func TestMetaFrameEncode(t *testing.T) { - m := NewMetaFrame() - tidbuf := []byte(m.tid) - result := []byte{0x80 | byte(TagOfMetaFrame), byte(1 + 1 + len(tidbuf) + 2), byte(TagOfTransactionId), byte(len(tidbuf))} - result = append(result, tidbuf...) + m := &MetaFrame{tid: randString()} + m.SetBroadcast(true) + tidBuf := []byte(m.tid) + result := []byte{0x80 | byte(TagOfMetaFrame), byte(1 + 1 + len(tidBuf) + 2 + 3), byte(TagOfTransactionId), byte(len(tidBuf))} + result = append(result, tidBuf...) result = append(result, byte(TagOfSourceId), 0x0) + result = append(result, byte(TagOfBroadcast), 0x1, 0x1) assert.Equal(t, result, m.Encode()) } func TestMetaFrameDecode(t *testing.T) { - buf := []byte{0x80 | byte(TagOfMetaFrame), 0x09, byte(TagOfTransactionId), 0x04, 0x31, 0x32, 0x33, 0x34, byte(TagOfSourceId), 0x01, 0x31} - meta, err := DecodeToMetaFrame(buf) + buf := []byte{0x80 | byte(TagOfMetaFrame), 0x0C, byte(TagOfTransactionId), 0x04, 0x31, 0x32, 0x33, 0x34, byte(TagOfSourceId), 0x01, 0x31, byte(TagOfBroadcast), 0x01, 0x01} + meta := &MetaFrame{} + err := DecodeToMetaFrame(buf, meta) assert.NoError(t, err) assert.EqualValues(t, "1234", meta.TransactionId()) assert.EqualValues(t, "1", meta.SourceId()) - t.Logf("%# x", buf) + assert.EqualValues(t, true, meta.IsBroadcast()) } diff --git a/frame/payload_frame.go b/frame/payload_frame.go index 2f1f43b..830abb1 100644 --- a/frame/payload_frame.go +++ b/frame/payload_frame.go @@ -4,20 +4,22 @@ import ( coder "git.hpds.cc/Component/mq_coder" ) +type Tag uint32 + // PayloadFrame is a coder encoded bytes, Tag is a fixed value TYPE_ID_PAYLOAD_FRAME // the Len is the length of Val. Val is also a coder encoded PrimitivePacket, storing // raw bytes as user's data type PayloadFrame struct { - Tag byte + Tag Tag Carriage []byte } -// NewPayloadFrame creates a new PayloadFrame with a given TagId of user's data -func NewPayloadFrame(tag byte) *PayloadFrame { - return &PayloadFrame{ - Tag: tag, - } -} +//// NewPayloadFrame creates a new PayloadFrame with a given TagId of user's data +//func NewPayloadFrame(tag byte) *PayloadFrame { +// return &PayloadFrame{ +// Tag: tag, +// } +//} // SetCarriage sets the user's raw data func (m *PayloadFrame) SetCarriage(buf []byte) *PayloadFrame { @@ -27,29 +29,38 @@ func (m *PayloadFrame) SetCarriage(buf []byte) *PayloadFrame { // Encode to coder encoded bytes func (m *PayloadFrame) Encode() []byte { - carriage := coder.NewPrimitivePacketEncoder(m.Tag) + tag := coder.NewPrimitivePacketEncoder(byte(TagOfPayloadDataTag)) + tag.SetUInt32Value(uint32(m.Tag)) + + carriage := coder.NewPrimitivePacketEncoder(byte(TagOfPayloadCarriage)) carriage.SetBytesValue(m.Carriage) payload := coder.NewNodePacketEncoder(byte(TagOfPayloadFrame)) + payload.AddPrimitivePacket(tag) payload.AddPrimitivePacket(carriage) return payload.Encode() } // DecodeToPayloadFrame decodes coder encoded bytes to PayloadFrame -func DecodeToPayloadFrame(buf []byte) (*PayloadFrame, error) { +func DecodeToPayloadFrame(buf []byte, payload *PayloadFrame) error { nodeBlock := coder.NodePacket{} _, err := coder.DecodeToNodePacket(buf, &nodeBlock) if err != nil { - return nil, err + return err } - payload := &PayloadFrame{} - for _, v := range nodeBlock.PrimitivePackets { - payload.Tag = v.SeqId() - payload.Carriage = v.GetValBuf() - break + if p, ok := nodeBlock.PrimitivePackets[byte(TagOfPayloadDataTag)]; ok { + tag, err := p.ToUInt32() + if err != nil { + return err + } + payload.Tag = Tag(tag) } - return payload, nil + if p, ok := nodeBlock.PrimitivePackets[byte(TagOfPayloadCarriage)]; ok { + payload.Carriage = p.GetValBuf() + } + + return nil } diff --git a/frame/payload_frame_test.go b/frame/payload_frame_test.go index 34b4bb7..e666506 100644 --- a/frame/payload_frame_test.go +++ b/frame/payload_frame_test.go @@ -7,14 +7,19 @@ import ( ) func TestPayloadFrameEncode(t *testing.T) { - f := NewPayloadFrame(0x13).SetCarriage([]byte("hpds")) - assert.Equal(t, []byte{0x80 | byte(TagOfPayloadFrame), 0x06, 0x13, 0x04, 0x68, 0x70, 0x64, 0x73}, f.Encode()) + f := &PayloadFrame{ + Tag(0x13), + []byte("yomo"), + } + f.SetCarriage([]byte("yomo")) + assert.Equal(t, []byte{0x80 | byte(TagOfPayloadFrame), 0x9, 0x1, 0x1, 0x13, 0x2, 0x04, 0x79, 0x6F, 0x6D, 0x6F}, f.Encode()) } func TestPayloadFrameDecode(t *testing.T) { - buf := []byte{0x80 | byte(TagOfPayloadFrame), 0x06, 0x13, 0x04, 0x68, 0x70, 0x64, 0x73} - payload, err := DecodeToPayloadFrame(buf) + buf := []byte{0x80 | byte(TagOfPayloadFrame), 0x9, 0x1, 0x1, 0x13, 0x2, 0x04, 0x79, 0x6F, 0x6D, 0x6F} + payload := new(PayloadFrame) + err := DecodeToPayloadFrame(buf, payload) assert.NoError(t, err) assert.EqualValues(t, 0x13, payload.Tag) - assert.Equal(t, []byte{0x68, 0x70, 0x64, 0x73}, payload.Carriage) + assert.Equal(t, []byte{0x79, 0x6F, 0x6D, 0x6F}, payload.Carriage) } diff --git a/frame_stream.go b/frame_stream.go index ed407d6..a1da2c6 100644 --- a/frame_stream.go +++ b/frame_stream.go @@ -8,6 +8,9 @@ import ( "git.hpds.cc/Component/network/frame" ) +// ErrStreamNil be returned if FrameStream underlying stream is nil. +var ErrStreamNil = errors.New("hpdsMq: frame stream underlying is nil") + // FrameStream is the QUIC Stream with the minimum unit Frame. type FrameStream struct { // Stream is a QUIC stream. @@ -16,27 +19,25 @@ type FrameStream struct { } // NewFrameStream creates a new FrameStream. -func NewFrameStream(s io.ReadWriter) *FrameStream { - return &FrameStream{ - stream: s, - mu: sync.Mutex{}, - } +func NewFrameStream(s io.ReadWriter) frame.ReadWriter { + return &FrameStream{stream: s} } // ReadFrame reads next frame from QUIC stream. func (fs *FrameStream) ReadFrame() (frame.Frame, error) { if fs.stream == nil { - return nil, errors.New("network.ReadStream: stream can not be nil") + return nil, ErrStreamNil } return ParseFrame(fs.stream) } // WriteFrame writes a frame into QUIC stream. -func (fs *FrameStream) WriteFrame(f frame.Frame) (int, error) { +func (fs *FrameStream) WriteFrame(frm frame.Frame) error { if fs.stream == nil { - return 0, errors.New("network.WriteFrame: stream can not be nil") + return ErrStreamNil } fs.mu.Lock() defer fs.mu.Unlock() - return fs.stream.Write(f.Encode()) + _, err := fs.stream.Write(frm.Encode()) + return err } diff --git a/go.mod b/go.mod index 0783c10..dad7534 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.19 require ( git.hpds.cc/Component/mq_coder v0.0.0-20221010064749-174ae7ae3340 - github.com/lucas-clemente/quic-go v0.29.1 github.com/matoous/go-nanoid/v2 v2.0.0 + github.com/quic-go/quic-go v0.33.0 github.com/stretchr/testify v1.8.0 go.uber.org/zap v1.23.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 @@ -14,24 +14,21 @@ require ( require ( github.com/BurntSushi/toml v1.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/golang/mock v1.6.0 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/kr/pretty v0.3.1 // indirect - github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect - github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect - github.com/nxadm/tail v1.4.8 // indirect - github.com/onsi/ginkgo v1.16.4 // indirect + github.com/onsi/ginkgo/v2 v2.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/quic-go/qtls-go1-19 v0.2.1 // indirect + github.com/quic-go/qtls-go1-20 v0.1.1 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect - golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect - golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect - golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect - golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect - golang.org/x/tools v0.1.10 // indirect - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect - gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect + golang.org/x/crypto v0.4.0 // indirect + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect + golang.org/x/mod v0.6.0 // indirect + golang.org/x/net v0.4.0 // indirect + golang.org/x/sys v0.3.0 // indirect + golang.org/x/tools v0.2.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/hpds_err/errors.go b/hpds_err/errors.go index 4359c50..4e5f355 100644 --- a/hpds_err/errors.go +++ b/hpds_err/errors.go @@ -3,33 +3,44 @@ package hpds_err import ( "fmt" - quic "github.com/lucas-clemente/quic-go" + quic "github.com/quic-go/quic-go" ) // HpdsError hpds error -type HpdsError struct { +type HpdsError interface { + error + // ErrorCode getter method + ErrorCode() ErrorCode +} + +type hpdsError struct { errorCode ErrorCode err error } // New create hpds error -func New(code ErrorCode, err error) *HpdsError { - return &HpdsError{ +func New(code ErrorCode, err error) HpdsError { + return &hpdsError{ errorCode: code, err: err, } } -func (e *HpdsError) Error() string { +func (e *hpdsError) Error() string { return fmt.Sprintf("%s error: message=%s", e.errorCode, e.err.Error()) } +// ErrorCode getter method +func (e *hpdsError) ErrorCode() ErrorCode { + return e.errorCode +} + // ErrorCode error code type ErrorCode uint64 const ( // ErrorCodeClientAbort client abort - ErrorCodeClientAbort ErrorCode = 0x00 + ErrorCodeClientAbort ErrorCode = 0xC7 // ErrorCodeUnknown unknown error ErrorCodeUnknown ErrorCode = 0xC0 // ErrorCodeClosed net closed @@ -52,6 +63,7 @@ const ( ErrorCodeUnknownClient ErrorCode = 0xCD // ErrorCodeDuplicateName unknown client error ErrorCodeDuplicateName ErrorCode = 0xC6 + ErrorCodeStartHandler ErrorCode = 0xC8 ) func (e ErrorCode) String() string { @@ -80,24 +92,26 @@ func (e ErrorCode) String() string { return "UnknownClient" case ErrorCodeDuplicateName: return "DuplicateName" + case ErrorCodeStartHandler: + return "StartHandler" default: return "XXX" } } // Is parse quic ApplicationErrorCode to hpds ErrorCode -func Is(he quic.ApplicationErrorCode, yerr ErrorCode) bool { - return uint64(he) == uint64(yerr) +func Is(he quic.ApplicationErrorCode, err ErrorCode) bool { + return uint64(he) == uint64(err) } // Parse parse quic ApplicationErrorCode -func Parse(he quic.ApplicationErrorCode) ErrorCode { - return ErrorCode(he) +func Parse(err quic.ApplicationErrorCode) ErrorCode { + return ErrorCode(err) } // To convert hpds ErrorCode to quic ApplicationErrorCode -func To(code ErrorCode) quic.ApplicationErrorCode { - return quic.ApplicationErrorCode(code) +func (e ErrorCode) To() quic.ApplicationErrorCode { + return quic.ApplicationErrorCode(e) } // DuplicateNameError duplicate name(sfn) @@ -119,6 +133,11 @@ func (e DuplicateNameError) Error() string { return e.err.Error() } +// ErrorCode getter method +func (e DuplicateNameError) ErrorCode() ErrorCode { + return ErrorCodeDuplicateName +} + // ConnId duplicate connection ID func (e DuplicateNameError) ConnId() string { return e.connId diff --git a/listener.go b/listener.go index eeb4ff2..057b6e3 100644 --- a/listener.go +++ b/listener.go @@ -3,7 +3,7 @@ package network import ( "crypto/tls" "git.hpds.cc/Component/network/log" - "github.com/lucas-clemente/quic-go" + "github.com/quic-go/quic-go" "net" "time" @@ -43,7 +43,7 @@ func newListener(conn net.PacketConn, tlsConfig *tls.Config, quicConfig *quic.Co if tlsConfig == nil { tc, err := pkgtls.CreateServerTLSConfig(conn.LocalAddr().String()) if err != nil { - log.Errorf("%sCreateServerTLSConfig: %v", ServerLogPrefix, err) + log.Errorf("CreateServerTLSConfig: %v", err) return &defaultListener{}, err } tlsConfig = tc @@ -55,7 +55,7 @@ func newListener(conn net.PacketConn, tlsConfig *tls.Config, quicConfig *quic.Co quicListener, err := quic.Listen(conn, tlsConfig, quicConfig) if err != nil { - log.Errorf("%squic Listen: %v", ServerLogPrefix, err) + log.Errorf("quic Listen: %v", err) return &defaultListener{}, err } diff --git a/metadata.go b/metadata.go deleted file mode 100644 index 25eeea5..0000000 --- a/metadata.go +++ /dev/null @@ -1,17 +0,0 @@ -package network - -import "git.hpds.cc/Component/network/frame" - -// Metadata is used for storing extra info of the application -type Metadata interface { - // Encode is the serialize method - Encode() []byte -} - -// MetadataBuilder is the builder of Metadata -type MetadataBuilder interface { - // Build will return a Metadata instance according to the handshake frame passed in - Build(f *frame.HandshakeFrame) (Metadata, error) - // Decode is the deserialize method - Decode(buf []byte) (Metadata, error) -} diff --git a/metadata/default.go b/metadata/default.go new file mode 100644 index 0000000..1ef67d9 --- /dev/null +++ b/metadata/default.go @@ -0,0 +1,38 @@ +// Package metadata provides a default implements of `Metadata`. +package metadata + +import ( + "git.hpds.cc/Component/network/frame" +) + +var _ Metadata = &Default{} + +// Default returns an implement of `Metadata`, +// the default `Metadata` do not store anything. +type Default struct{} + +// Encode returns nil, It indicates the application do not have metadata. +func (m *Default) Encode() []byte { + return nil +} + +type defaultBuilder struct { + m *Default +} + +// DefaultBuilder returns an implement of `Builder`, +// the default builder only return default `Metadata`, the default `Metadata` +// do not store anything. +func DefaultBuilder() Builder { + return &defaultBuilder{ + m: &Default{}, + } +} + +func (builder *defaultBuilder) Build(f *frame.HandshakeFrame) (Metadata, error) { + return builder.m, nil +} + +func (builder *defaultBuilder) Decode(buf []byte) (Metadata, error) { + return builder.m, nil +} diff --git a/metadata/default_test.go b/metadata/default_test.go new file mode 100644 index 0000000..9106880 --- /dev/null +++ b/metadata/default_test.go @@ -0,0 +1,21 @@ +package metadata + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMetadata(t *testing.T) { + builder := DefaultBuilder() + + m, err := builder.Build(nil) + + assert.NoError(t, err) + assert.Equal(t, []uint8([]byte(nil)), m.Encode()) + + de, err := builder.Decode([]byte{}) + + assert.NoError(t, err) + assert.Equal(t, m, de) +} diff --git a/metadata/metadata.go b/metadata/metadata.go new file mode 100644 index 0000000..b0df6c4 --- /dev/null +++ b/metadata/metadata.go @@ -0,0 +1,21 @@ +// Package metadata defines `Metadata` and the `Builder`. +package metadata + +import "git.hpds.cc/Component/network/frame" + +// Metadata is used for storing extra info of the application. +type Metadata interface { + // Encode is the serialize method, + // That represents the Metadata can be transmitted. + Encode() []byte +} + +// Builder is the builder of Metadata. +// the metadata usually be built from `HandshakeFrame`, +// and It can be decode as byte array for io transmission. +type Builder interface { + // Build returns a Metadata instance according to the handshake frame passed in. + Build(f *frame.HandshakeFrame) (Metadata, error) + // Decode is the deserialize method + Decode(buf []byte) (Metadata, error) +} diff --git a/parser_stream.go b/parser_stream.go index 849dee7..0ceb826 100644 --- a/parser_stream.go +++ b/parser_stream.go @@ -15,16 +15,11 @@ func ParseFrame(stream io.Reader) (frame.Frame, error) { } frameType := buf[0] - // determine the frame type switch frameType { case 0x80 | byte(frame.TagOfHandshakeFrame): - handshakeFrame, err := readHandshakeFrame(buf) - // logger.Debugf("%sHandshakeFrame: name=%s, type=%s", ParseFrameLogPrefix, handshakeFrame.Name, handshakeFrame.Type()) - return handshakeFrame, err + return frame.DecodeToHandshakeFrame(buf) case 0x80 | byte(frame.TagOfDataFrame): - data, err := readDataFrame(buf) - // logger.Debugf("%sDataFrame: tid=%s, tag=%#x, len(carriage)=%d", ParseFrameLogPrefix, data.TransactionID(), data.GetDataTag(), len(data.GetCarriage())) - return data, err + return frame.DecodeToDataFrame(buf) case 0x80 | byte(frame.TagOfAcceptedFrame): return frame.DecodeToAcceptedFrame(buf) case 0x80 | byte(frame.TagOfRejectedFrame): @@ -33,15 +28,15 @@ func ParseFrame(stream io.Reader) (frame.Frame, error) { return frame.DecodeToGoawayFrame(buf) case 0x80 | byte(frame.TagOfBackFlowFrame): return frame.DecodeToBackFlowFrame(buf) + case 0x80 | byte(frame.TagOfHandshakeAckFrame): + return frame.DecodeToHandshakeAckFrame(buf) + case 0x80 | byte(frame.TagOfAuthenticationFrame): + return frame.DecodeToAuthenticationFrame(buf) + case 0x80 | byte(frame.TagOfAuthenticationAckFrame): + return frame.DecodeToAuthenticationRespFrame(buf) + case 0x80 | (byte(frame.TagOfCloseStreamFrame)): + return frame.DecodeToCloseStreamFrame(buf) default: return nil, fmt.Errorf("unknown frame type, buf[0]=%#x", buf[0]) } } - -func readHandshakeFrame(buf []byte) (*frame.HandshakeFrame, error) { - return frame.DecodeToHandshakeFrame(buf) -} - -func readDataFrame(buf []byte) (*frame.DataFrame, error) { - return frame.DecodeToDataFrame(buf) -} diff --git a/router/default.go b/router/default.go new file mode 100644 index 0000000..d43f79e --- /dev/null +++ b/router/default.go @@ -0,0 +1,112 @@ +// Package router providers a default implement of `router` and `Route`. +package router + +import ( + "fmt" + "sync" + + "git.hpds.cc/Component/network/frame" + herr "git.hpds.cc/Component/network/hpds_err" + "git.hpds.cc/Component/network/metadata" +) + +// DefaultRouter providers a default implement of `router`, +// It routes the data according to obverse tag or connId. +type DefaultRouter struct { + r *defaultRoute +} + +// Default return the DefaultRouter. +func Default(functions []string) Router { + return &DefaultRouter{r: newRoute(functions)} +} + +// Route get route from metadata. +func (r *DefaultRouter) Route(metadata metadata.Metadata) Route { + return r.r +} + +// Clean router. +func (r *DefaultRouter) Clean() { + r.r.mu.Lock() + defer r.r.mu.Unlock() + + for key := range r.r.data { + delete(r.r.data, key) + } +} + +type defaultRoute struct { + functions []string + data map[frame.Tag]map[string]string + mu sync.RWMutex +} + +func newRoute(functions []string) *defaultRoute { + return &defaultRoute{ + functions: functions, + data: make(map[frame.Tag]map[string]string), + } +} + +func (r *defaultRoute) Add(connId string, name string, observeDataTags []frame.Tag) (err error) { + r.mu.Lock() + defer r.mu.Unlock() + + ok := false + for _, v := range r.functions { + if v == name { + ok = true + break + } + } + if !ok { + return fmt.Errorf("SFN[%s] does not exist in config functions", name) + } + +LOOP: + for _, conn := range r.data { + for connId, n := range conn { + if n == name { + err = herr.NewDuplicateNameError(connId, fmt.Errorf("SFN[%s] is already linked to another connection", name)) + delete(conn, connId) + break LOOP + } + } + } + + for _, tag := range observeDataTags { + conn := r.data[tag] + if conn == nil { + conn = make(map[string]string) + r.data[tag] = conn + } + r.data[tag][connId] = name + } + + return err +} + +func (r *defaultRoute) Remove(connId string) error { + r.mu.Lock() + defer r.mu.Unlock() + + for _, conn := range r.data { + delete(conn, connId) + } + + return nil +} + +func (r *defaultRoute) GetForwardRoutes(tag frame.Tag) []string { + r.mu.RLock() + defer r.mu.RUnlock() + + var keys []string + if conn := r.data[tag]; conn != nil { + for k := range conn { + keys = append(keys, k) + } + } + return keys +} diff --git a/router/default_test.go b/router/default_test.go new file mode 100644 index 0000000..03d39b2 --- /dev/null +++ b/router/default_test.go @@ -0,0 +1,40 @@ +package router + +import ( + "testing" + + "git.hpds.cc/Component/network/frame" + "git.hpds.cc/Component/network/metadata" + "github.com/stretchr/testify/assert" +) + +func TestRouter(t *testing.T) { + router := Default([]string{"sfn-1"}) + + m := &metadata.Default{} + + route := router.Route(m) + + err := route.Add("conn-1", "sfn-1", []frame.Tag{frame.Tag(1)}) + assert.NoError(t, err) + + ids := route.GetForwardRoutes(frame.Tag(1)) + assert.Equal(t, []string{"conn-1"}, ids) + + err = route.Add("conn-2", "sfn-2", []frame.Tag{frame.Tag(2)}) + assert.EqualError(t, err, "SFN[sfn-2] does not exist in config functions") + + err = route.Add("conn-3", "sfn-1", []frame.Tag{frame.Tag(1)}) + assert.EqualError(t, err, "SFN[sfn-1] is already linked to another connection") + + err = route.Remove("conn-1") + assert.NoError(t, err) + + ids = route.GetForwardRoutes(frame.Tag(1)) + assert.Equal(t, []string{"conn-3"}, ids) + + router.Clean() + + ids = route.GetForwardRoutes(frame.Tag(1)) + assert.Equal(t, []string(nil), ids) +} diff --git a/router.go b/router/router.go similarity index 58% rename from router.go rename to router/router.go index 5cf79a5..3a4181d 100644 --- a/router.go +++ b/router/router.go @@ -1,9 +1,14 @@ -package network +package router + +import ( + "git.hpds.cc/Component/network/frame" + "git.hpds.cc/Component/network/metadata" +) // Router is the interface to manage the routes for applications. type Router interface { // Route gets the route - Route(metadata Metadata) Route + Route(metadata metadata.Metadata) Route // Clean the routes. Clean() } @@ -11,9 +16,9 @@ type Router interface { // Route manages data subscribers according to their observed data tags. type Route interface { // Add a route. - Add(connId string, name string, observeDataTags []byte) error + Add(connId string, name string, observeDataTags []frame.Tag) error // Remove a route. Remove(connId string) error // GetForwardRoutes returns all the subscribers by the given data tag. - GetForwardRoutes(tag byte) []string + GetForwardRoutes(tag frame.Tag) (streamIds []string) } diff --git a/server.go b/server.go index c06925a..4593290 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,10 @@ import ( "context" "errors" "fmt" + "git.hpds.cc/Component/network/auth" + "git.hpds.cc/Component/network/metadata" + "git.hpds.cc/Component/network/router" + "github.com/quic-go/quic-go" "io" "net" "os" @@ -15,59 +19,47 @@ import ( "git.hpds.cc/Component/network/frame" "git.hpds.cc/Component/network/hpds_err" "git.hpds.cc/Component/network/log" - pkgtls "git.hpds.cc/Component/network/tls" - "github.com/lucas-clemente/quic-go" ) -const ( - // DefaultListenAddr is the default address to listen. - DefaultListenAddr = "0.0.0.0:9000" -) - -// ServerOption is the option for server. -type ServerOption func(*ServerOptions) - // FrameHandler is the handler for frame. type FrameHandler func(c *Context) error +// ConnectionHandler is the handler for quic connection +type ConnectionHandler func(conn quic.Connection) + // Server is the underlining server of Message Queue type Server struct { - name string - state string - connector Connector - router Router - metadataBuilder MetadataBuilder - counterOfDataFrame int64 - downStreams map[string]*Client - mu sync.Mutex - opts ServerOptions - beforeHandlers []FrameHandler - afterHandlers []FrameHandler + name string + connector *Connector + router router.Router + metadataBuilder metadata.Builder + counterOfDataFrame int64 + downStreams map[string]frame.Writer + mu sync.Mutex + opts *serverOptions + startHandlers []FrameHandler + beforeHandlers []FrameHandler + afterHandlers []FrameHandler + connectionCloseHandlers []ConnectionHandler + listener Listener } // NewServer create a Server instance. func NewServer(name string, opts ...ServerOption) *Server { + options := defaultServerOptions() + + for _, o := range opts { + o(options) + } s := &Server{ name: name, - connector: newConnector(), - downStreams: make(map[string]*Client), + downStreams: make(map[string]frame.Writer), + opts: options, } - _ = s.Init(opts...) return s } -// Init the options. -func (s *Server) Init(opts ...ServerOption) error { - for _, o := range opts { - o(&s.opts) - } - // options defaults - s.initOptions() - - return nil -} - // ListenAndServe starts the server. func (s *Server) ListenAndServe(ctx context.Context, addr string) error { if addr == "" { @@ -95,156 +87,192 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { } // listen the address - listener, err := newListener(conn, s.opts.TLSConfig, s.opts.QuicConfig) + listener, err := newListener(conn, s.opts.tlsConfig, s.opts.quicConfig) if err != nil { - log.Errorf("%slistener.Listen: err=%v", ServerLogPrefix, err) + log.Errorf("listener.Listen: err=%v", err) return err } - defer func() { - _ = listener.Close() - }() - log.Printf("%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), mode(), listener.Versions(), s.authNames()) + s.listener = listener + + log.Printf("[%s][%d] Listening on: %s, QUIC: %v, AUTH: %s", s.name, os.Getpid(), listener.Addr(), listener.Versions(), s.authNames()) - s.state = ConnStateConnected for { - _ = s.createNewClientConnection(ctx, listener) + conn, err := s.listener.Accept(ctx) + if err != nil { + log.Errorf("listener accept connections error", err) + return err + } + err = s.opts.alpnHandler(conn.ConnectionState().TLS.NegotiatedProtocol) + if err != nil { + _ = conn.CloseWithError(quic.ApplicationErrorCode(hpds_err.ErrorCodeRejected), err.Error()) + continue + } + stream0, err := conn.AcceptStream(ctx) + if err != nil { + continue + } + + controlStream := NewServerControlStream(conn, NewFrameStream(stream0)) + + // Auth accepts a AuthenticationFrame from client. The first frame from client must be + // AuthenticationFrame, It returns true if auth successful otherwise return false. + // It response to client a AuthenticationAckFrame. + err = controlStream.VerifyAuthentication(s.handleAuthenticationFrame) + if err != nil { + log.Warnf("Authentication Failed", "error", err) + continue + } + log.Debugf("Authentication Success") + + go func(qConn quic.Connection) { + streamGroup := NewStreamGroup(ctx, controlStream, s.connector) + + defer streamGroup.Wait() + defer s.doConnectionCloseHandlers(qConn) + + select { + case <-ctx.Done(): + return + case err := <-s.runWithStreamGroup(streamGroup): + log.Errorf("Client Close, %v", err) + } + }(conn) } } -// createNewClientConnection create a new connection when new hpds-client connected -func (s *Server) createNewClientConnection(ctx context.Context, listener Listener) error { - sctx, cancel := context.WithCancel(ctx) - defer cancel() +func (s *Server) runWithStreamGroup(group *StreamGroup) <-chan error { + errCh := make(chan error) - connect, e := listener.Accept(sctx) - if e != nil { - log.Errorf("%screate connection error: %v", ServerLogPrefix, e) - return e - } - - connId := GetConnId(connect) - log.Infof("%s1/ new connection: %s", ServerLogPrefix, connId) - - go func(ctx context.Context, qConn quic.Connection) { - for { - err := s.handle(ctx, qConn, connId) - if err != nil { - break - } - } - }(sctx, connect) - return nil -} - -func (s *Server) handle(ctx context.Context, qConn quic.Connection, connId string) error { - log.Infof("%s2/ waiting for new stream", ServerLogPrefix) - stream, err := qConn.AcceptStream(ctx) - if err != nil { - name := "--" - conn := s.connector.Get(connId) - if conn != nil { - _ = conn.Close() - // connector - s.connector.Remove(connId) - route := s.router.Route(conn.Metadata()) - if route != nil { - _ = route.Remove(connId) - } - name = conn.Name() - } else { - _ = s.Close() - - } - log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connId, err) - return err - } - defer func() { - _ = stream.Close() + go func() { + errCh <- group.Run(s.handleStreamContext) }() - log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connId) - // process frames on stream - // c := newContext(connId, stream) - c := newContext(qConn, stream) - defer c.Clean() - s.handleConnection(c) - log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID()) - return nil + return errCh } // Close will shut down the server. func (s *Server) Close() error { - if s.router != nil { - s.router.Clean() - } // connector if s.connector != nil { - s.connector.Clean() + s.connector.Close() + } + // listener + if s.listener != nil { + _ = s.listener.Close() + } + // router + if s.router != nil { + s.router.Clean() } return nil } -// handle streams on a connection -func (s *Server) handleConnection(c *Context) { - fs := NewFrameStream(c.Stream) +func (s *Server) handleRoute(c *Context) error { + if c.DataStream.StreamType() == StreamTypeStreamFunction { + md, err := s.metadataBuilder.Decode(c.DataStream.Metadata()) + if err != nil { + return err + } + // route + route := s.router.Route(md) + if route == nil { + return errors.New("handleHandshakeFrame route is nil") + } + if err := route.Add(c.StreamId(), c.DataStream.Name(), c.DataStream.ObserveDataTags()); err != nil { + // duplicate name + if e, ok := err.(hpds_err.DuplicateNameError); ok { + existsConnId := e.ConnId() + + log.Debugf("StreamFunction Duplicate Name, error: %s; sfn_name: %s, old_stream_id: %s; current_stream_id: %s", + e.Error(), c.DataStream.Name(), existsConnId, c.StreamId()) + + stream, ok, err := s.connector.Get(existsConnId) + if err != nil { + return err + } + if ok { + _ = stream.CloseWithError(e.Error()) + _ = s.connector.Remove(existsConnId) + } + } else { + return err + } + } + } + return nil +} + +// handleStreamContext handles data streams, +func (s *Server) handleStreamContext(c *Context) { + // handle route. + if err := s.handleRoute(c); err != nil { + c.CloseWithError(hpds_err.ErrorCodeRejected, err.Error()) + return + } + defer s.cleanRoute(c) + + // start frame handlers + for _, handler := range s.startHandlers { + if err := handler(c); err != nil { + log.Errorf("startHandlers error: %v", err) + c.CloseWithError(hpds_err.ErrorCodeStartHandler, err.Error()) + return + } + } + // check update for stream for { - log.Debugf("%shandleConnection waiting read next...", ServerLogPrefix) - f, err := fs.ReadFrame() + f, err := c.DataStream.ReadFrame() if err != nil { // if client close connection, will get ApplicationError with code = 0x00 if e, ok := err.(*quic.ApplicationError); ok { if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) { // client abort - log.Infof("%sclient close the connection", ServerLogPrefix) + log.Infof("client close the connection") break - } else { - ye := hpds_err.New(hpds_err.Parse(e.ErrorCode), err) - log.Errorf("%s[ERR] %s", ServerLogPrefix, ye) } + he := hpds_err.New(hpds_err.Parse(e.ErrorCode), err) + log.Errorf("read frame error: %v", he) } else if err == io.EOF { - log.Infof("%sthe connection is EOF", ServerLogPrefix) + log.Infof("connection EOF") break } if errors.Is(err, net.ErrClosed) { - // if client close the connection, net.ErrClosed will be raised + // if client close the connection, net.ErrClosed will be raise // by quic-go IdleTimeoutError after connection's KeepAlive config. - log.Warnf("%s[ERR] net.ErrClosed on [handleConnection] %v", ServerLogPrefix, net.ErrClosed) + log.Warnf("connection error, error: %v", net.ErrClosed) c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed") break } // any error occurred, we should close the stream // after this, conn.AcceptStream() will raise the error c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error()) - log.Warnf("%sconnection.Close()", ServerLogPrefix) + log.Warnf("connection close") break } - frameType := f.Type() - data := f.Encode() - log.Debugf("%stype=%s, frame[%d]=%# x", ServerLogPrefix, frameType, len(data), frame.Shortly(data)) - // add frame to contextFrame - contextFrame := c.WithFrame(f) + // add frame to context + c.WithFrame(f) // before frame handlers for _, handler := range s.beforeHandlers { - if e := handler(contextFrame); e != nil { - log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) - contextFrame.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error()) + if err := handler(c); err != nil { + log.Errorf("beforeFrameHandler error: %v", err) + c.CloseWithError(hpds_err.ErrorCodeBeforeHandler, err.Error()) return } } // main handler - if e := s.mainFrameHandler(contextFrame); e != nil { - log.Errorf("%smainFrameHandler e: %s", ServerLogPrefix, e) - contextFrame.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error()) + if err := s.mainFrameHandler(c); err != nil { + log.Errorf("mainFrameHandler error: %v", err) + c.CloseWithError(hpds_err.ErrorCodeMainHandler, err.Error()) return } // after frame handler for _, handler := range s.afterHandlers { - if e := handler(contextFrame); e != nil { - log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) - contextFrame.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error()) + if err := handler(c); err != nil { + log.Errorf("afterFrameHandler error: %v", err) + c.CloseWithError(hpds_err.ErrorCodeAfterHandler, err.Error()) return } } @@ -256,181 +284,93 @@ func (s *Server) mainFrameHandler(c *Context) error { frameType := c.Frame.Type() switch frameType { - case frame.TagOfHandshakeFrame: - if err = s.handleHandshakeFrame(c); err != nil { - log.Errorf("%shandleHandshakeFrame err: %s", ServerLogPrefix, err) - // close connections early to avoid resource consumption - if c.Stream != nil { - goawayFrame := frame.NewGoawayFrame(err.Error()) - if _, e := c.Stream.Write(goawayFrame.Encode()); e != nil { - log.Errorf("%s write to client[%s] GoawayFrame error:%v", ServerLogPrefix, c.ConnId, e) - return e - } - } - } - // case frame.TagOfPingFrame: - // s.handlePingFrame(mainStream, connection, f.(*frame.PingFrame)) case frame.TagOfDataFrame: if err = s.handleDataFrame(c); err != nil { c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err)) } else { - conn := s.connector.Get(c.connId) - if conn != nil && conn.ClientType() == ClientTypeProtocolGateway { - f := c.Frame.(*frame.DataFrame) - f.GetMetaFrame().SetMetadata(conn.Metadata().Encode()) - s.dispatchToDownStreams(f) - } + s.dispatchToDownStreams(c) // observe data tags back flow _ = s.handleBackFlowFrame(c) } default: - log.Errorf("%serr=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode())) + log.Errorf("err=%v, frame=%v", err, frame.Shortly(c.Frame.Encode())) } return nil } -// handle HandShakeFrame -func (s *Server) handleHandshakeFrame(c *Context) error { - f := c.Frame.(*frame.HandshakeFrame) +func (s *Server) handleAuthenticationFrame(f auth.Object) (bool, error) { + ok := auth.Authenticate(s.opts.auths, f) - log.Debugf("%sGOT HandshakeFrame : %# x", ServerLogPrefix, f) - // basic info - connId := c.ConnId() - clientId := f.ClientId - clientType := ClientType(f.ClientType) - stream := c.Stream - // credential - log.Debugf("%sClientType=%# x is %s, ClientId=%s, Credential=%s", ServerLogPrefix, f.ClientType, ClientType(f.ClientType), clientId, authName(f.AuthName())) - // authenticate - if !s.authenticate(f) { - err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName())) - // return err - log.Debugf("%s <%s> [%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, connId) - rejectedFrame := frame.NewRejectedFrame(err.Error()) - if _, err = stream.Write(rejectedFrame.Encode()); err != nil { - log.Debugf("%s write to <%s> [%s](%s) RejectedFrame error:%v", ServerLogPrefix, clientType, f.Name, connId, err) - return err - } - return nil + if ok { + log.Debugf("Successful authentication") + } else { + log.Warnf("Authentication failed", "credential", f.AuthName()) } - // client type - var conn Connection - switch clientType { - case ClientTypeProtocolGateway, ClientTypeStreamFunction: - // metadata - metadata, err := s.metadataBuilder.Build(f) - if err != nil { - return err - } - conn = newConnection(f.Name, f.ClientId, clientType, metadata, stream, f.ObserveDataTags) - - if clientType == ClientTypeStreamFunction { - // route - route := s.router.Route(metadata) - if route == nil { - return errors.New("handleHandshakeFrame route is nil") - } - if e1 := route.Add(connId, f.Name, f.ObserveDataTags); e1 != nil { - // duplicate name - if e2, ok := e1.(hpds_err.DuplicateNameError); ok { - existsConnID := e2.ConnId() - if conn = s.connector.Get(existsConnID); conn != nil { - log.Debugf("%s%s, write to SFN[%s](%s) GoawayFrame", ServerLogPrefix, e2.Error(), f.Name, existsConnID) - goawayFrame := frame.NewGoawayFrame(e2.Error()) - if e3 := conn.Write(goawayFrame); e3 != nil { - log.Errorf("%s write to SFN[%s] GoawayFrame error:%v", ServerLogPrefix, f.Name, e3) - return e3 - } - } - } else { - return e1 - } - } - } - case ClientTypeMessageQueue: - conn = newConnection(f.Name, f.ClientId, clientType, nil, stream, f.ObserveDataTags) - default: - // unknown client type - s.connector.Remove(connId) - err := fmt.Errorf("Illegal ClientType: %#x ", f.ClientType) - c.CloseWithError(hpds_err.ErrorCodeUnknownClient, err.Error()) - return err - } - - s.connector.Add(connId, conn) - log.Printf("%s <%s> [%s][%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, clientId, connId) - return nil + return ok, nil } - -// handle handleGoawayFrame -func (s *Server) handleGoawayFrame(c *Context) error { - f := c.Frame.(*frame.GoawayFrame) - - log.Debugf("%s GOT GoawayFrame code=%d, message==%s", ServerLogPrefix, hpds_err.ErrorCodeGoaway, f.Message()) - // c.CloseWithError(f.Code(), f.Message()) - _, err := c.Stream.Write(f.Encode()) - return err -} - -// will reuse quic-go's keep-alive feature -// func (s *Server) handlePingFrame(stream quic.Stream, conn quic.Connection, f *frame.PingFrame) error { -// log.Infof("%s------> GOT PingFrame : %# x", ServerLogPrefix, f) -// return nil -// } - func (s *Server) handleDataFrame(c *Context) error { // counter +1 atomic.AddInt64(&s.counterOfDataFrame, 1) // currentIssuer := f.GetIssuer() - fromId := c.ConnId() - from := s.connector.Get(fromId) - if from == nil { - log.Warnf("%shandleDataFrame connector cannot find %s", ServerLogPrefix, fromId) + fromId := c.StreamId() + from, ok, err := s.connector.Get(fromId) + if err != nil { + return err + } + if !ok { + log.Warnf("handleDataFrame connector cannot find, from_conn_id: %s", fromId) return fmt.Errorf("handleDataFrame connector cannot find %s", fromId) } f := c.Frame.(*frame.DataFrame) - var metadata Metadata - if from.ClientType() == ClientTypeMessageQueue { - m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata()) + m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata()) + if err != nil { + return err + } + + if m == nil { + m, err = s.metadataBuilder.Decode(from.Metadata()) if err != nil { return err } - metadata = m - } else { - metadata = from.Metadata() } // route - route := s.router.Route(metadata) + route := s.router.Route(m) if route == nil { - log.Warnf("%shandleDataFrame route is nil", ServerLogPrefix) + log.Warnf("handleDataFrame route is nil") return fmt.Errorf("handleDataFrame route is nil") } // get stream function connection ids from route - connIds := route.GetForwardRoutes(f.GetDataTag()) - for _, toId := range connIds { - conn := s.connector.Get(toId) - if conn == nil { - log.Errorf("%sconn is nil: (%s)", ServerLogPrefix, toId) + connIDs := route.GetForwardRoutes(f.GetDataTag()) + + log.Debugf("Data Routing Status, sfn_stream_ids: %v; connector: %v", connIDs, s.connector.GetSnapshot()) + + for _, toId := range connIDs { + conn, ok, err := s.connector.Get(toId) + if err != nil { + continue + } + if !ok { + log.Errorf("Can't find forward conn, error: conn is nil ;forward_conn_id: ", toId) continue } to := conn.Name() - log.Debugf("%shandleDataFrame tag=%#x tid=%s, counter=%d, from=[%s](%s), to=[%s](%s)", ServerLogPrefix, f.Tag(), f.TransactionId(), s.counterOfDataFrame, from.Name(), fromId, to, toId) + log.Infof("handleDataFrame, from_conn_name: %s; from_conn_id: %s; to_conn_name: %s; to_conn_id: %s; data_frame: %s", + from.Name(), fromId, to, toId, f.String()) // write data frame to stream - if err := conn.Write(f); err != nil { - log.Warnf("%shandleDataFrame conn.Write tag=%#x tid=%s, from=[%s](%s), to=[%s](%s), %v", ServerLogPrefix, f.Tag(), f.TransactionId(), from.Name(), fromId, to, toId, err) + if err := conn.WriteFrame(f); err != nil { + log.Errorf("handleDataFrame conn.Write, %v", err) } } return nil } - func (s *Server) handleBackFlowFrame(c *Context) error { f := c.Frame.(*frame.DataFrame) tag := f.GetDataTag() @@ -438,14 +378,17 @@ func (s *Server) handleBackFlowFrame(c *Context) error { sourceId := f.SourceId() // write to Protocol Gateway with BackFlowFrame bf := frame.NewBackFlowFrame(tag, carriage) - sourceConns := s.connector.GetProtocolGatewayConnections(sourceId, tag) + sourceConnList, err := s.connector.GetSourceConns(sourceId, tag) + if err != nil { + return err + } // conn := s.connector.Get(c.connId) // logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage) - for _, source := range sourceConns { + for _, source := range sourceConnList { if source != nil { - log.Debugf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, result=%# x", ServerLogPrefix, tag, sourceId, frame.Shortly(carriage)) - if err := source.Write(bf); err != nil { - log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err) + log.Debugf("handleBackFlowFrame tag:%#v; source_conn_id: %s, back_flow_frame: %s", tag, sourceId, f.String()) + if err := source.WriteFrame(bf); err != nil { + log.Errorf("handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", tag, sourceId, err) return err } } @@ -464,23 +407,32 @@ func (s *Server) StatsCounter() int64 { } // DownStreams return all the downstream servers. -func (s *Server) DownStreams() map[string]*Client { +func (s *Server) DownStreams() map[string]frame.Writer { return s.downStreams } // ConfigRouter is used to set router by Message Queue -func (s *Server) ConfigRouter(router Router) { +func (s *Server) ConfigRouter(router router.Router) { s.mu.Lock() s.router = router - log.Debugf("%sconfig router is %#v", ServerLogPrefix, router) + log.Debugf("config router is %#v", router) s.mu.Unlock() } // ConfigMetadataBuilder is used to set metadataBuilder by Message Queue -func (s *Server) ConfigMetadataBuilder(builder MetadataBuilder) { + +func (s *Server) ConfigMetadataBuilder(builder metadata.Builder) { s.mu.Lock() s.metadataBuilder = builder - log.Debugf("%sconfig metadataBuilder is %#v", ServerLogPrefix, builder) + log.Debugf("config metadataBuilder is %#v", builder) + s.mu.Unlock() +} + +// ConfigAlpnHandler is used to set alpnHandler by Emitter +func (s *Server) ConfigAlpnHandler(h func(string) error) { + s.mu.Lock() + s.opts.alpnHandler = h + log.Debugf("config alpnHandler") s.mu.Unlock() } @@ -493,10 +445,28 @@ func (s *Server) AddDownstreamServer(addr string, c *Client) { } // dispatch every DataFrames to all downStreams -func (s *Server) dispatchToDownStreams(df *frame.DataFrame) { - for addr, ds := range s.downStreams { - log.Debugf("%sdispatching to [%s]: %# x", ServerLogPrefix, addr, df.Tag()) - _ = ds.WriteFrame(df) +func (s *Server) dispatchToDownStreams(c *Context) { + stream, ok, err := s.connector.Get(c.StreamId()) + if err != nil { + log.Errorf("Connector Get Error, %v", err) + return + } + if !ok { + log.Debugf("dispatchTo Down Streams failed") + return + } + + if stream.StreamType() == StreamTypeSource { + f := c.Frame.(*frame.DataFrame) + if f.IsBroadcast() { + if f.GetMetaFrame().Metadata() == nil { + f.GetMetaFrame().SetMetadata(stream.Metadata()) + } + for addr, ds := range s.downStreams { + log.Infof("dispatching to, dispatch_addr: %s; tid: %s;", addr, "", f.TransactionId()) + _ = ds.WriteFrame(f) + } + } } } @@ -505,10 +475,6 @@ func GetConnId(conn quic.Connection) string { return conn.RemoteAddr().String() } -func (s *Server) initOptions() { - // defaults -} - func (s *Server) validateRouter() error { if s.router == nil { return errors.New("server's router is nil") @@ -523,14 +489,10 @@ func (s *Server) validateMetadataBuilder() error { return nil } -// Options returns the options of server. -func (s *Server) Options() ServerOptions { - return s.opts -} - -// Connector returns the connector of server. -func (s *Server) Connector() Connector { - return s.connector +// SetStartHandlers sets a function for operating connection, +// this function executes after handshake successful. +func (s *Server) SetStartHandlers(handlers ...FrameHandler) { + s.startHandlers = append(s.startHandlers, handlers...) } // SetBeforeHandlers set the before handlers of server. @@ -543,44 +505,30 @@ func (s *Server) SetAfterHandlers(handlers ...FrameHandler) { s.afterHandlers = append(s.afterHandlers, handlers...) } +// SetConnectionCloseHandlers set the connection close handlers of server. +func (s *Server) SetConnectionCloseHandlers(handlers ...ConnectionHandler) { + s.connectionCloseHandlers = append(s.connectionCloseHandlers, handlers...) +} + func (s *Server) authNames() []string { - if len(s.opts.Auths) == 0 { + if len(s.opts.auths) == 0 { return []string{"none"} } result := make([]string, 0) - for _, auth := range s.opts.Auths { - result = append(result, auth.Name()) + for _, a := range s.opts.auths { + result = append(result, a.Name()) } return result } -func (s *Server) authenticate(f *frame.HandshakeFrame) bool { - if len(s.opts.Auths) > 0 { - for _, auth := range s.opts.Auths { - if f.AuthName() == auth.Name() { - isAuthenticated := auth.Authenticate(f.AuthPayload()) - if isAuthenticated { - log.Debugf("%sauthenticated==%v", ServerLogPrefix, isAuthenticated) - return isAuthenticated - } - } - } - return false +func (s *Server) doConnectionCloseHandlers(qConn quic.Connection) { + log.Debugf("QUIC Connection Closed") + for _, h := range s.connectionCloseHandlers { + h(qConn) } - return true } -func mode() string { - if pkgtls.IsDev() { - return "DEVELOPMENT" - } - return "PRODUCTION" -} - -func authName(name string) string { - if name == "" { - return "empty" - } - - return name +func (s *Server) cleanRoute(c *Context) { + md, _ := s.metadataBuilder.Decode(c.DataStream.Metadata()) + _ = s.router.Route(md).Remove(c.StreamId()) } diff --git a/server_options.go b/server_options.go index 137a0f0..0fbdf29 100644 --- a/server_options.go +++ b/server_options.go @@ -2,55 +2,72 @@ package network import ( "crypto/tls" - "net" - "git.hpds.cc/Component/network/auth" - "github.com/lucas-clemente/quic-go" + "git.hpds.cc/Component/network/log" + "github.com/quic-go/quic-go" ) +const ( + // DefaultListenAddr is the default address to listen. + DefaultListenAddr = "0.0.0.0:9000" +) + +// ServerOption is the option for server. +type ServerOption func(*serverOptions) + // ServerOptions are the options for HPDS Network server. -type ServerOptions struct { - QuicConfig *quic.Config - TLSConfig *tls.Config - Addr string - Auths []auth.Authentication - Conn net.PacketConn +type serverOptions struct { + quicConfig *quic.Config + tlsConfig *tls.Config + addr string + auths map[string]auth.Authentication + alpnHandler func(proto string) error +} + +func defaultServerOptions() *serverOptions { + opts := &serverOptions{ + quicConfig: DefaultQuicConfig, + tlsConfig: nil, + addr: DefaultListenAddr, + auths: map[string]auth.Authentication{}, + } + opts.alpnHandler = func(proto string) error { + log.Infof("client alpn proto", "component", "server", "proto", proto) + return nil + } + return opts } // WithAddr sets the server address. func WithAddr(addr string) ServerOption { - return func(o *ServerOptions) { - o.Addr = addr + return func(o *serverOptions) { + o.addr = addr } } // WithAuth sets the server authentication method. func WithAuth(name string, args ...string) ServerOption { - return func(o *ServerOptions) { - if auth, ok := auth.GetAuth(name); ok { - auth.Init(args...) - o.Auths = append(o.Auths, auth) + return func(o *serverOptions) { + if a, ok := auth.GetAuth(name); ok { + a.Init(args...) + if o.auths == nil { + o.auths = make(map[string]auth.Authentication) + } + o.auths[a.Name()] = a } } } // WithServerTLSConfig sets the TLS configuration for the server. func WithServerTLSConfig(tc *tls.Config) ServerOption { - return func(o *ServerOptions) { - o.TLSConfig = tc + return func(o *serverOptions) { + o.tlsConfig = tc } } // WithServerQuicConfig sets the QUIC configuration for the server. func WithServerQuicConfig(qc *quic.Config) ServerOption { - return func(o *ServerOptions) { - o.QuicConfig = qc - } -} - -// WithConn sets the connection for the server. -func WithConn(conn net.PacketConn) ServerOption { - return func(o *ServerOptions) { - o.Conn = conn + return func(o *serverOptions) { + o.quicConfig = qc } } diff --git a/stream_group.go b/stream_group.go new file mode 100644 index 0000000..60d2fd6 --- /dev/null +++ b/stream_group.go @@ -0,0 +1,57 @@ +package network + +import ( + "context" + "sync" +) + +// StreamGroup is the group of stream includes ControlStream amd DataStream. +// One Connection has many DataStream and only one ControlStream, ControlStream authenticates +// Connection and recevies HandshakeFrame and CloseStreamFrame to create DataStream or close +// stream. the ControlStream always the first stream established between server and client. +type StreamGroup struct { + ctx context.Context + controlStream ServerControlStream + connector *Connector + group sync.WaitGroup +} + +// NewStreamGroup returns StreamGroup. +func NewStreamGroup(ctx context.Context, controlStream ServerControlStream, connector *Connector) *StreamGroup { + group := &StreamGroup{ + ctx: ctx, + controlStream: controlStream, + connector: connector, + } + return group +} + +// Run run contextFunc with connector. +// Run continus Accepts DataStream and create a Context to run with contextFunc. +// TODO: run in aop model, like setMetadata -> handleRoute -> before -> handle -> after. +func (g *StreamGroup) Run(contextFunc func(c *Context)) error { + for { + dataStream, err := g.controlStream.AcceptStream(g.ctx) + if err != nil { + return err + } + + g.group.Add(1) + _ = g.connector.Add(dataStream.ID(), dataStream) + + go func() { + defer func() { + g.group.Done() + _ = g.connector.Remove(dataStream.ID()) + }() + + c := newContext(dataStream) + defer c.Clean() + + contextFunc(c) + }() + } +} + +// Wait waits all dataStream down. +func (g *StreamGroup) Wait() { g.group.Wait() } diff --git a/tls/tls.go b/tls/tls.go index ff444a7..8531e04 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -48,6 +48,15 @@ func CreateServerTLSConfig(host string) (*tls.Config, error) { }, nil } +// MustCreateClientTLSConfig creates client tls config, It is panic If error here. +func MustCreateClientTLSConfig() *tls.Config { + conf, err := CreateClientTLSConfig() + if err != nil { + panic(err) + } + return conf +} + // CreateClientTLSConfig creates client tls config. func CreateClientTLSConfig() (*tls.Config, error) { // development mode