hpds_control_center/mq/index.go

1050 lines
31 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mq
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"git.hpds.cc/Component/logging"
"git.hpds.cc/Component/network/frame"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/text/encoding/simplifiedchinese"
"hpds_control_center/config"
"hpds_control_center/internal/balance"
"hpds_control_center/internal/minio"
"hpds_control_center/internal/proto"
"hpds_control_center/model"
"hpds_control_center/pkg/utils"
"io"
"math"
"os"
"os/exec"
"path"
"strconv"
"strings"
"sync"
"time"
"git.hpds.cc/pavement/hpds_node"
)
type Charset string
const (
UTF8 = Charset("UTF-8")
GB18030 = Charset("GB18030")
)
var (
MqList []HpdsMqNode
TaskList = make(map[int64]*TaskItem)
)
type HpdsMqNode struct {
MqType uint
Topic string
Node config.HpdsNode
EndPoint interface{}
Logger *logging.Logger
}
type TaskItem struct {
TaskId int64
TotalCount int64
CompletedCount int64
FailingCount int64
UnfinishedCount int64
LastSendTime int64
}
func must(logger *logging.Logger, err error) {
if err != nil {
if logger != nil {
logger.With(zap.String("web节点", "错误信息")).Error("启动错误", zap.Error(err))
} else {
_, _ = fmt.Fprint(os.Stderr, err)
}
os.Exit(1)
}
}
func NewMqClient(funcs []config.FuncConfig, node config.HpdsNode, logger *logging.Logger) (mqList []HpdsMqNode, err error) {
mqList = make([]HpdsMqNode, 0)
for _, v := range funcs {
switch v.MqType {
case 2:
sf := hpds_node.NewStreamFunction(
v.Name,
hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)),
hpds_node.WithObserveDataTags(frame.Tag(v.DataTag)),
hpds_node.WithCredential(node.Token),
)
err = sf.Connect()
must(logger, err)
nodeInfo := HpdsMqNode{
MqType: 2,
Topic: v.Name,
Node: node,
EndPoint: sf,
}
switch v.Name {
case "task-request":
_ = sf.SetHandler(TaskRequestHandler)
case "task-response":
_ = sf.SetHandler(TaskResponseHandler)
case "task-execute-log":
_ = sf.SetHandler(TaskExecuteLogHandler)
default:
}
mqList = append(mqList, nodeInfo)
default:
ap := hpds_node.NewAccessPoint(
v.Name,
hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)),
hpds_node.WithCredential(node.Token),
)
err = ap.Connect()
nodeInfo := HpdsMqNode{
MqType: 1,
Topic: v.Name,
Node: node,
EndPoint: ap,
}
must(logger, err)
ap.SetDataTag(frame.Tag(v.DataTag))
mqList = append(mqList, nodeInfo)
}
}
return mqList, err
}
func GetMqClient(topic string, mqType uint) *HpdsMqNode {
for _, v := range MqList {
if v.Topic == topic && v.MqType == mqType {
return &v
}
}
return nil
}
func GenerateAndSendData(stream hpds_node.AccessPoint, data []byte) error {
_, err := stream.Write(data)
if err != nil {
return err
}
time.Sleep(1000 * time.Millisecond)
return nil
}
func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
cmd := new(InstructionReq)
err := json.Unmarshal(data, cmd)
if err != nil {
return 0x0B, []byte(err.Error())
}
switch cmd.Command {
case TaskAdd:
payload := cmd.Payload.(map[string]interface{})
if len(payload["subDataset"].(string)) > 0 {
if payload["nodeId"].(float64) == 0 {
//根据业务属性进行分配节点
m := model.GetModelById(int64(payload["modelId"].(float64)))
var nodeList []model.Node
//todo 需要增加模型下发记录
if m.IsLightWeight {
nodeList = model.GetLightWeight(m.ModelId)
} else {
nodeList = model.GetAllNode(m.ModelId)
}
if nodeList != nil {
if len(nodeList) > 1 {
//这里采用加权算法权重采用CPU占用+mem使用+任务执行状态
list := model.GetNodeState(nodeList)
lb := balance.LoadBalanceFactory(balance.LbWeightRoundRobin)
for _, v := range list {
_ = lb.Add(v)
}
nodeId, _ := lb.Get(0)
if nodeId == nil {
//todo 需要增加未能获取的处理
}
payload["nodeId"] = nodeId.NodeId
payload["nodeGuid"] = nodeId.NodeGuid
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeId.NodeId)
} else {
payload["nodeId"] = nodeList[0].NodeId
issue := new(model.IssueModel)
h, _ := model.DB.Where("model_id=? and node_id =?", int64(payload["modelId"].(float64)), nodeList[0].NodeId).Get(issue)
if !h {
}
payload["issueResult"] = issue.IssueResult
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeList[0].NodeId)
}
} else {
}
} else {
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
}
} else {
if len(payload["datasetArr"].(string)) > 0 {
GoroutinueChan := make(chan bool, 5)
if payload["nodeId"].(float64) == 0 {
//根据业务属性进行分配节点
m := model.GetModelById(int64(payload["modelId"].(float64)))
var nodeList []model.Node
//todo 需要增加模型下发记录
if m.IsLightWeight {
nodeList = model.GetLightWeight(m.ModelId)
} else {
nodeList = model.GetAllNode(m.ModelId)
}
if nodeList != nil {
if len(nodeList) > 1 {
//这里采用加权算法权重采用CPU占用+mem使用+任务执行状态
list := model.GetNodeState(nodeList)
lb := balance.LoadBalanceFactory(balance.LbWeightRoundRobin)
for _, v := range list {
_ = lb.Add(v)
}
nodeId, _ := lb.Get(0)
if nodeId == nil {
//todo 需要增加未能获取的处理
}
payload["nodeId"] = nodeId.NodeId
payload["nodeGuid"] = nodeId.NodeGuid
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeId.NodeId)
} else {
payload["nodeId"] = nodeList[0].NodeId
issue := new(model.IssueModel)
h, _ := model.DB.Where("model_id=? and node_id =?", int64(payload["modelId"].(float64)), nodeList[0].NodeId).Get(issue)
if !h {
}
payload["issueResult"] = issue.IssueResult
cmd := &InstructionReq{
Command: TaskExecute,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeList[0].NodeId)
}
} else {
}
} else {
node := model.GetNodeById(int64(payload["nodeId"].(float64)))
if node != nil {
payload["nodeGuid"] = node.NodeGuid
}
}
//数据集处理
datasetArr := strings.Split(payload["datasetArr"].(string), ",")
//for _, val := range datasetArr {
// dId, err := strconv.ParseInt(val, 10, 64)
// if err != nil {
// continue
// }
// dt := new(model.Dataset)
// _, _ = model.DB.ID(dId).Get(dt)
fileList := make([]model.FileManager, 0)
err = model.DB.In("dataset_id", datasetArr).
Find(&fileList)
if err != nil {
}
item := &TaskItem{
TaskId: int64(payload["taskId"].(float64)),
TotalCount: int64(len(fileList)),
CompletedCount: 0,
FailingCount: 0,
UnfinishedCount: int64(len(fileList)),
LastSendTime: time.Now().Unix(),
}
TaskList[int64(payload["taskId"].(float64))] = item
//获取任务总数,并入库
taskProgress := &proto.TaskLogProgress{
PayloadType: 1,
TaskId: int64(payload["taskId"].(float64)),
TotalCount: int64(len(fileList)),
CompletedCount: 0,
FailingCount: 0,
UnfinishedCount: int64(len(fileList)),
}
model.UpdateTaskProgress(taskProgress)
taskLog := &model.TaskLog{
TaskId: int64(payload["taskId"].(float64)),
NodeId: int64(payload["nodeId"].(float64)),
Content: fmt.Sprintf("[%s] 在节点[%s]上开始执行任务,任务数量共[%d]", time.Now().Format("2006-01-02 15:04:05"), payload["nodeGuid"].(string), taskProgress.TotalCount),
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
model.InsertLog(taskLog)
taskProgressCmd := &InstructionReq{
Command: TaskLog,
Payload: taskProgress,
}
deliver("task-log", 1, taskProgressCmd)
//数据集处理
minioCli := minio.NewClient(config.Cfg.Minio.AccessKeyId, config.Cfg.Minio.SecretAccessKey, config.Cfg.Minio.Endpoint, false, logging.L())
for _, v := range fileList {
GoroutinueChan <- true
go func(fa model.FileManager, payload map[string]interface{}) {
p := make(map[string]interface{})
for key, val := range payload {
p[key] = val
}
dstPath := strings.Replace(fa.AccessUrl, fmt.Sprintf("%s://%s/", config.Cfg.Minio.Protocol, config.Cfg.Minio.Endpoint), "", 1)
dstPath = strings.Replace(dstPath, config.Cfg.Minio.Bucket, "", 1)
imgByte, _ := minioCli.GetObject(dstPath, config.Cfg.Minio.Bucket)
fc := FileCapture{
FileId: fa.FileId,
FileName: fa.FileName,
File: base64.StdEncoding.EncodeToString(imgByte),
DatasetName: p["datasetName"].(string),
CaptureTime: fa.CreateAt,
}
p["single"] = fc
taskCode, _ := uuid.NewUUID()
p["taskCode"] = taskCode.String()
cmd := &InstructionReq{
Command: TaskExecute,
Payload: p,
}
deliver("task-execute", 1, cmd)
<-GoroutinueChan
}(v, payload)
}
//}
}
}
case ModelIssue:
payload := cmd.Payload.(map[string]interface{})
cmd := &InstructionReq{
Command: ModelIssueRepeater,
Payload: payload,
}
pData, _ := json.Marshal(cmd)
cli := GetMqClient("task-execute", 1)
if cli != nil {
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
case ModelIssueResponse:
payload := cmd.Payload.(map[string]interface{})
//查找下发记录
item := new(model.IssueModel)
h, _ := model.DB.Where("model_id = ? and node_id = ?", payload["modelId"].(int64), payload["nodeId"].(int64)).Get(item)
pData, _ := json.Marshal(payload)
if h {
item.Status = 1
item.IssueResult = string(pData)
item.UpdateAt = time.Now().Unix()
_, _ = model.DB.ID(item.Id).AllCols().Update(item)
} else {
item.ModelId = int64(payload["modelId"].(float64))
item.NodeId = int64(payload["nodeId"].(float64))
item.Status = 1
item.IssueResult = string(pData)
item.CreateAt = time.Now().Unix()
item.UpdateAt = time.Now().Unix()
_, _ = model.DB.Insert(item)
}
//case TaskResponse:
// payload := cmd.Payload.(map[string]interface{})
// item := new(model.TaskResult)
// item.TaskId = int64(payload["taskId"].(float64))
// item.TaskCode = payload["taskCode"].(string)
// item.NodeId = int64(payload["nodeId"].(float64))
// item.ModelId = int64(payload["modelId"].(float64))
// item.StartTime = int64(payload["startTime"].(float64))
// item.FinishTime = int64(payload["finishTime"].(float64))
// item.SubDataset = payload["subDataset"].(string)
// item.DatasetId = int64(payload["datasetArr"].(float64))
// item.SrcPath = payload["srcPath"].(string)
// item.Result = payload["body"].(string)
// _, _ = model.DB.Insert(item)
// //fn := payload["fileName"].(string)
// //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
case TrainTaskAdd:
payload := cmd.Payload.(map[string]interface{})
if itemId, ok := payload["taskId"].(float64); ok {
item := new(model.TrainTask)
h, err := model.DB.ID(int64(itemId)).Get(item)
if err != nil || !h {
}
RunTraining(item)
}
default:
}
return frame.Tag(cmd.Command), nil
}
func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
cmd := new(InstructionReq)
err := json.Unmarshal(data, cmd)
if err != nil {
return 0x0B, []byte(err.Error())
}
switch cmd.Command {
case TaskResponse:
payload := cmd.Payload.(map[string]interface{})
item := new(model.TaskResult)
item.TaskId = int64(payload["taskId"].(float64))
if _, ok := payload["taskCode"]; ok && payload["taskCode"] != nil {
item.TaskCode = payload["taskCode"].(string)
}
if _, ok := payload["fileId"]; ok {
item.FileId = int64(payload["fileId"].(float64))
}
item.NodeId = int64(payload["nodeId"].(float64))
item.ModelId = int64(payload["modelId"].(float64))
item.StartTime = int64(payload["startTime"].(float64))
item.FinishTime = int64(payload["finishTime"].(float64))
if _, ok := payload["subDataset"]; ok {
item.SubDataset = payload["subDataset"].(string)
}
item.DatasetId, _ = strconv.ParseInt(payload["datasetArr"].(string), 10, 64)
if _, ok := payload["srcPath"]; ok && payload["srcPath"] != nil {
item.SrcPath = payload["srcPath"].(string)
}
if _, ok := payload["body"]; ok {
item.Result = payload["body"].(string)
}
isFailing := false
if _, ok := payload["code"]; ok && int(payload["code"].(float64)) == 500 {
item.Result = payload["msg"].(string)
isFailing = true
}
_, err = model.DB.Insert(item)
if err != nil {
fmt.Println("接收TaskResponse数据出错", err)
}
//处理到项目结果表
go processToProjectResult(item)
//更新运行进度
processed, unProcessed := model.UpdateTaskProgressByLog(item, isFailing)
var (
ratStr string
)
if unProcessed > 0 {
ratStr = fmt.Sprintf("[已处理[%d],剩余[%d]未处理]", processed, unProcessed)
} else {
ratStr = "[已全部处理]"
}
taskLog := new(model.TaskLog)
taskLog.TaskId = item.TaskId
taskLog.NodeId = item.NodeId
if len(item.SrcPath) > 0 {
taskLog.Content = fmt.Sprintf("[%s] 图片%s处理完成 %s ", time.Unix(item.FinishTime, 0).Format("2006-01-02 15:04:05"),
item.SrcPath, ratStr)
} else {
taskLog.Content = fmt.Sprintf("[%s] %s", time.Unix(item.FinishTime, 0).Format("2006-01-02 15:04:05"),
ratStr)
}
model.InsertLog(taskLog)
//fn := payload["fileName"].(string)
//dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
default:
}
return frame.Tag(cmd.Command), nil
}
type ModelResult struct {
Code int `json:"code"`
}
type InsigmaResult struct {
Code int `json:"code"`
NumOfDiseases int `json:"num_of_diseases"`
Diseases []DiseasesInfo `json:"diseases"`
Image string `json:"image"`
}
type DiseasesInfo struct {
Id int `json:"id"`
Type string `json:"type"`
Level string `json:"level"`
Param DiseasesParam `json:"param"`
}
type DiseasesParam struct {
Length float64 `json:"length"`
Area float64 `json:"area"`
MaxWidth string `json:"max_width"`
}
type LightweightResult struct {
Code int `json:"code"`
Crack bool `json:"crack"`
ImgDiscern string `json:"img_discern"`
ImgSrc string `json:"img_src"`
Pothole bool `json:"pothole"`
}
func processToProjectResult(src *model.TaskResult) {
project := new(model.Project)
h, err := model.DB.Table("project").Alias("p").Join("inner", []string{"dataset", "d"}, "d.project_id= p.project_id").Where("d.dataset_id=?", src.DatasetId).Get(project)
if !h {
err = fmt.Errorf("未能找到对应的项目信息")
}
if err != nil {
logging.L().With(zap.String("控制节点", "错误信息")).Error("获取项目信息", zap.Error(err))
return
}
var (
mr ModelResult
mrList []string
fileDiscern string
memo string
milepostNumber string
upDown string
lineNum int
width float64
)
switch project.BizType {
case 1: //道路
arr := strings.Split(src.SrcPath, " ")
if len(arr) > 1 {
milepostNumber = GetMilepost(project.StartName, arr[1], arr[2])
if arr[2] == "D" {
upDown = "下行"
} else {
upDown = "上行"
}
}
if len(arr) > 3 {
lineNum, _ = strconv.Atoi(arr[3])
}
case 2: //桥梁
case 3: //隧道
//隧道名-采集方向(D/X)-相机编号(01-22)-采集序号五位K里程桩号.bmp DAXIASHAN-D-05-00003K15069.5.bmp
arr := strings.Split(src.SrcPath, "K")
if len(arr) > 1 {
arrM := strings.Split(arr[1], ".")
milepostNumber = meter2Milepost(arrM[0])
arrD := strings.Split(arr[0], ".")
if len(arrD) > 1 {
if arrD[1] == "D" {
upDown = "下行"
} else {
upDown = "上行"
}
}
if len(arrD) > 4 {
lineNum, _ = strconv.Atoi(arrD[3])
}
}
}
if len(src.Result) > 0 && src.Result[0] == '[' {
mrList = make([]string, 0)
if err := json.Unmarshal([]byte(src.Result), &mrList); err != nil {
return
}
list := make([]*model.ProjectResult, 0)
for _, str := range mrList {
if err := json.Unmarshal([]byte(str), &mr); err != nil {
continue
}
if mr.Code == 2001 {
ir := new(InsigmaResult)
if err := json.Unmarshal([]byte(str), &ir); err != nil {
continue
}
fileDiscern = ir.Image
for key, value := range ir.Diseases {
if len(value.Param.MaxWidth) > 0 && width == 0 {
width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64)
} else {
width = 0
}
memo = fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area)
item := &model.ProjectResult{
ProjectId: project.ProjectId,
SourceResultId: src.ResultId,
MilepostNumber: milepostNumber,
UpDown: upDown,
LineNum: lineNum,
DiseaseType: value.Type,
DiseaseLevel: value.Level,
Length: value.Param.Length,
Width: width,
Acreage: value.Param.Area,
Memo: memo,
Result: fileDiscern,
Creator: 0,
Modifier: 0,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
list = append(list, item)
}
}
}
_, _ = model.DB.Insert(list)
} else {
if err := json.Unmarshal([]byte(src.Result), &mr); err != nil {
return
}
switch mr.Code {
case 0: //轻量化模型返回
lr := new(LightweightResult)
if err := json.Unmarshal([]byte(src.Result), &lr); err != nil {
return
}
if lr.Crack || lr.Pothole {
if lr.Crack {
memo = "检测到裂缝"
} else {
memo = "检测到坑洼"
}
fileDiscern = lr.ImgDiscern
if len(fileDiscern) == 0 {
fileDiscern = lr.ImgSrc
}
diseaseLevelName := "重度"
diseaseTypeName := ""
switch project.BizType {
case 2:
diseaseTypeName = "结构裂缝"
case 3:
diseaseTypeName = "衬砌裂缝"
default:
diseaseTypeName = "横向裂缝"
}
item := &model.ProjectResult{
ProjectId: project.ProjectId,
SourceResultId: src.ResultId,
MilepostNumber: milepostNumber,
UpDown: upDown,
LineNum: lineNum,
DiseaseType: diseaseTypeName,
DiseaseLevel: diseaseLevelName,
Length: 0,
Width: 0,
Acreage: 0,
Memo: memo,
Result: fileDiscern,
Creator: 0,
Modifier: 0,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
_, _ = model.DB.Insert(item)
} else {
fileDiscern = lr.ImgSrc
}
//
case 2001: //网新返回有病害
ir := new(InsigmaResult)
if err := json.Unmarshal([]byte(src.Result), &ir); err != nil {
return
}
fileDiscern = ir.Image
list := make([]*model.ProjectResult, 0)
for _, val := range ir.Diseases {
if len(val.Param.MaxWidth) > 0 && width == 0 {
width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64)
} else {
width = 0
}
memo = fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area)
maxWidth, _ := strconv.ParseFloat(val.Param.MaxWidth, 64)
item := &model.ProjectResult{
ProjectId: project.ProjectId,
SourceResultId: src.ResultId,
MilepostNumber: milepostNumber,
UpDown: upDown,
LineNum: lineNum,
DiseaseType: val.Type,
DiseaseLevel: val.Level,
Length: val.Param.Length,
Width: maxWidth,
Acreage: val.Param.Area,
Memo: memo,
Result: fileDiscern,
Creator: 0,
Modifier: 0,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
list = append(list, item)
}
_, _ = model.DB.Insert(list)
}
}
}
// 里程桩加减里程,返回里程桩
func GetMilepost(start, num, upDown string) string {
arr := strings.Split(start, "+")
var (
kilometre, meter, milepost, counter, res, resMilepost, resMeter float64
)
if len(arr) == 1 {
meter = 0
} else {
meter, _ = strconv.ParseFloat(arr[1], 64)
}
str := strings.Replace(arr[0], "k", "", -1)
str = strings.Replace(str, "K", "", -1)
kilometre, _ = strconv.ParseFloat(str, 64)
milepost = kilometre + meter/1000
counter, _ = strconv.ParseFloat(num, 64)
if upDown == "D" {
res = milepost - counter
} else {
res = milepost + counter
}
resMilepost = math.Floor(res)
resMeter = (res - resMilepost) * 100
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
}
// 米装换成里程桩号
func meter2Milepost(meter string) string {
meter = strings.Replace(meter, "K", "", -1)
m, _ := strconv.ParseFloat(meter, 64)
resMilepost := math.Floor(m / 1000)
resMeter := (m - resMilepost*1000) * 100
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
}
func deliver(topic string, mqType uint, payload interface{}) {
cli := GetMqClient(topic, mqType)
pData, _ := json.Marshal(payload)
_ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData)
}
func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) {
cmd := new(InstructionReq)
err := json.Unmarshal(data, cmd)
if err != nil {
return 0x0B, []byte(err.Error())
}
payload := cmd.Payload.(map[string]interface{})
var l sync.Mutex
l.Lock()
taskId := int64(payload["taskId"].(float64))
if item, ok := TaskList[taskId]; ok {
item.UnfinishedCount -= 1
if int(payload["status"].(float64)) == 1 {
item.CompletedCount += 1
}
if int(payload["status"].(float64)) == 2 {
item.FailingCount += 1
}
if item.UnfinishedCount <= 0 || time.Now().Unix()-item.LastSendTime > 5000 {
//发送完成消息
taskProgress := &proto.TaskLogProgress{
PayloadType: 1,
TaskId: item.TaskId,
TotalCount: item.TotalCount,
CompletedCount: item.CompletedCount,
FailingCount: item.FailingCount,
UnfinishedCount: item.UnfinishedCount,
}
//model.UpdateTaskProgress(taskProgress)
taskProgressCmd := &InstructionReq{
Command: TaskLog,
Payload: taskProgress,
}
deliver("task-log", 1, taskProgressCmd)
if item.UnfinishedCount <= 0 {
delete(TaskList, item.TaskId)
} else {
item.LastSendTime = time.Now().Unix()
}
}
taskLog := &proto.TaskLogPayload{
PayloadType: 2,
TaskId: item.TaskId,
TaskCode: payload["taskCode"].(string),
NodeId: int64(payload["nodeId"].(float64)),
NodeGuid: payload["nodeGuid"].(string),
TaskContent: payload["taskContent"].(string),
Status: int(payload["status"].(float64)),
EventTime: int64(payload["eventTime"].(float64)),
}
taskLogCmd := &InstructionReq{
Command: TaskLog,
Payload: taskLog,
}
deliver("task-log", 1, taskLogCmd)
}
l.Unlock()
return frame.Tag(cmd.Command), nil
}
func RunTraining(task *model.TrainTask) {
var (
args []string
modelPath, modelFileName, testSize string
modelAcc, modelLoss float64
)
fmt.Println("curr tmp dir====>>>>", config.Cfg.TmpTrainDir)
modelFileName = utils.GetUUIDString()
//复制训练数据集
tmpTrainDir := path.Join(config.Cfg.TmpTrainDir, fmt.Sprintf("%s_%s_%d_%d", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize))
fileList := make([]model.TrainingDatasetDetail, 0)
_ = model.DB.Where("dataset_id = ?", task.TrainDatasetId).Find(&fileList)
_ = os.MkdirAll(tmpTrainDir, os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "train"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "0"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "1"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "val"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "0"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "1"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "test"), os.ModePerm)
for _, v := range fileList {
dstFilePath := ""
switch v.CategoryId {
case 2:
dstFilePath = "test"
default:
dstFilePath = "train"
}
if v.CategoryId != 2 {
if v.IsDisease == 1 {
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "0")
} else {
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "1")
}
} else {
dstFilePath = path.Join(tmpTrainDir, dstFilePath)
}
err := utils.CopyFile(v.FilePath, path.Join(dstFilePath, v.FileName))
if err != nil {
fmt.Println("copy error: ", err)
}
}
modelPath = path.Join(config.Cfg.ModelOutPath, fmt.Sprintf("%s_%s_%d_%d_%s", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize, task.OutputType))
_ = os.MkdirAll(modelPath, os.ModePerm)
dt := new(model.TrainingDataset)
_, err := model.DB.ID(task.TrainDatasetId).Get(dt)
if err != nil {
goto ReturnPoint
}
testSize = fmt.Sprintf("%.2f", dt.ValidationNumber/100)
//执行训练命令
args = []string{"--dataset=" + path.Join(tmpTrainDir, "train"),
"--img_size=" + strconv.Itoa(task.ImageSize), "--batch_size=" + strconv.Itoa(task.BatchSize), "--test_size=" + testSize,
"--epochs=" + strconv.Itoa(task.EpochsSize), "--model=" + task.Arithmetic, "--model_save=" + path.Join(modelPath, modelFileName+".h5"),
}
fmt.Println("args====>>>", args)
err = ExecCommand(config.Cfg.TrainScriptPath, args, path.Join(modelPath, modelFileName+".log"), task.TaskId)
ReturnPoint:
//返回训练结果
modelMetricsFile := modelFileName + "_model_metrics.png"
task.FinishTime = time.Now().Unix()
task.ModelFilePath = path.Join(modelPath, modelFileName+".h5")
task.Loss = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Loss:")
task.Accuracy = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Accuracy:")
task.Status = 3
if err != nil {
task.Status = 5
}
task.ModelFileMetricsPath = path.Join(modelPath, modelMetricsFile)
_, _ = model.DB.ID(task.TaskId).AllCols().Update(task)
if utils.PathExists(path.Join(modelPath, modelFileName+".log")) {
logContext := utils.ReadFile(path.Join(modelPath, modelFileName+".log"))
taskRes := new(model.TrainTaskResult)
taskRes.TaskId = task.TaskId
taskRes.CreateAt = time.Now().Unix()
taskRes.Content = string(logContext)
taskRes.Result = path.Join(modelPath, modelMetricsFile)
taskRes.Accuracy = modelAcc
taskRes.Loss = modelLoss
c, err := model.DB.Insert(taskRes)
if err != nil {
fmt.Println("model.DB.Insert(taskRes) error ========>>>>>>", err)
}
fmt.Println("model.DB.Insert(taskRes) count ========>>>>>>", c)
} else {
fmt.Println("logContext========>>>>>>未读取")
}
}
func GetIndicatorByLog(logFileName, indicator string) float64 {
logFn, _ := os.Open(logFileName)
defer func() {
_ = logFn.Close()
}()
buf := bufio.NewReader(logFn)
for {
line, err := buf.ReadString('\n')
if err != nil {
if err == io.EOF {
//fmt.Println("File read ok!")
break
} else {
fmt.Println("Read file error!", err)
return 0
}
}
if strings.Index(line, indicator) >= 0 {
str := strings.Replace(line, indicator, "", -1)
str = strings.Replace(str, "\n", "", -1)
value, _ := strconv.ParseFloat(strings.Trim(str, " "), 64)
return value
}
}
return 0
}
func ExecCommand(cmd string, args []string, logFileName string, taskId int64) (err error) {
logFile, _ := os.Create(logFileName)
defer func() {
_ = logFile.Close()
}()
fmt.Print("开始训练......")
c := exec.Command(cmd, args...) // mac or linux
stdout, err := c.StdoutPipe()
if err != nil {
return err
}
var (
wg sync.WaitGroup
)
wg.Add(1)
go func() {
defer wg.Done()
reader := bufio.NewReader(stdout)
var (
epoch int
//modelLoss, modelAcc float64
)
for {
readString, err := reader.ReadString('\n')
if err != nil || err == io.EOF {
fmt.Println("训练2===>>>", err)
//wg.Done()
return
}
byte2String := ConvertByte2String([]byte(readString), "GB18030")
_, _ = fmt.Fprint(logFile, byte2String)
if strings.Index(byte2String, "Epoch") >= 0 {
str := strings.Replace(byte2String, "Epoch ", "", -1)
arr := strings.Split(str, "/")
epoch, _ = strconv.Atoi(arr[0])
}
if strings.Index(byte2String, "- loss:") > 0 &&
strings.Index(byte2String, "- accuracy:") > 0 &&
strings.Index(byte2String, "- val_loss:") > 0 &&
strings.Index(byte2String, "- val_accuracy:") > 0 {
var (
loss, acc, valLoss, valAcc float64
)
arr := strings.Split(byte2String, "-")
for _, v := range arr {
if strings.Index(v, "loss:") > 0 && strings.Index(v, "val_loss:") < 0 {
strLoss := strings.Replace(v, " loss: ", "", -1)
loss, _ = strconv.ParseFloat(strings.Trim(strLoss, " "), 64)
}
if strings.Index(v, "accuracy:") > 0 && strings.Index(v, "val_accuracy:") < 0 {
strAcc := strings.Replace(v, " accuracy: ", "", -1)
acc, _ = strconv.ParseFloat(strings.Trim(strAcc, " "), 64)
}
if strings.Index(v, "val_loss:") > 0 {
strValLoss := strings.Replace(v, "val_loss: ", "", -1)
valLoss, _ = strconv.ParseFloat(strings.Trim(strValLoss, " "), 64)
}
if strings.Index(v, "val_accuracy:") > 0 {
strValAcc := strings.Replace(v, "val_accuracy: ", "", -1)
strValAcc = strings.Replace(strValAcc, "\n", "", -1)
valAcc, _ = strconv.ParseFloat(strings.Trim(strValAcc, " "), 64)
}
}
taskLog := new(model.TrainTaskLog)
taskLog.Epoch = epoch
taskLog.TaskId = taskId
taskLog.CreateAt = time.Now().Unix()
taskLog.Loss = loss
taskLog.Accuracy = acc
taskLog.ValLoss = valLoss
taskLog.ValAccuracy = valAcc
_, _ = model.DB.Insert(taskLog)
}
fmt.Print(byte2String)
}
}()
err = c.Start()
if err != nil {
fmt.Println("训练3===>>>", err)
}
wg.Wait()
return
}
func ConvertByte2String(byte []byte, charset Charset) string {
var str string
switch charset {
case GB18030:
var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte)
str = string(decodeBytes)
case UTF8:
fallthrough
default:
str = string(byte)
}
return str
}