This commit is contained in:
wangjian 2022-10-11 17:36:09 +08:00
parent 7dfeef5599
commit d4226d3b66
41 changed files with 3654 additions and 1 deletions

View File

@ -1,2 +1,4 @@
# network
# network 网络库
### 基于QUIC协议

59
auth/auth.go Normal file
View File

@ -0,0 +1,59 @@
package auth
import "strings"
var (
auths = make(map[string]Authentication)
)
// Authentication for Network server
type Authentication interface {
// Init authentication initialize arguments
Init(args ...string)
// Authenticate authentication client's credential
Authenticate(payload string) bool
// Name authentication name
Name() string
}
// Register register authentication
func Register(authentication Authentication) {
auths[authentication.Name()] = authentication
}
// GetAuth get authentication by name
func GetAuth(name string) (Authentication, bool) {
auth, ok := auths[name]
return auth, ok
}
// Credential client credential
type Credential struct {
name string
payload string
}
// NewCredential create client credential
func NewCredential(payload string) *Credential {
idx := strings.Index(payload, ":")
if idx != -1 {
authName := payload[:idx]
idx++
authPayload := payload[idx:]
return &Credential{
name: authName,
payload: authPayload,
}
}
return &Credential{name: "none"}
}
// Payload client credential payload
func (c *Credential) Payload() string {
return c.payload
}
// Name client credential name
func (c *Credential) Name() string {
return c.name
}

20
auth/auth.puml Normal file
View File

@ -0,0 +1,20 @@
@startuml
namespace auth {
interface Authentication {
+ Init(args ...string)
+ Authenticate(payload string) bool
+ Name() string
}
class Credential << (S,Aquamarine) >> {
- name string
- payload string
+ Payload() string
+ Name() string
}
}
@enduml

465
client.go Normal file
View File

@ -0,0 +1,465 @@
package network
import (
"context"
"errors"
"fmt"
"git.hpds.cc/Component/network/hpds_err"
"git.hpds.cc/Component/network/id"
pkgtls "git.hpds.cc/Component/network/tls"
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go"
"git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/frame"
"git.hpds.cc/Component/network/log"
)
// ClientOption client options
type ClientOption func(*ClientOptions)
// ConnState describes the state of the connection.
type ConnState = string
// Client is the abstraction of a HPDS-Client. a HPDS-Client can be
// Protocol Gateway, Message Queue or StreamFunction.
type Client struct {
name string // name of the client
clientId string // id of the client
clientType ClientType // type of the connection
conn quic.Connection // quic connection
stream quic.Stream // quic stream
state ConnState // state of the connection
processor func(*frame.DataFrame) // functions to invoke when data arrived
receiver func(*frame.BackFlowFrame) // functions to invoke when data is processed
addr string // the address of server connected to
mu sync.Mutex
opts ClientOptions
localAddr string // client local addr, it will be changed on reconnect
logger log.Logger
errChan chan error
closeChan chan bool
closed bool
}
// NewClient creates a new HPDS-Client.
func NewClient(appName string, connType ClientType, opts ...ClientOption) *Client {
c := &Client{
name: appName,
clientId: id.New(),
clientType: connType,
state: ConnStateReady,
opts: ClientOptions{},
errChan: make(chan error),
closeChan: make(chan bool),
}
c.Init(opts...)
once.Do(func() {
c.init()
})
return c
}
// Init the options.
func (c *Client) Init(opts ...ClientOption) error {
for _, o := range opts {
o(&c.opts)
}
return c.initOptions()
}
// Connect connects to HPDS-MessageQueue.
func (c *Client) Connect(ctx context.Context, addr string) error {
// TODO: refactor this later as a Connection Manager
// reconnect
// for download mq
// If you do not check for errors, the connection will be automatically reconnected
go c.reconnect(ctx, addr)
// connect
if err := c.connect(ctx, addr); err != nil {
return err
}
return nil
}
func (c *Client) connect(ctx context.Context, addr string) error {
c.addr = addr
c.state = ConnStateConnecting
// create quic connection
conn, err := quic.DialAddrContext(ctx, addr, c.opts.TLSConfig, c.opts.QuicConfig)
if err != nil {
c.state = ConnStateDisconnected
return err
}
// quic stream
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
c.state = ConnStateDisconnected
return err
}
c.stream = stream
c.conn = conn
c.state = ConnStateAuthenticating
// send handshake
handshake := frame.NewHandshakeFrame(
c.name,
c.clientId,
byte(c.clientType),
c.opts.ObserveDataTags,
c.opts.Credential.Name(),
c.opts.Credential.Payload(),
)
err = c.WriteFrame(handshake)
if err != nil {
c.state = ConnStateRejected
return err
}
c.state = ConnStateConnected
c.localAddr = c.conn.LocalAddr().String()
c.logger.Printf("%s [%s][%s](%s) is connected to HPDS-MQ %s", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr)
// receiving frames
go c.handleFrame()
return nil
}
// handleFrame handles the logic when receiving frame from server.
func (c *Client) handleFrame() {
// transform raw QUIC stream to wire format
fs := NewFrameStream(c.stream)
for {
c.logger.Debugf("%shandleFrame connection state=%v", ClientLogPrefix, c.state)
// this will block until a frame is received
f, err := fs.ReadFrame()
if err != nil {
defer c.stream.Close()
// defer c.conn.CloseWithError(0xD0, err.Error())
c.logger.Debugf("%shandleFrame(): %T | %v", ClientLogPrefix, err, err)
if e, ok := err.(*quic.IdleTimeoutError); ok {
c.logger.Errorf("%sconnection timeout, err=%v, mq addr=%s", ClientLogPrefix, e, c.addr)
c.setState(ConnStateDisconnected)
} else if e, ok := err.(*quic.ApplicationError); ok {
c.logger.Infof("%sapplication error, err=%v, errcode=%v", ClientLogPrefix, e, e.ErrorCode)
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeRejected) {
// if connection is rejected(eg: authenticate fails) from server
c.logger.Errorf("%sIllegal client, server rejected.", ClientLogPrefix)
c.setState(ConnStateRejected)
break
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
// client abort
c.logger.Infof("%sclient close the connection", ClientLogPrefix)
c.setState(ConnStateAborted)
break
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeGoaway) {
// server goaway
c.logger.Infof("%sserver goaway the connection", ClientLogPrefix)
c.setState(ConnStateGoaway)
break
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeHandshake) {
// handshake
c.logger.Errorf("%shandshake fails", ClientLogPrefix)
c.setState(ConnStateRejected)
break
}
} else if errors.Is(err, net.ErrClosed) {
// if client close the connection, net.ErrClosed will be raised
c.logger.Errorf("%sconnection is closed, err=%v", ClientLogPrefix, err)
c.setState(ConnStateDisconnected)
// by quic-go IdleTimeoutError after connection's KeepAlive config.
break
} else {
// any error occurred, we should close the stream
// after this, conn.AcceptStream() will raise the error
c.setState(ConnStateClosed)
c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeUnknown), err.Error())
c.logger.Errorf("%sunknown error occurred, err=%v, state=%v", ClientLogPrefix, err, c.getState())
break
}
}
if f == nil {
break
}
// read frame
// first, get frame type
frameType := f.Type()
c.logger.Debugf("%stype=%s, frame=%# x", ClientLogPrefix, frameType, frame.Shortly(f.Encode()))
switch frameType {
case frame.TagOfHandshakeFrame:
if v, ok := f.(*frame.HandshakeFrame); ok {
c.logger.Debugf("%sreceive HandshakeFrame, name=%v", ClientLogPrefix, v.Name)
}
case frame.TagOfPongFrame:
c.setState(ConnStatePong)
case frame.TagOfAcceptedFrame:
c.setState(ConnStateAccepted)
case frame.TagOfRejectedFrame:
c.setState(ConnStateRejected)
if v, ok := f.(*frame.RejectedFrame); ok {
c.logger.Errorf("%s receive RejectedFrame, message=%s", ClientLogPrefix, v.Message())
c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeRejected), v.Message())
c.errChan <- errors.New(v.Message())
break
}
case frame.TagOfGoawayFrame:
c.setState(ConnStateGoaway)
if v, ok := f.(*frame.GoawayFrame); ok {
c.logger.Errorf("%s receive GoawayFrame, message=%s", ClientLogPrefix, v.Message())
c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeGoaway), v.Message())
c.errChan <- errors.New(v.Message())
break
}
case frame.TagOfDataFrame: // DataFrame carries user's data
if v, ok := f.(*frame.DataFrame); ok {
c.setState(ConnStateTransportData)
c.logger.Debugf("%sreceive DataFrame, tag=%#x, tid=%s, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.TransactionId(), v.GetCarriage())
if c.processor == nil {
c.logger.Warnf("%sprocessor is nil", ClientLogPrefix)
} else {
// TODO: should c.processor accept a DataFrame as parameter?
// c.processor(v.GetDataTagID(), v.GetCarriage(), v.GetMetaFrame())
c.processor(v)
}
}
case frame.TagOfBackFlowFrame:
if v, ok := f.(*frame.BackFlowFrame); ok {
c.logger.Debugf("%sreceive BackFlowFrame, tag=%#x, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.GetCarriage())
if c.receiver == nil {
c.logger.Warnf("%sreceiver is nil", ClientLogPrefix)
} else {
c.setState(ConnStateBackFlow)
c.receiver(v)
}
}
default:
c.logger.Errorf("%sunknown signal", ClientLogPrefix)
}
}
}
// Close the client.
func (c *Client) Close() (err error) {
if c.conn != nil {
c.logger.Printf("%sclose the connection, name:%s, id:%s, addr:%s", ClientLogPrefix, c.name, c.clientId, c.conn.RemoteAddr().String())
}
if c.stream != nil {
err = c.stream.Close()
if err != nil {
c.logger.Errorf("%s stream.Close(): %v", ClientLogPrefix, err)
}
}
if c.conn != nil {
err = c.conn.CloseWithError(0, "client-ask-to-close-this-connection")
if err != nil {
c.logger.Errorf("%s connection.Close(): %v", ClientLogPrefix, err)
}
}
// close channel
c.mu.Lock()
if !c.closed {
close(c.errChan)
close(c.closeChan)
c.closed = true
}
c.mu.Unlock()
return err
}
// WriteFrame writes a frame to the connection, gurantee threadsafe.
func (c *Client) WriteFrame(frm frame.Frame) error {
// write on QUIC stream
if c.stream == nil {
return errors.New("stream is nil")
}
if c.state == ConnStateDisconnected || c.state == ConnStateRejected {
return fmt.Errorf("client connection state is %s", c.state)
}
c.logger.Debugf("%s[%s](%s)@%s WriteFrame() will write frame: %s", ClientLogPrefix, c.name, c.localAddr, c.state, frm.Type())
data := frm.Encode()
// emit raw bytes of Frame
c.mu.Lock()
n, err := c.stream.Write(data)
c.mu.Unlock()
c.logger.Debugf("%sWriteFrame() wrote n=%d, data=%# x", ClientLogPrefix, n, frame.Shortly(data))
if err != nil {
c.setState(ConnStateDisconnected)
// c.state = ConnStateDisconnected
if e, ok := err.(*quic.IdleTimeoutError); ok {
c.logger.Errorf("%sWriteFrame() connection timeout, err=%v", ClientLogPrefix, e)
} else {
c.logger.Errorf("%sWriteFrame() wrote error=%v", ClientLogPrefix, err)
return err
}
}
if n != len(data) {
err := errors.New("[client] hpds Client .Write() wrote error")
c.logger.Errorf("%s error:%v", ClientLogPrefix, err)
return err
}
return err
}
// update connection state
func (c *Client) setState(state ConnState) {
c.logger.Debugf("setState to:%s", state)
c.mu.Lock()
c.state = state
c.mu.Unlock()
}
// getState get connection state
func (c *Client) getState() ConnState {
c.mu.Lock()
defer c.mu.Unlock()
return c.state
}
// update connection local addr
func (c *Client) setLocalAddr(addr string) {
c.mu.Lock()
c.localAddr = addr
c.mu.Unlock()
}
// SetDataFrameObserver sets the data frame handler.
func (c *Client) SetDataFrameObserver(fn func(*frame.DataFrame)) {
c.processor = fn
c.logger.Debugf("%sSetDataFrameObserver(%v)", ClientLogPrefix, c.processor)
}
// SetBackFlowFrameObserver sets the backflow frame handler.
func (c *Client) SetBackFlowFrameObserver(fn func(*frame.BackFlowFrame)) {
c.receiver = fn
c.logger.Debugf("%sSetBackFlowFrameObserver(%v)", ClientLogPrefix, c.receiver)
}
// reconnect the connection between client and server.
func (c *Client) reconnect(ctx context.Context, addr string) {
t := time.NewTicker(1 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
c.logger.Debugf("%s[%s](%s) context.Done()", ClientLogPrefix, c.name, c.localAddr)
return
case <-c.closeChan:
c.logger.Debugf("%s[%s](%s) close channel", ClientLogPrefix, c.name, c.localAddr)
return
case <-t.C:
if c.getState() == ConnStateDisconnected {
c.logger.Printf("%s[%s][%s](%s) is reconnecting to HPDS-MQ %s...", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr)
err := c.connect(ctx, addr)
if err != nil {
c.logger.Errorf("%s[%s][%s](%s) reconnect error:%v", ClientLogPrefix, c.name, c.clientId, c.localAddr, err)
}
}
}
}
}
func (c *Client) init() {
// // tracing
// _, _, err := tracing.NewTracerProvider(c.name)
// if err != nil {
// logger.Errorf("tracing: %v", err)
// }
}
// ServerAddr returns the address of the server.
func (c *Client) ServerAddr() string {
return c.addr
}
// initOptions init options defaults
func (c *Client) initOptions() error {
// logger
if c.logger == nil {
if c.opts.Logger != nil {
c.logger = c.opts.Logger
} else {
c.logger = log.Default()
}
}
// observe tag list
if c.opts.ObserveDataTags == nil {
c.opts.ObserveDataTags = make([]byte, 0)
}
// credential
if c.opts.Credential == nil {
c.opts.Credential = auth.NewCredential("")
}
// tls config
if c.opts.TLSConfig == nil {
tc, err := pkgtls.CreateClientTLSConfig()
if err != nil {
c.logger.Errorf("%sCreateClientTLSConfig: %v", ClientLogPrefix, err)
return err
}
c.opts.TLSConfig = tc
}
// quic config
if c.opts.QuicConfig == nil {
c.opts.QuicConfig = &quic.Config{
Versions: []quic.VersionNumber{quic.Version1, quic.VersionDraft29},
MaxIdleTimeout: time.Second * 40,
KeepAlivePeriod: time.Second * 20,
MaxIncomingStreams: 1000,
MaxIncomingUniStreams: 1000,
HandshakeIdleTimeout: time.Second * 3,
InitialStreamReceiveWindow: 1024 * 1024 * 2,
InitialConnectionReceiveWindow: 1024 * 1024 * 2,
TokenStore: quic.NewLRUTokenStore(10, 5),
DisablePathMTUDiscovery: true,
}
}
// credential
if c.opts.Credential != nil {
c.logger.Printf("%suse credential: [%s]", ClientLogPrefix, c.opts.Credential.Name())
}
return nil
}
// SetObserveDataTags set the data tag list that will be observed.
// Deprecated: use hpds.WithObserveDataTags instead
func (c *Client) SetObserveDataTags(tag ...byte) {
c.opts.ObserveDataTags = append(c.opts.ObserveDataTags, tag...)
}
// Logger get client's logger instance, you can customize this using `hpds.WithLogger`
func (c *Client) Logger() log.Logger {
return c.logger
}
// SetErrorHandler set error handler
func (c *Client) SetErrorHandler(fn func(err error)) {
if fn != nil {
go func() {
err := <-c.errChan
if err != nil {
fn(err)
}
}()
}
}
// ClientId return the client ID
func (c *Client) ClientId() string {
return c.clientId
}

