1、增加任务处理进度
This commit is contained in:
parent
7d09ba0286
commit
90d43d468a
|
@ -19,6 +19,9 @@ type ControlCenterConfig struct {
|
||||||
Host string `yaml:"host,omitempty"`
|
Host string `yaml:"host,omitempty"`
|
||||||
Port int `yaml:"port,omitempty"`
|
Port int `yaml:"port,omitempty"`
|
||||||
Mode string `yaml:"mode,omitempty"`
|
Mode string `yaml:"mode,omitempty"`
|
||||||
|
TmpTrainDir string `yaml:"tmpTrainDir"`
|
||||||
|
TrainScriptPath string `yaml:"trainScriptPath"`
|
||||||
|
ModelOutPath string `yaml:"modelOutPath"`
|
||||||
Consul ConsulConfig `yaml:"consul,omitempty"`
|
Consul ConsulConfig `yaml:"consul,omitempty"`
|
||||||
Db DbConfig `yaml:"db"`
|
Db DbConfig `yaml:"db"`
|
||||||
Cache CacheConfig `yaml:"cache"`
|
Cache CacheConfig `yaml:"cache"`
|
||||||
|
|
|
@ -2,6 +2,9 @@ name: control_center
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
port: 8088
|
port: 8088
|
||||||
mode: dev
|
mode: dev
|
||||||
|
tmpTrainDir: ./tmp
|
||||||
|
trainScriptPath: ./scripts/runTrainScript.sh
|
||||||
|
modelOutPath: ./out
|
||||||
logging:
|
logging:
|
||||||
path: ./logs
|
path: ./logs
|
||||||
prefix: hpds-control
|
prefix: hpds-control
|
||||||
|
|
7
go.mod
7
go.mod
|
@ -7,12 +7,14 @@ require (
|
||||||
git.hpds.cc/Component/network v0.0.0-20230405135741-a4ea724bab76
|
git.hpds.cc/Component/network v0.0.0-20230405135741-a4ea724bab76
|
||||||
git.hpds.cc/pavement/hpds_node v0.0.0-20230405153516-9403c4d01e12
|
git.hpds.cc/pavement/hpds_node v0.0.0-20230405153516-9403c4d01e12
|
||||||
github.com/go-sql-driver/mysql v1.7.0
|
github.com/go-sql-driver/mysql v1.7.0
|
||||||
|
github.com/google/uuid v1.3.0
|
||||||
github.com/hashicorp/consul/api v1.20.0
|
github.com/hashicorp/consul/api v1.20.0
|
||||||
github.com/minio/minio-go v6.0.14+incompatible
|
|
||||||
github.com/minio/minio-go/v7 v7.0.52
|
github.com/minio/minio-go/v7 v7.0.52
|
||||||
github.com/spf13/cobra v1.6.1
|
github.com/spf13/cobra v1.6.1
|
||||||
github.com/spf13/viper v1.15.0
|
github.com/spf13/viper v1.15.0
|
||||||
go.uber.org/zap v1.23.0
|
go.uber.org/zap v1.23.0
|
||||||
|
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8
|
||||||
|
golang.org/x/text v0.7.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
xorm.io/xorm v1.3.2
|
xorm.io/xorm v1.3.2
|
||||||
)
|
)
|
||||||
|
@ -30,7 +32,6 @@ require (
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/fatih/color v1.13.0 // indirect
|
github.com/fatih/color v1.13.0 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||||
github.com/go-ini/ini v1.67.0 // indirect
|
|
||||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
|
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
|
||||||
github.com/goccy/go-json v0.8.1 // indirect
|
github.com/goccy/go-json v0.8.1 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
|
@ -40,7 +41,6 @@ require (
|
||||||
github.com/golang/snappy v0.0.4 // indirect
|
github.com/golang/snappy v0.0.4 // indirect
|
||||||
github.com/google/go-cmp v0.5.9 // indirect
|
github.com/google/go-cmp v0.5.9 // indirect
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||||
github.com/google/uuid v1.3.0 // indirect
|
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.2.1 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.2.1 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
github.com/googleapis/gax-go/v2 v2.7.0 // indirect
|
||||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||||
|
@ -94,7 +94,6 @@ require (
|
||||||
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect
|
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect
|
||||||
golang.org/x/sync v0.1.0 // indirect
|
golang.org/x/sync v0.1.0 // indirect
|
||||||
golang.org/x/sys v0.5.0 // indirect
|
golang.org/x/sys v0.5.0 // indirect
|
||||||
golang.org/x/text v0.7.0 // indirect
|
|
||||||
golang.org/x/time v0.1.0 // indirect
|
golang.org/x/time v0.1.0 // indirect
|
||||||
golang.org/x/tools v0.2.0 // indirect
|
golang.org/x/tools v0.2.0 // indirect
|
||||||
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
|
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
|
||||||
|
|
|
@ -4,7 +4,7 @@ type TaskLog struct {
|
||||||
TaskLogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskLogId"`
|
TaskLogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskLogId"`
|
||||||
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
|
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
|
||||||
NodeId int64 `xorm:"INT(11) index" json:"nodeId"`
|
NodeId int64 `xorm:"INT(11) index" json:"nodeId"`
|
||||||
Content string `xorm:"LANGTEXT" json:"content"`
|
Content string `xorm:"LONGTEXT" json:"content"`
|
||||||
CreateAt int64 `xorm:"created" json:"createAt"`
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
// DiseaseType 病害类别
|
||||||
|
type DiseaseType struct {
|
||||||
|
TypeId int64 `xorm:"not null pk autoincr INT(11)" json:"typeId"`
|
||||||
|
TypeName string `xorm:"varchar(200) not null" json:"typeName"`
|
||||||
|
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //病害分类, 1:道路 2:桥梁 3:隧道 4:边坡
|
||||||
|
Status int `xorm:"not null INT(11) default 0" json:"status"`
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
|
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDiseaseType(name string, categoryId int) int64 {
|
||||||
|
item := new(DiseaseType)
|
||||||
|
h, err := DB.Where("type_name like ?", "%"+name+"%").
|
||||||
|
And("category_id = ?", categoryId).
|
||||||
|
And("status = 1").Get(item)
|
||||||
|
if err != nil || !h {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return item.TypeId
|
||||||
|
|
||||||
|
}
|
|
@ -26,6 +26,11 @@ func New(driveName, dsn string, showSql bool) {
|
||||||
&Task{},
|
&Task{},
|
||||||
&TaskLog{},
|
&TaskLog{},
|
||||||
&TaskResult{},
|
&TaskResult{},
|
||||||
|
&TrainingDataset{},
|
||||||
|
&TrainingDatasetDetail{},
|
||||||
|
&TrainTask{},
|
||||||
|
&TrainTaskLog{},
|
||||||
|
&TrainTaskResult{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("同步数据库表结构", err)
|
fmt.Println("同步数据库表结构", err)
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type Project struct {
|
||||||
|
ProjectId int64 `xorm:"not null pk autoincr INT(11)" json:"projectId"`
|
||||||
|
ProjectName string `xorm:"varchar(200) not null " json:"projectName"`
|
||||||
|
OwnerId int64 `xorm:"not null INT(11) default 0" json:"ownerId"`
|
||||||
|
BizType int `xorm:"SMALLINT" json:"bizType"`
|
||||||
|
LineName string `xorm:"varchar(200) not null " json:"lineName"`
|
||||||
|
StartName string `xorm:"varchar(200) not null " json:"startName"`
|
||||||
|
EndName string `xorm:"varchar(200) not null " json:"endName"`
|
||||||
|
FixedDeviceNum int `xorm:"not null INT(11) default 0" json:"fixedDeviceNum"`
|
||||||
|
Direction string `xorm:"varchar(200) not null " json:"direction"`
|
||||||
|
LaneNum int `xorm:"not null INT(4) default 0" json:"laneNum"`
|
||||||
|
StartPointLng float64 `xorm:"decimal(18,6)" json:"startPointLng"`
|
||||||
|
StartPointLat float64 `xorm:"decimal(18,6)" json:"startPointLat"`
|
||||||
|
EndPointLng float64 `xorm:"decimal(18,6)" json:"endPointLng"`
|
||||||
|
EndPointLat float64 `xorm:"decimal(18,6)" json:"endPointLat"`
|
||||||
|
Status int `xorm:"SMALLINT default 1" json:"status"`
|
||||||
|
Creator int64 `xorm:"INT(11) default 0" json:"creator"`
|
||||||
|
Modifier int64 `xorm:"INT(11) default 0" json:"modifier"`
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
|
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type ProjectResult struct {
|
||||||
|
Id int64 `xorm:"not null pk autoincr BIGINT" json:"id"`
|
||||||
|
ProjectId int64 `xorm:"INT(11) index" json:"projectId"` //项目编号
|
||||||
|
SourceResultId int64 `xorm:"INT(11) index" json:"sourceResultId"` //识别结果来源编号
|
||||||
|
MilepostNumber string `xorm:"VARCHAR(50)" json:"milepostNumber"` //里程桩号
|
||||||
|
UpDown string `xorm:"VARCHAR(20)" json:"upDown"` //上下行
|
||||||
|
LineNum int `xorm:"SMALLINT default 1" json:"lineNum"` //车道号
|
||||||
|
DiseaseType string `xorm:"VARCHAR(50)" json:"diseaseType"` //病害类型
|
||||||
|
DiseaseLevel string `xorm:"VARCHAR(20)" json:"diseaseLevel"` //病害等级
|
||||||
|
Length float64 `xorm:"decimal(18,6)" json:"length"` //长度
|
||||||
|
Width float64 `xorm:"decimal(18,6)" json:"width"` //宽度
|
||||||
|
Acreage float64 `xorm:"decimal(18,6)" json:"acreage"` //面积
|
||||||
|
Memo string `xorm:"VARCHAR(1000)" json:"memo"` //备注说明
|
||||||
|
Result string `xorm:"LONGTEXT" json:"result"` //识别结果
|
||||||
|
Creator int64 `xorm:"INT(11) default 0" json:"creator"`
|
||||||
|
Modifier int64 `xorm:"INT(11) default 0" json:"modifier"`
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
|
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||||
|
}
|
|
@ -62,12 +62,11 @@ func UpdateTaskProgress(taskProgress *proto.TaskLogProgress) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
|
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) (int, int) {
|
||||||
ret := -1.0
|
|
||||||
item := new(Task)
|
item := new(Task)
|
||||||
h, err := DB.ID(res.TaskId).Get(item)
|
h, err := DB.ID(res.TaskId).Get(item)
|
||||||
if err != nil || !h {
|
if err != nil || !h {
|
||||||
return ret
|
return 0, 0
|
||||||
}
|
}
|
||||||
if isFailing {
|
if isFailing {
|
||||||
item.FailingCount += 1
|
item.FailingCount += 1
|
||||||
|
@ -79,12 +78,11 @@ func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
|
||||||
item.FinishTime = time.Now().Unix()
|
item.FinishTime = time.Now().Unix()
|
||||||
item.UnfinishedCount = 0
|
item.UnfinishedCount = 0
|
||||||
item.Status = 3
|
item.Status = 3
|
||||||
ret = 1.0
|
|
||||||
}
|
}
|
||||||
item.UpdateAt = time.Now().Unix()
|
item.UpdateAt = time.Now().Unix()
|
||||||
_, _ = DB.ID(res.TaskId).Cols("completed_count", "failing_count", "total_count", "unfinished_count", "update_at", "finish_time", "status").Update(item)
|
_, _ = DB.ID(res.TaskId).Cols("completed_count", "failing_count", "total_count", "unfinished_count", "update_at", "finish_time", "status").Update(item)
|
||||||
if item.TotalCount > 0 {
|
if item.TotalCount > 0 {
|
||||||
return 1 - float64(item.CompletedCount)/float64(item.TotalCount)
|
return int(item.CompletedCount), int(item.UnfinishedCount)
|
||||||
}
|
}
|
||||||
return ret
|
return int(item.CompletedCount), int(item.UnfinishedCount)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type TrainTask struct {
|
||||||
|
TaskId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskId"`
|
||||||
|
TrainDatasetId int64 `xorm:"INT(11) index" json:"trainDatasetId"`
|
||||||
|
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
|
||||||
|
TaskName string `xorm:"VARCHAR(200)" json:"taskName"`
|
||||||
|
TaskDesc string `xorm:"VARCHAR(500)" json:"taskDesc"`
|
||||||
|
Arithmetic string `xorm:"VARCHAR(50)" json:"arithmetic"`
|
||||||
|
ImageSize int `xorm:"INT" json:"imageSize"`
|
||||||
|
BatchSize int `xorm:"INT" json:"batchSize"`
|
||||||
|
EpochsSize int `xorm:"INT" json:"epochsSize"`
|
||||||
|
OutputType string `xorm:"VARCHAR(20)" json:"outputType"`
|
||||||
|
StartTime int64 `xorm:"BIGINT" json:"startTime"`
|
||||||
|
FinishTime int64 `xorm:"BIGINT" json:"finishTime"`
|
||||||
|
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
|
||||||
|
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
|
||||||
|
ModelFilePath string `xorm:"VARCHAR(2000)" json:"modelFilePath"`
|
||||||
|
ModelFileMetricsPath string `xorm:"VARCHAR(2000)" json:"modelFileMetricsPath"`
|
||||||
|
Status int `xorm:"not null SMALLINT default 0" json:"status"` // 1:等待执行; 2:执行中; 3:执行完成; 4:任务分配失败; 5:任务执行失败
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
|
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type TrainTaskLog struct {
|
||||||
|
LogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"logId"`
|
||||||
|
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
|
||||||
|
Epoch int `xorm:"SMALLINT" json:"epoch"`
|
||||||
|
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
|
||||||
|
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
|
||||||
|
ValLoss float64 `xorm:"DECIMAL(18,6)" json:"valLoss"`
|
||||||
|
ValAccuracy float64 `xorm:"DECIMAL(18,6)" json:"valAccuracy"`
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type TrainTaskResult struct {
|
||||||
|
ResultId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"resultId"`
|
||||||
|
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
|
||||||
|
Content string `xorm:"LONGTEXT" json:"content"`
|
||||||
|
Result string `xorm:"VARCHAR(200)" json:"result"`
|
||||||
|
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
|
||||||
|
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type TrainingDataset struct {
|
||||||
|
DatasetId int64 `xorm:"not null pk autoincr INT(11)" json:"datasetId"`
|
||||||
|
Name string `xorm:"VARCHAR(200)" json:"name"`
|
||||||
|
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
|
||||||
|
DatasetDesc string `xorm:"varchar(200)" json:"datasetDesc"`
|
||||||
|
StoreName string `xorm:"varchar(200)" json:"storeName"` //存储路径
|
||||||
|
ValidationNumber float64 `xorm:"DECIMAL(18,4)" json:"validationNumber"` //验证占比
|
||||||
|
TestNumber float64 `xorm:"DECIMAL(18,4)" json:"testNumber"` //测试占比
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"`
|
||||||
|
UpdateAt int64 `xorm:"updated" json:"updateAt"`
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
type TrainingDatasetDetail struct {
|
||||||
|
DetailId int64 `xorm:"not null pk autoincr INT(11)" json:"detailId"`
|
||||||
|
FileName string `xorm:"VARCHAR(200)" json:"fileName"`
|
||||||
|
FilePath string `xorm:"VARCHAR(1000)" json:"filePath"`
|
||||||
|
DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //训练数据集
|
||||||
|
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //训练集分类,1:训练集;2:测试集;3:验证集
|
||||||
|
FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小
|
||||||
|
FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5
|
||||||
|
IsDisease int `xorm:"TINYINT(1)" json:"isDisease"` //是否有病害, 1:有病害;2:无病害;
|
||||||
|
OperationLogId int64 `xorm:"INT(11) index" json:"operationLogId"` //操作日志编号
|
||||||
|
Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人
|
||||||
|
CreateAt int64 `xorm:"created" json:"createAt"` //上传时间
|
||||||
|
UpdateAt int64 `xorm:"updated" json:"updateAt"` //更新时间
|
||||||
|
}
|
515
mq/index.go
515
mq/index.go
|
@ -1,6 +1,7 @@
|
||||||
package mq
|
package mq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -8,12 +9,18 @@ import (
|
||||||
"git.hpds.cc/Component/network/frame"
|
"git.hpds.cc/Component/network/frame"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/text/encoding/simplifiedchinese"
|
||||||
"hpds_control_center/config"
|
"hpds_control_center/config"
|
||||||
"hpds_control_center/internal/balance"
|
"hpds_control_center/internal/balance"
|
||||||
"hpds_control_center/internal/minio"
|
"hpds_control_center/internal/minio"
|
||||||
"hpds_control_center/internal/proto"
|
"hpds_control_center/internal/proto"
|
||||||
"hpds_control_center/model"
|
"hpds_control_center/model"
|
||||||
|
"hpds_control_center/pkg/utils"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -22,6 +29,13 @@ import (
|
||||||
"git.hpds.cc/pavement/hpds_node"
|
"git.hpds.cc/pavement/hpds_node"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Charset string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UTF8 = Charset("UTF-8")
|
||||||
|
GB18030 = Charset("GB18030")
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
MqList []HpdsMqNode
|
MqList []HpdsMqNode
|
||||||
TaskList = make(map[int64]*TaskItem)
|
TaskList = make(map[int64]*TaskItem)
|
||||||
|
@ -379,8 +393,8 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
|
||||||
item.UpdateAt = time.Now().Unix()
|
item.UpdateAt = time.Now().Unix()
|
||||||
_, _ = model.DB.ID(item.Id).AllCols().Update(item)
|
_, _ = model.DB.ID(item.Id).AllCols().Update(item)
|
||||||
} else {
|
} else {
|
||||||
item.ModelId = payload["modelId"].(int64)
|
item.ModelId = int64(payload["modelId"].(float64))
|
||||||
item.NodeId = payload["nodeId"].(int64)
|
item.NodeId = int64(payload["nodeId"].(float64))
|
||||||
item.Status = 1
|
item.Status = 1
|
||||||
item.IssueResult = string(pData)
|
item.IssueResult = string(pData)
|
||||||
item.CreateAt = time.Now().Unix()
|
item.CreateAt = time.Now().Unix()
|
||||||
|
@ -403,7 +417,16 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
|
||||||
// _, _ = model.DB.Insert(item)
|
// _, _ = model.DB.Insert(item)
|
||||||
// //fn := payload["fileName"].(string)
|
// //fn := payload["fileName"].(string)
|
||||||
// //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
|
// //dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
|
||||||
|
case TrainTaskAdd:
|
||||||
|
payload := cmd.Payload.(map[string]interface{})
|
||||||
|
if itemId, ok := payload["taskId"].(float64); ok {
|
||||||
|
item := new(model.TrainTask)
|
||||||
|
h, err := model.DB.ID(int64(itemId)).Get(item)
|
||||||
|
if err != nil || !h {
|
||||||
|
|
||||||
|
}
|
||||||
|
RunTraining(item)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -450,14 +473,16 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("接收TaskResponse数据出错", err)
|
fmt.Println("接收TaskResponse数据出错", err)
|
||||||
}
|
}
|
||||||
|
//处理到项目结果表
|
||||||
|
go processToProjectResult(item)
|
||||||
//更新运行进度
|
//更新运行进度
|
||||||
rat := model.UpdateTaskProgressByLog(item, isFailing)
|
processed, unProcessed := model.UpdateTaskProgressByLog(item, isFailing)
|
||||||
var (
|
var (
|
||||||
ratStr string
|
ratStr string
|
||||||
)
|
)
|
||||||
if rat > 0 && rat < 1 {
|
if unProcessed > 0 {
|
||||||
ratStr = fmt.Sprintf("[已处理%2.f,剩余%2.f未处理]", 1-rat, rat)
|
ratStr = fmt.Sprintf("[已处理[%d],剩余[%d]未处理]", processed, unProcessed)
|
||||||
} else if rat == 1 {
|
} else {
|
||||||
ratStr = "[已全部处理]"
|
ratStr = "[已全部处理]"
|
||||||
}
|
}
|
||||||
taskLog := new(model.TaskLog)
|
taskLog := new(model.TaskLog)
|
||||||
|
@ -479,6 +504,267 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
|
||||||
return frame.Tag(cmd.Command), nil
|
return frame.Tag(cmd.Command), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ModelResult struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InsigmaResult struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
NumOfDiseases int `json:"num_of_diseases"`
|
||||||
|
Diseases []DiseasesInfo `json:"diseases"`
|
||||||
|
Image string `json:"image"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiseasesInfo struct {
|
||||||
|
Id int `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Level string `json:"level"`
|
||||||
|
Param DiseasesParam `json:"param"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiseasesParam struct {
|
||||||
|
Length float64 `json:"length"`
|
||||||
|
Area float64 `json:"area"`
|
||||||
|
MaxWidth string `json:"max_width"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LightweightResult struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Crack bool `json:"crack"`
|
||||||
|
ImgDiscern string `json:"img_discern"`
|
||||||
|
ImgSrc string `json:"img_src"`
|
||||||
|
Pothole bool `json:"pothole"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func processToProjectResult(src *model.TaskResult) {
|
||||||
|
project := new(model.Project)
|
||||||
|
h, err := model.DB.Table("project").Alias("p").Join("inner", []string{"dataset", "d"}, "d.project_id= p.project_id").Where("d.dataset_id=?", src.DatasetId).Get(project)
|
||||||
|
if !h {
|
||||||
|
err = fmt.Errorf("未能找到对应的项目信息")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logging.L().With(zap.String("控制节点", "错误信息")).Error("获取项目信息", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
mr ModelResult
|
||||||
|
mrList []string
|
||||||
|
fileDiscern string
|
||||||
|
memo string
|
||||||
|
milepostNumber string
|
||||||
|
upDown string
|
||||||
|
lineNum int
|
||||||
|
width float64
|
||||||
|
)
|
||||||
|
|
||||||
|
switch project.BizType {
|
||||||
|
case 1: //道路
|
||||||
|
arr := strings.Split(src.SrcPath, " ")
|
||||||
|
if len(arr) > 1 {
|
||||||
|
milepostNumber = GetMilepost(project.StartName, arr[1], arr[2])
|
||||||
|
if arr[2] == "D" {
|
||||||
|
upDown = "下行"
|
||||||
|
} else {
|
||||||
|
upDown = "上行"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(arr) > 3 {
|
||||||
|
lineNum, _ = strconv.Atoi(arr[3])
|
||||||
|
}
|
||||||
|
case 2: //桥梁
|
||||||
|
case 3: //隧道
|
||||||
|
//隧道名-采集方向(D/X)-相机编号(01-22)-采集序号(五位)K里程桩号.bmp DAXIASHAN-D-05-00003K15069.5.bmp
|
||||||
|
arr := strings.Split(src.SrcPath, "K")
|
||||||
|
if len(arr) > 1 {
|
||||||
|
arrM := strings.Split(arr[1], ".")
|
||||||
|
milepostNumber = meter2Milepost(arrM[0])
|
||||||
|
arrD := strings.Split(arr[0], ".")
|
||||||
|
if len(arrD) > 1 {
|
||||||
|
if arrD[1] == "D" {
|
||||||
|
upDown = "下行"
|
||||||
|
} else {
|
||||||
|
upDown = "上行"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(arrD) > 4 {
|
||||||
|
lineNum, _ = strconv.Atoi(arrD[3])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(src.Result) > 0 && src.Result[0] == '[' {
|
||||||
|
mrList = make([]string, 0)
|
||||||
|
if err := json.Unmarshal([]byte(src.Result), &mrList); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := make([]*model.ProjectResult, 0)
|
||||||
|
for _, str := range mrList {
|
||||||
|
if err := json.Unmarshal([]byte(str), &mr); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if mr.Code == 2001 {
|
||||||
|
ir := new(InsigmaResult)
|
||||||
|
if err := json.Unmarshal([]byte(str), &ir); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fileDiscern = ir.Image
|
||||||
|
for key, value := range ir.Diseases {
|
||||||
|
if len(value.Param.MaxWidth) > 0 && width == 0 {
|
||||||
|
width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64)
|
||||||
|
} else {
|
||||||
|
width = 0
|
||||||
|
}
|
||||||
|
memo = fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area)
|
||||||
|
item := &model.ProjectResult{
|
||||||
|
ProjectId: project.ProjectId,
|
||||||
|
SourceResultId: src.ResultId,
|
||||||
|
MilepostNumber: milepostNumber,
|
||||||
|
UpDown: upDown,
|
||||||
|
LineNum: lineNum,
|
||||||
|
DiseaseType: value.Type,
|
||||||
|
DiseaseLevel: value.Level,
|
||||||
|
Length: value.Param.Length,
|
||||||
|
Width: width,
|
||||||
|
Acreage: value.Param.Area,
|
||||||
|
Memo: memo,
|
||||||
|
Result: fileDiscern,
|
||||||
|
Creator: 0,
|
||||||
|
Modifier: 0,
|
||||||
|
CreateAt: time.Now().Unix(),
|
||||||
|
UpdateAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
list = append(list, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, _ = model.DB.Insert(list)
|
||||||
|
} else {
|
||||||
|
if err := json.Unmarshal([]byte(src.Result), &mr); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch mr.Code {
|
||||||
|
case 0: //轻量化模型返回
|
||||||
|
lr := new(LightweightResult)
|
||||||
|
if err := json.Unmarshal([]byte(src.Result), &lr); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if lr.Crack || lr.Pothole {
|
||||||
|
if lr.Crack {
|
||||||
|
memo = "检测到裂缝"
|
||||||
|
} else {
|
||||||
|
memo = "检测到坑洼"
|
||||||
|
}
|
||||||
|
fileDiscern = lr.ImgDiscern
|
||||||
|
if len(fileDiscern) == 0 {
|
||||||
|
fileDiscern = lr.ImgSrc
|
||||||
|
}
|
||||||
|
diseaseLevelName := "重度"
|
||||||
|
diseaseTypeName := ""
|
||||||
|
switch project.BizType {
|
||||||
|
case 2:
|
||||||
|
diseaseTypeName = "结构裂缝"
|
||||||
|
case 3:
|
||||||
|
diseaseTypeName = "衬砌裂缝"
|
||||||
|
default:
|
||||||
|
diseaseTypeName = "横向裂缝"
|
||||||
|
}
|
||||||
|
item := &model.ProjectResult{
|
||||||
|
ProjectId: project.ProjectId,
|
||||||
|
SourceResultId: src.ResultId,
|
||||||
|
MilepostNumber: milepostNumber,
|
||||||
|
UpDown: upDown,
|
||||||
|
LineNum: lineNum,
|
||||||
|
DiseaseType: diseaseTypeName,
|
||||||
|
DiseaseLevel: diseaseLevelName,
|
||||||
|
Length: 0,
|
||||||
|
Width: 0,
|
||||||
|
Acreage: 0,
|
||||||
|
Memo: memo,
|
||||||
|
Result: fileDiscern,
|
||||||
|
Creator: 0,
|
||||||
|
Modifier: 0,
|
||||||
|
CreateAt: time.Now().Unix(),
|
||||||
|
UpdateAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
_, _ = model.DB.Insert(item)
|
||||||
|
} else {
|
||||||
|
fileDiscern = lr.ImgSrc
|
||||||
|
}
|
||||||
|
//
|
||||||
|
case 2001: //网新返回有病害
|
||||||
|
ir := new(InsigmaResult)
|
||||||
|
if err := json.Unmarshal([]byte(src.Result), &ir); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fileDiscern = ir.Image
|
||||||
|
list := make([]*model.ProjectResult, 0)
|
||||||
|
for _, val := range ir.Diseases {
|
||||||
|
if len(val.Param.MaxWidth) > 0 && width == 0 {
|
||||||
|
width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64)
|
||||||
|
} else {
|
||||||
|
width = 0
|
||||||
|
}
|
||||||
|
memo = fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area)
|
||||||
|
maxWidth, _ := strconv.ParseFloat(val.Param.MaxWidth, 64)
|
||||||
|
item := &model.ProjectResult{
|
||||||
|
ProjectId: project.ProjectId,
|
||||||
|
SourceResultId: src.ResultId,
|
||||||
|
MilepostNumber: milepostNumber,
|
||||||
|
UpDown: upDown,
|
||||||
|
LineNum: lineNum,
|
||||||
|
DiseaseType: val.Type,
|
||||||
|
DiseaseLevel: val.Level,
|
||||||
|
Length: val.Param.Length,
|
||||||
|
Width: maxWidth,
|
||||||
|
Acreage: val.Param.Area,
|
||||||
|
Memo: memo,
|
||||||
|
Result: fileDiscern,
|
||||||
|
Creator: 0,
|
||||||
|
Modifier: 0,
|
||||||
|
CreateAt: time.Now().Unix(),
|
||||||
|
UpdateAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
list = append(list, item)
|
||||||
|
}
|
||||||
|
_, _ = model.DB.Insert(list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 里程桩加减里程,返回里程桩
|
||||||
|
func GetMilepost(start, num, upDown string) string {
|
||||||
|
arr := strings.Split(start, "+")
|
||||||
|
var (
|
||||||
|
kilometre, meter, milepost, counter, res, resMilepost, resMeter float64
|
||||||
|
)
|
||||||
|
if len(arr) == 1 {
|
||||||
|
meter = 0
|
||||||
|
} else {
|
||||||
|
meter, _ = strconv.ParseFloat(arr[1], 64)
|
||||||
|
}
|
||||||
|
str := strings.Replace(arr[0], "k", "", -1)
|
||||||
|
str = strings.Replace(str, "K", "", -1)
|
||||||
|
kilometre, _ = strconv.ParseFloat(str, 64)
|
||||||
|
milepost = kilometre + meter/1000
|
||||||
|
counter, _ = strconv.ParseFloat(num, 64)
|
||||||
|
if upDown == "D" {
|
||||||
|
res = milepost - counter
|
||||||
|
|
||||||
|
} else {
|
||||||
|
res = milepost + counter
|
||||||
|
}
|
||||||
|
resMilepost = math.Floor(res)
|
||||||
|
resMeter = (res - resMilepost) * 100
|
||||||
|
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 米装换成里程桩号
|
||||||
|
func meter2Milepost(meter string) string {
|
||||||
|
meter = strings.Replace(meter, "K", "", -1)
|
||||||
|
m, _ := strconv.ParseFloat(meter, 64)
|
||||||
|
resMilepost := math.Floor(m / 1000)
|
||||||
|
resMeter := (m - resMilepost*1000) * 100
|
||||||
|
return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
|
||||||
|
}
|
||||||
func deliver(topic string, mqType uint, payload interface{}) {
|
func deliver(topic string, mqType uint, payload interface{}) {
|
||||||
cli := GetMqClient(topic, mqType)
|
cli := GetMqClient(topic, mqType)
|
||||||
pData, _ := json.Marshal(payload)
|
pData, _ := json.Marshal(payload)
|
||||||
|
@ -544,3 +830,220 @@ func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) {
|
||||||
l.Unlock()
|
l.Unlock()
|
||||||
return frame.Tag(cmd.Command), nil
|
return frame.Tag(cmd.Command), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RunTraining(task *model.TrainTask) {
|
||||||
|
var (
|
||||||
|
args []string
|
||||||
|
modelPath, modelFileName, testSize string
|
||||||
|
modelAcc, modelLoss float64
|
||||||
|
)
|
||||||
|
fmt.Println("curr tmp dir====>>>>", config.Cfg.TmpTrainDir)
|
||||||
|
modelFileName = utils.GetUUIDString()
|
||||||
|
//复制训练数据集
|
||||||
|
tmpTrainDir := path.Join(config.Cfg.TmpTrainDir, fmt.Sprintf("%s_%s_%d_%d", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize))
|
||||||
|
fileList := make([]model.TrainingDatasetDetail, 0)
|
||||||
|
_ = model.DB.Where("dataset_id = ?", task.TrainDatasetId).Find(&fileList)
|
||||||
|
_ = os.MkdirAll(tmpTrainDir, os.ModePerm)
|
||||||
|
_ = os.MkdirAll(path.Join(tmpTrainDir, "train"), os.ModePerm)
|
||||||
|
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "0"), os.ModePerm)
|
||||||
|
_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "1"), os.ModePerm)
|
||||||
|
_ = os.MkdirAll(path.Join(tmpTrainDir, "val"), os.ModePerm)
|
||||||
|
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "0"), os.ModePerm)
|
||||||
|
_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "1"), os.ModePerm)
|
||||||
|
_ = os.MkdirAll(path.Join(tmpTrainDir, "test"), os.ModePerm)
|
||||||
|
for _, v := range fileList {
|
||||||
|
dstFilePath := ""
|
||||||
|
switch v.CategoryId {
|
||||||
|
case 2:
|
||||||
|
dstFilePath = "test"
|
||||||
|
default:
|
||||||
|
dstFilePath = "train"
|
||||||
|
}
|
||||||
|
if v.CategoryId != 2 {
|
||||||
|
if v.IsDisease == 1 {
|
||||||
|
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "0")
|
||||||
|
} else {
|
||||||
|
dstFilePath = path.Join(tmpTrainDir, dstFilePath, "1")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dstFilePath = path.Join(tmpTrainDir, dstFilePath)
|
||||||
|
}
|
||||||
|
err := utils.CopyFile(v.FilePath, path.Join(dstFilePath, v.FileName))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("copy error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
modelPath = path.Join(config.Cfg.ModelOutPath, fmt.Sprintf("%s_%s_%d_%d_%s", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize, task.OutputType))
|
||||||
|
_ = os.MkdirAll(modelPath, os.ModePerm)
|
||||||
|
|
||||||
|
dt := new(model.TrainingDataset)
|
||||||
|
_, err := model.DB.ID(task.TrainDatasetId).Get(dt)
|
||||||
|
if err != nil {
|
||||||
|
goto ReturnPoint
|
||||||
|
}
|
||||||
|
testSize = fmt.Sprintf("%.2f", dt.ValidationNumber/100)
|
||||||
|
//执行训练命令
|
||||||
|
args = []string{"--dataset=" + path.Join(tmpTrainDir, "train"),
|
||||||
|
"--img_size=" + strconv.Itoa(task.ImageSize), "--batch_size=" + strconv.Itoa(task.BatchSize), "--test_size=" + testSize,
|
||||||
|
"--epochs=" + strconv.Itoa(task.EpochsSize), "--model=" + task.Arithmetic, "--model_save=" + path.Join(modelPath, modelFileName+".h5"),
|
||||||
|
}
|
||||||
|
fmt.Println("args====>>>", args)
|
||||||
|
err = ExecCommand(config.Cfg.TrainScriptPath, args, path.Join(modelPath, modelFileName+".log"), task.TaskId)
|
||||||
|
ReturnPoint:
|
||||||
|
//返回训练结果
|
||||||
|
modelMetricsFile := modelFileName + "_model_metrics.png"
|
||||||
|
task.FinishTime = time.Now().Unix()
|
||||||
|
task.ModelFilePath = path.Join(modelPath, modelFileName+".h5")
|
||||||
|
task.Loss = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Loss:")
|
||||||
|
task.Accuracy = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Accuracy:")
|
||||||
|
task.Status = 3
|
||||||
|
if err != nil {
|
||||||
|
task.Status = 5
|
||||||
|
}
|
||||||
|
task.ModelFileMetricsPath = path.Join(modelPath, modelMetricsFile)
|
||||||
|
_, _ = model.DB.ID(task.TaskId).AllCols().Update(task)
|
||||||
|
if utils.PathExists(path.Join(modelPath, modelFileName+".log")) {
|
||||||
|
logContext := utils.ReadFile(path.Join(modelPath, modelFileName+".log"))
|
||||||
|
taskRes := new(model.TrainTaskResult)
|
||||||
|
taskRes.TaskId = task.TaskId
|
||||||
|
taskRes.CreateAt = time.Now().Unix()
|
||||||
|
taskRes.Content = string(logContext)
|
||||||
|
taskRes.Result = path.Join(modelPath, modelMetricsFile)
|
||||||
|
taskRes.Accuracy = modelAcc
|
||||||
|
taskRes.Loss = modelLoss
|
||||||
|
c, err := model.DB.Insert(taskRes)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("model.DB.Insert(taskRes) error ========>>>>>>", err)
|
||||||
|
}
|
||||||
|
fmt.Println("model.DB.Insert(taskRes) count ========>>>>>>", c)
|
||||||
|
} else {
|
||||||
|
fmt.Println("logContext========>>>>>>未读取")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetIndicatorByLog(logFileName, indicator string) float64 {
|
||||||
|
logFn, _ := os.Open(logFileName)
|
||||||
|
defer func() {
|
||||||
|
_ = logFn.Close()
|
||||||
|
}()
|
||||||
|
buf := bufio.NewReader(logFn)
|
||||||
|
for {
|
||||||
|
line, err := buf.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
//fmt.Println("File read ok!")
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
fmt.Println("Read file error!", err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Index(line, indicator) >= 0 {
|
||||||
|
str := strings.Replace(line, indicator, "", -1)
|
||||||
|
str = strings.Replace(str, "\n", "", -1)
|
||||||
|
value, _ := strconv.ParseFloat(strings.Trim(str, " "), 64)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExecCommand(cmd string, args []string, logFileName string, taskId int64) (err error) {
|
||||||
|
logFile, _ := os.Create(logFileName)
|
||||||
|
defer func() {
|
||||||
|
_ = logFile.Close()
|
||||||
|
}()
|
||||||
|
fmt.Print("开始训练......")
|
||||||
|
|
||||||
|
c := exec.Command(cmd, args...) // mac or linux
|
||||||
|
stdout, err := c.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
wg sync.WaitGroup
|
||||||
|
)
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
reader := bufio.NewReader(stdout)
|
||||||
|
var (
|
||||||
|
epoch int
|
||||||
|
//modelLoss, modelAcc float64
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
readString, err := reader.ReadString('\n')
|
||||||
|
if err != nil || err == io.EOF {
|
||||||
|
fmt.Println("训练2===>>>", err)
|
||||||
|
//wg.Done()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
byte2String := ConvertByte2String([]byte(readString), "GB18030")
|
||||||
|
_, _ = fmt.Fprint(logFile, byte2String)
|
||||||
|
if strings.Index(byte2String, "Epoch") >= 0 {
|
||||||
|
str := strings.Replace(byte2String, "Epoch ", "", -1)
|
||||||
|
arr := strings.Split(str, "/")
|
||||||
|
epoch, _ = strconv.Atoi(arr[0])
|
||||||
|
}
|
||||||
|
if strings.Index(byte2String, "- loss:") > 0 &&
|
||||||
|
strings.Index(byte2String, "- accuracy:") > 0 &&
|
||||||
|
strings.Index(byte2String, "- val_loss:") > 0 &&
|
||||||
|
strings.Index(byte2String, "- val_accuracy:") > 0 {
|
||||||
|
var (
|
||||||
|
loss, acc, valLoss, valAcc float64
|
||||||
|
)
|
||||||
|
arr := strings.Split(byte2String, "-")
|
||||||
|
for _, v := range arr {
|
||||||
|
if strings.Index(v, "loss:") > 0 && strings.Index(v, "val_loss:") < 0 {
|
||||||
|
strLoss := strings.Replace(v, " loss: ", "", -1)
|
||||||
|
loss, _ = strconv.ParseFloat(strings.Trim(strLoss, " "), 64)
|
||||||
|
}
|
||||||
|
if strings.Index(v, "accuracy:") > 0 && strings.Index(v, "val_accuracy:") < 0 {
|
||||||
|
strAcc := strings.Replace(v, " accuracy: ", "", -1)
|
||||||
|
acc, _ = strconv.ParseFloat(strings.Trim(strAcc, " "), 64)
|
||||||
|
}
|
||||||
|
if strings.Index(v, "val_loss:") > 0 {
|
||||||
|
strValLoss := strings.Replace(v, "val_loss: ", "", -1)
|
||||||
|
valLoss, _ = strconv.ParseFloat(strings.Trim(strValLoss, " "), 64)
|
||||||
|
}
|
||||||
|
if strings.Index(v, "val_accuracy:") > 0 {
|
||||||
|
strValAcc := strings.Replace(v, "val_accuracy: ", "", -1)
|
||||||
|
strValAcc = strings.Replace(strValAcc, "\n", "", -1)
|
||||||
|
valAcc, _ = strconv.ParseFloat(strings.Trim(strValAcc, " "), 64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
taskLog := new(model.TrainTaskLog)
|
||||||
|
taskLog.Epoch = epoch
|
||||||
|
taskLog.TaskId = taskId
|
||||||
|
taskLog.CreateAt = time.Now().Unix()
|
||||||
|
taskLog.Loss = loss
|
||||||
|
taskLog.Accuracy = acc
|
||||||
|
taskLog.ValLoss = valLoss
|
||||||
|
taskLog.ValAccuracy = valAcc
|
||||||
|
_, _ = model.DB.Insert(taskLog)
|
||||||
|
}
|
||||||
|
fmt.Print(byte2String)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err = c.Start()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("训练3===>>>", err)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func ConvertByte2String(byte []byte, charset Charset) string {
|
||||||
|
var str string
|
||||||
|
switch charset {
|
||||||
|
case GB18030:
|
||||||
|
var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte)
|
||||||
|
str = string(decodeBytes)
|
||||||
|
case UTF8:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
str = string(byte)
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ const (
|
||||||
ModelIssueResponse
|
ModelIssueResponse
|
||||||
TaskExecuteLog
|
TaskExecuteLog
|
||||||
TaskLog
|
TaskLog
|
||||||
|
TrainTaskAdd
|
||||||
)
|
)
|
||||||
|
|
||||||
type InstructionReq struct {
|
type InstructionReq struct {
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"git.hpds.cc/Component/logging"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CopyFile(src, dst string) error {
|
||||||
|
sourceFileStat, err := os.Stat(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !sourceFileStat.Mode().IsRegular() {
|
||||||
|
return fmt.Errorf("%s is not a regular file", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
source, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func(source *os.File) {
|
||||||
|
_ = source.Close()
|
||||||
|
}(source)
|
||||||
|
|
||||||
|
destination, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func(destination *os.File) {
|
||||||
|
_ = destination.Close()
|
||||||
|
}(destination)
|
||||||
|
_, err = io.Copy(destination, source)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func PathExists(path string) bool {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFile 读取到file中,再利用ioutil将file直接读取到[]byte中, 这是最优
|
||||||
|
func ReadFile(fn string) []byte {
|
||||||
|
f, err := os.Open(fn)
|
||||||
|
if err != nil {
|
||||||
|
logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer func(f *os.File) {
|
||||||
|
_ = f.Close()
|
||||||
|
}(f)
|
||||||
|
|
||||||
|
fd, err := io.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fd
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFileName(fn string) string {
|
||||||
|
fileType := path.Ext(fn)
|
||||||
|
return strings.TrimSuffix(fn, fileType)
|
||||||
|
}
|
||||||
|
func GetFileNameAndExt(fn string) string {
|
||||||
|
_, fileName := filepath.Split(fn)
|
||||||
|
return fileName
|
||||||
|
}
|
||||||
|
func GetFileMd5(data []byte) string {
|
||||||
|
hash := md5.New()
|
||||||
|
hash.Write(data)
|
||||||
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func FileToBase64(fn string) string {
|
||||||
|
buff := ReadFile(fn)
|
||||||
|
return base64.StdEncoding.EncodeToString(buff) // 加密成base64字符串
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) {
|
||||||
|
var paramStr string = ""
|
||||||
|
if contentType, ok := header["Content-Type"]; ok && strings.Contains(contentType, "json") {
|
||||||
|
bytesData, _ := json.Marshal(params)
|
||||||
|
paramStr = string(bytesData)
|
||||||
|
} else {
|
||||||
|
for k, v := range params {
|
||||||
|
if len(paramStr) == 0 {
|
||||||
|
paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v))
|
||||||
|
} else {
|
||||||
|
paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for k, v := range header {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if resp.Body != nil {
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var body []byte
|
||||||
|
if resp.Body != nil {
|
||||||
|
body, err = io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type UploadFile struct {
|
||||||
|
// 表单名称
|
||||||
|
Name string
|
||||||
|
Filepath string
|
||||||
|
// 文件全路径
|
||||||
|
File *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string {
|
||||||
|
requestBody, realContentType := getReader(reqParams, contentType, files)
|
||||||
|
httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody)
|
||||||
|
// 添加请求头
|
||||||
|
httpRequest.Header.Add("Content-Type", realContentType)
|
||||||
|
if headers != nil {
|
||||||
|
for k, v := range headers {
|
||||||
|
httpRequest.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
httpClient := &http.Client{}
|
||||||
|
// 发送请求
|
||||||
|
resp, err := httpClient.Do(httpRequest)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer func(Body io.ReadCloser) {
|
||||||
|
_ = Body.Close()
|
||||||
|
}(resp.Body)
|
||||||
|
response, _ := io.ReadAll(resp.Body)
|
||||||
|
return string(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) {
|
||||||
|
if strings.Index(contentType, "json") > -1 {
|
||||||
|
bytesData, _ := json.Marshal(reqParams)
|
||||||
|
return bytes.NewReader(bytesData), contentType
|
||||||
|
} else if files != nil {
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
// 文件写入 body
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
for _, uploadFile := range files {
|
||||||
|
part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(part, uploadFile.File)
|
||||||
|
}
|
||||||
|
// 其他参数列表写入 body
|
||||||
|
for k, v := range reqParams {
|
||||||
|
if err := writer.WriteField(k, v); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := writer.Close(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
// 上传文件需要自己专用的contentType
|
||||||
|
return body, writer.FormDataContentType()
|
||||||
|
} else {
|
||||||
|
urlValues := url.Values{}
|
||||||
|
for key, val := range reqParams {
|
||||||
|
urlValues.Set(key, val)
|
||||||
|
}
|
||||||
|
reqBody := urlValues.Encode()
|
||||||
|
return strings.NewReader(reqBody), contentType
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,139 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"golang.org/x/image/bmp"
|
||||||
|
"golang.org/x/image/tiff"
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/jpeg"
|
||||||
|
"image/png"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BuffToImage(in []byte) image.Image {
|
||||||
|
buff := bytes.NewBuffer(in)
|
||||||
|
m, _, _ := image.Decode(buff)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clip 图片裁剪
|
||||||
|
func Clip(in []byte, wi, hi int, equalProportion bool) (out image.Image, imageType string, err error) {
|
||||||
|
buff := bytes.NewBuffer(in)
|
||||||
|
m, imgType, _ := image.Decode(buff)
|
||||||
|
rgbImg := m.(*image.YCbCr)
|
||||||
|
if equalProportion {
|
||||||
|
w := m.Bounds().Max.X
|
||||||
|
h := m.Bounds().Max.Y
|
||||||
|
if w > 0 && h > 0 && wi > 0 && hi > 0 {
|
||||||
|
wi, hi = fixSize(w, h, wi, hi)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rgbImg.SubImage(image.Rect(0, 0, wi, hi)), imgType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fixSize(img1W, img2H, wi, hi int) (new1W, new2W int) {
|
||||||
|
var ( //为了方便计算,将图片的宽转为 float64
|
||||||
|
imgWidth, imgHeight = float64(img1W), float64(img2H)
|
||||||
|
ratio float64
|
||||||
|
)
|
||||||
|
if imgWidth >= imgHeight {
|
||||||
|
ratio = imgWidth / float64(wi)
|
||||||
|
return int(imgWidth * ratio), int(imgHeight * ratio)
|
||||||
|
}
|
||||||
|
ratio = imgHeight / float64(hi)
|
||||||
|
return int(imgWidth * ratio), int(imgHeight * ratio)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Gray(in []byte) (out image.Image, err error) {
|
||||||
|
m := BuffToImage(in)
|
||||||
|
bounds := m.Bounds()
|
||||||
|
dx := bounds.Dx()
|
||||||
|
dy := bounds.Dy()
|
||||||
|
newRgba := image.NewRGBA(bounds)
|
||||||
|
for i := 0; i < dx; i++ {
|
||||||
|
for j := 0; j < dy; j++ {
|
||||||
|
colorRgb := m.At(i, j)
|
||||||
|
_, g, _, a := colorRgb.RGBA()
|
||||||
|
gUint8 := uint8(g >> 8)
|
||||||
|
aUint8 := uint8(a >> 8)
|
||||||
|
newRgba.SetRGBA(i, j, color.RGBA{R: gUint8, G: gUint8, B: gUint8, A: aUint8})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r := image.Rect(0, 0, dx, dy)
|
||||||
|
return newRgba.SubImage(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Rotate90(in []byte) image.Image {
|
||||||
|
m := BuffToImage(in)
|
||||||
|
rotate90 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
|
||||||
|
// 矩阵旋转
|
||||||
|
for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
|
||||||
|
for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
|
||||||
|
// 设置像素点
|
||||||
|
rotate90.Set(m.Bounds().Max.Y-x, y, m.At(y, x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rotate90
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rotate180 旋转180度
|
||||||
|
func Rotate180(in []byte) image.Image {
|
||||||
|
m := BuffToImage(in)
|
||||||
|
rotate180 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dx(), m.Bounds().Dy()))
|
||||||
|
// 矩阵旋转
|
||||||
|
for x := m.Bounds().Min.X; x < m.Bounds().Max.X; x++ {
|
||||||
|
for y := m.Bounds().Min.Y; y < m.Bounds().Max.Y; y++ {
|
||||||
|
// 设置像素点
|
||||||
|
rotate180.Set(m.Bounds().Max.X-x, m.Bounds().Max.Y-y, m.At(x, y))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rotate180
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rotate270 旋转270度
|
||||||
|
func Rotate270(in []byte) image.Image {
|
||||||
|
m := BuffToImage(in)
|
||||||
|
rotate270 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
|
||||||
|
// 矩阵旋转
|
||||||
|
for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
|
||||||
|
for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
|
||||||
|
// 设置像素点
|
||||||
|
rotate270.Set(x, m.Bounds().Max.X-y, m.At(y, x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rotate270
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func ImageToBase64(img image.Image, imgType string) string {
|
||||||
|
buff := ImageToBuff(img, imgType)
|
||||||
|
return base64.StdEncoding.EncodeToString(buff.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func ImageToBuff(img image.Image, imgType string) *bytes.Buffer {
|
||||||
|
buff := bytes.NewBuffer(nil)
|
||||||
|
switch imgType {
|
||||||
|
case "bmp":
|
||||||
|
imgType = "bmp"
|
||||||
|
_ = bmp.Encode(buff, img)
|
||||||
|
case "png":
|
||||||
|
imgType = "png"
|
||||||
|
_ = png.Encode(buff, img)
|
||||||
|
case "tiff":
|
||||||
|
imgType = "tiff"
|
||||||
|
_ = tiff.Encode(buff, img, nil)
|
||||||
|
default:
|
||||||
|
imgType = "jpeg"
|
||||||
|
_ = jpeg.Encode(buff, img, nil)
|
||||||
|
}
|
||||||
|
return buff
|
||||||
|
}
|
||||||
|
|
||||||
|
func ImgFileToBase64(fn string) string {
|
||||||
|
fileByte := ReadFile(fn)
|
||||||
|
buff := bytes.NewBuffer(fileByte)
|
||||||
|
m, _, _ := image.Decode(buff)
|
||||||
|
return "data:image/jpeg;base64," + ImageToBase64(m, "jpeg")
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
RandomString 产生随机数
|
||||||
|
- size 随机码的位数
|
||||||
|
- kind 0 // 纯数字
|
||||||
|
1 // 小写字母
|
||||||
|
2 // 大写字母
|
||||||
|
3 // 数字、大小写字母
|
||||||
|
*/
|
||||||
|
func RandomString(size int, kind int) string {
|
||||||
|
iKind, kinds, rsBytes := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size)
|
||||||
|
isAll := kind > 2 || kind < 0
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
if isAll { // random iKind
|
||||||
|
iKind = rand.Intn(3)
|
||||||
|
}
|
||||||
|
scope, base := kinds[iKind][0], kinds[iKind][1]
|
||||||
|
rsBytes[i] = uint8(base + rand.Intn(scope))
|
||||||
|
}
|
||||||
|
return string(rsBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserSha1Pass(pass, salt string) string {
|
||||||
|
key := []byte(salt)
|
||||||
|
mac := hmac.New(sha1.New, key)
|
||||||
|
mac.Write([]byte(pass))
|
||||||
|
//进行base64编码
|
||||||
|
res := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUUIDString() string {
|
||||||
|
u, _ := uuid.NewUUID()
|
||||||
|
str := strings.Replace(u.String(), "-", "", -1)
|
||||||
|
return str
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#获取对应的参数 没有的话赋默认值
|
||||||
|
ARGS=`getopt -o d::s::b::e::m::o:: --long dataset::,img_size::,batch_size::,epochs::,model::,model_save:: -n 'example.sh' -- "$@"`
|
||||||
|
if [ $? != 0 ]; then
|
||||||
|
echo "Terminating..."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo $ARGS
|
||||||
|
eval set -- "${ARGS}"
|
||||||
|
while true;
|
||||||
|
do
|
||||||
|
case "$1" in
|
||||||
|
-d|--dataset)
|
||||||
|
case "$2" in
|
||||||
|
"")
|
||||||
|
echo "Internal error!"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
Dataset=$2
|
||||||
|
shift 2;
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
-s|--img_size)
|
||||||
|
case "$2" in
|
||||||
|
"")
|
||||||
|
echo "Internal error!"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
|
||||||
|
*)
|
||||||
|
ImgSize=$2;
|
||||||
|
shift 2;
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
-b|--batch_size)
|
||||||
|
case "$2" in
|
||||||
|
*)
|
||||||
|
BatchSize=$2;
|
||||||
|
shift 2;
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
-e|--epochs)
|
||||||
|
case "$2" in
|
||||||
|
*)
|
||||||
|
Epochs=$2;
|
||||||
|
shift 2;
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
-m|--model)
|
||||||
|
case "$2" in
|
||||||
|
*)
|
||||||
|
Model=$2;
|
||||||
|
shift 2;
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
-o|--model_save)
|
||||||
|
case "$2" in
|
||||||
|
*)
|
||||||
|
ModelSave=$2;
|
||||||
|
shift 2;
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
;;
|
||||||
|
|
||||||
|
--)
|
||||||
|
shift
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Internal error!"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
done
|
||||||
|
|
||||||
|
#echo ${Dataset}
|
||||||
|
|
||||||
|
eval "$(conda shell.bash hook)"
|
||||||
|
conda activate hpds_train
|
||||||
|
cd /home/data/hpds_train_keras
|
||||||
|
python train.py --dataset ${Dataset} --img_size ${ImgSize} --batch_size ${BatchSize} \
|
||||||
|
--epochs ${Epochs} --model ${Model} --model_save ${ModelSave}
|
||||||
|
|
Loading…
Reference in New Issue