1、增加任务处理进度

This commit is contained in:
wangjian 2023-06-17 09:39:13 +08:00
parent 7d09ba0286
commit 90d43d468a
21 changed files with 1180 additions and 28 deletions

View File

@ -15,17 +15,20 @@ var (
)
type ControlCenterConfig struct {
Name string `yaml:"name,omitempty"`
Host string `yaml:"host,omitempty"`
Port int `yaml:"port,omitempty"`
Mode string `yaml:"mode,omitempty"`
Consul ConsulConfig `yaml:"consul,omitempty"`
Db DbConfig `yaml:"db"`
Cache CacheConfig `yaml:"cache"`
Logging LogOptions `yaml:"logging"`
Minio MinioConfig `yaml:"minio"`
Node HpdsNode `yaml:"node,omitempty"`
Funcs []FuncConfig `yaml:"functions,omitempty"`
Name string `yaml:"name,omitempty"`
Host string `yaml:"host,omitempty"`
Port int `yaml:"port,omitempty"`
Mode string `yaml:"mode,omitempty"`
TmpTrainDir string `yaml:"tmpTrainDir"`
TrainScriptPath string `yaml:"trainScriptPath"`
ModelOutPath string `yaml:"modelOutPath"`
Consul ConsulConfig `yaml:"consul,omitempty"`
Db DbConfig `yaml:"db"`
Cache CacheConfig `yaml:"cache"`
Logging LogOptions `yaml:"logging"`
Minio MinioConfig `yaml:"minio"`
Node HpdsNode `yaml:"node,omitempty"`
Funcs []FuncConfig `yaml:"functions,omitempty"`
}
type ConsulConfig struct {

View File

@ -2,6 +2,9 @@ name: control_center
host: 0.0.0.0
port: 8088
mode: dev
tmpTrainDir: ./tmp
trainScriptPath: ./scripts/runTrainScript.sh
modelOutPath: ./out
logging:
path: ./logs
prefix: hpds-control

7
go.mod
View File

@ -7,12 +7,14 @@ require (
git.hpds.cc/Component/network v0.0.0-20230405135741-a4ea724bab76
git.hpds.cc/pavement/hpds_node v0.0.0-20230405153516-9403c4d01e12
github.com/go-sql-driver/mysql v1.7.0
github.com/google/uuid v1.3.0
github.com/hashicorp/consul/api v1.20.0
github.com/minio/minio-go v6.0.14+incompatible
github.com/minio/minio-go/v7 v7.0.52
github.com/spf13/cobra v1.6.1
github.com/spf13/viper v1.15.0
go.uber.org/zap v1.23.0
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8
golang.org/x/text v0.7.0
gopkg.in/yaml.v3 v3.0.1
xorm.io/xorm v1.3.2
)
@ -30,7 +32,6 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fatih/color v1.13.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
github.com/goccy/go-json v0.8.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
@ -40,7 +41,6 @@ require (
github.com/golang/snappy v0.0.4 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.1 // indirect
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
@ -94,7 +94,6 @@ require (
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/text v0.7.0 // indirect
golang.org/x/time v0.1.0 // indirect
golang.org/x/tools v0.2.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect

View File

@ -4,7 +4,7 @@ type TaskLog struct {
TaskLogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskLogId"`
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
NodeId int64 `xorm:"INT(11) index" json:"nodeId"`
Content string `xorm:"LANGTEXT" json:"content"`
Content string `xorm:"LONGTEXT" json:"content"`
CreateAt int64 `xorm:"created" json:"createAt"`
UpdateAt int64 `xorm:"updated" json:"updateAt"`
}

23
model/diseaseType.go Normal file
View File

@ -0,0 +1,23 @@
package model
// DiseaseType 病害类别
type DiseaseType struct {
TypeId int64 `xorm:"not null pk autoincr INT(11)" json:"typeId"`
TypeName string `xorm:"varchar(200) not null" json:"typeName"`
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //病害分类, 1:道路 2:桥梁 3:隧道 4:边坡
Status int `xorm:"not null INT(11) default 0" json:"status"`
CreateAt int64 `xorm:"created" json:"createAt"`
UpdateAt int64 `xorm:"updated" json:"updateAt"`
}
func GetDiseaseType(name string, categoryId int) int64 {
item := new(DiseaseType)
h, err := DB.Where("type_name like ?", "%"+name+"%").
And("category_id = ?", categoryId).
And("status = 1").Get(item)
if err != nil || !h {
return 0
}
return item.TypeId
}

View File

@ -26,6 +26,11 @@ func New(driveName, dsn string, showSql bool) {
&Task{},
&TaskLog{},
&TaskResult{},
&TrainingDataset{},
&TrainingDatasetDetail{},
&TrainTask{},
&TrainTaskLog{},
&TrainTaskResult{},
)
if err != nil {
fmt.Println("同步数据库表结构", err)

23
model/project.go Normal file
View File

@ -0,0 +1,23 @@
package model
type Project struct {
ProjectId int64 `xorm:"not null pk autoincr INT(11)" json:"projectId"`
ProjectName string `xorm:"varchar(200) not null " json:"projectName"`
OwnerId int64 `xorm:"not null INT(11) default 0" json:"ownerId"`
BizType int `xorm:"SMALLINT" json:"bizType"`
LineName string `xorm:"varchar(200) not null " json:"lineName"`
StartName string `xorm:"varchar(200) not null " json:"startName"`
EndName string `xorm:"varchar(200) not null " json:"endName"`
FixedDeviceNum int `xorm:"not null INT(11) default 0" json:"fixedDeviceNum"`
Direction string `xorm:"varchar(200) not null " json:"direction"`
LaneNum int `xorm:"not null INT(4) default 0" json:"laneNum"`
StartPointLng float64 `xorm:"decimal(18,6)" json:"startPointLng"`
StartPointLat float64 `xorm:"decimal(18,6)" json:"startPointLat"`
EndPointLng float64 `xorm:"decimal(18,6)" json:"endPointLng"`
EndPointLat float64 `xorm:"decimal(18,6)" json:"endPointLat"`
Status int `xorm:"SMALLINT default 1" json:"status"`
Creator int64 `xorm:"INT(11) default 0" json:"creator"`
Modifier int64 `xorm:"INT(11) default 0" json:"modifier"`
CreateAt int64 `xorm:"created" json:"createAt"`
UpdateAt int64 `xorm:"updated" json:"updateAt"`
}

21
model/projectResult.go Normal file
View File

@ -0,0 +1,21 @@
package model
type ProjectResult struct {
Id int64 `xorm:"not null pk autoincr BIGINT" json:"id"`
ProjectId int64 `xorm:"INT(11) index" json:"projectId"` //项目编号
SourceResultId int64 `xorm:"INT(11) index" json:"sourceResultId"` //识别结果来源编号
MilepostNumber string `xorm:"VARCHAR(50)" json:"milepostNumber"` //里程桩号
UpDown string `xorm:"VARCHAR(20)" json:"upDown"` //上下行
LineNum int `xorm:"SMALLINT default 1" json:"lineNum"` //车道号
DiseaseType string `xorm:"VARCHAR(50)" json:"diseaseType"` //病害类型
DiseaseLevel string `xorm:"VARCHAR(20)" json:"diseaseLevel"` //病害等级
Length float64 `xorm:"decimal(18,6)" json:"length"` //长度
Width float64 `xorm:"decimal(18,6)" json:"width"` //宽度
Acreage float64 `xorm:"decimal(18,6)" json:"acreage"` //面积
Memo string `xorm:"VARCHAR(1000)" json:"memo"` //备注说明
Result string `xorm:"LONGTEXT" json:"result"` //识别结果
Creator int64 `xorm:"INT(11) default 0" json:"creator"`
Modifier int64 `xorm:"INT(11) default 0" json:"modifier"`
CreateAt int64 `xorm:"created" json:"createAt"`
UpdateAt int64 `xorm:"updated" json:"updateAt"`
}

View File

@ -62,12 +62,11 @@ func UpdateTaskProgress(taskProgress *proto.TaskLogProgress) {
}
}
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
ret := -1.0
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) (int, int) {
item := new(Task)
h, err := DB.ID(res.TaskId).Get(item)
if err != nil || !h {
return ret
return 0, 0
}
if isFailing {
item.FailingCount += 1
@ -79,12 +78,11 @@ func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
item.FinishTime = time.Now().Unix()
item.UnfinishedCount = 0
item.Status = 3
ret = 1.0
}
item.UpdateAt = time.Now().Unix()
_, _ = DB.ID(res.TaskId).Cols("completed_count", "failing_count", "total_count", "unfinished_count", "update_at", "finish_time", "status").Update(item)
if item.TotalCount > 0 {
return 1 - float64(item.CompletedCount)/float64(item.TotalCount)
return int(item.CompletedCount), int(item.UnfinishedCount)
}
return ret
return int(item.CompletedCount), int(item.UnfinishedCount)
}

23
model/trainTask.go Normal file
View File

@ -0,0 +1,23 @@
package model
type TrainTask struct {
TaskId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskId"`
TrainDatasetId int64 `xorm:"INT(11) index" json:"trainDatasetId"`
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
TaskName string `xorm:"VARCHAR(200)" json:"taskName"`
TaskDesc string `xorm:"VARCHAR(500)" json:"taskDesc"`
Arithmetic string `xorm:"VARCHAR(50)" json:"arithmetic"`
ImageSize int `xorm:"INT" json:"imageSize"`
BatchSize int `xorm:"INT" json:"batchSize"`
EpochsSize int `xorm:"INT" json:"epochsSize"`
OutputType string `xorm:"VARCHAR(20)" json:"outputType"`
StartTime int64 `xorm:"BIGINT" json:"startTime"`
FinishTime int64 `xorm:"BIGINT" json:"finishTime"`
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
ModelFilePath string `xorm:"VARCHAR(2000)" json:"modelFilePath"`
ModelFileMetricsPath string `xorm:"VARCHAR(2000)" json:"modelFileMetricsPath"`
Status int `xorm:"not null SMALLINT default 0" json:"status"` // 1:等待执行; 2:执行中; 3:执行完成; 4:任务分配失败; 5:任务执行失败
CreateAt int64 `xorm:"created" json:"createAt"`
UpdateAt int64 `xorm:"updated" json:"updateAt"`
}

12
model/trainTaskLog.go Normal file
View File

@ -0,0 +1,12 @@
package model
type TrainTaskLog struct {
LogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"logId"`
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
Epoch int `xorm:"SMALLINT" json:"epoch"`
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
ValLoss float64 `xorm:"DECIMAL(18,6)" json:"valLoss"`
ValAccuracy float64 `xorm:"DECIMAL(18,6)" json:"valAccuracy"`
CreateAt int64 `xorm:"created" json:"createAt"`
}

11
model/trainTaskResult.go Normal file
View File

@ -0,0 +1,11 @@
package model
type TrainTaskResult struct {
ResultId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"resultId"`
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
Content string `xorm:"LONGTEXT" json:"content"`
Result string `xorm:"VARCHAR(200)" json:"result"`
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
CreateAt int64 `xorm:"created" json:"createAt"`
}

13
model/trainingDataset.go Normal file
View File

@ -0,0 +1,13 @@
package model
type TrainingDataset struct {
DatasetId int64 `xorm:"not null pk autoincr INT(11)" json:"datasetId"`
Name string `xorm:"VARCHAR(200)" json:"name"`
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
DatasetDesc string `xorm:"varchar(200)" json:"datasetDesc"`
StoreName string `xorm:"varchar(200)" json:"storeName"` //存储路径
ValidationNumber float64 `xorm:"DECIMAL(18,4)" json:"validationNumber"` //验证占比
TestNumber float64 `xorm:"DECIMAL(18,4)" json:"testNumber"` //测试占比
CreateAt int64 `xorm:"created" json:"createAt"`
UpdateAt int64 `xorm:"updated" json:"updateAt"`
}

View File

@ -0,0 +1,16 @@
package model
type TrainingDatasetDetail struct {
DetailId int64 `xorm:"not null pk autoincr INT(11)" json:"detailId"`
FileName string `xorm:"VARCHAR(200)" json:"fileName"`
FilePath string `xorm:"VARCHAR(1000)" json:"filePath"`
DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //训练数据集
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //训练集分类1训练集;2测试集;3:验证集
FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小
FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5
IsDisease int `xorm:"TINYINT(1)" json:"isDisease"` //是否有病害, 1:有病害;2:无病害;
OperationLogId int64 `xorm:"INT(11) index" json:"operationLogId"` //操作日志编号
Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人
CreateAt int64 `xorm:"created" json:"createAt"` //上传时间
UpdateAt int64 `xorm:"updated" json:"updateAt"` //更新时间
}

View File

@ -1,6 +1,7 @@
package mq
import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
@ -8,12 +9,18 @@ import (
"git.hpds.cc/Component/network/frame"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/text/encoding/simplifiedchinese"
"hpds_control_center/config"
"hpds_control_center/internal/balance"
"hpds_control_center/internal/minio"
"hpds_control_center/internal/proto"
"hpds_control_center/model"
"hpds_control_center/pkg/utils"
"io"
"math"
"os"
"os/exec"
"path"
"strconv"
"strings"
"sync"
@ -22,6 +29,13 @@ import (
"git.hpds.cc/pavement/hpds_node"
)
type Charset string
const (
UTF8 = Charset("UTF-8")
GB18030 = Charset("GB18030")
)
var (
MqList []HpdsMqNode
TaskList = make(map[int64]*TaskItem)
@ -379,8 +393,8 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
item.UpdateAt = time.Now().Unix()
_, _ = model.DB.ID(item.Id).AllCols().Update(item)
} else {
item.ModelId = payload["modelId"].(int64)
item.NodeId = payload["nodeId"].(int64)
item.ModelId = int64(payload["modelId"].(float64))
item.NodeId = int64(payload["nodeId"].(float64))
item.Status = 1
item.IssueResult = string(pData)
item.CreateAt = time.Now().Unix()
@ -403,7 +417,16 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
// _, _ = model.DB.Insert(item)
// //fn := payload["fileName"].(string)
// //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
case TrainTaskAdd:
payload := cmd.Payload.(map[string]interface{})
if itemId, ok := payload["taskId"].(float64); ok {
item := new(model.TrainTask)
h, err := model.DB.ID(int64(itemId)).Get(item)
if err != nil || !h {
}
RunTraining(item)
}
default:
}
@ -450,14 +473,16 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
if err != nil {
fmt.Println("接收TaskResponse数据出错", err)
}
//处理到项目结果表
go processToProjectResult(item)
//更新运行进度
rat := model.UpdateTaskProgressByLog(item, isFailing)
processed, unProcessed := model.UpdateTaskProgressByLog(item, isFailing)
var (
ratStr string
)
if rat > 0 && rat < 1 {
ratStr = fmt.Sprintf("[已处理%2.f,剩余%2.f未处理]", 1-rat, rat)
} else if rat == 1 {
if unProcessed > 0 {
ratStr = fmt.Sprintf("[已处理[%d],剩余[%d]未处理]", processed, unProcessed)
} else {
ratStr = "[已全部处理]"
}
taskLog := new(model.TaskLog)
@ -479,6 +504,267 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
return frame.Tag(cmd.Command), nil
}
type ModelResult struct {
Code int `json:"code"`
}
type InsigmaResult struct {
Code int `json:"code"`
NumOfDiseases int `json:"num_of_diseases"`
Diseases []DiseasesInfo `json:"diseases"`
Image string `json:"image"`
}
type DiseasesInfo struct {
Id int `json:"id"`
Type string `json:"type"`
Level string `json:"level"`
Param DiseasesParam `json:"param"`
}
type DiseasesParam struct {
Length float64 `json:"length"`
Area float64 `json:"area"`
MaxWidth string `json:"max_width"`
}
type LightweightResult struct {
Code int `json:"code"`
Crack bool `json:"crack"`
ImgDiscern string `json:"img_discern"`
ImgSrc string `json:"img_src"`
Pothole bool `json:"pothole"`
}
func processToProjectResult(src *model.TaskResult) {
project := new(model.Project)
h, err := model.DB.Table("project").Alias("p").Join("inner", []string{"dataset", "d"}, "d.project_id= p.project_id").Where("d.dataset_id=?", src.DatasetId).Get(project)
if !h {
err = fmt.Errorf("未能找到对应的项目信息")
}
if err != nil {
logging.L().With(zap.String("控制节点", "错误信息")).Error("获取项目信息", zap.Error(err))
return
}
var (
mr ModelResult
mrList []string
fileDiscern string
memo string
milepostNumber string
upDown string
lineNum int
width float64
)
switch project.BizType {
case 1: //道路
arr := strings.Split(src.SrcPath, " ")
if len(arr) > 1 {
milepostNumber = GetMilepost(project.StartName, arr[1], arr[2])
if arr[2] == "D" {
upDown = "下行"
} else {
upDown = "上行"
}
}
if len(arr) > 3 {
lineNum, _ = strconv.Atoi(arr[3])
}
case 2: //桥梁
case 3: //隧道
//隧道名-采集方向(D/X)-相机编号(01-22)-采集序号五位K里程桩号.bmp DAXIASHAN-D-05-00003K15069.5.bmp
arr := strings.Split(src.SrcPath, "K")
if len(arr) > 1 {
arrM := strings.Split(arr[1], ".")
milepostNumber = meter2Milepost(arrM[0])
arrD := strings.Split(arr[0], ".")
if len(arrD) > 1 {
if arrD[1] == "D" {
upDown = "下行"
} else {
upDown = "上行"
}
}
if len(arrD) > 4 {
lineNum, _ = strconv.Atoi(arrD[3])
}
}
}
if len(src.Result) > 0 && src.Result[0] == '[' {
mrList = make([]string, 0)
if err := json.Unmarshal([]byte(src.Result), &mrList); err != nil {
return
}
list := make([]*model.ProjectResult, 0)
for _, str := range mrList {
if err := json.Unmarshal([]byte(str), &mr); err != nil {
continue
}
if mr.Code == 2001 {
ir := new(InsigmaResult)
if err := json.Unmarshal([]byte(str), &ir); err != nil {
continue
}
fileDiscern = ir.Image
for key, value := range ir.Diseases {
if len(value.Param.MaxWidth) > 0 && width == 0 {
width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64)
} else {
width = 0
}
memo = fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area)
item := &model.ProjectResult{
ProjectId: project.ProjectId,
SourceResultId: src.ResultId,
MilepostNumber: milepostNumber,
UpDown: upDown,
LineNum: lineNum,
DiseaseType: value.Type,
DiseaseLevel: value.Level,
Length: value.Param.Length,
Width: width,
Acreage: value.Param.Area,
Memo: memo,
Result: fileDiscern,
Creator: 0,
Modifier: 0,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
list = append(list, item)
}
}
}
_, _ = model.DB.Insert(list)
} else {
if err := json.Unmarshal([]byte(src.Result), &mr); err != nil {
return
}
switch mr.Code {
case 0: //轻量化模型返回
lr := new(LightweightResult)
if err := json.Unmarshal([]byte(src.Result), &lr); err != nil {
return
}
if lr.Crack || lr.Pothole {
if lr.Crack {
memo = "检测到裂缝"
} else {
memo = "检测到坑洼"
}
fileDiscern = lr.ImgDiscern
if len(fileDiscern) == 0 {
fileDiscern = lr.ImgSrc
}
diseaseLevelName := "重度"
diseaseTypeName := ""
switch project.BizType {
case 2:
diseaseTypeName = "结构裂缝"
case 3:
diseaseTypeName = "衬砌裂缝"
default:
diseaseTypeName = "横向裂缝"
}
item := &model.ProjectResult{
ProjectId: project.ProjectId,
SourceResultId: src.ResultId,
MilepostNumber: milepostNumber,
UpDown: upDown,
LineNum: lineNum,
DiseaseType: diseaseTypeName,
DiseaseLevel: diseaseLevelName,
Length: 0,
Width: 0,
Acreage: 0,
Memo: memo,
Result: fileDiscern,
Creator: 0,
Modifier: 0,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
_, _ = model.DB.Insert(item)
} else {
fileDiscern = lr.ImgSrc
}
//
case 2001: //网新返回有病害
ir := new(InsigmaResult)
if err := json.Unmarshal([]byte(src.Result), &ir); err != nil {
return
}
fileDiscern = ir.Image
list := make([]*model.ProjectResult, 0)
for _, val := range ir.Diseases {
if len(val.Param.MaxWidth) > 0 && width == 0 {
width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64)
} else {
width = 0
}
memo = fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area)
maxWidth, _ := strconv.ParseFloat(val.Param.MaxWidth, 64)
item := &model.ProjectResult{
ProjectId: project.ProjectId,
SourceResultId: src.ResultId,
MilepostNumber: milepostNumber,
UpDown: upDown,
LineNum: lineNum,
DiseaseType: val.Type,
DiseaseLevel: val.Level,
Length: val.Param.Length,
Width: maxWidth,
Acreage: val.Param.Area,
Memo: memo,
Result: fileDiscern,
Creator: 0,
Modifier: 0,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
list = append(list, item)
}
_, _ = model.DB.Insert(list)
}
}
}
// 里程桩加减里程,返回里程桩
func GetMilepost(start, num, upDown string) string {
arr := strings.Split(start, "+")
var (
kilometre, meter, milepost, counter, res, resMilepost, resMeter float64
)
if len(arr) == 1 {
meter = 0
} else {
meter, _ = strconv.ParseFloat(arr[1], 64)
}
str := strings.Replace(arr[0], "k", "", -1)
str = strings.Replace(str, "K", "", -1)
kilometre, _ = strconv.ParseFloat(str, 64)
milepost = kilometre + meter/1000
counter, _ = strconv.ParseFloat(num, 64)
if upDown == "D" {
res = milepost - counter
} else {
res = milepost + counter
}
resMilepost = math.Floor(res)
resMeter = (res - resMilepost) * 100
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
}
// 米装换成里程桩号
func meter2Milepost(meter string) string {
meter = strings.Replace(meter, "K", "", -1)
m, _ := strconv.ParseFloat(meter, 64)
resMilepost := math.Floor(m / 1000)
resMeter := (m - resMilepost*1000) * 100
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
}
func deliver(topic string, mqType uint, payload interface{}) {
cli := GetMqClient(topic, mqType)
pData, _ := json.Marshal(payload)
@ -544,3 +830,220 @@ func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) {
l.Unlock()
return frame.Tag(cmd.Command), nil
}
func RunTraining(task *model.TrainTask) {
var (
args []string
modelPath, modelFileName, testSize string
modelAcc, modelLoss float64
)
fmt.Println("curr tmp dir====>>>>", config.Cfg.TmpTrainDir)
modelFileName = utils.GetUUIDString()
//复制训练数据集
tmpTrainDir := path.Join(config.Cfg.TmpTrainDir, fmt.Sprintf("%s_%s_%d_%d", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize))
fileList := make([]model.TrainingDatasetDetail, 0)
_ = model.DB.Where("dataset_id = ?", task.TrainDatasetId).Find(&fileList)
_ = os.MkdirAll(tmpTrainDir, os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "train"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "0"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "1"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "val"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "0"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "1"), os.ModePerm)
_ = os.MkdirAll(path.Join(tmpTrainDir, "test"), os.ModePerm)
for _, v := range fileList {
dstFilePath := ""
switch v.CategoryId {
case 2:
dstFilePath = "test"
default:
dstFilePath = "train"
}
if v.CategoryId != 2 {
if v.IsDisease == 1 {
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "0")
} else {
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "1")
}
} else {
dstFilePath = path.Join(tmpTrainDir, dstFilePath)
}
err := utils.CopyFile(v.FilePath, path.Join(dstFilePath, v.FileName))
if err != nil {
fmt.Println("copy error: ", err)
}
}
modelPath = path.Join(config.Cfg.ModelOutPath, fmt.Sprintf("%s_%s_%d_%d_%s", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize, task.OutputType))
_ = os.MkdirAll(modelPath, os.ModePerm)
dt := new(model.TrainingDataset)
_, err := model.DB.ID(task.TrainDatasetId).Get(dt)
if err != nil {
goto ReturnPoint
}
testSize = fmt.Sprintf("%.2f", dt.ValidationNumber/100)
//执行训练命令
args = []string{"--dataset=" + path.Join(tmpTrainDir, "train"),
"--img_size=" + strconv.Itoa(task.ImageSize), "--batch_size=" + strconv.Itoa(task.BatchSize), "--test_size=" + testSize,
"--epochs=" + strconv.Itoa(task.EpochsSize), "--model=" + task.Arithmetic, "--model_save=" + path.Join(modelPath, modelFileName+".h5"),
}
fmt.Println("args====>>>", args)
err = ExecCommand(config.Cfg.TrainScriptPath, args, path.Join(modelPath, modelFileName+".log"), task.TaskId)
ReturnPoint:
//返回训练结果
modelMetricsFile := modelFileName + "_model_metrics.png"
task.FinishTime = time.Now().Unix()
task.ModelFilePath = path.Join(modelPath, modelFileName+".h5")
task.Loss = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Loss:")
task.Accuracy = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Accuracy:")
task.Status = 3
if err != nil {
task.Status = 5
}
task.ModelFileMetricsPath = path.Join(modelPath, modelMetricsFile)
_, _ = model.DB.ID(task.TaskId).AllCols().Update(task)
if utils.PathExists(path.Join(modelPath, modelFileName+".log")) {
logContext := utils.ReadFile(path.Join(modelPath, modelFileName+".log"))
taskRes := new(model.TrainTaskResult)
taskRes.TaskId = task.TaskId
taskRes.CreateAt = time.Now().Unix()
taskRes.Content = string(logContext)
taskRes.Result = path.Join(modelPath, modelMetricsFile)
taskRes.Accuracy = modelAcc
taskRes.Loss = modelLoss
c, err := model.DB.Insert(taskRes)
if err != nil {
fmt.Println("model.DB.Insert(taskRes) error ========>>>>>>", err)
}
fmt.Println("model.DB.Insert(taskRes) count ========>>>>>>", c)
} else {
fmt.Println("logContext========>>>>>>未读取")
}
}
func GetIndicatorByLog(logFileName, indicator string) float64 {
logFn, _ := os.Open(logFileName)
defer func() {
_ = logFn.Close()
}()
buf := bufio.NewReader(logFn)
for {
line, err := buf.ReadString('\n')
if err != nil {
if err == io.EOF {
//fmt.Println("File read ok!")
break
} else {
fmt.Println("Read file error!", err)
return 0
}
}
if strings.Index(line, indicator) >= 0 {
str := strings.Replace(line, indicator, "", -1)
str = strings.Replace(str, "\n", "", -1)
value, _ := strconv.ParseFloat(strings.Trim(str, " "), 64)
return value
}
}
return 0
}
func ExecCommand(cmd string, args []string, logFileName string, taskId int64) (err error) {
logFile, _ := os.Create(logFileName)
defer func() {
_ = logFile.Close()
}()
fmt.Print("开始训练......")
c := exec.Command(cmd, args...) // mac or linux
stdout, err := c.StdoutPipe()
if err != nil {
return err
}
var (
wg sync.WaitGroup
)
wg.Add(1)
go func() {
defer wg.Done()
reader := bufio.NewReader(stdout)
var (
epoch int
//modelLoss, modelAcc float64
)
for {
readString, err := reader.ReadString('\n')
if err != nil || err == io.EOF {
fmt.Println("训练2===>>>", err)
//wg.Done()
return
}
byte2String := ConvertByte2String([]byte(readString), "GB18030")
_, _ = fmt.Fprint(logFile, byte2String)
if strings.Index(byte2String, "Epoch") >= 0 {
str := strings.Replace(byte2String, "Epoch ", "", -1)
arr := strings.Split(str, "/")
epoch, _ = strconv.Atoi(arr[0])
}
if strings.Index(byte2String, "- loss:") > 0 &&
strings.Index(byte2String, "- accuracy:") > 0 &&
strings.Index(byte2String, "- val_loss:") > 0 &&
strings.Index(byte2String, "- val_accuracy:") > 0 {
var (
loss, acc, valLoss, valAcc float64
)
arr := strings.Split(byte2String, "-")
for _, v := range arr {
if strings.Index(v, "loss:") > 0 && strings.Index(v, "val_loss:") < 0 {
strLoss := strings.Replace(v, " loss: ", "", -1)
loss, _ = strconv.ParseFloat(strings.Trim(strLoss, " "), 64)
}
if strings.Index(v, "accuracy:") > 0 && strings.Index(v, "val_accuracy:") < 0 {
strAcc := strings.Replace(v, " accuracy: ", "", -1)
acc, _ = strconv.ParseFloat(strings.Trim(strAcc, " "), 64)
}
if strings.Index(v, "val_loss:") > 0 {
strValLoss := strings.Replace(v, "val_loss: ", "", -1)
valLoss, _ = strconv.ParseFloat(strings.Trim(strValLoss, " "), 64)
}
if strings.Index(v, "val_accuracy:") > 0 {
strValAcc := strings.Replace(v, "val_accuracy: ", "", -1)
strValAcc = strings.Replace(strValAcc, "\n", "", -1)
valAcc, _ = strconv.ParseFloat(strings.Trim(strValAcc, " "), 64)
}
}
taskLog := new(model.TrainTaskLog)
taskLog.Epoch = epoch
taskLog.TaskId = taskId
taskLog.CreateAt = time.Now().Unix()
taskLog.Loss = loss
taskLog.Accuracy = acc
taskLog.ValLoss = valLoss
taskLog.ValAccuracy = valAcc
_, _ = model.DB.Insert(taskLog)
}
fmt.Print(byte2String)
}
}()
err = c.Start()
if err != nil {
fmt.Println("训练3===>>>", err)
}
wg.Wait()
return
}
func ConvertByte2String(byte []byte, charset Charset) string {
var str string
switch charset {
case GB18030:
var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte)
str = string(decodeBytes)
case UTF8:
fallthrough
default:
str = string(byte)
}
return str
}