53
client_options.go Normal file
View File

@ -0,0 +1,53 @@
package network
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/log"
)
// ClientOptions are the options for HPDS client.
type ClientOptions struct {
ObserveDataTags []byte
QuicConfig *quic.Config
TLSConfig *tls.Config
Credential *auth.Credential
Logger log.Logger
}
// WithObserveDataTags sets data tag list for the client.
func WithObserveDataTags(tags ...byte) ClientOption {
return func(o *ClientOptions) {
o.ObserveDataTags = tags
}
}
// WithCredential sets the client credential method (used by client).
func WithCredential(payload string) ClientOption {
return func(o *ClientOptions) {
o.Credential = auth.NewCredential(payload)
}
}
// WithClientTLSConfig sets tls config for the client.
func WithClientTLSConfig(tc *tls.Config) ClientOption {
return func(o *ClientOptions) {
o.TLSConfig = tc
}
}
// WithClientQuicConfig sets quic config for the client.
func WithClientQuicConfig(qc *quic.Config) ClientOption {
return func(o *ClientOptions) {
o.QuicConfig = qc
}
}
// WithLogger sets logger for the client.
func WithLogger(logger log.Logger) ClientOption {
return func(o *ClientOptions) {
o.Logger = logger
}
}

28
client_type.go Normal file
View File

@ -0,0 +1,28 @@
package network
const (
// ClientTypeNone is connection type "None".
ClientTypeNone ClientType = 0xFF
// ClientTypeProtocolGateway is connection type "Protocol Gateway".
ClientTypeProtocolGateway ClientType = 0x5F
// ClientTypeMessageQueue is connection type "Message Queue".
ClientTypeMessageQueue ClientType = 0x5E
// ClientTypeStreamFunction is connection type "Stream Function".
ClientTypeStreamFunction ClientType = 0x5D
)
// ClientType represents the connection type.
type ClientType byte
func (c ClientType) String() string {
switch c {
case ClientTypeProtocolGateway:
return "Source"
case ClientTypeMessageQueue:
return "Message Queue"
case ClientTypeStreamFunction:
return "Stream Function"
default:
return "None"
}
}

95
connection.go Normal file
View File

@ -0,0 +1,95 @@
package network
import (
"git.hpds.cc/Component/network/frame"
"git.hpds.cc/Component/network/log"
"io"
"sync"
)
// Connection wraps the specific io connections (typically quic.Connection) to transfer coder frames
type Connection interface {
io.Closer
// Name returns the name of the connection, which is set by clients
Name() string
// ClientId connection client ID
ClientId() string
// ClientType returns the type of the client (Protocol Gateway | Message Queue | Stream Function)
ClientType() ClientType
// Metadata returns the extra info of the application
Metadata() Metadata
// Write should goroutine-safely send coder frames to peer side
Write(f frame.Frame) error
// ObserveDataTags observed data tags
ObserveDataTags() []byte
}
type connection struct {
name string
clientType ClientType
metadata Metadata
stream io.ReadWriteCloser
clientId string
observed []byte // observed data tags
mu sync.Mutex
closed bool
}
func newConnection(name string, clientId string, clientType ClientType, metadata Metadata,
stream io.ReadWriteCloser, observed []byte) Connection {
return &connection{
name: name,
clientId: clientId,
clientType: clientType,
observed: observed,
metadata: metadata,
stream: stream,
closed: false,
}
}
// Close implements io.Close interface
func (c *connection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return c.stream.Close()
}
// Name returns the name of the connection, which is set by clients
func (c *connection) Name() string {
return c.name
}
// ClientType returns the type of the connection (Protocol Gateway | Message Queue | Stream Function )
func (c *connection) ClientType() ClientType {
return c.clientType
}
// Metadata returns the extra info of the application
func (c *connection) Metadata() Metadata {
return c.metadata
}
// Write should goroutine-safely send coder frames to peer side
func (c *connection) Write(f frame.Frame) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
log.Warnf("%sclient stream is closed: %s", ServerLogPrefix, c.clientId)
return nil
}
_, err := c.stream.Write(f.Encode())
return err
}
// ObserveDataTags observed data tags
func (c *connection) ObserveDataTags() []byte {
return c.observed
}
// ClientId connection client id
func (c *connection) ClientId() string {
return c.clientId
}

87
connector.go Normal file
View File

@ -0,0 +1,87 @@
package network
import (
"git.hpds.cc/Component/network/log"
"sync"
)
var _ Connector = &connector{}
// Connector is an interface to manage the connections and applications.
type Connector interface {
// Add a connection.
Add(connId string, conn Connection)
// Remove a connection.
Remove(connId string)
// Get a connection by connection id.
Get(connId string) Connection
// GetSnapshot gets the snapshot of all connections.
GetSnapshot() map[string]string
// GetProtocolGatewayConnections gets the connections by Protocol Gateway observe tags.
GetProtocolGatewayConnections(sourceId string, tags byte) []Connection
// Clean the connector.
Clean()
}
type connector struct {
conns sync.Map
}
func newConnector() Connector {
return &connector{conns: sync.Map{}}
}
// Add a connection.
func (c *connector) Add(connID string, conn Connection) {
log.Debugf("%sconnector add: connId=%s", ServerLogPrefix, connID)
c.conns.Store(connID, conn)
}
// Remove a connection.
func (c *connector) Remove(connID string) {
log.Debugf("%sconnector remove: connId=%s", ServerLogPrefix, connID)
c.conns.Delete(connID)
}
// Get a connection by connection id.
func (c *connector) Get(connID string) Connection {
log.Debugf("%sconnector get connection: connId=%s", ServerLogPrefix, connID)
if conn, ok := c.conns.Load(connID); ok {
return conn.(Connection)
}
return nil
}
// GetProtocolGatewayConnections gets the Protocol Gateway connection by tag.
func (c *connector) GetProtocolGatewayConnections(sourceId string, tag byte) []Connection {
conns := make([]Connection, 0)
c.conns.Range(func(key interface{}, val interface{}) bool {
conn := val.(Connection)
for _, v := range conn.ObserveDataTags() {
if v == tag && conn.ClientType() == ClientTypeProtocolGateway && conn.ClientId() == sourceId {
conns = append(conns, conn)
}
}
return true
})
return conns
}
// GetSnapshot gets the snapshot of all connections.
func (c *connector) GetSnapshot() map[string]string {
result := make(map[string]string)
c.conns.Range(func(key interface{}, val interface{}) bool {
connID := key.(string)
conn := val.(Connection)
result[connID] = conn.Name()
return true
})
return result
}
// Clean the connector.
func (c *connector) Clean() {
c.conns = sync.Map{}
}

40
constant.go Normal file
View File

@ -0,0 +1,40 @@
package network
import (
"math/rand"
"sync"
"time"
)
var (
once sync.Once
)
// ConnState represents the state of a connection.
const (
ConnStateReady ConnState = "Ready"
ConnStateDisconnected ConnState = "Disconnected"
ConnStateConnecting ConnState = "Connecting"
ConnStateConnected ConnState = "Connected"
ConnStateAuthenticating ConnState = "Authenticating"
ConnStateAccepted ConnState = "Accepted"
ConnStateRejected ConnState = "Rejected"
ConnStatePing ConnState = "Ping"
ConnStatePong ConnState = "Pong"
ConnStateTransportData ConnState = "TransportData"
ConnStateAborted ConnState = "Aborted"
ConnStateClosed ConnState = "Closed" // close connection by server
ConnStateGoaway ConnState = "Goaway"
ConnStateBackFlow ConnState = "BackFlow"
)
// Prefix is the prefix for logger.
const (
ClientLogPrefix = "\033[36m[network:client]\033[0m "
ServerLogPrefix = "\033[32m[network:server]\033[0m "
ParseFrameLogPrefix = "\033[36m[network:stream_parser]\033[0m "
)
func init() {
rand.Seed(time.Now().Unix())
}

191
context.go Normal file
View File

