465 lines
13 KiB
Go
465 lines
13 KiB
Go
package network
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"git.hpds.cc/Component/network/hpds_err"
|
||
"git.hpds.cc/Component/network/id"
|
||
pkgtls "git.hpds.cc/Component/network/tls"
|
||
"net"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/lucas-clemente/quic-go"
|
||
|
||
"git.hpds.cc/Component/network/auth"
|
||
"git.hpds.cc/Component/network/frame"
|
||
"git.hpds.cc/Component/network/log"
|
||
)
|
||
|
||
// ClientOption client options
|
||
type ClientOption func(*ClientOptions)
|
||
|
||
// ConnState describes the state of the connection.
|
||
type ConnState = string
|
||
|
||
// Client is the abstraction of a HPDS-Client. a HPDS-Client can be
|
||
// Protocol Gateway, Message Queue or StreamFunction.
|
||
type Client struct {
|
||
name string // name of the client
|
||
clientId string // id of the client
|
||
clientType ClientType // type of the connection
|
||
conn quic.Connection // quic connection
|
||
stream quic.Stream // quic stream
|
||
state ConnState // state of the connection
|
||
processor func(*frame.DataFrame) // functions to invoke when data arrived
|
||
receiver func(*frame.BackFlowFrame) // functions to invoke when data is processed
|
||
addr string // the address of server connected to
|
||
mu sync.Mutex
|
||
opts ClientOptions
|
||
localAddr string // client local addr, it will be changed on reconnect
|
||
logger log.Logger
|
||
errChan chan error
|
||
closeChan chan bool
|
||
closed bool
|
||
}
|
||
|
||
// NewClient creates a new HPDS-Client.
|
||
func NewClient(appName string, connType ClientType, opts ...ClientOption) *Client {
|
||
c := &Client{
|
||
name: appName,
|
||
clientId: id.New(),
|
||
clientType: connType,
|
||
state: ConnStateReady,
|
||
opts: ClientOptions{},
|
||
errChan: make(chan error),
|
||
closeChan: make(chan bool),
|
||
}
|
||
_ = c.Init(opts...)
|
||
once.Do(func() {
|
||
c.init()
|
||
})
|
||
|
||
return c
|
||
}
|
||
|
||
// Init the options.
|
||
func (c *Client) Init(opts ...ClientOption) error {
|
||
for _, o := range opts {
|
||
o(&c.opts)
|
||
}
|
||
return c.initOptions()
|
||
}
|
||
|
||
// Connect connects to HPDS-MessageQueue.
|
||
func (c *Client) Connect(ctx context.Context, addr string) error {
|
||
|
||
// TODO: refactor this later as a Connection Manager
|
||
// reconnect
|
||
// for download mq
|
||
// If you do not check for errors, the connection will be automatically reconnected
|
||
go c.reconnect(ctx, addr)
|
||
|
||
// connect
|
||
if err := c.connect(ctx, addr); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (c *Client) connect(ctx context.Context, addr string) error {
|
||
c.addr = addr
|
||
c.state = ConnStateConnecting
|
||
|
||
// create quic connection
|
||
conn, err := quic.DialAddrContext(ctx, addr, c.opts.TLSConfig, c.opts.QuicConfig)
|
||
if err != nil {
|
||
c.state = ConnStateDisconnected
|
||
return err
|
||
}
|
||
|
||
// quic stream
|
||
stream, err := conn.OpenStreamSync(ctx)
|
||
if err != nil {
|
||
c.state = ConnStateDisconnected
|
||
return err
|
||
}
|
||
|
||
c.stream = stream
|
||
c.conn = conn
|
||
|
||
c.state = ConnStateAuthenticating
|
||
// send handshake
|
||
handshake := frame.NewHandshakeFrame(
|
||
c.name,
|
||
c.clientId,
|
||
byte(c.clientType),
|
||
c.opts.ObserveDataTags,
|
||
c.opts.Credential.Name(),
|
||
c.opts.Credential.Payload(),
|
||
)
|
||
err = c.WriteFrame(handshake)
|
||
if err != nil {
|
||
c.state = ConnStateRejected
|
||
return err
|
||
}
|
||
c.state = ConnStateConnected
|
||
c.localAddr = c.conn.LocalAddr().String()
|
||
|
||
c.logger.Printf("%s [%s][%s](%s) is connected to HPDS-MQ %s", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr)
|
||
|
||
// receiving frames
|
||
go c.handleFrame()
|
||
|
||
return nil
|
||
}
|
||
|
||
// handleFrame handles the logic when receiving frame from server.
|
||
func (c *Client) handleFrame() {
|
||
// transform raw QUIC stream to wire format
|
||
fs := NewFrameStream(c.stream)
|
||
for {
|
||
c.logger.Debugf("%shandleFrame connection state=%v", ClientLogPrefix, c.state)
|
||
// this will block until a frame is received
|
||
f, err := fs.ReadFrame()
|
||
if err != nil {
|
||
defer func() {
|
||
_ = c.stream.Close()
|
||
}()
|
||
|
||
c.logger.Debugf("%shandleFrame(): %T | %v", ClientLogPrefix, err, err)
|
||
if e, ok := err.(*quic.IdleTimeoutError); ok {
|
||
c.logger.Errorf("%sconnection timeout, err=%v, mq addr=%s", ClientLogPrefix, e, c.addr)
|
||
c.setState(ConnStateDisconnected)
|
||
} else if e, ok := err.(*quic.ApplicationError); ok {
|
||
c.logger.Infof("%sapplication error, err=%v, errcode=%v", ClientLogPrefix, e, e.ErrorCode)
|
||
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeRejected) {
|
||
// if connection is rejected(eg: authenticate fails) from server
|
||
c.logger.Errorf("%sIllegal client, server rejected.", ClientLogPrefix)
|
||
c.setState(ConnStateRejected)
|
||
break
|
||
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
|
||
// client abort
|
||
c.logger.Infof("%sclient close the connection", ClientLogPrefix)
|
||
c.setState(ConnStateAborted)
|
||
break
|
||
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeGoaway) {
|
||
// server goaway
|
||
c.logger.Infof("%sserver goaway the connection", ClientLogPrefix)
|
||
c.setState(ConnStateGoaway)
|
||
break
|
||
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeHandshake) {
|
||
// handshake
|
||
c.logger.Errorf("%shandshake fails", ClientLogPrefix)
|
||
c.setState(ConnStateRejected)
|
||
break
|
||
}
|
||
} else if errors.Is(err, net.ErrClosed) {
|
||
// if client close the connection, net.ErrClosed will be raised
|
||
c.logger.Errorf("%sconnection is closed, err=%v", ClientLogPrefix, err)
|
||
c.setState(ConnStateDisconnected)
|
||
// by quic-go IdleTimeoutError after connection's KeepAlive config.
|
||
break
|
||
} else {
|
||
// any error occurred, we should close the stream
|
||
// after this, conn.AcceptStream() will raise the error
|
||
c.setState(ConnStateClosed)
|
||
_ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeUnknown), err.Error())
|
||
c.logger.Errorf("%sunknown error occurred, err=%v, state=%v", ClientLogPrefix, err, c.getState())
|
||
break
|
||
}
|
||
}
|
||
if f == nil {
|
||
break
|
||
}
|
||
// read frame
|
||
// first, get frame type
|
||
frameType := f.Type()
|
||
c.logger.Debugf("%stype=%s, frame=%# x", ClientLogPrefix, frameType, frame.Shortly(f.Encode()))
|
||
switch frameType {
|
||
case frame.TagOfHandshakeFrame:
|
||
if v, ok := f.(*frame.HandshakeFrame); ok {
|
||
c.logger.Debugf("%sreceive HandshakeFrame, name=%v", ClientLogPrefix, v.Name)
|
||
}
|
||
case frame.TagOfPongFrame:
|
||
c.setState(ConnStatePong)
|
||
case frame.TagOfAcceptedFrame:
|
||
c.setState(ConnStateAccepted)
|
||
case frame.TagOfRejectedFrame:
|
||
c.setState(ConnStateRejected)
|
||
if v, ok := f.(*frame.RejectedFrame); ok {
|
||
c.logger.Errorf("%s receive RejectedFrame, message=%s", ClientLogPrefix, v.Message())
|
||
_ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeRejected), v.Message())
|
||
c.errChan <- errors.New(v.Message())
|
||
break
|
||
}
|
||
case frame.TagOfGoawayFrame:
|
||
c.setState(ConnStateGoaway)
|
||
if v, ok := f.(*frame.GoawayFrame); ok {
|
||
c.logger.Errorf("%s️ receive GoawayFrame, message=%s", ClientLogPrefix, v.Message())
|
||
_ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeGoaway), v.Message())
|
||
c.errChan <- errors.New(v.Message())
|
||
break
|
||
}
|
||
case frame.TagOfDataFrame: // DataFrame carries user's data
|
||
if v, ok := f.(*frame.DataFrame); ok {
|
||
c.setState(ConnStateTransportData)
|
||
c.logger.Debugf("%sreceive DataFrame, tag=%#x, tid=%s, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.TransactionId(), v.GetCarriage())
|
||
if c.processor == nil {
|
||
c.logger.Warnf("%sprocessor is nil", ClientLogPrefix)
|
||
} else {
|
||
c.processor(v)
|
||
}
|
||
}
|
||
case frame.TagOfBackFlowFrame:
|
||
if v, ok := f.(*frame.BackFlowFrame); ok {
|
||
c.logger.Debugf("%sreceive BackFlowFrame, tag=%#x, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.GetCarriage())
|
||
if c.receiver == nil {
|
||
c.logger.Warnf("%sreceiver is nil", ClientLogPrefix)
|
||
} else {
|
||
c.setState(ConnStateBackFlow)
|
||
c.receiver(v)
|
||
}
|
||
}
|
||
default:
|
||
c.logger.Errorf("%sunknown signal", ClientLogPrefix)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Close the client.
|
||
func (c *Client) Close() (err error) {
|
||
if c.conn != nil {
|
||
c.logger.Printf("%sclose the connection, name:%s, id:%s, addr:%s", ClientLogPrefix, c.name, c.clientId, c.conn.RemoteAddr().String())
|
||
}
|
||
if c.stream != nil {
|
||
err = c.stream.Close()
|
||
if err != nil {
|
||
c.logger.Errorf("%s stream.Close(): %v", ClientLogPrefix, err)
|
||
}
|
||
}
|
||
if c.conn != nil {
|
||
err = c.conn.CloseWithError(0, "client-ask-to-close-this-connection")
|
||
if err != nil {
|
||
c.logger.Errorf("%s connection.Close(): %v", ClientLogPrefix, err)
|
||
}
|
||
}
|
||
// close channel
|
||
c.mu.Lock()
|
||
if !c.closed {
|
||
close(c.errChan)
|
||
close(c.closeChan)
|
||
c.closed = true
|
||
}
|
||
c.mu.Unlock()
|
||
|
||
return err
|
||
}
|
||
|
||
// WriteFrame writes a frame to the connection, gurantee threadsafe.
|
||
func (c *Client) WriteFrame(frm frame.Frame) error {
|
||
// write on QUIC stream
|
||
if c.stream == nil {
|
||
return errors.New("stream is nil")
|
||
}
|
||
if c.state == ConnStateDisconnected || c.state == ConnStateRejected {
|
||
return fmt.Errorf("client connection state is %s", c.state)
|
||
}
|
||
c.logger.Debugf("%s[%s](%s)@%s WriteFrame() will write frame: %s", ClientLogPrefix, c.name, c.localAddr, c.state, frm.Type())
|
||
|
||
data := frm.Encode()
|
||
// emit raw bytes of Frame
|
||
c.mu.Lock()
|
||
n, err := c.stream.Write(data)
|
||
c.mu.Unlock()
|
||
c.logger.Debugf("%sWriteFrame() wrote n=%d, data=%# x", ClientLogPrefix, n, frame.Shortly(data))
|
||
if err != nil {
|
||
c.setState(ConnStateDisconnected)
|
||
// c.state = ConnStateDisconnected
|
||
if e, ok := err.(*quic.IdleTimeoutError); ok {
|
||
c.logger.Errorf("%sWriteFrame() connection timeout, err=%v", ClientLogPrefix, e)
|
||
} else {
|
||
c.logger.Errorf("%sWriteFrame() wrote error=%v", ClientLogPrefix, err)
|
||
return err
|
||
}
|
||
}
|
||
if n != len(data) {
|
||
err := errors.New("[client] hpds Client .Write() wrote error")
|
||
c.logger.Errorf("%s error:%v", ClientLogPrefix, err)
|
||
return err
|
||
}
|
||
return err
|
||
}
|
||
|
||
// update connection state
|
||
func (c *Client) setState(state ConnState) {
|
||
c.logger.Debugf("setState to:%s", state)
|
||
c.mu.Lock()
|
||
c.state = state
|
||
c.mu.Unlock()
|
||
}
|
||
|
||
// getState get connection state
|
||
func (c *Client) getState() ConnState {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
return c.state
|
||
}
|
||
|
||
// update connection local addr
|
||
func (c *Client) setLocalAddr(addr string) {
|
||
c.mu.Lock()
|
||
c.localAddr = addr
|
||
c.mu.Unlock()
|
||
}
|
||
|
||
// SetDataFrameObserver sets the data frame handler.
|
||
func (c *Client) SetDataFrameObserver(fn func(*frame.DataFrame)) {
|
||
c.processor = fn
|
||
c.logger.Debugf("%sSetDataFrameObserver(%v)", ClientLogPrefix, c.processor)
|
||
}
|
||
|
||
// SetBackFlowFrameObserver sets the backflow frame handler.
|
||
func (c *Client) SetBackFlowFrameObserver(fn func(*frame.BackFlowFrame)) {
|
||
c.receiver = fn
|
||
c.logger.Debugf("%sSetBackFlowFrameObserver(%v)", ClientLogPrefix, c.receiver)
|
||
}
|
||
|
||
// reconnect the connection between client and server.
|
||
func (c *Client) reconnect(ctx context.Context, addr string) {
|
||
t := time.NewTicker(1 * time.Second)
|
||
defer t.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
c.logger.Debugf("%s[%s](%s) context.Done()", ClientLogPrefix, c.name, c.localAddr)
|
||
return
|
||
case <-c.closeChan:
|
||
c.logger.Debugf("%s[%s](%s) close channel", ClientLogPrefix, c.name, c.localAddr)
|
||
return
|
||
case <-t.C:
|
||
if c.getState() == ConnStateDisconnected {
|
||
c.logger.Printf("%s[%s][%s](%s) is reconnecting to HPDS-MQ %s...", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr)
|
||
err := c.connect(ctx, addr)
|
||
if err != nil {
|
||
c.logger.Errorf("%s[%s][%s](%s) reconnect error:%v", ClientLogPrefix, c.name, c.clientId, c.localAddr, err)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func (c *Client) init() {
|
||
// // tracing
|
||
// _, _, err := tracing.NewTracerProvider(c.name)
|
||
// if err != nil {
|
||
// logger.Errorf("tracing: %v", err)
|
||
// }
|
||
}
|
||
|
||
// ServerAddr returns the address of the server.
|
||
func (c *Client) ServerAddr() string {
|
||
return c.addr
|
||
}
|
||
|
||
// initOptions init options defaults
|
||
func (c *Client) initOptions() error {
|
||
// logger
|
||
if c.logger == nil {
|
||
if c.opts.Logger != nil {
|
||
c.logger = c.opts.Logger
|
||
} else {
|
||
c.logger = log.Default()
|
||
}
|
||
}
|
||
// observe tag list
|
||
if c.opts.ObserveDataTags == nil {
|
||
c.opts.ObserveDataTags = make([]byte, 0)
|
||
}
|
||
// credential
|
||
if c.opts.Credential == nil {
|
||
c.opts.Credential = auth.NewCredential("")
|
||
}
|
||
// tls config
|
||
if c.opts.TLSConfig == nil {
|
||
tc, err := pkgtls.CreateClientTLSConfig()
|
||
if err != nil {
|
||
c.logger.Errorf("%sCreateClientTLSConfig: %v", ClientLogPrefix, err)
|
||
return err
|
||
}
|
||
c.opts.TLSConfig = tc
|
||
}
|
||
// quic config
|
||
if c.opts.QuicConfig == nil {
|
||
c.opts.QuicConfig = &quic.Config{
|
||
Versions: []quic.VersionNumber{quic.Version1, quic.VersionDraft29},
|
||
MaxIdleTimeout: time.Second * 40,
|
||
KeepAlivePeriod: time.Second * 20,
|
||
MaxIncomingStreams: 1000,
|
||
MaxIncomingUniStreams: 1000,
|
||
HandshakeIdleTimeout: time.Second * 3,
|
||
InitialStreamReceiveWindow: 1024 * 1024 * 2,
|
||
InitialConnectionReceiveWindow: 1024 * 1024 * 2,
|
||
TokenStore: quic.NewLRUTokenStore(10, 5),
|
||
DisablePathMTUDiscovery: true,
|
||
}
|
||
}
|
||
// credential
|
||
if c.opts.Credential != nil {
|
||
c.logger.Printf("%suse credential: [%s]", ClientLogPrefix, c.opts.Credential.Name())
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// SetObserveDataTags set the data tag list that will be observed.
|
||
// Deprecated: use hpds.WithObserveDataTags instead
|
||
func (c *Client) SetObserveDataTags(tag ...byte) {
|
||
c.opts.ObserveDataTags = append(c.opts.ObserveDataTags, tag...)
|
||
}
|
||
|
||
// Logger get client's logger instance, you can customize this using `hpds.WithLogger`
|
||
func (c *Client) Logger() log.Logger {
|
||
return c.logger
|
||
}
|
||
|
||
// SetErrorHandler set error handler
|
||
func (c *Client) SetErrorHandler(fn func(err error)) {
|
||
if fn != nil {
|
||
go func() {
|
||
err := <-c.errChan
|
||
if err != nil {
|
||
fn(err)
|
||
}
|
||
}()
|
||
}
|
||
}
|
||
|
||
// ClientId return the client ID
|
||
func (c *Client) ClientId() string {
|
||
return c.clientId
|
||
}
|