View File

@ -9,6 +9,7 @@ const (
ModelIssueResponse
TaskExecuteLog
TaskLog
TrainTaskAdd
)
type InstructionReq struct {

94
pkg/utils/file.go Normal file
View File

@ -0,0 +1,94 @@
package utils
import (
"crypto/md5"
"encoding/base64"
"encoding/hex"
"fmt"
"git.hpds.cc/Component/logging"
"go.uber.org/zap"
"io"
"os"
"path"
"path/filepath"
"strings"
)
func CopyFile(src, dst string) error {
sourceFileStat, err := os.Stat(src)
if err != nil {
return err
}
if !sourceFileStat.Mode().IsRegular() {
return fmt.Errorf("%s is not a regular file", src)
}
source, err := os.Open(src)
if err != nil {
return err
}
defer func(source *os.File) {
_ = source.Close()
}(source)
destination, err := os.Create(dst)
if err != nil {
return err
}
defer func(destination *os.File) {
_ = destination.Close()
}(destination)
_, err = io.Copy(destination, source)
return err
}
func PathExists(path string) bool {
_, err := os.Stat(path)
if err == nil {
return true
}
if os.IsNotExist(err) {
return false
}
return false
}
// ReadFile 读取到file中再利用ioutil将file直接读取到[]byte中, 这是最优
func ReadFile(fn string) []byte {
f, err := os.Open(fn)
if err != nil {
logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err))
return nil
}
defer func(f *os.File) {
_ = f.Close()
}(f)
fd, err := io.ReadAll(f)
if err != nil {
logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err))
return nil
}
return fd
}
func GetFileName(fn string) string {
fileType := path.Ext(fn)
return strings.TrimSuffix(fn, fileType)
}
func GetFileNameAndExt(fn string) string {
_, fileName := filepath.Split(fn)
return fileName
}
func GetFileMd5(data []byte) string {
hash := md5.New()
hash.Write(data)
return hex.EncodeToString(hash.Sum(nil))
}
func FileToBase64(fn string) string {
buff := ReadFile(fn)
return base64.StdEncoding.EncodeToString(buff) // 加密成base64字符串
}