@ -0,0 +1,191 @@
package network
import (
"git.hpds.cc/Component/network/hpds_err"
"git.hpds.cc/Component/network/log"
"io"
"sync"
"time"
"git.hpds.cc/Component/network/frame"
"github.com/lucas-clemente/quic-go"
)
// Context for Network Server.
type Context struct {
// Conn is the connection of client.
Conn quic.Connection
connId string
// Stream is the long-lived connection between client and server.
Stream io.ReadWriteCloser
// Frame receives from client.
Frame frame.Frame
// Keys store the key/value pairs in context.
Keys map[string]interface{}
mu sync.RWMutex
}
func newContext(conn quic.Connection, stream quic.Stream) *Context {
return &Context{
Conn: conn,
connId: conn.RemoteAddr().String(),
Stream: stream,
// keys: make(map[string]interface{}),
}
}
// WithFrame sets a frame to context.
func (c *Context) WithFrame(f frame.Frame) *Context {
c.Frame = f
return c
}
// Clean the context.
func (c *Context) Clean() {
log.Debugf("%sconn[%s] context clean", ServerLogPrefix, c.connId)
c.Stream = nil
c.Frame = nil
c.Keys = nil
c.Conn = nil
}
// CloseWithError closes the stream and cleans the context.
func (c *Context) CloseWithError(code hpds_err.ErrorCode, msg string) {
log.Debugf("%sconn[%s] context close, errCode=%#x, msg=%s", ServerLogPrefix, c.connId, code, msg)
if c.Stream != nil {
c.Stream.Close()
}
if c.Conn != nil {
c.Conn.CloseWithError(quic.ApplicationErrorCode(code), msg)
}
c.Clean()
}
// ConnId get quic connection id
func (c *Context) ConnId() string {
return c.connId
}
// Set a key/value pair to context.
func (c *Context) Set(key string, value interface{}) {
c.mu.Lock()
if c.Keys == nil {
c.Keys = make(map[string]interface{})
}
c.Keys[key] = value
c.mu.Unlock()
}
// Get the value by a specified key.
func (c *Context) Get(key string) (value interface{}, exists bool) {
c.mu.RLock()
value, exists = c.Keys[key]
c.mu.RUnlock()
return
}
// GetString gets a string value by a specified key.
func (c *Context) GetString(key string) (s string) {
if val, ok := c.Get(key); ok && val != nil {
s, _ = val.(string)
}
return
}
// GetBool gets a bool value by a specified key.
func (c *Context) GetBool(key string) (b bool) {
if val, ok := c.Get(key); ok && val != nil {
b, _ = val.(bool)
}
return
}
// GetInt gets an int value by a specified key.
func (c *Context) GetInt(key string) (i int) {
if val, ok := c.Get(key); ok && val != nil {
i, _ = val.(int)
}
return
}
// GetInt64 gets an int64 value by a specified key.
func (c *Context) GetInt64(key string) (i64 int64) {
if val, ok := c.Get(key); ok && val != nil {
i64, _ = val.(int64)
}
return
}
// GetUint gets an uint value by a specified key.
func (c *Context) GetUint(key string) (ui uint) {
if val, ok := c.Get(key); ok && val != nil {
ui, _ = val.(uint)
}
return
}
// GetUint64 gets an uint64 value by a specified key.
func (c *Context) GetUint64(key string) (ui64 uint64) {
if val, ok := c.Get(key); ok && val != nil {
ui64, _ = val.(uint64)
}
return
}
// GetFloat64 gets a float64 value by a specified key.
func (c *Context) GetFloat64(key string) (f64 float64) {
if val, ok := c.Get(key); ok && val != nil {
f64, _ = val.(float64)
}
return
}
// GetTime gets a time.Time value by a specified key.
func (c *Context) GetTime(key string) (t time.Time) {
if val, ok := c.Get(key); ok && val != nil {
t, _ = val.(time.Time)
}
return
}
// GetDuration gets a time.Duration value by a specified key.
func (c *Context) GetDuration(key string) (d time.Duration) {
if val, ok := c.Get(key); ok && val != nil {
d, _ = val.(time.Duration)
}
return
}
// GetStringSlice gets a []string value by a specified key.
func (c *Context) GetStringSlice(key string) (ss []string) {
if val, ok := c.Get(key); ok && val != nil {
ss, _ = val.([]string)
}
return
}
// GetStringMap gets a map[string]interface{} value by a specified key.
func (c *Context) GetStringMap(key string) (sm map[string]interface{}) {
if val, ok := c.Get(key); ok && val != nil {
sm, _ = val.(map[string]interface{})
}
return
}
// GetStringMapString gets a map[string]string value by a specified key.
func (c *Context) GetStringMapString(key string) (sms map[string]string) {
if val, ok := c.Get(key); ok && val != nil {
sms, _ = val.(map[string]string)
}
return
}
// GetStringMapStringSlice gets a map[string][]string value by a specified key.
func (c *Context) GetStringMapStringSlice(key string) (smss map[string][]string) {
if val, ok := c.Get(key); ok && val != nil {
smss, _ = val.(map[string][]string)
}
return
}

36
frame/accepted_frame.go Normal file
View File

@ -0,0 +1,36 @@
package frame
import (
coder "git.hpds.cc/Component/mq_coder"
)
// AcceptedFrame is encoded bytes, Tag is a fixed value TYPE_ID_ACCEPTED_FRAME
type AcceptedFrame struct{}
// NewAcceptedFrame creates a new AcceptedFrame with a given TagId of user's data
func NewAcceptedFrame() *AcceptedFrame {
return &AcceptedFrame{}
}
// Type gets the type of Frame.
func (m *AcceptedFrame) Type() Type {
return TagOfAcceptedFrame
}
// Encode to coder encoded bytes.
func (m *AcceptedFrame) Encode() []byte {
accepted := coder.NewNodePacketEncoder(byte(m.Type()))
accepted.AddBytes(nil)
return accepted.Encode()
}
// DecodeToAcceptedFrame decodes coder encoded bytes to AcceptedFrame.
func DecodeToAcceptedFrame(buf []byte) (*AcceptedFrame, error) {
nodeBlock := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &nodeBlock)
if err != nil {
return nil, err
}
return &AcceptedFrame{}, nil
}

View File

@ -0,0 +1,19 @@
package frame
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAcceptedFrameEncode(t *testing.T) {
f := NewAcceptedFrame()
assert.Equal(t, []byte{0x80 | byte(TagOfAcceptedFrame), 0x00}, f.Encode())
}
func TestAcceptedFrameDecode(t *testing.T) {
buf := []byte{0x80 | byte(TagOfAcceptedFrame), 0x00}
ping, err := DecodeToAcceptedFrame(buf)
assert.NoError(t, err)
assert.Equal(t, []byte{0x80 | byte(TagOfAcceptedFrame), 0x00}, ping.Encode())
}

70
frame/backflow_frame.go Normal file
View File

@ -0,0 +1,70 @@
package frame
import (
coder "git.hpds.cc/Component/mq_coder"
)
// BackFlowFrame is a coder encoded bytes
// It's used to receive stream function processed result
type BackFlowFrame struct {
Tag byte
Carriage []byte
}
// NewBackFlowFrame creates a new BackFlowFrame with a given tag and carriage
func NewBackFlowFrame(tag byte, carriage []byte) *BackFlowFrame {
return &BackFlowFrame{
Tag: tag,
Carriage: carriage,
}
}
// Type gets the type of Frame.
func (f *BackFlowFrame) Type() Type {
return TagOfBackFlowFrame
}
// SetCarriage sets the user's raw data
func (f *BackFlowFrame) SetCarriage(buf []byte) *BackFlowFrame {
f.Carriage = buf
return f
}
// Encode to coder encoded bytes
func (f *BackFlowFrame) Encode() []byte {
carriage := coder.NewPrimitivePacketEncoder(f.Tag)
carriage.SetBytesValue(f.Carriage)
node := coder.NewNodePacketEncoder(byte(TagOfBackFlowFrame))
node.AddPrimitivePacket(carriage)
return node.Encode()
}
// GetDataTag return the Tag of user's data
func (f *BackFlowFrame) GetDataTag() byte {
return f.Tag
}
// GetCarriage return data
func (f *BackFlowFrame) GetCarriage() []byte {
return f.Carriage
}
// DecodeToBackFlowFrame decodes coder encoded bytes to BackFlowFrame
func DecodeToBackFlowFrame(buf []byte) (*BackFlowFrame, error) {
nodeBlock := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &nodeBlock)
if err != nil {
return nil, err
}
payload := &BackFlowFrame{}
for _, v := range nodeBlock.PrimitivePackets {
payload.Tag = v.SeqId()
payload.Carriage = v.GetValBuf()
break
}
return payload, nil
}

110
frame/data_frame.go Normal file
View File

@ -0,0 +1,110 @@
package frame
import (
coder "git.hpds.cc/Component/mq_coder"
)
// DataFrame defines the data structure carried with user's data
type DataFrame struct {
metaFrame *MetaFrame
payloadFrame *PayloadFrame
}
// NewDataFrame create `DataFrame` with a transactionId string,
// consider change transactionID to UUID type later
func NewDataFrame() *DataFrame {
data := &DataFrame{
metaFrame: NewMetaFrame(),
}
return data
}
// Type gets the type of Frame.
func (d *DataFrame) Type() Type {
return TagOfDataFrame
}
// Tag return the tag of carriage data.
func (d *DataFrame) Tag() byte {
return d.payloadFrame.Tag
}
// SetCarriage set user's raw data in `DataFrame`
func (d *DataFrame) SetCarriage(tag byte, carriage []byte) {
d.payloadFrame = NewPayloadFrame(tag).SetCarriage(carriage)
}
// GetCarriage return user's raw data in `DataFrame`
func (d *DataFrame) GetCarriage() []byte {
return d.payloadFrame.Carriage
}
// TransactionId return transactionId string
func (d *DataFrame) TransactionId() string {
return d.metaFrame.TransactionId()
}
// SetTransactionId set transactionId string
func (d *DataFrame) SetTransactionId(transactionID string) {
d.metaFrame.SetTransactionId(transactionID)
}
// GetMetaFrame return MetaFrame.
func (d *DataFrame) GetMetaFrame() *MetaFrame {
return d.metaFrame
}
// GetDataTag return the Tag of user's data
func (d *DataFrame) GetDataTag() byte {
return d.payloadFrame.Tag
}
// SetSourceId set the source id.
func (d *DataFrame) SetSourceId(sourceID string) {
d.metaFrame.SetSourceId(sourceID)
}
// SourceId returns source id
func (d *DataFrame) SourceId() string {
return d.metaFrame.SourceId()
}
// Encode return coder encoded bytes of `DataFrame`
func (d *DataFrame) Encode() []byte {
data := coder.NewNodePacketEncoder(byte(d.Type()))
// MetaFrame
data.AddBytes(d.metaFrame.Encode())
// PayloadFrame
data.AddBytes(d.payloadFrame.Encode())
return data.Encode()
}
// DecodeToDataFrame decode coder encoded bytes to `DataFrame`
func DecodeToDataFrame(buf []byte) (*DataFrame, error) {
packet := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &packet)
if err != nil {
return nil, err
}
data := &DataFrame{}
if metaBlock, ok := packet.NodePackets[byte(TagOfMetaFrame)]; ok {
meta, err := DecodeToMetaFrame(metaBlock.GetRawBytes())
if err != nil {
return nil, err
}
data.metaFrame = meta
}
if payloadBlock, ok := packet.NodePackets[byte(TagOfPayloadFrame)]; ok {
payload, err := DecodeToPayloadFrame(payloadBlock.GetRawBytes())
if err != nil {
return nil, err
}
data.payloadFrame = payload
}
return data, nil
}

39
frame/data_frame_test.go Normal file
View File

@ -0,0 +1,39 @@
package frame
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDataFrameEncode(t *testing.T) {
var userDataTag byte = 0x15
d := NewDataFrame()
d.SetCarriage(userDataTag, []byte("hpds"))
tidbuf := []byte(d.TransactionId())
result := []byte{
0x80 | byte(TagOfDataFrame), byte(len(tidbuf) + 4 + 8 + 2),
0x80 | byte(TagOfMetaFrame), byte(len(tidbuf) + 2 + 2),
byte(TagOfTransactionId), byte(len(tidbuf))}
result = append(result, tidbuf...)
result = append(result, byte(TagOfSourceId), 0x0)
result = append(result, 0x80|byte(TagOfPayloadFrame), 0x06,
userDataTag, 0x04, 0x79, 0x6F, 0x6D, 0x6F)
assert.Equal(t, result, d.Encode())
}
func TestDataFrameDecode(t *testing.T) {
var userDataTag byte = 0x15
buf := []byte{
0x80 | byte(TagOfDataFrame), 0x10,
0x80 | byte(TagOfMetaFrame), 0x06,
byte(TagOfTransactionId), 0x04, 0x31, 0x32, 0x33, 0x34,
0x80 | byte(TagOfPayloadFrame), 0x06,
userDataTag, 0x04, 0x79, 0x6F, 0x6D, 0x6F}
data, err := DecodeToDataFrame(buf)
assert.NoError(t, err)
assert.EqualValues(t, "1234", data.TransactionId())
assert.EqualValues(t, userDataTag, data.GetDataTag())
assert.EqualValues(t, []byte("hpds"), data.GetCarriage())
}

106
frame/frame.go Normal file
View File

