package mq import ( "bufio" "encoding/base64" "encoding/json" "fmt" "git.hpds.cc/Component/logging" "git.hpds.cc/Component/network/frame" "github.com/google/uuid" "go.uber.org/zap" "golang.org/x/text/encoding/simplifiedchinese" "hpds_control_center/config" "hpds_control_center/internal/balance" "hpds_control_center/internal/minio" "hpds_control_center/internal/proto" "hpds_control_center/model" "hpds_control_center/pkg/utils" "io" "math" "os" "os/exec" "path" "strconv" "strings" "sync" "time" "git.hpds.cc/pavement/hpds_node" ) type Charset string const ( UTF8 = Charset("UTF-8") GB18030 = Charset("GB18030") ) var ( MqList []HpdsMqNode TaskList = make(map[int64]*TaskItem) ) type HpdsMqNode struct { MqType uint Topic string Node config.HpdsNode EndPoint interface{} Logger *logging.Logger } type TaskItem struct { TaskId int64 TotalCount int64 CompletedCount int64 FailingCount int64 UnfinishedCount int64 LastSendTime int64 } func must(logger *logging.Logger, err error) { if err != nil { if logger != nil { logger.With(zap.String("web节点", "错误信息")).Error("启动错误", zap.Error(err)) } else { _, _ = fmt.Fprint(os.Stderr, err) } os.Exit(1) } } func NewMqClient(funcs []config.FuncConfig, node config.HpdsNode, logger *logging.Logger) (mqList []HpdsMqNode, err error) { mqList = make([]HpdsMqNode, 0) for _, v := range funcs { switch v.MqType { case 2: sf := hpds_node.NewStreamFunction( v.Name, hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)), hpds_node.WithObserveDataTags(frame.Tag(v.DataTag)), hpds_node.WithCredential(node.Token), ) err = sf.Connect() must(logger, err) nodeInfo := HpdsMqNode{ MqType: 2, Topic: v.Name, Node: node, EndPoint: sf, } switch v.Name { case "task-request": _ = sf.SetHandler(TaskRequestHandler) case "task-response": _ = sf.SetHandler(TaskResponseHandler) case "task-execute-log": _ = sf.SetHandler(TaskExecuteLogHandler) default: } mqList = append(mqList, nodeInfo) default: ap := hpds_node.NewAccessPoint( v.Name, hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)), hpds_node.WithCredential(node.Token), ) err = ap.Connect() nodeInfo := HpdsMqNode{ MqType: 1, Topic: v.Name, Node: node, EndPoint: ap, } must(logger, err) ap.SetDataTag(frame.Tag(v.DataTag)) mqList = append(mqList, nodeInfo) } } return mqList, err } func GetMqClient(topic string, mqType uint) *HpdsMqNode { for _, v := range MqList { if v.Topic == topic && v.MqType == mqType { return &v } } return nil } func GenerateAndSendData(stream hpds_node.AccessPoint, data []byte) error { _, err := stream.Write(data) if err != nil { return err } time.Sleep(1000 * time.Millisecond) return nil } func TaskRequestHandler(data []byte) (frame.Tag, []byte) { cmd := new(InstructionReq) err := json.Unmarshal(data, cmd) if err != nil { return 0x0B, []byte(err.Error()) } switch cmd.Command { case TaskAdd: payload := cmd.Payload.(map[string]interface{}) if len(payload["subDataset"].(string)) > 0 { if payload["nodeId"].(float64) == 0 { //根据业务属性进行分配节点 m := model.GetModelById(int64(payload["modelId"].(float64))) var nodeList []model.Node //todo 需要增加模型下发记录 if m.IsLightWeight { nodeList = model.GetLightWeight(m.ModelId) } else { nodeList = model.GetAllNode(m.ModelId) } if nodeList != nil { if len(nodeList) > 1 { //这里采用加权算法,权重采用CPU占用+mem使用+任务执行状态 list := model.GetNodeState(nodeList) lb := balance.LoadBalanceFactory(balance.LbWeightRoundRobin) for _, v := range list { _ = lb.Add(v) } nodeId, _ := lb.Get(0) if nodeId == nil { //todo 需要增加未能获取的处理 } payload["nodeId"] = nodeId.NodeId payload["nodeGuid"] = nodeId.NodeGuid cmd := &InstructionReq{ Command: TaskExecute, Payload: payload, } pData, _ := json.Marshal(cmd) cli := GetMqClient("task-execute", 1) if cli != nil { _ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData) } model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeId.NodeId) } else { payload["nodeId"] = nodeList[0].NodeId issue := new(model.IssueModel) h, _ := model.DB.Where("model_id=? and node_id =?", int64(payload["modelId"].(float64)), nodeList[0].NodeId).Get(issue) if !h { } payload["issueResult"] = issue.IssueResult cmd := &InstructionReq{ Command: TaskExecute, Payload: payload, } pData, _ := json.Marshal(cmd) cli := GetMqClient("task-execute", 1) if cli != nil { _ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData) } model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeList[0].NodeId) } } else { } } else { cmd := &InstructionReq{ Command: TaskExecute, Payload: payload, } pData, _ := json.Marshal(cmd) cli := GetMqClient("task-execute", 1) if cli != nil { _ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData) } } } else { if len(payload["datasetArr"].(string)) > 0 { GoroutinueChan := make(chan bool, 5) if payload["nodeId"].(float64) == 0 { //根据业务属性进行分配节点 m := model.GetModelById(int64(payload["modelId"].(float64))) var nodeList []model.Node //todo 需要增加模型下发记录 if m.IsLightWeight { nodeList = model.GetLightWeight(m.ModelId) } else { nodeList = model.GetAllNode(m.ModelId) } if nodeList != nil { if len(nodeList) > 1 { //这里采用加权算法,权重采用CPU占用+mem使用+任务执行状态 list := model.GetNodeState(nodeList) lb := balance.LoadBalanceFactory(balance.LbWeightRoundRobin) for _, v := range list { _ = lb.Add(v) } nodeId, _ := lb.Get(0) if nodeId == nil { //todo 需要增加未能获取的处理 } payload["nodeId"] = nodeId.NodeId payload["nodeGuid"] = nodeId.NodeGuid cmd := &InstructionReq{ Command: TaskExecute, Payload: payload, } pData, _ := json.Marshal(cmd) cli := GetMqClient("task-execute", 1) if cli != nil { _ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData) } model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeId.NodeId) } else { payload["nodeId"] = nodeList[0].NodeId issue := new(model.IssueModel) h, _ := model.DB.Where("model_id=? and node_id =?", int64(payload["modelId"].(float64)), nodeList[0].NodeId).Get(issue) if !h { } payload["issueResult"] = issue.IssueResult cmd := &InstructionReq{ Command: TaskExecute, Payload: payload, } pData, _ := json.Marshal(cmd) cli := GetMqClient("task-execute", 1) if cli != nil { _ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData) } model.UpdateTaskExecuteNode(int64(payload["taskId"].(float64)), nodeList[0].NodeId) } } else { } } else { node := model.GetNodeById(int64(payload["nodeId"].(float64))) if node != nil { payload["nodeGuid"] = node.NodeGuid } } //数据集处理 datasetArr := strings.Split(payload["datasetArr"].(string), ",") //for _, val := range datasetArr { // dId, err := strconv.ParseInt(val, 10, 64) // if err != nil { // continue // } // dt := new(model.Dataset) // _, _ = model.DB.ID(dId).Get(dt) fileList := make([]model.FileManager, 0) err = model.DB.In("dataset_id", datasetArr). Find(&fileList) if err != nil { } item := &TaskItem{ TaskId: int64(payload["taskId"].(float64)), TotalCount: int64(len(fileList)), CompletedCount: 0, FailingCount: 0, UnfinishedCount: int64(len(fileList)), LastSendTime: time.Now().Unix(), } TaskList[int64(payload["taskId"].(float64))] = item //获取任务总数,并入库 taskProgress := &proto.TaskLogProgress{ PayloadType: 1, TaskId: int64(payload["taskId"].(float64)), TotalCount: int64(len(fileList)), CompletedCount: 0, FailingCount: 0, UnfinishedCount: int64(len(fileList)), } model.UpdateTaskProgress(taskProgress) taskLog := &model.TaskLog{ TaskId: int64(payload["taskId"].(float64)), NodeId: int64(payload["nodeId"].(float64)), Content: fmt.Sprintf("[%s] 在节点[%s]上开始执行任务,任务数量共[%d]", time.Now().Format("2006-01-02 15:04:05"), payload["nodeGuid"].(string), taskProgress.TotalCount), CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } model.InsertLog(taskLog) taskProgressCmd := &InstructionReq{ Command: TaskLog, Payload: taskProgress, } deliver("task-log", 1, taskProgressCmd) //数据集处理 minioCli := minio.NewClient(config.Cfg.Minio.AccessKeyId, config.Cfg.Minio.SecretAccessKey, config.Cfg.Minio.Endpoint, false, logging.L()) for _, v := range fileList { GoroutinueChan <- true go func(fa model.FileManager, payload map[string]interface{}) { p := make(map[string]interface{}) for key, val := range payload { p[key] = val } dstPath := strings.Replace(fa.AccessUrl, fmt.Sprintf("%s://%s/", config.Cfg.Minio.Protocol, config.Cfg.Minio.Endpoint), "", 1) dstPath = strings.Replace(dstPath, config.Cfg.Minio.Bucket, "", 1) imgByte, _ := minioCli.GetObject(dstPath, config.Cfg.Minio.Bucket) fc := FileCapture{ FileId: fa.FileId, FileName: fa.FileName, File: base64.StdEncoding.EncodeToString(imgByte), DatasetName: p["datasetName"].(string), CaptureTime: fa.CreateAt, } p["single"] = fc taskCode, _ := uuid.NewUUID() p["taskCode"] = taskCode.String() cmd := &InstructionReq{ Command: TaskExecute, Payload: p, } deliver("task-execute", 1, cmd) <-GoroutinueChan }(v, payload) } //} } } case ModelIssue: payload := cmd.Payload.(map[string]interface{}) cmd := &InstructionReq{ Command: ModelIssueRepeater, Payload: payload, } pData, _ := json.Marshal(cmd) cli := GetMqClient("task-execute", 1) if cli != nil { _ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData) } case ModelIssueResponse: payload := cmd.Payload.(map[string]interface{}) //查找下发记录 item := new(model.IssueModel) h, _ := model.DB.Where("model_id = ? and node_id = ?", payload["modelId"].(int64), payload["nodeId"].(int64)).Get(item) pData, _ := json.Marshal(payload) if h { item.Status = 1 item.IssueResult = string(pData) item.UpdateAt = time.Now().Unix() _, _ = model.DB.ID(item.Id).AllCols().Update(item) } else { item.ModelId = int64(payload["modelId"].(float64)) item.NodeId = int64(payload["nodeId"].(float64)) item.Status = 1 item.IssueResult = string(pData) item.CreateAt = time.Now().Unix() item.UpdateAt = time.Now().Unix() _, _ = model.DB.Insert(item) } //case TaskResponse: // payload := cmd.Payload.(map[string]interface{}) // item := new(model.TaskResult) // item.TaskId = int64(payload["taskId"].(float64)) // item.TaskCode = payload["taskCode"].(string) // item.NodeId = int64(payload["nodeId"].(float64)) // item.ModelId = int64(payload["modelId"].(float64)) // item.StartTime = int64(payload["startTime"].(float64)) // item.FinishTime = int64(payload["finishTime"].(float64)) // item.SubDataset = payload["subDataset"].(string) // item.DatasetId = int64(payload["datasetArr"].(float64)) // item.SrcPath = payload["srcPath"].(string) // item.Result = payload["body"].(string) // _, _ = model.DB.Insert(item) // //fn := payload["fileName"].(string) // //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string))) case TrainTaskAdd: payload := cmd.Payload.(map[string]interface{}) if itemId, ok := payload["taskId"].(float64); ok { item := new(model.TrainTask) h, err := model.DB.ID(int64(itemId)).Get(item) if err != nil || !h { } RunTraining(item) } default: } return frame.Tag(cmd.Command), nil } func TaskResponseHandler(data []byte) (frame.Tag, []byte) { cmd := new(InstructionReq) err := json.Unmarshal(data, cmd) if err != nil { return 0x0B, []byte(err.Error()) } switch cmd.Command { case TaskResponse: payload := cmd.Payload.(map[string]interface{}) item := new(model.TaskResult) item.TaskId = int64(payload["taskId"].(float64)) if _, ok := payload["taskCode"]; ok && payload["taskCode"] != nil { item.TaskCode = payload["taskCode"].(string) } if _, ok := payload["fileId"]; ok { item.FileId = int64(payload["fileId"].(float64)) } item.NodeId = int64(payload["nodeId"].(float64)) item.ModelId = int64(payload["modelId"].(float64)) item.StartTime = int64(payload["startTime"].(float64)) item.FinishTime = int64(payload["finishTime"].(float64)) if _, ok := payload["subDataset"]; ok { item.SubDataset = payload["subDataset"].(string) } item.DatasetId, _ = strconv.ParseInt(payload["datasetArr"].(string), 10, 64) if _, ok := payload["srcPath"]; ok && payload["srcPath"] != nil { item.SrcPath = payload["srcPath"].(string) } if _, ok := payload["body"]; ok { item.Result = payload["body"].(string) } isFailing := false if _, ok := payload["code"]; ok && int(payload["code"].(float64)) == 500 { item.Result = payload["msg"].(string) isFailing = true } _, err = model.DB.Insert(item) if err != nil { fmt.Println("接收TaskResponse数据出错", err) } //处理到项目结果表 go processToProjectResult(item) //更新运行进度 processed, unProcessed := model.UpdateTaskProgressByLog(item, isFailing) var ( ratStr string ) if unProcessed > 0 { ratStr = fmt.Sprintf("[已处理[%d],剩余[%d]未处理]", processed, unProcessed) } else { ratStr = "[已全部处理]" } taskLog := new(model.TaskLog) taskLog.TaskId = item.TaskId taskLog.NodeId = item.NodeId if len(item.SrcPath) > 0 { taskLog.Content = fmt.Sprintf("[%s] 图片%s处理完成 %s ", time.Unix(item.FinishTime, 0).Format("2006-01-02 15:04:05"), item.SrcPath, ratStr) } else { taskLog.Content = fmt.Sprintf("[%s] %s", time.Unix(item.FinishTime, 0).Format("2006-01-02 15:04:05"), ratStr) } model.InsertLog(taskLog) //fn := payload["fileName"].(string) //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string))) default: } return frame.Tag(cmd.Command), nil } type ModelResult struct { Code int `json:"code"` } type InsigmaResult struct { Code int `json:"code"` NumOfDiseases int `json:"num_of_diseases"` Diseases []DiseasesInfo `json:"diseases"` Image string `json:"image"` } type DiseasesInfo struct { Id int `json:"id"` Type string `json:"type"` Level string `json:"level"` Param DiseasesParam `json:"param"` } type DiseasesParam struct { Length float64 `json:"length"` Area float64 `json:"area"` MaxWidth string `json:"max_width"` } type LightweightResult struct { Code int `json:"code"` Crack bool `json:"crack"` ImgDiscern string `json:"img_discern"` ImgSrc string `json:"img_src"` Pothole bool `json:"pothole"` } func processToProjectResult(src *model.TaskResult) { project := new(model.Project) h, err := model.DB.Table("project").Alias("p").Join("inner", []string{"dataset", "d"}, "d.project_id= p.project_id").Where("d.dataset_id=?", src.DatasetId).Get(project) if !h { err = fmt.Errorf("未能找到对应的项目信息") } if err != nil { logging.L().With(zap.String("控制节点", "错误信息")).Error("获取项目信息", zap.Error(err)) return } var ( mr ModelResult mrList []string fileDiscern string memo string milepostNumber string upDown string lineNum int width float64 ) switch project.BizType { case 1: //道路 arr := strings.Split(src.SrcPath, " ") if len(arr) > 1 { milepostNumber = GetMilepost(project.StartName, arr[1], arr[2]) if arr[2] == "D" { upDown = "下行" } else { upDown = "上行" } } if len(arr) > 3 { lineNum, _ = strconv.Atoi(arr[3]) } case 2: //桥梁 case 3: //隧道 //隧道名-采集方向(D/X)-相机编号(01-22)-采集序号(五位)K里程桩号.bmp DAXIASHAN-D-05-00003K15069.5.bmp arr := strings.Split(src.SrcPath, "K") if len(arr) > 1 { arrM := strings.Split(arr[1], ".") milepostNumber = meter2Milepost(arrM[0]) arrD := strings.Split(arr[0], ".") if len(arrD) > 1 { if arrD[1] == "D" { upDown = "下行" } else { upDown = "上行" } } if len(arrD) > 4 { lineNum, _ = strconv.Atoi(arrD[3]) } } } if len(src.Result) > 0 && src.Result[0] == '[' { mrList = make([]string, 0) if err := json.Unmarshal([]byte(src.Result), &mrList); err != nil { return } list := make([]*model.ProjectResult, 0) for _, str := range mrList { if err := json.Unmarshal([]byte(str), &mr); err != nil { continue } if mr.Code == 2001 { ir := new(InsigmaResult) if err := json.Unmarshal([]byte(str), &ir); err != nil { continue } fileDiscern = ir.Image for key, value := range ir.Diseases { if len(value.Param.MaxWidth) > 0 && width == 0 { width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64) } else { width = 0 } memo = fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area) item := &model.ProjectResult{ ProjectId: project.ProjectId, SourceResultId: src.ResultId, MilepostNumber: milepostNumber, UpDown: upDown, LineNum: lineNum, DiseaseType: value.Type, DiseaseLevel: value.Level, Length: value.Param.Length, Width: width, Acreage: value.Param.Area, Memo: memo, Result: fileDiscern, Creator: 0, Modifier: 0, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } } } _, _ = model.DB.Insert(list) } else { if err := json.Unmarshal([]byte(src.Result), &mr); err != nil { return } switch mr.Code { case 0: //轻量化模型返回 lr := new(LightweightResult) if err := json.Unmarshal([]byte(src.Result), &lr); err != nil { return } if lr.Crack || lr.Pothole { if lr.Crack { memo = "检测到裂缝" } else { memo = "检测到坑洼" } fileDiscern = lr.ImgDiscern if len(fileDiscern) == 0 { fileDiscern = lr.ImgSrc } diseaseLevelName := "重度" diseaseTypeName := "" switch project.BizType { case 2: diseaseTypeName = "结构裂缝" case 3: diseaseTypeName = "衬砌裂缝" default: diseaseTypeName = "横向裂缝" } item := &model.ProjectResult{ ProjectId: project.ProjectId, SourceResultId: src.ResultId, MilepostNumber: milepostNumber, UpDown: upDown, LineNum: lineNum, DiseaseType: diseaseTypeName, DiseaseLevel: diseaseLevelName, Length: 0, Width: 0, Acreage: 0, Memo: memo, Result: fileDiscern, Creator: 0, Modifier: 0, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } _, _ = model.DB.Insert(item) } else { fileDiscern = lr.ImgSrc } // case 2001: //网新返回有病害 ir := new(InsigmaResult) if err := json.Unmarshal([]byte(src.Result), &ir); err != nil { return } fileDiscern = ir.Image list := make([]*model.ProjectResult, 0) for _, val := range ir.Diseases { if len(val.Param.MaxWidth) > 0 && width == 0 { width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64) } else { width = 0 } memo = fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area) maxWidth, _ := strconv.ParseFloat(val.Param.MaxWidth, 64) item := &model.ProjectResult{ ProjectId: project.ProjectId, SourceResultId: src.ResultId, MilepostNumber: milepostNumber, UpDown: upDown, LineNum: lineNum, DiseaseType: val.Type, DiseaseLevel: val.Level, Length: val.Param.Length, Width: maxWidth, Acreage: val.Param.Area, Memo: memo, Result: fileDiscern, Creator: 0, Modifier: 0, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } _, _ = model.DB.Insert(list) } } } // 里程桩加减里程,返回里程桩 func GetMilepost(start, num, upDown string) string { arr := strings.Split(start, "+") var ( kilometre, meter, milepost, counter, res, resMilepost, resMeter float64 ) if len(arr) == 1 { meter = 0 } else { meter, _ = strconv.ParseFloat(arr[1], 64) } str := strings.Replace(arr[0], "k", "", -1) str = strings.Replace(str, "K", "", -1) kilometre, _ = strconv.ParseFloat(str, 64) milepost = kilometre + meter/1000 counter, _ = strconv.ParseFloat(num, 64) if upDown == "D" { res = milepost - counter } else { res = milepost + counter } resMilepost = math.Floor(res) resMeter = (res - resMilepost) * 100 return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter) } // 米装换成里程桩号 func meter2Milepost(meter string) string { meter = strings.Replace(meter, "K", "", -1) m, _ := strconv.ParseFloat(meter, 64) resMilepost := math.Floor(m / 1000) resMeter := (m - resMilepost*1000) * 100 return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter) } func deliver(topic string, mqType uint, payload interface{}) { cli := GetMqClient(topic, mqType) pData, _ := json.Marshal(payload) _ = GenerateAndSendData(cli.EndPoint.(hpds_node.AccessPoint), pData) } func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) { cmd := new(InstructionReq) err := json.Unmarshal(data, cmd) if err != nil { return 0x0B, []byte(err.Error()) } payload := cmd.Payload.(map[string]interface{}) var l sync.Mutex l.Lock() taskId := int64(payload["taskId"].(float64)) if item, ok := TaskList[taskId]; ok { item.UnfinishedCount -= 1 if int(payload["status"].(float64)) == 1 { item.CompletedCount += 1 } if int(payload["status"].(float64)) == 2 { item.FailingCount += 1 } if item.UnfinishedCount <= 0 || time.Now().Unix()-item.LastSendTime > 5000 { //发送完成消息 taskProgress := &proto.TaskLogProgress{ PayloadType: 1, TaskId: item.TaskId, TotalCount: item.TotalCount, CompletedCount: item.CompletedCount, FailingCount: item.FailingCount, UnfinishedCount: item.UnfinishedCount, } //model.UpdateTaskProgress(taskProgress) taskProgressCmd := &InstructionReq{ Command: TaskLog, Payload: taskProgress, } deliver("task-log", 1, taskProgressCmd) if item.UnfinishedCount <= 0 { delete(TaskList, item.TaskId) } else { item.LastSendTime = time.Now().Unix() } } taskLog := &proto.TaskLogPayload{ PayloadType: 2, TaskId: item.TaskId, TaskCode: payload["taskCode"].(string), NodeId: int64(payload["nodeId"].(float64)), NodeGuid: payload["nodeGuid"].(string), TaskContent: payload["taskContent"].(string), Status: int(payload["status"].(float64)), EventTime: int64(payload["eventTime"].(float64)), } taskLogCmd := &InstructionReq{ Command: TaskLog, Payload: taskLog, } deliver("task-log", 1, taskLogCmd) } l.Unlock() return frame.Tag(cmd.Command), nil } func RunTraining(task *model.TrainTask) { var ( args []string modelPath, modelFileName, testSize string modelAcc, modelLoss float64 ) fmt.Println("curr tmp dir====>>>>", config.Cfg.TmpTrainDir) modelFileName = utils.GetUUIDString() //复制训练数据集 tmpTrainDir := path.Join(config.Cfg.TmpTrainDir, fmt.Sprintf("%s_%s_%d_%d", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize)) fileList := make([]model.TrainingDatasetDetail, 0) _ = model.DB.Where("dataset_id = ?", task.TrainDatasetId).Find(&fileList) _ = os.MkdirAll(tmpTrainDir, os.ModePerm) _ = os.MkdirAll(path.Join(tmpTrainDir, "train"), os.ModePerm) _ = os.MkdirAll(path.Join(tmpTrainDir, "train", "0"), os.ModePerm) _ = os.MkdirAll(path.Join(tmpTrainDir, "train", "1"), os.ModePerm) _ = os.MkdirAll(path.Join(tmpTrainDir, "val"), os.ModePerm) _ = os.MkdirAll(path.Join(tmpTrainDir, "val", "0"), os.ModePerm) _ = os.MkdirAll(path.Join(tmpTrainDir, "val", "1"), os.ModePerm) _ = os.MkdirAll(path.Join(tmpTrainDir, "test"), os.ModePerm) for _, v := range fileList { dstFilePath := "" switch v.CategoryId { case 2: dstFilePath = "test" default: dstFilePath = "train" } if v.CategoryId != 2 { if v.IsDisease == 1 { dstFilePath = path.Join(tmpTrainDir, dstFilePath, "0") } else { dstFilePath = path.Join(tmpTrainDir, dstFilePath, "1") } } else { dstFilePath = path.Join(tmpTrainDir, dstFilePath) } err := utils.CopyFile(v.FilePath, path.Join(dstFilePath, v.FileName)) if err != nil { fmt.Println("copy error: ", err) } } modelPath = path.Join(config.Cfg.ModelOutPath, fmt.Sprintf("%s_%s_%d_%d_%s", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize, task.OutputType)) _ = os.MkdirAll(modelPath, os.ModePerm) dt := new(model.TrainingDataset) _, err := model.DB.ID(task.TrainDatasetId).Get(dt) if err != nil { goto ReturnPoint } testSize = fmt.Sprintf("%.2f", dt.ValidationNumber/100) //执行训练命令 args = []string{"--dataset=" + path.Join(tmpTrainDir, "train"), "--img_size=" + strconv.Itoa(task.ImageSize), "--batch_size=" + strconv.Itoa(task.BatchSize), "--test_size=" + testSize, "--epochs=" + strconv.Itoa(task.EpochsSize), "--model=" + task.Arithmetic, "--model_save=" + path.Join(modelPath, modelFileName+".h5"), } fmt.Println("args====>>>", args) err = ExecCommand(config.Cfg.TrainScriptPath, args, path.Join(modelPath, modelFileName+".log"), task.TaskId) ReturnPoint: //返回训练结果 modelMetricsFile := modelFileName + "_model_metrics.png" task.FinishTime = time.Now().Unix() task.ModelFilePath = path.Join(modelPath, modelFileName+".h5") task.Loss = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Loss:") task.Accuracy = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Accuracy:") task.Status = 3 if err != nil { task.Status = 5 } task.ModelFileMetricsPath = path.Join(modelPath, modelMetricsFile) _, _ = model.DB.ID(task.TaskId).AllCols().Update(task) if utils.PathExists(path.Join(modelPath, modelFileName+".log")) { logContext := utils.ReadFile(path.Join(modelPath, modelFileName+".log")) taskRes := new(model.TrainTaskResult) taskRes.TaskId = task.TaskId taskRes.CreateAt = time.Now().Unix() taskRes.Content = string(logContext) taskRes.Result = path.Join(modelPath, modelMetricsFile) taskRes.Accuracy = modelAcc taskRes.Loss = modelLoss c, err := model.DB.Insert(taskRes) if err != nil { fmt.Println("model.DB.Insert(taskRes) error ========>>>>>>", err) } fmt.Println("model.DB.Insert(taskRes) count ========>>>>>>", c) } else { fmt.Println("logContext========>>>>>>未读取") } } func GetIndicatorByLog(logFileName, indicator string) float64 { logFn, _ := os.Open(logFileName) defer func() { _ = logFn.Close() }() buf := bufio.NewReader(logFn) for { line, err := buf.ReadString('\n') if err != nil { if err == io.EOF { //fmt.Println("File read ok!") break } else { fmt.Println("Read file error!", err) return 0 } } if strings.Index(line, indicator) >= 0 { str := strings.Replace(line, indicator, "", -1) str = strings.Replace(str, "\n", "", -1) value, _ := strconv.ParseFloat(strings.Trim(str, " "), 64) return value } } return 0 } func ExecCommand(cmd string, args []string, logFileName string, taskId int64) (err error) { logFile, _ := os.Create(logFileName) defer func() { _ = logFile.Close() }() fmt.Print("开始训练......") c := exec.Command(cmd, args...) // mac or linux stdout, err := c.StdoutPipe() if err != nil { return err } var ( wg sync.WaitGroup ) wg.Add(1) go func() { defer wg.Done() reader := bufio.NewReader(stdout) var ( epoch int //modelLoss, modelAcc float64 ) for { readString, err := reader.ReadString('\n') if err != nil || err == io.EOF { fmt.Println("训练2===>>>", err) //wg.Done() return } byte2String := ConvertByte2String([]byte(readString), "GB18030") _, _ = fmt.Fprint(logFile, byte2String) if strings.Index(byte2String, "Epoch") >= 0 { str := strings.Replace(byte2String, "Epoch ", "", -1) arr := strings.Split(str, "/") epoch, _ = strconv.Atoi(arr[0]) } if strings.Index(byte2String, "- loss:") > 0 && strings.Index(byte2String, "- accuracy:") > 0 && strings.Index(byte2String, "- val_loss:") > 0 && strings.Index(byte2String, "- val_accuracy:") > 0 { var ( loss, acc, valLoss, valAcc float64 ) arr := strings.Split(byte2String, "-") for _, v := range arr { if strings.Index(v, "loss:") > 0 && strings.Index(v, "val_loss:") < 0 { strLoss := strings.Replace(v, " loss: ", "", -1) loss, _ = strconv.ParseFloat(strings.Trim(strLoss, " "), 64) } if strings.Index(v, "accuracy:") > 0 && strings.Index(v, "val_accuracy:") < 0 { strAcc := strings.Replace(v, " accuracy: ", "", -1) acc, _ = strconv.ParseFloat(strings.Trim(strAcc, " "), 64) } if strings.Index(v, "val_loss:") > 0 { strValLoss := strings.Replace(v, "val_loss: ", "", -1) valLoss, _ = strconv.ParseFloat(strings.Trim(strValLoss, " "), 64) } if strings.Index(v, "val_accuracy:") > 0 { strValAcc := strings.Replace(v, "val_accuracy: ", "", -1) strValAcc = strings.Replace(strValAcc, "\n", "", -1) valAcc, _ = strconv.ParseFloat(strings.Trim(strValAcc, " "), 64) } } taskLog := new(model.TrainTaskLog) taskLog.Epoch = epoch taskLog.TaskId = taskId taskLog.CreateAt = time.Now().Unix() taskLog.Loss = loss taskLog.Accuracy = acc taskLog.ValLoss = valLoss taskLog.ValAccuracy = valAcc _, _ = model.DB.Insert(taskLog) } fmt.Print(byte2String) } }() err = c.Start() if err != nil { fmt.Println("训练3===>>>", err) } wg.Wait() return } func ConvertByte2String(byte []byte, charset Charset) string { var str string switch charset { case GB18030: var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte) str = string(decodeBytes) case UTF8: fallthrough default: str = string(byte) } return str }