126
pkg/utils/http.go Normal file
View File

@ -0,0 +1,126 @@
package utils
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"path/filepath"
"strings"
)
func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) {
var paramStr string = ""
if contentType, ok := header["Content-Type"]; ok && strings.Contains(contentType, "json") {
bytesData, _ := json.Marshal(params)
paramStr = string(bytesData)
} else {
for k, v := range params {
if len(paramStr) == 0 {
paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v))
} else {
paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v))
}
}
}
client := &http.Client{}
req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr))
if err != nil {
return nil, err
}
for k, v := range header {
req.Header.Set(k, v)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() {
if resp.Body != nil {
err = resp.Body.Close()
if err != nil {
return
}
}
}()
var body []byte
if resp.Body != nil {
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
}
return body, nil
}
type UploadFile struct {
// 表单名称
Name string
Filepath string
// 文件全路径
File *bytes.Buffer
}
func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string {
requestBody, realContentType := getReader(reqParams, contentType, files)
httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody)
// 添加请求头
httpRequest.Header.Add("Content-Type", realContentType)
if headers != nil {
for k, v := range headers {
httpRequest.Header.Add(k, v)
}
}
httpClient := &http.Client{}
// 发送请求
resp, err := httpClient.Do(httpRequest)
if err != nil {
panic(err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
response, _ := io.ReadAll(resp.Body)
return string(response)
}
func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) {
if strings.Index(contentType, "json") > -1 {
bytesData, _ := json.Marshal(reqParams)
return bytes.NewReader(bytesData), contentType
} else if files != nil {
body := &bytes.Buffer{}
// 文件写入 body
writer := multipart.NewWriter(body)
for _, uploadFile := range files {
part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath))
if err != nil {
panic(err)
}
_, err = io.Copy(part, uploadFile.File)
}
// 其他参数列表写入 body
for k, v := range reqParams {
if err := writer.WriteField(k, v); err != nil {
panic(err)
}
}
if err := writer.Close(); err != nil {
panic(err)
}
// 上传文件需要自己专用的contentType
return body, writer.FormDataContentType()
} else {
urlValues := url.Values{}
for key, val := range reqParams {
urlValues.Set(key, val)
}
reqBody := urlValues.Encode()
return strings.NewReader(reqBody), contentType
}
}

