1、增加任务处理进度
This commit is contained in:
parent
7d09ba0286
commit
90d43d468a
|
@ -19,6 +19,9 @@ type ControlCenterConfig struct {
|
|||
Host string `yaml:"host,omitempty"`
|
||||
Port int `yaml:"port,omitempty"`
|
||||
Mode string `yaml:"mode,omitempty"`
|
||||
TmpTrainDir string `yaml:"tmpTrainDir"`
|
||||
TrainScriptPath string `yaml:"trainScriptPath"`
|
||||
ModelOutPath string `yaml:"modelOutPath"`
|
||||
Consul ConsulConfig `yaml:"consul,omitempty"`
|
||||
Db DbConfig `yaml:"db"`
|
||||
Cache CacheConfig `yaml:"cache"`
|
||||
|
|
|
@ -2,6 +2,9 @@ name: control_center
|
|||
host: 0.0.0.0
|
||||
port: 8088
|
||||
mode: dev
|
||||
tmpTrainDir: ./tmp
|
||||
trainScriptPath: ./scripts/runTrainScript.sh
|
||||
modelOutPath: ./out
|
||||
logging:
|
||||
path: ./logs
|
||||
prefix: hpds-control
|
||||
|
|
7
go.mod
7
go.mod
|
@ -7,12 +7,14 @@ require (
|
|||
git.hpds.cc/Component/network v0.0.0-20230405135741-a4ea724bab76
|
||||
git.hpds.cc/pavement/hpds_node v0.0.0-20230405153516-9403c4d01e12
|
||||
github.com/go-sql-driver/mysql v1.7.0
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/hashicorp/consul/api v1.20.0
|
||||
github.com/minio/minio-go v6.0.14+incompatible
|
||||
github.com/minio/minio-go/v7 v7.0.52
|
||||
github.com/spf13/cobra v1.6.1
|
||||
github.com/spf13/viper v1.15.0
|
||||
go.uber.org/zap v1.23.0
|
||||
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8
|
||||
golang.org/x/text v0.7.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
xorm.io/xorm v1.3.2
|
||||
)
|
||||
|
@ -30,7 +32,6 @@ require (
|
|||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fatih/color v1.13.0 // indirect
|
||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||
github.com/go-ini/ini v1.67.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
|
||||
github.com/goccy/go-json v0.8.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
|
@ -40,7 +41,6 @@ require (
|
|||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.2.1 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
|
@ -94,7 +94,6 @@ require (
|
|||
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.5.0 // indirect
|
||||
golang.org/x/text v0.7.0 // indirect
|
||||
golang.org/x/time v0.1.0 // indirect
|
||||
golang.org/x/tools v0.2.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
|
||||
|
|
|
@ -4,7 +4,7 @@ type TaskLog struct {
|
|||
TaskLogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskLogId"`
|
||||
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
|
||||
NodeId int64 `xorm:"INT(11) index" json:"nodeId"`
|
||||
Content string `xorm:"LANGTEXT" json:"content"`
|
||||
Content string `xorm:"LONGTEXT" json:"content"`
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package model
|
||||
|
||||
// DiseaseType 病害类别
|
||||
type DiseaseType struct {
|
||||
TypeId int64 `xorm:"not null pk autoincr INT(11)" json:"typeId"`
|
||||
TypeName string `xorm:"varchar(200) not null" json:"typeName"`
|
||||
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //病害分类, 1:道路 2:桥梁 3:隧道 4:边坡
|
||||
Status int `xorm:"not null INT(11) default 0" json:"status"`
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||
}
|
||||
|
||||
func GetDiseaseType(name string, categoryId int) int64 {
|
||||
item := new(DiseaseType)
|
||||
h, err := DB.Where("type_name like ?", "%"+name+"%").
|
||||
And("category_id = ?", categoryId).
|
||||
And("status = 1").Get(item)
|
||||
if err != nil || !h {
|
||||
return 0
|
||||
}
|
||||
return item.TypeId
|
||||
|
||||
}
|
|
@ -26,6 +26,11 @@ func New(driveName, dsn string, showSql bool) {
|
|||
&Task{},
|
||||
&TaskLog{},
|
||||
&TaskResult{},
|
||||
&TrainingDataset{},
|
||||
&TrainingDatasetDetail{},
|
||||
&TrainTask{},
|
||||
&TrainTaskLog{},
|
||||
&TrainTaskResult{},
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Println("同步数据库表结构", err)
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package model
|
||||
|
||||
type Project struct {
|
||||
ProjectId int64 `xorm:"not null pk autoincr INT(11)" json:"projectId"`
|
||||
ProjectName string `xorm:"varchar(200) not null " json:"projectName"`
|
||||
OwnerId int64 `xorm:"not null INT(11) default 0" json:"ownerId"`
|
||||
BizType int `xorm:"SMALLINT" json:"bizType"`
|
||||
LineName string `xorm:"varchar(200) not null " json:"lineName"`
|
||||
StartName string `xorm:"varchar(200) not null " json:"startName"`
|
||||
EndName string `xorm:"varchar(200) not null " json:"endName"`
|
||||
FixedDeviceNum int `xorm:"not null INT(11) default 0" json:"fixedDeviceNum"`
|
||||
Direction string `xorm:"varchar(200) not null " json:"direction"`
|
||||
LaneNum int `xorm:"not null INT(4) default 0" json:"laneNum"`
|
||||
StartPointLng float64 `xorm:"decimal(18,6)" json:"startPointLng"`
|
||||
StartPointLat float64 `xorm:"decimal(18,6)" json:"startPointLat"`
|
||||
EndPointLng float64 `xorm:"decimal(18,6)" json:"endPointLng"`
|
||||
EndPointLat float64 `xorm:"decimal(18,6)" json:"endPointLat"`
|
||||
Status int `xorm:"SMALLINT default 1" json:"status"`
|
||||
Creator int64 `xorm:"INT(11) default 0" json:"creator"`
|
||||
Modifier int64 `xorm:"INT(11) default 0" json:"modifier"`
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package model
|
||||
|
||||
type ProjectResult struct {
|
||||
Id int64 `xorm:"not null pk autoincr BIGINT" json:"id"`
|
||||
ProjectId int64 `xorm:"INT(11) index" json:"projectId"` //项目编号
|
||||
SourceResultId int64 `xorm:"INT(11) index" json:"sourceResultId"` //识别结果来源编号
|
||||
MilepostNumber string `xorm:"VARCHAR(50)" json:"milepostNumber"` //里程桩号
|
||||
UpDown string `xorm:"VARCHAR(20)" json:"upDown"` //上下行
|
||||
LineNum int `xorm:"SMALLINT default 1" json:"lineNum"` //车道号
|
||||
DiseaseType string `xorm:"VARCHAR(50)" json:"diseaseType"` //病害类型
|
||||
DiseaseLevel string `xorm:"VARCHAR(20)" json:"diseaseLevel"` //病害等级
|
||||
Length float64 `xorm:"decimal(18,6)" json:"length"` //长度
|
||||
Width float64 `xorm:"decimal(18,6)" json:"width"` //宽度
|
||||
Acreage float64 `xorm:"decimal(18,6)" json:"acreage"` //面积
|
||||
Memo string `xorm:"VARCHAR(1000)" json:"memo"` //备注说明
|
||||
Result string `xorm:"LONGTEXT" json:"result"` //识别结果
|
||||
Creator int64 `xorm:"INT(11) default 0" json:"creator"`
|
||||
Modifier int64 `xorm:"INT(11) default 0" json:"modifier"`
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||
}
|
|
@ -62,12 +62,11 @@ func UpdateTaskProgress(taskProgress *proto.TaskLogProgress) {
|
|||
}
|
||||
}
|
||||
|
||||
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
|
||||
ret := -1.0
|
||||
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) (int, int) {
|
||||
item := new(Task)
|
||||
h, err := DB.ID(res.TaskId).Get(item)
|
||||
if err != nil || !h {
|
||||
return ret
|
||||
return 0, 0
|
||||
}
|
||||
if isFailing {
|
||||
item.FailingCount += 1
|
||||
|
@ -79,12 +78,11 @@ func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
|
|||
item.FinishTime = time.Now().Unix()
|
||||
item.UnfinishedCount = 0
|
||||
item.Status = 3
|
||||
ret = 1.0
|
||||
}
|
||||
item.UpdateAt = time.Now().Unix()
|
||||
_, _ = DB.ID(res.TaskId).Cols("completed_count", "failing_count", "total_count", "unfinished_count", "update_at", "finish_time", "status").Update(item)
|
||||
if item.TotalCount > 0 {
|
||||
return 1 - float64(item.CompletedCount)/float64(item.TotalCount)
|
||||
return int(item.CompletedCount), int(item.UnfinishedCount)
|
||||
}
|
||||
return ret
|
||||
return int(item.CompletedCount), int(item.UnfinishedCount)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
package model
|
||||
|
||||
type TrainTask struct {
|
||||
TaskId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskId"`
|
||||
TrainDatasetId int64 `xorm:"INT(11) index" json:"trainDatasetId"`
|
||||
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
|
||||
TaskName string `xorm:"VARCHAR(200)" json:"taskName"`
|
||||
TaskDesc string `xorm:"VARCHAR(500)" json:"taskDesc"`
|
||||
Arithmetic string `xorm:"VARCHAR(50)" json:"arithmetic"`
|
||||
ImageSize int `xorm:"INT" json:"imageSize"`
|
||||
BatchSize int `xorm:"INT" json:"batchSize"`
|
||||
EpochsSize int `xorm:"INT" json:"epochsSize"`
|
||||
OutputType string `xorm:"VARCHAR(20)" json:"outputType"`
|
||||
StartTime int64 `xorm:"BIGINT" json:"startTime"`
|
||||
FinishTime int64 `xorm:"BIGINT" json:"finishTime"`
|
||||
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
|
||||
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
|
||||
ModelFilePath string `xorm:"VARCHAR(2000)" json:"modelFilePath"`
|
||||
ModelFileMetricsPath string `xorm:"VARCHAR(2000)" json:"modelFileMetricsPath"`
|
||||
Status int `xorm:"not null SMALLINT default 0" json:"status"` // 1:等待执行; 2:执行中; 3:执行完成; 4:任务分配失败; 5:任务执行失败
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package model
|
||||
|
||||
type TrainTaskLog struct {
|
||||
LogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"logId"`
|
||||
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
|
||||
Epoch int `xorm:"SMALLINT" json:"epoch"`
|
||||
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
|
||||
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
|
||||
ValLoss float64 `xorm:"DECIMAL(18,6)" json:"valLoss"`
|
||||
ValAccuracy float64 `xorm:"DECIMAL(18,6)" json:"valAccuracy"`
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package model
|
||||
|
||||
type TrainTaskResult struct {
|
||||
ResultId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"resultId"`
|
||||
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
|
||||
Content string `xorm:"LONGTEXT" json:"content"`
|
||||
Result string `xorm:"VARCHAR(200)" json:"result"`
|
||||
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
|
||||
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package model
|
||||
|
||||
type TrainingDataset struct {
|
||||
DatasetId int64 `xorm:"not null pk autoincr INT(11)" json:"datasetId"`
|
||||
Name string `xorm:"VARCHAR(200)" json:"name"`
|
||||
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
|
||||
DatasetDesc string `xorm:"varchar(200)" json:"datasetDesc"`
|
||||
StoreName string `xorm:"varchar(200)" json:"storeName"` //存储路径
|
||||
ValidationNumber float64 `xorm:"DECIMAL(18,4)" json:"validationNumber"` //验证占比
|
||||
TestNumber float64 `xorm:"DECIMAL(18,4)" json:"testNumber"` //测试占比
|
||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package model
|
||||
|
||||
type TrainingDatasetDetail struct {
|
||||
DetailId int64 `xorm:"not null pk autoincr INT(11)" json:"detailId"`
|
||||
FileName string `xorm:"VARCHAR(200)" json:"fileName"`
|
||||
FilePath string `xorm:"VARCHAR(1000)" json:"filePath"`
|
||||
DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //训练数据集
|
||||
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //训练集分类,1:训练集;2:测试集;3:验证集
|
||||
FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小
|
||||
FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5
|
||||
IsDisease int `xorm:"TINYINT(1)" json:"isDisease"` //是否有病害, 1:有病害;2:无病害;
|
||||
OperationLogId int64 `xorm:"INT(11) index" json:"operationLogId"` //操作日志编号
|
||||
Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人
|
||||
CreateAt int64 `xorm:"created" json:"createAt"` //上传时间
|
||||
UpdateAt int64 `xorm:"updated" json:"updateAt"` //更新时间
|
||||
}
|
515
mq/index.go
515
mq/index.go
|
@ -1,6 +1,7 @@
|
|||
package mq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -8,12 +9,18 @@ import (
|
|||
"git.hpds.cc/Component/network/frame"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"hpds_control_center/config"
|
||||
"hpds_control_center/internal/balance"
|
||||
"hpds_control_center/internal/minio"
|
||||
"hpds_control_center/internal/proto"
|
||||
"hpds_control_center/model"
|
||||
"hpds_control_center/pkg/utils"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -22,6 +29,13 @@ import (
|
|||
"git.hpds.cc/pavement/hpds_node"
|
||||
)
|
||||
|
||||
type Charset string
|
||||
|
||||
const (
|
||||
UTF8 = Charset("UTF-8")
|
||||
GB18030 = Charset("GB18030")
|
||||
)
|
||||
|
||||
var (
|
||||
MqList []HpdsMqNode
|
||||
TaskList = make(map[int64]*TaskItem)
|
||||
|
@ -379,8 +393,8 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
|
|||
item.UpdateAt = time.Now().Unix()
|
||||
_, _ = model.DB.ID(item.Id).AllCols().Update(item)
|
||||
} else {
|
||||
item.ModelId = payload["modelId"].(int64)
|
||||
item.NodeId = payload["nodeId"].(int64)
|
||||
item.ModelId = int64(payload["modelId"].(float64))
|
||||
item.NodeId = int64(payload["nodeId"].(float64))
|
||||
item.Status = 1
|
||||
item.IssueResult = string(pData)
|
||||
item.CreateAt = time.Now().Unix()
|
||||
|
@ -403,7 +417,16 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
|
|||
// _, _ = model.DB.Insert(item)
|
||||
// //fn := payload["fileName"].(string)
|
||||
// //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
|
||||
case TrainTaskAdd:
|
||||
payload := cmd.Payload.(map[string]interface{})
|
||||
if itemId, ok := payload["taskId"].(float64); ok {
|
||||
item := new(model.TrainTask)
|
||||
h, err := model.DB.ID(int64(itemId)).Get(item)
|
||||
if err != nil || !h {
|
||||
|
||||
}
|
||||
RunTraining(item)
|
||||
}
|
||||
default:
|
||||
|
||||
}
|
||||
|
@ -450,14 +473,16 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
|
|||
if err != nil {
|
||||
fmt.Println("接收TaskResponse数据出错", err)
|
||||
}
|
||||
//处理到项目结果表
|
||||
go processToProjectResult(item)
|
||||
//更新运行进度
|
||||
rat := model.UpdateTaskProgressByLog(item, isFailing)
|
||||
processed, unProcessed := model.UpdateTaskProgressByLog(item, isFailing)
|
||||
var (
|
||||
ratStr string
|
||||
)
|
||||
if rat > 0 && rat < 1 {
|
||||
ratStr = fmt.Sprintf("[已处理%2.f,剩余%2.f未处理]", 1-rat, rat)
|
||||
} else if rat == 1 {
|
||||
if unProcessed > 0 {
|
||||
ratStr = fmt.Sprintf("[已处理[%d],剩余[%d]未处理]", processed, unProcessed)
|
||||
} else {
|
||||
ratStr = "[已全部处理]"
|
||||
}
|
||||
taskLog := new(model.TaskLog)
|
||||
|
@ -479,6 +504,267 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
|
|||
return frame.Tag(cmd.Command), nil
|
||||
}
|
||||
|
||||
type ModelResult struct {
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
type InsigmaResult struct {
|
||||
Code int `json:"code"`
|
||||
NumOfDiseases int `json:"num_of_diseases"`
|
||||
Diseases []DiseasesInfo `json:"diseases"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
type DiseasesInfo struct {
|
||||
Id int `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Level string `json:"level"`
|
||||
Param DiseasesParam `json:"param"`
|
||||
}
|
||||
|
||||
type DiseasesParam struct {
|
||||
Length float64 `json:"length"`
|
||||
Area float64 `json:"area"`
|
||||
MaxWidth string `json:"max_width"`
|
||||
}
|
||||
|
||||
type LightweightResult struct {
|
||||
Code int `json:"code"`
|
||||
Crack bool `json:"crack"`
|
||||
ImgDiscern string `json:"img_discern"`
|
||||
ImgSrc string `json:"img_src"`
|
||||
Pothole bool `json:"pothole"`
|
||||
}
|
||||
|
||||
func processToProjectResult(src *model.TaskResult) {
|
||||
project := new(model.Project)
|
||||
h, err := model.DB.Table("project").Alias("p").Join("inner", []string{"dataset", "d"}, "d.project_id= p.project_id").Where("d.dataset_id=?", src.DatasetId).Get(project)
|
||||
if !h {
|
||||
err = fmt.Errorf("未能找到对应的项目信息")
|
||||
}
|
||||
if err != nil {
|
||||
logging.L().With(zap.String("控制节点", "错误信息")).Error("获取项目信息", zap.Error(err))
|
||||
return
|
||||
}
|
||||
var (
|
||||
mr ModelResult
|
||||
mrList []string
|
||||
fileDiscern string
|
||||
memo string
|
||||
milepostNumber string
|
||||
upDown string
|
||||
lineNum int
|
||||
width float64
|
||||
)
|
||||
|
||||
switch project.BizType {
|
||||
case 1: //道路
|
||||
arr := strings.Split(src.SrcPath, " ")
|
||||
if len(arr) > 1 {
|
||||
milepostNumber = GetMilepost(project.StartName, arr[1], arr[2])
|
||||
if arr[2] == "D" {
|
||||
upDown = "下行"
|
||||
} else {
|
||||
upDown = "上行"
|
||||
}
|
||||
}
|
||||
if len(arr) > 3 {
|
||||
lineNum, _ = strconv.Atoi(arr[3])
|
||||
}
|
||||
case 2: //桥梁
|
||||
case 3: //隧道
|
||||
//隧道名-采集方向(D/X)-相机编号(01-22)-采集序号(五位)K里程桩号.bmp DAXIASHAN-D-05-00003K15069.5.bmp
|
||||
arr := strings.Split(src.SrcPath, "K")
|
||||
if len(arr) > 1 {
|
||||
arrM := strings.Split(arr[1], ".")
|
||||
milepostNumber = meter2Milepost(arrM[0])
|
||||
arrD := strings.Split(arr[0], ".")
|
||||
if len(arrD) > 1 {
|
||||
if arrD[1] == "D" {
|
||||
upDown = "下行"
|
||||
} else {
|
||||
upDown = "上行"
|
||||
}
|
||||
}
|
||||
if len(arrD) > 4 {
|
||||
lineNum, _ = strconv.Atoi(arrD[3])
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(src.Result) > 0 && src.Result[0] == '[' {
|
||||
mrList = make([]string, 0)
|
||||
if err := json.Unmarshal([]byte(src.Result), &mrList); err != nil {
|
||||
return
|
||||
}
|
||||
list := make([]*model.ProjectResult, 0)
|
||||
for _, str := range mrList {
|
||||
if err := json.Unmarshal([]byte(str), &mr); err != nil {
|
||||
continue
|
||||
}
|
||||
if mr.Code == 2001 {
|
||||
ir := new(InsigmaResult)
|
||||
if err := json.Unmarshal([]byte(str), &ir); err != nil {
|
||||
continue
|
||||
}
|
||||
fileDiscern = ir.Image
|
||||
for key, value := range ir.Diseases {
|
||||
if len(value.Param.MaxWidth) > 0 && width == 0 {
|
||||
width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64)
|
||||
} else {
|
||||
width = 0
|
||||
}
|
||||
memo = fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area)
|
||||
item := &model.ProjectResult{
|
||||
ProjectId: project.ProjectId,
|
||||
SourceResultId: src.ResultId,
|
||||
MilepostNumber: milepostNumber,
|
||||
UpDown: upDown,
|
||||
LineNum: lineNum,
|
||||
DiseaseType: value.Type,
|
||||
DiseaseLevel: value.Level,
|
||||
Length: value.Param.Length,
|
||||
Width: width,
|
||||
Acreage: value.Param.Area,
|
||||
Memo: memo,
|
||||
Result: fileDiscern,
|
||||
Creator: 0,
|
||||
Modifier: 0,
|
||||
CreateAt: time.Now().Unix(),
|
||||
UpdateAt: time.Now().Unix(),
|
||||
}
|
||||
list = append(list, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
_, _ = model.DB.Insert(list)
|
||||
} else {
|
||||
if err := json.Unmarshal([]byte(src.Result), &mr); err != nil {
|
||||
return
|
||||
}
|
||||
switch mr.Code {
|
||||
case 0: //轻量化模型返回
|
||||
lr := new(LightweightResult)
|
||||
if err := json.Unmarshal([]byte(src.Result), &lr); err != nil {
|
||||
return
|
||||
}
|
||||
if lr.Crack || lr.Pothole {
|
||||
if lr.Crack {
|
||||
memo = "检测到裂缝"
|
||||
} else {
|
||||
memo = "检测到坑洼"
|
||||
}
|
||||
fileDiscern = lr.ImgDiscern
|
||||
if len(fileDiscern) == 0 {
|
||||
fileDiscern = lr.ImgSrc
|
||||
}
|
||||
diseaseLevelName := "重度"
|
||||
diseaseTypeName := ""
|
||||
switch project.BizType {
|
||||
case 2:
|
||||
diseaseTypeName = "结构裂缝"
|
||||
case 3:
|
||||
diseaseTypeName = "衬砌裂缝"
|
||||
default:
|
||||
diseaseTypeName = "横向裂缝"
|
||||
}
|
||||
item := &model.ProjectResult{
|
||||
ProjectId: project.ProjectId,
|
||||
SourceResultId: src.ResultId,
|
||||
MilepostNumber: milepostNumber,
|
||||
UpDown: upDown,
|
||||
LineNum: lineNum,
|
||||
DiseaseType: diseaseTypeName,
|
||||
DiseaseLevel: diseaseLevelName,
|
||||
Length: 0,
|
||||
Width: 0,
|
||||
Acreage: 0,
|
||||
Memo: memo,
|
||||
Result: fileDiscern,
|
||||
Creator: 0,
|
||||
Modifier: 0,
|
||||
CreateAt: time.Now().Unix(),
|
||||
UpdateAt: time.Now().Unix(),
|
||||
}
|
||||
_, _ = model.DB.Insert(item)
|
||||
} else {
|
||||
fileDiscern = lr.ImgSrc
|
||||
}
|
||||
//
|
||||
case 2001: //网新返回有病害
|
||||
ir := new(InsigmaResult)
|
||||
if err := json.Unmarshal([]byte(src.Result), &ir); err != nil {
|
||||
return
|
||||
}
|
||||
fileDiscern = ir.Image
|
||||
list := make([]*model.ProjectResult, 0)
|
||||
for _, val := range ir.Diseases {
|
||||
if len(val.Param.MaxWidth) > 0 && width == 0 {
|
||||
width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64)
|
||||
} else {
|
||||
width = 0
|
||||
}
|
||||
memo = fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area)
|
||||
maxWidth, _ := strconv.ParseFloat(val.Param.MaxWidth, 64)
|
||||
item := &model.ProjectResult{
|
||||
ProjectId: project.ProjectId,
|
||||
SourceResultId: src.ResultId,
|
||||
MilepostNumber: milepostNumber,
|
||||
UpDown: upDown,
|
||||
LineNum: lineNum,
|
||||
DiseaseType: val.Type,
|
||||
DiseaseLevel: val.Level,
|
||||
Length: val.Param.Length,
|
||||
Width: maxWidth,
|
||||
Acreage: val.Param.Area,
|
||||
Memo: memo,
|
||||
Result: fileDiscern,
|
||||
Creator: 0,
|
||||
Modifier: 0,
|
||||
CreateAt: time.Now().Unix(),
|
||||
UpdateAt: time.Now().Unix(),
|
||||
}
|
||||
list = append(list, item)
|
||||
}
|
||||
_, _ = model.DB.Insert(list)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 里程桩加减里程,返回里程桩
|
||||
func GetMilepost(start, num, upDown string) string {
|
||||
arr := strings.Split(start, "+")
|
||||
var (
|
||||
kilometre, meter, milepost, counter, res, resMilepost, resMeter float64
|
||||
)
|
||||
if len(arr) == 1 {
|
||||
meter = 0
|
||||
} else {
|
||||
meter, _ = strconv.ParseFloat(arr[1], 64)
|
||||
}
|
||||
str := strings.Replace(arr[0], "k", "", -1)
|
||||
str = strings.Replace(str, "K", "", -1)
|
||||
kilometre, _ = strconv.ParseFloat(str, 64)
|
||||
milepost = kilometre + meter/1000
|
||||
counter, _ = strconv.ParseFloat(num, 64)
|
||||
if upDown == "D" {
|
||||
res = milepost - counter
|
||||
|
||||
} else {
|
||||
res = milepost + counter
|
||||
}
|
||||
resMilepost = math.Floor(res)
|
||||
resMeter = (res - resMilepost) * 100
|
||||
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
|
||||
}
|
||||
|
||||
// 米装换成里程桩号
|
||||
func meter2Milepost(meter string) string {
|
||||
meter = strings.Replace(meter, "K", "", -1)
|
||||
m, _ := strconv.ParseFloat(meter, 64)
|
||||
resMilepost := math.Floor(m / 1000)
|
||||
resMeter := (m - resMilepost*1000) * 100
|
||||
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
|
||||
}
|
||||
func deliver(topic string, mqType uint, payload interface{}) {
|
||||
cli := GetMqClient(topic, mqType)
|
||||
pData, _ := json.Marshal(payload)
|
||||
|
@ -544,3 +830,220 @@ func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) {
|
|||
l.Unlock()
|
||||
return frame.Tag(cmd.Command), nil
|
||||
}
|
||||
|
||||
func RunTraining(task *model.TrainTask) {
|
||||
var (
|
||||
args []string
|
||||
modelPath, modelFileName, testSize string
|
||||
modelAcc, modelLoss float64
|
||||
)
|
||||
fmt.Println("curr tmp dir====>>>>", config.Cfg.TmpTrainDir)
|
||||
modelFileName = utils.GetUUIDString()
|
||||
//复制训练数据集
|
||||
tmpTrainDir := path.Join(config.Cfg.TmpTrainDir, fmt.Sprintf("%s_%s_%d_%d", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize))
|
||||
fileList := make([]model.TrainingDatasetDetail, 0)
|
||||
_ = model.DB.Where("dataset_id = ?", task.TrainDatasetId).Find(&fileList)
|
||||
_ = os.MkdirAll(tmpTrainDir, os.ModePerm)
|
||||
_ = os.MkdirAll(path.Join(tmpTrainDir, "train"), os.ModePerm)
|
||||
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "0"), os.ModePerm)
|
||||
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "1"), os.ModePerm)
|
||||
_ = os.MkdirAll(path.Join(tmpTrainDir, "val"), os.ModePerm)
|
||||
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "0"), os.ModePerm)
|
||||
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "1"), os.ModePerm)
|
||||
_ = os.MkdirAll(path.Join(tmpTrainDir, "test"), os.ModePerm)
|
||||
for _, v := range fileList {
|
||||
dstFilePath := ""
|
||||
switch v.CategoryId {
|
||||
case 2:
|
||||
dstFilePath = "test"
|
||||
default:
|
||||
dstFilePath = "train"
|
||||
}
|
||||
if v.CategoryId != 2 {
|
||||
if v.IsDisease == 1 {
|
||||
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "0")
|
||||
} else {
|
||||
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "1")
|
||||
}
|
||||
} else {
|
||||
dstFilePath = path.Join(tmpTrainDir, dstFilePath)
|
||||
}
|
||||
err := utils.CopyFile(v.FilePath, path.Join(dstFilePath, v.FileName))
|
||||
if err != nil {
|
||||
fmt.Println("copy error: ", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
modelPath = path.Join(config.Cfg.ModelOutPath, fmt.Sprintf("%s_%s_%d_%d_%s", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize, task.OutputType))
|
||||
_ = os.MkdirAll(modelPath, os.ModePerm)
|
||||
|
||||
dt := new(model.TrainingDataset)
|
||||
_, err := model.DB.ID(task.TrainDatasetId).Get(dt)
|
||||
if err != nil {
|
||||
goto ReturnPoint
|
||||
}
|
||||
testSize = fmt.Sprintf("%.2f", dt.ValidationNumber/100)
|
||||
//执行训练命令
|
||||
args = []string{"--dataset=" + path.Join(tmpTrainDir, "train"),
|
||||
"--img_size=" + strconv.Itoa(task.ImageSize), "--batch_size=" + strconv.Itoa(task.BatchSize), "--test_size=" + testSize,
|
||||
"--epochs=" + strconv.Itoa(task.EpochsSize), "--model=" + task.Arithmetic, "--model_save=" + path.Join(modelPath, modelFileName+".h5"),
|
||||
}
|
||||
fmt.Println("args====>>>", args)
|
||||
err = ExecCommand(config.Cfg.TrainScriptPath, args, path.Join(modelPath, modelFileName+".log"), task.TaskId)
|
||||
ReturnPoint:
|
||||
//返回训练结果
|
||||
modelMetricsFile := modelFileName + "_model_metrics.png"
|
||||
task.FinishTime = time.Now().Unix()
|
||||
task.ModelFilePath = path.Join(modelPath, modelFileName+".h5")
|
||||
task.Loss = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Loss:")
|
||||
task.Accuracy = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Accuracy:")
|
||||
task.Status = 3
|
||||
if err != nil {
|
||||
task.Status = 5
|
||||
}
|
||||
task.ModelFileMetricsPath = path.Join(modelPath, modelMetricsFile)
|
||||
_, _ = model.DB.ID(task.TaskId).AllCols().Update(task)
|
||||
if utils.PathExists(path.Join(modelPath, modelFileName+".log")) {
|
||||
logContext := utils.ReadFile(path.Join(modelPath, modelFileName+".log"))
|
||||
taskRes := new(model.TrainTaskResult)
|
||||
taskRes.TaskId = task.TaskId
|
||||
taskRes.CreateAt = time.Now().Unix()
|
||||
taskRes.Content = string(logContext)
|
||||
taskRes.Result = path.Join(modelPath, modelMetricsFile)
|
||||
taskRes.Accuracy = modelAcc
|
||||
taskRes.Loss = modelLoss
|
||||
c, err := model.DB.Insert(taskRes)
|
||||
if err != nil {
|
||||
fmt.Println("model.DB.Insert(taskRes) error ========>>>>>>", err)
|
||||
}
|
||||
fmt.Println("model.DB.Insert(taskRes) count ========>>>>>>", c)
|
||||
} else {
|
||||
fmt.Println("logContext========>>>>>>未读取")
|
||||
}
|
||||
}
|
||||
|
||||
func GetIndicatorByLog(logFileName, indicator string) float64 {
|
||||
logFn, _ := os.Open(logFileName)
|
||||
defer func() {
|
||||
_ = logFn.Close()
|
||||
}()
|
||||
buf := bufio.NewReader(logFn)
|
||||
for {
|
||||
line, err := buf.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
//fmt.Println("File read ok!")
|
||||
break
|
||||
} else {
|
||||
fmt.Println("Read file error!", err)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if strings.Index(line, indicator) >= 0 {
|
||||
str := strings.Replace(line, indicator, "", -1)
|
||||
str = strings.Replace(str, "\n", "", -1)
|
||||
value, _ := strconv.ParseFloat(strings.Trim(str, " "), 64)
|
||||
return value
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func ExecCommand(cmd string, args []string, logFileName string, taskId int64) (err error) {
|
||||
logFile, _ := os.Create(logFileName)
|
||||
defer func() {
|
||||
_ = logFile.Close()
|
||||
}()
|
||||
fmt.Print("开始训练......")
|
||||
|
||||
c := exec.Command(cmd, args...) // mac or linux
|
||||
stdout, err := c.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reader := bufio.NewReader(stdout)
|
||||
var (
|
||||
epoch int
|
||||
//modelLoss, modelAcc float64
|
||||
)
|
||||
for {
|
||||
readString, err := reader.ReadString('\n')
|
||||
if err != nil || err == io.EOF {
|
||||
fmt.Println("训练2===>>>", err)
|
||||
//wg.Done()
|
||||
return
|
||||
}
|
||||
byte2String := ConvertByte2String([]byte(readString), "GB18030")
|
||||
_, _ = fmt.Fprint(logFile, byte2String)
|
||||
if strings.Index(byte2String, "Epoch") >= 0 {
|
||||
str := strings.Replace(byte2String, "Epoch ", "", -1)
|
||||
arr := strings.Split(str, "/")
|
||||
epoch, _ = strconv.Atoi(arr[0])
|
||||
}
|
||||
if strings.Index(byte2String, "- loss:") > 0 &&
|
||||
strings.Index(byte2String, "- accuracy:") > 0 &&
|
||||
strings.Index(byte2String, "- val_loss:") > 0 &&
|
||||
strings.Index(byte2String, "- val_accuracy:") > 0 {
|
||||
var (
|
||||
loss, acc, valLoss, valAcc float64
|
||||
)
|
||||
arr := strings.Split(byte2String, "-")
|
||||
for _, v := range arr {
|
||||
if strings.Index(v, "loss:") > 0 && strings.Index(v, "val_loss:") < 0 {
|
||||
strLoss := strings.Replace(v, " loss: ", "", -1)
|
||||
loss, _ = strconv.ParseFloat(strings.Trim(strLoss, " "), 64)
|
||||
}
|
||||
if strings.Index(v, "accuracy:") > 0 && strings.Index(v, "val_accuracy:") < 0 {
|
||||
strAcc := strings.Replace(v, " accuracy: ", "", -1)
|
||||
acc, _ = strconv.ParseFloat(strings.Trim(strAcc, " "), 64)
|
||||
}
|
||||
if strings.Index(v, "val_loss:") > 0 {
|
||||
strValLoss := strings.Replace(v, "val_loss: ", "", -1)
|
||||
valLoss, _ = strconv.ParseFloat(strings.Trim(strValLoss, " "), 64)
|
||||
}
|
||||
if strings.Index(v, "val_accuracy:") > 0 {
|
||||
strValAcc := strings.Replace(v, "val_accuracy: ", "", -1)
|
||||
strValAcc = strings.Replace(strValAcc, "\n", "", -1)
|
||||
valAcc, _ = strconv.ParseFloat(strings.Trim(strValAcc, " "), 64)
|
||||
}
|
||||
}
|
||||
taskLog := new(model.TrainTaskLog)
|
||||
taskLog.Epoch = epoch
|
||||
taskLog.TaskId = taskId
|
||||
taskLog.CreateAt = time.Now().Unix()
|
||||
taskLog.Loss = loss
|
||||
taskLog.Accuracy = acc
|
||||
taskLog.ValLoss = valLoss
|
||||
taskLog.ValAccuracy = valAcc
|
||||
_, _ = model.DB.Insert(taskLog)
|
||||
}
|
||||
fmt.Print(byte2String)
|
||||
}
|
||||
}()
|
||||
err = c.Start()
|
||||
if err != nil {
|
||||
fmt.Println("训练3===>>>", err)
|
||||
}
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
func ConvertByte2String(byte []byte, charset Charset) string {
|
||||
var str string
|
||||
switch charset {
|
||||
case GB18030:
|
||||
var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte)
|
||||
str = string(decodeBytes)
|
||||
case UTF8:
|
||||
fallthrough
|
||||
default:
|
||||
str = string(byte)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ const (
|
|||
ModelIssueResponse
|
||||
TaskExecuteLog
|
||||
TaskLog
|
||||
TrainTaskAdd
|
||||
)
|
||||
|
||||
type InstructionReq struct {
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"git.hpds.cc/Component/logging"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CopyFile(src, dst string) error {
|
||||
sourceFileStat, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !sourceFileStat.Mode().IsRegular() {
|
||||
return fmt.Errorf("%s is not a regular file", src)
|
||||
}
|
||||
|
||||
source, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(source *os.File) {
|
||||
_ = source.Close()
|
||||
}(source)
|
||||
|
||||
destination, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func(destination *os.File) {
|
||||
_ = destination.Close()
|
||||
}(destination)
|
||||
_, err = io.Copy(destination, source)
|
||||
return err
|
||||
}
|
||||
|
||||
func PathExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ReadFile 读取到file中,再利用ioutil将file直接读取到[]byte中, 这是最优
|
||||
func ReadFile(fn string) []byte {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
defer func(f *os.File) {
|
||||
_ = f.Close()
|
||||
}(f)
|
||||
|
||||
fd, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
return fd
|
||||
}
|
||||
|
||||
func GetFileName(fn string) string {
|
||||
fileType := path.Ext(fn)
|
||||
return strings.TrimSuffix(fn, fileType)
|
||||
}
|
||||
func GetFileNameAndExt(fn string) string {
|
||||
_, fileName := filepath.Split(fn)
|
||||
return fileName
|
||||
}
|
||||
func GetFileMd5(data []byte) string {
|
||||
hash := md5.New()
|
||||
hash.Write(data)
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func FileToBase64(fn string) string {
|
||||
buff := ReadFile(fn)
|
||||
return base64.StdEncoding.EncodeToString(buff) // 加密成base64字符串
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) {
|
||||
var paramStr string = ""
|
||||
if contentType, ok := header["Content-Type"]; ok && strings.Contains(contentType, "json") {
|
||||
bytesData, _ := json.Marshal(params)
|
||||
paramStr = string(bytesData)
|
||||
} else {
|
||||
for k, v := range params {
|
||||
if len(paramStr) == 0 {
|
||||
paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v))
|
||||
} else {
|
||||
paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if resp.Body != nil {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
var body []byte
|
||||
if resp.Body != nil {
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
type UploadFile struct {
|
||||
// 表单名称
|
||||
Name string
|
||||
Filepath string
|
||||
// 文件全路径
|
||||
File *bytes.Buffer
|
||||
}
|
||||
|
||||
func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string {
|
||||
requestBody, realContentType := getReader(reqParams, contentType, files)
|
||||
httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody)
|
||||
// 添加请求头
|
||||
httpRequest.Header.Add("Content-Type", realContentType)
|
||||
if headers != nil {
|
||||
for k, v := range headers {
|
||||
httpRequest.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
httpClient := &http.Client{}
|
||||
// 发送请求
|
||||
resp, err := httpClient.Do(httpRequest)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
response, _ := io.ReadAll(resp.Body)
|
||||
return string(response)
|
||||
}
|
||||
|
||||
func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) {
|
||||
if strings.Index(contentType, "json") > -1 {
|
||||
bytesData, _ := json.Marshal(reqParams)
|
||||
return bytes.NewReader(bytesData), contentType
|
||||
} else if files != nil {
|
||||
body := &bytes.Buffer{}
|
||||
// 文件写入 body
|
||||
writer := multipart.NewWriter(body)
|
||||
for _, uploadFile := range files {
|
||||
part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = io.Copy(part, uploadFile.File)
|
||||
}
|
||||
// 其他参数列表写入 body
|
||||
for k, v := range reqParams {
|
||||
if err := writer.WriteField(k, v); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// 上传文件需要自己专用的contentType
|
||||
return body, writer.FormDataContentType()
|
||||
} else {
|
||||
urlValues := url.Values{}
|
||||
for key, val := range reqParams {
|
||||
urlValues.Set(key, val)
|
||||
}
|
||||
reqBody := urlValues.Encode()
|
||||
return strings.NewReader(reqBody), contentType
|
||||
}
|
||||
}
|
|
@ -0,0 +1,139 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"golang.org/x/image/bmp"
|
||||
"golang.org/x/image/tiff"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
)
|
||||
|
||||
func BuffToImage(in []byte) image.Image {
|
||||
buff := bytes.NewBuffer(in)
|
||||
m, _, _ := image.Decode(buff)
|
||||
return m
|
||||
}
|
||||
|
||||
// Clip 图片裁剪
|
||||
func Clip(in []byte, wi, hi int, equalProportion bool) (out image.Image, imageType string, err error) {
|
||||
buff := bytes.NewBuffer(in)
|
||||
m, imgType, _ := image.Decode(buff)
|
||||
rgbImg := m.(*image.YCbCr)
|
||||
if equalProportion {
|
||||
w := m.Bounds().Max.X
|
||||
h := m.Bounds().Max.Y
|
||||
if w > 0 && h > 0 && wi > 0 && hi > 0 {
|
||||
wi, hi = fixSize(w, h, wi, hi)
|
||||
}
|
||||
}
|
||||
return rgbImg.SubImage(image.Rect(0, 0, wi, hi)), imgType, nil
|
||||
}
|
||||
|
||||
func fixSize(img1W, img2H, wi, hi int) (new1W, new2W int) {
|
||||
var ( //为了方便计算,将图片的宽转为 float64
|
||||
imgWidth, imgHeight = float64(img1W), float64(img2H)
|
||||
ratio float64
|
||||
)
|
||||
if imgWidth >= imgHeight {
|
||||
ratio = imgWidth / float64(wi)
|
||||
return int(imgWidth * ratio), int(imgHeight * ratio)
|
||||
}
|
||||
ratio = imgHeight / float64(hi)
|
||||
return int(imgWidth * ratio), int(imgHeight * ratio)
|
||||
}
|
||||
|
||||
func Gray(in []byte) (out image.Image, err error) {
|
||||
m := BuffToImage(in)
|
||||
bounds := m.Bounds()
|
||||
dx := bounds.Dx()
|
||||
dy := bounds.Dy()
|
||||
newRgba := image.NewRGBA(bounds)
|
||||
for i := 0; i < dx; i++ {
|
||||
for j := 0; j < dy; j++ {
|
||||
colorRgb := m.At(i, j)
|
||||
_, g, _, a := colorRgb.RGBA()
|
||||
gUint8 := uint8(g >> 8)
|
||||
aUint8 := uint8(a >> 8)
|
||||
newRgba.SetRGBA(i, j, color.RGBA{R: gUint8, G: gUint8, B: gUint8, A: aUint8})
|
||||
}
|
||||
}
|
||||
r := image.Rect(0, 0, dx, dy)
|
||||
return newRgba.SubImage(r), nil
|
||||
}
|
||||
|
||||
func Rotate90(in []byte) image.Image {
|
||||
m := BuffToImage(in)
|
||||
rotate90 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
|
||||
// 矩阵旋转
|
||||
for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
|
||||
for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
|
||||
// 设置像素点
|
||||
rotate90.Set(m.Bounds().Max.Y-x, y, m.At(y, x))
|
||||
}
|
||||
}
|
||||
return rotate90
|
||||
}
|
||||
|
||||
// Rotate180 旋转180度
|
||||
func Rotate180(in []byte) image.Image {
|
||||
m := BuffToImage(in)
|
||||
rotate180 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dx(), m.Bounds().Dy()))
|
||||
// 矩阵旋转
|
||||
for x := m.Bounds().Min.X; x < m.Bounds().Max.X; x++ {
|
||||
for y := m.Bounds().Min.Y; y < m.Bounds().Max.Y; y++ {
|
||||
// 设置像素点
|
||||
rotate180.Set(m.Bounds().Max.X-x, m.Bounds().Max.Y-y, m.At(x, y))
|
||||
}
|
||||
}
|
||||
return rotate180
|
||||
}
|
||||
|
||||
// Rotate270 旋转270度
|
||||
func Rotate270(in []byte) image.Image {
|
||||
m := BuffToImage(in)
|
||||
rotate270 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
|
||||
// 矩阵旋转
|
||||
for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
|
||||
for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
|
||||
// 设置像素点
|
||||
rotate270.Set(x, m.Bounds().Max.X-y, m.At(y, x))
|
||||
}
|
||||
}
|
||||
return rotate270
|
||||
|
||||
}
|
||||
|
||||
func ImageToBase64(img image.Image, imgType string) string {
|
||||
buff := ImageToBuff(img, imgType)
|
||||
return base64.StdEncoding.EncodeToString(buff.Bytes())
|
||||
}
|
||||
|
||||
func ImageToBuff(img image.Image, imgType string) *bytes.Buffer {
|
||||
buff := bytes.NewBuffer(nil)
|
||||
switch imgType {
|
||||
case "bmp":
|
||||
imgType = "bmp"
|
||||
_ = bmp.Encode(buff, img)
|
||||
case "png":
|
||||
imgType = "png"
|
||||
_ = png.Encode(buff, img)
|
||||
case "tiff":
|
||||
imgType = "tiff"
|
||||
_ = tiff.Encode(buff, img, nil)
|
||||
default:
|
||||
imgType = "jpeg"
|
||||
_ = jpeg.Encode(buff, img, nil)
|
||||
}
|
||||
return buff
|
||||
}
|
||||
|
||||
func ImgFileToBase64(fn string) string {
|
||||
fileByte := ReadFile(fn)
|
||||
buff := bytes.NewBuffer(fileByte)
|
||||
m, _, _ := image.Decode(buff)
|
||||
return "data:image/jpeg;base64," + ImageToBase64(m, "jpeg")
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"github.com/google/uuid"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
RandomString 产生随机数
|
||||
- size 随机码的位数
|
||||
- kind 0 // 纯数字
|
||||
1 // 小写字母
|
||||
2 // 大写字母
|
||||
3 // 数字、大小写字母
|
||||
*/
|
||||
func RandomString(size int, kind int) string {
|
||||
iKind, kinds, rsBytes := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size)
|
||||
isAll := kind > 2 || kind < 0
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
for i := 0; i < size; i++ {
|
||||
if isAll { // random iKind
|
||||
iKind = rand.Intn(3)
|
||||
}
|
||||
scope, base := kinds[iKind][0], kinds[iKind][1]
|
||||
rsBytes[i] = uint8(base + rand.Intn(scope))
|
||||
}
|
||||
return string(rsBytes)
|
||||
}
|
||||
|
||||
func GetUserSha1Pass(pass, salt string) string {
|
||||
key := []byte(salt)
|
||||
mac := hmac.New(sha1.New, key)
|
||||
mac.Write([]byte(pass))
|
||||
//进行base64编码
|
||||
res := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||
return res
|
||||
}
|
||||
|
||||
func GetUUIDString() string {
|
||||
u, _ := uuid.NewUUID()
|
||||
str := strings.Replace(u.String(), "-", "", -1)
|
||||
return str
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
#!/bin/bash
|
||||
|
||||
#获取对应的参数 没有的话赋默认值
|
||||
ARGS=`getopt -o d::s::b::e::m::o:: --long dataset::,img_size::,batch_size::,epochs::,model::,model_save:: -n 'example.sh' -- "$@"`
|
||||
if [ $? != 0 ]; then
|
||||
echo "Terminating..."
|
||||
exit 1
|
||||
fi
|
||||
echo $ARGS
|
||||
eval set -- "${ARGS}"
|
||||
while true;
|
||||
do
|
||||
case "$1" in
|
||||
-d|--dataset)
|
||||
case "$2" in
|
||||
"")
|
||||
echo "Internal error!"
|
||||
exit 1
|
||||
;;
|
||||
*)
|
||||
Dataset=$2
|
||||
shift 2;
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
-s|--img_size)
|
||||
case "$2" in
|
||||
"")
|
||||
echo "Internal error!"
|
||||
exit 1
|
||||
;;
|
||||
|
||||
*)
|
||||
ImgSize=$2;
|
||||
shift 2;
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
-b|--batch_size)
|
||||
case "$2" in
|
||||
*)
|
||||
BatchSize=$2;
|
||||
shift 2;
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
-e|--epochs)
|
||||
case "$2" in
|
||||
*)
|
||||
Epochs=$2;
|
||||
shift 2;
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
-m|--model)
|
||||
case "$2" in
|
||||
*)
|
||||
Model=$2;
|
||||
shift 2;
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
-o|--model_save)
|
||||
case "$2" in
|
||||
*)
|
||||
ModelSave=$2;
|
||||
shift 2;
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
echo "Internal error!"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
done
|
||||
|
||||
#echo ${Dataset}
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate hpds_train
|
||||
cd /home/data/hpds_train_keras
|
||||
python train.py --dataset ${Dataset} --img_size ${ImgSize} --batch_size ${BatchSize} \
|
||||
--epochs ${Epochs} --model ${Model} --model_save ${ModelSave}
|
||||
|
Loading…
Reference in New Issue