@ -0,0 +1,106 @@
package frame
import (
"os"
"strconv"
)
// debugFrameSize print frame data size on debug mode
var debugFrameSize = 16
// Kinds of frames transferable within HPDS
const (
// DataFrame
TagOfDataFrame Type = 0x3F
// MetaFrame of DataFrame
TagOfMetaFrame Type = 0x2F
TagOfMetadata Type = 0x03
TagOfTransactionId Type = 0x01
TagOfSourceId Type = 0x02
// PayloadFrame of DataFrame
TagOfPayloadFrame Type = 0x2E
TagOfBackFlowFrame Type = 0x2D
TagOfTokenFrame Type = 0x3E
// HandshakeFrame
TagOfHandshakeFrame Type = 0x3D
TagOfHandshakeName Type = 0x01
TagOfHandshakeType Type = 0x02
TagOfHandshakeId Type = 0x03
TagOfHandshakeAuthName Type = 0x04
TagOfHandshakeAuthPayload Type = 0x05
TagOfHandshakeObserveDataTags Type = 0x06
TagOfPingFrame Type = 0x3C
TagOfPongFrame Type = 0x3B
TagOfAcceptedFrame Type = 0x3A
TagOfRejectedFrame Type = 0x39
TagOfRejectedMessage Type = 0x02
// GoawayFrame
TagOfGoawayFrame Type = 0x30
TagOfGoawayCode Type = 0x01
TagOfGoawayMessage Type = 0x02
)
// Type represents the type of frame.
type Type uint8
// Frame is the interface for frame.
type Frame interface {
// Type gets the type of Frame.
Type() Type
// Encode the frame into []byte.
Encode() []byte
}
func (f Type) String() string {
switch f {
case TagOfDataFrame:
return "DataFrame"
case TagOfTokenFrame:
return "TokenFrame"
case TagOfHandshakeFrame:
return "HandshakeFrame"
case TagOfPingFrame:
return "PingFrame"
case TagOfPongFrame:
return "PongFrame"
case TagOfAcceptedFrame:
return "AcceptedFrame"
case TagOfRejectedFrame:
return "RejectedFrame"
case TagOfGoawayFrame:
return "GoawayFrame"
case TagOfBackFlowFrame:
return "BackFlowFrame"
case TagOfMetaFrame:
return "MetaFrame"
case TagOfPayloadFrame:
return "PayloadFrame"
// case TagOfTransactionId:
// return "TransactionId"
case TagOfHandshakeName:
return "HandshakeName"
case TagOfHandshakeType:
return "HandshakeType"
default:
return "UnknownFrame"
}
}
// Shortly reduce data size for easy viewing
func Shortly(data []byte) []byte {
if len(data) > debugFrameSize {
return data[:debugFrameSize]
}
return data
}
func init() {
if envFrameSize := os.Getenv("DEBUG_FRAME_SIZE"); envFrameSize != "" {
if val, err := strconv.Atoi(envFrameSize); err == nil {
debugFrameSize = val
}
}
}

110
frame/frame.puml Normal file
View File

@ -0,0 +1,110 @@
@startuml
namespace frame {
class AcceptedFrame << (S,Aquamarine) >> {
+ Type() Type
+ Encode() []byte
}
class BackFlowFrame << (S,Aquamarine) >> {
+ Tag byte
+ Carriage []byte
+ Type() Type
+ SetCarriage(buf []byte) *BackFlowFrame
+ Encode() []byte
+ GetDataTag() byte
+ GetCarriage() []byte
}
class DataFrame << (S,Aquamarine) >> {
- metaFrame *MetaFrame
- payloadFrame *PayloadFrame
+ Type() Type
+ Tag() byte
+ SetCarriage(tag byte, carriage []byte)
+ GetCarriage() []byte
+ TransactionId() string
+ SetTransactionId(transactionID string)
+ GetMetaFrame() *MetaFrame
+ GetDataTag() byte
+ SetSourceId(sourceID string)
+ SourceId() string
+ Encode() []byte
}
interface Frame {
+ Type() Type
+ Encode() []byte
}
class GoawayFrame << (S,Aquamarine) >> {
- message string
+ Type() Type
+ Encode() []byte
+ Message() string
}
class HandshakeFrame << (S,Aquamarine) >> {
- authName string
- authPayload string
+ Name string
+ ClientId string
+ ClientType byte
+ ObserveDataTags []byte
+ Type() Type
+ Encode() []byte
+ AuthPayload() string
+ AuthName() string
}
class MetaFrame << (S,Aquamarine) >> {
- tid string
- metadata []byte
- sourceId string
+ SetTransactionId(transactionId string)
+ TransactionId() string
+ SetMetadata(metadata []byte)
+ Metadata() []byte
+ SetSourceId(sourceId string)
+ SourceId() string
+ Encode() []byte
}
class PayloadFrame << (S,Aquamarine) >> {
+ Tag byte
+ Carriage []byte
+ SetCarriage(buf []byte) *PayloadFrame
+ Encode() []byte
}
class RejectedFrame << (S,Aquamarine) >> {
- message string
+ Type() Type
+ Encode() []byte
+ Message() string
}
class Type << (S,Aquamarine) >> {
+ String() string
}
class frame.Type << (T, #FF7700) >> {
}
}
"frame.Frame" <|-- "frame.AcceptedFrame"
"frame.Frame" <|-- "frame.BackFlowFrame"
"frame.Frame" <|-- "frame.DataFrame"
"frame.Frame" <|-- "frame.GoawayFrame"
"frame.Frame" <|-- "frame.HandshakeFrame"
"frame.Frame" <|-- "frame.RejectedFrame"
"__builtin__.uint8" #.. "frame.Type"
@enduml

57
frame/goaway_frame.go Normal file
View File

@ -0,0 +1,57 @@
package frame
import (
coder "git.hpds.cc/Component/mq_coder"
)
// GoawayFrame is a coder encoded bytes, Tag is a fixed value TYPE_ID_GOAWAY_FRAME
type GoawayFrame struct {
message string
}
// NewGoawayFrame creates a new GoawayFrame
func NewGoawayFrame(msg string) *GoawayFrame {
return &GoawayFrame{message: msg}
}
// Type gets the type of Frame.
func (f *GoawayFrame) Type() Type {
return TagOfGoawayFrame
}
// Encode to coder encoded bytes
func (f *GoawayFrame) Encode() []byte {
goaway := coder.NewNodePacketEncoder(byte(f.Type()))
// message
msgBlock := coder.NewPrimitivePacketEncoder(byte(TagOfGoawayMessage))
msgBlock.SetStringValue(f.message)
goaway.AddPrimitivePacket(msgBlock)
return goaway.Encode()
}
// Message goaway message
func (f *GoawayFrame) Message() string {
return f.message
}
// DecodeToGoawayFrame decodes coder encoded bytes to GoawayFrame
func DecodeToGoawayFrame(buf []byte) (*GoawayFrame, error) {
node := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &node)
if err != nil {
return nil, err
}
goaway := &GoawayFrame{}
// message
if msgBlock, ok := node.PrimitivePackets[byte(TagOfGoawayMessage)]; ok {
msg, err := msgBlock.ToUTF8String()
if err != nil {
return nil, err
}
goaway.message = msg
}
return goaway, nil
}

131
frame/handshake_frame.go Normal file
View File

@ -0,0 +1,131 @@
package frame
import (
coder "git.hpds.cc/Component/mq_coder"
)
// HandshakeFrame is a coder encoded.
type HandshakeFrame struct {
// Name is client name
Name string
// ClientId represents client id
ClientId string
// ClientType represents client type (Protocol Gateway | Stream Function)
ClientType byte
// ObserveDataTags are the client data tag list.
ObserveDataTags []byte
// auth
authName string
authPayload string
}
// NewHandshakeFrame creates a new HandshakeFrame.
func NewHandshakeFrame(name string, clientId string, clientType byte, observeDataTags []byte, authName string, authPayload string) *HandshakeFrame {
return &HandshakeFrame{
Name: name,
ClientId: clientId,
ClientType: clientType,
ObserveDataTags: observeDataTags,
authName: authName,
authPayload: authPayload,
}
}
// Type gets the type of Frame.
func (h *HandshakeFrame) Type() Type {
return TagOfHandshakeFrame
}
// Encode to coder encoding.
func (h *HandshakeFrame) Encode() []byte {
// name
nameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeName))
nameBlock.SetStringValue(h.Name)
// client ID
idBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeId))
idBlock.SetStringValue(h.ClientId)
// client type
typeBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeType))
typeBlock.SetBytesValue([]byte{h.ClientType})
// observe data tags
observeDataTagsBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeObserveDataTags))
observeDataTagsBlock.SetBytesValue(h.ObserveDataTags)
// auth
authNameBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthName))
authNameBlock.SetStringValue(h.authName)
authPayloadBlock := coder.NewPrimitivePacketEncoder(byte(TagOfHandshakeAuthPayload))
authPayloadBlock.SetStringValue(h.authPayload)
// handshake frame
handshake := coder.NewNodePacketEncoder(byte(h.Type()))
handshake.AddPrimitivePacket(nameBlock)
handshake.AddPrimitivePacket(idBlock)
handshake.AddPrimitivePacket(typeBlock)
handshake.AddPrimitivePacket(observeDataTagsBlock)
handshake.AddPrimitivePacket(authNameBlock)
handshake.AddPrimitivePacket(authPayloadBlock)
return handshake.Encode()
}
// DecodeToHandshakeFrame decodes coder encoded bytes to HandshakeFrame.
func DecodeToHandshakeFrame(buf []byte) (*HandshakeFrame, error) {
node := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &node)
if err != nil {
return nil, err
}
handshake := &HandshakeFrame{}
// name
if nameBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeName)]; ok {
name, err := nameBlock.ToUTF8String()
if err != nil {
return nil, err
}
handshake.Name = name
}
// client id
if idBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeId)]; ok {
id, err := idBlock.ToUTF8String()
if err != nil {
return nil, err
}
handshake.ClientId = id
}
// client type
if typeBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeType)]; ok {
clientType := typeBlock.ToBytes()
handshake.ClientType = clientType[0]
}
// observe data tag list
if observeDataTagsBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeObserveDataTags)]; ok {
handshake.ObserveDataTags = observeDataTagsBlock.ToBytes()
}
// auth
if authNameBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthName)]; ok {
authName, err := authNameBlock.ToUTF8String()
if err != nil {
return nil, err
}
handshake.authName = authName
}
if authPayloadBlock, ok := node.PrimitivePackets[byte(TagOfHandshakeAuthPayload)]; ok {
authPayload, err := authPayloadBlock.ToUTF8String()
if err != nil {
return nil, err
}
handshake.authPayload = authPayload
}
return handshake, nil
}
// AuthPayload authentication payload
func (h *HandshakeFrame) AuthPayload() string {
return h.authPayload
}
// AuthName authentication name
func (h *HandshakeFrame) AuthName() string {
return h.authName
}

View File

@ -0,0 +1,30 @@
package frame
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHandshakeFrameEncode(t *testing.T) {
expectedName := "1234"
var expectedType byte = 0xD3
m := NewHandshakeFrame(expectedName, "", expectedType, []byte{0x01, 0x02}, "token", "a")
assert.Equal(t, []byte{
0x80 | byte(TagOfHandshakeFrame), 0x19,
byte(TagOfHandshakeName), 0x04, 0x31, 0x32, 0x33, 0x34,
byte(TagOfHandshakeId), 0x0,
byte(TagOfHandshakeType), 0x01, 0xD3,
byte(TagOfHandshakeObserveDataTags), 0x02, 0x01, 0x02,
// byte(TagOfHandshakeAppID), 0x0,
byte(TagOfHandshakeAuthName), 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
byte(TagOfHandshakeAuthPayload), 0x01, 0x61,
},
m.Encode(),
)
Handshake, err := DecodeToHandshakeFrame(m.Encode())
assert.NoError(t, err)
assert.EqualValues(t, expectedName, Handshake.Name)
assert.EqualValues(t, expectedType, Handshake.ClientType)
}

113
frame/meta_frame.go Normal file
View File

@ -0,0 +1,113 @@
package frame
import (
"strconv"
"time"
coder "git.hpds.cc/Component/mq_coder"
gonanoid "github.com/matoous/go-nanoid/v2"
)
// MetaFrame is a coder encoded bytes, SeqId is a fixed value of TYPE_ID_TRANSACTION.
// used for describes metadata for a DataFrame.
type MetaFrame struct {
tid string
metadata []byte
sourceId string
}
// NewMetaFrame creates a new MetaFrame instance.
func NewMetaFrame() *MetaFrame {
tid, err := gonanoid.New()
if err != nil {
tid = strconv.FormatInt(time.Now().UnixMicro(), 10)
}
return &MetaFrame{tid: tid}
}
// SetTransactionId set the transaction id.
func (m *MetaFrame) SetTransactionId(transactionId string) {
m.tid = transactionId
}
// TransactionId returns transactionId
func (m *MetaFrame) TransactionId() string {
return m.tid
}
// SetMetadata set the extra info of the application
func (m *MetaFrame) SetMetadata(metadata []byte) {
m.metadata = metadata
}
// Metadata returns the extra info of the application
func (m *MetaFrame) Metadata() []byte {
return m.metadata
}
// SetSourceId set the source ID.
func (m *MetaFrame) SetSourceId(sourceId string) {
m.sourceId = sourceId
}
// SourceId returns source ID
func (m *MetaFrame) SourceId() string {
return m.sourceId
}
// Encode implements Frame.Encode method.
func (m *MetaFrame) Encode() []byte {
meta := coder.NewNodePacketEncoder(byte(TagOfMetaFrame))
// transaction ID
transactionId := coder.NewPrimitivePacketEncoder(byte(TagOfTransactionId))
transactionId.SetStringValue(m.tid)
meta.AddPrimitivePacket(transactionId)
// source ID
sourceId := coder.NewPrimitivePacketEncoder(byte(TagOfSourceId))
sourceId.SetStringValue(m.sourceId)
meta.AddPrimitivePacket(sourceId)
// metadata
if m.metadata != nil {
metadata := coder.NewPrimitivePacketEncoder(byte(TagOfMetadata))
metadata.SetBytesValue(m.metadata)
meta.AddPrimitivePacket(metadata)
}
return meta.Encode()
}
// DecodeToMetaFrame decode a MetaFrame instance from given buffer.
func DecodeToMetaFrame(buf []byte) (*MetaFrame, error) {
nodeBlock := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &nodeBlock)
if err != nil {
return nil, err
}
meta := &MetaFrame{}
for k, v := range nodeBlock.PrimitivePackets {
switch k {
case byte(TagOfTransactionId):
val, err := v.ToUTF8String()
if err != nil {
return nil, err
}
meta.tid = val
break
case byte(TagOfMetadata):
meta.metadata = v.ToBytes()
break
case byte(TagOfSourceId):
sourceId, err := v.ToUTF8String()
if err != nil {
return nil, err
}
meta.sourceId = sourceId
break
}
}
return meta, nil
}

25
frame/meta_frame_test.go Normal file
View File

@ -0,0 +1,25 @@
package frame
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMetaFrameEncode(t *testing.T) {
m := NewMetaFrame()
tidbuf := []byte(m.tid)
result := []byte{0x80 | byte(TagOfMetaFrame), byte(1 + 1 + len(tidbuf) + 2), byte(TagOfTransactionId), byte(len(tidbuf))}
result = append(result, tidbuf...)
result = append(result, byte(TagOfSourceId), 0x0)
assert.Equal(t, result, m.Encode())
}
func TestMetaFrameDecode(t *testing.T) {
buf := []byte{0x80 | byte(TagOfMetaFrame), 0x09, byte(TagOfTransactionId), 0x04, 0x31, 0x32, 0x33, 0x34, byte(TagOfSourceId), 0x01, 0x31}
meta, err := DecodeToMetaFrame(buf)
assert.NoError(t, err)
assert.EqualValues(t, "1234", meta.TransactionId())
assert.EqualValues(t, "1", meta.SourceId())
t.Logf("%# x", buf)
}

55
frame/payload_frame.go Normal file
View File

@ -0,0 +1,55 @@
package frame
import (
coder "git.hpds.cc/Component/mq_coder"
)
// PayloadFrame is a coder encoded bytes, Tag is a fixed value TYPE_ID_PAYLOAD_FRAME
// the Len is the length of Val. Val is also a coder encoded PrimitivePacket, storing
// raw bytes as user's data
type PayloadFrame struct {
Tag byte
Carriage []byte
}
// NewPayloadFrame creates a new PayloadFrame with a given TagId of user's data
func NewPayloadFrame(tag byte) *PayloadFrame {
return &PayloadFrame{
Tag: tag,
}
}
// SetCarriage sets the user's raw data
func (m *PayloadFrame) SetCarriage(buf []byte) *PayloadFrame {
m.Carriage = buf
return m
}
// Encode to coder encoded bytes
func (m *PayloadFrame) Encode() []byte {
carriage := coder.NewPrimitivePacketEncoder(m.Tag)
carriage.SetBytesValue(m.Carriage)
payload := coder.NewNodePacketEncoder(byte(TagOfPayloadFrame))
payload.AddPrimitivePacket(carriage)
return payload.Encode()
}
// DecodeToPayloadFrame decodes coder encoded bytes to PayloadFrame
func DecodeToPayloadFrame(buf []byte) (*PayloadFrame, error) {
nodeBlock := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &nodeBlock)
if err != nil {
return nil, err
}
payload := &PayloadFrame{}
for _, v := range nodeBlock.PrimitivePackets {
payload.Tag = v.SeqId()
payload.Carriage = v.GetValBuf()
break
}
return payload, nil
}

View File

@ -0,0 +1,20 @@
package frame
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestPayloadFrameEncode(t *testing.T) {
f := NewPayloadFrame(0x13).SetCarriage([]byte("hpds"))
assert.Equal(t, []byte{0x80 | byte(TagOfPayloadFrame), 0x06, 0x13, 0x04, 0x79, 0x6F, 0x6D, 0x6F}, f.Encode())
}
func TestPayloadFrameDecode(t *testing.T) {
buf := []byte{0x80 | byte(TagOfPayloadFrame), 0x06, 0x13, 0x04, 0x79, 0x6F, 0x6D, 0x6F}
payload, err := DecodeToPayloadFrame(buf)
assert.NoError(t, err)
assert.EqualValues(t, 0x13, payload.Tag)
assert.Equal(t, []byte{0x79, 0x6F, 0x6D, 0x6F}, payload.Carriage)
}

56
frame/rejected_frame.go Normal file
View File

@ -0,0 +1,56 @@
package frame
import (
coder "git.hpds.cc/Component/mq_coder"
)
// RejectedFrame is a coder encoded bytes, Tag is a fixed value TYPE_ID_REJECTED_FRAME
type RejectedFrame struct {
message string
}
// NewRejectedFrame creates a new RejectedFrame with a given TagId of user's data
func NewRejectedFrame(msg string) *RejectedFrame {
return &RejectedFrame{message: msg}
}
// Type gets the type of Frame.
func (f *RejectedFrame) Type() Type {
return TagOfRejectedFrame
}
// Encode to coder encoded bytes
func (f *RejectedFrame) Encode() []byte {
rejected := coder.NewNodePacketEncoder(byte(f.Type()))
// message
msgBlock := coder.NewPrimitivePacketEncoder(byte(TagOfRejectedMessage))
msgBlock.SetStringValue(f.message)
rejected.AddPrimitivePacket(msgBlock)
return rejected.Encode()
}
// Message rejected message
func (f *RejectedFrame) Message() string {
return f.message
}
// DecodeToRejectedFrame decodes coder encoded bytes to RejectedFrame
func DecodeToRejectedFrame(buf []byte) (*RejectedFrame, error) {
node := coder.NodePacket{}
_, err := coder.DecodeToNodePacket(buf, &node)
if err != nil {
return nil, err
}
rejected := &RejectedFrame{}
// message
if msgBlock, ok := node.PrimitivePackets[byte(TagOfRejectedMessage)]; ok {
msg, e := msgBlock.ToUTF8String()
if e != nil {
return nil, e
}
rejected.message = msg
}
return rejected, nil
}

View File

@ -0,0 +1,19 @@
package frame
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRejectedFrameEncode(t *testing.T) {
f := NewRejectedFrame("")
assert.Equal(t, []byte{0x80 | byte(TagOfRejectedFrame), 0x02, 0x02, 0x00}, f.Encode())
}
func TestRejectedFrameDecode(t *testing.T) {
buf := []byte{0x80 | byte(TagOfRejectedFrame), 0x00}
ping, err := DecodeToRejectedFrame(buf)
assert.NoError(t, err)
assert.Equal(t, []byte{0x80 | byte(TagOfRejectedFrame), 0x2, 0x2, 0x0}, ping.Encode())
}

42
frame_stream.go Normal file
View File

@ -0,0 +1,42 @@
package network
import (
"errors"
"io"
"sync"
"git.hpds.cc/Component/network/frame"
)
// FrameStream is the QUIC Stream with the minimum unit Frame.
type FrameStream struct {
// Stream is a QUIC stream.
stream io.ReadWriter
mu sync.Mutex
}
// NewFrameStream creates a new FrameStream.
func NewFrameStream(s io.ReadWriter) *FrameStream {
return &FrameStream{
stream: s,
mu: sync.Mutex{},
}
}
// ReadFrame reads next frame from QUIC stream.
func (fs *FrameStream) ReadFrame() (frame.Frame, error) {
if fs.stream == nil {
return nil, errors.New("network.ReadStream: stream can not be nil")
}
return ParseFrame(fs.stream)
}
// WriteFrame writes a frame into QUIC stream.
func (fs *FrameStream) WriteFrame(f frame.Frame) (int, error) {
if fs.stream == nil {
return 0, errors.New("network.WriteFrame: stream can not be nil")
}
fs.mu.Lock()
defer fs.mu.Unlock()
return fs.stream.Write(f.Encode())
}

37
go.mod Normal file
View File

