hpds_control_center/mq/index.go

459 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mq
import (
"encoding/base64"
"encoding/json"
"fmt"
"git.hpds.cc/Component/logging"
"git.hpds.cc/Component/network/frame"
"github.com/google/uuid"
"go.uber.org/zap"
"hpds_control_center/config"
"hpds_control_center/internal/balance"
"hpds_control_center/internal/minio"
"hpds_control_center/internal/proto"
"hpds_control_center/model"
"os"
"strconv"
"strings"
"sync"
"time"
"git.hpds.cc/pavement/hpds_node"
)
var (
MqList []HpdsMqNode
TaskList = make(map[int64]*TaskItem)
)
type HpdsMqNode struct {
MqType uint
Topic string
Node config.HpdsNode
EndPoint interface{}
Logger *logging.Logger
}
type TaskItem struct {
TaskId int64
TotalCount int64
CompletedCount int64
FailingCount int64
UnfinishedCount int64
LastSendTime int64
}
func must(logger *logging.Logger, err error) {
if err != nil {
if logger != nil {
logger.With(zap.String("web节点", "错误信息")).Error("启动错误", zap.Error(err))
} else {
_, _ = fmt.Fprint(os.Stderr, err)
}
os.Exit(1)
}
}
func NewMqClient(funcs []config.FuncConfig, node config.HpdsNode, logger *logging.Logger) (mqList []HpdsMqNode, err error) {
mqList = make([]HpdsMqNode, 0)
for _, v := range funcs {
switch v.MqType {
case 2:
sf := hpds_node.NewStreamFunction(
v.Name,
hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)),
hpds_node.WithObserveDataTags(frame.Tag(v.DataTag)),
hpds_node.WithCredential(node.Token),
)
err = sf.Connect()
must(logger, err)
nodeInfo := HpdsMqNode{
MqType: 2,
Topic: v.Name,
Node: node,
EndPoint: sf,
}
switch v.Name {
case "task-request":
_ = sf.SetHandler(TaskRequestHandler)
case "task-response":
_ = sf.SetHandler(TaskResponseHandler)
case "task-execute-log":
_ = sf.SetHandler(TaskExecuteLogHandler)
default:
}
mqList = append(mqList, nodeInfo)
default:
ap := hpds_node.NewAccessPoint(
v.Name,
hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)),
hpds_node.WithCredential(node.Token),
)
err = ap.Connect()
nodeInfo := HpdsMqNode{
MqType: 1,
Topic: v.Name,
Node: node,
EndPoint: ap,
}
must(logger, err)
ap.SetDataTag(frame.Tag(v.DataTag))
mqList = append(mqList, nodeInfo)
}
}
return mqList, err
}
func GetMqClient(topic string, mqType uint) *HpdsMqNode {
for _, v := range MqList {
if v.Topic == topic && v.MqType == mqType {
return &v
}
}
return nil
}
func GenerateAndSendData(stream hpds_node.AccessPoint, data []byte) error {
_, err := stream.Write(data)
if err != nil {
return err
}
time.Sleep(1000 * time.Millisecond)
return nil
}
func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
cmd := new(InstructionReq)
err := json.Unmarshal(data, cmd)
if err != nil {
return 0x0B, []byte(err.Error())
}
switch cmd.Command {
case TaskAdd:
payload := cmd.Payload.(map[string]interface{})
if len(payload["subDataset"].(string)) > 0 {
if payload["nodeId"].(float64) == 0 {
//根据业务属性进行分配节点
m := model.GetModelById(int64(payload["modelId"].(float64)))
var nodeList []model.Node
//todo 需要增加模型下发记录
if m.IsLightWeight {
nodeList = model.GetLightWeight(m.ModelId)
} else {
nodeList = model.GetAllNode(m.ModelId)
}
if nodeList != nil {
if len(nodeList) > 1 {
//这里采用加权算法权重采用CPU占用+mem使用+任务执行状态
list := model.GetNodeState(nodeList)
lb := balance.LoadBalanceFactory(balance.LbWeightRoundRobin)
for _, v := range list {
_ = lb.Add(v)
}
nodeId, _ := lb.Get(0)
if nodeId == nil {
//todo 需要增加未能获取的处理
}
payload["nodeId"] = nodeId.NodeId
payload["nodeGuid"] = nodeId.NodeGuid
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeId.NodeId)
} else {
payload["nodeId"] = nodeList[0].NodeId
issue := new(model.IssueModel)
h, _ := model.DB.Where("model_id=? and node_id =?", int64(payload["modelId"].(float64)), nodeList[0].NodeId).Get(issue)
if !h {
}
payload["issueResult"] = issue.IssueResult
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeList[0].NodeId)
}
} else {
}
} else {
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
}
} else {
if len(payload["datasetArr"].(string)) > 0 {
GoroutinueChan := make(chan bool, 5)
datasetArr := strings.Split(payload["datasetArr"].(string), ",")
//for _, val := range datasetArr {
// dId, err := strconv.ParseInt(val, 10, 64)
// if err != nil {
// continue
// }
// dt := new(model.Dataset)
// _, _ = model.DB.ID(dId).Get(dt)
fileList := make([]model.FileManager, 0)
_ = model.DB.In("dataset_id", datasetArr).
Find(&fileList)
item := &TaskItem{
TaskId: int64(payload["taskId"].(float64)),
TotalCount: int64(len(fileList)),
CompletedCount: 0,
FailingCount: 0,
UnfinishedCount: int64(len(fileList)),
LastSendTime: time.Now().Unix(),
}
TaskList[int64(payload["taskId"].(float64))] = item
//获取任务总数,并入库
taskProgress := &proto.TaskLogProgress{
PayloadType: 1,
TaskId: int64(payload["taskId"].(float64)),
TotalCount: int64(len(fileList)),
CompletedCount: 0,
FailingCount: 0,
UnfinishedCount: int64(len(fileList)),
}
model.UpdateTaskProgress(taskProgress)
//taskProgressCmd := &InstructionReq{
// Command: TaskLog,
// Payload: taskProgress,
//}
//deliver("task-log", 1, taskProgressCmd)
//数据集处理
minioCli := minio.NewClient(config.Cfg.Minio.AccessKeyId, config.Cfg.Minio.SecretAccessKey, config.Cfg.Minio.Endpoint, false, logging.L())
for _, v := range fileList {
GoroutinueChan <- true
go func(fa model.FileManager) {
dstPath := strings.Replace(fa.AccessUrl, fmt.Sprintf("%s://%s/", config.Cfg.Minio.Protocol, config.Cfg.Minio.Endpoint), "", 1)
dstPath = strings.Replace(dstPath, config.Cfg.Minio.Bucket, "", 1)
imgByte, _ := minioCli.GetObject(dstPath, config.Cfg.Minio.Bucket)
f := FileCapture{
FileId: fa.FileId,
FileName: fa.FileName,
File: base64.StdEncoding.EncodeToString(imgByte),
DatasetName: payload["datasetName"].(string),
CaptureTime: fa.CreateAt,
}
payload["single"] = f
taskCode, _ := uuid.NewUUID()
payload["taskCode"] = taskCode.String()
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
deliver("task-execute", 1, cmd)
<-GoroutinueChan
}(v)
}
//}
}
}
case ModelIssue:
payload := cmd.Payload.(map[string]interface{})
cmd := &InstructionReq{
Command: ModelIssueRepeater,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
case ModelIssueResponse:
payload := cmd.Payload.(map[string]interface{})
//查找下发记录
item := new(model.IssueModel)
h, _ := model.DB.Where("model_id = ? and node_id = ?", payload["modelId"].(int64), payload["nodeId"].(int64)).Get(item)
pData, _ := json.Marshal(payload)
if h {
item.Status = 1
item.IssueResult = string(pData)
item.UpdateAt = time.Now().Unix()
_, _ = model.DB.ID(item.Id).AllCols().Update(item)
} else {
item.ModelId = payload["modelId"].(int64)
item.NodeId = payload["nodeId"].(int64)
item.Status = 1
item.IssueResult = string(pData)
item.CreateAt = time.Now().Unix()
item.UpdateAt = time.Now().Unix()
_, _ = model.DB.Insert(item)
}
//case TaskResponse:
// payload := cmd.Payload.(map[string]interface{})
// item := new(model.TaskResult)
// item.TaskId = int64(payload["taskId"].(float64))
// item.TaskCode = payload["taskCode"].(string)
// item.NodeId = int64(payload["nodeId"].(float64))
// item.ModelId = int64(payload["modelId"].(float64))
// item.StartTime = int64(payload["startTime"].(float64))
// item.FinishTime = int64(payload["finishTime"].(float64))
// item.SubDataset = payload["subDataset"].(string)
// item.DatasetId = int64(payload["datasetArr"].(float64))
// item.SrcPath = payload["srcPath"].(string)
// item.Result = payload["body"].(string)
// _, _ = model.DB.Insert(item)
// //fn := payload["fileName"].(string)
// //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
default:
}
return frame.Tag(cmd.Command), nil
}
func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
cmd := new(InstructionReq)
err := json.Unmarshal(data, cmd)
if err != nil {
return 0x0B, []byte(err.Error())
}
switch cmd.Command {
case TaskResponse:
payload := cmd.Payload.(map[string]interface{})
item := new(model.TaskResult)
item.TaskId = int64(payload["taskId"].(float64))
if _, ok := payload["taskCode"]; ok && payload["taskCode"] != nil {
item.TaskCode = payload["taskCode"].(string)
}
if _, ok := payload["fileId"]; ok {
item.FileId = int64(payload["fileId"].(float64))
}
item.NodeId = int64(payload["nodeId"].(float64))
item.ModelId = int64(payload["modelId"].(float64))
item.StartTime = int64(payload["startTime"].(float64))
item.FinishTime = int64(payload["finishTime"].(float64))
if _, ok := payload["subDataset"]; ok {
item.SubDataset = payload["subDataset"].(string)
}
item.DatasetId, _ = strconv.ParseInt(payload["datasetArr"].(string), 10, 64)
if _, ok := payload["srcPath"]; ok && payload["srcPath"] != nil {
item.SrcPath = payload["srcPath"].(string)
}
item.Result = payload["body"].(string)
_, err = model.DB.Insert(item)
if err != nil {
fmt.Println("接收TaskResponse数据出错", err)
}
//更新运行进度
rat := model.UpdateTaskProgressByLog(item)
var (
ratStr string
)
if rat > 0 && rat < 1 {
ratStr = fmt.Sprintf("[已处理%2.f,剩余%2.f未处理]", 1-rat, rat)
} else if rat == 1 {
ratStr = "[已全部处理]"
}
taskLog := new(model.TaskLog)
taskLog.TaskId = item.TaskId
taskLog.NodeId = item.NodeId
if len(item.SrcPath) > 0 {
taskLog.Content = fmt.Sprintf("[%s] %s 图片%s处理完成", time.Unix(item.FinishTime, 0).Format("2006-01-02 15:04:05"),
ratStr, item.SrcPath)
} else {
taskLog.Content = fmt.Sprintf("[%s] %s", time.Unix(item.FinishTime, 0).Format("2006-01-02 15:04:05"),
ratStr)
}
model.InsertLog(taskLog)
//fn := payload["fileName"].(string)
//dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
default:
}
return frame.Tag(cmd.Command), nil
}
func deliver(topic string, mqType uint, payload interface{}) {
cli := GetMqClient(topic, mqType)
pData, _ := json.Marshal(payload)
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) {
cmd := new(InstructionReq)
err := json.Unmarshal(data, cmd)
if err != nil {
return 0x0B, []byte(err.Error())
}
payload := cmd.Payload.(map[string]interface{})
var l sync.Mutex
l.Lock()
taskId := int64(payload["taskId"].(float64))
if item, ok := TaskList[taskId]; ok {
item.UnfinishedCount -= 1
if int(payload["status"].(float64)) == 1 {
item.CompletedCount += 1
}
if int(payload["status"].(float64)) == 2 {
item.FailingCount += 1
}
if item.UnfinishedCount <= 0 || time.Now().Unix()-item.LastSendTime > 5000 {
//发送完成消息
taskProgress := &proto.TaskLogProgress{
PayloadType: 1,
TaskId: item.TaskId,
TotalCount: item.TotalCount,
CompletedCount: item.CompletedCount,
FailingCount: item.FailingCount,
UnfinishedCount: item.UnfinishedCount,
}
//model.UpdateTaskProgress(taskProgress)
taskProgressCmd := &InstructionReq{
Command: TaskLog,
Payload: taskProgress,
}
deliver("task-log", 1, taskProgressCmd)
if item.UnfinishedCount <= 0 {
delete(TaskList, item.TaskId)
} else {
item.LastSendTime = time.Now().Unix()
}
}
taskLog := &proto.TaskLogPayload{
PayloadType: 2,
TaskId: item.TaskId,
TaskCode: payload["taskCode"].(string),
NodeId: int64(payload["nodeId"].(float64)),
NodeGuid: payload["nodeGuid"].(string),
TaskContent: payload["taskContent"].(string),
Status: int(payload["status"].(float64)),
EventTime: int64(payload["eventTime"].(float64)),
}
taskLogCmd := &InstructionReq{
Command: TaskLog,
Payload: taskLog,
}
deliver("task-log", 1, taskLogCmd)
}
l.Unlock()
return frame.Tag(cmd.Command), nil
}