1、增加任务处理进度
This commit is contained in:
		
							parent
							
								
									7d09ba0286
								
							
						
					
					
						commit
						90d43d468a
					
				| 
						 | 
				
			
			@ -19,6 +19,9 @@ type ControlCenterConfig struct {
 | 
			
		|||
	Host            string       `yaml:"host,omitempty"`
 | 
			
		||||
	Port            int          `yaml:"port,omitempty"`
 | 
			
		||||
	Mode            string       `yaml:"mode,omitempty"`
 | 
			
		||||
	TmpTrainDir     string       `yaml:"tmpTrainDir"`
 | 
			
		||||
	TrainScriptPath string       `yaml:"trainScriptPath"`
 | 
			
		||||
	ModelOutPath    string       `yaml:"modelOutPath"`
 | 
			
		||||
	Consul          ConsulConfig `yaml:"consul,omitempty"`
 | 
			
		||||
	Db              DbConfig     `yaml:"db"`
 | 
			
		||||
	Cache           CacheConfig  `yaml:"cache"`
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,9 @@ name: control_center
 | 
			
		|||
host: 0.0.0.0
 | 
			
		||||
port: 8088
 | 
			
		||||
mode: dev
 | 
			
		||||
tmpTrainDir: ./tmp
 | 
			
		||||
trainScriptPath: ./scripts/runTrainScript.sh
 | 
			
		||||
modelOutPath: ./out
 | 
			
		||||
logging:
 | 
			
		||||
  path: ./logs
 | 
			
		||||
  prefix: hpds-control
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										7
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										7
									
								
								go.mod
								
								
								
								
							| 
						 | 
				
			
			@ -7,12 +7,14 @@ require (
 | 
			
		|||
	git.hpds.cc/Component/network v0.0.0-20230405135741-a4ea724bab76
 | 
			
		||||
	git.hpds.cc/pavement/hpds_node v0.0.0-20230405153516-9403c4d01e12
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.7.0
 | 
			
		||||
	github.com/google/uuid v1.3.0
 | 
			
		||||
	github.com/hashicorp/consul/api v1.20.0
 | 
			
		||||
	github.com/minio/minio-go v6.0.14+incompatible
 | 
			
		||||
	github.com/minio/minio-go/v7 v7.0.52
 | 
			
		||||
	github.com/spf13/cobra v1.6.1
 | 
			
		||||
	github.com/spf13/viper v1.15.0
 | 
			
		||||
	go.uber.org/zap v1.23.0
 | 
			
		||||
	golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8
 | 
			
		||||
	golang.org/x/text v0.7.0
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1
 | 
			
		||||
	xorm.io/xorm v1.3.2
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -30,7 +32,6 @@ require (
 | 
			
		|||
	github.com/dustin/go-humanize v1.0.1 // indirect
 | 
			
		||||
	github.com/fatih/color v1.13.0 // indirect
 | 
			
		||||
	github.com/fsnotify/fsnotify v1.6.0 // indirect
 | 
			
		||||
	github.com/go-ini/ini v1.67.0 // indirect
 | 
			
		||||
	github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
 | 
			
		||||
	github.com/goccy/go-json v0.8.1 // indirect
 | 
			
		||||
	github.com/gogo/protobuf v1.3.2 // indirect
 | 
			
		||||
| 
						 | 
				
			
			@ -40,7 +41,6 @@ require (
 | 
			
		|||
	github.com/golang/snappy v0.0.4 // indirect
 | 
			
		||||
	github.com/google/go-cmp v0.5.9 // indirect
 | 
			
		||||
	github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
 | 
			
		||||
	github.com/google/uuid v1.3.0 // indirect
 | 
			
		||||
	github.com/googleapis/enterprise-certificate-proxy v0.2.1 // indirect
 | 
			
		||||
	github.com/googleapis/gax-go/v2 v2.7.0 // indirect
 | 
			
		||||
	github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
 | 
			
		||||
| 
						 | 
				
			
			@ -94,7 +94,6 @@ require (
 | 
			
		|||
	golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect
 | 
			
		||||
	golang.org/x/sync v0.1.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.5.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/time v0.1.0 // indirect
 | 
			
		||||
	golang.org/x/tools v0.2.0 // indirect
 | 
			
		||||
	golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,7 +4,7 @@ type TaskLog struct {
 | 
			
		|||
	TaskLogId int64  `xorm:"not null pk autoincr BIGINT(11)" json:"taskLogId"`
 | 
			
		||||
	TaskId    int64  `xorm:"INT(11) index" json:"taskId"`
 | 
			
		||||
	NodeId    int64  `xorm:"INT(11) index" json:"nodeId"`
 | 
			
		||||
	Content   string `xorm:"LANGTEXT" json:"content"`
 | 
			
		||||
	Content   string `xorm:"LONGTEXT" json:"content"`
 | 
			
		||||
	CreateAt  int64  `xorm:"created" json:"createAt"`
 | 
			
		||||
	UpdateAt  int64  `xorm:"updated" json:"updateAt"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
// DiseaseType 病害类别
 | 
			
		||||
type DiseaseType struct {
 | 
			
		||||
	TypeId     int64  `xorm:"not null pk autoincr INT(11)" json:"typeId"`
 | 
			
		||||
	TypeName   string `xorm:"varchar(200) not null" json:"typeName"`
 | 
			
		||||
	CategoryId int    `xorm:"not null SMALLINT default 1" json:"categoryId"` //病害分类, 1:道路 2:桥梁 3:隧道 4:边坡
 | 
			
		||||
	Status     int    `xorm:"not null INT(11) default 0" json:"status"`
 | 
			
		||||
	CreateAt   int64  `xorm:"created" json:"createAt"`
 | 
			
		||||
	UpdateAt   int64  `xorm:"updated" json:"updateAt"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetDiseaseType(name string, categoryId int) int64 {
 | 
			
		||||
	item := new(DiseaseType)
 | 
			
		||||
	h, err := DB.Where("type_name like ?", "%"+name+"%").
 | 
			
		||||
		And("category_id = ?", categoryId).
 | 
			
		||||
		And("status = 1").Get(item)
 | 
			
		||||
	if err != nil || !h {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return item.TypeId
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -26,6 +26,11 @@ func New(driveName, dsn string, showSql bool) {
 | 
			
		|||
		&Task{},
 | 
			
		||||
		&TaskLog{},
 | 
			
		||||
		&TaskResult{},
 | 
			
		||||
		&TrainingDataset{},
 | 
			
		||||
		&TrainingDatasetDetail{},
 | 
			
		||||
		&TrainTask{},
 | 
			
		||||
		&TrainTaskLog{},
 | 
			
		||||
		&TrainTaskResult{},
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println("同步数据库表结构", err)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
type Project struct {
 | 
			
		||||
	ProjectId      int64   `xorm:"not null pk autoincr INT(11)" json:"projectId"`
 | 
			
		||||
	ProjectName    string  `xorm:"varchar(200) not null " json:"projectName"`
 | 
			
		||||
	OwnerId        int64   `xorm:"not null INT(11) default 0" json:"ownerId"`
 | 
			
		||||
	BizType        int     `xorm:"SMALLINT" json:"bizType"`
 | 
			
		||||
	LineName       string  `xorm:"varchar(200) not null " json:"lineName"`
 | 
			
		||||
	StartName      string  `xorm:"varchar(200) not null " json:"startName"`
 | 
			
		||||
	EndName        string  `xorm:"varchar(200) not null " json:"endName"`
 | 
			
		||||
	FixedDeviceNum int     `xorm:"not null INT(11) default 0" json:"fixedDeviceNum"`
 | 
			
		||||
	Direction      string  `xorm:"varchar(200) not null " json:"direction"`
 | 
			
		||||
	LaneNum        int     `xorm:"not null INT(4) default 0" json:"laneNum"`
 | 
			
		||||
	StartPointLng  float64 `xorm:"decimal(18,6)" json:"startPointLng"`
 | 
			
		||||
	StartPointLat  float64 `xorm:"decimal(18,6)" json:"startPointLat"`
 | 
			
		||||
	EndPointLng    float64 `xorm:"decimal(18,6)" json:"endPointLng"`
 | 
			
		||||
	EndPointLat    float64 `xorm:"decimal(18,6)" json:"endPointLat"`
 | 
			
		||||
	Status         int     `xorm:"SMALLINT default 1" json:"status"`
 | 
			
		||||
	Creator        int64   `xorm:"INT(11) default 0" json:"creator"`
 | 
			
		||||
	Modifier       int64   `xorm:"INT(11) default 0" json:"modifier"`
 | 
			
		||||
	CreateAt       int64   `xorm:"created" json:"createAt"`
 | 
			
		||||
	UpdateAt       int64   `xorm:"updated" json:"updateAt"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,21 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
type ProjectResult struct {
 | 
			
		||||
	Id             int64   `xorm:"not null pk autoincr BIGINT" json:"id"`
 | 
			
		||||
	ProjectId      int64   `xorm:"INT(11) index" json:"projectId"`      //项目编号
 | 
			
		||||
	SourceResultId int64   `xorm:"INT(11) index" json:"sourceResultId"` //识别结果来源编号
 | 
			
		||||
	MilepostNumber string  `xorm:"VARCHAR(50)" json:"milepostNumber"`   //里程桩号
 | 
			
		||||
	UpDown         string  `xorm:"VARCHAR(20)" json:"upDown"`           //上下行
 | 
			
		||||
	LineNum        int     `xorm:"SMALLINT default 1" json:"lineNum"`   //车道号
 | 
			
		||||
	DiseaseType    string  `xorm:"VARCHAR(50)" json:"diseaseType"`      //病害类型
 | 
			
		||||
	DiseaseLevel   string  `xorm:"VARCHAR(20)" json:"diseaseLevel"`     //病害等级
 | 
			
		||||
	Length         float64 `xorm:"decimal(18,6)" json:"length"`         //长度
 | 
			
		||||
	Width          float64 `xorm:"decimal(18,6)" json:"width"`          //宽度
 | 
			
		||||
	Acreage        float64 `xorm:"decimal(18,6)" json:"acreage"`        //面积
 | 
			
		||||
	Memo           string  `xorm:"VARCHAR(1000)" json:"memo"`           //备注说明
 | 
			
		||||
	Result         string  `xorm:"LONGTEXT" json:"result"`              //识别结果
 | 
			
		||||
	Creator        int64   `xorm:"INT(11) default 0" json:"creator"`
 | 
			
		||||
	Modifier       int64   `xorm:"INT(11) default 0" json:"modifier"`
 | 
			
		||||
	CreateAt       int64   `xorm:"created" json:"createAt"`
 | 
			
		||||
	UpdateAt       int64   `xorm:"updated" json:"updateAt"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -62,12 +62,11 @@ func UpdateTaskProgress(taskProgress *proto.TaskLogProgress) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
 | 
			
		||||
	ret := -1.0
 | 
			
		||||
func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) (int, int) {
 | 
			
		||||
	item := new(Task)
 | 
			
		||||
	h, err := DB.ID(res.TaskId).Get(item)
 | 
			
		||||
	if err != nil || !h {
 | 
			
		||||
		return ret
 | 
			
		||||
		return 0, 0
 | 
			
		||||
	}
 | 
			
		||||
	if isFailing {
 | 
			
		||||
		item.FailingCount += 1
 | 
			
		||||
| 
						 | 
				
			
			@ -79,12 +78,11 @@ func UpdateTaskProgressByLog(res *TaskResult, isFailing bool) float64 {
 | 
			
		|||
		item.FinishTime = time.Now().Unix()
 | 
			
		||||
		item.UnfinishedCount = 0
 | 
			
		||||
		item.Status = 3
 | 
			
		||||
		ret = 1.0
 | 
			
		||||
	}
 | 
			
		||||
	item.UpdateAt = time.Now().Unix()
 | 
			
		||||
	_, _ = DB.ID(res.TaskId).Cols("completed_count", "failing_count", "total_count", "unfinished_count", "update_at", "finish_time", "status").Update(item)
 | 
			
		||||
	if item.TotalCount > 0 {
 | 
			
		||||
		return 1 - float64(item.CompletedCount)/float64(item.TotalCount)
 | 
			
		||||
		return int(item.CompletedCount), int(item.UnfinishedCount)
 | 
			
		||||
	}
 | 
			
		||||
	return ret
 | 
			
		||||
	return int(item.CompletedCount), int(item.UnfinishedCount)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
type TrainTask struct {
 | 
			
		||||
	TaskId               int64   `xorm:"not null pk autoincr BIGINT(11)" json:"taskId"`
 | 
			
		||||
	TrainDatasetId       int64   `xorm:"INT(11) index" json:"trainDatasetId"`
 | 
			
		||||
	CategoryId           int     `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
 | 
			
		||||
	TaskName             string  `xorm:"VARCHAR(200)" json:"taskName"`
 | 
			
		||||
	TaskDesc             string  `xorm:"VARCHAR(500)" json:"taskDesc"`
 | 
			
		||||
	Arithmetic           string  `xorm:"VARCHAR(50)" json:"arithmetic"`
 | 
			
		||||
	ImageSize            int     `xorm:"INT" json:"imageSize"`
 | 
			
		||||
	BatchSize            int     `xorm:"INT" json:"batchSize"`
 | 
			
		||||
	EpochsSize           int     `xorm:"INT" json:"epochsSize"`
 | 
			
		||||
	OutputType           string  `xorm:"VARCHAR(20)" json:"outputType"`
 | 
			
		||||
	StartTime            int64   `xorm:"BIGINT" json:"startTime"`
 | 
			
		||||
	FinishTime           int64   `xorm:"BIGINT" json:"finishTime"`
 | 
			
		||||
	Loss                 float64 `xorm:"DECIMAL(18,6)" json:"loss"`
 | 
			
		||||
	Accuracy             float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
 | 
			
		||||
	ModelFilePath        string  `xorm:"VARCHAR(2000)" json:"modelFilePath"`
 | 
			
		||||
	ModelFileMetricsPath string  `xorm:"VARCHAR(2000)" json:"modelFileMetricsPath"`
 | 
			
		||||
	Status               int     `xorm:"not null SMALLINT default 0" json:"status"` // 1:等待执行; 2:执行中; 3:执行完成; 4:任务分配失败; 5:任务执行失败
 | 
			
		||||
	CreateAt             int64   `xorm:"created" json:"createAt"`
 | 
			
		||||
	UpdateAt             int64   `xorm:"updated" json:"updateAt"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,12 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
type TrainTaskLog struct {
 | 
			
		||||
	LogId       int64   `xorm:"not null pk autoincr BIGINT(11)" json:"logId"`
 | 
			
		||||
	TaskId      int64   `xorm:"INT(11) index" json:"taskId"`
 | 
			
		||||
	Epoch       int     `xorm:"SMALLINT" json:"epoch"`
 | 
			
		||||
	Loss        float64 `xorm:"DECIMAL(18,6)" json:"loss"`
 | 
			
		||||
	Accuracy    float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
 | 
			
		||||
	ValLoss     float64 `xorm:"DECIMAL(18,6)" json:"valLoss"`
 | 
			
		||||
	ValAccuracy float64 `xorm:"DECIMAL(18,6)" json:"valAccuracy"`
 | 
			
		||||
	CreateAt    int64   `xorm:"created" json:"createAt"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,11 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
type TrainTaskResult struct {
 | 
			
		||||
	ResultId int64   `xorm:"not null pk autoincr BIGINT(11)"  json:"resultId"`
 | 
			
		||||
	TaskId   int64   `xorm:"INT(11) index" json:"taskId"`
 | 
			
		||||
	Content  string  `xorm:"LONGTEXT" json:"content"`
 | 
			
		||||
	Result   string  `xorm:"VARCHAR(200)" json:"result"`
 | 
			
		||||
	Loss     float64 `xorm:"DECIMAL(18,6)" json:"loss"`
 | 
			
		||||
	Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
 | 
			
		||||
	CreateAt int64   `xorm:"created" json:"createAt"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,13 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
type TrainingDataset struct {
 | 
			
		||||
	DatasetId        int64   `xorm:"not null pk autoincr INT(11)" json:"datasetId"`
 | 
			
		||||
	Name             string  `xorm:"VARCHAR(200)" json:"name"`
 | 
			
		||||
	CategoryId       int     `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
 | 
			
		||||
	DatasetDesc      string  `xorm:"varchar(200)" json:"datasetDesc"`
 | 
			
		||||
	StoreName        string  `xorm:"varchar(200)" json:"storeName"`         //存储路径
 | 
			
		||||
	ValidationNumber float64 `xorm:"DECIMAL(18,4)" json:"validationNumber"` //验证占比
 | 
			
		||||
	TestNumber       float64 `xorm:"DECIMAL(18,4)" json:"testNumber"`       //测试占比
 | 
			
		||||
	CreateAt         int64   `xorm:"created" json:"createAt"`
 | 
			
		||||
	UpdateAt         int64   `xorm:"updated" json:"updateAt"`
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,16 @@
 | 
			
		|||
package model
 | 
			
		||||
 | 
			
		||||
type TrainingDatasetDetail struct {
 | 
			
		||||
	DetailId       int64  `xorm:"not null pk autoincr INT(11)" json:"detailId"`
 | 
			
		||||
	FileName       string `xorm:"VARCHAR(200)" json:"fileName"`
 | 
			
		||||
	FilePath       string `xorm:"VARCHAR(1000)" json:"filePath"`
 | 
			
		||||
	DatasetId      int64  `xorm:"INT(11) index default 0" json:"datasetId"`      //训练数据集
 | 
			
		||||
	CategoryId     int    `xorm:"not null SMALLINT default 1" json:"categoryId"` //训练集分类,1:训练集;2:测试集;3:验证集
 | 
			
		||||
	FileSize       int64  `xorm:"BIGINT" json:"fileSize"`                        //文件大小
 | 
			
		||||
	FileMd5        string `xorm:"VARCHAR(64)" json:"fileMd5"`                    //文件MD5
 | 
			
		||||
	IsDisease      int    `xorm:"TINYINT(1)" json:"isDisease"`                   //是否有病害, 1:有病害;2:无病害;
 | 
			
		||||
	OperationLogId int64  `xorm:"INT(11) index" json:"operationLogId"`           //操作日志编号
 | 
			
		||||
	Creator        int64  `xorm:"INT(11) index" json:"creator"`                  //上传人
 | 
			
		||||
	CreateAt       int64  `xorm:"created" json:"createAt"`                       //上传时间
 | 
			
		||||
	UpdateAt       int64  `xorm:"updated" json:"updateAt"`                       //更新时间
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										515
									
								
								mq/index.go
								
								
								
								
							
							
						
						
									
										515
									
								
								mq/index.go
								
								
								
								
							| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
package mq
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
| 
						 | 
				
			
			@ -8,12 +9,18 @@ import (
 | 
			
		|||
	"git.hpds.cc/Component/network/frame"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"golang.org/x/text/encoding/simplifiedchinese"
 | 
			
		||||
	"hpds_control_center/config"
 | 
			
		||||
	"hpds_control_center/internal/balance"
 | 
			
		||||
	"hpds_control_center/internal/minio"
 | 
			
		||||
	"hpds_control_center/internal/proto"
 | 
			
		||||
	"hpds_control_center/model"
 | 
			
		||||
	"hpds_control_center/pkg/utils"
 | 
			
		||||
	"io"
 | 
			
		||||
	"math"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"path"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
| 
						 | 
				
			
			@ -22,6 +29,13 @@ import (
 | 
			
		|||
	"git.hpds.cc/pavement/hpds_node"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Charset string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	UTF8    = Charset("UTF-8")
 | 
			
		||||
	GB18030 = Charset("GB18030")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	MqList   []HpdsMqNode
 | 
			
		||||
	TaskList = make(map[int64]*TaskItem)
 | 
			
		||||
| 
						 | 
				
			
			@ -379,8 +393,8 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
 | 
			
		|||
			item.UpdateAt = time.Now().Unix()
 | 
			
		||||
			_, _ = model.DB.ID(item.Id).AllCols().Update(item)
 | 
			
		||||
		} else {
 | 
			
		||||
			item.ModelId = payload["modelId"].(int64)
 | 
			
		||||
			item.NodeId = payload["nodeId"].(int64)
 | 
			
		||||
			item.ModelId = int64(payload["modelId"].(float64))
 | 
			
		||||
			item.NodeId = int64(payload["nodeId"].(float64))
 | 
			
		||||
			item.Status = 1
 | 
			
		||||
			item.IssueResult = string(pData)
 | 
			
		||||
			item.CreateAt = time.Now().Unix()
 | 
			
		||||
| 
						 | 
				
			
			@ -403,7 +417,16 @@ func TaskRequestHandler(data []byte) (frame.Tag, []byte) {
 | 
			
		|||
	//	_, _ = model.DB.Insert(item)
 | 
			
		||||
	//	//fn := payload["fileName"].(string)
 | 
			
		||||
	//	//dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(payload["file"].(string)))
 | 
			
		||||
	case TrainTaskAdd:
 | 
			
		||||
		payload := cmd.Payload.(map[string]interface{})
 | 
			
		||||
		if itemId, ok := payload["taskId"].(float64); ok {
 | 
			
		||||
			item := new(model.TrainTask)
 | 
			
		||||
			h, err := model.DB.ID(int64(itemId)).Get(item)
 | 
			
		||||
			if err != nil || !h {
 | 
			
		||||
 | 
			
		||||
			}
 | 
			
		||||
			RunTraining(item)
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -450,14 +473,16 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
 | 
			
		|||
		if err != nil {
 | 
			
		||||
			fmt.Println("接收TaskResponse数据出错", err)
 | 
			
		||||
		}
 | 
			
		||||
		//处理到项目结果表
 | 
			
		||||
		go processToProjectResult(item)
 | 
			
		||||
		//更新运行进度
 | 
			
		||||
		rat := model.UpdateTaskProgressByLog(item, isFailing)
 | 
			
		||||
		processed, unProcessed := model.UpdateTaskProgressByLog(item, isFailing)
 | 
			
		||||
		var (
 | 
			
		||||
			ratStr string
 | 
			
		||||
		)
 | 
			
		||||
		if rat > 0 && rat < 1 {
 | 
			
		||||
			ratStr = fmt.Sprintf("[已处理%2.f,剩余%2.f未处理]", 1-rat, rat)
 | 
			
		||||
		} else if rat == 1 {
 | 
			
		||||
		if unProcessed > 0 {
 | 
			
		||||
			ratStr = fmt.Sprintf("[已处理[%d],剩余[%d]未处理]", processed, unProcessed)
 | 
			
		||||
		} else {
 | 
			
		||||
			ratStr = "[已全部处理]"
 | 
			
		||||
		}
 | 
			
		||||
		taskLog := new(model.TaskLog)
 | 
			
		||||
| 
						 | 
				
			
			@ -479,6 +504,267 @@ func TaskResponseHandler(data []byte) (frame.Tag, []byte) {
 | 
			
		|||
	return frame.Tag(cmd.Command), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ModelResult struct {
 | 
			
		||||
	Code int `json:"code"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InsigmaResult struct {
 | 
			
		||||
	Code          int            `json:"code"`
 | 
			
		||||
	NumOfDiseases int            `json:"num_of_diseases"`
 | 
			
		||||
	Diseases      []DiseasesInfo `json:"diseases"`
 | 
			
		||||
	Image         string         `json:"image"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DiseasesInfo struct {
 | 
			
		||||
	Id    int           `json:"id"`
 | 
			
		||||
	Type  string        `json:"type"`
 | 
			
		||||
	Level string        `json:"level"`
 | 
			
		||||
	Param DiseasesParam `json:"param"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DiseasesParam struct {
 | 
			
		||||
	Length   float64 `json:"length"`
 | 
			
		||||
	Area     float64 `json:"area"`
 | 
			
		||||
	MaxWidth string  `json:"max_width"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type LightweightResult struct {
 | 
			
		||||
	Code       int    `json:"code"`
 | 
			
		||||
	Crack      bool   `json:"crack"`
 | 
			
		||||
	ImgDiscern string `json:"img_discern"`
 | 
			
		||||
	ImgSrc     string `json:"img_src"`
 | 
			
		||||
	Pothole    bool   `json:"pothole"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func processToProjectResult(src *model.TaskResult) {
 | 
			
		||||
	project := new(model.Project)
 | 
			
		||||
	h, err := model.DB.Table("project").Alias("p").Join("inner", []string{"dataset", "d"}, "d.project_id= p.project_id").Where("d.dataset_id=?", src.DatasetId).Get(project)
 | 
			
		||||
	if !h {
 | 
			
		||||
		err = fmt.Errorf("未能找到对应的项目信息")
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logging.L().With(zap.String("控制节点", "错误信息")).Error("获取项目信息", zap.Error(err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		mr             ModelResult
 | 
			
		||||
		mrList         []string
 | 
			
		||||
		fileDiscern    string
 | 
			
		||||
		memo           string
 | 
			
		||||
		milepostNumber string
 | 
			
		||||
		upDown         string
 | 
			
		||||
		lineNum        int
 | 
			
		||||
		width          float64
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	switch project.BizType {
 | 
			
		||||
	case 1: //道路
 | 
			
		||||
		arr := strings.Split(src.SrcPath, " ")
 | 
			
		||||
		if len(arr) > 1 {
 | 
			
		||||
			milepostNumber = GetMilepost(project.StartName, arr[1], arr[2])
 | 
			
		||||
			if arr[2] == "D" {
 | 
			
		||||
				upDown = "下行"
 | 
			
		||||
			} else {
 | 
			
		||||
				upDown = "上行"
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if len(arr) > 3 {
 | 
			
		||||
			lineNum, _ = strconv.Atoi(arr[3])
 | 
			
		||||
		}
 | 
			
		||||
	case 2: //桥梁
 | 
			
		||||
	case 3: //隧道
 | 
			
		||||
		//隧道名-采集方向(D/X)-相机编号(01-22)-采集序号(五位)K里程桩号.bmp DAXIASHAN-D-05-00003K15069.5.bmp
 | 
			
		||||
		arr := strings.Split(src.SrcPath, "K")
 | 
			
		||||
		if len(arr) > 1 {
 | 
			
		||||
			arrM := strings.Split(arr[1], ".")
 | 
			
		||||
			milepostNumber = meter2Milepost(arrM[0])
 | 
			
		||||
			arrD := strings.Split(arr[0], ".")
 | 
			
		||||
			if len(arrD) > 1 {
 | 
			
		||||
				if arrD[1] == "D" {
 | 
			
		||||
					upDown = "下行"
 | 
			
		||||
				} else {
 | 
			
		||||
					upDown = "上行"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if len(arrD) > 4 {
 | 
			
		||||
				lineNum, _ = strconv.Atoi(arrD[3])
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if len(src.Result) > 0 && src.Result[0] == '[' {
 | 
			
		||||
		mrList = make([]string, 0)
 | 
			
		||||
		if err := json.Unmarshal([]byte(src.Result), &mrList); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		list := make([]*model.ProjectResult, 0)
 | 
			
		||||
		for _, str := range mrList {
 | 
			
		||||
			if err := json.Unmarshal([]byte(str), &mr); err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if mr.Code == 2001 {
 | 
			
		||||
				ir := new(InsigmaResult)
 | 
			
		||||
				if err := json.Unmarshal([]byte(str), &ir); err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				fileDiscern = ir.Image
 | 
			
		||||
				for key, value := range ir.Diseases {
 | 
			
		||||
					if len(value.Param.MaxWidth) > 0 && width == 0 {
 | 
			
		||||
						width, _ = strconv.ParseFloat(value.Param.MaxWidth, 64)
 | 
			
		||||
					} else {
 | 
			
		||||
						width = 0
 | 
			
		||||
					}
 | 
			
		||||
					memo = fmt.Sprintf("%d. 发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", key+1, value.Type, value.Level, value.Param.Length, value.Param.MaxWidth, value.Param.Area)
 | 
			
		||||
					item := &model.ProjectResult{
 | 
			
		||||
						ProjectId:      project.ProjectId,
 | 
			
		||||
						SourceResultId: src.ResultId,
 | 
			
		||||
						MilepostNumber: milepostNumber,
 | 
			
		||||
						UpDown:         upDown,
 | 
			
		||||
						LineNum:        lineNum,
 | 
			
		||||
						DiseaseType:    value.Type,
 | 
			
		||||
						DiseaseLevel:   value.Level,
 | 
			
		||||
						Length:         value.Param.Length,
 | 
			
		||||
						Width:          width,
 | 
			
		||||
						Acreage:        value.Param.Area,
 | 
			
		||||
						Memo:           memo,
 | 
			
		||||
						Result:         fileDiscern,
 | 
			
		||||
						Creator:        0,
 | 
			
		||||
						Modifier:       0,
 | 
			
		||||
						CreateAt:       time.Now().Unix(),
 | 
			
		||||
						UpdateAt:       time.Now().Unix(),
 | 
			
		||||
					}
 | 
			
		||||
					list = append(list, item)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		_, _ = model.DB.Insert(list)
 | 
			
		||||
	} else {
 | 
			
		||||
		if err := json.Unmarshal([]byte(src.Result), &mr); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		switch mr.Code {
 | 
			
		||||
		case 0: //轻量化模型返回
 | 
			
		||||
			lr := new(LightweightResult)
 | 
			
		||||
			if err := json.Unmarshal([]byte(src.Result), &lr); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if lr.Crack || lr.Pothole {
 | 
			
		||||
				if lr.Crack {
 | 
			
		||||
					memo = "检测到裂缝"
 | 
			
		||||
				} else {
 | 
			
		||||
					memo = "检测到坑洼"
 | 
			
		||||
				}
 | 
			
		||||
				fileDiscern = lr.ImgDiscern
 | 
			
		||||
				if len(fileDiscern) == 0 {
 | 
			
		||||
					fileDiscern = lr.ImgSrc
 | 
			
		||||
				}
 | 
			
		||||
				diseaseLevelName := "重度"
 | 
			
		||||
				diseaseTypeName := ""
 | 
			
		||||
				switch project.BizType {
 | 
			
		||||
				case 2:
 | 
			
		||||
					diseaseTypeName = "结构裂缝"
 | 
			
		||||
				case 3:
 | 
			
		||||
					diseaseTypeName = "衬砌裂缝"
 | 
			
		||||
				default:
 | 
			
		||||
					diseaseTypeName = "横向裂缝"
 | 
			
		||||
				}
 | 
			
		||||
				item := &model.ProjectResult{
 | 
			
		||||
					ProjectId:      project.ProjectId,
 | 
			
		||||
					SourceResultId: src.ResultId,
 | 
			
		||||
					MilepostNumber: milepostNumber,
 | 
			
		||||
					UpDown:         upDown,
 | 
			
		||||
					LineNum:        lineNum,
 | 
			
		||||
					DiseaseType:    diseaseTypeName,
 | 
			
		||||
					DiseaseLevel:   diseaseLevelName,
 | 
			
		||||
					Length:         0,
 | 
			
		||||
					Width:          0,
 | 
			
		||||
					Acreage:        0,
 | 
			
		||||
					Memo:           memo,
 | 
			
		||||
					Result:         fileDiscern,
 | 
			
		||||
					Creator:        0,
 | 
			
		||||
					Modifier:       0,
 | 
			
		||||
					CreateAt:       time.Now().Unix(),
 | 
			
		||||
					UpdateAt:       time.Now().Unix(),
 | 
			
		||||
				}
 | 
			
		||||
				_, _ = model.DB.Insert(item)
 | 
			
		||||
			} else {
 | 
			
		||||
				fileDiscern = lr.ImgSrc
 | 
			
		||||
			}
 | 
			
		||||
			//
 | 
			
		||||
		case 2001: //网新返回有病害
 | 
			
		||||
			ir := new(InsigmaResult)
 | 
			
		||||
			if err := json.Unmarshal([]byte(src.Result), &ir); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			fileDiscern = ir.Image
 | 
			
		||||
			list := make([]*model.ProjectResult, 0)
 | 
			
		||||
			for _, val := range ir.Diseases {
 | 
			
		||||
				if len(val.Param.MaxWidth) > 0 && width == 0 {
 | 
			
		||||
					width, _ = strconv.ParseFloat(val.Param.MaxWidth, 64)
 | 
			
		||||
				} else {
 | 
			
		||||
					width = 0
 | 
			
		||||
				}
 | 
			
		||||
				memo = fmt.Sprintf("发现[%s],等级[%s],长度[%f],最大宽度[%s],面积[%f];\n", val.Type, val.Level, val.Param.Length, val.Param.MaxWidth, val.Param.Area)
 | 
			
		||||
				maxWidth, _ := strconv.ParseFloat(val.Param.MaxWidth, 64)
 | 
			
		||||
				item := &model.ProjectResult{
 | 
			
		||||
					ProjectId:      project.ProjectId,
 | 
			
		||||
					SourceResultId: src.ResultId,
 | 
			
		||||
					MilepostNumber: milepostNumber,
 | 
			
		||||
					UpDown:         upDown,
 | 
			
		||||
					LineNum:        lineNum,
 | 
			
		||||
					DiseaseType:    val.Type,
 | 
			
		||||
					DiseaseLevel:   val.Level,
 | 
			
		||||
					Length:         val.Param.Length,
 | 
			
		||||
					Width:          maxWidth,
 | 
			
		||||
					Acreage:        val.Param.Area,
 | 
			
		||||
					Memo:           memo,
 | 
			
		||||
					Result:         fileDiscern,
 | 
			
		||||
					Creator:        0,
 | 
			
		||||
					Modifier:       0,
 | 
			
		||||
					CreateAt:       time.Now().Unix(),
 | 
			
		||||
					UpdateAt:       time.Now().Unix(),
 | 
			
		||||
				}
 | 
			
		||||
				list = append(list, item)
 | 
			
		||||
			}
 | 
			
		||||
			_, _ = model.DB.Insert(list)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 里程桩加减里程,返回里程桩
 | 
			
		||||
func GetMilepost(start, num, upDown string) string {
 | 
			
		||||
	arr := strings.Split(start, "+")
 | 
			
		||||
	var (
 | 
			
		||||
		kilometre, meter, milepost, counter, res, resMilepost, resMeter float64
 | 
			
		||||
	)
 | 
			
		||||
	if len(arr) == 1 {
 | 
			
		||||
		meter = 0
 | 
			
		||||
	} else {
 | 
			
		||||
		meter, _ = strconv.ParseFloat(arr[1], 64)
 | 
			
		||||
	}
 | 
			
		||||
	str := strings.Replace(arr[0], "k", "", -1)
 | 
			
		||||
	str = strings.Replace(str, "K", "", -1)
 | 
			
		||||
	kilometre, _ = strconv.ParseFloat(str, 64)
 | 
			
		||||
	milepost = kilometre + meter/1000
 | 
			
		||||
	counter, _ = strconv.ParseFloat(num, 64)
 | 
			
		||||
	if upDown == "D" {
 | 
			
		||||
		res = milepost - counter
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		res = milepost + counter
 | 
			
		||||
	}
 | 
			
		||||
	resMilepost = math.Floor(res)
 | 
			
		||||
	resMeter = (res - resMilepost) * 100
 | 
			
		||||
	return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 米装换成里程桩号
 | 
			
		||||
func meter2Milepost(meter string) string {
 | 
			
		||||
	meter = strings.Replace(meter, "K", "", -1)
 | 
			
		||||
	m, _ := strconv.ParseFloat(meter, 64)
 | 
			
		||||
	resMilepost := math.Floor(m / 1000)
 | 
			
		||||
	resMeter := (m - resMilepost*1000) * 100
 | 
			
		||||
	return fmt.Sprintf("K%d+%.2f", int(resMilepost), resMeter)
 | 
			
		||||
}
 | 
			
		||||
func deliver(topic string, mqType uint, payload interface{}) {
 | 
			
		||||
	cli := GetMqClient(topic, mqType)
 | 
			
		||||
	pData, _ := json.Marshal(payload)
 | 
			
		||||
| 
						 | 
				
			
			@ -544,3 +830,220 @@ func TaskExecuteLogHandler(data []byte) (frame.Tag, []byte) {
 | 
			
		|||
	l.Unlock()
 | 
			
		||||
	return frame.Tag(cmd.Command), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RunTraining(task *model.TrainTask) {
 | 
			
		||||
	var (
 | 
			
		||||
		args                               []string
 | 
			
		||||
		modelPath, modelFileName, testSize string
 | 
			
		||||
		modelAcc, modelLoss                float64
 | 
			
		||||
	)
 | 
			
		||||
	fmt.Println("curr tmp dir====>>>>", config.Cfg.TmpTrainDir)
 | 
			
		||||
	modelFileName = utils.GetUUIDString()
 | 
			
		||||
	//复制训练数据集
 | 
			
		||||
	tmpTrainDir := path.Join(config.Cfg.TmpTrainDir, fmt.Sprintf("%s_%s_%d_%d", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize))
 | 
			
		||||
	fileList := make([]model.TrainingDatasetDetail, 0)
 | 
			
		||||
	_ = model.DB.Where("dataset_id = ?", task.TrainDatasetId).Find(&fileList)
 | 
			
		||||
	_ = os.MkdirAll(tmpTrainDir, os.ModePerm)
 | 
			
		||||
	_ = os.MkdirAll(path.Join(tmpTrainDir, "train"), os.ModePerm)
 | 
			
		||||
	_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "0"), os.ModePerm)
 | 
			
		||||
	_ = os.MkdirAll(path.Join(tmpTrainDir, "train", "1"), os.ModePerm)
 | 
			
		||||
	_ = os.MkdirAll(path.Join(tmpTrainDir, "val"), os.ModePerm)
 | 
			
		||||
	_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "0"), os.ModePerm)
 | 
			
		||||
	_ = os.MkdirAll(path.Join(tmpTrainDir, "val", "1"), os.ModePerm)
 | 
			
		||||
	_ = os.MkdirAll(path.Join(tmpTrainDir, "test"), os.ModePerm)
 | 
			
		||||
	for _, v := range fileList {
 | 
			
		||||
		dstFilePath := ""
 | 
			
		||||
		switch v.CategoryId {
 | 
			
		||||
		case 2:
 | 
			
		||||
			dstFilePath = "test"
 | 
			
		||||
		default:
 | 
			
		||||
			dstFilePath = "train"
 | 
			
		||||
		}
 | 
			
		||||
		if v.CategoryId != 2 {
 | 
			
		||||
			if v.IsDisease == 1 {
 | 
			
		||||
				dstFilePath = path.Join(tmpTrainDir, dstFilePath, "0")
 | 
			
		||||
			} else {
 | 
			
		||||
				dstFilePath = path.Join(tmpTrainDir, dstFilePath, "1")
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			dstFilePath = path.Join(tmpTrainDir, dstFilePath)
 | 
			
		||||
		}
 | 
			
		||||
		err := utils.CopyFile(v.FilePath, path.Join(dstFilePath, v.FileName))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("copy error: ", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelPath = path.Join(config.Cfg.ModelOutPath, fmt.Sprintf("%s_%s_%d_%d_%s", modelFileName, task.Arithmetic, task.BatchSize, task.EpochsSize, task.OutputType))
 | 
			
		||||
	_ = os.MkdirAll(modelPath, os.ModePerm)
 | 
			
		||||
 | 
			
		||||
	dt := new(model.TrainingDataset)
 | 
			
		||||
	_, err := model.DB.ID(task.TrainDatasetId).Get(dt)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		goto ReturnPoint
 | 
			
		||||
	}
 | 
			
		||||
	testSize = fmt.Sprintf("%.2f", dt.ValidationNumber/100)
 | 
			
		||||
	//执行训练命令
 | 
			
		||||
	args = []string{"--dataset=" + path.Join(tmpTrainDir, "train"),
 | 
			
		||||
		"--img_size=" + strconv.Itoa(task.ImageSize), "--batch_size=" + strconv.Itoa(task.BatchSize), "--test_size=" + testSize,
 | 
			
		||||
		"--epochs=" + strconv.Itoa(task.EpochsSize), "--model=" + task.Arithmetic, "--model_save=" + path.Join(modelPath, modelFileName+".h5"),
 | 
			
		||||
	}
 | 
			
		||||
	fmt.Println("args====>>>", args)
 | 
			
		||||
	err = ExecCommand(config.Cfg.TrainScriptPath, args, path.Join(modelPath, modelFileName+".log"), task.TaskId)
 | 
			
		||||
ReturnPoint:
 | 
			
		||||
	//返回训练结果
 | 
			
		||||
	modelMetricsFile := modelFileName + "_model_metrics.png"
 | 
			
		||||
	task.FinishTime = time.Now().Unix()
 | 
			
		||||
	task.ModelFilePath = path.Join(modelPath, modelFileName+".h5")
 | 
			
		||||
	task.Loss = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Loss:")
 | 
			
		||||
	task.Accuracy = GetIndicatorByLog(path.Join(modelPath, modelFileName+".log"), "[INFO] Model Validation Accuracy:")
 | 
			
		||||
	task.Status = 3
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		task.Status = 5
 | 
			
		||||
	}
 | 
			
		||||
	task.ModelFileMetricsPath = path.Join(modelPath, modelMetricsFile)
 | 
			
		||||
	_, _ = model.DB.ID(task.TaskId).AllCols().Update(task)
 | 
			
		||||
	if utils.PathExists(path.Join(modelPath, modelFileName+".log")) {
 | 
			
		||||
		logContext := utils.ReadFile(path.Join(modelPath, modelFileName+".log"))
 | 
			
		||||
		taskRes := new(model.TrainTaskResult)
 | 
			
		||||
		taskRes.TaskId = task.TaskId
 | 
			
		||||
		taskRes.CreateAt = time.Now().Unix()
 | 
			
		||||
		taskRes.Content = string(logContext)
 | 
			
		||||
		taskRes.Result = path.Join(modelPath, modelMetricsFile)
 | 
			
		||||
		taskRes.Accuracy = modelAcc
 | 
			
		||||
		taskRes.Loss = modelLoss
 | 
			
		||||
		c, err := model.DB.Insert(taskRes)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			fmt.Println("model.DB.Insert(taskRes) error ========>>>>>>", err)
 | 
			
		||||
		}
 | 
			
		||||
		fmt.Println("model.DB.Insert(taskRes) count ========>>>>>>", c)
 | 
			
		||||
	} else {
 | 
			
		||||
		fmt.Println("logContext========>>>>>>未读取")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetIndicatorByLog(logFileName, indicator string) float64 {
 | 
			
		||||
	logFn, _ := os.Open(logFileName)
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = logFn.Close()
 | 
			
		||||
	}()
 | 
			
		||||
	buf := bufio.NewReader(logFn)
 | 
			
		||||
	for {
 | 
			
		||||
		line, err := buf.ReadString('\n')
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err == io.EOF {
 | 
			
		||||
				//fmt.Println("File read ok!")
 | 
			
		||||
				break
 | 
			
		||||
			} else {
 | 
			
		||||
				fmt.Println("Read file error!", err)
 | 
			
		||||
				return 0
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if strings.Index(line, indicator) >= 0 {
 | 
			
		||||
			str := strings.Replace(line, indicator, "", -1)
 | 
			
		||||
			str = strings.Replace(str, "\n", "", -1)
 | 
			
		||||
			value, _ := strconv.ParseFloat(strings.Trim(str, " "), 64)
 | 
			
		||||
			return value
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ExecCommand(cmd string, args []string, logFileName string, taskId int64) (err error) {
 | 
			
		||||
	logFile, _ := os.Create(logFileName)
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = logFile.Close()
 | 
			
		||||
	}()
 | 
			
		||||
	fmt.Print("开始训练......")
 | 
			
		||||
 | 
			
		||||
	c := exec.Command(cmd, args...) // mac or linux
 | 
			
		||||
	stdout, err := c.StdoutPipe()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		wg sync.WaitGroup
 | 
			
		||||
	)
 | 
			
		||||
	wg.Add(1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		reader := bufio.NewReader(stdout)
 | 
			
		||||
		var (
 | 
			
		||||
			epoch int
 | 
			
		||||
			//modelLoss, modelAcc float64
 | 
			
		||||
		)
 | 
			
		||||
		for {
 | 
			
		||||
			readString, err := reader.ReadString('\n')
 | 
			
		||||
			if err != nil || err == io.EOF {
 | 
			
		||||
				fmt.Println("训练2===>>>", err)
 | 
			
		||||
				//wg.Done()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			byte2String := ConvertByte2String([]byte(readString), "GB18030")
 | 
			
		||||
			_, _ = fmt.Fprint(logFile, byte2String)
 | 
			
		||||
			if strings.Index(byte2String, "Epoch") >= 0 {
 | 
			
		||||
				str := strings.Replace(byte2String, "Epoch ", "", -1)
 | 
			
		||||
				arr := strings.Split(str, "/")
 | 
			
		||||
				epoch, _ = strconv.Atoi(arr[0])
 | 
			
		||||
			}
 | 
			
		||||
			if strings.Index(byte2String, "- loss:") > 0 &&
 | 
			
		||||
				strings.Index(byte2String, "- accuracy:") > 0 &&
 | 
			
		||||
				strings.Index(byte2String, "- val_loss:") > 0 &&
 | 
			
		||||
				strings.Index(byte2String, "- val_accuracy:") > 0 {
 | 
			
		||||
				var (
 | 
			
		||||
					loss, acc, valLoss, valAcc float64
 | 
			
		||||
				)
 | 
			
		||||
				arr := strings.Split(byte2String, "-")
 | 
			
		||||
				for _, v := range arr {
 | 
			
		||||
					if strings.Index(v, "loss:") > 0 && strings.Index(v, "val_loss:") < 0 {
 | 
			
		||||
						strLoss := strings.Replace(v, " loss: ", "", -1)
 | 
			
		||||
						loss, _ = strconv.ParseFloat(strings.Trim(strLoss, " "), 64)
 | 
			
		||||
					}
 | 
			
		||||
					if strings.Index(v, "accuracy:") > 0 && strings.Index(v, "val_accuracy:") < 0 {
 | 
			
		||||
						strAcc := strings.Replace(v, " accuracy: ", "", -1)
 | 
			
		||||
						acc, _ = strconv.ParseFloat(strings.Trim(strAcc, " "), 64)
 | 
			
		||||
					}
 | 
			
		||||
					if strings.Index(v, "val_loss:") > 0 {
 | 
			
		||||
						strValLoss := strings.Replace(v, "val_loss: ", "", -1)
 | 
			
		||||
						valLoss, _ = strconv.ParseFloat(strings.Trim(strValLoss, " "), 64)
 | 
			
		||||
					}
 | 
			
		||||
					if strings.Index(v, "val_accuracy:") > 0 {
 | 
			
		||||
						strValAcc := strings.Replace(v, "val_accuracy: ", "", -1)
 | 
			
		||||
						strValAcc = strings.Replace(strValAcc, "\n", "", -1)
 | 
			
		||||
						valAcc, _ = strconv.ParseFloat(strings.Trim(strValAcc, " "), 64)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				taskLog := new(model.TrainTaskLog)
 | 
			
		||||
				taskLog.Epoch = epoch
 | 
			
		||||
				taskLog.TaskId = taskId
 | 
			
		||||
				taskLog.CreateAt = time.Now().Unix()
 | 
			
		||||
				taskLog.Loss = loss
 | 
			
		||||
				taskLog.Accuracy = acc
 | 
			
		||||
				taskLog.ValLoss = valLoss
 | 
			
		||||
				taskLog.ValAccuracy = valAcc
 | 
			
		||||
				_, _ = model.DB.Insert(taskLog)
 | 
			
		||||
			}
 | 
			
		||||
			fmt.Print(byte2String)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	err = c.Start()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println("训练3===>>>", err)
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
func ConvertByte2String(byte []byte, charset Charset) string {
 | 
			
		||||
	var str string
 | 
			
		||||
	switch charset {
 | 
			
		||||
	case GB18030:
 | 
			
		||||
		var decodeBytes, _ = simplifiedchinese.GB18030.NewDecoder().Bytes(byte)
 | 
			
		||||
		str = string(decodeBytes)
 | 
			
		||||
	case UTF8:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	default:
 | 
			
		||||
		str = string(byte)
 | 
			
		||||
	}
 | 
			
		||||
	return str
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,6 +9,7 @@ const (
 | 
			
		|||
	ModelIssueResponse
 | 
			
		||||
	TaskExecuteLog
 | 
			
		||||
	TaskLog
 | 
			
		||||
	TrainTaskAdd
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type InstructionReq struct {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,94 @@
 | 
			
		|||
package utils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.hpds.cc/Component/logging"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CopyFile(src, dst string) error {
 | 
			
		||||
	sourceFileStat, err := os.Stat(src)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !sourceFileStat.Mode().IsRegular() {
 | 
			
		||||
		return fmt.Errorf("%s is not a regular file", src)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	source, err := os.Open(src)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer func(source *os.File) {
 | 
			
		||||
		_ = source.Close()
 | 
			
		||||
	}(source)
 | 
			
		||||
 | 
			
		||||
	destination, err := os.Create(dst)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer func(destination *os.File) {
 | 
			
		||||
		_ = destination.Close()
 | 
			
		||||
	}(destination)
 | 
			
		||||
	_, err = io.Copy(destination, source)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PathExists(path string) bool {
 | 
			
		||||
	_, err := os.Stat(path)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if os.IsNotExist(err) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadFile 读取到file中,再利用ioutil将file直接读取到[]byte中, 这是最优
 | 
			
		||||
func ReadFile(fn string) []byte {
 | 
			
		||||
	f, err := os.Open(fn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err))
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	defer func(f *os.File) {
 | 
			
		||||
		_ = f.Close()
 | 
			
		||||
	}(f)
 | 
			
		||||
 | 
			
		||||
	fd, err := io.ReadAll(f)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err))
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fd
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetFileName(fn string) string {
 | 
			
		||||
	fileType := path.Ext(fn)
 | 
			
		||||
	return strings.TrimSuffix(fn, fileType)
 | 
			
		||||
}
 | 
			
		||||
func GetFileNameAndExt(fn string) string {
 | 
			
		||||
	_, fileName := filepath.Split(fn)
 | 
			
		||||
	return fileName
 | 
			
		||||
}
 | 
			
		||||
func GetFileMd5(data []byte) string {
 | 
			
		||||
	hash := md5.New()
 | 
			
		||||
	hash.Write(data)
 | 
			
		||||
	return hex.EncodeToString(hash.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FileToBase64(fn string) string {
 | 
			
		||||
	buff := ReadFile(fn)
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(buff) // 加密成base64字符串
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,126 @@
 | 
			
		|||
package utils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mime/multipart"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) {
 | 
			
		||||
	var paramStr string = ""
 | 
			
		||||
	if contentType, ok := header["Content-Type"]; ok && strings.Contains(contentType, "json") {
 | 
			
		||||
		bytesData, _ := json.Marshal(params)
 | 
			
		||||
		paramStr = string(bytesData)
 | 
			
		||||
	} else {
 | 
			
		||||
		for k, v := range params {
 | 
			
		||||
			if len(paramStr) == 0 {
 | 
			
		||||
				paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v))
 | 
			
		||||
			} else {
 | 
			
		||||
				paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := &http.Client{}
 | 
			
		||||
	req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	for k, v := range header {
 | 
			
		||||
		req.Header.Set(k, v)
 | 
			
		||||
	}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if resp.Body != nil {
 | 
			
		||||
			err = resp.Body.Close()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	var body []byte
 | 
			
		||||
	if resp.Body != nil {
 | 
			
		||||
		body, err = io.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return body, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type UploadFile struct {
 | 
			
		||||
	// 表单名称
 | 
			
		||||
	Name     string
 | 
			
		||||
	Filepath string
 | 
			
		||||
	// 文件全路径
 | 
			
		||||
	File *bytes.Buffer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string {
 | 
			
		||||
	requestBody, realContentType := getReader(reqParams, contentType, files)
 | 
			
		||||
	httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody)
 | 
			
		||||
	// 添加请求头
 | 
			
		||||
	httpRequest.Header.Add("Content-Type", realContentType)
 | 
			
		||||
	if headers != nil {
 | 
			
		||||
		for k, v := range headers {
 | 
			
		||||
			httpRequest.Header.Add(k, v)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	httpClient := &http.Client{}
 | 
			
		||||
	// 发送请求
 | 
			
		||||
	resp, err := httpClient.Do(httpRequest)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer func(Body io.ReadCloser) {
 | 
			
		||||
		_ = Body.Close()
 | 
			
		||||
	}(resp.Body)
 | 
			
		||||
	response, _ := io.ReadAll(resp.Body)
 | 
			
		||||
	return string(response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) {
 | 
			
		||||
	if strings.Index(contentType, "json") > -1 {
 | 
			
		||||
		bytesData, _ := json.Marshal(reqParams)
 | 
			
		||||
		return bytes.NewReader(bytesData), contentType
 | 
			
		||||
	} else if files != nil {
 | 
			
		||||
		body := &bytes.Buffer{}
 | 
			
		||||
		// 文件写入 body
 | 
			
		||||
		writer := multipart.NewWriter(body)
 | 
			
		||||
		for _, uploadFile := range files {
 | 
			
		||||
			part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				panic(err)
 | 
			
		||||
			}
 | 
			
		||||
			_, err = io.Copy(part, uploadFile.File)
 | 
			
		||||
		}
 | 
			
		||||
		// 其他参数列表写入 body
 | 
			
		||||
		for k, v := range reqParams {
 | 
			
		||||
			if err := writer.WriteField(k, v); err != nil {
 | 
			
		||||
				panic(err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if err := writer.Close(); err != nil {
 | 
			
		||||
			panic(err)
 | 
			
		||||
		}
 | 
			
		||||
		// 上传文件需要自己专用的contentType
 | 
			
		||||
		return body, writer.FormDataContentType()
 | 
			
		||||
	} else {
 | 
			
		||||
		urlValues := url.Values{}
 | 
			
		||||
		for key, val := range reqParams {
 | 
			
		||||
			urlValues.Set(key, val)
 | 
			
		||||
		}
 | 
			
		||||
		reqBody := urlValues.Encode()
 | 
			
		||||
		return strings.NewReader(reqBody), contentType
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,139 @@
 | 
			
		|||
package utils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"golang.org/x/image/bmp"
 | 
			
		||||
	"golang.org/x/image/tiff"
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/color"
 | 
			
		||||
	"image/jpeg"
 | 
			
		||||
	"image/png"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func BuffToImage(in []byte) image.Image {
 | 
			
		||||
	buff := bytes.NewBuffer(in)
 | 
			
		||||
	m, _, _ := image.Decode(buff)
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Clip 图片裁剪
 | 
			
		||||
func Clip(in []byte, wi, hi int, equalProportion bool) (out image.Image, imageType string, err error) {
 | 
			
		||||
	buff := bytes.NewBuffer(in)
 | 
			
		||||
	m, imgType, _ := image.Decode(buff)
 | 
			
		||||
	rgbImg := m.(*image.YCbCr)
 | 
			
		||||
	if equalProportion {
 | 
			
		||||
		w := m.Bounds().Max.X
 | 
			
		||||
		h := m.Bounds().Max.Y
 | 
			
		||||
		if w > 0 && h > 0 && wi > 0 && hi > 0 {
 | 
			
		||||
			wi, hi = fixSize(w, h, wi, hi)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return rgbImg.SubImage(image.Rect(0, 0, wi, hi)), imgType, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fixSize(img1W, img2H, wi, hi int) (new1W, new2W int) {
 | 
			
		||||
	var ( //为了方便计算,将图片的宽转为 float64
 | 
			
		||||
		imgWidth, imgHeight = float64(img1W), float64(img2H)
 | 
			
		||||
		ratio               float64
 | 
			
		||||
	)
 | 
			
		||||
	if imgWidth >= imgHeight {
 | 
			
		||||
		ratio = imgWidth / float64(wi)
 | 
			
		||||
		return int(imgWidth * ratio), int(imgHeight * ratio)
 | 
			
		||||
	}
 | 
			
		||||
	ratio = imgHeight / float64(hi)
 | 
			
		||||
	return int(imgWidth * ratio), int(imgHeight * ratio)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Gray(in []byte) (out image.Image, err error) {
 | 
			
		||||
	m := BuffToImage(in)
 | 
			
		||||
	bounds := m.Bounds()
 | 
			
		||||
	dx := bounds.Dx()
 | 
			
		||||
	dy := bounds.Dy()
 | 
			
		||||
	newRgba := image.NewRGBA(bounds)
 | 
			
		||||
	for i := 0; i < dx; i++ {
 | 
			
		||||
		for j := 0; j < dy; j++ {
 | 
			
		||||
			colorRgb := m.At(i, j)
 | 
			
		||||
			_, g, _, a := colorRgb.RGBA()
 | 
			
		||||
			gUint8 := uint8(g >> 8)
 | 
			
		||||
			aUint8 := uint8(a >> 8)
 | 
			
		||||
			newRgba.SetRGBA(i, j, color.RGBA{R: gUint8, G: gUint8, B: gUint8, A: aUint8})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	r := image.Rect(0, 0, dx, dy)
 | 
			
		||||
	return newRgba.SubImage(r), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Rotate90(in []byte) image.Image {
 | 
			
		||||
	m := BuffToImage(in)
 | 
			
		||||
	rotate90 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
 | 
			
		||||
	// 矩阵旋转
 | 
			
		||||
	for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
 | 
			
		||||
		for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
 | 
			
		||||
			//  设置像素点
 | 
			
		||||
			rotate90.Set(m.Bounds().Max.Y-x, y, m.At(y, x))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return rotate90
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Rotate180 旋转180度
 | 
			
		||||
func Rotate180(in []byte) image.Image {
 | 
			
		||||
	m := BuffToImage(in)
 | 
			
		||||
	rotate180 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dx(), m.Bounds().Dy()))
 | 
			
		||||
	// 矩阵旋转
 | 
			
		||||
	for x := m.Bounds().Min.X; x < m.Bounds().Max.X; x++ {
 | 
			
		||||
		for y := m.Bounds().Min.Y; y < m.Bounds().Max.Y; y++ {
 | 
			
		||||
			//  设置像素点
 | 
			
		||||
			rotate180.Set(m.Bounds().Max.X-x, m.Bounds().Max.Y-y, m.At(x, y))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return rotate180
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Rotate270 旋转270度
 | 
			
		||||
func Rotate270(in []byte) image.Image {
 | 
			
		||||
	m := BuffToImage(in)
 | 
			
		||||
	rotate270 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx()))
 | 
			
		||||
	// 矩阵旋转
 | 
			
		||||
	for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ {
 | 
			
		||||
		for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- {
 | 
			
		||||
			// 设置像素点
 | 
			
		||||
			rotate270.Set(x, m.Bounds().Max.X-y, m.At(y, x))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return rotate270
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ImageToBase64(img image.Image, imgType string) string {
 | 
			
		||||
	buff := ImageToBuff(img, imgType)
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(buff.Bytes())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ImageToBuff(img image.Image, imgType string) *bytes.Buffer {
 | 
			
		||||
	buff := bytes.NewBuffer(nil)
 | 
			
		||||
	switch imgType {
 | 
			
		||||
	case "bmp":
 | 
			
		||||
		imgType = "bmp"
 | 
			
		||||
		_ = bmp.Encode(buff, img)
 | 
			
		||||
	case "png":
 | 
			
		||||
		imgType = "png"
 | 
			
		||||
		_ = png.Encode(buff, img)
 | 
			
		||||
	case "tiff":
 | 
			
		||||
		imgType = "tiff"
 | 
			
		||||
		_ = tiff.Encode(buff, img, nil)
 | 
			
		||||
	default:
 | 
			
		||||
		imgType = "jpeg"
 | 
			
		||||
		_ = jpeg.Encode(buff, img, nil)
 | 
			
		||||
	}
 | 
			
		||||
	return buff
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ImgFileToBase64(fn string) string {
 | 
			
		||||
	fileByte := ReadFile(fn)
 | 
			
		||||
	buff := bytes.NewBuffer(fileByte)
 | 
			
		||||
	m, _, _ := image.Decode(buff)
 | 
			
		||||
	return "data:image/jpeg;base64," + ImageToBase64(m, "jpeg")
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,48 @@
 | 
			
		|||
package utils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha1"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
		RandomString 产生随机数
 | 
			
		||||
	  - size 随机码的位数
 | 
			
		||||
	  - kind 0    // 纯数字
 | 
			
		||||
	    1    // 小写字母
 | 
			
		||||
	    2    // 大写字母
 | 
			
		||||
	    3    // 数字、大小写字母
 | 
			
		||||
*/
 | 
			
		||||
func RandomString(size int, kind int) string {
 | 
			
		||||
	iKind, kinds, rsBytes := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size)
 | 
			
		||||
	isAll := kind > 2 || kind < 0
 | 
			
		||||
	rand.Seed(time.Now().UnixNano())
 | 
			
		||||
	for i := 0; i < size; i++ {
 | 
			
		||||
		if isAll { // random iKind
 | 
			
		||||
			iKind = rand.Intn(3)
 | 
			
		||||
		}
 | 
			
		||||
		scope, base := kinds[iKind][0], kinds[iKind][1]
 | 
			
		||||
		rsBytes[i] = uint8(base + rand.Intn(scope))
 | 
			
		||||
	}
 | 
			
		||||
	return string(rsBytes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserSha1Pass(pass, salt string) string {
 | 
			
		||||
	key := []byte(salt)
 | 
			
		||||
	mac := hmac.New(sha1.New, key)
 | 
			
		||||
	mac.Write([]byte(pass))
 | 
			
		||||
	//进行base64编码
 | 
			
		||||
	res := base64.StdEncoding.EncodeToString(mac.Sum(nil))
 | 
			
		||||
	return res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUUIDString() string {
 | 
			
		||||
	u, _ := uuid.NewUUID()
 | 
			
		||||
	str := strings.Replace(u.String(), "-", "", -1)
 | 
			
		||||
	return str
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,91 @@
 | 
			
		|||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
#获取对应的参数 没有的话赋默认值
 | 
			
		||||
ARGS=`getopt -o d::s::b::e::m::o:: --long dataset::,img_size::,batch_size::,epochs::,model::,model_save:: -n 'example.sh' -- "$@"`
 | 
			
		||||
if [ $? != 0 ]; then
 | 
			
		||||
    echo "Terminating..."
 | 
			
		||||
    exit 1
 | 
			
		||||
fi
 | 
			
		||||
echo $ARGS
 | 
			
		||||
eval set -- "${ARGS}"
 | 
			
		||||
while true;
 | 
			
		||||
do
 | 
			
		||||
    case "$1" in
 | 
			
		||||
        -d|--dataset)
 | 
			
		||||
            case "$2" in
 | 
			
		||||
                "")
 | 
			
		||||
		    echo "Internal error!"
 | 
			
		||||
            	    exit 1
 | 
			
		||||
		    ;;
 | 
			
		||||
		*)
 | 
			
		||||
                    Dataset=$2
 | 
			
		||||
                    shift 2;
 | 
			
		||||
                    ;;
 | 
			
		||||
            esac
 | 
			
		||||
             ;;
 | 
			
		||||
        -s|--img_size)
 | 
			
		||||
    case "$2" in
 | 
			
		||||
		"")
 | 
			
		||||
                    echo "Internal error!"
 | 
			
		||||
                    exit 1
 | 
			
		||||
                    ;;
 | 
			
		||||
 | 
			
		||||
                *)
 | 
			
		||||
                  ImgSize=$2;
 | 
			
		||||
                    shift 2;
 | 
			
		||||
                    ;;
 | 
			
		||||
            esac
 | 
			
		||||
             ;;
 | 
			
		||||
        -b|--batch_size)
 | 
			
		||||
            case "$2" in
 | 
			
		||||
                *)
 | 
			
		||||
                  BatchSize=$2;
 | 
			
		||||
                    shift 2;
 | 
			
		||||
                    ;;
 | 
			
		||||
            esac
 | 
			
		||||
            ;;
 | 
			
		||||
        -e|--epochs)
 | 
			
		||||
            case "$2" in
 | 
			
		||||
                *)
 | 
			
		||||
                  Epochs=$2;
 | 
			
		||||
                    shift 2;
 | 
			
		||||
                    ;;
 | 
			
		||||
            esac
 | 
			
		||||
            ;;
 | 
			
		||||
        -m|--model)
 | 
			
		||||
            case "$2" in
 | 
			
		||||
                *)
 | 
			
		||||
                  Model=$2;
 | 
			
		||||
                    shift 2;
 | 
			
		||||
                    ;;
 | 
			
		||||
            esac
 | 
			
		||||
            ;;
 | 
			
		||||
         -o|--model_save)
 | 
			
		||||
            case "$2" in
 | 
			
		||||
                *)
 | 
			
		||||
                  ModelSave=$2;
 | 
			
		||||
                    shift 2;
 | 
			
		||||
                    ;;
 | 
			
		||||
            esac
 | 
			
		||||
            ;;
 | 
			
		||||
 | 
			
		||||
        --)
 | 
			
		||||
            shift
 | 
			
		||||
            break
 | 
			
		||||
            ;;
 | 
			
		||||
        *)
 | 
			
		||||
            echo "Internal error!"
 | 
			
		||||
            exit 1
 | 
			
		||||
            ;;
 | 
			
		||||
    esac
 | 
			
		||||
 | 
			
		||||
done
 | 
			
		||||
 | 
			
		||||
#echo ${Dataset}
 | 
			
		||||
 | 
			
		||||
eval "$(conda shell.bash hook)"
 | 
			
		||||
conda activate hpds_train
 | 
			
		||||
cd /home/data/hpds_train_keras
 | 
			
		||||
python train.py --dataset ${Dataset} --img_size ${ImgSize}  --batch_size  ${BatchSize} \
 | 
			
		||||
                 --epochs  ${Epochs} --model  ${Model} --model_save  ${ModelSave}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue