From db923ac9aea0df58da0f2bc79cfc5928a2424f9f Mon Sep 17 00:00:00 2001 From: wangjian Date: Thu, 18 May 2023 11:01:34 +0800 Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E5=A2=9E=E5=8A=A0=E8=BE=B9=E7=BC=98?= =?UTF-8?q?=E8=AE=BE=E5=A4=87=E6=95=B0=E6=8D=AE=E9=9B=86=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E6=A0=87=E6=B3=A8=EF=BC=9B=202=E3=80=81=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE=E9=9B=86=E3=80=81=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E4=BB=BB=E5=8A=A1=E3=80=81=E8=AE=AD=E7=BB=83=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E3=80=81=E8=AE=AD=E7=BB=83=E7=BB=93=E6=9E=9C=E5=AF=BC?= =?UTF-8?q?=E5=87=BA=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/server.go | 5 +- config/config-dev.yaml | 18 +++- config/config.go | 1 + config/config.yaml | 1 + internal/handler/dataset.go | 54 +++++++++++- internal/handler/edge.go | 40 +++++++++ internal/handler/file.go | 15 ++++ internal/handler/task.go | 52 +++++++++++ internal/proto/request.go | 17 ++++ internal/proto/response.go | 24 ++++++ internal/router/router.go | 24 +++++- internal/service/dataset.go | 191 ++++++++++++++++++++++++++++++++++++++--- internal/service/edge.go | 96 +++++++++++++++++++++ internal/service/fileManage.go | 48 +++++++++++ internal/service/manage.go | 16 ++-- internal/service/task.go | 129 ++++++++++++++++++++++++++++ model/file.go | 1 + model/index.go | 19 +++- model/trainTask.go | 18 ++++ model/trainTaskLog.go | 12 +++ model/trainingDatasetDetail.go | 3 +- pkg/minio/index.go | 46 ++++++++++ pkg/utils/file.go | 119 +++++++++++++++++++++++++ pkg/utils/http.go | 126 +++++++++++++++++++++++++++ 24 files changed, 1039 insertions(+), 36 deletions(-) create mode 100644 internal/handler/edge.go create mode 100644 internal/service/edge.go create mode 100644 model/trainTask.go create mode 100644 model/trainTaskLog.go create mode 100644 pkg/minio/index.go create mode 100644 pkg/utils/file.go create mode 100644 pkg/utils/http.go diff --git a/cmd/server.go b/cmd/server.go index b6cc915..323cca2 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -84,13 +84,12 @@ func NewStartCmd() *cobra.Command { tags, 300, 300, 300) must(err) + logger := LoadLoggerConfig(cfg.Logging) //连接数据库 - model.New(cfg.Db.DriveName, cfg.Db.Conn, cfg.Mode == "dev") + model.New(cfg.Db.DriveName, cfg.Db.Conn, cfg.Mode == "dev", logger) //连接redis model.NewCache(cfg.Cache) - logger := LoadLoggerConfig(cfg.Logging) - //创建消息连接点 mq.MqList, err = mq.NewMqClient(cfg.Funcs, cfg.Node, logger) must(err) diff --git a/config/config-dev.yaml b/config/config-dev.yaml index 53802df..abef04f 100644 --- a/config/config-dev.yaml +++ b/config/config-dev.yaml @@ -2,6 +2,7 @@ name: web host: 0.0.0.0 port: 8088 mode: dev +trainDir : ./classification_dataset_balanced/ logging: path: ./logs prefix: hpds-iot-web @@ -16,7 +17,7 @@ logging: mineData: accessKey: f0bda738033e47ffbfbd5d3f865c19e1 minio: - endpoint: 192.168.0.200:9000 + endpoint: 127.0.0.1:9000 accessKeyId: root secretAccessKey: OIxv7QptYBO3 consul: @@ -29,9 +30,18 @@ db: conn: root:OIxv7QptYBO3@tcp(114.55.236.153:27136)/hpds_jky?charset=utf8mb4 drive_name: mysql cache: - host: 192.168.0.200 + host: 127.0.0.1 port: 6379 - db: 8 + db: 0 pool_size: 10 +node: + host: 127.0.0.1 + port: 27188 + token: 06d36c6f5705507dae778fdce90d0767 functions: - - name: web-sf \ No newline at end of file + - name: task-request + dataTag : 12 + mqType: 1 + - name: task-log + dataTag: 28 + mqType: 2 \ No newline at end of file diff --git a/config/config.go b/config/config.go index 8be8716..790dd2e 100644 --- a/config/config.go +++ b/config/config.go @@ -15,6 +15,7 @@ type WebConfig struct { Host string `yaml:"host,omitempty"` Port int `yaml:"port,omitempty"` Mode string `yaml:"mode,omitempty"` + TrainDir string `yaml:"trainDir,omitempty"` Consul ConsulConfig `yaml:"consul,omitempty"` Db DbConfig `yaml:"db"` Cache CacheConfig `yaml:"cache"` diff --git a/config/config.yaml b/config/config.yaml index 8be431d..ec6aaa8 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -2,6 +2,7 @@ name: web host: 0.0.0.0 port: 8088 mode: dev +trainDir : ./classification_dataset_balanced/ logging: path: ./logs prefix: hpds-iot-web diff --git a/internal/handler/dataset.go b/internal/handler/dataset.go index c727687..39ccfa5 100644 --- a/internal/handler/dataset.go +++ b/internal/handler/dataset.go @@ -79,18 +79,66 @@ func (s HandlerService) DatasetInfo(c *gin.Context) (data interface{}, err error return } -func (s HandlerService) CreateTraining(c *gin.Context) (data interface{}, err error) { +func (s HandlerService) CreateTrainDataset(c *gin.Context) (data interface{}, err error) { repo := service.NewDatasetService(s.AppConfig, s.Engine, s.Logger) us, _ := c.Get("operatorUser") userInfo := us.(*model.SystemUser) var req proto.TrainDatasetRequest err = c.ShouldBindJSON(&req) if err != nil { - go s.SaveLog("CreateTraining", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + go s.SaveLog("CreateTrainDataset", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") return nil, e.NewValidErr(err) } req.UserId = userInfo.UserId - data, err = repo.CreateTraining(c, req) + data, err = repo.CreateTrainDataset(c, req) go s.SaveLog("创建训练数据集", "Training", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") return } + +func (s HandlerService) TrainDatasetList(c *gin.Context) (data interface{}, err error) { + repo := service.NewDatasetService(s.AppConfig, s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.TrainDatasetItemRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("CreateTrainDataset", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + if req.Size < 1 { + req.Size = 20 + } + if req.Size > 1000 { + req.Size = 1000 + } + if req.Page < 1 { + req.Page = 1 + } + data, err = repo.TrainDatasetList(c, req) + go s.SaveLog("获取训练数据集列表", "Training", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} + +func (s HandlerService) TrainDatasetFileList(c *gin.Context) (data interface{}, err error) { + repo := service.NewDatasetService(s.AppConfig, s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.TrainDatasetItemRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("TrainDatasetFileList", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + if req.Size < 1 { + req.Size = 20 + } + if req.Size > 1000 { + req.Size = 1000 + } + if req.Page < 1 { + req.Page = 1 + } + data, err = repo.TrainDatasetFileList(c, req) + go s.SaveLog("获取训练数据集中的文件列表", "Training", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} diff --git a/internal/handler/edge.go b/internal/handler/edge.go new file mode 100644 index 0000000..f923952 --- /dev/null +++ b/internal/handler/edge.go @@ -0,0 +1,40 @@ +package handler + +import ( + "fmt" + "github.com/gin-gonic/gin" + "hpds-iot-web/internal/proto" + "hpds-iot-web/internal/service" + "hpds-iot-web/model" + e "hpds-iot-web/pkg/err" +) + +func (s HandlerService) GetEdgeList(c *gin.Context) (data interface{}, err error) { + repo := service.NewEdgeService(s.AppConfig, s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.EdgeDatasetRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("GetEdgeList", "Dataset", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + data, err = repo.GetEdgeList(c, req) + go s.SaveLog("获取边缘端数据列表", "Dataset", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} + +func (s HandlerService) GetEdgeInfo(c *gin.Context) (data interface{}, err error) { + repo := service.NewEdgeService(s.AppConfig, s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.EdgeDatasetRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("GetEdgeInfo", "Dataset", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + data, err = repo.GetEdgeInfo(c, req) + go s.SaveLog("获取边缘端数据详情", "Dataset", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} diff --git a/internal/handler/file.go b/internal/handler/file.go index 87b85e8..e3a7ed0 100644 --- a/internal/handler/file.go +++ b/internal/handler/file.go @@ -61,3 +61,18 @@ func (s HandlerService) FileList(c *gin.Context) (data interface{}, err error) { go s.SaveLog("获取数据集详情", "FileManage", "", "", ToString(data), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") return } + +func (s HandlerService) FileLabel(c *gin.Context) (data interface{}, err error) { + repo := service.NewFileService(s.AppConfig, s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.FileLabelRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("FileLabel", "FileManage", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + data, err = repo.FileLabel(c, req) + go s.SaveLog("标注文件", "FileManage", "", "", ToString(data), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} diff --git a/internal/handler/task.go b/internal/handler/task.go index e66a04c..367d614 100644 --- a/internal/handler/task.go +++ b/internal/handler/task.go @@ -115,3 +115,55 @@ func (s HandlerService) TaskLog(c *gin.Context) (data interface{}, err error) { go s.SaveLog("获取任务日志信息", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") return } + +func (s HandlerService) TrainingTaskList(c *gin.Context) (data interface{}, err error) { + repo := service.NewTaskService(s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.TaskRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("TrainingTaskList", "Manage", "", "", req.ToString(), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + if req.Size < 1 { + req.Size = 20 + } + if req.Size > 1000 { + req.Size = 1000 + } + if req.Page < 1 { + req.Page = 1 + } + data, err = repo.TrainingTaskList(c, req) + go s.SaveLog("获取训练任务列表", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} +func (s HandlerService) TrainingTaskInfo(c *gin.Context) (data interface{}, err error) { + repo := service.NewTaskService(s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.TaskItemRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("TrainingTaskInfo", "Manage", "", "", req.ToString(), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + data, err = repo.TrainingTaskInfo(c, req) + go s.SaveLog("获取训练任务详情", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} +func (s HandlerService) TrainingTaskLog(c *gin.Context) (data interface{}, err error) { + repo := service.NewTaskService(s.Engine, s.Logger) + us, _ := c.Get("operatorUser") + userInfo := us.(*model.SystemUser) + var req proto.TaskItemRequest + err = c.ShouldBindJSON(&req) + if err != nil { + go s.SaveLog("TrainingTaskLog", "Manage", "", "", req.ToString(), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return nil, e.NewValidErr(err) + } + data, err = repo.TrainingTaskLog(c, req) + go s.SaveLog("获取训练任务日志详情", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "") + return +} diff --git a/internal/proto/request.go b/internal/proto/request.go index 34c6e17..cd5a334 100644 --- a/internal/proto/request.go +++ b/internal/proto/request.go @@ -580,6 +580,11 @@ func (p DatasetItemRequest) ToString() string { return string(data) } +type FileLabelRequest struct { + FileId int64 `json:"fileId"` + IsDisease bool `json:"isDisease"` +} + type ImportDatasetRequest struct { DatasetId int64 `json:"datasetId"` CategoryId int `json:"categoryId"` @@ -745,3 +750,15 @@ type TrainDatasetRequest struct { SplitMethod int `json:"splitMethod"` UserId int64 `json:"userId"` } + +type TrainDatasetItemRequest struct { + DatasetId int64 `json:"datasetId"` + TrainName string `json:"trainName"` + BizType int `json:"bizType"` + BasePageList +} + +type EdgeDatasetRequest struct { + NodeId int64 `json:"nodeId"` + Path string `json:"path"` +} diff --git a/internal/proto/response.go b/internal/proto/response.go index e678315..e4b231a 100644 --- a/internal/proto/response.go +++ b/internal/proto/response.go @@ -168,3 +168,27 @@ type TaskLogPayload struct { Status int `json:"status"` //1:执行成功;2:执行失败 EventTime int64 `json:"eventTime"` } + +type TrainingDataset struct { + DatasetId int64 `json:"datasetId"` + Name string `json:"name"` + CategoryId int `json:"categoryId"` + DatasetDesc string `json:"datasetDesc"` + TotalSize int64 `json:"totalSize"` + TrainSize int64 `json:"trainSize"` + ValSize int64 `json:"valSize"` + TestSize int64 `json:"testSize"` + StoreName string `json:"storeName"` + CreateAt int64 `json:"createAt"` + UpdateAt int64 `json:"updateAt"` +} + +type TrainingDatasetFileItem struct { + DetailId int64 `json:"detailId"` + FileName string `json:"fileName"` + FileSize int64 `json:"fileSize"` + FilePath string `json:"filePath"` + FileContent string `json:"fileContent"` + IsDisease int `json:"isDisease"` + CategoryId int `json:"categoryId"` +} diff --git a/internal/router/router.go b/internal/router/router.go index 6abe464..9f53b2a 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -148,6 +148,8 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi file.Use(middleware.JwtAuthMiddleware(logger.Logger)) file.POST("/upload", e.ErrorWrapper(hs.UploadFile)) file.POST("/list", e.ErrorWrapper(hs.FileList)) + file.POST("/label", e.ErrorWrapper(hs.FileLabel)) + //file.POST("/batchLabel", e.ErrorWrapper(hs.FileBatchLabel)) } system := r.Group("/system") { @@ -220,6 +222,12 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi // flusher.Flush() //} }) + train := task.Group("/train") + { + train.POST("/list", e.ErrorWrapper(hs.TrainingTaskList)) + train.POST("/info", e.ErrorWrapper(hs.TrainingTaskInfo)) + train.POST("/log", e.ErrorWrapper(hs.TrainingTaskLog)) + } } disease := r.Group("/disease") { @@ -243,10 +251,12 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi dataset.POST("/info", e.ErrorWrapper(hs.DatasetInfo)) } - training := r.Group("/training") + training := r.Group("/trainDataset") { training.Use(middleware.JwtAuthMiddleware(logger.Logger)) - training.POST("/create", e.ErrorWrapper(hs.CreateTraining)) + training.POST("/create", e.ErrorWrapper(hs.CreateTrainDataset)) + training.POST("/list", e.ErrorWrapper(hs.TrainDatasetList)) + training.POST("/fileList", e.ErrorWrapper(hs.TrainDatasetFileList)) //training.POST("/list", e.ErrorWrapper(hs.TrainingList)) //training.POST("/info", e.ErrorWrapper(hs.TrainingInfo)) } @@ -263,6 +273,16 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi report.POST("/generate", e.ErrorWrapper(hs.GenerateReport)) //report.POST("/view", e.ErrorWrapper(hs.ViewReport)) } + edge := r.Group("/edge") + { + edge.Use(middleware.JwtAuthMiddleware(logger.Logger)) + dir := edge.Group("/directory") + { + dir.POST("/list", e.ErrorWrapper(hs.GetEdgeList)) + dir.POST("/info", e.ErrorWrapper(hs.GetEdgeInfo)) + } + + } } return root } diff --git a/internal/service/dataset.go b/internal/service/dataset.go index 46949b8..ea05b39 100644 --- a/internal/service/dataset.go +++ b/internal/service/dataset.go @@ -2,14 +2,17 @@ package service import ( "context" + "encoding/base64" "fmt" "git.hpds.cc/Component/logging" "hpds-iot-web/config" "hpds-iot-web/internal/proto" "hpds-iot-web/model" + "hpds-iot-web/pkg/utils" "math" "math/rand" "net/http" + "path" "time" "xorm.io/xorm" ) @@ -20,7 +23,9 @@ type DatasetService interface { ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error) DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) - CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) + CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) + TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) + TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) } func NewDatasetService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DatasetService { @@ -229,7 +234,7 @@ ReturnPoint: return rsp, err } -func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) { +func (rp *repo) CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): @@ -258,7 +263,7 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques goto ReturnPoint } fileList := make([]model.FileManager, 0) - err = rp.engine.Where("dataset_id = ?", req.DatasetId).Find(&fileList) + err = rp.engine.Where("dataset_id = ?", req.DatasetId).And("is_disease > 0").Find(&fileList) if err != nil { goto ReturnPoint } @@ -266,7 +271,7 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques req.TargetData = len(fileList) } if req.TargetData > len(fileList) { - err = fmt.Errorf("超出现有数据集数量") + err = fmt.Errorf("超出现有标注数据集数量") goto ReturnPoint } if req.SplitMethod == 1 { //随机 @@ -278,9 +283,24 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100)) valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100)) testNumber := req.TargetData - trainNumber - valNumber - trainFileList = fileList[:trainNumber-1] - valFileList = fileList[trainNumber : valNumber-1] - testFileList = fileList[valNumber:] + if trainNumber-1 > 1 { + trainFileList = fileList[:trainNumber-1] + } else { + trainFileList = make([]model.FileManager, 0) + trainFileList = append(trainFileList, fileList[0]) + } + if trainNumber != trainNumber+valNumber-1 { + valFileList = fileList[trainNumber : trainNumber+valNumber-1] + } else { + valFileList = make([]model.FileManager, 0) + valFileList = append(valFileList, fileList[trainNumber]) + } + if trainNumber+valNumber < len(fileList) { + testFileList = fileList[trainNumber+valNumber:] + } else { + testFileList = make([]model.FileManager, 0) + testFileList = append(testFileList, fileList[trainNumber+valNumber]) + } train := new(model.TrainingDataset) h, err = rp.engine.Where("name = ?", req.TrainName).Get(train) @@ -290,7 +310,7 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques if !h { train.Name = req.TrainName train.DatasetDesc = req.TrainDesc - train.DatasetId = req.DatasetId + //train.DatasetId = req.DatasetId train.CategoryId = dataset.CategoryId _, err = rp.engine.Insert(train) if err != nil { @@ -310,9 +330,9 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques goto ReturnPoint } //wg.Add(3) - go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, rp.engine) //, &wg - go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, rp.engine) //, &wg - go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, rp.engine) //, &wg + go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, req.TrainName, rp) //, &wg + go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, req.TrainName, rp) //, &wg + go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, req.TrainName, rp) //, &wg //wg.Wait() rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) @@ -331,16 +351,24 @@ ReturnPoint: return rsp, err } -func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, engine *xorm.Engine) { //, wg *sync.WaitGroup +func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, trainName string, rp *repo) { //, wg *sync.WaitGroup batchList := make([]model.TrainingDatasetDetail, len(list)) for k, v := range list { + dir := "no_disease" + if v.IsDisease == 1 { + dir = "disease" + } + utils.DownloadMinioFileToLocalPath(v.AccessUrl, path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir), v.FileName, + rp.AppConfig.Minio.Protocol, rp.AppConfig.Minio.Endpoint, rp.AppConfig.Minio.Bucket, rp.AppConfig.Minio.AccessKeyId, + rp.AppConfig.Minio.SecretAccessKey, rp.logger) item := model.TrainingDatasetDetail{ FileName: v.FileName, - FilePath: v.AccessUrl, + FilePath: path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir, v.FileName), DatasetId: trainId, CategoryId: categoryId, FileSize: v.FileSize, FileMd5: v.FileMd5, + IsDisease: v.IsDisease, OperationLogId: logId, Creator: userId, CreateAt: time.Now().Unix(), @@ -348,6 +376,141 @@ func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categ } batchList[k] = item } - _, _ = engine.Insert(batchList) + _, _ = rp.engine.Insert(batchList) //wg.Done() } + +func (rp *repo) TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + type QuantityStatistics struct { + DatasetId int64 + CategoryId int + Total int64 + } + + var ( + count int64 + list []proto.TrainingDataset + ) + trainingList := make([]model.TrainingDataset, 0) + count, err = rp.engine.Where("(?=0 or dataset_id = ?)", req.DatasetId, req.DatasetId). + And("(?= 0 or category_id = ?)", req.BizType, req.BizType). + And("(? ='' or name like ?)", req.TrainName, "%"+req.TrainName+"%"). + Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&trainingList) + if err != nil { + goto ReturnPoint + } + list = make([]proto.TrainingDataset, len(trainingList)) + for k, v := range trainingList { + qs := make([]QuantityStatistics, 0) + err = rp.engine.SQL("select dataset_id, category_id, count(1) as total from training_dataset_detail where dataset_id = ? group by category_id, dataset_id", v.DatasetId).Find(&qs) + if err != nil { + goto ReturnPoint + } + item := proto.TrainingDataset{ + DatasetId: v.DatasetId, + Name: v.Name, + CategoryId: v.CategoryId, + DatasetDesc: v.DatasetDesc, + TotalSize: 0, + TrainSize: 0, + ValSize: 0, + TestSize: 0, + StoreName: v.StoreName, + CreateAt: v.CreateAt, + UpdateAt: v.UpdateAt, + } + for _, val := range qs { + switch val.CategoryId { + case 1: + item.TrainSize = val.Total + case 2: + item.ValSize = val.Total + case 3: + item.TestSize = val.Total + } + item.TotalSize += val.Total + } + list[k] = item + } + rsp.Code = http.StatusOK + rsp.Status = http.StatusText(http.StatusOK) + rsp.Message = "成功" + rsp = FillPaging(count, req.Page, req.Size, list, rsp) + rsp.Err = err + return rsp, err + } + +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} + +func (rp *repo) TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + var ( + count int64 + list []proto.TrainingDatasetFileItem + ) + fileList := make([]model.TrainingDatasetDetail, 0) + count, err = rp.engine.Where("dataset_id = ?", req.DatasetId). + Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&fileList) + if err != nil { + goto ReturnPoint + } + list = make([]proto.TrainingDatasetFileItem, len(fileList)) + for k, v := range fileList { + buff := utils.ReadFile(v.FilePath) + img := utils.BuffToImage(buff) + buf := utils.ImageToBuff(img, "jpeg") + list[k] = proto.TrainingDatasetFileItem{ + DetailId: v.DetailId, + FileName: v.FileName, + FileSize: v.FileSize, + FilePath: v.FilePath, + FileContent: "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()), + IsDisease: v.IsDisease, + CategoryId: v.CategoryId, + } + } + + rsp.Code = http.StatusOK + rsp.Status = http.StatusText(http.StatusOK) + rsp.Message = "成功" + rsp = FillPaging(count, req.Page, req.Size, list, rsp) + rsp.Err = err + return rsp, err + } + +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} diff --git a/internal/service/edge.go b/internal/service/edge.go new file mode 100644 index 0000000..d7a9fc2 --- /dev/null +++ b/internal/service/edge.go @@ -0,0 +1,96 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "git.hpds.cc/Component/logging" + "hpds-iot-web/config" + "hpds-iot-web/internal/proto" + "hpds-iot-web/pkg/utils" + "net/http" + "xorm.io/xorm" +) + +type EdgeService interface { + GetEdgeList(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error) + GetEdgeInfo(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error) +} + +func NewEdgeService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) EdgeService { + return &repo{ + AppConfig: cfg, + engine: engine, + logger: logger, + } +} + +func (rp *repo) GetEdgeList(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + param := make(map[string]string) + param["path"] = req.Path + header := make(map[string]string) + header["Content-Type"] = "application/json" + res, err := utils.HttpDo("http://192.168.22.151:8099/api/directory/list", "POST", param, header) + if err != nil { + goto ReturnPoint + } + err = json.Unmarshal(res, &rsp) + if err != nil { + goto ReturnPoint + } + return rsp, err + } +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} + +func (rp *repo) GetEdgeInfo(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + param := make(map[string]string) + param["path"] = req.Path + header := make(map[string]string) + header["Content-Type"] = "application/json" + res, err := utils.HttpDo("http://192.168.22.151:8099/api/directory/info", "POST", param, header) + if err != nil { + goto ReturnPoint + } + err = json.Unmarshal(res, &rsp) + if err != nil { + goto ReturnPoint + } + return rsp, err + } +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} diff --git a/internal/service/fileManage.go b/internal/service/fileManage.go index 3790c7e..deb96be 100644 --- a/internal/service/fileManage.go +++ b/internal/service/fileManage.go @@ -24,6 +24,7 @@ type FileService interface { UploadFile(ctx context.Context, req proto.UploadFileRequest) (rsp *proto.BaseResponse, err error) UploadFileToMinIo(ctx context.Context, srcFile *multipart.FileHeader, scene string, datasetId, creator int64, dataType int) (data *model.FileManager, err error) FileList(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) + FileLabel(ctx context.Context, req proto.FileLabelRequest) (rsp *proto.BaseResponse, err error) } func NewFileService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) FileService { @@ -171,3 +172,50 @@ ReturnPoint: } return rsp, err } + +func (rp *repo) FileLabel(ctx context.Context, req proto.FileLabelRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + item := new(model.FileManager) + var h bool + h, err = rp.engine.ID(req.FileId).Get(item) + if err != nil { + goto ReturnPoint + } + if !h { + err = fmt.Errorf("未能找到对应的文件") + goto ReturnPoint + } + if req.IsDisease { + item.IsDisease = 1 + } else { + item.IsDisease = 2 + } + _, err = rp.engine.ID(req.FileId).Cols("is_disease").Update(item) + if err != nil { + goto ReturnPoint + } + + rsp.Code = http.StatusOK + rsp.Status = http.StatusText(http.StatusOK) + rsp.Message = "成功" + rsp.Err = err + return rsp, err + } +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} diff --git a/internal/service/manage.go b/internal/service/manage.go index 2b2fb83..65f938d 100644 --- a/internal/service/manage.go +++ b/internal/service/manage.go @@ -383,20 +383,22 @@ func (rp *repo) AddProject(ctx context.Context, req proto.ProjectItemRequest) (r Status: 1, Creator: req.Creator, } - - slng, slat, err := rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.StartName)) + var ( + sLng, sLat, eLng, eLat float64 + ) + sLng, sLat, err = rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.StartName)) if err != nil { goto ReturnPoint } - item.StartPointLng = slng - item.StartPointLat = slat + item.StartPointLng = sLng + item.StartPointLat = sLat - elng, elat, err := rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.EndName)) + eLng, eLat, err = rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.EndName)) if err != nil { goto ReturnPoint } - item.EndPointLng = elng - item.EndPointLat = elat + item.EndPointLng = eLng + item.EndPointLat = eLat _, err = rp.engine.Insert(item) if err != nil { diff --git a/internal/service/task.go b/internal/service/task.go index 968bd1e..28d6065 100644 --- a/internal/service/task.go +++ b/internal/service/task.go @@ -26,6 +26,10 @@ type TaskService interface { //EditTask(ctx context.Context, req proto.ModelItemRequest) (rsp *proto.BaseResponse, err error) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *proto.BaseResponse, err error) TaskLog(ctx context.Context, req proto.TaskLogItem) (rsp *proto.BaseResponse, err error) + + TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) + TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) + TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) } func NewTaskService(engine *xorm.Engine, logger *logging.Logger) TaskService { @@ -402,6 +406,11 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p if err := json.Unmarshal([]byte(str), &mr); err != nil { continue } + switch md.BizType { + case 1: //道路 + case 2: //桥梁 + case 3: //隧道 + } switch mr.Code { case 0: //轻量化模型返回 lr := new(mq.LightweightResult) @@ -433,6 +442,9 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p fn, _ := base64.StdEncoding.DecodeString(fileDiscern) buff := bytes.NewBuffer(fn) _, imgType, _ := image.Decode(buff) + if len(fileDiscern) == 0 { + fileDiscern = lr.ImgSrc + } fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern) item := proto.TaskResultItem{ FileId: v.FileId, @@ -555,6 +567,9 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p memo = "检测到坑洼" } fileDiscern = lr.ImgDiscern + if len(fileDiscern) == 0 { + fileDiscern = lr.ImgSrc + } diseaseLevel = 3 diseaseLevelName = "重度" switch md.BizType { @@ -568,6 +583,8 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p diseaseType = 4 diseaseTypeName = "横向裂缝" } + } else { + fileDiscern = lr.ImgSrc } // case 2000: //网新返回没有病害 @@ -677,3 +694,115 @@ ReturnPoint: } return rsp, err } + +func (rp *repo) TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + var ( + count int64 + ) + list := make([]model.TrainTask, 0) + count, err = rp.engine.Where("(? = 0 or category_id = ?)", req.BizType, req.BizType). + And("(? = '' or task_name like ?)", req.TaskName, "%"+req.TaskName+"%"). + Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list) + if err != nil { + goto ReturnPoint + } + rsp.Code = http.StatusOK + rsp.Status = http.StatusText(http.StatusOK) + rsp.Message = "成功" + rsp.Err = err + rsp = FillPaging(count, req.Page, req.Size, list, rsp) + return rsp, err + } +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} + +func (rp *repo) TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + var ( + h bool + ) + item := new(model.TrainTask) + h, err = rp.engine.ID(req.TaskId).Get(item) + if err != nil { + goto ReturnPoint + } + if !h { + err = fmt.Errorf("未能找到对应的任务信息") + goto ReturnPoint + } + rsp.Code = http.StatusOK + rsp.Status = http.StatusText(http.StatusOK) + rsp.Message = "成功" + rsp.Err = err + rsp.Data = item + return rsp, err + } +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} + +func (rp *repo) TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) { + rsp = new(proto.BaseResponse) + select { + case <-ctx.Done(): + err = fmt.Errorf("超时/取消") + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Message = "超时/取消" + rsp.Err = ctx.Err() + return rsp, ctx.Err() + default: + list := make([]model.TrainTaskLog, 0) + err = rp.engine.Where("task_id = ?", req.TaskId).Asc("epoch").Find(&list) + if err != nil { + goto ReturnPoint + } + + rsp.Code = http.StatusOK + rsp.Status = http.StatusText(http.StatusOK) + rsp.Message = "成功" + rsp.Err = err + rsp = FillPaging(int64(len(list)), 1, 1000, list, rsp) + return rsp, err + } +ReturnPoint: + if err != nil { + rsp.Code = http.StatusInternalServerError + rsp.Status = http.StatusText(http.StatusInternalServerError) + rsp.Err = err + rsp.Message = "失败" + } + return rsp, err +} diff --git a/model/file.go b/model/file.go index c6034bc..4266b88 100644 --- a/model/file.go +++ b/model/file.go @@ -9,6 +9,7 @@ type FileManager struct { DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //数据集 FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小 FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5 + IsDisease int `xorm:"TINYINT index default 0" json:"isDisease"` //数据标注状态; 0:未标注;1:有病害;2:无病害 Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人 CreateAt int64 `xorm:"created" json:"createAt"` //上传时间 UpdateAt int64 `xorm:"updated" json:"updateAt"` //更新时间 diff --git a/model/index.go b/model/index.go index 082b868..d00a14f 100644 --- a/model/index.go +++ b/model/index.go @@ -2,6 +2,7 @@ package model import ( "fmt" + "git.hpds.cc/Component/logging" "github.com/go-redis/redis" _ "github.com/go-sql-driver/mysql" "go.uber.org/zap" @@ -21,7 +22,7 @@ var ( Redis *redis.Client ) -func New(driveName, dsn string, showSql bool) { +func New(driveName, dsn string, showSql bool, logger *logging.Logger) { DB, _ = NewDbConnection(driveName, dsn) DB.ShowSQL(showSql) DB.Dialect().SetQuotePolicy(dialects.QuotePolicyReserved) @@ -64,9 +65,11 @@ func New(driveName, dsn string, showSql bool) { &TaskResult{}, &TrainingDataset{}, &TrainingDatasetDetail{}, + &TrainTask{}, + &TrainTaskLog{}, ) if err != nil { - zap.L().Error("同步数据库表结构", zap.Error(err)) + logger.Error("同步数据库表结构", zap.Error(err)) os.Exit(1) } } @@ -99,3 +102,15 @@ func NewCache(c config.CacheConfig) { zap.L().Info("Redis连接成功", zap.String("pong", pong)) } } + +func GetTrainCategory(categoryId int) string { + switch categoryId { + case 1: + return "train" + case 2: + return "val" + case 3: + return "test" + } + return "other" +} diff --git a/model/trainTask.go b/model/trainTask.go new file mode 100644 index 0000000..48e808a --- /dev/null +++ b/model/trainTask.go @@ -0,0 +1,18 @@ +package model + +type TrainTask struct { + TaskId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskId"` + TrainDatasetId int64 `xorm:"INT(11) index" json:"trainDatasetId"` + CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡 + TaskName string `xorm:"VARCHAR(200)" json:"taskName"` + TaskDesc string `xorm:"VARCHAR(500)" json:"taskDesc"` + StartTime int64 `xorm:"BIGINT" json:"startTime"` + FinishTime int64 `xorm:"BIGINT" json:"finishTime"` + Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"` + Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"` + ModelFilePath string `xorm:"VARCHAR(2000)" json:"modelFilePath"` + PbModelFilePath string `xorm:"VARCHAR(2000)" json:"pbModelFilePath"` + Status int `xorm:"not null SMALLINT default 0" json:"status"` // 1:等待执行; 2:执行中; 3:执行完成; 4:任务分配失败; 5:任务执行失败 + CreateAt int64 `xorm:"created" json:"createAt"` + UpdateAt int64 `xorm:"updated" json:"updateAt"` +} diff --git a/model/trainTaskLog.go b/model/trainTaskLog.go new file mode 100644 index 0000000..9f6188d --- /dev/null +++ b/model/trainTaskLog.go @@ -0,0 +1,12 @@ +package model + +type TrainTaskLog struct { + LogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"logId"` + TaskId int64 `xorm:"INT(11) index" json:"taskId"` + Epoch int `xorm:"SMALLINT" json:"epoch"` + Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"` + Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"` + ValLoss float64 `xorm:"DECIMAL(18,6)" json:"valLoss"` + ValAccuracy float64 `xorm:"DECIMAL(18,6)" json:"valAccuracy"` + CreateAt int64 `xorm:"created" json:"createAt"` +} diff --git a/model/trainingDatasetDetail.go b/model/trainingDatasetDetail.go index 7c4b75e..37dcc44 100644 --- a/model/trainingDatasetDetail.go +++ b/model/trainingDatasetDetail.go @@ -6,8 +6,9 @@ type TrainingDatasetDetail struct { FilePath string `xorm:"VARCHAR(1000)" json:"filePath"` DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //训练数据集 CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //训练集分类,1:训练集;2:测试集;3:验证集 - FileSize int64 `xorm:"BININT" json:"fileSize"` //文件大小 + FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小 FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5 + IsDisease int `xorm:"TINYINT(1)" json:"isDisease"` //是否有病害 OperationLogId int64 `xorm:"INT(11) index" json:"operationLogId"` //操作日志编号 Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人 CreateAt int64 `xorm:"created" json:"createAt"` //上传时间 diff --git a/pkg/minio/index.go b/pkg/minio/index.go new file mode 100644 index 0000000..b407b6f --- /dev/null +++ b/pkg/minio/index.go @@ -0,0 +1,46 @@ +package minio + +import ( + "context" + "git.hpds.cc/Component/logging" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "io" +) + +type MinClient struct { + Client *minio.Client + Logger *logging.Logger +} + +func NewClient(ak, sak, ep string, useSSL bool, logger *logging.Logger) *MinClient { + opt := &minio.Options{ + Creds: credentials.NewStaticV4(ak, sak, ""), + Secure: useSSL, + } + client, err := minio.New(ep, opt) + if err != nil { + return nil + } + return &MinClient{ + Client: client, + Logger: logger, + } +} + +func (cli *MinClient) UploadObject(fn, dst, bucket string) error { + _, err := cli.Client.FPutObject(context.Background(), bucket, dst, fn, minio.PutObjectOptions{}) + if err != nil { + return err + } + return nil +} + +func (cli *MinClient) GetObject(dstUrl, bucket string) ([]byte, error) { + f, err := cli.Client.GetObject(context.Background(), bucket, dstUrl, minio.GetObjectOptions{}) + if err != nil { + return nil, err + } + imgByte, _ := io.ReadAll(f) + return imgByte, nil +} diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 0000000..59e067c --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,119 @@ +package utils + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "git.hpds.cc/Component/logging" + "go.uber.org/zap" + "hpds-iot-web/pkg/minio" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +func CopyFile(src, dst string) error { + sourceFileStat, err := os.Stat(src) + if err != nil { + return err + } + + if !sourceFileStat.Mode().IsRegular() { + return fmt.Errorf("%s is not a regular file", src) + } + + source, err := os.Open(src) + if err != nil { + return err + } + defer func(source *os.File) { + _ = source.Close() + }(source) + + destination, err := os.Create(dst) + if err != nil { + return err + } + defer func(destination *os.File) { + _ = destination.Close() + }(destination) + _, err = io.Copy(destination, source) + return err +} + +func PathExists(path string) bool { + _, err := os.Stat(path) + if err == nil { + return true + } + if os.IsNotExist(err) { + return false + } + return false +} + +// ReadFile 读取到file中,再利用ioutil将file直接读取到[]byte中, 这是最优 +func ReadFile(fn string) []byte { + f, err := os.Open(fn) + if err != nil { + logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err)) + return nil + } + defer func(f *os.File) { + _ = f.Close() + }(f) + + fd, err := io.ReadAll(f) + if err != nil { + logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err)) + return nil + } + + return fd +} + +func GetFileName(fn string) string { + fileType := path.Ext(fn) + return strings.TrimSuffix(fn, fileType) +} +func GetFileNameAndExt(fn string) string { + _, fileName := filepath.Split(fn) + return fileName +} +func GetFileMd5(data []byte) string { + hash := md5.New() + hash.Write(data) + return hex.EncodeToString(hash.Sum(nil)) +} + +func DownloadMinioFileToLocalPath(accessUrl, dstPath, fileName, protocol, endpoint, bucket, accessKeyId, secretAccessKey string, + logger *logging.Logger) { + + if !PathExists(path.Join(dstPath, fileName)) { + dPath := strings.Replace(accessUrl, fmt.Sprintf("%s://%s/", protocol, endpoint), "", 1) + + dPath = strings.Replace(dPath, bucket, "", 1) + minioCli := minio.NewClient(accessKeyId, secretAccessKey, endpoint, false, logger) + + imgByte, err := minioCli.GetObject(dPath, bucket) + if err != nil { + logger.With(zap.String("源文件名", accessUrl)). + With(zap.String("文件名", path.Join(dstPath, fileName))). + Error("文件下载", zap.Error(err)) + } + err = os.MkdirAll(dstPath, os.ModePerm) + if err != nil { + logger.With(zap.String("源文件名", accessUrl)). + With(zap.String("文件名", path.Join(dstPath, fileName))). + Error("创建文件下载目录", zap.Error(err)) + } + err = os.WriteFile(path.Join(dstPath, fileName), imgByte, os.ModePerm) + if err != nil { + logger.With(zap.String("源文件名", accessUrl)). + With(zap.String("文件名", path.Join(dstPath, fileName))). + Error("文件写入", zap.Error(err)) + } + } +} diff --git a/pkg/utils/http.go b/pkg/utils/http.go new file mode 100644 index 0000000..81783b2 --- /dev/null +++ b/pkg/utils/http.go @@ -0,0 +1,126 @@ +package utils + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "path/filepath" + "strings" +) + +func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) { + var paramStr string = "" + if contentType, ok := header["Content-Type"]; ok && strings.Contains(contentType, "json") { + bytesData, _ := json.Marshal(params) + paramStr = string(bytesData) + } else { + for k, v := range params { + if len(paramStr) == 0 { + paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v)) + } else { + paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v)) + } + } + } + + client := &http.Client{} + req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr)) + if err != nil { + return nil, err + } + for k, v := range header { + req.Header.Set(k, v) + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + defer func() { + if resp.Body != nil { + err = resp.Body.Close() + if err != nil { + return + } + } + }() + var body []byte + if resp.Body != nil { + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + } + return body, nil +} + +type UploadFile struct { + // 表单名称 + Name string + Filepath string + // 文件全路径 + File *bytes.Buffer +} + +func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string { + requestBody, realContentType := getReader(reqParams, contentType, files) + httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody) + // 添加请求头 + httpRequest.Header.Add("Content-Type", realContentType) + if headers != nil { + for k, v := range headers { + httpRequest.Header.Add(k, v) + } + } + httpClient := &http.Client{} + // 发送请求 + resp, err := httpClient.Do(httpRequest) + if err != nil { + panic(err) + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + response, _ := io.ReadAll(resp.Body) + return string(response) +} + +func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) { + if strings.Index(contentType, "json") > -1 { + bytesData, _ := json.Marshal(reqParams) + return bytes.NewReader(bytesData), contentType + } else if files != nil { + body := &bytes.Buffer{} + // 文件写入 body + writer := multipart.NewWriter(body) + for _, uploadFile := range files { + part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath)) + if err != nil { + panic(err) + } + _, err = io.Copy(part, uploadFile.File) + } + // 其他参数列表写入 body + for k, v := range reqParams { + if err := writer.WriteField(k, v); err != nil { + panic(err) + } + } + if err := writer.Close(); err != nil { + panic(err) + } + // 上传文件需要自己专用的contentType + return body, writer.FormDataContentType() + } else { + urlValues := url.Values{} + for key, val := range reqParams { + urlValues.Set(key, val) + } + reqBody := urlValues.Encode() + return strings.NewReader(reqBody), contentType + } +}