diff --git a/config/config.go b/config/config.go index 0e25aa8..74bd419 100644 --- a/config/config.go +++ b/config/config.go @@ -15,17 +15,20 @@ var ( ) type ControlCenterConfig struct { - Name string `yaml:"name,omitempty"` - Host string `yaml:"host,omitempty"` - Port int `yaml:"port,omitempty"` - Mode string `yaml:"mode,omitempty"` - Consul ConsulConfig `yaml:"consul,omitempty"` - Db DbConfig `yaml:"db"` - Cache CacheConfig `yaml:"cache"` - Logging LogOptions `yaml:"logging"` - Minio MinioConfig `yaml:"minio"` - Node HpdsNode `yaml:"node,omitempty"` - Funcs []FuncConfig `yaml:"functions,omitempty"` + Name string `yaml:"name,omitempty"` + Host string `yaml:"host,omitempty"` + Port int `yaml:"port,omitempty"` + Mode string `yaml:"mode,omitempty"` + TmpTrainDir string `yaml:"tmpTrainDir"` + TrainScriptPath string `yaml:"trainScriptPath"` + ModelOutPath string `yaml:"modelOutPath"` + Consul ConsulConfig `yaml:"consul,omitempty"` + Db DbConfig `yaml:"db"` + Cache CacheConfig `yaml:"cache"` + Logging LogOptions `yaml:"logging"` + Minio MinioConfig `yaml:"minio"` + Node HpdsNode `yaml:"node,omitempty"` + Funcs []FuncConfig `yaml:"functions,omitempty"` } type ConsulConfig struct { diff --git a/config/config.yaml b/config/config.yaml index b10d85d..96c262d 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -2,6 +2,9 @@ name: control_center host: 0.0.0.0 port: 8088 mode: dev +tmpTrainDir: ./tmp +trainScriptPath: ./scripts/runTrainScript.sh +modelOutPath: ./out logging: path: ./logs prefix: hpds-control diff --git a/go.mod b/go.mod index 363ccf6..b0a0344 100644 --- a/go.mod +++ b/go.mod @@ -7,12 +7,14 @@ require ( git.hpds.cc/Component/network v0.0.0-20230405135741-a4ea724bab76 git.hpds.cc/pavement/hpds_node v0.0.0-20230405153516-9403c4d01e12 github.com/go-sql-driver/mysql v1.7.0 + github.com/google/uuid v1.3.0 github.com/hashicorp/consul/api v1.20.0 - github.com/minio/minio-go v6.0.14+incompatible github.com/minio/minio-go/v7 v7.0.52 github.com/spf13/cobra v1.6.1 github.com/spf13/viper v1.15.0 go.uber.org/zap v1.23.0 + golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8 + golang.org/x/text v0.7.0 gopkg.in/yaml.v3 v3.0.1 xorm.io/xorm v1.3.2 ) @@ -30,7 +32,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.13.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/go-ini/ini v1.67.0 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/goccy/go-json v0.8.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -40,7 +41,6 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.1 // indirect github.com/googleapis/gax-go/v2 v2.7.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect @@ -94,7 +94,6 @@ require ( golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.5.0 // indirect - golang.org/x/text v0.7.0 // indirect golang.org/x/time v0.1.0 // indirect golang.org/x/tools v0.2.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/model/TaskLog.go b/model/TaskLog.go index fab94c7..646f3a9 100644 --- a/model/TaskLog.go +++ b/model/TaskLog.go @@ -4,7 +4,7 @@ type TaskLog struct { TaskLogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskLogId"` TaskId int64 `xorm:"INT(11) index" json:"taskId"` NodeId int64 `xorm:"INT(11) index" json:"nodeId"` - Content string `xorm:"LANGTEXT" json:"content"` + Content string `xorm:"LONGTEXT" json:"content"` CreateAt int64 `xorm:"created" json:"createAt"` UpdateAt int64 `xorm:"updated" json:"updateAt"` } diff --git a/model/diseaseType.go b/model/diseaseType.go new file mode 100644 index 0000000..0ab4db5 --- /dev/null +++ b/model/diseaseType.go @@ -0,0 +1,23 @@ +package model + +// DiseaseType 病害类别 +type DiseaseType struct { + TypeId int64 `xorm:"not null pk autoincr INT(11)" json:"typeId"` + TypeName string `xorm:"varchar(200) not null" json:"typeName"` + CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //病害分类, 1:道路 2:桥梁 3:隧道 4:边坡 + Status int `xorm:"not null INT(11) default 0" json:"status"` + CreateAt int64 `xorm:"created" json:"createAt"` + UpdateAt int64 `xorm:"updated" json:"updateAt"` +} + +func GetDiseaseType(name string, categoryId int) int64 { + item := new(DiseaseType) + h, err := DB.Where("type_name like ?", "%"+name+"%"). + And("category_id = ?", categoryId). + And("status = 1").Get(item) + if err != nil || !h { + return 0 + } + return item.TypeId + +} diff --git a/model/index.go b/model/index.go index d20c6c3..9fd37b8 100644 --- a/model/index.go +++ b/model/index.go @@ -26,6 +26,11 @@ func New(driveName, dsn string, showSql bool) { &Task{}, &TaskLog{}, &TaskResult{}, + &TrainingDataset{}, + &TrainingDatasetDetail{}, + &TrainTask{}, + &TrainTaskLog{}, + &TrainTaskResult{}, ) if err != nil { fmt.Println("同步数据库表结构", err) diff --git a/model/project.go b/model/project.go new file mode 100644 index 0000000..b43cd85 --- /dev/null +++ b/model/project.go @@ -0,0 +1,23 @@ +package model + +type Project struct { + ProjectId int64 `xorm:"not null pk autoincr INT(11)" json:"projectId"` + ProjectName string `xorm:"varchar(200) not null " json:"projectName"` + OwnerId int64 `xorm:"not null INT(11) default 0" json:"ownerId"` + BizType int `xorm:"SMALLINT" json:"bizType"` + LineName string `xorm:"varchar(200) not null " json:"lineName"` + StartName string `xorm:"varchar(200) not null " json:"startName"` + EndName string `xorm:"varchar(200) not null " json:"endName"` + FixedDeviceNum int `xorm:"not null INT(11) default 0" json:"fixedDeviceNum"` + Direction string `xorm:"varchar(200) not null " json:"direction"` + LaneNum int `xorm:"not null INT(4) default 0" json:"laneNum"` + StartPointLng float64 `xorm:"decimal(18,6)" json:"startPointLng"` + StartPointLat float64 `xorm:"decimal(18,6)" json:"startPointLat"` + EndPointLng float64 `xorm:"decimal(18,6)" json:"endPointLng"` + EndPointLat float64 `xorm:"decimal(18,6)" json:"endPointLat"` + Status int `xorm:"SMALLINT default 1" json:"status"` + Creator int64 `xorm:"INT(11) default 0" json:"creator"` + Modifier int64 `xorm:"INT(11) default 0" json:"modifier"` + CreateAt int64 `xorm:"created" json:"createAt"` + UpdateAt int64 `xorm:"updated" json:"updateAt"` +} diff --git a/model/projectResult.go b/model/projectResult.go new file mode 100644 index 0000000..e97453c --- /dev/null +++ b/model/projectResult.go @@ -0,0 +1,21 @@ +package model + +type ProjectResult struct { + Id int64 `xorm:"not null pk autoincr BIGINT" json:"id"` + ProjectId int64 `xorm:"INT(11) index" json:"projectId"` //项目编号 + SourceResultId int64 `xorm:"INT(11) index" json:"sourceResultId"` //识别结果来源编号 + MilepostNumber string `xorm:"VARCHAR(50)" json:"milepostNumber"` //里程桩号 + UpDown string `xorm:"VARCHAR(20)" json:"upDown"` //上下行 + LineNum int `xorm:"SMALLINT default 1" json:"lineNum"` //车道号 + DiseaseType string `xorm:"VARCHAR(50)" json:"diseaseType"` //病害类型 + DiseaseLevel string `xorm:"VARCHAR(20)" json:"diseaseLevel"` //病害等级 + Length float64 `xorm:"decimal(18,6)" json:"length"` //长度 + Width float64 `xorm:"decimal(18,6)" json:"width"` //宽度 + Acreage float64 `xorm:"decimal(18,6)" json:"acreage"` //面积 + Memo string `xorm:"VARCHAR(1000)" json:"memo"` //备注说明 + Result string `xorm:"LONGTEXT" json:"result"` //识别结果 + Creator int64 `xorm:"INT(11) default 0" json:"creator"` + Modifier int64 `xorm:"INT(11) default 0" json:"modifier"` + CreateAt int64 `xorm:"created" json:"createAt"` + UpdateAt int64 `xorm:"updated" json:"updateAt"` +} diff --git a/model/task.go b/model/task.go index 134dce7..ec5b798 100644 --- a/model/task.go +++ b/model/task.go @@ -62,12 +62,11 @@ func UpdateTaskProgress(taskProgress *proto.TaskLogProgress) { } } -func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 { - ret := -1.0 +func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) (int, int) { item := new(Task) h, err := DB.ID(res.TaskId).Get(item) if err != nil || !h { - return ret + return 0, 0 } if isFailing { item.FailingCount += 1 @@ -79,12 +78,11 @@ func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 { item.FinishTime = time.Now().Unix() item.UnfinishedCount = 0 item.Status = 3 - ret = 1.0 } item.UpdateAt = time.Now().Unix() _, _ = DB.ID(res.TaskId).Cols("completed_count", "failing_count", "total_count", "unfinished_count", "update_at", "finish_time", "status").Update(item) if item.TotalCount > 0 { - return 1 - float64(item.CompletedCount)/float64(item.TotalCount) + return int(item.CompletedCount), int(item.UnfinishedCount) } - return ret + return int(item.CompletedCount), int(item.UnfinishedCount) } diff --git a/model/trainTask.go b/model/trainTask.go new file mode 100644 index 0000000..6ed157f --- /dev/null +++ b/model/trainTask.go @@ -0,0 +1,23 @@ +package model + +type TrainTask struct { + TaskId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskId"` + TrainDatasetId int64 `xorm:"INT(11) index" json:"trainDatasetId"` + CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡 + TaskName string `xorm:"VARCHAR(200)" json:"taskName"` + TaskDesc string `xorm:"VARCHAR(500)" json:"taskDesc"` + Arithmetic string `xorm:"VARCHAR(50)" json:"arithmetic"` + ImageSize int `xorm:"INT" json:"imageSize"` + BatchSize int `xorm:"INT" json:"batchSize"` + EpochsSize int `xorm:"INT" json:"epochsSize"` + OutputType string `xorm:"VARCHAR(20)" json:"outputType"` + StartTime int64 `xorm:"BIGINT" json:"startTime"` + FinishTime int64 `xorm:"BIGINT" json:"finishTime"` + Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"` + Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"` + ModelFilePath string `xorm:"VARCHAR(2000)" json:"modelFilePath"` + ModelFileMetricsPath string `xorm:"VARCHAR(2000)" json:"modelFileMetricsPath"` + Status int `xorm:"not null SMALLINT default 0" json:"status"` // 1:等待执行; 2:执行中; 3:执行完成; 4:任务分配失败; 5:任务执行失败 + CreateAt int64 `xorm:"created" json:"createAt"` + UpdateAt int64 `xorm:"updated" json:"updateAt"` +} diff --git a/model/trainTaskLog.go b/model/trainTaskLog.go new file mode 100644 index 0000000..9f6188d --- /dev/null +++ b/model/trainTaskLog.go @@ -0,0 +1,12 @@ +package model + +type TrainTaskLog struct { + LogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"logId"` + TaskId int64 `xorm:"INT(11) index" json:"taskId"` + Epoch int `xorm:"SMALLINT" json:"epoch"` + Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"` + Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"` + ValLoss float64 `xorm:"DECIMAL(18,6)" json:"valLoss"` + ValAccuracy float64 `xorm:"DECIMAL(18,6)" json:"valAccuracy"` + CreateAt int64 `xorm:"created" json:"createAt"` +} diff --git a/model/trainTaskResult.go b/model/trainTaskResult.go new file mode 100644 index 0000000..d987999 --- /dev/null +++ b/model/trainTaskResult.go @@ -0,0 +1,11 @@ +package model + +type TrainTaskResult struct { + ResultId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"resultId"` + TaskId int64 `xorm:"INT(11) index" json:"taskId"` + Content string `xorm:"LONGTEXT" json:"content"` + Result string `xorm:"VARCHAR(200)" json:"result"` + Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"` + Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"` + CreateAt int64 `xorm:"created" json:"createAt"` +} diff --git a/model/trainingDataset.go b/model/trainingDataset.go new file mode 100644 index 0000000..8500834 --- /dev/null +++ b/model/trainingDataset.go @@ -0,0 +1,13 @@ +package model + +type TrainingDataset struct { + DatasetId int64 `xorm:"not null pk autoincr INT(11)" json:"datasetId"` + Name string `xorm:"VARCHAR(200)" json:"name"` + CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡 + DatasetDesc string `xorm:"varchar(200)" json:"datasetDesc"` + StoreName string `xorm:"varchar(200)" json:"storeName"` //存储路径 + ValidationNumber float64 `xorm:"DECIMAL(18,4)" json:"validationNumber"` //验证占比 + TestNumber float64 `xorm:"DECIMAL(18,4)" json:"testNumber"` //测试占比 + CreateAt int64 `xorm:"created" json:"createAt"` + UpdateAt int64 `xorm:"updated" json:"updateAt"` +} diff --git a/model/trainingDatasetDetail.go b/model/trainingDatasetDetail.go new file mode 100644 index 0000000..35dc098 --- /dev/null +++ b/model/trainingDatasetDetail.go @@ -0,0 +1,16 @@ +package model + +type TrainingDatasetDetail struct { + DetailId int64 `xorm:"not null pk autoincr INT(11)" json:"detailId"` + FileName string `xorm:"VARCHAR(200)" json:"fileName"` + FilePath string `xorm:"VARCHAR(1000)" json:"filePath"` + DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //训练数据集 + CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //训练集分类,1:训练集;2:测试集;3:验证集 + FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小 + FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5 + IsDisease int `xorm:"TINYINT(1)" json:"isDisease"` //是否有病害, 1:有病害;2:无病害; + OperationLogId int64 `xorm:"INT(11) index" json:"operationLogId"` //操作日志编号 + Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人 + CreateAt int64 `xorm:"created" json:"createAt"` //上传时间 + UpdateAt int64 `xorm:"updated" json:"updateAt"` //更新时间 +} diff --git a/mq/index.go b/mq/index.go index bfc1bac..ecc7ba5 100644 --- a/mq/index.go +++ b/mq/index.go @@ -1,6 +1,7 @@ package mq import ( + "bufio" "encoding/base64" "encoding/json" "fmt" @@ -8,12 +9,18 @@ import ( "git.hpds.cc/Component/network/frame" "github.com/google/uuid" "go.uber.org/zap" + "golang.org/x/text/encoding/simplifiedchinese" "hpds_control_center/config" "hpds_control_center/internal/balance" "hpds_control_center/internal/minio" "hpds_control_center/internal/proto" "hpds_control_center/model" + "hpds_control_center/pkg/utils" + "io" + "math" "os" + "os/exec" + "path" "strconv" "strings" "sync" @@ -22,6 +29,13 @@ import ( "git.hpds.cc/pavement/hpds_node" ) +type Charset string + +const ( + UTF8 = Charset("UTF-8") + GB18030 = Charset("GB18030") +) + var ( MqList []HpdsMqNode TaskList = make(map[int64]*TaskItem) @@ -379,8 +393,8 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) { item.UpdateAt = time.Now().Unix() _, _ = model.DB.ID(item.Id).AllCols().Update(item) } else { - item.ModelId = payload["modelId"].(int64) - item.NodeId = payload["nodeId"].(int64) + item.ModelId = int64(payload["modelId"].(float64)) + item.NodeId = int64(payload["nodeId"].(float64)) item.Status = 1 item.IssueResult = string(pData) item.CreateAt = time.Now().Unix() @@ -403,7 +417,16 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) { // _, _ = model.DB.Insert(item) // //fn := payload["fileName"].(string) // //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string))) + case TrainTaskAdd: + payload := cmd.Payload.(map[string]interface{}) + if itemId, ok := payload["taskId"].(float64); ok { + item := new(model.TrainTask) + h, err := model.DB.ID(int64(itemId)).Get(item) + if err != nil || !h { + } + RunTraining(item) + } default: } @@ -450,14 +473,16 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) { if err != nil { fmt.Println("接收TaskResponse数据出错", err) } + //处理到项目结果表 + go processToProjectResult(item) //更新运行进度 - rat := model.UpdateTaskProgressByLog(item, isFailing) + processed, unProcessed := model.UpdateTaskProgressByLog(item, isFailing) var ( ratStr string ) - if rat > 0 && rat < 1 { - ratStr = fmt.Sprintf("[已处理%2.f,剩余%2.f未处理]", 1-rat, rat) - } else if rat == 1 { + if unProcessed > 0 { + ratStr = fmt.Sprintf("[已处理[%d],剩余[%d]未处理]", processed, unProcessed) + } else { ratStr = "[已全部处理]" } taskLog := new(model.TaskLog) @@ -479,6 +504,267 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) { return frame.Tag(cmd.Command), nil } +type ModelResult struct { + Code int `json:"code"` +} + +type InsigmaResult struct { + Code int `json:"code"` + NumOfDiseases int `json:"num_of_diseases"` + Diseases []DiseasesInfo `json:"diseases"` + Image string `json:"image"` +} + +type DiseasesInfo struct { + Id int `json:"id"` + Type string `json:"type"` + Level string `json:"level"` + Param DiseasesParam `json:"param"` +} + +type DiseasesParam struct { + Length float64 `json:"length"` + Area float64 `json:"area"` + MaxWidth string `json:"max_width"` +} + +type LightweightResult struct { + Code int `json:"code"` + Crack bool `json:"crack"` + ImgDiscern string `json:"img_discern"` + ImgSrc string `json:"img_src"` + Pothole bool `json:"pothole"` +} + +func processToProjectResult(src *model.TaskResult) { + project := new(model.Project) + h, err := model.DB.Table("project").Alias("p").Join("inner", []string{"dataset", "d"}, "d.project_id= p.project_id").Where("d.dataset_id=?", src.DatasetId).Get(project) + if !h { + err = fmt.Errorf("未能找到对应的项目信息") + } + if err != nil { + logging.L().With(zap.String("控制节点", "错误信息")).Error("获取项目信息", zap.Error(err)) + return + } + var ( + mr ModelResult + mrList []string + fileDiscern string + memo string + milepostNumber string + upDown string + lineNum int + width float64 + ) + + switch project.BizType { + case 1: //道路 + arr := strings.Split(src.SrcPath, " ") + if len(arr) > 1 { + milepostNumber = GetMilepost(project.StartName, arr[1], arr[2]) + if arr[2] == "D" { + upDown = "下行" + } else { + upDown = "上行" + } + } + if len(arr) > 3 { + lineNum, _ = strconv.Atoi(arr[3]) + } + case 2: //桥梁 + case 3: //隧道 + //隧道名-采集方向(D/X)-相机编号(01-22)-采集序号(五位)K里程桩号.bmp DAXIASHAN-D-05-00003K15069.5.bmp + arr := strings.Split(src.SrcPath, "K") + if len(arr) > 1 { + arrM := strings.Split(arr[1], ".") + milepostNumber = meter2Milepost(arrM[0]) + arrD := strings.Split(arr[0], ".") + if len(arrD) > 1 { + if arrD[1] == "D" { + upDown = "下行" + } else { + upDown = "上行" + } + } + if len(arrD) > 4 { + lineNum, _ = strconv.Atoi(arrD[3]) + } + } + } + if len(src.Result) > 0 && src.Result[0] == '[' { + mrList = make([]string, 0) + if err := json.Unmarshal([]byte(src.Result), &mrList); err != nil { + return + } + list := make([]*model.ProjectResult, 0) + for _, str := range mrList { + if err := json.Unmarshal([]byte(str), &mr); err != nil { + continue + } + if mr.Code == 2001 { + ir := new(InsigmaResult) + if err := json.Unmarshal([]byte(str), &ir); err != nil { + continue + } + fileDiscern = ir.Image + for key, value := range ir.Diseases { + if len(value.Param.MaxWidth) > 0 && width == 0 { + width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64) + } else { + width = 0 + } + memo = fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area) + item := &model.ProjectResult{ + ProjectId: project.ProjectId, + SourceResultId: src.ResultId, + MilepostNumber: milepostNumber, + UpDown: upDown, + LineNum: lineNum, + DiseaseType: value.Type, + DiseaseLevel: value.Level, + Length: value.Param.Length, + Width: width, + Acreage: value.Param.Area, + Memo: memo, + Result: fileDiscern, + Creator: 0, + Modifier: 0, + CreateAt: time.Now().Unix(), + UpdateAt: time.Now().Unix(), + } + list = append(list, item) + } + } + } + _, _ = model.DB.Insert(list) + } else { + if err := json.Unmarshal([]byte(src.Result), &mr); err != nil { + return + } + switch mr.Code { + case 0: //轻量化模型返回 + lr := new(LightweightResult) + if err := json.Unmarshal([]byte(src.Result), &lr); err != nil { + return + } + if lr.Crack || lr.Pothole { + if lr.Crack { + memo = "检测到裂缝" + } else { + memo = "检测到坑洼" + } + fileDiscern = lr.ImgDiscern + if len(fileDiscern) == 0 { + fileDiscern = lr.ImgSrc + } + diseaseLevelName := "重度" + diseaseTypeName := "" + switch project.BizType { + case 2: + diseaseTypeName = "结构裂缝" + case 3: + diseaseTypeName = "衬砌裂缝" + default: + diseaseTypeName = "横向裂缝" + } + item := &model.ProjectResult{ + ProjectId: project.ProjectId, + SourceResultId: src.ResultId, + MilepostNumber: milepostNumber, + UpDown: upDown, + LineNum: lineNum, + DiseaseType: diseaseTypeName, + DiseaseLevel: diseaseLevelName, + Length: 0, + Width: 0, + Acreage: 0, + Memo: memo, + Result: fileDiscern, + Creator: 0, + Modifier: 0, + CreateAt: time.Now().Unix(), + UpdateAt: time.Now().Unix(), + } + _, _ = model.DB.Insert(item) + } else { + fileDiscern = lr.ImgSrc + } + // + case 2001: //网新返回有病害 + ir := new(InsigmaResult) + if err := json.Unmarshal([]byte(src.Result), &ir); err != nil { + return + } + fileDiscern = ir.Image + list := make([]*model.ProjectResult, 0) + for _, val := range ir.Diseases { + if len(val.Param.MaxWidth) > 0 && width == 0 { + width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64) + } else { + width = 0 + } + memo = fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area) + maxWidth, _ := strconv.ParseFloat(val.Param.MaxWidth, 64) + item := &model.ProjectResult{ + ProjectId: project.ProjectId, + SourceResultId: src.ResultId, + MilepostNumber: milepostNumber, + UpDown: upDown, + LineNum: lineNum, + DiseaseType: val.Type, + DiseaseLevel: val.Level, + Length: val.Param.Length, + Width: maxWidth, + Acreage: val.Param.Area, + Memo: memo, + Result: fileDiscern, + Creator: 0, + Modifier: 0, + CreateAt: time.Now().Unix(), + UpdateAt: time.Now().Unix(), + } + list = append(list, item) + } + _, _ = model.DB.Insert(list) + } + } +} + +// 里程桩加减里程,返回里程桩 +func GetMilepost(start, num, upDown string) string { + arr := strings.Split(start, "+") + var ( + kilometre, meter, milepost, counter, res, resMilepost, resMeter float64 + ) + if len(arr) == 1 { + meter = 0 + } else { + meter, _ = strconv.ParseFloat(arr[1], 64) + } + str := strings.Replace(arr[0], "k", "", -1) + str = strings.Replace(str, "K", "", -1) + kilometre, _ = strconv.ParseFloat(str, 64) + milepost = kilometre + meter/1000 + counter, _ = strconv.ParseFloat(num, 64) + if upDown == "D" { + res = milepost - counter + + } else { + res = milepost + counter + } + resMilepost = math.Floor(res) + resMeter = (res - resMilepost) * 100 + return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter) +} + +// 米装换成里程桩号 +func meter2Milepost(meter string) string { + meter = strings.Replace(meter, "K", "", -1) + m, _ := strconv.ParseFloat(meter, 64) + resMilepost := math.Floor(m / 1000) + resMeter := (m - resMilepost*1000) * 100 + return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter) +} func deliver(topic string, mqType uint, payload interface{}) { cli := GetMqClient(topic, mqType) pData, _ := json.Marshal(payload) @@ -544,3 +830,220 @@ func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) { l.Unlock() return frame.Tag(cmd.Command), nil } + +func RunTraining(task *model.TrainTask) { + var ( + args []string + modelPath, modelFileName, testSize string + modelAcc, modelLoss float64 + ) + fmt.Println("curr tmp dir====>>>>", config.Cfg.TmpTrainDir) + modelFileName = utils.GetUUIDString() + //复制训练数据集 + tmpTrainDir := path.Join(config.Cfg.TmpTrainDir, fmt.Sprintf("%s_%s_%d_%d", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize)) + fileList := make([]model.TrainingDatasetDetail, 0) + _ = model.DB.Where("dataset_id = ?", task.TrainDatasetId).Find(&fileList) + _ = os.MkdirAll(tmpTrainDir, os.ModePerm) + _ = os.MkdirAll(path.Join(tmpTrainDir, "train"), os.ModePerm) + _ = os.MkdirAll(path.Join(tmpTrainDir, "train", "0"), os.ModePerm) + _ = os.MkdirAll(path.Join(tmpTrainDir, "train", "1"), os.ModePerm) + _ = os.MkdirAll(path.Join(tmpTrainDir, "val"), os.ModePerm) + _ = os.MkdirAll(path.Join(tmpTrainDir, "val", "0"), os.ModePerm) + _ = os.MkdirAll(path.Join(tmpTrainDir, "val", "1"), os.ModePerm) + _ = os.MkdirAll(path.Join(tmpTrainDir, "test"), os.ModePerm) + for _, v := range fileList { + dstFilePath := "" + switch v.CategoryId { + case 2: + dstFilePath = "test" + default: + dstFilePath = "train" + } + if v.CategoryId != 2 { + if v.IsDisease == 1 { + dstFilePath = path.Join(tmpTrainDir, dstFilePath, "0") + } else { + dstFilePath = path.Join(tmpTrainDir, dstFilePath, "1") + } + } else { + dstFilePath = path.Join(tmpTrainDir, dstFilePath) + } + err := utils.CopyFile(v.FilePath, path.Join(dstFilePath, v.FileName)) + if err != nil { + fmt.Println("copy error: ", err) + } + + } + + modelPath = path.Join(config.Cfg.ModelOutPath, fmt.Sprintf("%s_%s_%d_%d_%s", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize, task.OutputType)) + _ = os.MkdirAll(modelPath, os.ModePerm) + + dt := new(model.TrainingDataset) + _, err := model.DB.ID(task.TrainDatasetId).Get(dt) + if err != nil { + goto ReturnPoint + } + testSize = fmt.Sprintf("%.2f", dt.ValidationNumber/100) + //执行训练命令 + args = []string{"--dataset=" + path.Join(tmpTrainDir, "train"), + "--img_size=" + strconv.Itoa(task.ImageSize), "--batch_size=" + strconv.Itoa(task.BatchSize), "--test_size=" + testSize, + "--epochs=" + strconv.Itoa(task.EpochsSize), "--model=" + task.Arithmetic, "--model_save=" + path.Join(modelPath, modelFileName+".h5"), + } + fmt.Println("args====>>>", args) + err = ExecCommand(config.Cfg.TrainScriptPath, args, path.Join(modelPath, modelFileName+".log"), task.TaskId) +ReturnPoint: + //返回训练结果 + modelMetricsFile := modelFileName + "_model_metrics.png" + task.FinishTime = time.Now().Unix() + task.ModelFilePath = path.Join(modelPath, modelFileName+".h5") + task.Loss = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Loss:") + task.Accuracy = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Accuracy:") + task.Status = 3 + if err != nil { + task.Status = 5 + } + task.ModelFileMetricsPath = path.Join(modelPath, modelMetricsFile) + _, _ = model.DB.ID(task.TaskId).AllCols().Update(task) + if utils.PathExists(path.Join(modelPath, modelFileName+".log")) { + logContext := utils.ReadFile(path.Join(modelPath, modelFileName+".log")) + taskRes := new(model.TrainTaskResult) + taskRes.TaskId = task.TaskId + taskRes.CreateAt = time.Now().Unix() + taskRes.Content = string(logContext) + taskRes.Result = path.Join(modelPath, modelMetricsFile) + taskRes.Accuracy = modelAcc + taskRes.Loss = modelLoss + c, err := model.DB.Insert(taskRes) + if err != nil { + fmt.Println("model.DB.Insert(taskRes) error ========>>>>>>", err) + } + fmt.Println("model.DB.Insert(taskRes) count ========>>>>>>", c) + } else { + fmt.Println("logContext========>>>>>>未读取") + } +} + +func GetIndicatorByLog(logFileName, indicator string) float64 { + logFn, _ := os.Open(logFileName) + defer func() { + _ = logFn.Close() + }() + buf := bufio.NewReader(logFn) + for { + line, err := buf.ReadString('\n') + if err != nil { + if err == io.EOF { + //fmt.Println("File read ok!") + break + } else { + fmt.Println("Read file error!", err) + return 0 + } + } + if strings.Index(line, indicator) >= 0 { + str := strings.Replace(line, indicator, "", -1) + str = strings.Replace(str, "\n", "", -1) + value, _ := strconv.ParseFloat(strings.Trim(str, " "), 64) + return value + } + } + return 0 +} + +func ExecCommand(cmd string, args []string, logFileName string, taskId int64) (err error) { + logFile, _ := os.Create(logFileName) + defer func() { + _ = logFile.Close() + }() + fmt.Print("开始训练......") + + c := exec.Command(cmd, args...) // mac or linux + stdout, err := c.StdoutPipe() + if err != nil { + return err + } + var ( + wg sync.WaitGroup + ) + wg.Add(1) + go func() { + defer wg.Done() + reader := bufio.NewReader(stdout) + var ( + epoch int + //modelLoss, modelAcc float64 + ) + for { + readString, err := reader.ReadString('\n') + if err != nil || err == io.EOF { + fmt.Println("训练2===>>>", err) + //wg.Done() + return + } + byte2String := ConvertByte2String([]byte(readString), "GB18030") + _, _ = fmt.Fprint(logFile, byte2String) + if strings.Index(byte2String, "Epoch") >= 0 { + str := strings.Replace(byte2String, "Epoch ", "", -1) + arr := strings.Split(str, "/") + epoch, _ = strconv.Atoi(arr[0]) + } + if strings.Index(byte2String, "- loss:") > 0 && + strings.Index(byte2String, "- accuracy:") > 0 && + strings.Index(byte2String, "- val_loss:") > 0 && + strings.Index(byte2String, "- val_accuracy:") > 0 { + var ( + loss, acc, valLoss, valAcc float64 + ) + arr := strings.Split(byte2String, "-") + for _, v := range arr { + if strings.Index(v, "loss:") > 0 && strings.Index(v, "val_loss:") < 0 { + strLoss := strings.Replace(v, " loss: ", "", -1) + loss, _ = strconv.ParseFloat(strings.Trim(strLoss, " "), 64) + } + if strings.Index(v, "accuracy:") > 0 && strings.Index(v, "val_accuracy:") < 0 { + strAcc := strings.Replace(v, " accuracy: ", "", -1) + acc, _ = strconv.ParseFloat(strings.Trim(strAcc, " "), 64) + } + if strings.Index(v, "val_loss:") > 0 { + strValLoss := strings.Replace(v, "val_loss: ", "", -1) + valLoss, _ = strconv.ParseFloat(strings.Trim(strValLoss, " "), 64) + } + if strings.Index(v, "val_accuracy:") > 0 { + strValAcc := strings.Replace(v, "val_accuracy: ", "", -1) + strValAcc = strings.Replace(strValAcc, "\n", "", -1) + valAcc, _ = strconv.ParseFloat(strings.Trim(strValAcc, " "), 64) + } + } + taskLog := new(model.TrainTaskLog) + taskLog.Epoch = epoch + taskLog.TaskId = taskId + taskLog.CreateAt = time.Now().Unix() + taskLog.Loss = loss + taskLog.Accuracy = acc + taskLog.ValLoss = valLoss + taskLog.ValAccuracy = valAcc + _, _ = model.DB.Insert(taskLog) + } + fmt.Print(byte2String) + } + }() + err = c.Start() + if err != nil { + fmt.Println("训练3===>>>", err) + } + wg.Wait() + return +} +func ConvertByte2String(byte []byte, charset Charset) string { + var str string + switch charset { + case GB18030: + var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte) + str = string(decodeBytes) + case UTF8: + fallthrough + default: + str = string(byte) + } + return str +} diff --git a/mq/instruction.go b/mq/instruction.go index bdf5a30..8acc9ce 100644 --- a/mq/instruction.go +++ b/mq/instruction.go @@ -9,6 +9,7 @@ const ( ModelIssueResponse TaskExecuteLog TaskLog + TrainTaskAdd ) type InstructionReq struct { diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 0000000..5d0f1d3 --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,94 @@ +package utils + +import ( + "crypto/md5" + "encoding/base64" + "encoding/hex" + "fmt" + "git.hpds.cc/Component/logging" + "go.uber.org/zap" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +func CopyFile(src, dst string) error { + sourceFileStat, err := os.Stat(src) + if err != nil { + return err + } + + if !sourceFileStat.Mode().IsRegular() { + return fmt.Errorf("%s is not a regular file", src) + } + + source, err := os.Open(src) + if err != nil { + return err + } + defer func(source *os.File) { + _ = source.Close() + }(source) + + destination, err := os.Create(dst) + if err != nil { + return err + } + defer func(destination *os.File) { + _ = destination.Close() + }(destination) + _, err = io.Copy(destination, source) + return err +} + +func PathExists(path string) bool { + _, err := os.Stat(path) + if err == nil { + return true + } + if os.IsNotExist(err) { + return false + } + return false +} + +// ReadFile 读取到file中,再利用ioutil将file直接读取到[]byte中, 这是最优 +func ReadFile(fn string) []byte { + f, err := os.Open(fn) + if err != nil { + logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err)) + return nil + } + defer func(f *os.File) { + _ = f.Close() + }(f) + + fd, err := io.ReadAll(f) + if err != nil { + logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err)) + return nil + } + + return fd +} + +func GetFileName(fn string) string { + fileType := path.Ext(fn) + return strings.TrimSuffix(fn, fileType) +} +func GetFileNameAndExt(fn string) string { + _, fileName := filepath.Split(fn) + return fileName +} +func GetFileMd5(data []byte) string { + hash := md5.New() + hash.Write(data) + return hex.EncodeToString(hash.Sum(nil)) +} + +func FileToBase64(fn string) string { + buff := ReadFile(fn) + return base64.StdEncoding.EncodeToString(buff) // 加密成base64字符串 +} diff --git a/pkg/utils/http.go b/pkg/utils/http.go new file mode 100644 index 0000000..81783b2 --- /dev/null +++ b/pkg/utils/http.go @@ -0,0 +1,126 @@ +package utils + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "path/filepath" + "strings" +) + +func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) { + var paramStr string = "" + if contentType, ok := header["Content-Type"]; ok && strings.Contains(contentType, "json") { + bytesData, _ := json.Marshal(params) + paramStr = string(bytesData) + } else { + for k, v := range params { + if len(paramStr) == 0 { + paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v)) + } else { + paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v)) + } + } + } + + client := &http.Client{} + req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr)) + if err != nil { + return nil, err + } + for k, v := range header { + req.Header.Set(k, v) + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + defer func() { + if resp.Body != nil { + err = resp.Body.Close() + if err != nil { + return + } + } + }() + var body []byte + if resp.Body != nil { + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + } + return body, nil +} + +type UploadFile struct { + // 表单名称 + Name string + Filepath string + // 文件全路径 + File *bytes.Buffer +} + +func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string { + requestBody, realContentType := getReader(reqParams, contentType, files) + httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody) + // 添加请求头 + httpRequest.Header.Add("Content-Type", realContentType) + if headers != nil { + for k, v := range headers { + httpRequest.Header.Add(k, v) + } + } + httpClient := &http.Client{} + // 发送请求 + resp, err := httpClient.Do(httpRequest) + if err != nil { + panic(err) + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + response, _ := io.ReadAll(resp.Body) + return string(response) +} + +func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) { + if strings.Index(contentType, "json") > -1 { + bytesData, _ := json.Marshal(reqParams) + return bytes.NewReader(bytesData), contentType + } else if files != nil { + body := &bytes.Buffer{} + // 文件写入 body + writer := multipart.NewWriter(body) + for _, uploadFile := range files { + part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath)) + if err != nil { + panic(err) + } + _, err = io.Copy(part, uploadFile.File) + } + // 其他参数列表写入 body + for k, v := range reqParams { + if err := writer.WriteField(k, v); err != nil { + panic(err) + } + } + if err := writer.Close(); err != nil { + panic(err) + } + // 上传文件需要自己专用的contentType + return body, writer.FormDataContentType() + } else { + urlValues := url.Values{} + for key, val := range reqParams { + urlValues.Set(key, val) + } + reqBody := urlValues.Encode() + return strings.NewReader(reqBody), contentType + } +} diff --git a/pkg/utils/image.go b/pkg/utils/image.go new file mode 100644 index 0000000..025119a --- /dev/null +++ b/pkg/utils/image.go @@ -0,0 +1,139 @@ +package utils + +import ( + "bytes" + "encoding/base64" + "golang.org/x/image/bmp" + "golang.org/x/image/tiff" + "image" + "image/color" + "image/jpeg" + "image/png" +) + +func BuffToImage(in []byte) image.Image { + buff := bytes.NewBuffer(in) + m, _, _ := image.Decode(buff) + return m +} + +// Clip 图片裁剪 +func Clip(in []byte, wi, hi int, equalProportion bool) (out image.Image, imageType string, err error) { + buff := bytes.NewBuffer(in) + m, imgType, _ := image.Decode(buff) + rgbImg := m.(*image.YCbCr) + if equalProportion { + w := m.Bounds().Max.X + h := m.Bounds().Max.Y + if w > 0 && h > 0 && wi > 0 && hi > 0 { + wi, hi = fixSize(w, h, wi, hi) + } + } + return rgbImg.SubImage(image.Rect(0, 0, wi, hi)), imgType, nil +} + +func fixSize(img1W, img2H, wi, hi int) (new1W, new2W int) { + var ( //为了方便计算,将图片的宽转为 float64 + imgWidth, imgHeight = float64(img1W), float64(img2H) + ratio float64 + ) + if imgWidth >= imgHeight { + ratio = imgWidth / float64(wi) + return int(imgWidth * ratio), int(imgHeight * ratio) + } + ratio = imgHeight / float64(hi) + return int(imgWidth * ratio), int(imgHeight * ratio) +} + +func Gray(in []byte) (out image.Image, err error) { + m := BuffToImage(in) + bounds := m.Bounds() + dx := bounds.Dx() + dy := bounds.Dy() + newRgba := image.NewRGBA(bounds) + for i := 0; i < dx; i++ { + for j := 0; j < dy; j++ { + colorRgb := m.At(i, j) + _, g, _, a := colorRgb.RGBA() + gUint8 := uint8(g >> 8) + aUint8 := uint8(a >> 8) + newRgba.SetRGBA(i, j, color.RGBA{R: gUint8, G: gUint8, B: gUint8, A: aUint8}) + } + } + r := image.Rect(0, 0, dx, dy) + return newRgba.SubImage(r), nil +} + +func Rotate90(in []byte) image.Image { + m := BuffToImage(in) + rotate90 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx())) + // 矩阵旋转 + for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ { + for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- { + // 设置像素点 + rotate90.Set(m.Bounds().Max.Y-x, y, m.At(y, x)) + } + } + return rotate90 +} + +// Rotate180 旋转180度 +func Rotate180(in []byte) image.Image { + m := BuffToImage(in) + rotate180 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dx(), m.Bounds().Dy())) + // 矩阵旋转 + for x := m.Bounds().Min.X; x < m.Bounds().Max.X; x++ { + for y := m.Bounds().Min.Y; y < m.Bounds().Max.Y; y++ { + // 设置像素点 + rotate180.Set(m.Bounds().Max.X-x, m.Bounds().Max.Y-y, m.At(x, y)) + } + } + return rotate180 +} + +// Rotate270 旋转270度 +func Rotate270(in []byte) image.Image { + m := BuffToImage(in) + rotate270 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx())) + // 矩阵旋转 + for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ { + for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- { + // 设置像素点 + rotate270.Set(x, m.Bounds().Max.X-y, m.At(y, x)) + } + } + return rotate270 + +} + +func ImageToBase64(img image.Image, imgType string) string { + buff := ImageToBuff(img, imgType) + return base64.StdEncoding.EncodeToString(buff.Bytes()) +} + +func ImageToBuff(img image.Image, imgType string) *bytes.Buffer { + buff := bytes.NewBuffer(nil) + switch imgType { + case "bmp": + imgType = "bmp" + _ = bmp.Encode(buff, img) + case "png": + imgType = "png" + _ = png.Encode(buff, img) + case "tiff": + imgType = "tiff" + _ = tiff.Encode(buff, img, nil) + default: + imgType = "jpeg" + _ = jpeg.Encode(buff, img, nil) + } + return buff +} + +func ImgFileToBase64(fn string) string { + fileByte := ReadFile(fn) + buff := bytes.NewBuffer(fileByte) + m, _, _ := image.Decode(buff) + return "data:image/jpeg;base64," + ImageToBase64(m, "jpeg") + +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go new file mode 100644 index 0000000..e5f99b0 --- /dev/null +++ b/pkg/utils/utils.go @@ -0,0 +1,48 @@ +package utils + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "github.com/google/uuid" + "math/rand" + "strings" + "time" +) + +/* + RandomString 产生随机数 + - size 随机码的位数 + - kind 0 // 纯数字 + 1 // 小写字母 + 2 // 大写字母 + 3 // 数字、大小写字母 +*/ +func RandomString(size int, kind int) string { + iKind, kinds, rsBytes := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size) + isAll := kind > 2 || kind < 0 + rand.Seed(time.Now().UnixNano()) + for i := 0; i < size; i++ { + if isAll { // random iKind + iKind = rand.Intn(3) + } + scope, base := kinds[iKind][0], kinds[iKind][1] + rsBytes[i] = uint8(base + rand.Intn(scope)) + } + return string(rsBytes) +} + +func GetUserSha1Pass(pass, salt string) string { + key := []byte(salt) + mac := hmac.New(sha1.New, key) + mac.Write([]byte(pass)) + //进行base64编码 + res := base64.StdEncoding.EncodeToString(mac.Sum(nil)) + return res +} + +func GetUUIDString() string { + u, _ := uuid.NewUUID() + str := strings.Replace(u.String(), "-", "", -1) + return str +} diff --git a/scripts/runTrainScript.sh b/scripts/runTrainScript.sh new file mode 100644 index 0000000..c850b99 --- /dev/null +++ b/scripts/runTrainScript.sh @@ -0,0 +1,91 @@ +#!/bin/bash + +#获取对应的参数 没有的话赋默认值 +ARGS=`getopt -o d::s::b::e::m::o:: --long dataset::,img_size::,batch_size::,epochs::,model::,model_save:: -n 'example.sh' -- "$@"` +if [ $? != 0 ]; then + echo "Terminating..." + exit 1 +fi +echo $ARGS +eval set -- "${ARGS}" +while true; +do + case "$1" in + -d|--dataset) + case "$2" in + "") + echo "Internal error!" + exit 1 + ;; + *) + Dataset=$2 + shift 2; + ;; + esac + ;; + -s|--img_size) + case "$2" in + "") + echo "Internal error!" + exit 1 + ;; + + *) + ImgSize=$2; + shift 2; + ;; + esac + ;; + -b|--batch_size) + case "$2" in + *) + BatchSize=$2; + shift 2; + ;; + esac + ;; + -e|--epochs) + case "$2" in + *) + Epochs=$2; + shift 2; + ;; + esac + ;; + -m|--model) + case "$2" in + *) + Model=$2; + shift 2; + ;; + esac + ;; + -o|--model_save) + case "$2" in + *) + ModelSave=$2; + shift 2; + ;; + esac + ;; + + --) + shift + break + ;; + *) + echo "Internal error!" + exit 1 + ;; + esac + +done + +#echo ${Dataset} + +eval "$(conda shell.bash hook)" +conda activate hpds_train +cd /home/data/hpds_train_keras +python train.py --dataset ${Dataset} --img_size ${ImgSize} --batch_size ${BatchSize} \ + --epochs ${Epochs} --model ${Model} --model_save ${ModelSave} +