139
pkg/utils/image.go Normal file
View File

@ -0,0 +1,139 @@
package utils
import (
"bytes"
"encoding/base64"
"golang.org/x/image/bmp"
"golang.org/x/image/tiff"
"image"
"image/color"
"image/jpeg"
"image/png"
)
func BuffToImage(in []byte) image.Image {
buff := bytes.NewBuffer(in)
m, _, _ := image.Decode(buff)
return m
}
// Clip 图片裁剪
func Clip(in []byte, wi, hi int, equalProportion bool) (out image.Image, imageType string, err error) {
buff := bytes.NewBuffer(in)
m, imgType, _ := image.Decode(buff)
rgbImg := m.(*image.YCbCr)
if equalProportion {
w := m.Bounds().Max.X
h := m.Bounds().Max.Y
if w > 0 && h > 0 && wi > 0 && hi > 0 {
wi, hi = fixSize(w, h, wi, hi)
}
}
return rgbImg.SubImage(image.Rect(0, 0, wi, hi)), imgType, nil
}
func fixSize(img1W, img2H, wi, hi int) (new1W, new2W int) {
var ( //为了方便计算,将图片的宽转为 float64
imgWidth, imgHeight = float64(img1W), float64(img2H)
ratio float64
)
if imgWidth >= imgHeight {
ratio = imgWidth / float64(wi)
return int(imgWidth * ratio), int(imgHeight * ratio)
}
ratio = imgHeight / float64(hi)
return int(imgWidth * ratio), int(imgHeight * ratio)
}
func Gray(in []byte) (out image.Image, err error) {
m := BuffToImage(in)
bounds := m.Bounds()
dx := bounds.Dx()
dy := bounds.Dy()
newRgba := image.NewRGBA(bounds)
for i := 0; i < dx; i++ {
for j := 0; j < dy; j++ {
colorRgb := m.At(i, j)
_, g, _, a := colorRgb.RGBA()
gUint8 := uint8(g >> 8)
aUint8 := uint8(a >> 8)
newRgba.SetRGBA(i, j, color.RGBA{R: gUint8, G: gUint8, B: gUint8, A: aUint8})
}
}
r := image.Rect(0, 0, dx, dy)
return newRgba.SubImage(r), nil
}
func Rotate90(in []byte) image.Image {
m := BuffToImage(in)
rotate90 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
// 矩阵旋转
for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
// 设置像素点
rotate90.Set(m.Bounds().Max.Y-x, y, m.At(y, x))
}
}
return rotate90
}
// Rotate180 旋转180度
func Rotate180(in []byte) image.Image {
m := BuffToImage(in)
rotate180 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dx(), m.Bounds().Dy()))
// 矩阵旋转
for x := m.Bounds().Min.X; x < m.Bounds().Max.X; x++ {
for y := m.Bounds().Min.Y; y < m.Bounds().Max.Y; y++ {
// 设置像素点
rotate180.Set(m.Bounds().Max.X-x, m.Bounds().Max.Y-y, m.At(x, y))
}
}
return rotate180
}
// Rotate270 旋转270度
func Rotate270(in []byte) image.Image {
m := BuffToImage(in)
rotate270 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
// 矩阵旋转
for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
// 设置像素点
rotate270.Set(x, m.Bounds().Max.X-y, m.At(y, x))
}
}
return rotate270
}
func ImageToBase64(img image.Image, imgType string) string {
buff := ImageToBuff(img, imgType)
return base64.StdEncoding.EncodeToString(buff.Bytes())
}
func ImageToBuff(img image.Image, imgType string) *bytes.Buffer {
buff := bytes.NewBuffer(nil)
switch imgType {
case "bmp":
imgType = "bmp"
_ = bmp.Encode(buff, img)
case "png":
imgType = "png"
_ = png.Encode(buff, img)
case "tiff":
imgType = "tiff"
_ = tiff.Encode(buff, img, nil)
default:
imgType = "jpeg"
_ = jpeg.Encode(buff, img, nil)
}
return buff
}
func ImgFileToBase64(fn string) string {
fileByte := ReadFile(fn)
buff := bytes.NewBuffer(fileByte)
m, _, _ := image.Decode(buff)
return "data:image/jpeg;base64," + ImageToBase64(m, "jpeg")
}

