diff --git a/README.md b/README.md index 5e8a65d..1af40cd 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,4 @@ -# network +# network 网络库 + +### 基于QUIC协议 diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..d43b728 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,59 @@ +package auth + +import "strings" + +var ( + auths = make(map[string]Authentication) +) + +// Authentication for Network server +type Authentication interface { + // Init authentication initialize arguments + Init(args ...string) + // Authenticate authentication client's credential + Authenticate(payload string) bool + // Name authentication name + Name() string +} + +// Register register authentication +func Register(authentication Authentication) { + auths[authentication.Name()] = authentication +} + +// GetAuth get authentication by name +func GetAuth(name string) (Authentication, bool) { + auth, ok := auths[name] + return auth, ok +} + +// Credential client credential +type Credential struct { + name string + payload string +} + +// NewCredential create client credential +func NewCredential(payload string) *Credential { + idx := strings.Index(payload, ":") + if idx != -1 { + authName := payload[:idx] + idx++ + authPayload := payload[idx:] + return &Credential{ + name: authName, + payload: authPayload, + } + } + return &Credential{name: "none"} +} + +// Payload client credential payload +func (c *Credential) Payload() string { + return c.payload +} + +// Name client credential name +func (c *Credential) Name() string { + return c.name +} diff --git a/auth/auth.puml b/auth/auth.puml new file mode 100644 index 0000000..fec3081 --- /dev/null +++ b/auth/auth.puml @@ -0,0 +1,20 @@ +@startuml +namespace auth { + interface Authentication { + + Init(args ...string) + + Authenticate(payload string) bool + + Name() string + + } + class Credential << (S,Aquamarine) >> { + - name string + - payload string + + + Payload() string + + Name() string + + } +} + + +@enduml \ No newline at end of file diff --git a/client.go b/client.go new file mode 100644 index 0000000..993e7a9 --- /dev/null +++ b/client.go @@ -0,0 +1,465 @@ +package network + +import ( + "context" + "errors" + "fmt" + "git.hpds.cc/Component/network/hpds_err" + "git.hpds.cc/Component/network/id" + pkgtls "git.hpds.cc/Component/network/tls" + "net" + "sync" + "time" + + "github.com/lucas-clemente/quic-go" + + "git.hpds.cc/Component/network/auth" + "git.hpds.cc/Component/network/frame" + "git.hpds.cc/Component/network/log" +) + +// ClientOption client options +type ClientOption func(*ClientOptions) + +// ConnState describes the state of the connection. +type ConnState = string + +// Client is the abstraction of a HPDS-Client. a HPDS-Client can be +// Protocol Gateway, Message Queue or StreamFunction. +type Client struct { + name string // name of the client + clientId string // id of the client + clientType ClientType // type of the connection + conn quic.Connection // quic connection + stream quic.Stream // quic stream + state ConnState // state of the connection + processor func(*frame.DataFrame) // functions to invoke when data arrived + receiver func(*frame.BackFlowFrame) // functions to invoke when data is processed + addr string // the address of server connected to + mu sync.Mutex + opts ClientOptions + localAddr string // client local addr, it will be changed on reconnect + logger log.Logger + errChan chan error + closeChan chan bool + closed bool +} + +// NewClient creates a new HPDS-Client. +func NewClient(appName string, connType ClientType, opts ...ClientOption) *Client { + c := &Client{ + name: appName, + clientId: id.New(), + clientType: connType, + state: ConnStateReady, + opts: ClientOptions{}, + errChan: make(chan error), + closeChan: make(chan bool), + } + c.Init(opts...) + once.Do(func() { + c.init() + }) + + return c +} + +// Init the options. +func (c *Client) Init(opts ...ClientOption) error { + for _, o := range opts { + o(&c.opts) + } + return c.initOptions() +} + +// Connect connects to HPDS-MessageQueue. +func (c *Client) Connect(ctx context.Context, addr string) error { + + // TODO: refactor this later as a Connection Manager + // reconnect + // for download mq + // If you do not check for errors, the connection will be automatically reconnected + go c.reconnect(ctx, addr) + + // connect + if err := c.connect(ctx, addr); err != nil { + return err + } + + return nil +} + +func (c *Client) connect(ctx context.Context, addr string) error { + c.addr = addr + c.state = ConnStateConnecting + + // create quic connection + conn, err := quic.DialAddrContext(ctx, addr, c.opts.TLSConfig, c.opts.QuicConfig) + if err != nil { + c.state = ConnStateDisconnected + return err + } + + // quic stream + stream, err := conn.OpenStreamSync(ctx) + if err != nil { + c.state = ConnStateDisconnected + return err + } + + c.stream = stream + c.conn = conn + + c.state = ConnStateAuthenticating + // send handshake + handshake := frame.NewHandshakeFrame( + c.name, + c.clientId, + byte(c.clientType), + c.opts.ObserveDataTags, + c.opts.Credential.Name(), + c.opts.Credential.Payload(), + ) + err = c.WriteFrame(handshake) + if err != nil { + c.state = ConnStateRejected + return err + } + c.state = ConnStateConnected + c.localAddr = c.conn.LocalAddr().String() + + c.logger.Printf("%s [%s][%s](%s) is connected to HPDS-MQ %s", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr) + + // receiving frames + go c.handleFrame() + + return nil +} + +// handleFrame handles the logic when receiving frame from server. +func (c *Client) handleFrame() { + // transform raw QUIC stream to wire format + fs := NewFrameStream(c.stream) + for { + c.logger.Debugf("%shandleFrame connection state=%v", ClientLogPrefix, c.state) + // this will block until a frame is received + f, err := fs.ReadFrame() + if err != nil { + defer c.stream.Close() + // defer c.conn.CloseWithError(0xD0, err.Error()) + + c.logger.Debugf("%shandleFrame(): %T | %v", ClientLogPrefix, err, err) + if e, ok := err.(*quic.IdleTimeoutError); ok { + c.logger.Errorf("%sconnection timeout, err=%v, mq addr=%s", ClientLogPrefix, e, c.addr) + c.setState(ConnStateDisconnected) + } else if e, ok := err.(*quic.ApplicationError); ok { + c.logger.Infof("%sapplication error, err=%v, errcode=%v", ClientLogPrefix, e, e.ErrorCode) + if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeRejected) { + // if connection is rejected(eg: authenticate fails) from server + c.logger.Errorf("%sIllegal client, server rejected.", ClientLogPrefix) + c.setState(ConnStateRejected) + break + } else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) { + // client abort + c.logger.Infof("%sclient close the connection", ClientLogPrefix) + c.setState(ConnStateAborted) + break + } else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeGoaway) { + // server goaway + c.logger.Infof("%sserver goaway the connection", ClientLogPrefix) + c.setState(ConnStateGoaway) + break + } else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeHandshake) { + // handshake + c.logger.Errorf("%shandshake fails", ClientLogPrefix) + c.setState(ConnStateRejected) + break + } + } else if errors.Is(err, net.ErrClosed) { + // if client close the connection, net.ErrClosed will be raised + c.logger.Errorf("%sconnection is closed, err=%v", ClientLogPrefix, err) + c.setState(ConnStateDisconnected) + // by quic-go IdleTimeoutError after connection's KeepAlive config. + break + } else { + // any error occurred, we should close the stream + // after this, conn.AcceptStream() will raise the error + c.setState(ConnStateClosed) + c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeUnknown), err.Error()) + c.logger.Errorf("%sunknown error occurred, err=%v, state=%v", ClientLogPrefix, err, c.getState()) + break + } + } + if f == nil { + break + } + // read frame + // first, get frame type + frameType := f.Type() + c.logger.Debugf("%stype=%s, frame=%# x", ClientLogPrefix, frameType, frame.Shortly(f.Encode())) + switch frameType { + case frame.TagOfHandshakeFrame: + if v, ok := f.(*frame.HandshakeFrame); ok { + c.logger.Debugf("%sreceive HandshakeFrame, name=%v", ClientLogPrefix, v.Name) + } + case frame.TagOfPongFrame: + c.setState(ConnStatePong) + case frame.TagOfAcceptedFrame: + c.setState(ConnStateAccepted) + case frame.TagOfRejectedFrame: + c.setState(ConnStateRejected) + if v, ok := f.(*frame.RejectedFrame); ok { + c.logger.Errorf("%s receive RejectedFrame, message=%s", ClientLogPrefix, v.Message()) + c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeRejected), v.Message()) + c.errChan <- errors.New(v.Message()) + break + } + case frame.TagOfGoawayFrame: + c.setState(ConnStateGoaway) + if v, ok := f.(*frame.GoawayFrame); ok { + c.logger.Errorf("%s️ receive GoawayFrame, message=%s", ClientLogPrefix, v.Message()) + c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeGoaway), v.Message()) + c.errChan <- errors.New(v.Message()) + break + } + case frame.TagOfDataFrame: // DataFrame carries user's data + if v, ok := f.(*frame.DataFrame); ok { + c.setState(ConnStateTransportData) + c.logger.Debugf("%sreceive DataFrame, tag=%#x, tid=%s, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.TransactionId(), v.GetCarriage()) + if c.processor == nil { + c.logger.Warnf("%sprocessor is nil", ClientLogPrefix) + } else { + // TODO: should c.processor accept a DataFrame as parameter? + // c.processor(v.GetDataTagID(), v.GetCarriage(), v.GetMetaFrame()) + c.processor(v) + } + } + case frame.TagOfBackFlowFrame: + if v, ok := f.(*frame.BackFlowFrame); ok { + c.logger.Debugf("%sreceive BackFlowFrame, tag=%#x, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.GetCarriage()) + if c.receiver == nil { + c.logger.Warnf("%sreceiver is nil", ClientLogPrefix) + } else { + c.setState(ConnStateBackFlow) + c.receiver(v) + } + } + default: + c.logger.Errorf("%sunknown signal", ClientLogPrefix) + } + } +} + +// Close the client. +func (c *Client) Close() (err error) { + if c.conn != nil { + c.logger.Printf("%sclose the connection, name:%s, id:%s, addr:%s", ClientLogPrefix, c.name, c.clientId, c.conn.RemoteAddr().String()) + } + if c.stream != nil { + err = c.stream.Close() + if err != nil { + c.logger.Errorf("%s stream.Close(): %v", ClientLogPrefix, err) + } + } + if c.conn != nil { + err = c.conn.CloseWithError(0, "client-ask-to-close-this-connection") + if err != nil { + c.logger.Errorf("%s connection.Close(): %v", ClientLogPrefix, err) + } + } + // close channel + c.mu.Lock() + if !c.closed { + close(c.errChan) + close(c.closeChan) + c.closed = true + } + c.mu.Unlock() + + return err +} + +// WriteFrame writes a frame to the connection, gurantee threadsafe. +func (c *Client) WriteFrame(frm frame.Frame) error { + // write on QUIC stream + if c.stream == nil { + return errors.New("stream is nil") + } + if c.state == ConnStateDisconnected || c.state == ConnStateRejected { + return fmt.Errorf("client connection state is %s", c.state) + } + c.logger.Debugf("%s[%s](%s)@%s WriteFrame() will write frame: %s", ClientLogPrefix, c.name, c.localAddr, c.state, frm.Type()) + + data := frm.Encode() + // emit raw bytes of Frame + c.mu.Lock() + n, err := c.stream.Write(data) + c.mu.Unlock() + c.logger.Debugf("%sWriteFrame() wrote n=%d, data=%# x", ClientLogPrefix, n, frame.Shortly(data)) + if err != nil { + c.setState(ConnStateDisconnected) + // c.state = ConnStateDisconnected + if e, ok := err.(*quic.IdleTimeoutError); ok { + c.logger.Errorf("%sWriteFrame() connection timeout, err=%v", ClientLogPrefix, e) + } else { + c.logger.Errorf("%sWriteFrame() wrote error=%v", ClientLogPrefix, err) + return err + } + } + if n != len(data) { + err := errors.New("[client] hpds Client .Write() wrote error") + c.logger.Errorf("%s error:%v", ClientLogPrefix, err) + return err + } + return err +} + +// update connection state +func (c *Client) setState(state ConnState) { + c.logger.Debugf("setState to:%s", state) + c.mu.Lock() + c.state = state + c.mu.Unlock() +} + +// getState get connection state +func (c *Client) getState() ConnState { + c.mu.Lock() + defer c.mu.Unlock() + return c.state +} + +// update connection local addr +func (c *Client) setLocalAddr(addr string) { + c.mu.Lock() + c.localAddr = addr + c.mu.Unlock() +} + +// SetDataFrameObserver sets the data frame handler. +func (c *Client) SetDataFrameObserver(fn func(*frame.DataFrame)) { + c.processor = fn + c.logger.Debugf("%sSetDataFrameObserver(%v)", ClientLogPrefix, c.processor) +} + +// SetBackFlowFrameObserver sets the backflow frame handler. +func (c *Client) SetBackFlowFrameObserver(fn func(*frame.BackFlowFrame)) { + c.receiver = fn + c.logger.Debugf("%sSetBackFlowFrameObserver(%v)", ClientLogPrefix, c.receiver) +} + +// reconnect the connection between client and server. +func (c *Client) reconnect(ctx context.Context, addr string) { + t := time.NewTicker(1 * time.Second) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + c.logger.Debugf("%s[%s](%s) context.Done()", ClientLogPrefix, c.name, c.localAddr) + return + case <-c.closeChan: + c.logger.Debugf("%s[%s](%s) close channel", ClientLogPrefix, c.name, c.localAddr) + return + case <-t.C: + if c.getState() == ConnStateDisconnected { + c.logger.Printf("%s[%s][%s](%s) is reconnecting to HPDS-MQ %s...", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr) + err := c.connect(ctx, addr) + if err != nil { + c.logger.Errorf("%s[%s][%s](%s) reconnect error:%v", ClientLogPrefix, c.name, c.clientId, c.localAddr, err) + } + } + } + } +} + +func (c *Client) init() { + // // tracing + // _, _, err := tracing.NewTracerProvider(c.name) + // if err != nil { + // logger.Errorf("tracing: %v", err) + // } +} + +// ServerAddr returns the address of the server. +func (c *Client) ServerAddr() string { + return c.addr +} + +// initOptions init options defaults +func (c *Client) initOptions() error { + // logger + if c.logger == nil { + if c.opts.Logger != nil { + c.logger = c.opts.Logger + } else { + c.logger = log.Default() + } + } + // observe tag list + if c.opts.ObserveDataTags == nil { + c.opts.ObserveDataTags = make([]byte, 0) + } + // credential + if c.opts.Credential == nil { + c.opts.Credential = auth.NewCredential("") + } + // tls config + if c.opts.TLSConfig == nil { + tc, err := pkgtls.CreateClientTLSConfig() + if err != nil { + c.logger.Errorf("%sCreateClientTLSConfig: %v", ClientLogPrefix, err) + return err + } + c.opts.TLSConfig = tc + } + // quic config + if c.opts.QuicConfig == nil { + c.opts.QuicConfig = &quic.Config{ + Versions: []quic.VersionNumber{quic.Version1, quic.VersionDraft29}, + MaxIdleTimeout: time.Second * 40, + KeepAlivePeriod: time.Second * 20, + MaxIncomingStreams: 1000, + MaxIncomingUniStreams: 1000, + HandshakeIdleTimeout: time.Second * 3, + InitialStreamReceiveWindow: 1024 * 1024 * 2, + InitialConnectionReceiveWindow: 1024 * 1024 * 2, + TokenStore: quic.NewLRUTokenStore(10, 5), + DisablePathMTUDiscovery: true, + } + } + // credential + if c.opts.Credential != nil { + c.logger.Printf("%suse credential: [%s]", ClientLogPrefix, c.opts.Credential.Name()) + } + + return nil +} + +// SetObserveDataTags set the data tag list that will be observed. +// Deprecated: use hpds.WithObserveDataTags instead +func (c *Client) SetObserveDataTags(tag ...byte) { + c.opts.ObserveDataTags = append(c.opts.ObserveDataTags, tag...) +} + +// Logger get client's logger instance, you can customize this using `hpds.WithLogger` +func (c *Client) Logger() log.Logger { + return c.logger +} + +// SetErrorHandler set error handler +func (c *Client) SetErrorHandler(fn func(err error)) { + if fn != nil { + go func() { + err := <-c.errChan + if err != nil { + fn(err) + } + }() + } +} + +// ClientId return the client ID +func (c *Client) ClientId() string { + return c.clientId +} diff --git a/client_options.go b/client_options.go new file mode 100644 index 0000000..00079f9 --- /dev/null +++ b/client_options.go @@ -0,0 +1,53 @@ +package network + +import ( + "crypto/tls" + "github.com/lucas-clemente/quic-go" + + "git.hpds.cc/Component/network/auth" + "git.hpds.cc/Component/network/log" +) + +// ClientOptions are the options for HPDS client. +type ClientOptions struct { + ObserveDataTags []byte + QuicConfig *quic.Config + TLSConfig *tls.Config + Credential *auth.Credential + Logger log.Logger +} + +// WithObserveDataTags sets data tag list for the client. +func WithObserveDataTags(tags ...byte) ClientOption { + return func(o *ClientOptions) { + o.ObserveDataTags = tags + } +} + +// WithCredential sets the client credential method (used by client). +func WithCredential(payload string) ClientOption { + return func(o *ClientOptions) { + o.Credential = auth.NewCredential(payload) + } +} + +// WithClientTLSConfig sets tls config for the client. +func WithClientTLSConfig(tc *tls.Config) ClientOption { + return func(o *ClientOptions) { + o.TLSConfig = tc + } +} + +// WithClientQuicConfig sets quic config for the client. +func WithClientQuicConfig(qc *quic.Config) ClientOption { + return func(o *ClientOptions) { + o.QuicConfig = qc + } +} + +// WithLogger sets logger for the client. +func WithLogger(logger log.Logger) ClientOption { + return func(o *ClientOptions) { + o.Logger = logger + } +} diff --git a/client_type.go b/client_type.go new file mode 100644 index 0000000..6c2baef --- /dev/null +++ b/client_type.go @@ -0,0 +1,28 @@ +package network + +const ( + // ClientTypeNone is connection type "None". + ClientTypeNone ClientType = 0xFF + // ClientTypeProtocolGateway is connection type "Protocol Gateway". + ClientTypeProtocolGateway ClientType = 0x5F + // ClientTypeMessageQueue is connection type "Message Queue". + ClientTypeMessageQueue ClientType = 0x5E + // ClientTypeStreamFunction is connection type "Stream Function". + ClientTypeStreamFunction ClientType = 0x5D +) + +// ClientType represents the connection type. +type ClientType byte + +func (c ClientType) String() string { + switch c { + case ClientTypeProtocolGateway: + return "Source" + case ClientTypeMessageQueue: + return "Message Queue" + case ClientTypeStreamFunction: + return "Stream Function" + default: + return "None" + } +} diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..1c51705 --- /dev/null +++ b/connection.go @@ -0,0 +1,95 @@ +package network + +import ( + "git.hpds.cc/Component/network/frame" + "git.hpds.cc/Component/network/log" + "io" + "sync" +) + +// Connection wraps the specific io connections (typically quic.Connection) to transfer coder frames +type Connection interface { + io.Closer + + // Name returns the name of the connection, which is set by clients + Name() string + // ClientId connection client ID + ClientId() string + // ClientType returns the type of the client (Protocol Gateway | Message Queue | Stream Function) + ClientType() ClientType + // Metadata returns the extra info of the application + Metadata() Metadata + // Write should goroutine-safely send coder frames to peer side + Write(f frame.Frame) error + // ObserveDataTags observed data tags + ObserveDataTags() []byte +} + +type connection struct { + name string + clientType ClientType + metadata Metadata + stream io.ReadWriteCloser + clientId string + observed []byte // observed data tags + mu sync.Mutex + closed bool +} + +func newConnection(name string, clientId string, clientType ClientType, metadata Metadata, + stream io.ReadWriteCloser, observed []byte) Connection { + return &connection{ + name: name, + clientId: clientId, + clientType: clientType, + observed: observed, + metadata: metadata, + stream: stream, + closed: false, + } +} + +// Close implements io.Close interface +func (c *connection) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return c.stream.Close() +} + +// Name returns the name of the connection, which is set by clients +func (c *connection) Name() string { + return c.name +} + +// ClientType returns the type of the connection (Protocol Gateway | Message Queue | Stream Function ) +func (c *connection) ClientType() ClientType { + return c.clientType +} + +// Metadata returns the extra info of the application +func (c *connection) Metadata() Metadata { + return c.metadata +} + +// Write should goroutine-safely send coder frames to peer side +func (c *connection) Write(f frame.Frame) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + log.Warnf("%sclient stream is closed: %s", ServerLogPrefix, c.clientId) + return nil + } + _, err := c.stream.Write(f.Encode()) + return err +} + +// ObserveDataTags observed data tags +func (c *connection) ObserveDataTags() []byte { + return c.observed +} + +// ClientId connection client id +func (c *connection) ClientId() string { + return c.clientId +} diff --git a/connector.go b/connector.go new file mode 100644 index 0000000..756035d --- /dev/null +++ b/connector.go @@ -0,0 +1,87 @@ +package network + +import ( + "git.hpds.cc/Component/network/log" + "sync" +) + +var _ Connector = &connector{} + +// Connector is an interface to manage the connections and applications. +type Connector interface { + // Add a connection. + Add(connId string, conn Connection) + // Remove a connection. + Remove(connId string) + // Get a connection by connection id. + Get(connId string) Connection + // GetSnapshot gets the snapshot of all connections. + GetSnapshot() map[string]string + // GetProtocolGatewayConnections gets the connections by Protocol Gateway observe tags. + GetProtocolGatewayConnections(sourceId string, tags byte) []Connection + // Clean the connector. + Clean() +} + +type connector struct { + conns sync.Map +} + +func newConnector() Connector { + return &connector{conns: sync.Map{}} +} + +// Add a connection. +func (c *connector) Add(connID string, conn Connection) { + log.Debugf("%sconnector add: connId=%s", ServerLogPrefix, connID) + c.conns.Store(connID, conn) +} + +// Remove a connection. +func (c *connector) Remove(connID string) { + log.Debugf("%sconnector remove: connId=%s", ServerLogPrefix, connID) + c.conns.Delete(connID) +} + +// Get a connection by connection id. +func (c *connector) Get(connID string) Connection { + log.Debugf("%sconnector get connection: connId=%s", ServerLogPrefix, connID) + if conn, ok := c.conns.Load(connID); ok { + return conn.(Connection) + } + return nil +} + +// GetProtocolGatewayConnections gets the Protocol Gateway connection by tag. +func (c *connector) GetProtocolGatewayConnections(sourceId string, tag byte) []Connection { + conns := make([]Connection, 0) + + c.conns.Range(func(key interface{}, val interface{}) bool { + conn := val.(Connection) + for _, v := range conn.ObserveDataTags() { + if v == tag && conn.ClientType() == ClientTypeProtocolGateway && conn.ClientId() == sourceId { + conns = append(conns, conn) + } + } + return true + }) + + return conns +} + +// GetSnapshot gets the snapshot of all connections. +func (c *connector) GetSnapshot() map[string]string { + result := make(map[string]string) + c.conns.Range(func(key interface{}, val interface{}) bool { + connID := key.(string) + conn := val.(Connection) + result[connID] = conn.Name() + return true + }) + return result +} + +// Clean the connector. +func (c *connector) Clean() { + c.conns = sync.Map{} +} diff --git a/constant.go b/constant.go new file mode 100644 index 0000000..2cddae8 --- /dev/null +++ b/constant.go @@ -0,0 +1,40 @@ +package network + +import ( + "math/rand" + "sync" + "time" +) + +var ( + once sync.Once +) + +// ConnState represents the state of a connection. +const ( + ConnStateReady ConnState = "Ready" + ConnStateDisconnected ConnState = "Disconnected" + ConnStateConnecting ConnState = "Connecting" + ConnStateConnected ConnState = "Connected" + ConnStateAuthenticating ConnState = "Authenticating" + ConnStateAccepted ConnState = "Accepted" + ConnStateRejected ConnState = "Rejected" + ConnStatePing ConnState = "Ping" + ConnStatePong ConnState = "Pong" + ConnStateTransportData ConnState = "TransportData" + ConnStateAborted ConnState = "Aborted" + ConnStateClosed ConnState = "Closed" // close connection by server + ConnStateGoaway ConnState = "Goaway" + ConnStateBackFlow ConnState = "BackFlow" +) + +// Prefix is the prefix for logger. +const ( + ClientLogPrefix = "\033[36m[network:client]\033[0m " + ServerLogPrefix = "\033[32m[network:server]\033[0m " + ParseFrameLogPrefix = "\033[36m[network:stream_parser]\033[0m " +) + +func init() { + rand.Seed(time.Now().Unix()) +} diff --git a/context.go b/context.go new file mode 100644 index 0000000..2aceb6b --- /dev/null +++ b/context.go @@ -0,0 +1,191 @@ +package network + +import ( + "git.hpds.cc/Component/network/hpds_err" + "git.hpds.cc/Component/network/log" + "io" + "sync" + "time" + + "git.hpds.cc/Component/network/frame" + "github.com/lucas-clemente/quic-go" +) + +// Context for Network Server. +type Context struct { + // Conn is the connection of client. + Conn quic.Connection + connId string + // Stream is the long-lived connection between client and server. + Stream io.ReadWriteCloser + // Frame receives from client. + Frame frame.Frame + // Keys store the key/value pairs in context. + Keys map[string]interface{} + + mu sync.RWMutex +} + +func newContext(conn quic.Connection, stream quic.Stream) *Context { + return &Context{ + Conn: conn, + connId: conn.RemoteAddr().String(), + Stream: stream, + // keys: make(map[string]interface{}), + } +} + +// WithFrame sets a frame to context. +func (c *Context) WithFrame(f frame.Frame) *Context { + c.Frame = f + return c +} + +// Clean the context. +func (c *Context) Clean() { + log.Debugf("%sconn[%s] context clean", ServerLogPrefix, c.connId) + c.Stream = nil + c.Frame = nil + c.Keys = nil + c.Conn = nil +} + +// CloseWithError closes the stream and cleans the context. +func (c *Context) CloseWithError(code hpds_err.ErrorCode, msg string) { + log.Debugf("%sconn[%s] context close, errCode=%#x, msg=%s", ServerLogPrefix, c.connId, code, msg) + if c.Stream != nil { + c.Stream.Close() + } + if c.Conn != nil { + c.Conn.CloseWithError(quic.ApplicationErrorCode(code), msg) + } + c.Clean() +} + +// ConnId get quic connection id +func (c *Context) ConnId() string { + return c.connId +} + +// Set a key/value pair to context. +func (c *Context) Set(key string, value interface{}) { + c.mu.Lock() + if c.Keys == nil { + c.Keys = make(map[string]interface{}) + } + + c.Keys[key] = value + c.mu.Unlock() +} + +// Get the value by a specified key. +func (c *Context) Get(key string) (value interface{}, exists bool) { + c.mu.RLock() + value, exists = c.Keys[key] + c.mu.RUnlock() + return +} + +// GetString gets a string value by a specified key. +func (c *Context) GetString(key string) (s string) { + if val, ok := c.Get(key); ok && val != nil { + s, _ = val.(string) + } + return +} + +// GetBool gets a bool value by a specified key. +func (c *Context) GetBool(key string) (b bool) { + if val, ok := c.Get(key); ok && val != nil { + b, _ = val.(bool) + } + return +} + +// GetInt gets an int value by a specified key. +func (c *Context) GetInt(key string) (i int) { + if val, ok := c.Get(key); ok && val != nil { + i, _ = val.(int) + } + return +} + +// GetInt64 gets an int64 value by a specified key. +func (c *Context) GetInt64(key string) (i64 int64) { + if val, ok := c.Get(key); ok && val != nil { + i64, _ = val.(int64) + } + return +} + +// GetUint gets an uint value by a specified key. +func (c *Context) GetUint(key string) (ui uint) { + if val, ok := c.Get(key); ok && val != nil { + ui, _ = val.(uint) + } + return +} + +// GetUint64 gets an uint64 value by a specified key. +func (c *Context) GetUint64(key string) (ui64 uint64) { + if val, ok := c.Get(key); ok && val != nil { + ui64, _ = val.(uint64) + } + return +} + +// GetFloat64 gets a float64 value by a specified key. +func (c *Context) GetFloat64(key string) (f64 float64) { + if val, ok := c.Get(key); ok && val != nil { + f64, _ = val.(float64) + } + return +} + +// GetTime gets a time.Time value by a specified key. +func (c *Context) GetTime(key string) (t time.Time) { + if val, ok := c.Get(key); ok && val != nil { + t, _ = val.(time.Time) + } + return +} + +// GetDuration gets a time.Duration value by a specified key. +func (c *Context) GetDuration(key string) (d time.Duration) { + if val, ok := c.Get(key); ok && val != nil { + d, _ = val.(time.Duration) + } + return +} + +// GetStringSlice gets a []string value by a specified key. +func (c *Context) GetStringSlice(key string) (ss []string) { + if val, ok := c.Get(key); ok && val != nil { + ss, _ = val.([]string) + } + return +} + +// GetStringMap gets a map[string]interface{} value by a specified key. +func (c *Context) GetStringMap(key string) (sm map[string]interface{}) { + if val, ok := c.Get(key); ok && val != nil { + sm, _ = val.(map[string]interface{}) + } + return +} + +// GetStringMapString gets a map[string]string value by a specified key. +func (c *Context) GetStringMapString(key string) (sms map[string]string) { + if val, ok := c.Get(key); ok && val != nil { + sms, _ = val.(map[string]string) + } + return +} + +// GetStringMapStringSlice gets a map[string][]string value by a specified key. +func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) { + if val, ok := c.Get(key); ok && val != nil { + smss, _ = val.(map[string][]string) + } + return +} diff --git a/frame/accepted_frame.go b/frame/accepted_frame.go new file mode 100644 index 0000000..5e8571a --- /dev/null +++ b/frame/accepted_frame.go @@ -0,0 +1,36 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// AcceptedFrame is encoded bytes, Tag is a fixed value TYPE_ID_ACCEPTED_FRAME +type AcceptedFrame struct{} + +// NewAcceptedFrame creates a new AcceptedFrame with a given TagId of user's data +func NewAcceptedFrame() *AcceptedFrame { + return &AcceptedFrame{} +} + +// Type gets the type of Frame. +func (m *AcceptedFrame) Type() Type { + return TagOfAcceptedFrame +} + +// Encode to coder encoded bytes. +func (m *AcceptedFrame) Encode() []byte { + accepted := coder.NewNodePacketEncoder(byte(m.Type())) + accepted.AddBytes(nil) + + return accepted.Encode() +} + +// DecodeToAcceptedFrame decodes coder encoded bytes to AcceptedFrame. +func DecodeToAcceptedFrame(buf []byte) (*AcceptedFrame, error) { + nodeBlock := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &nodeBlock) + if err != nil { + return nil, err + } + return &AcceptedFrame{}, nil +} diff --git a/frame/accepted_frame_test.go b/frame/accepted_frame_test.go new file mode 100644 index 0000000..6f0a713 --- /dev/null +++ b/frame/accepted_frame_test.go @@ -0,0 +1,19 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAcceptedFrameEncode(t *testing.T) { + f := NewAcceptedFrame() + assert.Equal(t, []byte{0x80 | byte(TagOfAcceptedFrame), 0x00}, f.Encode()) +} + +func TestAcceptedFrameDecode(t *testing.T) { + buf := []byte{0x80 | byte(TagOfAcceptedFrame), 0x00} + ping, err := DecodeToAcceptedFrame(buf) + assert.NoError(t, err) + assert.Equal(t, []byte{0x80 | byte(TagOfAcceptedFrame), 0x00}, ping.Encode()) +} diff --git a/frame/backflow_frame.go b/frame/backflow_frame.go new file mode 100644 index 0000000..298422c --- /dev/null +++ b/frame/backflow_frame.go @@ -0,0 +1,70 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// BackFlowFrame is a coder encoded bytes +// It's used to receive stream function processed result +type BackFlowFrame struct { + Tag byte + Carriage []byte +} + +// NewBackFlowFrame creates a new BackFlowFrame with a given tag and carriage +func NewBackFlowFrame(tag byte, carriage []byte) *BackFlowFrame { + return &BackFlowFrame{ + Tag: tag, + Carriage: carriage, + } +} + +// Type gets the type of Frame. +func (f *BackFlowFrame) Type() Type { + return TagOfBackFlowFrame +} + +// SetCarriage sets the user's raw data +func (f *BackFlowFrame) SetCarriage(buf []byte) *BackFlowFrame { + f.Carriage = buf + return f +} + +// Encode to coder encoded bytes +func (f *BackFlowFrame) Encode() []byte { + carriage := coder.NewPrimitivePacketEncoder(f.Tag) + carriage.SetBytesValue(f.Carriage) + + node := coder.NewNodePacketEncoder(byte(TagOfBackFlowFrame)) + node.AddPrimitivePacket(carriage) + + return node.Encode() +} + +// GetDataTag return the Tag of user's data +func (f *BackFlowFrame) GetDataTag() byte { + return f.Tag +} + +// GetCarriage return data +func (f *BackFlowFrame) GetCarriage() []byte { + return f.Carriage +} + +// DecodeToBackFlowFrame decodes coder encoded bytes to BackFlowFrame +func DecodeToBackFlowFrame(buf []byte) (*BackFlowFrame, error) { + nodeBlock := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &nodeBlock) + if err != nil { + return nil, err + } + + payload := &BackFlowFrame{} + for _, v := range nodeBlock.PrimitivePackets { + payload.Tag = v.SeqId() + payload.Carriage = v.GetValBuf() + break + } + + return payload, nil +} diff --git a/frame/data_frame.go b/frame/data_frame.go new file mode 100644 index 0000000..217d943 --- /dev/null +++ b/frame/data_frame.go @@ -0,0 +1,110 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// DataFrame defines the data structure carried with user's data +type DataFrame struct { + metaFrame *MetaFrame + payloadFrame *PayloadFrame +} + +// NewDataFrame create `DataFrame` with a transactionId string, +// consider change transactionID to UUID type later +func NewDataFrame() *DataFrame { + data := &DataFrame{ + metaFrame: NewMetaFrame(), + } + return data +} + +// Type gets the type of Frame. +func (d *DataFrame) Type() Type { + return TagOfDataFrame +} + +// Tag return the tag of carriage data. +func (d *DataFrame) Tag() byte { + return d.payloadFrame.Tag +} + +// SetCarriage set user's raw data in `DataFrame` +func (d *DataFrame) SetCarriage(tag byte, carriage []byte) { + d.payloadFrame = NewPayloadFrame(tag).SetCarriage(carriage) +} + +// GetCarriage return user's raw data in `DataFrame` +func (d *DataFrame) GetCarriage() []byte { + return d.payloadFrame.Carriage +} + +// TransactionId return transactionId string +func (d *DataFrame) TransactionId() string { + return d.metaFrame.TransactionId() +} + +// SetTransactionId set transactionId string +func (d *DataFrame) SetTransactionId(transactionID string) { + d.metaFrame.SetTransactionId(transactionID) +} + +// GetMetaFrame return MetaFrame. +func (d *DataFrame) GetMetaFrame() *MetaFrame { + return d.metaFrame +} + +// GetDataTag return the Tag of user's data +func (d *DataFrame) GetDataTag() byte { + return d.payloadFrame.Tag +} + +// SetSourceId set the source id. +func (d *DataFrame) SetSourceId(sourceID string) { + d.metaFrame.SetSourceId(sourceID) +} + +// SourceId returns source id +func (d *DataFrame) SourceId() string { + return d.metaFrame.SourceId() +} + +// Encode return coder encoded bytes of `DataFrame` +func (d *DataFrame) Encode() []byte { + data := coder.NewNodePacketEncoder(byte(d.Type())) + // MetaFrame + data.AddBytes(d.metaFrame.Encode()) + // PayloadFrame + data.AddBytes(d.payloadFrame.Encode()) + + return data.Encode() +} + +// DecodeToDataFrame decode coder encoded bytes to `DataFrame` +func DecodeToDataFrame(buf []byte) (*DataFrame, error) { + packet := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &packet) + if err != nil { + return nil, err + } + + data := &DataFrame{} + + if metaBlock, ok := packet.NodePackets[byte(TagOfMetaFrame)]; ok { + meta, err := DecodeToMetaFrame(metaBlock.GetRawBytes()) + if err != nil { + return nil, err + } + data.metaFrame = meta + } + + if payloadBlock, ok := packet.NodePackets[byte(TagOfPayloadFrame)]; ok { + payload, err := DecodeToPayloadFrame(payloadBlock.GetRawBytes()) + if err != nil { + return nil, err + } + data.payloadFrame = payload + } + + return data, nil +} diff --git a/frame/data_frame_test.go b/frame/data_frame_test.go new file mode 100644 index 0000000..ba9019b --- /dev/null +++ b/frame/data_frame_test.go @@ -0,0 +1,39 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDataFrameEncode(t *testing.T) { + var userDataTag byte = 0x15 + d := NewDataFrame() + d.SetCarriage(userDataTag, []byte("hpds")) + + tidbuf := []byte(d.TransactionId()) + result := []byte{ + 0x80 | byte(TagOfDataFrame), byte(len(tidbuf) + 4 + 8 + 2), + 0x80 | byte(TagOfMetaFrame), byte(len(tidbuf) + 2 + 2), + byte(TagOfTransactionId), byte(len(tidbuf))} + result = append(result, tidbuf...) + result = append(result, byte(TagOfSourceId), 0x0) + result = append(result, 0x80|byte(TagOfPayloadFrame), 0x06, + userDataTag, 0x04, 0x79, 0x6F, 0x6D, 0x6F) + assert.Equal(t, result, d.Encode()) +} + +func TestDataFrameDecode(t *testing.T) { + var userDataTag byte = 0x15 + buf := []byte{ + 0x80 | byte(TagOfDataFrame), 0x10, + 0x80 | byte(TagOfMetaFrame), 0x06, + byte(TagOfTransactionId), 0x04, 0x31, 0x32, 0x33, 0x34, + 0x80 | byte(TagOfPayloadFrame), 0x06, + userDataTag, 0x04, 0x79, 0x6F, 0x6D, 0x6F} + data, err := DecodeToDataFrame(buf) + assert.NoError(t, err) + assert.EqualValues(t, "1234", data.TransactionId()) + assert.EqualValues(t, userDataTag, data.GetDataTag()) + assert.EqualValues(t, []byte("hpds"), data.GetCarriage()) +} diff --git a/frame/frame.go b/frame/frame.go new file mode 100644 index 0000000..4ac67ee --- /dev/null +++ b/frame/frame.go @@ -0,0 +1,106 @@ +package frame + +import ( + "os" + "strconv" +) + +// debugFrameSize print frame data size on debug mode +var debugFrameSize = 16 + +// Kinds of frames transferable within HPDS +const ( + // DataFrame + TagOfDataFrame Type = 0x3F + // MetaFrame of DataFrame + TagOfMetaFrame Type = 0x2F + TagOfMetadata Type = 0x03 + TagOfTransactionId Type = 0x01 + TagOfSourceId Type = 0x02 + // PayloadFrame of DataFrame + TagOfPayloadFrame Type = 0x2E + TagOfBackFlowFrame Type = 0x2D + + TagOfTokenFrame Type = 0x3E + // HandshakeFrame + TagOfHandshakeFrame Type = 0x3D + TagOfHandshakeName Type = 0x01 + TagOfHandshakeType Type = 0x02 + TagOfHandshakeId Type = 0x03 + TagOfHandshakeAuthName Type = 0x04 + TagOfHandshakeAuthPayload Type = 0x05 + TagOfHandshakeObserveDataTags Type = 0x06 + + TagOfPingFrame Type = 0x3C + TagOfPongFrame Type = 0x3B + TagOfAcceptedFrame Type = 0x3A + TagOfRejectedFrame Type = 0x39 + TagOfRejectedMessage Type = 0x02 + // GoawayFrame + TagOfGoawayFrame Type = 0x30 + TagOfGoawayCode Type = 0x01 + TagOfGoawayMessage Type = 0x02 +) + +// Type represents the type of frame. +type Type uint8 + +// Frame is the interface for frame. +type Frame interface { + // Type gets the type of Frame. + Type() Type + + // Encode the frame into []byte. + Encode() []byte +} + +func (f Type) String() string { + switch f { + case TagOfDataFrame: + return "DataFrame" + case TagOfTokenFrame: + return "TokenFrame" + case TagOfHandshakeFrame: + return "HandshakeFrame" + case TagOfPingFrame: + return "PingFrame" + case TagOfPongFrame: + return "PongFrame" + case TagOfAcceptedFrame: + return "AcceptedFrame" + case TagOfRejectedFrame: + return "RejectedFrame" + case TagOfGoawayFrame: + return "GoawayFrame" + case TagOfBackFlowFrame: + return "BackFlowFrame" + case TagOfMetaFrame: + return "MetaFrame" + case TagOfPayloadFrame: + return "PayloadFrame" + // case TagOfTransactionId: + // return "TransactionId" + case TagOfHandshakeName: + return "HandshakeName" + case TagOfHandshakeType: + return "HandshakeType" + default: + return "UnknownFrame" + } +} + +// Shortly reduce data size for easy viewing +func Shortly(data []byte) []byte { + if len(data) > debugFrameSize { + return data[:debugFrameSize] + } + return data +} + +func init() { + if envFrameSize := os.Getenv("DEBUG_FRAME_SIZE"); envFrameSize != "" { + if val, err := strconv.Atoi(envFrameSize); err == nil { + debugFrameSize = val + } + } +} diff --git a/frame/frame.puml b/frame/frame.puml new file mode 100644 index 0000000..a1c192f --- /dev/null +++ b/frame/frame.puml @@ -0,0 +1,110 @@ +@startuml +namespace frame { + class AcceptedFrame << (S,Aquamarine) >> { + + Type() Type + + Encode() []byte + + } + class BackFlowFrame << (S,Aquamarine) >> { + + Tag byte + + Carriage []byte + + + Type() Type + + SetCarriage(buf []byte) *BackFlowFrame + + Encode() []byte + + GetDataTag() byte + + GetCarriage() []byte + + } + class DataFrame << (S,Aquamarine) >> { + - metaFrame *MetaFrame + - payloadFrame *PayloadFrame + + + Type() Type + + Tag() byte + + SetCarriage(tag byte, carriage []byte) + + GetCarriage() []byte + + TransactionId() string + + SetTransactionId(transactionID string) + + GetMetaFrame() *MetaFrame + + GetDataTag() byte + + SetSourceId(sourceID string) + + SourceId() string + + Encode() []byte + + } + interface Frame { + + Type() Type + + Encode() []byte + + } + class GoawayFrame << (S,Aquamarine) >> { + - message string + + + Type() Type + + Encode() []byte + + Message() string + + } + class HandshakeFrame << (S,Aquamarine) >> { + - authName string + - authPayload string + + + Name string + + ClientId string + + ClientType byte + + ObserveDataTags []byte + + + Type() Type + + Encode() []byte + + AuthPayload() string + + AuthName() string + + } + class MetaFrame << (S,Aquamarine) >> { + - tid string + - metadata []byte + - sourceId string + + + SetTransactionId(transactionId string) + + TransactionId() string + + SetMetadata(metadata []byte) + + Metadata() []byte + + SetSourceId(sourceId string) + + SourceId() string + + Encode() []byte + + } + class PayloadFrame << (S,Aquamarine) >> { + + Tag byte + + Carriage []byte + + + SetCarriage(buf []byte) *PayloadFrame + + Encode() []byte + + } + class RejectedFrame << (S,Aquamarine) >> { + - message string + + + Type() Type + + Encode() []byte + + Message() string + + } + class Type << (S,Aquamarine) >> { + + String() string + + } + class frame.Type << (T, #FF7700) >> { + } +} + +"frame.Frame" <|-- "frame.AcceptedFrame" +"frame.Frame" <|-- "frame.BackFlowFrame" +"frame.Frame" <|-- "frame.DataFrame" +"frame.Frame" <|-- "frame.GoawayFrame" +"frame.Frame" <|-- "frame.HandshakeFrame" +"frame.Frame" <|-- "frame.RejectedFrame" + +"__builtin__.uint8" #.. "frame.Type" +@enduml \ No newline at end of file diff --git a/frame/goaway_frame.go b/frame/goaway_frame.go new file mode 100644 index 0000000..dd82a2c --- /dev/null +++ b/frame/goaway_frame.go @@ -0,0 +1,57 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// GoawayFrame is a coder encoded bytes, Tag is a fixed value TYPE_ID_GOAWAY_FRAME +type GoawayFrame struct { + message string +} + +// NewGoawayFrame creates a new GoawayFrame +func NewGoawayFrame(msg string) *GoawayFrame { + return &GoawayFrame{message: msg} +} + +// Type gets the type of Frame. +func (f *GoawayFrame) Type() Type { + return TagOfGoawayFrame +} + +// Encode to coder encoded bytes +func (f *GoawayFrame) Encode() []byte { + goaway := coder.NewNodePacketEncoder(byte(f.Type())) + // message + msgBlock := coder.NewPrimitivePacketEncoder(byte(TagOfGoawayMessage)) + msgBlock.SetStringValue(f.message) + + goaway.AddPrimitivePacket(msgBlock) + + return goaway.Encode() +} + +// Message goaway message +func (f *GoawayFrame) Message() string { + return f.message +} + +// DecodeToGoawayFrame decodes coder encoded bytes to GoawayFrame +func DecodeToGoawayFrame(buf []byte) (*GoawayFrame, error) { + node := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &node) + if err != nil { + return nil, err + } + + goaway := &GoawayFrame{} + // message + if msgBlock, ok := node.PrimitivePackets[byte(TagOfGoawayMessage)]; ok { + msg, err := msgBlock.ToUTF8String() + if err != nil { + return nil, err + } + goaway.message = msg + } + return goaway, nil +} diff --git a/frame/handshake_frame.go b/frame/handshake_frame.go new file mode 100644 index 0000000..4b4cda6 --- /dev/null +++ b/frame/handshake_frame.go @@ -0,0 +1,131 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// HandshakeFrame is a coder encoded. +type HandshakeFrame struct { + // Name is client name + Name string + // ClientId represents client id + ClientId string + // ClientType represents client type (Protocol Gateway | Stream Function) + ClientType byte + // ObserveDataTags are the client data tag list. + ObserveDataTags []byte + // auth + authName string + authPayload string +} + +// NewHandshakeFrame creates a new HandshakeFrame. +func NewHandshakeFrame(name string, clientId string, clientType byte, observeDataTags []byte, authName string, authPayload string) *HandshakeFrame { + return &HandshakeFrame{ + Name: name, + ClientId: clientId, + ClientType: clientType, + ObserveDataTags: observeDataTags, + authName: authName, + authPayload: authPayload, + } +} + +// Type gets the type of Frame. +func (h *HandshakeFrame) Type() Type { + return TagOfHandshakeFrame +} + +// Encode to coder encoding. +func (h *HandshakeFrame) Encode() []byte { + // name + nameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeName)) + nameBlock.SetStringValue(h.Name) + // client ID + idBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeId)) + idBlock.SetStringValue(h.ClientId) + // client type + typeBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeType)) + typeBlock.SetBytesValue([]byte{h.ClientType}) + // observe data tags + observeDataTagsBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeObserveDataTags)) + observeDataTagsBlock.SetBytesValue(h.ObserveDataTags) + // auth + authNameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthName)) + authNameBlock.SetStringValue(h.authName) + authPayloadBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthPayload)) + authPayloadBlock.SetStringValue(h.authPayload) + // handshake frame + handshake := coder.NewNodePacketEncoder(byte(h.Type())) + handshake.AddPrimitivePacket(nameBlock) + handshake.AddPrimitivePacket(idBlock) + handshake.AddPrimitivePacket(typeBlock) + handshake.AddPrimitivePacket(observeDataTagsBlock) + handshake.AddPrimitivePacket(authNameBlock) + handshake.AddPrimitivePacket(authPayloadBlock) + + return handshake.Encode() +} + +// DecodeToHandshakeFrame decodes coder encoded bytes to HandshakeFrame. +func DecodeToHandshakeFrame(buf []byte) (*HandshakeFrame, error) { + node := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &node) + if err != nil { + return nil, err + } + + handshake := &HandshakeFrame{} + // name + if nameBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeName)]; ok { + name, err := nameBlock.ToUTF8String() + if err != nil { + return nil, err + } + handshake.Name = name + } + // client id + if idBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeId)]; ok { + id, err := idBlock.ToUTF8String() + if err != nil { + return nil, err + } + handshake.ClientId = id + } + // client type + if typeBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeType)]; ok { + clientType := typeBlock.ToBytes() + handshake.ClientType = clientType[0] + } + // observe data tag list + if observeDataTagsBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeObserveDataTags)]; ok { + handshake.ObserveDataTags = observeDataTagsBlock.ToBytes() + } + // auth + if authNameBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthName)]; ok { + authName, err := authNameBlock.ToUTF8String() + if err != nil { + return nil, err + } + handshake.authName = authName + } + if authPayloadBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthPayload)]; ok { + authPayload, err := authPayloadBlock.ToUTF8String() + if err != nil { + return nil, err + } + handshake.authPayload = authPayload + } + + return handshake, nil +} + +// AuthPayload authentication payload +func (h *HandshakeFrame) AuthPayload() string { + return h.authPayload +} + +// AuthName authentication name +func (h *HandshakeFrame) AuthName() string { + return h.authName +} diff --git a/frame/handshake_frame_test.go b/frame/handshake_frame_test.go new file mode 100644 index 0000000..b8c5d49 --- /dev/null +++ b/frame/handshake_frame_test.go @@ -0,0 +1,30 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHandshakeFrameEncode(t *testing.T) { + expectedName := "1234" + var expectedType byte = 0xD3 + m := NewHandshakeFrame(expectedName, "", expectedType, []byte{0x01, 0x02}, "token", "a") + assert.Equal(t, []byte{ + 0x80 | byte(TagOfHandshakeFrame), 0x19, + byte(TagOfHandshakeName), 0x04, 0x31, 0x32, 0x33, 0x34, + byte(TagOfHandshakeId), 0x0, + byte(TagOfHandshakeType), 0x01, 0xD3, + byte(TagOfHandshakeObserveDataTags), 0x02, 0x01, 0x02, + // byte(TagOfHandshakeAppID), 0x0, + byte(TagOfHandshakeAuthName), 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + byte(TagOfHandshakeAuthPayload), 0x01, 0x61, + }, + m.Encode(), + ) + + Handshake, err := DecodeToHandshakeFrame(m.Encode()) + assert.NoError(t, err) + assert.EqualValues(t, expectedName, Handshake.Name) + assert.EqualValues(t, expectedType, Handshake.ClientType) +} diff --git a/frame/meta_frame.go b/frame/meta_frame.go new file mode 100644 index 0000000..e9beb32 --- /dev/null +++ b/frame/meta_frame.go @@ -0,0 +1,113 @@ +package frame + +import ( + "strconv" + "time" + + coder "git.hpds.cc/Component/mq_coder" + gonanoid "github.com/matoous/go-nanoid/v2" +) + +// MetaFrame is a coder encoded bytes, SeqId is a fixed value of TYPE_ID_TRANSACTION. +// used for describes metadata for a DataFrame. +type MetaFrame struct { + tid string + metadata []byte + sourceId string +} + +// NewMetaFrame creates a new MetaFrame instance. +func NewMetaFrame() *MetaFrame { + tid, err := gonanoid.New() + if err != nil { + tid = strconv.FormatInt(time.Now().UnixMicro(), 10) + } + return &MetaFrame{tid: tid} +} + +// SetTransactionId set the transaction id. +func (m *MetaFrame) SetTransactionId(transactionId string) { + m.tid = transactionId +} + +// TransactionId returns transactionId +func (m *MetaFrame) TransactionId() string { + return m.tid +} + +// SetMetadata set the extra info of the application +func (m *MetaFrame) SetMetadata(metadata []byte) { + m.metadata = metadata +} + +// Metadata returns the extra info of the application +func (m *MetaFrame) Metadata() []byte { + return m.metadata +} + +// SetSourceId set the source ID. +func (m *MetaFrame) SetSourceId(sourceId string) { + m.sourceId = sourceId +} + +// SourceId returns source ID +func (m *MetaFrame) SourceId() string { + return m.sourceId +} + +// Encode implements Frame.Encode method. +func (m *MetaFrame) Encode() []byte { + meta := coder.NewNodePacketEncoder(byte(TagOfMetaFrame)) + // transaction ID + transactionId := coder.NewPrimitivePacketEncoder(byte(TagOfTransactionId)) + transactionId.SetStringValue(m.tid) + meta.AddPrimitivePacket(transactionId) + + // source ID + sourceId := coder.NewPrimitivePacketEncoder(byte(TagOfSourceId)) + sourceId.SetStringValue(m.sourceId) + meta.AddPrimitivePacket(sourceId) + + // metadata + if m.metadata != nil { + metadata := coder.NewPrimitivePacketEncoder(byte(TagOfMetadata)) + metadata.SetBytesValue(m.metadata) + meta.AddPrimitivePacket(metadata) + } + + return meta.Encode() +} + +// DecodeToMetaFrame decode a MetaFrame instance from given buffer. +func DecodeToMetaFrame(buf []byte) (*MetaFrame, error) { + nodeBlock := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &nodeBlock) + if err != nil { + return nil, err + } + + meta := &MetaFrame{} + for k, v := range nodeBlock.PrimitivePackets { + switch k { + case byte(TagOfTransactionId): + val, err := v.ToUTF8String() + if err != nil { + return nil, err + } + meta.tid = val + break + case byte(TagOfMetadata): + meta.metadata = v.ToBytes() + break + case byte(TagOfSourceId): + sourceId, err := v.ToUTF8String() + if err != nil { + return nil, err + } + meta.sourceId = sourceId + break + } + } + + return meta, nil +} diff --git a/frame/meta_frame_test.go b/frame/meta_frame_test.go new file mode 100644 index 0000000..293c632 --- /dev/null +++ b/frame/meta_frame_test.go @@ -0,0 +1,25 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMetaFrameEncode(t *testing.T) { + m := NewMetaFrame() + tidbuf := []byte(m.tid) + result := []byte{0x80 | byte(TagOfMetaFrame), byte(1 + 1 + len(tidbuf) + 2), byte(TagOfTransactionId), byte(len(tidbuf))} + result = append(result, tidbuf...) + result = append(result, byte(TagOfSourceId), 0x0) + assert.Equal(t, result, m.Encode()) +} + +func TestMetaFrameDecode(t *testing.T) { + buf := []byte{0x80 | byte(TagOfMetaFrame), 0x09, byte(TagOfTransactionId), 0x04, 0x31, 0x32, 0x33, 0x34, byte(TagOfSourceId), 0x01, 0x31} + meta, err := DecodeToMetaFrame(buf) + assert.NoError(t, err) + assert.EqualValues(t, "1234", meta.TransactionId()) + assert.EqualValues(t, "1", meta.SourceId()) + t.Logf("%# x", buf) +} diff --git a/frame/payload_frame.go b/frame/payload_frame.go new file mode 100644 index 0000000..2f1f43b --- /dev/null +++ b/frame/payload_frame.go @@ -0,0 +1,55 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// PayloadFrame is a coder encoded bytes, Tag is a fixed value TYPE_ID_PAYLOAD_FRAME +// the Len is the length of Val. Val is also a coder encoded PrimitivePacket, storing +// raw bytes as user's data +type PayloadFrame struct { + Tag byte + Carriage []byte +} + +// NewPayloadFrame creates a new PayloadFrame with a given TagId of user's data +func NewPayloadFrame(tag byte) *PayloadFrame { + return &PayloadFrame{ + Tag: tag, + } +} + +// SetCarriage sets the user's raw data +func (m *PayloadFrame) SetCarriage(buf []byte) *PayloadFrame { + m.Carriage = buf + return m +} + +// Encode to coder encoded bytes +func (m *PayloadFrame) Encode() []byte { + carriage := coder.NewPrimitivePacketEncoder(m.Tag) + carriage.SetBytesValue(m.Carriage) + + payload := coder.NewNodePacketEncoder(byte(TagOfPayloadFrame)) + payload.AddPrimitivePacket(carriage) + + return payload.Encode() +} + +// DecodeToPayloadFrame decodes coder encoded bytes to PayloadFrame +func DecodeToPayloadFrame(buf []byte) (*PayloadFrame, error) { + nodeBlock := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &nodeBlock) + if err != nil { + return nil, err + } + + payload := &PayloadFrame{} + for _, v := range nodeBlock.PrimitivePackets { + payload.Tag = v.SeqId() + payload.Carriage = v.GetValBuf() + break + } + + return payload, nil +} diff --git a/frame/payload_frame_test.go b/frame/payload_frame_test.go new file mode 100644 index 0000000..dc520a7 --- /dev/null +++ b/frame/payload_frame_test.go @@ -0,0 +1,20 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPayloadFrameEncode(t *testing.T) { + f := NewPayloadFrame(0x13).SetCarriage([]byte("hpds")) + assert.Equal(t, []byte{0x80 | byte(TagOfPayloadFrame), 0x06, 0x13, 0x04, 0x79, 0x6F, 0x6D, 0x6F}, f.Encode()) +} + +func TestPayloadFrameDecode(t *testing.T) { + buf := []byte{0x80 | byte(TagOfPayloadFrame), 0x06, 0x13, 0x04, 0x79, 0x6F, 0x6D, 0x6F} + payload, err := DecodeToPayloadFrame(buf) + assert.NoError(t, err) + assert.EqualValues(t, 0x13, payload.Tag) + assert.Equal(t, []byte{0x79, 0x6F, 0x6D, 0x6F}, payload.Carriage) +} diff --git a/frame/rejected_frame.go b/frame/rejected_frame.go new file mode 100644 index 0000000..9cb5db4 --- /dev/null +++ b/frame/rejected_frame.go @@ -0,0 +1,56 @@ +package frame + +import ( + coder "git.hpds.cc/Component/mq_coder" +) + +// RejectedFrame is a coder encoded bytes, Tag is a fixed value TYPE_ID_REJECTED_FRAME +type RejectedFrame struct { + message string +} + +// NewRejectedFrame creates a new RejectedFrame with a given TagId of user's data +func NewRejectedFrame(msg string) *RejectedFrame { + return &RejectedFrame{message: msg} +} + +// Type gets the type of Frame. +func (f *RejectedFrame) Type() Type { + return TagOfRejectedFrame +} + +// Encode to coder encoded bytes +func (f *RejectedFrame) Encode() []byte { + rejected := coder.NewNodePacketEncoder(byte(f.Type())) + // message + msgBlock := coder.NewPrimitivePacketEncoder(byte(TagOfRejectedMessage)) + msgBlock.SetStringValue(f.message) + + rejected.AddPrimitivePacket(msgBlock) + + return rejected.Encode() +} + +// Message rejected message +func (f *RejectedFrame) Message() string { + return f.message +} + +// DecodeToRejectedFrame decodes coder encoded bytes to RejectedFrame +func DecodeToRejectedFrame(buf []byte) (*RejectedFrame, error) { + node := coder.NodePacket{} + _, err := coder.DecodeToNodePacket(buf, &node) + if err != nil { + return nil, err + } + rejected := &RejectedFrame{} + // message + if msgBlock, ok := node.PrimitivePackets[byte(TagOfRejectedMessage)]; ok { + msg, e := msgBlock.ToUTF8String() + if e != nil { + return nil, e + } + rejected.message = msg + } + return rejected, nil +} diff --git a/frame/rejected_frame_test.go b/frame/rejected_frame_test.go new file mode 100644 index 0000000..ee983c8 --- /dev/null +++ b/frame/rejected_frame_test.go @@ -0,0 +1,19 @@ +package frame + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRejectedFrameEncode(t *testing.T) { + f := NewRejectedFrame("") + assert.Equal(t, []byte{0x80 | byte(TagOfRejectedFrame), 0x02, 0x02, 0x00}, f.Encode()) +} + +func TestRejectedFrameDecode(t *testing.T) { + buf := []byte{0x80 | byte(TagOfRejectedFrame), 0x00} + ping, err := DecodeToRejectedFrame(buf) + assert.NoError(t, err) + assert.Equal(t, []byte{0x80 | byte(TagOfRejectedFrame), 0x2, 0x2, 0x0}, ping.Encode()) +} diff --git a/frame_stream.go b/frame_stream.go new file mode 100644 index 0000000..ed407d6 --- /dev/null +++ b/frame_stream.go @@ -0,0 +1,42 @@ +package network + +import ( + "errors" + "io" + "sync" + + "git.hpds.cc/Component/network/frame" +) + +// FrameStream is the QUIC Stream with the minimum unit Frame. +type FrameStream struct { + // Stream is a QUIC stream. + stream io.ReadWriter + mu sync.Mutex +} + +// NewFrameStream creates a new FrameStream. +func NewFrameStream(s io.ReadWriter) *FrameStream { + return &FrameStream{ + stream: s, + mu: sync.Mutex{}, + } +} + +// ReadFrame reads next frame from QUIC stream. +func (fs *FrameStream) ReadFrame() (frame.Frame, error) { + if fs.stream == nil { + return nil, errors.New("network.ReadStream: stream can not be nil") + } + return ParseFrame(fs.stream) +} + +// WriteFrame writes a frame into QUIC stream. +func (fs *FrameStream) WriteFrame(f frame.Frame) (int, error) { + if fs.stream == nil { + return 0, errors.New("network.WriteFrame: stream can not be nil") + } + fs.mu.Lock() + defer fs.mu.Unlock() + return fs.stream.Write(f.Encode()) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0783c10 --- /dev/null +++ b/go.mod @@ -0,0 +1,37 @@ +module git.hpds.cc/Component/network + +go 1.19 + +require ( + git.hpds.cc/Component/mq_coder v0.0.0-20221010064749-174ae7ae3340 + github.com/lucas-clemente/quic-go v0.29.1 + github.com/matoous/go-nanoid/v2 v2.0.0 + github.com/stretchr/testify v1.8.0 + go.uber.org/zap v1.23.0 + gopkg.in/natefinch/lumberjack.v2 v2.0.0 +) + +require ( + github.com/BurntSushi/toml v1.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.4.9 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/golang/mock v1.6.0 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.6.0 // indirect + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect + golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect + golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect + golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect + golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect + golang.org/x/tools v0.1.10 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/handler_type.go b/handler_type.go new file mode 100644 index 0000000..84f138c --- /dev/null +++ b/handler_type.go @@ -0,0 +1,9 @@ +package network + +import "git.hpds.cc/Component/network/frame" + +// AsyncHandler is the request-response mode (async) +type AsyncHandler func([]byte) (byte, []byte) + +// PipeHandler is the bidirectional stream mode (blocking). +type PipeHandler func(in <-chan []byte, out chan<- *frame.PayloadFrame) diff --git a/hpds_err/errors.go b/hpds_err/errors.go new file mode 100644 index 0000000..4359c50 --- /dev/null +++ b/hpds_err/errors.go @@ -0,0 +1,125 @@ +package hpds_err + +import ( + "fmt" + + quic "github.com/lucas-clemente/quic-go" +) + +// HpdsError hpds error +type HpdsError struct { + errorCode ErrorCode + err error +} + +// New create hpds error +func New(code ErrorCode, err error) *HpdsError { + return &HpdsError{ + errorCode: code, + err: err, + } +} + +func (e *HpdsError) Error() string { + return fmt.Sprintf("%s error: message=%s", e.errorCode, e.err.Error()) +} + +// ErrorCode error code +type ErrorCode uint64 + +const ( + // ErrorCodeClientAbort client abort + ErrorCodeClientAbort ErrorCode = 0x00 + // ErrorCodeUnknown unknown error + ErrorCodeUnknown ErrorCode = 0xC0 + // ErrorCodeClosed net closed + ErrorCodeClosed ErrorCode = 0xC1 + // ErrorCodeBeforeHandler before handler + ErrorCodeBeforeHandler ErrorCode = 0xC2 + // ErrorCodeMainHandler main handler + ErrorCodeMainHandler ErrorCode = 0xC3 + // ErrorCodeAfterHandler after handler + ErrorCodeAfterHandler ErrorCode = 0xC4 + // ErrorCodeHandshake handshake frame + ErrorCodeHandshake ErrorCode = 0xC5 + // ErrorCodeRejected server rejected + ErrorCodeRejected ErrorCode = 0xCC + // ErrorCodeGoaway goaway frame + ErrorCodeGoaway ErrorCode = 0xCF + // ErrorCodeData data frame + ErrorCodeData ErrorCode = 0xCE + // ErrorCodeUnknownClient unknown client error + ErrorCodeUnknownClient ErrorCode = 0xCD + // ErrorCodeDuplicateName unknown client error + ErrorCodeDuplicateName ErrorCode = 0xC6 +) + +func (e ErrorCode) String() string { + switch e { + case ErrorCodeClientAbort: + return "ClientAbort" + case ErrorCodeUnknown: + return "UnknownError" + case ErrorCodeClosed: + return "NetClosed" + case ErrorCodeBeforeHandler: + return "BeforeHandler" + case ErrorCodeMainHandler: + return "MainHandler" + case ErrorCodeAfterHandler: + return "AfterHandler" + case ErrorCodeHandshake: + return "Handshake" + case ErrorCodeRejected: + return "Rejected" + case ErrorCodeGoaway: + return "Goaway" + case ErrorCodeData: + return "DataFrame" + case ErrorCodeUnknownClient: + return "UnknownClient" + case ErrorCodeDuplicateName: + return "DuplicateName" + default: + return "XXX" + } +} + +// Is parse quic ApplicationErrorCode to hpds ErrorCode +func Is(he quic.ApplicationErrorCode, yerr ErrorCode) bool { + return uint64(he) == uint64(yerr) +} + +// Parse parse quic ApplicationErrorCode +func Parse(he quic.ApplicationErrorCode) ErrorCode { + return ErrorCode(he) +} + +// To convert hpds ErrorCode to quic ApplicationErrorCode +func To(code ErrorCode) quic.ApplicationErrorCode { + return quic.ApplicationErrorCode(code) +} + +// DuplicateNameError duplicate name(sfn) +type DuplicateNameError struct { + connId string + err error +} + +// NewDuplicateNameError create a duplicate name error +func NewDuplicateNameError(connId string, err error) DuplicateNameError { + return DuplicateNameError{ + connId: connId, + err: err, + } +} + +// Error raw error +func (e DuplicateNameError) Error() string { + return e.err.Error() +} + +// ConnId duplicate connection ID +func (e DuplicateNameError) ConnId() string { + return e.connId +} diff --git a/id/id.go b/id/id.go new file mode 100644 index 0000000..7f6c5d3 --- /dev/null +++ b/id/id.go @@ -0,0 +1,16 @@ +package id + +import ( + "git.hpds.cc/Component/network/log" + gonanoid "github.com/matoous/go-nanoid/v2" +) + +// New generate id +func New() string { + id, err := gonanoid.New() + if err != nil { + log.Errorf("generated id err=%v", err) + return "" + } + return id +} diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..eeb4ff2 --- /dev/null +++ b/listener.go @@ -0,0 +1,73 @@ +package network + +import ( + "crypto/tls" + "git.hpds.cc/Component/network/log" + "github.com/lucas-clemente/quic-go" + "net" + "time" + + pkgtls "git.hpds.cc/Component/network/tls" +) + +// A Listener for incoming connections +type Listener interface { + quic.Listener + // Name Listener's name + Name() string + // Versions get Version + Versions() []string +} + +var _ Listener = (*defaultListener)(nil) + +type defaultListener struct { + conf *quic.Config + quic.Listener +} + +// DefaultQuicConfig be used when `quicConfig` is nil. +var DefaultQuicConfig = &quic.Config{ + Versions: []quic.VersionNumber{quic.Version1, quic.VersionDraft29}, + MaxIdleTimeout: time.Second * 5, + KeepAlivePeriod: time.Second * 2, + MaxIncomingStreams: 1000, + MaxIncomingUniStreams: 1000, + HandshakeIdleTimeout: time.Second * 3, + InitialStreamReceiveWindow: 1024 * 1024 * 2, + InitialConnectionReceiveWindow: 1024 * 1024 * 2, + // DisablePathMTUDiscovery: true, +} + +func newListener(conn net.PacketConn, tlsConfig *tls.Config, quicConfig *quic.Config) (*defaultListener, error) { + if tlsConfig == nil { + tc, err := pkgtls.CreateServerTLSConfig(conn.LocalAddr().String()) + if err != nil { + log.Errorf("%sCreateServerTLSConfig: %v", ServerLogPrefix, err) + return &defaultListener{}, err + } + tlsConfig = tc + } + + if quicConfig == nil { + quicConfig = DefaultQuicConfig + } + + quicListener, err := quic.Listen(conn, tlsConfig, quicConfig) + if err != nil { + log.Errorf("%squic Listen: %v", ServerLogPrefix, err) + return &defaultListener{}, err + } + + return &defaultListener{conf: quicConfig, Listener: quicListener}, nil +} + +func (l *defaultListener) Name() string { return "QUIC-Server" } + +func (l *defaultListener) Versions() []string { + versions := make([]string, len(l.conf.Versions)) + for k, v := range l.conf.Versions { + versions[k] = v.String() + } + return versions +} diff --git a/log/logger.go b/log/logger.go new file mode 100644 index 0000000..f9ebf06 --- /dev/null +++ b/log/logger.go @@ -0,0 +1,143 @@ +package log + +import ( + "os" + "strings" +) + +// Level of log +type Level uint8 + +const ( + // DebugLevel defines debug log level. + DebugLevel Level = iota + 1 + // InfoLevel defines info log level. + InfoLevel + // WarnLevel defines warn log level. + WarnLevel + // ErrorLevel defines error log level. + ErrorLevel + // NoLevel defines an absent log level. + NoLevel Level = 254 + // Disabled disables the logger. + Disabled Level = 255 +) + +// Logger is the interface for logger. +type Logger interface { + // SetLevel sets the logger level + SetLevel(Level) + // SetEncoding sets the logger's encoding + SetEncoding(encoding string) + // Printf logs a message without level + Printf(template string, args ...interface{}) + // Debugf logs a message at DebugLevel + Debugf(template string, args ...interface{}) + // Infof logs a message at InfoLevel + Infof(template string, args ...interface{}) + // Warnf logs a message at WarnLevel + Warnf(template string, args ...interface{}) + // Errorf logs a message at ErrorLevel + Errorf(template string, args ...interface{}) + // Output file path to write log message + Output(file string) + // ErrorOutput file path to write error message + ErrorOutput(file string) +} + +// String the logger level +func (l Level) String() string { + switch l { + case DebugLevel: + return "DEBUG" + case ErrorLevel: + return "ERROR" + case WarnLevel: + return "WARN" + case InfoLevel: + return "INFO" + default: + return "" + } +} + +// 实例 +var logger Logger + +func init() { + logger = Default(isEnableDebug()) +} + +// SetLogger allows developers to customize the logger instance. +func SetLogger(l Logger) { + logger = l +} + +// EnableDebug enables the development model for logging. +// Deprecated +func EnableDebug() { + logger = Default(true) +} + +// Printf prints a formatted message without a specified level. +func Printf(format string, v ...interface{}) { + logger.Printf(format, v...) +} + +// Debugf logs a message at DebugLevel. +func Debugf(template string, args ...interface{}) { + logger.Debugf(template, args...) +} + +// Infof logs a message at InfoLevel. +func Infof(template string, args ...interface{}) { + logger.Infof(template, args...) +} + +// Warnf logs a message at WarnLevel. +func Warnf(template string, args ...interface{}) { + logger.Warnf(template, args...) +} + +// Errorf logs a message at ErrorLevel. +func Errorf(template string, args ...interface{}) { + logger.Errorf(template, args...) +} + +// isEnableDebug indicates whether the debug is enabled. +func isEnableDebug() bool { + return os.Getenv("HPDS_ENABLE_DEBUG") == "true" +} + +// isJSONFormat indicates whether the log is in JSON format. +func isJSONFormat() bool { + return os.Getenv("HPDS_LOG_FORMAT") == "json" +} + +func logFormat() string { + return os.Getenv("HPDS_LOG_FORMAT") +} + +func logLevel() Level { + envLevel := strings.ToLower(os.Getenv("HPDS_LOG_LEVEL")) + level := ErrorLevel + switch envLevel { + case "debug": + return DebugLevel + case "info": + return InfoLevel + case "warn": + return WarnLevel + case "error": + return ErrorLevel + } + return level +} + +func output() string { + return strings.ToLower(os.Getenv("HPDS_LOG_OUTPUT")) +} + +func errorOutput() string { + return strings.ToLower(os.Getenv("HPDS_LOG_ERROR_OUTPUT")) +} diff --git a/log/zap.go b/log/zap.go new file mode 100644 index 0000000..c81043c --- /dev/null +++ b/log/zap.go @@ -0,0 +1,227 @@ +package log + +import ( + stdlog "log" + "os" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" +) + +const ( + timeFormat = "2006-01-02 15:04:05.000" +) + +// zapLogger is the logger implementation in go.uber.org/zap +type zapLogger struct { + level zapcore.Level + debug bool + encoding string + opts []zap.Option + logger *zap.Logger + instance *zap.SugaredLogger + output string + errorOutput string +} + +// Default the default logger instance +func Default(debug ...bool) Logger { + z := New() + z.SetLevel(logLevel()) + if isJSONFormat() { + z.SetEncoding("json") + } + // env debug + if isEnableDebug() { + z.SetLevel(DebugLevel) + } + if len(debug) > 0 { + if debug[0] { + z.SetLevel(DebugLevel) + } + } + z.Output(output()) + z.ErrorOutput(errorOutput()) + + return z +} + +// New create new logger instance +func New(opts ...zap.Option) Logger { + // std logger + stdlog.Default().SetFlags(0) + stdlog.Default().SetOutput(new(logWriter)) + + z := zapLogger{ + level: zap.ErrorLevel, + debug: false, + encoding: "console", + opts: opts, + } + + return &z +} + +func openSinks(cfg zap.Config) (zapcore.WriteSyncer, zapcore.WriteSyncer, error) { + sink, closeOut, err := zap.Open(cfg.OutputPaths...) + if err != nil { + return nil, nil, err + } + errSink, _, err := zap.Open(cfg.ErrorOutputPaths...) + if err != nil { + closeOut() + return nil, nil, err + } + return sink, errSink, nil +} + +// SetEncoding set logger message coding +func (z *zapLogger) SetEncoding(enc string) { + z.encoding = enc +} + +// SetLevel set logger level +func (z *zapLogger) SetLevel(lvl Level) { + isDebug := lvl == DebugLevel + level := zap.ErrorLevel + switch lvl { + case DebugLevel: + level = zap.DebugLevel + case InfoLevel: + level = zap.InfoLevel + case WarnLevel: + level = zap.WarnLevel + case ErrorLevel: + level = zap.ErrorLevel + } + z.level = level + z.debug = isDebug +} + +// Output file path to write log message +func (z *zapLogger) Output(file string) { + if file != "" { + z.output = file + } +} + +// ErrorOutput file path to write log message +func (z *zapLogger) ErrorOutput(file string) { + if file != "" { + z.errorOutput = file + } +} + +// Printf logs a message wihout level +func (z *zapLogger) Printf(format string, v ...interface{}) { + stdlog.Printf(format, v...) +} + +// Debugf logs a message at DebugLevel +func (z *zapLogger) Debugf(template string, args ...interface{}) { + z.Instance().Debugf(template, args...) +} + +// Infof logs a message at InfoLevel +func (z *zapLogger) Infof(template string, args ...interface{}) { + z.Instance().Infof(template, args...) +} + +// Warnf logs a message at WarnLevel +func (z zapLogger) Warnf(template string, args ...interface{}) { + z.Instance().Warnf(template, args...) +} + +// Errorf logs a message at ErrorLevel +func (z zapLogger) Errorf(template string, args ...interface{}) { + z.Instance().Errorf(template, args...) +} + +func (z *zapLogger) Instance() *zap.SugaredLogger { + if z.instance == nil { + // zap + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "ts", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + FunctionKey: zapcore.OmitKey, + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalColorLevelEncoder, + EncodeTime: timeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + cfg := zap.Config{ + Level: zap.NewAtomicLevelAt(zap.ErrorLevel), + Development: z.debug, + DisableCaller: true, + DisableStacktrace: true, + Encoding: z.encoding, + EncoderConfig: encoderConfig, + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + } + cfg.Level.SetLevel(z.level) + if z.debug { + // set the minimal level to debug + cfg.Level.SetLevel(zap.DebugLevel) + } + // output + if z.output != "" { + cfg.OutputPaths = append(cfg.OutputPaths, z.output) + } + encoder := zapcore.NewConsoleEncoder(encoderConfig) + sink, _, err := openSinks(cfg) + if err != nil { + panic(err) + } + core := zapcore.NewCore(encoder, sink, cfg.Level) + // error output + if z.errorOutput != "" { + rotatedLogger := errorRotatedLogger(z.errorOutput, 10, 30, 7) + errorOutputOption := zap.Hooks(func(entry zapcore.Entry) error { + if entry.Level == zap.ErrorLevel { + msg, err := encoder.EncodeEntry(entry, nil) + if err != nil { + return err + } + rotatedLogger.Write(msg.Bytes()) + } + return nil + }) + z.opts = append(z.opts, errorOutputOption) + } + l := zap.New(core, z.opts...) + + z.logger = l + z.instance = z.logger.Sugar() + } + return z.instance +} + +func errorRotatedLogger(file string, maxSize, maxBacukups, maxAge int) *lumberjack.Logger { + return &lumberjack.Logger{ + Filename: file, + MaxSize: maxSize, + MaxBackups: maxBacukups, + MaxAge: maxAge, + Compress: false, + } +} + +func timeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.Format(timeFormat)) +} + +type logWriter struct{} + +func (l logWriter) Write(bytes []byte) (int, error) { + os.Stderr.WriteString(time.Now().Format(timeFormat)) + os.Stderr.Write([]byte("\t")) + return os.Stderr.Write(bytes) +} diff --git a/metadata.go b/metadata.go new file mode 100644 index 0000000..25eeea5 --- /dev/null +++ b/metadata.go @@ -0,0 +1,17 @@ +package network + +import "git.hpds.cc/Component/network/frame" + +// Metadata is used for storing extra info of the application +type Metadata interface { + // Encode is the serialize method + Encode() []byte +} + +// MetadataBuilder is the builder of Metadata +type MetadataBuilder interface { + // Build will return a Metadata instance according to the handshake frame passed in + Build(f *frame.HandshakeFrame) (Metadata, error) + // Decode is the deserialize method + Decode(buf []byte) (Metadata, error) +} diff --git a/parser_stream.go b/parser_stream.go new file mode 100644 index 0000000..849dee7 --- /dev/null +++ b/parser_stream.go @@ -0,0 +1,47 @@ +package network + +import ( + "fmt" + coder "git.hpds.cc/Component/mq_coder" + "git.hpds.cc/Component/network/frame" + "io" +) + +// ParseFrame parses the frame from QUIC stream. +func ParseFrame(stream io.Reader) (frame.Frame, error) { + buf, err := coder.ReadPacket(stream) + if err != nil { + return nil, err + } + + frameType := buf[0] + // determine the frame type + switch frameType { + case 0x80 | byte(frame.TagOfHandshakeFrame): + handshakeFrame, err := readHandshakeFrame(buf) + // logger.Debugf("%sHandshakeFrame: name=%s, type=%s", ParseFrameLogPrefix, handshakeFrame.Name, handshakeFrame.Type()) + return handshakeFrame, err + case 0x80 | byte(frame.TagOfDataFrame): + data, err := readDataFrame(buf) + // logger.Debugf("%sDataFrame: tid=%s, tag=%#x, len(carriage)=%d", ParseFrameLogPrefix, data.TransactionID(), data.GetDataTag(), len(data.GetCarriage())) + return data, err + case 0x80 | byte(frame.TagOfAcceptedFrame): + return frame.DecodeToAcceptedFrame(buf) + case 0x80 | byte(frame.TagOfRejectedFrame): + return frame.DecodeToRejectedFrame(buf) + case 0x80 | byte(frame.TagOfGoawayFrame): + return frame.DecodeToGoawayFrame(buf) + case 0x80 | byte(frame.TagOfBackFlowFrame): + return frame.DecodeToBackFlowFrame(buf) + default: + return nil, fmt.Errorf("unknown frame type, buf[0]=%#x", buf[0]) + } +} + +func readHandshakeFrame(buf []byte) (*frame.HandshakeFrame, error) { + return frame.DecodeToHandshakeFrame(buf) +} + +func readDataFrame(buf []byte) (*frame.DataFrame, error) { + return frame.DecodeToDataFrame(buf) +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..5cf79a5 --- /dev/null +++ b/router.go @@ -0,0 +1,19 @@ +package network + +// Router is the interface to manage the routes for applications. +type Router interface { + // Route gets the route + Route(metadata Metadata) Route + // Clean the routes. + Clean() +} + +// Route manages data subscribers according to their observed data tags. +type Route interface { + // Add a route. + Add(connId string, name string, observeDataTags []byte) error + // Remove a route. + Remove(connId string) error + // GetForwardRoutes returns all the subscribers by the given data tag. + GetForwardRoutes(tag byte) []string +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..87362ba --- /dev/null +++ b/server.go @@ -0,0 +1,567 @@ +package network + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "sync/atomic" + + // authentication implements, Currently, only token authentication is implemented + _ "git.hpds.cc/Component/network/auth" + "git.hpds.cc/Component/network/frame" + "git.hpds.cc/Component/network/hpds_err" + "git.hpds.cc/Component/network/log" + pkgtls "git.hpds.cc/Component/network/tls" + "github.com/lucas-clemente/quic-go" +) + +const ( + // DefaultListenAddr is the default address to listen. + DefaultListenAddr = "0.0.0.0:9000" +) + +// ServerOption is the option for server. +type ServerOption func(*ServerOptions) + +// FrameHandler is the handler for frame. +type FrameHandler func(c *Context) error + +// Server is the underlining server of Message Queue +type Server struct { + name string + state string + connector Connector + router Router + metadataBuilder MetadataBuilder + counterOfDataFrame int64 + downStreams map[string]*Client + mu sync.Mutex + opts ServerOptions + beforeHandlers []FrameHandler + afterHandlers []FrameHandler +} + +// NewServer create a Server instance. +func NewServer(name string, opts ...ServerOption) *Server { + s := &Server{ + name: name, + connector: newConnector(), + downStreams: make(map[string]*Client), + } + s.Init(opts...) + + return s +} + +// Init the options. +func (s *Server) Init(opts ...ServerOption) error { + for _, o := range opts { + o(&s.opts) + } + // options defaults + s.initOptions() + + return nil +} + +// ListenAndServe starts the server. +func (s *Server) ListenAndServe(ctx context.Context, addr string) error { + if addr == "" { + addr = DefaultListenAddr + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + return s.Serve(ctx, conn) +} + +// Serve the server with a net.PacketConn. +func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { + if err := s.validateMetadataBuilder(); err != nil { + return err + } + + if err := s.validateRouter(); err != nil { + return err + } + + // listen the address + listener, err := newListener(conn, s.opts.TLSConfig, s.opts.QuicConfig) + if err != nil { + log.Errorf("%slistener.Listen: err=%v", ServerLogPrefix, err) + return err + } + defer listener.Close() + log.Printf("%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), mode(), listener.Versions(), s.authNames()) + + s.state = ConnStateConnected + for { + // create a new connection when new hpds-client connected + sctx, cancel := context.WithCancel(ctx) + defer cancel() + + connect, e := listener.Accept(sctx) + if e != nil { + log.Errorf("%screate connection error: %v", ServerLogPrefix, e) + return e + } + + connID := GetConnId(connect) + log.Infof("%s1/ new connection: %s", ServerLogPrefix, connID) + + go func(ctx context.Context, qconn quic.Connection) { + for { + log.Infof("%s2/ waiting for new stream", ServerLogPrefix) + stream, err := qconn.AcceptStream(ctx) + if err != nil { + // if client close the connection, then we should close the connection + // @CC: when Source close the connection, it won't affect connectors + name := "--" + if conn := s.connector.Get(connID); conn != nil { + conn.Close() + // connector + s.connector.Remove(connID) + route := s.router.Route(conn.Metadata()) + if route != nil { + route.Remove(connID) + } + name = conn.Name() + } + log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connID, err) + break + } + defer stream.Close() + + log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connID) + // process frames on stream + // c := newContext(connId, stream) + c := newContext(connect, stream) + defer c.Clean() + s.handleConnection(c) + log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID()) + } + }(sctx, connect) + } +} + +// Close will shut down the server. +func (s *Server) Close() error { + if s.router != nil { + s.router.Clean() + } + // connector + if s.connector != nil { + s.connector.Clean() + } + return nil +} + +// handle streams on a connection +func (s *Server) handleConnection(c *Context) { + fs := NewFrameStream(c.Stream) + // check update for stream + for { + log.Debugf("%shandleConnection waiting read next...", ServerLogPrefix) + f, err := fs.ReadFrame() + if err != nil { + // if client close connection, will get ApplicationError with code = 0x00 + if e, ok := err.(*quic.ApplicationError); ok { + if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) { + // client abort + log.Infof("%sclient close the connection", ServerLogPrefix) + break + } else { + ye := hpds_err.New(hpds_err.Parse(e.ErrorCode), err) + log.Errorf("%s[ERR] %s", ServerLogPrefix, ye) + } + } else if err == io.EOF { + log.Infof("%sthe connection is EOF", ServerLogPrefix) + break + } + if errors.Is(err, net.ErrClosed) { + // if client close the connection, net.ErrClosed will be raised + // by quic-go IdleTimeoutError after connection's KeepAlive config. + log.Warnf("%s[ERR] net.ErrClosed on [handleConnection] %v", ServerLogPrefix, net.ErrClosed) + c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed") + break + } + // any error occurred, we should close the stream + // after this, conn.AcceptStream() will raise the error + c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error()) + log.Warnf("%sconnection.Close()", ServerLogPrefix) + break + } + + frameType := f.Type() + data := f.Encode() + log.Debugf("%stype=%s, frame[%d]=%# x", ServerLogPrefix, frameType, len(data), frame.Shortly(data)) + // add frame to context + context := c.WithFrame(f) + + // before frame handlers + for _, handler := range s.beforeHandlers { + if e := handler(context); e != nil { + log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) + context.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error()) + return + } + } + // main handler + if e := s.mainFrameHandler(context); e != nil { + log.Errorf("%smainFrameHandler e: %s", ServerLogPrefix, e) + context.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error()) + return + } + // after frame handler + for _, handler := range s.afterHandlers { + if e := handler(context); e != nil { + log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e) + context.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error()) + return + } + } + } +} + +func (s *Server) mainFrameHandler(c *Context) error { + var err error + frameType := c.Frame.Type() + + switch frameType { + case frame.TagOfHandshakeFrame: + if err = s.handleHandshakeFrame(c); err != nil { + log.Errorf("%shandleHandshakeFrame err: %s", ServerLogPrefix, err) + // close connections early to avoid resource consumption + if c.Stream != nil { + goawayFrame := frame.NewGoawayFrame(err.Error()) + if _, e := c.Stream.Write(goawayFrame.Encode()); e != nil { + log.Errorf("%s write to client[%s] GoawayFrame error:%v", ServerLogPrefix, c.ConnId, e) + return e + } + } + } + // case frame.TagOfPingFrame: + // s.handlePingFrame(mainStream, connection, f.(*frame.PingFrame)) + case frame.TagOfDataFrame: + if err = s.handleDataFrame(c); err != nil { + c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err)) + } else { + conn := s.connector.Get(c.connId) + if conn != nil && conn.ClientType() == ClientTypeProtocolGateway { + f := c.Frame.(*frame.DataFrame) + f.GetMetaFrame().SetMetadata(conn.Metadata().Encode()) + s.dispatchToDownStreams(f) + } + // observe data tags back flow + s.handleBackFlowFrame(c) + } + default: + log.Errorf("%serr=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode())) + } + return nil +} + +// handle HandShakeFrame +func (s *Server) handleHandshakeFrame(c *Context) error { + f := c.Frame.(*frame.HandshakeFrame) + + log.Debugf("%sGOT HandshakeFrame : %# x", ServerLogPrefix, f) + // basic info + connId := c.ConnId() + clientId := f.ClientId + clientType := ClientType(f.ClientType) + stream := c.Stream + // credential + log.Debugf("%sClientType=%# x is %s, ClientId=%s, Credential=%s", ServerLogPrefix, f.ClientType, ClientType(f.ClientType), clientId, authName(f.AuthName())) + // authenticate + if !s.authenticate(f) { + err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName())) + // return err + log.Debugf("%s <%s> [%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, connId) + rejectedFrame := frame.NewRejectedFrame(err.Error()) + if _, err = stream.Write(rejectedFrame.Encode()); err != nil { + log.Debugf("%s write to <%s> [%s](%s) RejectedFrame error:%v", ServerLogPrefix, clientType, f.Name, connId, err) + return err + } + return nil + } + + // client type + var conn Connection + switch clientType { + case ClientTypeProtocolGateway, ClientTypeStreamFunction: + // metadata + metadata, err := s.metadataBuilder.Build(f) + if err != nil { + return err + } + conn = newConnection(f.Name, f.ClientId, clientType, metadata, stream, f.ObserveDataTags) + + if clientType == ClientTypeStreamFunction { + // route + route := s.router.Route(metadata) + if route == nil { + return errors.New("handleHandshakeFrame route is nil") + } + if e1 := route.Add(connId, f.Name, f.ObserveDataTags); e1 != nil { + // duplicate name + if e2, ok := e1.(hpds_err.DuplicateNameError); ok { + existsConnID := e2.ConnId() + if conn = s.connector.Get(existsConnID); conn != nil { + log.Debugf("%s%s, write to SFN[%s](%s) GoawayFrame", ServerLogPrefix, e2.Error(), f.Name, existsConnID) + goawayFrame := frame.NewGoawayFrame(e2.Error()) + if e3 := conn.Write(goawayFrame); e3 != nil { + log.Errorf("%s write to SFN[%s] GoawayFrame error:%v", ServerLogPrefix, f.Name, e3) + return e3 + } + } + } else { + return e1 + } + } + } + case ClientTypeMessageQueue: + conn = newConnection(f.Name, f.ClientId, clientType, nil, stream, f.ObserveDataTags) + default: + // unknown client type + s.connector.Remove(connId) + err := fmt.Errorf("Illegal ClientType: %#x", f.ClientType) + c.CloseWithError(hpds_err.ErrorCodeUnknownClient, err.Error()) + return err + } + + s.connector.Add(connId, conn) + log.Printf("%s <%s> [%s][%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, clientId, connId) + return nil +} + +// handle handleGoawayFrame +func (s *Server) handleGoawayFrame(c *Context) error { + f := c.Frame.(*frame.GoawayFrame) + + log.Debugf("%s GOT GoawayFrame code=%d, message==%s", ServerLogPrefix, hpds_err.ErrorCodeGoaway, f.Message()) + // c.CloseWithError(f.Code(), f.Message()) + _, err := c.Stream.Write(f.Encode()) + return err +} + +// will reuse quic-go's keep-alive feature +// func (s *Server) handlePingFrame(stream quic.Stream, conn quic.Connection, f *frame.PingFrame) error { +// log.Infof("%s------> GOT PingFrame : %# x", ServerLogPrefix, f) +// return nil +// } + +func (s *Server) handleDataFrame(c *Context) error { + // counter +1 + atomic.AddInt64(&s.counterOfDataFrame, 1) + // currentIssuer := f.GetIssuer() + fromId := c.ConnId() + from := s.connector.Get(fromId) + if from == nil { + log.Warnf("%shandleDataFrame connector cannot find %s", ServerLogPrefix, fromId) + return fmt.Errorf("handleDataFrame connector cannot find %s", fromId) + } + + f := c.Frame.(*frame.DataFrame) + + var metadata Metadata + if from.ClientType() == ClientTypeMessageQueue { + m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata()) + if err != nil { + return err + } + metadata = m + } else { + metadata = from.Metadata() + } + + // route + route := s.router.Route(metadata) + if route == nil { + log.Warnf("%shandleDataFrame route is nil", ServerLogPrefix) + return fmt.Errorf("handleDataFrame route is nil") + } + + // get stream function connection ids from route + connIds := route.GetForwardRoutes(f.GetDataTag()) + for _, toId := range connIds { + conn := s.connector.Get(toId) + if conn == nil { + log.Errorf("%sconn is nil: (%s)", ServerLogPrefix, toId) + continue + } + + to := conn.Name() + log.Debugf("%shandleDataFrame tag=%#x tid=%s, counter=%d, from=[%s](%s), to=[%s](%s)", ServerLogPrefix, f.Tag(), f.TransactionId(), s.counterOfDataFrame, from.Name(), fromId, to, toId) + + // write data frame to stream + if err := conn.Write(f); err != nil { + log.Warnf("%shandleDataFrame conn.Write tag=%#x tid=%s, from=[%s](%s), to=[%s](%s), %v", ServerLogPrefix, f.Tag(), f.TransactionId(), from.Name(), fromId, to, toId, err) + } + } + + return nil +} + +func (s *Server) handleBackFlowFrame(c *Context) error { + f := c.Frame.(*frame.DataFrame) + tag := f.GetDataTag() + carriage := f.GetCarriage() + sourceId := f.SourceId() + // write to Protocol Gateway with BackFlowFrame + bf := frame.NewBackFlowFrame(tag, carriage) + sourceConns := s.connector.GetProtocolGatewayConnections(sourceId, tag) + // conn := s.connector.Get(c.connId) + // logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage) + for _, source := range sourceConns { + if source != nil { + log.Debugf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, result=%# x", ServerLogPrefix, tag, sourceId, frame.Shortly(carriage)) + if err := source.Write(bf); err != nil { + log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err) + return err + } + } + } + return nil +} + +// StatsFunctions returns the sfn stats of server. +func (s *Server) StatsFunctions() map[string]string { + return s.connector.GetSnapshot() +} + +// StatsCounter returns how many DataFrames pass through server. +func (s *Server) StatsCounter() int64 { + return s.counterOfDataFrame +} + +// DownStreams return all the downstream servers. +func (s *Server) DownStreams() map[string]*Client { + return s.downStreams +} + +// ConfigRouter is used to set router by Message Queue +func (s *Server) ConfigRouter(router Router) { + s.mu.Lock() + s.router = router + log.Debugf("%sconfig router is %#v", ServerLogPrefix, router) + s.mu.Unlock() +} + +// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue +func (s *Server) ConfigMetadataBuilder(builder MetadataBuilder) { + s.mu.Lock() + s.metadataBuilder = builder + log.Debugf("%sconfig metadataBuilder is %#v", ServerLogPrefix, builder) + s.mu.Unlock() +} + +// AddDownstreamServer add a downstream server to this server. all the DataFrames will be +// dispatch to all the downStreams. +func (s *Server) AddDownstreamServer(addr string, c *Client) { + s.mu.Lock() + s.downStreams[addr] = c + s.mu.Unlock() +} + +// dispatch every DataFrames to all downStreams +func (s *Server) dispatchToDownStreams(df *frame.DataFrame) { + for addr, ds := range s.downStreams { + log.Debugf("%sdispatching to [%s]: %# x", ServerLogPrefix, addr, df.Tag()) + ds.WriteFrame(df) + } +} + +// GetConnId get quic connection id +func GetConnId(conn quic.Connection) string { + return conn.RemoteAddr().String() +} + +func (s *Server) initOptions() { + // defaults +} + +func (s *Server) validateRouter() error { + if s.router == nil { + return errors.New("server's router is nil") + } + return nil +} + +func (s *Server) validateMetadataBuilder() error { + if s.metadataBuilder == nil { + return errors.New("server's metadataBuilder is nil") + } + return nil +} + +// Options returns the options of server. +func (s *Server) Options() ServerOptions { + return s.opts +} + +// Connector returns the connector of server. +func (s *Server) Connector() Connector { + return s.connector +} + +// SetBeforeHandlers set the before handlers of server. +func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) { + s.beforeHandlers = append(s.beforeHandlers, handlers...) +} + +// SetAfterHandlers set the after handlers of server. +func (s *Server) SetAfterHandlers(handlers ...FrameHandler) { + s.afterHandlers = append(s.afterHandlers, handlers...) +} + +func (s *Server) authNames() []string { + if len(s.opts.Auths) == 0 { + return []string{"none"} + } + result := make([]string, 0) + for _, auth := range s.opts.Auths { + result = append(result, auth.Name()) + } + return result +} + +func (s *Server) authenticate(f *frame.HandshakeFrame) bool { + if len(s.opts.Auths) > 0 { + for _, auth := range s.opts.Auths { + if f.AuthName() == auth.Name() { + isAuthenticated := auth.Authenticate(f.AuthPayload()) + if isAuthenticated { + log.Debugf("%sauthenticated==%v", ServerLogPrefix, isAuthenticated) + return isAuthenticated + } + } + } + return false + } + return true +} + +func mode() string { + if pkgtls.IsDev() { + return "DEVELOPMENT" + } + return "PRODUCTION" +} + +func authName(name string) string { + if name == "" { + return "empty" + } + + return name +} diff --git a/server_options.go b/server_options.go new file mode 100644 index 0000000..137a0f0 --- /dev/null +++ b/server_options.go @@ -0,0 +1,56 @@ +package network + +import ( + "crypto/tls" + "net" + + "git.hpds.cc/Component/network/auth" + "github.com/lucas-clemente/quic-go" +) + +// ServerOptions are the options for HPDS Network server. +type ServerOptions struct { + QuicConfig *quic.Config + TLSConfig *tls.Config + Addr string + Auths []auth.Authentication + Conn net.PacketConn +} + +// WithAddr sets the server address. +func WithAddr(addr string) ServerOption { + return func(o *ServerOptions) { + o.Addr = addr + } +} + +// WithAuth sets the server authentication method. +func WithAuth(name string, args ...string) ServerOption { + return func(o *ServerOptions) { + if auth, ok := auth.GetAuth(name); ok { + auth.Init(args...) + o.Auths = append(o.Auths, auth) + } + } +} + +// WithServerTLSConfig sets the TLS configuration for the server. +func WithServerTLSConfig(tc *tls.Config) ServerOption { + return func(o *ServerOptions) { + o.TLSConfig = tc + } +} + +// WithServerQuicConfig sets the QUIC configuration for the server. +func WithServerQuicConfig(qc *quic.Config) ServerOption { + return func(o *ServerOptions) { + o.QuicConfig = qc + } +} + +// WithConn sets the connection for the server. +func WithConn(conn net.PacketConn) ServerOption { + return func(o *ServerOptions) { + o.Conn = conn + } +} diff --git a/tls/tls.go b/tls/tls.go new file mode 100644 index 0000000..30de2c2 --- /dev/null +++ b/tls/tls.go @@ -0,0 +1,229 @@ +package tls + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "io/ioutil" + "math/big" + "net" + "os" + "time" +) + +var isDev bool + +// CreateServerTLSConfig creates server tls config. +func CreateServerTLSConfig(host string) (*tls.Config, error) { + // development mode + if isDev { + tc, err := developmentTLSConfig(host) + if err != nil { + return nil, err + } + return tc, nil + } + // production mode + // ca pool + pool, err := getCACertPool() + if err != nil { + return nil, err + } + // server certificate + tlsCert, err := getCertAndKey() + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{*tlsCert}, + ClientCAs: pool, + ClientAuth: tls.RequireAndVerifyClientCert, + NextProtos: []string{"hpds"}, + }, nil +} + +// CreateClientTLSConfig creates client tls config. +func CreateClientTLSConfig() (*tls.Config, error) { + // development mode + if isDev { + return &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"hpds"}, + ClientSessionCache: tls.NewLRUClientSessionCache(64), + }, nil + } + // production mode + pool, err := getCACertPool() + if err != nil { + return nil, err + } + + tlsCert, err := getCertAndKey() + if err != nil { + return nil, err + } + + return &tls.Config{ + InsecureSkipVerify: false, + Certificates: []tls.Certificate{*tlsCert}, + RootCAs: pool, + NextProtos: []string{"hpds"}, + ClientSessionCache: tls.NewLRUClientSessionCache(0), + }, nil +} + +func getCACertPool() (*x509.CertPool, error) { + var err error + var caCert []byte + + caCertPath := os.Getenv("HPDS_TLS_CACERT_FILE") + if len(caCertPath) == 0 { + return nil, errors.New("tls: must provide CA certificate on production mode, you can configure this via environment variables: `HPDS_TLS_CACERT_FILE`") + } + + caCert, err = ioutil.ReadFile(caCertPath) + if err != nil { + return nil, err + } + + if len(caCert) == 0 { + return nil, errors.New("tls: cannot load CA cert") + } + + pool := x509.NewCertPool() + if ok := pool.AppendCertsFromPEM(caCert); !ok { + return nil, errors.New("tls: cannot append CA cert to pool") + } + + return pool, nil +} + +func getCertAndKey() (*tls.Certificate, error) { + var err error + var cert, key []byte + + certPath := os.Getenv("HPDS_TLS_CERT_FILE") + keyPath := os.Getenv("HPDS_TLS_KEY_FILE") + if len(certPath) == 0 || len(keyPath) == 0 { + return nil, errors.New("tls: must provide certificate on production mode, you can configure this via environment variables: `HPDS_TLS_CERT_FILE` and `HPDS_TLS_KEY_FILE`") + } + + // certificate + cert, err = ioutil.ReadFile(certPath) + if err != nil { + return nil, err + } + // private key + key, err = ioutil.ReadFile(keyPath) + if err != nil { + return nil, err + } + + if len(cert) == 0 || len(key) == 0 { + return nil, errors.New("tls: cannot load tls cert/key") + } + + tlsCert, err := tls.X509KeyPair(cert, key) + if err != nil { + return nil, err + } + + return &tlsCert, nil +} + +// IsDev development mode +func IsDev() bool { + return isDev +} + +// developmentTLSConfig Setup a bare-bones TLS config for the server +func developmentTLSConfig(host ...string) (*tls.Config, error) { + tlsCert, err := generateCertificate(host...) + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + ClientSessionCache: tls.NewLRUClientSessionCache(1), + NextProtos: []string{"hpds"}, + }, nil +} + +func generateCertificate(host ...string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(time.Hour * 24 * 365) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"HPDS"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + for _, h := range host { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + // create public key + certOut := bytes.NewBuffer(nil) + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + if err != nil { + return tls.Certificate{}, err + } + + // create private key + keyOut := bytes.NewBuffer(nil) + b, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + err = pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + if err != nil { + return tls.Certificate{}, err + } + + return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes()) +} + +func init() { + env := os.Getenv("HPDS_ENV") + isDev = len(env) == 0 || env != "production" +} diff --git a/workflow.go b/workflow.go new file mode 100644 index 0000000..1e63ba5 --- /dev/null +++ b/workflow.go @@ -0,0 +1,10 @@ +package network + +// Workflow describes stream function workflows. +type Workflow struct { + // Seq represents the sequence id when executing workflows. + Seq int + + // Token represents the name of workflow. + Name string +}