network/client.go

465 lines
13 KiB
Go
Raw Normal View History

2022-10-11 17:36:09 +08:00
package network
import (
"context"
"errors"
"fmt"
"git.hpds.cc/Component/network/hpds_err"
"git.hpds.cc/Component/network/id"
pkgtls "git.hpds.cc/Component/network/tls"
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go"
"git.hpds.cc/Component/network/auth"
"git.hpds.cc/Component/network/frame"
"git.hpds.cc/Component/network/log"
)
// ClientOption client options
type ClientOption func(*ClientOptions)
// ConnState describes the state of the connection.
type ConnState = string
// Client is the abstraction of a HPDS-Client. a HPDS-Client can be
// Protocol Gateway, Message Queue or StreamFunction.
type Client struct {
name string // name of the client
clientId string // id of the client
clientType ClientType // type of the connection
conn quic.Connection // quic connection
stream quic.Stream // quic stream
state ConnState // state of the connection
processor func(*frame.DataFrame) // functions to invoke when data arrived
receiver func(*frame.BackFlowFrame) // functions to invoke when data is processed
addr string // the address of server connected to
mu sync.Mutex
opts ClientOptions
localAddr string // client local addr, it will be changed on reconnect
logger log.Logger
errChan chan error
closeChan chan bool
closed bool
}
// NewClient creates a new HPDS-Client.
func NewClient(appName string, connType ClientType, opts ...ClientOption) *Client {
c := &Client{
name: appName,
clientId: id.New(),
clientType: connType,
state: ConnStateReady,
opts: ClientOptions{},
errChan: make(chan error),
closeChan: make(chan bool),
}
2023-03-10 23:49:52 +08:00
_ = c.Init(opts...)
2022-10-11 17:36:09 +08:00
once.Do(func() {
c.init()
})
return c
}
// Init the options.
func (c *Client) Init(opts ...ClientOption) error {
for _, o := range opts {
o(&c.opts)
}
return c.initOptions()
}
// Connect connects to HPDS-MessageQueue.
func (c *Client) Connect(ctx context.Context, addr string) error {
// TODO: refactor this later as a Connection Manager
// reconnect
// for download mq
// If you do not check for errors, the connection will be automatically reconnected
go c.reconnect(ctx, addr)
// connect
if err := c.connect(ctx, addr); err != nil {
return err
}
return nil
}
func (c *Client) connect(ctx context.Context, addr string) error {
c.addr = addr
c.state = ConnStateConnecting
// create quic connection
conn, err := quic.DialAddrContext(ctx, addr, c.opts.TLSConfig, c.opts.QuicConfig)
if err != nil {
c.state = ConnStateDisconnected
return err
}
// quic stream
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
c.state = ConnStateDisconnected
return err
}
c.stream = stream
c.conn = conn
c.state = ConnStateAuthenticating
// send handshake
handshake := frame.NewHandshakeFrame(
c.name,
c.clientId,
byte(c.clientType),
c.opts.ObserveDataTags,
c.opts.Credential.Name(),
c.opts.Credential.Payload(),
)
err = c.WriteFrame(handshake)
if err != nil {
c.state = ConnStateRejected
return err
}
c.state = ConnStateConnected
c.localAddr = c.conn.LocalAddr().String()
c.logger.Printf("%s [%s][%s](%s) is connected to HPDS-MQ %s", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr)
// receiving frames
go c.handleFrame()
return nil
}
// handleFrame handles the logic when receiving frame from server.
func (c *Client) handleFrame() {
// transform raw QUIC stream to wire format
fs := NewFrameStream(c.stream)
for {
c.logger.Debugf("%shandleFrame connection state=%v", ClientLogPrefix, c.state)
// this will block until a frame is received
f, err := fs.ReadFrame()
if err != nil {
2023-03-10 23:49:52 +08:00
defer func() {
_ = c.stream.Close()
}()
2022-10-11 17:36:09 +08:00
c.logger.Debugf("%shandleFrame(): %T | %v", ClientLogPrefix, err, err)
if e, ok := err.(*quic.IdleTimeoutError); ok {
c.logger.Errorf("%sconnection timeout, err=%v, mq addr=%s", ClientLogPrefix, e, c.addr)
c.setState(ConnStateDisconnected)
} else if e, ok := err.(*quic.ApplicationError); ok {
c.logger.Infof("%sapplication error, err=%v, errcode=%v", ClientLogPrefix, e, e.ErrorCode)
if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeRejected) {
// if connection is rejected(eg: authenticate fails) from server
c.logger.Errorf("%sIllegal client, server rejected.", ClientLogPrefix)
c.setState(ConnStateRejected)
break
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeClientAbort) {
// client abort
c.logger.Infof("%sclient close the connection", ClientLogPrefix)
c.setState(ConnStateAborted)
break
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeGoaway) {
// server goaway
c.logger.Infof("%sserver goaway the connection", ClientLogPrefix)
c.setState(ConnStateGoaway)
break
} else if hpds_err.Is(e.ErrorCode, hpds_err.ErrorCodeHandshake) {
// handshake
c.logger.Errorf("%shandshake fails", ClientLogPrefix)
c.setState(ConnStateRejected)
break
}
} else if errors.Is(err, net.ErrClosed) {
// if client close the connection, net.ErrClosed will be raised
c.logger.Errorf("%sconnection is closed, err=%v", ClientLogPrefix, err)
c.setState(ConnStateDisconnected)
// by quic-go IdleTimeoutError after connection's KeepAlive config.
break
} else {
// any error occurred, we should close the stream
// after this, conn.AcceptStream() will raise the error
c.setState(ConnStateClosed)
2023-03-10 23:49:52 +08:00
_ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeUnknown), err.Error())
2022-10-11 17:36:09 +08:00
c.logger.Errorf("%sunknown error occurred, err=%v, state=%v", ClientLogPrefix, err, c.getState())
break
}
}
if f == nil {
break
}
// read frame
// first, get frame type
frameType := f.Type()
c.logger.Debugf("%stype=%s, frame=%# x", ClientLogPrefix, frameType, frame.Shortly(f.Encode()))
switch frameType {
case frame.TagOfHandshakeFrame:
if v, ok := f.(*frame.HandshakeFrame); ok {
c.logger.Debugf("%sreceive HandshakeFrame, name=%v", ClientLogPrefix, v.Name)
}
case frame.TagOfPongFrame:
c.setState(ConnStatePong)
case frame.TagOfAcceptedFrame:
c.setState(ConnStateAccepted)
case frame.TagOfRejectedFrame:
c.setState(ConnStateRejected)
if v, ok := f.(*frame.RejectedFrame); ok {
c.logger.Errorf("%s receive RejectedFrame, message=%s", ClientLogPrefix, v.Message())
2023-03-10 23:49:52 +08:00
_ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeRejected), v.Message())
2022-10-11 17:36:09 +08:00
c.errChan <- errors.New(v.Message())
break
}
case frame.TagOfGoawayFrame:
c.setState(ConnStateGoaway)
if v, ok := f.(*frame.GoawayFrame); ok {
c.logger.Errorf("%s receive GoawayFrame, message=%s", ClientLogPrefix, v.Message())
2023-03-10 23:49:52 +08:00
_ = c.conn.CloseWithError(hpds_err.To(hpds_err.ErrorCodeGoaway), v.Message())
2022-10-11 17:36:09 +08:00
c.errChan <- errors.New(v.Message())
break
}
case frame.TagOfDataFrame: // DataFrame carries user's data
if v, ok := f.(*frame.DataFrame); ok {
c.setState(ConnStateTransportData)
c.logger.Debugf("%sreceive DataFrame, tag=%#x, tid=%s, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.TransactionId(), v.GetCarriage())
if c.processor == nil {
c.logger.Warnf("%sprocessor is nil", ClientLogPrefix)
} else {
c.processor(v)
}
}
case frame.TagOfBackFlowFrame:
if v, ok := f.(*frame.BackFlowFrame); ok {
c.logger.Debugf("%sreceive BackFlowFrame, tag=%#x, carry=%# x", ClientLogPrefix, v.GetDataTag(), v.GetCarriage())
if c.receiver == nil {
c.logger.Warnf("%sreceiver is nil", ClientLogPrefix)
} else {
c.setState(ConnStateBackFlow)
c.receiver(v)
}
}
default:
c.logger.Errorf("%sunknown signal", ClientLogPrefix)
}
}
}
// Close the client.
func (c *Client) Close() (err error) {
if c.conn != nil {
c.logger.Printf("%sclose the connection, name:%s, id:%s, addr:%s", ClientLogPrefix, c.name, c.clientId, c.conn.RemoteAddr().String())
}
if c.stream != nil {
err = c.stream.Close()
if err != nil {
c.logger.Errorf("%s stream.Close(): %v", ClientLogPrefix, err)
}
}
if c.conn != nil {
err = c.conn.CloseWithError(0, "client-ask-to-close-this-connection")
if err != nil {
c.logger.Errorf("%s connection.Close(): %v", ClientLogPrefix, err)
}
}
// close channel
c.mu.Lock()
if !c.closed {
close(c.errChan)
close(c.closeChan)
c.closed = true
}
c.mu.Unlock()
return err
}
// WriteFrame writes a frame to the connection, gurantee threadsafe.
func (c *Client) WriteFrame(frm frame.Frame) error {
// write on QUIC stream
if c.stream == nil {
return errors.New("stream is nil")
}
if c.state == ConnStateDisconnected || c.state == ConnStateRejected {
return fmt.Errorf("client connection state is %s", c.state)
}
c.logger.Debugf("%s[%s](%s)@%s WriteFrame() will write frame: %s", ClientLogPrefix, c.name, c.localAddr, c.state, frm.Type())
data := frm.Encode()
// emit raw bytes of Frame
c.mu.Lock()
n, err := c.stream.Write(data)
c.mu.Unlock()
c.logger.Debugf("%sWriteFrame() wrote n=%d, data=%# x", ClientLogPrefix, n, frame.Shortly(data))
if err != nil {
c.setState(ConnStateDisconnected)
// c.state = ConnStateDisconnected
if e, ok := err.(*quic.IdleTimeoutError); ok {
c.logger.Errorf("%sWriteFrame() connection timeout, err=%v", ClientLogPrefix, e)
} else {
c.logger.Errorf("%sWriteFrame() wrote error=%v", ClientLogPrefix, err)
return err
}
}
if n != len(data) {
err := errors.New("[client] hpds Client .Write() wrote error")
c.logger.Errorf("%s error:%v", ClientLogPrefix, err)
return err
}
return err
}
// update connection state
func (c *Client) setState(state ConnState) {
c.logger.Debugf("setState to:%s", state)
c.mu.Lock()
c.state = state
c.mu.Unlock()
}
// getState get connection state
func (c *Client) getState() ConnState {
c.mu.Lock()
defer c.mu.Unlock()
return c.state
}
// update connection local addr
func (c *Client) setLocalAddr(addr string) {
c.mu.Lock()
c.localAddr = addr
c.mu.Unlock()
}
// SetDataFrameObserver sets the data frame handler.
func (c *Client) SetDataFrameObserver(fn func(*frame.DataFrame)) {
c.processor = fn
c.logger.Debugf("%sSetDataFrameObserver(%v)", ClientLogPrefix, c.processor)
}
// SetBackFlowFrameObserver sets the backflow frame handler.
func (c *Client) SetBackFlowFrameObserver(fn func(*frame.BackFlowFrame)) {
c.receiver = fn
c.logger.Debugf("%sSetBackFlowFrameObserver(%v)", ClientLogPrefix, c.receiver)
}
// reconnect the connection between client and server.
func (c *Client) reconnect(ctx context.Context, addr string) {
t := time.NewTicker(1 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
c.logger.Debugf("%s[%s](%s) context.Done()", ClientLogPrefix, c.name, c.localAddr)
return
case <-c.closeChan:
c.logger.Debugf("%s[%s](%s) close channel", ClientLogPrefix, c.name, c.localAddr)
return
case <-t.C:
if c.getState() == ConnStateDisconnected {
c.logger.Printf("%s[%s][%s](%s) is reconnecting to HPDS-MQ %s...", ClientLogPrefix, c.name, c.clientId, c.localAddr, addr)
err := c.connect(ctx, addr)
if err != nil {
c.logger.Errorf("%s[%s][%s](%s) reconnect error:%v", ClientLogPrefix, c.name, c.clientId, c.localAddr, err)
}
}
}
}
}
func (c *Client) init() {
// // tracing
// _, _, err := tracing.NewTracerProvider(c.name)
// if err != nil {
// logger.Errorf("tracing: %v", err)
// }
}
// ServerAddr returns the address of the server.
func (c *Client) ServerAddr() string {
return c.addr
}
// initOptions init options defaults
func (c *Client) initOptions() error {
// logger
if c.logger == nil {
if c.opts.Logger != nil {
c.logger = c.opts.Logger
} else {
c.logger = log.Default()
}
}
// observe tag list
if c.opts.ObserveDataTags == nil {
c.opts.ObserveDataTags = make([]byte, 0)
}
// credential
if c.opts.Credential == nil {
c.opts.Credential = auth.NewCredential("")
}
// tls config
if c.opts.TLSConfig == nil {
tc, err := pkgtls.CreateClientTLSConfig()
if err != nil {
c.logger.Errorf("%sCreateClientTLSConfig: %v", ClientLogPrefix, err)
return err
}
c.opts.TLSConfig = tc
}
// quic config
if c.opts.QuicConfig == nil {
c.opts.QuicConfig = &quic.Config{
Versions: []quic.VersionNumber{quic.Version1, quic.VersionDraft29},
MaxIdleTimeout: time.Second * 40,
KeepAlivePeriod: time.Second * 20,
MaxIncomingStreams: 1000,
MaxIncomingUniStreams: 1000,
HandshakeIdleTimeout: time.Second * 3,
InitialStreamReceiveWindow: 1024 * 1024 * 2,
InitialConnectionReceiveWindow: 1024 * 1024 * 2,
TokenStore: quic.NewLRUTokenStore(10, 5),
DisablePathMTUDiscovery: true,
}
}
// credential
if c.opts.Credential != nil {
c.logger.Printf("%suse credential: [%s]", ClientLogPrefix, c.opts.Credential.Name())
}
return nil
}
// SetObserveDataTags set the data tag list that will be observed.
// Deprecated: use hpds.WithObserveDataTags instead
func (c *Client) SetObserveDataTags(tag ...byte) {
c.opts.ObserveDataTags = append(c.opts.ObserveDataTags, tag...)
}
// Logger get client's logger instance, you can customize this using `hpds.WithLogger`
func (c *Client) Logger() log.Logger {
return c.logger
}
// SetErrorHandler set error handler
func (c *Client) SetErrorHandler(fn func(err error)) {
if fn != nil {
go func() {
err := <-c.errChan
if err != nil {
fn(err)
}
}()
}
}
// ClientId return the client ID
func (c *Client) ClientId() string {
return c.clientId
}