hpds_jkw_web/internal/service/task.go

1079 lines
32 KiB
Go
Raw Permalink Normal View History

2023-03-23 18:03:09 +08:00
package service
import (
"bytes"
2023-03-23 18:03:09 +08:00
"context"
"encoding/base64"
2023-03-23 18:03:09 +08:00
"encoding/json"
"fmt"
"git.hpds.cc/Component/logging"
"git.hpds.cc/pavement/hpds_node"
"hpds-iot-web/internal/proto"
"hpds-iot-web/model"
"hpds-iot-web/mq"
"image"
2023-03-23 18:03:09 +08:00
"net/http"
"strconv"
2023-03-23 18:03:09 +08:00
"time"
"xorm.io/xorm"
)
type TaskService interface {
TaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error)
TaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
2023-03-23 18:03:09 +08:00
AddTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
ReRunTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
2023-03-23 18:03:09 +08:00
//EditTask(ctx context.Context, req proto.ModelItemRequest) (rsp *proto.BaseResponse, err error)
TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *proto.BaseResponse, err error)
2023-05-14 18:23:12 +08:00
TaskLog(ctx context.Context, req proto.TaskLogItem) (rsp *proto.BaseResponse, err error)
TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error)
TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
2023-06-17 09:38:26 +08:00
TrainingTaskResult(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
CreateTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error)
EditTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error)
ReRunTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error)
2023-03-23 18:03:09 +08:00
}
func NewTaskService(engine *xorm.Engine, logger *logging.Logger) TaskService {
return &repo{
engine: engine,
logger: logger,
}
}
func (rp *repo) TaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
data := make([]proto.TaskDetail, 0)
count, err := rp.engine.Table("task").Alias("t").
Join("inner", []string{"model", "m"}, "t.model_id = m.model_id").
Join("inner", []string{"node", "n"}, "t.node_id = n.node_id").
Cols("t.*", "m.model_name", "n.node_name").
Where("(? = 0 or m.biz_type = ?)", req.BizType, req.BizType).
And("(?='' or task_name like ?)", req.TaskName, "%"+req.TaskName+"%").
And("t.start_time >= unix_timestamp(?)", req.StartTime).
And("? = 0 or t.start_time <= unix_timestamp(?)", req.FinishTime, req.FinishTime).
And("t.status > 0").Limit(int(req.Size), int(((req.Page)-1)*req.Size)).
Desc("start_time").
2023-03-23 18:03:09 +08:00
FindAndCount(&data)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp = FillPaging(count, req.Page, req.Size, data, rsp)
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) AddTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var h bool
m := new(model.Model)
h, err = rp.engine.ID(req.ModelId).Get(m)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的模型")
goto ReturnPoint
}
ds := new(model.Dataset)
h, err = rp.engine.ID(req.DatasetArr).Get(ds)
2023-03-23 18:03:09 +08:00
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的数据集")
goto ReturnPoint
}
var node *model.Node
if req.NodeId > 0 {
node = new(model.Node)
h, err = rp.engine.ID(req.NodeId).Get(node)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的节点")
goto ReturnPoint
}
}
2023-03-23 18:03:09 +08:00
item := &model.Task{
ModelId: req.ModelId,
NodeId: req.NodeId,
TaskName: req.TaskName,
TaskDesc: req.TaskDesc,
DatasetArr: fmt.Sprintf("%d", req.DatasetArr),
SubDataset: req.SubDataset,
SubDataTag: req.SubDataTag,
AppointmentTime: req.AppointmentTime,
2023-05-14 18:23:12 +08:00
Status: 2,
2023-03-23 18:03:09 +08:00
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
if len(req.AppointmentTime) > 0 {
var appTime time.Time
appTime, err = time.ParseInLocation("2006-01-02 15:04:05", req.AppointmentTime, time.Local)
if err != nil {
err = fmt.Errorf("时间格式不匹配")
goto ReturnPoint
}
item.StartTime = appTime.Unix()
2023-05-14 18:23:12 +08:00
item.Status = 1
2023-03-23 18:03:09 +08:00
} else {
item.StartTime = time.Now().Unix()
2023-05-14 18:23:12 +08:00
item.Status = 2
2023-03-23 18:03:09 +08:00
}
_, err = rp.engine.Insert(item)
if err != nil {
goto ReturnPoint
}
//reg, _ := regexp.Compile("\\[.*?\\]")
//if ok := reg.FindAll([]byte(item.ResultStorage), 2); len(ok) > 0 {
// item.ResultStorage = reg.ReplaceAllString(item.ResultStorage, fmt.Sprintf("%d_%d", item.TaskId, item.ModelId))
// _, err = rp.engine.ID(item.TaskId).Cols("result_storage").Update(item)
// if err != nil {
// goto ReturnPoint
// }
//}
payload := make(map[string]interface{})
payload["taskId"] = item.TaskId
payload["modelId"] = item.ModelId
payload["modelVersion"] = m.ModelVersion
payload["modelCommand"] = m.ModelCommand
payload["nodeId"] = item.NodeId
if item.NodeId > 0 && node != nil {
payload["nodeGuid"] = node.NodeGuid
}
2023-03-23 18:03:09 +08:00
payload["inPath"] = m.InPath
payload["outPath"] = m.OutPath
payload["httpUrl"] = m.HttpUrl
payload["datasetArr"] = item.DatasetArr
payload["datasetPath"] = ds.StoreName
payload["datasetName"] = ds.DatasetName
payload["subDataset"] = item.SubDataset
payload["subDataTag"] = item.SubDataTag
payload["workflow"] = m.Workflow
issue := new(model.IssueModel)
h, _ = model.DB.Where("model_id=? and node_id =?", req.ModelId, item.NodeId).Get(issue)
if h {
payload["issueResult"] = issue.IssueResult
}
2023-03-23 18:03:09 +08:00
mqClient := mq.GetMqClient("task-request", 1)
mqPayload := &mq.InstructionReq{
Command: mq.TaskAdd,
Payload: payload,
}
pData, _ := json.Marshal(mqPayload)
err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger)
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "新增任务成功"
rsp.Err = ctx.Err()
rsp.Data = item
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) ReRunTask(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
item := new(model.Task)
var h bool
h, err = rp.engine.ID(req.TaskId).Get(item)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到任务")
goto ReturnPoint
}
m := new(model.Model)
h, err = rp.engine.ID(item.ModelId).Get(m)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的模型")
goto ReturnPoint
}
ds := new(model.Dataset)
h, err = rp.engine.ID(item.DatasetArr).Get(ds)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的数据集")
goto ReturnPoint
}
var node *model.Node
if item.NodeId > 0 {
node = new(model.Node)
h, err = rp.engine.ID(item.NodeId).Get(node)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的节点")
goto ReturnPoint
}
}
payload := make(map[string]interface{})
payload["taskId"] = item.TaskId
payload["modelId"] = item.ModelId
payload["modelVersion"] = m.ModelVersion
payload["modelCommand"] = m.ModelCommand
payload["nodeId"] = item.NodeId
if item.NodeId > 0 && node != nil {
payload["nodeGuid"] = node.NodeGuid
}
payload["inPath"] = m.InPath
payload["outPath"] = m.OutPath
payload["httpUrl"] = m.HttpUrl
payload["datasetArr"] = item.DatasetArr
payload["datasetPath"] = ds.StoreName
payload["datasetName"] = ds.DatasetName
payload["subDataset"] = item.SubDataset
payload["subDataTag"] = item.SubDataTag
payload["workflow"] = m.Workflow
issue := new(model.IssueModel)
h, _ = model.DB.Where("model_id=? and node_id =?", req.ModelId, item.NodeId).Get(issue)
if h {
payload["issueResult"] = issue.IssueResult
}
mqClient := mq.GetMqClient("task-request", 1)
mqPayload := &mq.InstructionReq{
Command: mq.TaskAdd,
Payload: payload,
}
pData, _ := json.Marshal(mqPayload)
err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "重新任务成功"
rsp.Err = ctx.Err()
rsp.Data = item
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
item := new(model.Task)
var b bool
b, err = rp.engine.ID(req.TaskId).Get(item)
if err != nil {
goto ReturnPoint
}
if !b {
err = fmt.Errorf("未能找到对应的任务")
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = item
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
taskResultList := make([]model.TaskResult, 0)
err = rp.engine.Where("task_id = ?", req.TaskId).
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).
Find(&taskResultList)
if err != nil {
err = fmt.Errorf("未能找到对应的结果")
goto ReturnPoint
}
list := make([]proto.TaskResultItem, 0)
for _, v := range taskResultList {
var h bool
file := new(model.FileManager)
h, err = rp.engine.ID(v.FileId).Get(file)
if err != nil || !h {
continue
}
md := new(model.Model)
h, err = rp.engine.ID(v.ModelId).Get(md)
if err != nil || !h {
continue
}
var (
mr mq.ModelResult
mrList []string
fileDiscern string
memo string
diseaseType int64
diseaseLevel int
length float64
area float64
width float64
diseaseTypeName string
diseaseLevelName string
)
if len(v.Result) > 0 && v.Result[0] == '[' {
mrList = make([]string, 0)
if err := json.Unmarshal([]byte(v.Result), &mrList); err != nil {
continue
}
for _, str := range mrList {
if err := json.Unmarshal([]byte(str), &mr); err != nil {
continue
}
switch md.BizType {
case 1: //道路
case 2: //桥梁
case 3: //隧道
}
switch mr.Code {
case 0: //轻量化模型返回
lr := new(mq.LightweightResult)
if err := json.Unmarshal([]byte(v.Result), lr); err != nil {
continue
}
//for _, val := range lrList {
if lr.Crack || lr.Pothole {
if lr.Crack {
memo = "检测到裂缝"
} else {
memo = "检测到坑洼"
}
fileDiscern = lr.ImgDiscern
diseaseLevel = 3
diseaseLevelName = "重度"
switch md.BizType {
case 2:
diseaseType = 8
diseaseTypeName = "结构裂缝"
case 3:
diseaseType = 15
diseaseTypeName = "衬砌裂缝"
default:
diseaseType = 4
diseaseTypeName = "横向裂缝"
}
}
fn, _ := base64.StdEncoding.DecodeString(fileDiscern)
buff := bytes.NewBuffer(fn)
_, imgType, _ := image.Decode(buff)
if len(fileDiscern) == 0 {
fileDiscern = lr.ImgSrc
}
fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern)
item := proto.TaskResultItem{
FileId: v.FileId,
FileName: v.SrcPath,
SrcFile: file.AccessUrl,
DistFile: fileDiscern,
DiseaseType: int(diseaseType),
DiseaseTypeName: diseaseTypeName,
DiseaseLevel: diseaseLevel,
DiseaseLevelName: diseaseLevelName,
KPile: "",
UpDown: 0,
LineNum: 0,
Length: length,
Width: width,
Area: area,
HorizontalPositions: 0,
Memo: memo,
Stat: false,
}
list = append(list, item)
//}
case 2000:
ir := new(mq.InsigmaResult)
if err := json.Unmarshal([]byte(str), &ir); err != nil {
continue
}
fileDiscern = ir.Image
fn, _ := base64.StdEncoding.DecodeString(fileDiscern)
buff := bytes.NewBuffer(fn)
_, imgType, _ := image.Decode(buff)
fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern)
item := proto.TaskResultItem{
FileId: v.FileId,
FileName: v.SrcPath,
SrcFile: file.AccessUrl,
DistFile: fileDiscern,
DiseaseType: int(diseaseType),
DiseaseTypeName: diseaseTypeName,
DiseaseLevel: diseaseLevel,
DiseaseLevelName: diseaseLevelName,
KPile: "",
UpDown: 0,
LineNum: 0,
Length: length,
Width: width,
Area: area,
HorizontalPositions: 0,
Memo: memo,
Stat: false,
}
list = append(list, item)
case 2001:
ir := new(mq.InsigmaResult)
if err := json.Unmarshal([]byte(str), &ir); err != nil {
continue
}
fileDiscern = ir.Image
2023-06-17 09:38:26 +08:00
for key, value := range ir.Diseases {
diseaseType = model.GetDiseaseType(value.Type, md.BizType)
if len(value.Param.MaxWidth) > 0 && width == 0 {
width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64)
}
length = value.Param.Length
area = value.Param.Area
diseaseLevelName = value.Level
diseaseTypeName = value.Type
switch value.Level {
case "重度":
diseaseLevel = 3
case "中度":
diseaseLevel = 2
case "轻度":
diseaseLevel = 1
}
2023-06-17 09:38:26 +08:00
memo += fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area)
}
fn, _ := base64.StdEncoding.DecodeString(fileDiscern)
buff := bytes.NewBuffer(fn)
_, imgType, _ := image.Decode(buff)
fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern)
item := proto.TaskResultItem{
FileId: v.FileId,
FileName: v.SrcPath,
SrcFile: file.AccessUrl,
DistFile: fileDiscern,
DiseaseType: int(diseaseType),
DiseaseTypeName: diseaseTypeName,
DiseaseLevel: diseaseLevel,
DiseaseLevelName: diseaseLevelName,
KPile: "",
UpDown: 0,
LineNum: 0,
Length: length,
Width: width,
Area: area,
HorizontalPositions: 0,
Memo: memo,
Stat: false,
}
list = append(list, item)
}
}
} else {
if err := json.Unmarshal([]byte(v.Result), &mr); err != nil {
continue
}
switch mr.Code {
case 0: //轻量化模型返回
lr := new(mq.LightweightResult)
if err := json.Unmarshal([]byte(v.Result), &lr); err != nil {
continue
}
if lr.Crack || lr.Pothole {
if lr.Crack {
memo = "检测到裂缝"
} else {
memo = "检测到坑洼"
}
fileDiscern = lr.ImgDiscern
if len(fileDiscern) == 0 {
fileDiscern = lr.ImgSrc
}
diseaseLevel = 3
diseaseLevelName = "重度"
switch md.BizType {
case 2:
diseaseType = 8
diseaseTypeName = "结构裂缝"
case 3:
diseaseType = 15
diseaseTypeName = "衬砌裂缝"
default:
diseaseType = 4
diseaseTypeName = "横向裂缝"
}
} else {
fileDiscern = lr.ImgSrc
}
//
case 2000: //网新返回没有病害
ir := new(mq.InsigmaResult)
if err := json.Unmarshal([]byte(v.Result), &ir); err != nil {
continue
}
fileDiscern = ir.Image
case 2001: //网新返回有病害
ir := new(mq.InsigmaResult)
if err := json.Unmarshal([]byte(v.Result), &ir); err != nil {
continue
}
fileDiscern = ir.Image
for _, val := range ir.Diseases {
diseaseType = model.GetDiseaseType(val.Type, md.BizType)
if len(val.Param.MaxWidth) > 0 && width == 0 {
width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64)
}
length = val.Param.Length
area = val.Param.Area
diseaseLevelName = val.Level
diseaseTypeName = val.Type
switch val.Level {
case "重度":
diseaseLevel = 3
case "中度":
diseaseLevel = 2
case "轻度":
diseaseLevel = 1
}
memo += fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area)
}
}
fn, _ := base64.StdEncoding.DecodeString(fileDiscern)
buff := bytes.NewBuffer(fn)
_, imgType, _ := image.Decode(buff)
fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern)
item := proto.TaskResultItem{
FileId: v.FileId,
FileName: v.SrcPath,
SrcFile: file.AccessUrl,
DistFile: fileDiscern,
DiseaseType: int(diseaseType),
DiseaseTypeName: diseaseTypeName,
DiseaseLevel: diseaseLevel,
DiseaseLevelName: diseaseLevelName,
KPile: "",
UpDown: 0,
LineNum: 0,
Length: length,
Width: width,
Area: area,
HorizontalPositions: 0,
Memo: memo,
Stat: false,
}
list = append(list, item)
}
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = list
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
2023-05-14 18:23:12 +08:00
func (rp *repo) TaskLog(ctx context.Context, req proto.TaskLogItem) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
list := make([]model.TaskLog, 0)
2023-05-14 18:23:12 +08:00
err = rp.engine.Where("task_id = ?", req.TaskId).And("task_log_id>?", req.LogId).Find(&list)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = list
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
count int64
)
list := make([]model.TrainTask, 0)
count, err = rp.engine.Where("(? = 0 or category_id = ?)", req.BizType, req.BizType).
And("(? = '' or task_name like ?)", req.TaskName, "%"+req.TaskName+"%").
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp = FillPaging(count, req.Page, req.Size, list, rsp)
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
h bool
)
2023-06-17 09:38:26 +08:00
task := new(model.TrainTask)
h, err = rp.engine.ID(req.TaskId).Get(task)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的任务信息")
goto ReturnPoint
}
2023-06-17 09:38:26 +08:00
item := new(proto.TrainTaskInfoItem)
item.TaskId = task.TaskId
item.TrainDatasetId = task.TrainDatasetId
item.TrainDatasetName = model.GetTrainDatasetName(task.TrainDatasetId)
item.CategoryId = task.CategoryId
item.CategoryName = model.GetBizType(task.CategoryId)
item.TaskName = task.TaskName
item.TaskDesc = task.TaskDesc
item.Arithmetic = task.Arithmetic
item.ImageSize = task.ImageSize
item.BatchSize = task.BatchSize
item.EpochsSize = task.EpochsSize
item.OutputType = task.OutputType
item.StartTime = task.StartTime
item.FinishTime = task.FinishTime
item.Loss = task.Loss
item.Accuracy = task.Accuracy
item.ModelFilePath = task.ModelFilePath
item.ModelFileMetricsPath = task.ModelFileMetricsPath
item.Status = task.Status
item.CreateAt = task.CreateAt
item.UpdateAt = task.UpdateAt
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = item
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
list := make([]model.TrainTaskLog, 0)
err = rp.engine.Where("task_id = ?", req.TaskId).Asc("epoch").Find(&list)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp = FillPaging(int64(len(list)), 1, 1000, list, rsp)
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
2023-06-17 09:38:26 +08:00
func (rp *repo) TrainingTaskResult(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
h bool
)
item := new(model.TrainTaskResult)
h, err = rp.engine.Where("task_id = ?", req.TaskId).Desc("create_at").Get(item)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("任务还未训练完成")
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = item
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) CreateTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
h bool
)
item := new(model.TrainTask)
h, err = rp.engine.Where("task_name=?", req.ModelName).Get(item)
if err != nil {
goto ReturnPoint
}
if h {
err = fmt.Errorf("已存在同名任务,请修改后继续")
goto ReturnPoint
}
item.TaskName = req.ModelName
item.TrainDatasetId = req.DatasetId
item.CategoryId = req.BizType
item.TaskDesc = req.ModelDesc
item.Arithmetic = req.Arithmetic
item.ImageSize = req.ImageSize
item.BatchSize = req.BatchSize
item.EpochsSize = req.EpochsSize
item.OutputType = req.OutputType
item.StartTime = time.Now().Unix()
item.Status = 2
_, err = rp.engine.Insert(item)
if err != nil {
goto ReturnPoint
}
payload := make(map[string]interface{})
payload["taskId"] = item.TaskId
payload["taskName"] = item.TaskName
payload["trainDatasetId"] = item.TrainDatasetId
payload["arithmetic"] = item.Arithmetic
payload["imageSize"] = item.ImageSize
payload["batchSize"] = item.BatchSize
payload["epochsSize"] = item.EpochsSize
payload["outputType"] = item.OutputType
payload["testSize"] = item.OutputType
mqClient := mq.GetMqClient("task-request", 1)
mqPayload := &mq.InstructionReq{
Command: mq.TrainTaskAdd,
Payload: payload,
}
pData, _ := json.Marshal(mqPayload)
err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger)
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) EditTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
h bool
)
item := new(model.TrainTask)
h, err = rp.engine.ID(req.TaskId).Get(item)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("为找到对应任务")
goto ReturnPoint
}
item.TaskName = req.ModelName
item.TrainDatasetId = req.DatasetId
item.CategoryId = req.BizType
item.TaskDesc = req.ModelDesc
item.Arithmetic = req.Arithmetic
item.ImageSize = req.ImageSize
item.BatchSize = req.BatchSize
item.EpochsSize = req.EpochsSize
item.OutputType = req.OutputType
item.StartTime = time.Now().Unix()
item.Status = 2
_, err = rp.engine.ID(item.TaskId).AllCols().Update(item)
if err != nil {
goto ReturnPoint
}
payload := make(map[string]interface{})
payload["taskId"] = item.TaskId
payload["taskName"] = item.TaskName
payload["trainDatasetId"] = item.TrainDatasetId
payload["arithmetic"] = item.Arithmetic
payload["imageSize"] = item.ImageSize
payload["batchSize"] = item.BatchSize
payload["epochsSize"] = item.EpochsSize
payload["outputType"] = item.OutputType
payload["testSize"] = item.OutputType
mqClient := mq.GetMqClient("task-request", 1)
mqPayload := &mq.InstructionReq{
Command: mq.TrainTaskAdd,
Payload: payload,
}
pData, _ := json.Marshal(mqPayload)
err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger)
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) ReRunTrainingTask(ctx context.Context, req proto.TrainingTaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
h bool
)
item := new(model.TrainTask)
h, err = rp.engine.ID(req.TaskId).Get(item)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("为找到对应任务")
goto ReturnPoint
}
item.StartTime = time.Now().Unix()
item.Status = 2
_, err = rp.engine.ID(item.TaskId).AllCols().Update(item)
if err != nil {
goto ReturnPoint
}
payload := make(map[string]interface{})
payload["taskId"] = item.TaskId
payload["taskName"] = item.TaskName
payload["trainDatasetId"] = item.TrainDatasetId
payload["arithmetic"] = item.Arithmetic
payload["imageSize"] = item.ImageSize
payload["batchSize"] = item.BatchSize
payload["epochsSize"] = item.EpochsSize
payload["outputType"] = item.OutputType
payload["testSize"] = item.OutputType
mqClient := mq.GetMqClient("task-request", 1)
mqPayload := &mq.InstructionReq{
Command: mq.TrainTaskAdd,
Payload: payload,
}
pData, _ := json.Marshal(mqPayload)
err = mq.GenerateAndSendData(mqClient.EndPoint.(hpds_node.AccessPoint), pData, rp.logger)
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}