package service import ( "bytes" "context" "encoding/base64" "encoding/json" "fmt" "git.hpds.cc/Component/logging" "git.hpds.cc/pavement/hpds_node" "hpds-iot-web/internal/proto" "hpds-iot-web/model" "hpds-iot-web/mq" "image" "net/http" "strconv" "time" "xorm.io/xorm" ) type TaskService interface { TaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) TaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) AddTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) ReRunTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) //EditTask(ctx context.Context, req proto.ModelItemRequest) (rsp *proto.BaseResponse, err error) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *proto.BaseResponse, err error) TaskLog(ctx context.Context, req proto.TaskLogItem) (rsp *proto.BaseResponse, err error) TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) TrainingTaskResult(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) CreateTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) EditTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) ReRunTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) } func NewTaskService(engine *xorm.Engine, logger *logging.Logger) TaskService { return &repo{ engine: engine, logger: logger, } } func (rp *repo) TaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: data := make([]proto.TaskDetail, 0) count, err := rp.engine.Table("task").Alias("t"). Join("inner", []string{"model", "m"}, "t.model_id = m.model_id"). Join("inner", []string{"node", "n"}, "t.node_id = n.node_id"). Cols("t.*", "m.model_name", "n.node_name"). Where("(? = 0 or m.biz_type = ?)", req.BizType, req.BizType). And("(?='' or task_name like ?)", req.TaskName, "%"+req.TaskName+"%"). And("t.start_time >= unix_timestamp(?)", req.StartTime). And("? = 0 or t.start_time <= unix_timestamp(?)", req.FinishTime, req.FinishTime). And("t.status > 0").Limit(int(req.Size), int(((req.Page)-1)*req.Size)). Desc("start_time"). FindAndCount(&data) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, data, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) AddTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var h bool m := new(model.Model) h, err = rp.engine.ID(req.ModelId).Get(m) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的模型") goto ReturnPoint } ds := new(model.Dataset) h, err = rp.engine.ID(req.DatasetArr).Get(ds) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的数据集") goto ReturnPoint } var node *model.Node if req.NodeId > 0 { node = new(model.Node) h, err = rp.engine.ID(req.NodeId).Get(node) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的节点") goto ReturnPoint } } item := &model.Task{ ModelId: req.ModelId, NodeId: req.NodeId, TaskName: req.TaskName, TaskDesc: req.TaskDesc, DatasetArr: fmt.Sprintf("%d", req.DatasetArr), SubDataset: req.SubDataset, SubDataTag: req.SubDataTag, AppointmentTime: req.AppointmentTime, Status: 2, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } if len(req.AppointmentTime) > 0 { var appTime time.Time appTime, err = time.ParseInLocation("2006-01-02 15:04:05", req.AppointmentTime, time.Local) if err != nil { err = fmt.Errorf("时间格式不匹配") goto ReturnPoint } item.StartTime = appTime.Unix() item.Status = 1 } else { item.StartTime = time.Now().Unix() item.Status = 2 } _, err = rp.engine.Insert(item) if err != nil { goto ReturnPoint } //reg, _ := regexp.Compile("\\[.*?\\]") //if ok := reg.FindAll([]byte(item.ResultStorage), 2); len(ok) > 0 { // item.ResultStorage = reg.ReplaceAllString(item.ResultStorage, fmt.Sprintf("%d_%d", item.TaskId, item.ModelId)) // _, err = rp.engine.ID(item.TaskId).Cols("result_storage").Update(item) // if err != nil { // goto ReturnPoint // } //} payload := make(map[string]interface{}) payload["taskId"] = item.TaskId payload["modelId"] = item.ModelId payload["modelVersion"] = m.ModelVersion payload["modelCommand"] = m.ModelCommand payload["nodeId"] = item.NodeId if item.NodeId > 0 && node != nil { payload["nodeGuid"] = node.NodeGuid } payload["inPath"] = m.InPath payload["outPath"] = m.OutPath payload["httpUrl"] = m.HttpUrl payload["datasetArr"] = item.DatasetArr payload["datasetPath"] = ds.StoreName payload["datasetName"] = ds.DatasetName payload["subDataset"] = item.SubDataset payload["subDataTag"] = item.SubDataTag payload["workflow"] = m.Workflow issue := new(model.IssueModel) h, _ = model.DB.Where("model_id=? and node_id =?", req.ModelId, item.NodeId).Get(issue) if h { payload["issueResult"] = issue.IssueResult } mqClient := mq.GetMqClient("task-request", 1) mqPayload := &mq.InstructionReq{ Command: mq.TaskAdd, Payload: payload, } pData, _ := json.Marshal(mqPayload) err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger) rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "新增任务成功" rsp.Err = ctx.Err() rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) ReRunTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: item := new(model.Task) var h bool h, err = rp.engine.ID(req.TaskId).Get(item) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到任务") goto ReturnPoint } m := new(model.Model) h, err = rp.engine.ID(item.ModelId).Get(m) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的模型") goto ReturnPoint } ds := new(model.Dataset) h, err = rp.engine.ID(item.DatasetArr).Get(ds) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的数据集") goto ReturnPoint } var node *model.Node if item.NodeId > 0 { node = new(model.Node) h, err = rp.engine.ID(item.NodeId).Get(node) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的节点") goto ReturnPoint } } payload := make(map[string]interface{}) payload["taskId"] = item.TaskId payload["modelId"] = item.ModelId payload["modelVersion"] = m.ModelVersion payload["modelCommand"] = m.ModelCommand payload["nodeId"] = item.NodeId if item.NodeId > 0 && node != nil { payload["nodeGuid"] = node.NodeGuid } payload["inPath"] = m.InPath payload["outPath"] = m.OutPath payload["httpUrl"] = m.HttpUrl payload["datasetArr"] = item.DatasetArr payload["datasetPath"] = ds.StoreName payload["datasetName"] = ds.DatasetName payload["subDataset"] = item.SubDataset payload["subDataTag"] = item.SubDataTag payload["workflow"] = m.Workflow issue := new(model.IssueModel) h, _ = model.DB.Where("model_id=? and node_id =?", req.ModelId, item.NodeId).Get(issue) if h { payload["issueResult"] = issue.IssueResult } mqClient := mq.GetMqClient("task-request", 1) mqPayload := &mq.InstructionReq{ Command: mq.TaskAdd, Payload: payload, } pData, _ := json.Marshal(mqPayload) err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "重新任务成功" rsp.Err = ctx.Err() rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: item := new(model.Task) var b bool b, err = rp.engine.ID(req.TaskId).Get(item) if err != nil { goto ReturnPoint } if !b { err = fmt.Errorf("未能找到对应的任务") goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: taskResultList := make([]model.TaskResult, 0) err = rp.engine.Where("task_id = ?", req.TaskId). Limit(int(req.Size), int(((req.Page)-1)*req.Size)). Find(&taskResultList) if err != nil { err = fmt.Errorf("未能找到对应的结果") goto ReturnPoint } list := make([]proto.TaskResultItem, 0) for _, v := range taskResultList { var h bool file := new(model.FileManager) h, err = rp.engine.ID(v.FileId).Get(file) if err != nil || !h { continue } md := new(model.Model) h, err = rp.engine.ID(v.ModelId).Get(md) if err != nil || !h { continue } var ( mr mq.ModelResult mrList []string fileDiscern string memo string diseaseType int64 diseaseLevel int length float64 area float64 width float64 diseaseTypeName string diseaseLevelName string ) if len(v.Result) > 0 && v.Result[0] == '[' { mrList = make([]string, 0) if err := json.Unmarshal([]byte(v.Result), &mrList); err != nil { continue } for _, str := range mrList { if err := json.Unmarshal([]byte(str), &mr); err != nil { continue } switch md.BizType { case 1: //道路 case 2: //桥梁 case 3: //隧道 } switch mr.Code { case 0: //轻量化模型返回 lr := new(mq.LightweightResult) if err := json.Unmarshal([]byte(v.Result), lr); err != nil { continue } //for _, val := range lrList { if lr.Crack || lr.Pothole { if lr.Crack { memo = "检测到裂缝" } else { memo = "检测到坑洼" } fileDiscern = lr.ImgDiscern diseaseLevel = 3 diseaseLevelName = "重度" switch md.BizType { case 2: diseaseType = 8 diseaseTypeName = "结构裂缝" case 3: diseaseType = 15 diseaseTypeName = "衬砌裂缝" default: diseaseType = 4 diseaseTypeName = "横向裂缝" } } fn, _ := base64.StdEncoding.DecodeString(fileDiscern) buff := bytes.NewBuffer(fn) _, imgType, _ := image.Decode(buff) if len(fileDiscern) == 0 { fileDiscern = lr.ImgSrc } fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern) item := proto.TaskResultItem{ FileId: v.FileId, FileName: v.SrcPath, SrcFile: file.AccessUrl, DistFile: fileDiscern, DiseaseType: int(diseaseType), DiseaseTypeName: diseaseTypeName, DiseaseLevel: diseaseLevel, DiseaseLevelName: diseaseLevelName, KPile: "", UpDown: 0, LineNum: 0, Length: length, Width: width, Area: area, HorizontalPositions: 0, Memo: memo, Stat: false, } list = append(list, item) //} case 2000: ir := new(mq.InsigmaResult) if err := json.Unmarshal([]byte(str), &ir); err != nil { continue } fileDiscern = ir.Image fn, _ := base64.StdEncoding.DecodeString(fileDiscern) buff := bytes.NewBuffer(fn) _, imgType, _ := image.Decode(buff) fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern) item := proto.TaskResultItem{ FileId: v.FileId, FileName: v.SrcPath, SrcFile: file.AccessUrl, DistFile: fileDiscern, DiseaseType: int(diseaseType), DiseaseTypeName: diseaseTypeName, DiseaseLevel: diseaseLevel, DiseaseLevelName: diseaseLevelName, KPile: "", UpDown: 0, LineNum: 0, Length: length, Width: width, Area: area, HorizontalPositions: 0, Memo: memo, Stat: false, } list = append(list, item) case 2001: ir := new(mq.InsigmaResult) if err := json.Unmarshal([]byte(str), &ir); err != nil { continue } fileDiscern = ir.Image for key, value := range ir.Diseases { diseaseType = model.GetDiseaseType(value.Type, md.BizType) if len(value.Param.MaxWidth) > 0 && width == 0 { width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64) } length = value.Param.Length area = value.Param.Area diseaseLevelName = value.Level diseaseTypeName = value.Type switch value.Level { case "重度": diseaseLevel = 3 case "中度": diseaseLevel = 2 case "轻度": diseaseLevel = 1 } memo += fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area) } fn, _ := base64.StdEncoding.DecodeString(fileDiscern) buff := bytes.NewBuffer(fn) _, imgType, _ := image.Decode(buff) fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern) item := proto.TaskResultItem{ FileId: v.FileId, FileName: v.SrcPath, SrcFile: file.AccessUrl, DistFile: fileDiscern, DiseaseType: int(diseaseType), DiseaseTypeName: diseaseTypeName, DiseaseLevel: diseaseLevel, DiseaseLevelName: diseaseLevelName, KPile: "", UpDown: 0, LineNum: 0, Length: length, Width: width, Area: area, HorizontalPositions: 0, Memo: memo, Stat: false, } list = append(list, item) } } } else { if err := json.Unmarshal([]byte(v.Result), &mr); err != nil { continue } switch mr.Code { case 0: //轻量化模型返回 lr := new(mq.LightweightResult) if err := json.Unmarshal([]byte(v.Result), &lr); err != nil { continue } if lr.Crack || lr.Pothole { if lr.Crack { memo = "检测到裂缝" } else { memo = "检测到坑洼" } fileDiscern = lr.ImgDiscern if len(fileDiscern) == 0 { fileDiscern = lr.ImgSrc } diseaseLevel = 3 diseaseLevelName = "重度" switch md.BizType { case 2: diseaseType = 8 diseaseTypeName = "结构裂缝" case 3: diseaseType = 15 diseaseTypeName = "衬砌裂缝" default: diseaseType = 4 diseaseTypeName = "横向裂缝" } } else { fileDiscern = lr.ImgSrc } // case 2000: //网新返回没有病害 ir := new(mq.InsigmaResult) if err := json.Unmarshal([]byte(v.Result), &ir); err != nil { continue } fileDiscern = ir.Image case 2001: //网新返回有病害 ir := new(mq.InsigmaResult) if err := json.Unmarshal([]byte(v.Result), &ir); err != nil { continue } fileDiscern = ir.Image for _, val := range ir.Diseases { diseaseType = model.GetDiseaseType(val.Type, md.BizType) if len(val.Param.MaxWidth) > 0 && width == 0 { width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64) } length = val.Param.Length area = val.Param.Area diseaseLevelName = val.Level diseaseTypeName = val.Type switch val.Level { case "重度": diseaseLevel = 3 case "中度": diseaseLevel = 2 case "轻度": diseaseLevel = 1 } memo += fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area) } } fn, _ := base64.StdEncoding.DecodeString(fileDiscern) buff := bytes.NewBuffer(fn) _, imgType, _ := image.Decode(buff) fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern) item := proto.TaskResultItem{ FileId: v.FileId, FileName: v.SrcPath, SrcFile: file.AccessUrl, DistFile: fileDiscern, DiseaseType: int(diseaseType), DiseaseTypeName: diseaseTypeName, DiseaseLevel: diseaseLevel, DiseaseLevelName: diseaseLevelName, KPile: "", UpDown: 0, LineNum: 0, Length: length, Width: width, Area: area, HorizontalPositions: 0, Memo: memo, Stat: false, } list = append(list, item) } } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = list return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TaskLog(ctx context.Context, req proto.TaskLogItem) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: list := make([]model.TaskLog, 0) err = rp.engine.Where("task_id = ?", req.TaskId).And("task_log_id>?", req.LogId).Find(&list) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = list return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( count int64 ) list := make([]model.TrainTask, 0) count, err = rp.engine.Where("(? = 0 or category_id = ?)", req.BizType, req.BizType). And("(? = '' or task_name like ?)", req.TaskName, "%"+req.TaskName+"%"). Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp = FillPaging(count, req.Page, req.Size, list, rsp) return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool ) task := new(model.TrainTask) h, err = rp.engine.ID(req.TaskId).Get(task) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的任务信息") goto ReturnPoint } item := new(proto.TrainTaskInfoItem) item.TaskId = task.TaskId item.TrainDatasetId = task.TrainDatasetId item.TrainDatasetName = model.GetTrainDatasetName(task.TrainDatasetId) item.CategoryId = task.CategoryId item.CategoryName = model.GetBizType(task.CategoryId) item.TaskName = task.TaskName item.TaskDesc = task.TaskDesc item.Arithmetic = task.Arithmetic item.ImageSize = task.ImageSize item.BatchSize = task.BatchSize item.EpochsSize = task.EpochsSize item.OutputType = task.OutputType item.StartTime = task.StartTime item.FinishTime = task.FinishTime item.Loss = task.Loss item.Accuracy = task.Accuracy item.ModelFilePath = task.ModelFilePath item.ModelFileMetricsPath = task.ModelFileMetricsPath item.Status = task.Status item.CreateAt = task.CreateAt item.UpdateAt = task.UpdateAt rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: list := make([]model.TrainTaskLog, 0) err = rp.engine.Where("task_id = ?", req.TaskId).Asc("epoch").Find(&list) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp = FillPaging(int64(len(list)), 1, 1000, list, rsp) return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TrainingTaskResult(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool ) item := new(model.TrainTaskResult) h, err = rp.engine.Where("task_id = ?", req.TaskId).Desc("create_at").Get(item) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("任务还未训练完成") goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) CreateTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool ) item := new(model.TrainTask) h, err = rp.engine.Where("task_name=?", req.ModelName).Get(item) if err != nil { goto ReturnPoint } if h { err = fmt.Errorf("已存在同名任务,请修改后继续") goto ReturnPoint } item.TaskName = req.ModelName item.TrainDatasetId = req.DatasetId item.CategoryId = req.BizType item.TaskDesc = req.ModelDesc item.Arithmetic = req.Arithmetic item.ImageSize = req.ImageSize item.BatchSize = req.BatchSize item.EpochsSize = req.EpochsSize item.OutputType = req.OutputType item.StartTime = time.Now().Unix() item.Status = 2 _, err = rp.engine.Insert(item) if err != nil { goto ReturnPoint } payload := make(map[string]interface{}) payload["taskId"] = item.TaskId payload["taskName"] = item.TaskName payload["trainDatasetId"] = item.TrainDatasetId payload["arithmetic"] = item.Arithmetic payload["imageSize"] = item.ImageSize payload["batchSize"] = item.BatchSize payload["epochsSize"] = item.EpochsSize payload["outputType"] = item.OutputType payload["testSize"] = item.OutputType mqClient := mq.GetMqClient("task-request", 1) mqPayload := &mq.InstructionReq{ Command: mq.TrainTaskAdd, Payload: payload, } pData, _ := json.Marshal(mqPayload) err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger) rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) EditTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool ) item := new(model.TrainTask) h, err = rp.engine.ID(req.TaskId).Get(item) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("为找到对应任务") goto ReturnPoint } item.TaskName = req.ModelName item.TrainDatasetId = req.DatasetId item.CategoryId = req.BizType item.TaskDesc = req.ModelDesc item.Arithmetic = req.Arithmetic item.ImageSize = req.ImageSize item.BatchSize = req.BatchSize item.EpochsSize = req.EpochsSize item.OutputType = req.OutputType item.StartTime = time.Now().Unix() item.Status = 2 _, err = rp.engine.ID(item.TaskId).AllCols().Update(item) if err != nil { goto ReturnPoint } payload := make(map[string]interface{}) payload["taskId"] = item.TaskId payload["taskName"] = item.TaskName payload["trainDatasetId"] = item.TrainDatasetId payload["arithmetic"] = item.Arithmetic payload["imageSize"] = item.ImageSize payload["batchSize"] = item.BatchSize payload["epochsSize"] = item.EpochsSize payload["outputType"] = item.OutputType payload["testSize"] = item.OutputType mqClient := mq.GetMqClient("task-request", 1) mqPayload := &mq.InstructionReq{ Command: mq.TrainTaskAdd, Payload: payload, } pData, _ := json.Marshal(mqPayload) err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger) rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) ReRunTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool ) item := new(model.TrainTask) h, err = rp.engine.ID(req.TaskId).Get(item) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("为找到对应任务") goto ReturnPoint } item.StartTime = time.Now().Unix() item.Status = 2 _, err = rp.engine.ID(item.TaskId).AllCols().Update(item) if err != nil { goto ReturnPoint } payload := make(map[string]interface{}) payload["taskId"] = item.TaskId payload["taskName"] = item.TaskName payload["trainDatasetId"] = item.TrainDatasetId payload["arithmetic"] = item.Arithmetic payload["imageSize"] = item.ImageSize payload["batchSize"] = item.BatchSize payload["epochsSize"] = item.EpochsSize payload["outputType"] = item.OutputType payload["testSize"] = item.OutputType mqClient := mq.GetMqClient("task-request", 1) mqPayload := &mq.InstructionReq{ Command: mq.TrainTaskAdd, Payload: payload, } pData, _ := json.Marshal(mqPayload) err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger) rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err }