@ -0,0 +1,37 @@
module git.hpds.cc/Component/network
go 1.19
require (
git.hpds.cc/Component/mq_coder v0.0.0-20221010064749-174ae7ae3340
github.com/lucas-clemente/quic-go v0.29.1
github.com/matoous/go-nanoid/v2 v2.0.0
github.com/stretchr/testify v1.8.0
go.uber.org/zap v1.23.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
require (
github.com/BurntSushi/toml v1.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect
github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo v1.16.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/tools v0.1.10 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

9
handler_type.go Normal file
View File

@ -0,0 +1,9 @@
package network
import "git.hpds.cc/Component/network/frame"
// AsyncHandler is the request-response mode (async)
type AsyncHandler func([]byte) (byte, []byte)
// PipeHandler is the bidirectional stream mode (blocking).
type PipeHandler func(in <-chan []byte, out chan<- *frame.PayloadFrame)

125
hpds_err/errors.go Normal file
View File

@ -0,0 +1,125 @@
package hpds_err
import (
"fmt"
quic "github.com/lucas-clemente/quic-go"
)
// HpdsError hpds error
type HpdsError struct {
errorCode ErrorCode
err error
}
// New create hpds error
func New(code ErrorCode, err error) *HpdsError {
return &HpdsError{
errorCode: code,
err: err,
}
}
func (e *HpdsError) Error() string {
return fmt.Sprintf("%s error: message=%s", e.errorCode, e.err.Error())
}
// ErrorCode error code
type ErrorCode uint64
const (
// ErrorCodeClientAbort client abort
ErrorCodeClientAbort ErrorCode = 0x00
// ErrorCodeUnknown unknown error
ErrorCodeUnknown ErrorCode = 0xC0
// ErrorCodeClosed net closed
ErrorCodeClosed ErrorCode = 0xC1
// ErrorCodeBeforeHandler before handler
ErrorCodeBeforeHandler ErrorCode = 0xC2
// ErrorCodeMainHandler main handler
ErrorCodeMainHandler ErrorCode = 0xC3
// ErrorCodeAfterHandler after handler
ErrorCodeAfterHandler ErrorCode = 0xC4
// ErrorCodeHandshake handshake frame
ErrorCodeHandshake ErrorCode = 0xC5
// ErrorCodeRejected server rejected
ErrorCodeRejected ErrorCode = 0xCC
// ErrorCodeGoaway goaway frame
ErrorCodeGoaway ErrorCode = 0xCF
// ErrorCodeData data frame
ErrorCodeData ErrorCode = 0xCE
// ErrorCodeUnknownClient unknown client error
ErrorCodeUnknownClient ErrorCode = 0xCD
// ErrorCodeDuplicateName unknown client error
ErrorCodeDuplicateName ErrorCode = 0xC6
)
func (e ErrorCode) String() string {
switch e {
case ErrorCodeClientAbort:
return "ClientAbort"
case ErrorCodeUnknown:
return "UnknownError"
case ErrorCodeClosed:
return "NetClosed"
case ErrorCodeBeforeHandler:
return "BeforeHandler"
case ErrorCodeMainHandler:
return "MainHandler"
case ErrorCodeAfterHandler:
return "AfterHandler"
case ErrorCodeHandshake:
return "Handshake"
case ErrorCodeRejected:
return "Rejected"
case ErrorCodeGoaway:
return "Goaway"
case ErrorCodeData:
return "DataFrame"
case ErrorCodeUnknownClient:
return "UnknownClient"
case ErrorCodeDuplicateName:
return "DuplicateName"
default:
return "XXX"
}
}
// Is parse quic ApplicationErrorCode to hpds ErrorCode
func Is(he quic.ApplicationErrorCode, yerr ErrorCode) bool {
return uint64(he) == uint64(yerr)
}
// Parse parse quic ApplicationErrorCode
func Parse(he quic.ApplicationErrorCode) ErrorCode {
return ErrorCode(he)
}
// To convert hpds ErrorCode to quic ApplicationErrorCode
func To(code ErrorCode) quic.ApplicationErrorCode {
return quic.ApplicationErrorCode(code)
}
// DuplicateNameError duplicate name(sfn)
type DuplicateNameError struct {
connId string
err error
}
// NewDuplicateNameError create a duplicate name error
func NewDuplicateNameError(connId string, err error) DuplicateNameError {
return DuplicateNameError{
connId: connId,
err: err,
}
}
// Error raw error
func (e DuplicateNameError) Error() string {
return e.err.Error()
}
// ConnId duplicate connection ID
func (e DuplicateNameError) ConnId() string {
return e.connId
}

16
id/id.go Normal file
View File

@ -0,0 +1,16 @@
package id
import (
"git.hpds.cc/Component/network/log"
gonanoid "github.com/matoous/go-nanoid/v2"
)
// New generate id
func New() string {
id, err := gonanoid.New()
if err != nil {
log.Errorf("generated id err=%v", err)
return ""
}
return id
}

73
listener.go Normal file
View File

@ -0,0 +1,73 @@
package network
import (
"crypto/tls"
"git.hpds.cc/Component/network/log"
"github.com/lucas-clemente/quic-go"
"net"
"time"
pkgtls "git.hpds.cc/Component/network/tls"
)
// A Listener for incoming connections
type Listener interface {
quic.Listener
// Name Listener's name
Name() string
// Versions get Version
Versions() []string
}
var _ Listener = (*defaultListener)(nil)
type defaultListener struct {
conf *quic.Config
quic.Listener
}
// DefaultQuicConfig be used when `quicConfig` is nil.
var DefaultQuicConfig = &quic.Config{
Versions: []quic.VersionNumber{quic.Version1, quic.VersionDraft29},
MaxIdleTimeout: time.Second * 5,
KeepAlivePeriod: time.Second * 2,
MaxIncomingStreams: 1000,
MaxIncomingUniStreams: 1000,
HandshakeIdleTimeout: time.Second * 3,
InitialStreamReceiveWindow: 1024 * 1024 * 2,
InitialConnectionReceiveWindow: 1024 * 1024 * 2,
// DisablePathMTUDiscovery: true,
}
func newListener(conn net.PacketConn, tlsConfig *tls.Config, quicConfig *quic.Config) (*defaultListener, error) {
if tlsConfig == nil {
tc, err := pkgtls.CreateServerTLSConfig(conn.LocalAddr().String())
if err != nil {
log.Errorf("%sCreateServerTLSConfig: %v", ServerLogPrefix, err)
return &defaultListener{}, err
}
tlsConfig = tc
}
if quicConfig == nil {
quicConfig = DefaultQuicConfig
}
quicListener, err := quic.Listen(conn, tlsConfig, quicConfig)
if err != nil {
log.Errorf("%squic Listen: %v", ServerLogPrefix, err)
return &defaultListener{}, err
}
return &defaultListener{conf: quicConfig, Listener: quicListener}, nil
}
func (l *defaultListener) Name() string { return "QUIC-Server" }
func (l *defaultListener) Versions() []string {
versions := make([]string, len(l.conf.Versions))
for k, v := range l.conf.Versions {
versions[k] = v.String()
}
return versions
}

143
log/logger.go Normal file
View File

@ -0,0 +1,143 @@
package log
import (
"os"
"strings"
)
// Level of log
type Level uint8
const (
// DebugLevel defines debug log level.
DebugLevel Level = iota + 1
// InfoLevel defines info log level.
InfoLevel
// WarnLevel defines warn log level.
WarnLevel
// ErrorLevel defines error log level.
ErrorLevel
// NoLevel defines an absent log level.
NoLevel Level = 254
// Disabled disables the logger.
Disabled Level = 255
)
// Logger is the interface for logger.
type Logger interface {
// SetLevel sets the logger level
SetLevel(Level)
// SetEncoding sets the logger's encoding
SetEncoding(encoding string)
// Printf logs a message without level
Printf(template string, args ...interface{})
// Debugf logs a message at DebugLevel
Debugf(template string, args ...interface{})
// Infof logs a message at InfoLevel
Infof(template string, args ...interface{})
// Warnf logs a message at WarnLevel
Warnf(template string, args ...interface{})
// Errorf logs a message at ErrorLevel
Errorf(template string, args ...interface{})
// Output file path to write log message
Output(file string)
// ErrorOutput file path to write error message
ErrorOutput(file string)
}
// String the logger level
func (l Level) String() string {
switch l {
case DebugLevel:
return "DEBUG"
case ErrorLevel:
return "ERROR"
case WarnLevel:
return "WARN"
case InfoLevel:
return "INFO"
default:
return ""
}
}
// 实例
var logger Logger
func init() {
logger = Default(isEnableDebug())
}
// SetLogger allows developers to customize the logger instance.
func SetLogger(l Logger) {
logger = l
}
// EnableDebug enables the development model for logging.
// Deprecated
func EnableDebug() {
logger = Default(true)
}
// Printf prints a formatted message without a specified level.
func Printf(format string, v ...interface{}) {
logger.Printf(format, v...)
}
// Debugf logs a message at DebugLevel.
func Debugf(template string, args ...interface{}) {
logger.Debugf(template, args...)
}
// Infof logs a message at InfoLevel.
func Infof(template string, args ...interface{}) {
logger.Infof(template, args...)
}
// Warnf logs a message at WarnLevel.
func Warnf(template string, args ...interface{}) {
logger.Warnf(template, args...)
}
// Errorf logs a message at ErrorLevel.
func Errorf(template string, args ...interface{}) {
logger.Errorf(template, args...)
}
// isEnableDebug indicates whether the debug is enabled.
func isEnableDebug() bool {
return os.Getenv("HPDS_ENABLE_DEBUG") == "true"
}
// isJSONFormat indicates whether the log is in JSON format.
func isJSONFormat() bool {
return os.Getenv("HPDS_LOG_FORMAT") == "json"
}
func logFormat() string {
return os.Getenv("HPDS_LOG_FORMAT")
}
func logLevel() Level {
envLevel := strings.ToLower(os.Getenv("HPDS_LOG_LEVEL"))
level := ErrorLevel
switch envLevel {
case "debug":
return DebugLevel
case "info":
return InfoLevel
case "warn":
return WarnLevel
case "error":
return ErrorLevel
}
return level
}
func output() string {
return strings.ToLower(os.Getenv("HPDS_LOG_OUTPUT"))
}
func errorOutput() string {
return strings.ToLower(os.Getenv("HPDS_LOG_ERROR_OUTPUT"))
}

227
log/zap.go Normal file
View File

@ -0,0 +1,227 @@
package log
import (
stdlog "log"
"os"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
const (
timeFormat = "2006-01-02 15:04:05.000"
)
// zapLogger is the logger implementation in go.uber.org/zap
type zapLogger struct {
level zapcore.Level
debug bool
encoding string
opts []zap.Option
logger *zap.Logger
instance *zap.SugaredLogger
output string
errorOutput string
}
// Default the default logger instance
func Default(debug ...bool) Logger {
z := New()
z.SetLevel(logLevel())
if isJSONFormat() {
z.SetEncoding("json")
}
// env debug
if isEnableDebug() {
z.SetLevel(DebugLevel)
}
if len(debug) > 0 {
if debug[0] {
z.SetLevel(DebugLevel)
}
}
z.Output(output())
z.ErrorOutput(errorOutput())
return z
}
// New create new logger instance
func New(opts ...zap.Option) Logger {
// std logger
stdlog.Default().SetFlags(0)
stdlog.Default().SetOutput(new(logWriter))
z := zapLogger{
level: zap.ErrorLevel,
debug: false,
encoding: "console",
opts: opts,
}
return &z
}
func openSinks(cfg zap.Config) (zapcore.WriteSyncer, zapcore.WriteSyncer, error) {
sink, closeOut, err := zap.Open(cfg.OutputPaths...)
if err != nil {
return nil, nil, err
}
errSink, _, err := zap.Open(cfg.ErrorOutputPaths...)
if err != nil {
closeOut()
return nil, nil, err
}
return sink, errSink, nil
}
// SetEncoding set logger message coding
func (z *zapLogger) SetEncoding(enc string) {
z.encoding = enc
}
// SetLevel set logger level
func (z *zapLogger) SetLevel(lvl Level) {
isDebug := lvl == DebugLevel
level := zap.ErrorLevel
switch lvl {
case DebugLevel:
level = zap.DebugLevel
case InfoLevel:
level = zap.InfoLevel
case WarnLevel:
level = zap.WarnLevel
case ErrorLevel:
level = zap.ErrorLevel
}
z.level = level
z.debug = isDebug
}
// Output file path to write log message
func (z *zapLogger) Output(file string) {
if file != "" {
z.output = file
}
}
// ErrorOutput file path to write log message
func (z *zapLogger) ErrorOutput(file string) {
if file != "" {
z.errorOutput = file
}
}
// Printf logs a message wihout level
func (z *zapLogger) Printf(format string, v ...interface{}) {
stdlog.Printf(format, v...)
}
// Debugf logs a message at DebugLevel
func (z *zapLogger) Debugf(template string, args ...interface{}) {
z.Instance().Debugf(template, args...)
}
// Infof logs a message at InfoLevel
func (z *zapLogger) Infof(template string, args ...interface{}) {
z.Instance().Infof(template, args...)
}
// Warnf logs a message at WarnLevel
func (z zapLogger) Warnf(template string, args ...interface{}) {
z.Instance().Warnf(template, args...)
}
// Errorf logs a message at ErrorLevel
func (z zapLogger) Errorf(template string, args ...interface{}) {
z.Instance().Errorf(template, args...)
}
func (z *zapLogger) Instance() *zap.SugaredLogger {
if z.instance == nil {
// zap
encoderConfig := zapcore.EncoderConfig{
TimeKey: "ts",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
FunctionKey: zapcore.OmitKey,
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.CapitalColorLevelEncoder,
EncodeTime: timeEncoder,
EncodeDuration: zapcore.SecondsDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
cfg := zap.Config{
Level: zap.NewAtomicLevelAt(zap.ErrorLevel),
Development: z.debug,
DisableCaller: true,
DisableStacktrace: true,
Encoding: z.encoding,
EncoderConfig: encoderConfig,
OutputPaths: []string{"stderr"},
ErrorOutputPaths: []string{"stderr"},
}
cfg.Level.SetLevel(z.level)
if z.debug {
// set the minimal level to debug
cfg.Level.SetLevel(zap.DebugLevel)
}
// output
if z.output != "" {
cfg.OutputPaths = append(cfg.OutputPaths, z.output)
}
encoder := zapcore.NewConsoleEncoder(encoderConfig)
sink, _, err := openSinks(cfg)
if err != nil {
panic(err)
}
core := zapcore.NewCore(encoder, sink, cfg.Level)
// error output
if z.errorOutput != "" {
rotatedLogger := errorRotatedLogger(z.errorOutput, 10, 30, 7)
errorOutputOption := zap.Hooks(func(entry zapcore.Entry) error {
if entry.Level == zap.ErrorLevel {
msg, err := encoder.EncodeEntry(entry, nil)
if err != nil {
return err
}
rotatedLogger.Write(msg.Bytes())
}
return nil
})
z.opts = append(z.opts, errorOutputOption)
}
l := zap.New(core, z.opts...)
z.logger = l
z.instance = z.logger.Sugar()
}
return z.instance
}
func errorRotatedLogger(file string, maxSize, maxBacukups, maxAge int) *lumberjack.Logger {
return &lumberjack.Logger{
Filename: file,
MaxSize: maxSize,
MaxBackups: maxBacukups,
MaxAge: maxAge,
Compress: false,
}
}
func timeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
enc.AppendString(t.Format(timeFormat))
}
type logWriter struct{}
func (l logWriter) Write(bytes []byte) (int, error) {
os.Stderr.WriteString(time.Now().Format(timeFormat))
os.Stderr.Write([]byte("\t"))
return os.Stderr.Write(bytes)
}

17
metadata.go Normal file
View File

@ -0,0 +1,17 @@
package network
import "git.hpds.cc/Component/network/frame"
// Metadata is used for storing extra info of the application
type Metadata interface {
// Encode is the serialize method
Encode() []byte
}
// MetadataBuilder is the builder of Metadata
type MetadataBuilder interface {
// Build will return a Metadata instance according to the handshake frame passed in
Build(f *frame.HandshakeFrame) (Metadata, error)
// Decode is the deserialize method
Decode(buf []byte) (Metadata, error)
}

47
parser_stream.go Normal file
View File

@ -0,0 +1,47 @@
package network
import (
"fmt"
coder "git.hpds.cc/Component/mq_coder"
"git.hpds.cc/Component/network/frame"
"io"
)
// ParseFrame parses the frame from QUIC stream.
func ParseFrame(stream io.Reader) (frame.Frame, error) {
buf, err := coder.ReadPacket(stream)
if err != nil {
return nil, err
}
frameType := buf[0]
// determine the frame type
switch frameType {
case 0x80 | byte(frame.TagOfHandshakeFrame):
handshakeFrame, err := readHandshakeFrame(buf)
// logger.Debugf("%sHandshakeFrame: name=%s, type=%s", ParseFrameLogPrefix, handshakeFrame.Name, handshakeFrame.Type())
return handshakeFrame, err
case 0x80 | byte(frame.TagOfDataFrame):
data, err := readDataFrame(buf)
// logger.Debugf("%sDataFrame: tid=%s, tag=%#x, len(carriage)=%d", ParseFrameLogPrefix, data.TransactionID(), data.GetDataTag(), len(data.GetCarriage()))
return data, err
case 0x80 | byte(frame.TagOfAcceptedFrame):
return frame.DecodeToAcceptedFrame(buf)
case 0x80 | byte(frame.TagOfRejectedFrame):
return frame.DecodeToRejectedFrame(buf)
case 0x80 | byte(frame.TagOfGoawayFrame):
return frame.DecodeToGoawayFrame(buf)
case 0x80 | byte(frame.TagOfBackFlowFrame):
return frame.DecodeToBackFlowFrame(buf)
default:
return nil, fmt.Errorf("unknown frame type, buf[0]=%#x", buf[0])
}
}
func readHandshakeFrame(buf []byte) (*frame.HandshakeFrame, error) {
return frame.DecodeToHandshakeFrame(buf)
}
func readDataFrame(buf []byte) (*frame.DataFrame, error) {
return frame.DecodeToDataFrame(buf)
}

19
router.go Normal file
View File

@ -0,0 +1,19 @@
package network
// Router is the interface to manage the routes for applications.
type Router interface {
// Route gets the route
Route(metadata Metadata) Route
// Clean the routes.
Clean()
}
// Route manages data subscribers according to their observed data tags.
type Route interface {
// Add a route.
Add(connId string, name string, observeDataTags []byte) error
// Remove a route.
Remove(connId string) error
// GetForwardRoutes returns all the subscribers by the given data tag.
GetForwardRoutes(tag byte) []string
}

567
server.go Normal file
View File

@ -0,0 +1,567 @@
package network
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
"sync/atomic"
// authentication implements, Currently, only token authentication is implemented
_ "git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/frame"
"git.hpds.cc/Component/network/hpds_err"
"git.hpds.cc/Component/network/log"
pkgtls "git.hpds.cc/Component/network/tls"
"github.com/lucas-clemente/quic-go"
)
const (
// DefaultListenAddr is the default address to listen.
DefaultListenAddr = "0.0.0.0:9000"
)
// ServerOption is the option for server.
type ServerOption func(*ServerOptions)
// FrameHandler is the handler for frame.
type FrameHandler func(c *Context) error
// Server is the underlining server of Message Queue
type Server struct {
name string
state string
connector Connector
router Router
metadataBuilder MetadataBuilder
counterOfDataFrame int64
downStreams map[string]*Client
mu sync.Mutex
opts ServerOptions
beforeHandlers []FrameHandler
afterHandlers []FrameHandler
}
// NewServer create a Server instance.
func NewServer(name string, opts ...ServerOption) *Server {
s := &Server{
name: name,
connector: newConnector(),
downStreams: make(map[string]*Client),
}
s.Init(opts...)
return s
}
// Init the options.
func (s *Server) Init(opts ...ServerOption) error {
for _, o := range opts {
o(&s.opts)
}
// options defaults
s.initOptions()
return nil
}
// ListenAndServe starts the server.
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
if addr == "" {
addr = DefaultListenAddr
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
return s.Serve(ctx, conn)
}
// Serve the server with a net.PacketConn.
func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error {
if err := s.validateMetadataBuilder(); err != nil {
return err
}
if err := s.validateRouter(); err != nil {
return err
}
// listen the address
listener, err := newListener(conn, s.opts.TLSConfig, s.opts.QuicConfig)
if err != nil {
log.Errorf("%slistener.Listen: err=%v", ServerLogPrefix, err)
return err
}
defer listener.Close()
log.Printf("%s [%s][%d] Listening on: %s, MODE: %s, QUIC: %v, AUTH: %s", ServerLogPrefix, s.name, os.Getpid(), listener.Addr(), mode(), listener.Versions(), s.authNames())
s.state = ConnStateConnected
for {
// create a new connection when new hpds-client connected
sctx, cancel := context.WithCancel(ctx)
defer cancel()
connect, e := listener.Accept(sctx)
if e != nil {
log.Errorf("%screate connection error: %v", ServerLogPrefix, e)
return e
}
connID := GetConnId(connect)
log.Infof("%s1/ new connection: %s", ServerLogPrefix, connID)
go func(ctx context.Context, qconn quic.Connection) {
for {
log.Infof("%s2/ waiting for new stream", ServerLogPrefix)
stream, err := qconn.AcceptStream(ctx)
if err != nil {
// if client close the connection, then we should close the connection
// @CC: when Source close the connection, it won't affect connectors
name := "--"
if conn := s.connector.Get(connID); conn != nil {
conn.Close()
// connector
s.connector.Remove(connID)
route := s.router.Route(conn.Metadata())
if route != nil {
route.Remove(connID)
}
name = conn.Name()
}
log.Printf("%s [%s](%s) close the connection: %v", ServerLogPrefix, name, connID, err)
break
}
defer stream.Close()
log.Infof("%s3/ [stream:%d] created, connId=%s", ServerLogPrefix, stream.StreamID(), connID)
// process frames on stream
// c := newContext(connId, stream)
c := newContext(connect, stream)
defer c.Clean()
s.handleConnection(c)
log.Infof("%s4/ [stream:%d] handleConnection DONE", ServerLogPrefix, stream.StreamID())
}
}(sctx, connect)
}
}
// Close will shut down the server.
func (s *Server) Close() error {
if s.router != nil {
s.router.Clean()
}
// connector
if s.connector != nil {
s.connector.Clean()
}
return nil
}
// handle streams on a connection
func (s *Server) handleConnection(c *Context) {
fs := NewFrameStream(c.Stream)
// check update for stream
for {
log.Debugf("%shandleConnection waiting read next...", ServerLogPrefix)
f, err := fs.ReadFrame()
if err != nil {
// if client close connection, will get ApplicationError with code = 0x00
if e, ok := err.(*quic.ApplicationError); ok {
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
// client abort
log.Infof("%sclient close the connection", ServerLogPrefix)
break
} else {
ye := hpds_err.New(hpds_err.Parse(e.ErrorCode), err)
log.Errorf("%s[ERR] %s", ServerLogPrefix, ye)
}
} else if err == io.EOF {
log.Infof("%sthe connection is EOF", ServerLogPrefix)
break
}
if errors.Is(err, net.ErrClosed) {
// if client close the connection, net.ErrClosed will be raised
// by quic-go IdleTimeoutError after connection's KeepAlive config.
log.Warnf("%s[ERR] net.ErrClosed on [handleConnection] %v", ServerLogPrefix, net.ErrClosed)
c.CloseWithError(hpds_err.ErrorCodeClosed, "net.ErrClosed")
break
}
// any error occurred, we should close the stream
// after this, conn.AcceptStream() will raise the error
c.CloseWithError(hpds_err.ErrorCodeUnknown, err.Error())
log.Warnf("%sconnection.Close()", ServerLogPrefix)
break
}
frameType := f.Type()
data := f.Encode()
log.Debugf("%stype=%s, frame[%d]=%# x", ServerLogPrefix, frameType, len(data), frame.Shortly(data))
// add frame to context
context := c.WithFrame(f)
// before frame handlers
for _, handler := range s.beforeHandlers {
if e := handler(context); e != nil {
log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e)
context.CloseWithError(hpds_err.ErrorCodeBeforeHandler, e.Error())
return
}
}
// main handler
if e := s.mainFrameHandler(context); e != nil {
log.Errorf("%smainFrameHandler e: %s", ServerLogPrefix, e)
context.CloseWithError(hpds_err.ErrorCodeMainHandler, e.Error())
return
}
// after frame handler
for _, handler := range s.afterHandlers {
if e := handler(context); e != nil {
log.Errorf("%safterFrameHandler e: %s", ServerLogPrefix, e)
context.CloseWithError(hpds_err.ErrorCodeAfterHandler, e.Error())
return
}
}
}
}
func (s *Server) mainFrameHandler(c *Context) error {
var err error
frameType := c.Frame.Type()
switch frameType {
case frame.TagOfHandshakeFrame:
if err = s.handleHandshakeFrame(c); err != nil {
log.Errorf("%shandleHandshakeFrame err: %s", ServerLogPrefix, err)
// close connections early to avoid resource consumption
if c.Stream != nil {
goawayFrame := frame.NewGoawayFrame(err.Error())
if _, e := c.Stream.Write(goawayFrame.Encode()); e != nil {
log.Errorf("%s write to client[%s] GoawayFrame error:%v", ServerLogPrefix, c.ConnId, e)
return e
}
}
}
// case frame.TagOfPingFrame:
// s.handlePingFrame(mainStream, connection, f.(*frame.PingFrame))
case frame.TagOfDataFrame:
if err = s.handleDataFrame(c); err != nil {
c.CloseWithError(hpds_err.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err))
} else {
conn := s.connector.Get(c.connId)
if conn != nil && conn.ClientType() == ClientTypeProtocolGateway {
f := c.Frame.(*frame.DataFrame)
f.GetMetaFrame().SetMetadata(conn.Metadata().Encode())
s.dispatchToDownStreams(f)
}
// observe data tags back flow
s.handleBackFlowFrame(c)
}
default:
log.Errorf("%serr=%v, frame=%v", ServerLogPrefix, err, frame.Shortly(c.Frame.Encode()))
}
return nil
}
// handle HandShakeFrame
func (s *Server) handleHandshakeFrame(c *Context) error {
f := c.Frame.(*frame.HandshakeFrame)
log.Debugf("%sGOT HandshakeFrame : %# x", ServerLogPrefix, f)
// basic info
connId := c.ConnId()
clientId := f.ClientId
clientType := ClientType(f.ClientType)
stream := c.Stream
// credential
log.Debugf("%sClientType=%# x is %s, ClientId=%s, Credential=%s", ServerLogPrefix, f.ClientType, ClientType(f.ClientType), clientId, authName(f.AuthName()))
// authenticate
if !s.authenticate(f) {
err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName()))
// return err
log.Debugf("%s <%s> [%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, connId)
rejectedFrame := frame.NewRejectedFrame(err.Error())
if _, err = stream.Write(rejectedFrame.Encode()); err != nil {
log.Debugf("%s write to <%s> [%s](%s) RejectedFrame error:%v", ServerLogPrefix, clientType, f.Name, connId, err)
return err
}
return nil
}
// client type
var conn Connection
switch clientType {
case ClientTypeProtocolGateway, ClientTypeStreamFunction:
// metadata
metadata, err := s.metadataBuilder.Build(f)
if err != nil {
return err
}
conn = newConnection(f.Name, f.ClientId, clientType, metadata, stream, f.ObserveDataTags)
if clientType == ClientTypeStreamFunction {
// route
route := s.router.Route(metadata)
if route == nil {
return errors.New("handleHandshakeFrame route is nil")
}
if e1 := route.Add(connId, f.Name, f.ObserveDataTags); e1 != nil {
// duplicate name
if e2, ok := e1.(hpds_err.DuplicateNameError); ok {
existsConnID := e2.ConnId()
if conn = s.connector.Get(existsConnID); conn != nil {
log.Debugf("%s%s, write to SFN[%s](%s) GoawayFrame", ServerLogPrefix, e2.Error(), f.Name, existsConnID)
goawayFrame := frame.NewGoawayFrame(e2.Error())
if e3 := conn.Write(goawayFrame); e3 != nil {
log.Errorf("%s write to SFN[%s] GoawayFrame error:%v", ServerLogPrefix, f.Name, e3)
return e3
}
}
} else {
return e1
}
}
}
case ClientTypeMessageQueue:
conn = newConnection(f.Name, f.ClientId, clientType, nil, stream, f.ObserveDataTags)
default:
// unknown client type
s.connector.Remove(connId)
err := fmt.Errorf("Illegal ClientType: %#x", f.ClientType)
c.CloseWithError(hpds_err.ErrorCodeUnknownClient, err.Error())
return err
}
s.connector.Add(connId, conn)
log.Printf("%s <%s> [%s][%s](%s) is connected!", ServerLogPrefix, clientType, f.Name, clientId, connId)
return nil
}
// handle handleGoawayFrame
func (s *Server) handleGoawayFrame(c *Context) error {
f := c.Frame.(*frame.GoawayFrame)
log.Debugf("%s GOT GoawayFrame code=%d, message==%s", ServerLogPrefix, hpds_err.ErrorCodeGoaway, f.Message())
// c.CloseWithError(f.Code(), f.Message())
_, err := c.Stream.Write(f.Encode())
return err
}
// will reuse quic-go's keep-alive feature
// func (s *Server) handlePingFrame(stream quic.Stream, conn quic.Connection, f *frame.PingFrame) error {
// log.Infof("%s------> GOT PingFrame : %# x", ServerLogPrefix, f)
// return nil
// }
func (s *Server) handleDataFrame(c *Context) error {
// counter +1
atomic.AddInt64(&s.counterOfDataFrame, 1)
// currentIssuer := f.GetIssuer()
fromId := c.ConnId()
from := s.connector.Get(fromId)
if from == nil {
log.Warnf("%shandleDataFrame connector cannot find %s", ServerLogPrefix, fromId)
return fmt.Errorf("handleDataFrame connector cannot find %s", fromId)
}
f := c.Frame.(*frame.DataFrame)
var metadata Metadata
if from.ClientType() == ClientTypeMessageQueue {
m, err := s.metadataBuilder.Decode(f.GetMetaFrame().Metadata())
if err != nil {
return err
}
metadata = m
} else {
metadata = from.Metadata()
}
// route
route := s.router.Route(metadata)
if route == nil {
log.Warnf("%shandleDataFrame route is nil", ServerLogPrefix)
return fmt.Errorf("handleDataFrame route is nil")
}
// get stream function connection ids from route
connIds := route.GetForwardRoutes(f.GetDataTag())
for _, toId := range connIds {
conn := s.connector.Get(toId)
if conn == nil {
log.Errorf("%sconn is nil: (%s)", ServerLogPrefix, toId)
continue
}
to := conn.Name()
log.Debugf("%shandleDataFrame tag=%#x tid=%s, counter=%d, from=[%s](%s), to=[%s](%s)", ServerLogPrefix, f.Tag(), f.TransactionId(), s.counterOfDataFrame, from.Name(), fromId, to, toId)
// write data frame to stream
if err := conn.Write(f); err != nil {
log.Warnf("%shandleDataFrame conn.Write tag=%#x tid=%s, from=[%s](%s), to=[%s](%s), %v", ServerLogPrefix, f.Tag(), f.TransactionId(), from.Name(), fromId, to, toId, err)
}
}
return nil
}
func (s *Server) handleBackFlowFrame(c *Context) error {
f := c.Frame.(*frame.DataFrame)
tag := f.GetDataTag()
carriage := f.GetCarriage()
sourceId := f.SourceId()
// write to Protocol Gateway with BackFlowFrame
bf := frame.NewBackFlowFrame(tag, carriage)
sourceConns := s.connector.GetProtocolGatewayConnections(sourceId, tag)
// conn := s.connector.Get(c.connId)
// logger.Printf("%s handleBackFlowFrame tag:%#v --> source:%s, result=%s", ServerLogPrefix, tag, sourceId, carriage)
for _, source := range sourceConns {
if source != nil {
log.Debugf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, result=%# x", ServerLogPrefix, tag, sourceId, frame.Shortly(carriage))
if err := source.Write(bf); err != nil {
log.Errorf("%s handleBackFlowFrame tag:%#v --> Protocol Gateway:%s, error=%v", ServerLogPrefix, tag, sourceId, err)
return err
}
}
}
return nil
}
// StatsFunctions returns the sfn stats of server.
func (s *Server) StatsFunctions() map[string]string {
return s.connector.GetSnapshot()
}
// StatsCounter returns how many DataFrames pass through server.
func (s *Server) StatsCounter() int64 {
return s.counterOfDataFrame
}
// DownStreams return all the downstream servers.
func (s *Server) DownStreams() map[string]*Client {
return s.downStreams
}
// ConfigRouter is used to set router by Message Queue
func (s *Server) ConfigRouter(router Router) {
s.mu.Lock()
s.router = router
log.Debugf("%sconfig router is %#v", ServerLogPrefix, router)
s.mu.Unlock()
}
// ConfigMetadataBuilder is used to set metadataBuilder by Message Queue
func (s *Server) ConfigMetadataBuilder(builder MetadataBuilder) {
s.mu.Lock()
s.metadataBuilder = builder
log.Debugf("%sconfig metadataBuilder is %#v", ServerLogPrefix, builder)
s.mu.Unlock()
}
// AddDownstreamServer add a downstream server to this server. all the DataFrames will be
// dispatch to all the downStreams.
func (s *Server) AddDownstreamServer(addr string, c *Client) {
s.mu.Lock()
s.downStreams[addr] = c
s.mu.Unlock()
}
// dispatch every DataFrames to all downStreams
func (s *Server) dispatchToDownStreams(df *frame.DataFrame) {
for addr, ds := range s.downStreams {
log.Debugf("%sdispatching to [%s]: %# x", ServerLogPrefix, addr, df.Tag())
ds.WriteFrame(df)
}
}
// GetConnId get quic connection id
func GetConnId(conn quic.Connection) string {
return conn.RemoteAddr().String()
}
func (s *Server) initOptions() {
// defaults
}
func (s *Server) validateRouter() error {
if s.router == nil {
return errors.New("server's router is nil")
}
return nil
}
func (s *Server) validateMetadataBuilder() error {
if s.metadataBuilder == nil {
return errors.New("server's metadataBuilder is nil")
}
return nil
}
// Options returns the options of server.
func (s *Server) Options() ServerOptions {
return s.opts
}
// Connector returns the connector of server.
func (s *Server) Connector() Connector {
return s.connector
}
// SetBeforeHandlers set the before handlers of server.
func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) {
s.beforeHandlers = append(s.beforeHandlers, handlers...)
}
// SetAfterHandlers set the after handlers of server.
func (s *Server) SetAfterHandlers(handlers ...FrameHandler) {
s.afterHandlers = append(s.afterHandlers, handlers...)
}
func (s *Server) authNames() []string {
if len(s.opts.Auths) == 0 {
return []string{"none"}
}
result := make([]string, 0)
for _, auth := range s.opts.Auths {
result = append(result, auth.Name())
}
return result
}
func (s *Server) authenticate(f *frame.HandshakeFrame) bool {
if len(s.opts.Auths) > 0 {
for _, auth := range s.opts.Auths {
if f.AuthName() == auth.Name() {
isAuthenticated := auth.Authenticate(f.AuthPayload())
if isAuthenticated {
log.Debugf("%sauthenticated==%v", ServerLogPrefix, isAuthenticated)
return isAuthenticated
}
}
}
return false
}
return true
}
func mode() string {
if pkgtls.IsDev() {
return "DEVELOPMENT"
}
return "PRODUCTION"
}
func authName(name string) string {
if name == "" {
return "empty"
}
return name
}