48
pkg/utils/utils.go Normal file
View File

@ -0,0 +1,48 @@
package utils
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"github.com/google/uuid"
"math/rand"
"strings"
"time"
)
/*
RandomString 产生随机数
- size 随机码的位数
- kind 0 // 纯数字
1 // 小写字母
2 // 大写字母
3 // 数字、大小写字母
*/
func RandomString(size int, kind int) string {
iKind, kinds, rsBytes := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size)
isAll := kind > 2 || kind < 0
rand.Seed(time.Now().UnixNano())
for i := 0; i < size; i++ {
if isAll { // random iKind
iKind = rand.Intn(3)
}
scope, base := kinds[iKind][0], kinds[iKind][1]
rsBytes[i] = uint8(base + rand.Intn(scope))
}
return string(rsBytes)
}
func GetUserSha1Pass(pass, salt string) string {
key := []byte(salt)
mac := hmac.New(sha1.New, key)
mac.Write([]byte(pass))
//进行base64编码
res := base64.StdEncoding.EncodeToString(mac.Sum(nil))
return res
}
func GetUUIDString() string {
u, _ := uuid.NewUUID()
str := strings.Replace(u.String(), "-", "", -1)
return str
}

91
scripts/runTrainScript.sh Normal file
View File

@ -0,0 +1,91 @@
#!/bin/bash
#获取对应的参数 没有的话赋默认值
ARGS=`getopt -o d::s::b::e::m::o:: --long dataset::,img_size::,batch_size::,epochs::,model::,model_save:: -n 'example.sh' -- "$@"`
if [ $? != 0 ]; then
echo "Terminating..."
exit 1
fi
echo $ARGS
eval set -- "${ARGS}"
while true;
do
case "$1" in
-d|--dataset)
case "$2" in
"")
echo "Internal error!"
exit 1
;;
*)
Dataset=$2
shift 2;
;;
esac
;;
-s|--img_size)
case "$2" in
"")
echo "Internal error!"
exit 1
;;
*)
ImgSize=$2;
shift 2;
;;
esac
;;
-b|--batch_size)
case "$2" in
*)
BatchSize=$2;
shift 2;
;;
esac
;;
-e|--epochs)
case "$2" in
*)
Epochs=$2;
shift 2;
;;
esac
;;
-m|--model)
case "$2" in
*)
Model=$2;
shift 2;
;;
esac
;;
-o|--model_save)
case "$2" in
*)
ModelSave=$2;
shift 2;
;;
esac
;;
--)
shift
break
;;
*)
echo "Internal error!"
exit 1
;;
esac
done
#echo ${Dataset}
eval "$(conda shell.bash hook)"
conda activate hpds_train
cd /home/data/hpds_train_keras
python train.py --dataset ${Dataset} --img_size ${ImgSize} --batch_size ${BatchSize} \
--epochs ${Epochs} --model ${Model} --model_save ${ModelSave}