package service import ( "context" "fmt" "git.hpds.cc/Component/logging" "go.uber.org/zap" "hpds-iot-web/config" "hpds-iot-web/internal/proto" "hpds-iot-web/model" "hpds-iot-web/pkg/utils" "math" "math/rand" "net/http" "time" "xorm.io/xorm" ) type DiseaseService interface { DiseaseList(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error) DiseaseListNew(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error) DiseaseStatistics(ctx context.Context) (rsp *proto.BaseResponse, err error) DiseaseTypeList(ctx context.Context, req proto.DiseaseTypeRequest) (rsp *proto.BaseResponse, err error) AddDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) EditDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) DeleteDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) CreateTrainDatasetByDisease(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) } func NewDiseaseService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DiseaseService { return &repo{ AppConfig: cfg, engine: engine, logger: logger, } } func (rp *repo) DiseaseList(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: data := make([]model.Disease, 0) count, err := rp.engine.Where("(? = '' or disease_name like ?)", req.Key, "%"+req.Key+"%"). And("(? = 0 or category_id = ?)", req.DiseaseType, req.DiseaseType). Limit(int(req.Size), int(((req.Page)-1)*req.Size)). FindAndCount(&data) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, data, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DiseaseListNew(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var count int64 list := make([]model.LabelFile, 0) count, err = rp.engine.Where("(? = 0 or category_id = ?)", req.DiseaseType, req.DiseaseType). And("(?=0 or file_type = ?)", req.FileType, req.FileType). And("(?=0 or label_type = ?)", req.LabelType, req.LabelType). And("(?=-1 or pid = ?)", req.Pid, req.Pid). Limit(int(req.Size), int(((req.Page)-1)*req.Size)). FindAndCount(&list) if err != nil { goto ReturnPoint } data := make([]proto.DiseaseFileInfoItem, len(list)) for k, v := range list { item := proto.DiseaseFileInfoItem{ FileId: v.FileId, FileName: v.FileName, FilePath: v.FilePath, CategoryId: v.CategoryId, CategoryName: model.GetBizType(v.CategoryId), FileSize: v.FileSize, LabelType: v.LabelType, FileType: v.FileType, FileTypeName: model.GetFileType(v.FileType), LabelTypeName: model.GetLabelType(v.LabelType), CreateAt: v.CreateAt, UpdateAt: v.UpdateAt, Pid: v.Pid, } if item.FileType == 1 || (item.FileType == 3 && item.Pid > 0) { item.FileContent = utils.ImgFileToBase64(v.FilePath) } if item.FileType == 2 { item.FileContent = utils.FileToBase64(v.FilePath) } data[k] = item } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, data, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DiseaseStatistics(ctx context.Context) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: list := make([]*proto.DiseaseStatisticsItem, 5) type Statistics struct { CategoryId int TotalCount int64 TotalSize int64 } //所有的数据 statList := make([]Statistics, 0) err = rp.engine.SQL(`select category_id, sum(file_size) total_size, count(file_id) total_count from label_file group by category_id;`).Find(&statList) if err != nil { goto ReturnPoint } totalItem := &proto.DiseaseStatisticsItem{ DiseaseType: 0, DiseaseName: "数据量总计", } for _, v := range statList { item := &proto.DiseaseStatisticsItem{ DiseaseType: v.CategoryId, DiseaseName: model.GetBizType(v.CategoryId) + "数据", TotalNum: v.TotalCount, TotalSize: v.TotalSize, } totalItem.TotalNum += v.TotalCount totalItem.TotalSize += v.TotalSize list[v.CategoryId] = item } list[0] = totalItem rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Data = list rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DiseaseTypeList(ctx context.Context, req proto.DiseaseTypeRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: data := make([]model.DiseaseType, 0) count, err := rp.engine.Where("(? = '' or type_name like ?)", req.Key, "%"+req.Key+"%"). And("(? = 0 or category_id = ?)", req.CategoryId, req.CategoryId). And("status = 1").Limit(int(req.Size), int(((req.Page)-1)*req.Size)). FindAndCount(&data) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, data, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) AddDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: item := &model.DiseaseType{ TypeName: req.TypeName, CategoryId: req.CategoryId, Status: 1, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } _, err = rp.engine.Insert(item) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "新增病害类型成功" rsp.Err = ctx.Err() rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) EditDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var h bool item := new(model.DiseaseType) h, err = rp.engine.ID(req.TypeId).Get(item) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的类型") goto ReturnPoint } if len(req.TypeName) > 0 { item.TypeName = req.TypeName } if req.CategoryId > 0 { item.CategoryId = req.CategoryId } item.UpdateAt = time.Now().Unix() _, err = rp.engine.ID(req.TypeId).AllCols().Update(item) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "修改病害类型成功" rsp.Err = ctx.Err() rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DeleteDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var h bool item := new(model.DiseaseType) h, err = rp.engine.ID(req.TypeId).Get(item) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的类型") goto ReturnPoint } item.Status = 0 item.UpdateAt = time.Now().Unix() _, err = rp.engine.ID(req.TypeId).AllCols().Update(item) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "删除病害类型成功" rsp.Err = ctx.Err() rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) CreateTrainDatasetByDisease(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool //trainFileList []model.LabelFile trainDiseaseFileList []model.LabelFile trainNoDiseaseFileList []model.LabelFile //valFileList []model.LabelFile valDiseaseFileList []model.LabelFile valNoDiseaseFileList []model.LabelFile //testFileList []model.LabelFile testDiseaseFileList []model.LabelFile testNoDiseaseFileList []model.LabelFile trainDiseaseCount int64 trainNoDiseaseCount int64 //wg sync.WaitGroup ) log := new(model.DatasetOperationLog) if err != nil { goto ReturnPoint } fileList := make([]model.LabelFile, 0) err = rp.engine.Where("category_id = ? and file_type = 1", req.BizType).Find(&fileList) if err != nil { goto ReturnPoint } trainDiseaseFileList = make([]model.LabelFile, 0) err = rp.engine.Where("category_id = ? and label_type = 1 and file_type = 1", req.BizType).Find(&trainDiseaseFileList) if err != nil { goto ReturnPoint } trainNoDiseaseFileList = make([]model.LabelFile, 0) err = rp.engine.Where("category_id = ? and label_type = 2 and file_type = 1", req.BizType).Find(&trainNoDiseaseFileList) if err != nil { goto ReturnPoint } if req.TargetData == 0 { req.TargetData = len(fileList) } if req.TargetData > len(fileList) { err = fmt.Errorf("超出现有标注数据集数量") goto ReturnPoint } trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100)) valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100)) testNumber := req.TargetData - trainNumber - valNumber trainDiseaseCount = int64(float64(int64(len(trainDiseaseFileList))*req.TrainNumber) / 100) trainNoDiseaseCount = int64(float64(int64(len(trainNoDiseaseFileList))*req.TrainNumber) / 100) if trainDiseaseCount+trainNoDiseaseCount > int64(trainNumber) { if trainDiseaseCount > int64(trainNumber/2) { trainDiseaseCount = int64(trainNumber / 2) } if trainNoDiseaseCount > int64(trainNumber)-trainDiseaseCount { trainNoDiseaseCount = int64(trainNumber) - trainDiseaseCount } } valDiseaseCount := int64(float64(int64(len(trainDiseaseFileList))*req.ValidationNumber) / 100) valNoDiseaseCount := int64(float64(int64(len(trainNoDiseaseFileList))*req.ValidationNumber) / 100) if valDiseaseCount+valNoDiseaseCount > int64(valNumber) { if valDiseaseCount > int64(valNumber/2) { valDiseaseCount = int64(valNumber / 2) } if valNoDiseaseCount > int64(valNumber)-valDiseaseCount { valNoDiseaseCount = int64(valNumber) - valDiseaseCount } } testDiseaseCount := int64(float64(int64(len(trainDiseaseFileList))*req.TestNumber) / 100) testNoDiseaseCount := int64(float64(int64(len(trainNoDiseaseFileList))*req.TestNumber) / 100) if testDiseaseCount+testNoDiseaseCount > int64(testNumber) { if testDiseaseCount > int64(testNumber/2) { testDiseaseCount = int64(testNumber / 2) } if testNoDiseaseCount > int64(testNumber)-testDiseaseCount { testNoDiseaseCount = int64(testNumber) - testDiseaseCount } } if req.SplitMethod == 1 { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(trainDiseaseFileList), func(i, j int) { trainDiseaseFileList[i], trainDiseaseFileList[j] = trainDiseaseFileList[j], trainDiseaseFileList[i] }) rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(trainNoDiseaseFileList), func(i, j int) { trainNoDiseaseFileList[i], trainNoDiseaseFileList[j] = trainNoDiseaseFileList[j], trainNoDiseaseFileList[i] }) rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(fileList), func(i, j int) { fileList[i], fileList[j] = fileList[j], fileList[i] }) } testDiseaseFileList = trainDiseaseFileList[:testDiseaseCount] testNoDiseaseFileList = trainNoDiseaseFileList[:testNoDiseaseCount] valDiseaseFileList = trainDiseaseFileList[testDiseaseCount : testDiseaseCount+valDiseaseCount] valNoDiseaseFileList = trainNoDiseaseFileList[testNoDiseaseCount : testNoDiseaseCount+valNoDiseaseCount] trainDiseaseFileList = trainDiseaseFileList[testDiseaseCount+valDiseaseCount : testDiseaseCount+valDiseaseCount+trainDiseaseCount] trainNoDiseaseFileList = trainNoDiseaseFileList[testNoDiseaseCount+valNoDiseaseCount : testNoDiseaseCount+valNoDiseaseCount+trainNoDiseaseCount] rp.logger.With(zap.String("创建训练集", "数据集大小"), zap.Int("有病害训练数据集", len(trainDiseaseFileList)), zap.Int("无病害训练数据集", len(trainNoDiseaseFileList)), zap.Int("有病害验证数据集", len(valDiseaseFileList)), zap.Int("无病害验证数据集", len(valNoDiseaseFileList)), ).Info("总数据集", zap.Int("len(fileList)", len(fileList))) train := new(model.TrainingDataset) h, err = rp.engine.Where("name = ?", req.TrainName).Get(train) if err != nil { goto ReturnPoint } if !h { train.Name = req.TrainName train.DatasetDesc = req.TrainDesc //train.DatasetId = req.DatasetId train.CategoryId = req.BizType train.ValidationNumber = float64(req.ValidationNumber) train.TestNumber = float64(req.TestNumber) _, err = rp.engine.Insert(train) if err != nil { goto ReturnPoint } } log.TargetData = int64(req.TargetData) log.DatasetId = req.DatasetId log.TrainingDatasetId = train.DatasetId log.SplitMethod = req.SplitMethod log.TrainNumber = int64(trainNumber) log.ValidationNumber = int64(valNumber) log.TestNumber = int64(testNumber) log.Creator = req.UserId _, err = rp.engine.Insert(log) if err != nil { goto ReturnPoint } list := make([]model.TrainingDatasetDetail, 0) for _, v := range trainDiseaseFileList { item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: v.FilePath, DatasetId: train.DatasetId, CategoryId: 1, FileSize: v.FileSize, IsDisease: 1, OperationLogId: log.LogId, Creator: req.UserId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } for _, v := range trainNoDiseaseFileList { item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: v.FilePath, DatasetId: train.DatasetId, CategoryId: 1, FileSize: v.FileSize, IsDisease: 2, OperationLogId: log.LogId, Creator: req.UserId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } for _, v := range valDiseaseFileList { item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: v.FilePath, DatasetId: train.DatasetId, CategoryId: 3, FileSize: v.FileSize, IsDisease: 1, OperationLogId: log.LogId, Creator: req.UserId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } for _, v := range valNoDiseaseFileList { item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: v.FilePath, DatasetId: train.DatasetId, CategoryId: 3, FileSize: v.FileSize, IsDisease: 2, OperationLogId: log.LogId, Creator: req.UserId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } for _, v := range testDiseaseFileList { item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: v.FilePath, DatasetId: train.DatasetId, CategoryId: 2, FileSize: v.FileSize, IsDisease: 1, OperationLogId: log.LogId, Creator: req.UserId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } for _, v := range testNoDiseaseFileList { item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: v.FilePath, DatasetId: train.DatasetId, CategoryId: 2, FileSize: v.FileSize, IsDisease: 2, OperationLogId: log.LogId, Creator: req.UserId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } list = append(list, item) } _, err = rp.engine.Insert(list) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = log return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err }