56
server_options.go Normal file
View File

@ -0,0 +1,56 @@
package network
import (
"crypto/tls"
"net"
"git.hpds.cc/Component/network/auth"
"github.com/lucas-clemente/quic-go"
)
// ServerOptions are the options for HPDS Network server.
type ServerOptions struct {
QuicConfig *quic.Config
TLSConfig *tls.Config
Addr string
Auths []auth.Authentication
Conn net.PacketConn
}
// WithAddr sets the server address.
func WithAddr(addr string) ServerOption {
return func(o *ServerOptions) {
o.Addr = addr
}
}
// WithAuth sets the server authentication method.
func WithAuth(name string, args ...string) ServerOption {
return func(o *ServerOptions) {
if auth, ok := auth.GetAuth(name); ok {
auth.Init(args...)
o.Auths = append(o.Auths, auth)
}
}
}
// WithServerTLSConfig sets the TLS configuration for the server.
func WithServerTLSConfig(tc *tls.Config) ServerOption {
return func(o *ServerOptions) {
o.TLSConfig = tc
}
}
// WithServerQuicConfig sets the QUIC configuration for the server.
func WithServerQuicConfig(qc *quic.Config) ServerOption {
return func(o *ServerOptions) {
o.QuicConfig = qc
}
}
// WithConn sets the connection for the server.
func WithConn(conn net.PacketConn) ServerOption {
return func(o *ServerOptions) {
o.Conn = conn
}
}

229
tls/tls.go Normal file
View File

@ -0,0 +1,229 @@
package tls
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io/ioutil"
"math/big"
"net"
"os"
"time"
)
var isDev bool
// CreateServerTLSConfig creates server tls config.
func CreateServerTLSConfig(host string) (*tls.Config, error) {
// development mode
if isDev {
tc, err := developmentTLSConfig(host)
if err != nil {
return nil, err
}
return tc, nil
}
// production mode
// ca pool
pool, err := getCACertPool()
if err != nil {
return nil, err
}
// server certificate
tlsCert, err := getCertAndKey()
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{*tlsCert},
ClientCAs: pool,
ClientAuth: tls.RequireAndVerifyClientCert,
NextProtos: []string{"hpds"},
}, nil
}
// CreateClientTLSConfig creates client tls config.
func CreateClientTLSConfig() (*tls.Config, error) {
// development mode
if isDev {
return &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"hpds"},
ClientSessionCache: tls.NewLRUClientSessionCache(64),
}, nil
}
// production mode
pool, err := getCACertPool()
if err != nil {
return nil, err
}
tlsCert, err := getCertAndKey()
if err != nil {
return nil, err
}
return &tls.Config{
InsecureSkipVerify: false,
Certificates: []tls.Certificate{*tlsCert},
RootCAs: pool,
NextProtos: []string{"hpds"},
ClientSessionCache: tls.NewLRUClientSessionCache(0),
}, nil
}
func getCACertPool() (*x509.CertPool, error) {
var err error
var caCert []byte
caCertPath := os.Getenv("HPDS_TLS_CACERT_FILE")
if len(caCertPath) == 0 {
return nil, errors.New("tls: must provide CA certificate on production mode, you can configure this via environment variables: `HPDS_TLS_CACERT_FILE`")
}
caCert, err = ioutil.ReadFile(caCertPath)
if err != nil {
return nil, err
}
if len(caCert) == 0 {
return nil, errors.New("tls: cannot load CA cert")
}
pool := x509.NewCertPool()
if ok := pool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("tls: cannot append CA cert to pool")
}
return pool, nil
}
func getCertAndKey() (*tls.Certificate, error) {
var err error
var cert, key []byte
certPath := os.Getenv("HPDS_TLS_CERT_FILE")
keyPath := os.Getenv("HPDS_TLS_KEY_FILE")
if len(certPath) == 0 || len(keyPath) == 0 {
return nil, errors.New("tls: must provide certificate on production mode, you can configure this via environment variables: `HPDS_TLS_CERT_FILE` and `HPDS_TLS_KEY_FILE`")
}
// certificate
cert, err = ioutil.ReadFile(certPath)
if err != nil {
return nil, err
}
// private key
key, err = ioutil.ReadFile(keyPath)
if err != nil {
return nil, err
}
if len(cert) == 0 || len(key) == 0 {
return nil, errors.New("tls: cannot load tls cert/key")
}
tlsCert, err := tls.X509KeyPair(cert, key)
if err != nil {
return nil, err
}
return &tlsCert, nil
}
// IsDev development mode
func IsDev() bool {
return isDev
}
// developmentTLSConfig Setup a bare-bones TLS config for the server
func developmentTLSConfig(host ...string) (*tls.Config, error) {
tlsCert, err := generateCertificate(host...)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientSessionCache: tls.NewLRUClientSessionCache(1),
NextProtos: []string{"hpds"},
}, nil
}
func generateCertificate(host ...string) (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
notBefore := time.Now()
notAfter := notBefore.Add(time.Hour * 24 * 365)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return tls.Certificate{}, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"HPDS"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
}
for _, h := range host {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
// create public key
certOut := bytes.NewBuffer(nil)
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
if err != nil {
return tls.Certificate{}, err
}
// create private key
keyOut := bytes.NewBuffer(nil)
b, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
err = pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
if err != nil {
return tls.Certificate{}, err
}
return tls.X509KeyPair(certOut.Bytes(), keyOut.Bytes())
}
func init() {
env := os.Getenv("HPDS_ENV")
isDev = len(env) == 0 || env != "production"
}

10
workflow.go Normal file
View File

@ -0,0 +1,10 @@
package network
// Workflow describes stream function workflows.
type Workflow struct {
// Seq represents the sequence id when executing workflows.
Seq int
// Token represents the name of workflow.
Name string
}