package service import ( "context" "encoding/base64" "fmt" "git.hpds.cc/Component/logging" "hpds-iot-web/config" "hpds-iot-web/internal/proto" "hpds-iot-web/model" "hpds-iot-web/pkg/utils" "math" "math/rand" "net/http" "path" "time" "xorm.io/xorm" ) type DatasetService interface { GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error) DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error) ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error) DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) } func NewDatasetService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DatasetService { return &repo{ AppConfig: cfg, engine: engine, logger: logger, } } func (rp *repo) GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: ownerList := make([]model.Owner, 0) err := rp.engine.Where("(? = '' or owner_name like ?)", req.Key, "%"+req.Key+"%"). And("status = 1").Find(&ownerList) if err != nil { goto ReturnPoint } data := make([]proto.OwnerProjectItem, len(ownerList)) for k, v := range ownerList { projectList := make([]proto.ProjectItem, 0) err = rp.engine.Table("project").Cols("concat('"+fmt.Sprintf("%d", v.OwnerId)+"-', project_id) as project_id", "project_name"). Where("owner_id = ?", v.OwnerId).And("status = 1").Find(&projectList) if err != nil { goto ReturnPoint } data[k] = proto.OwnerProjectItem{ OwnerId: v.OwnerId, OwnerName: fmt.Sprintf("%s[%s]", v.OwnerName, v.ChargeUser), ProjectList: projectList, } } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Data = data rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: list := make([]model.Dataset, 0) st := rp.engine.Where("(? = '' or dataset_name like ?)", req.DatasetName, "%"+req.DatasetName+"%"). And("(? = '' or create_at >= ?)", req.StartTime, req.StartTime). And("(? = '' or create_at < ?)", req.EndTime, req.EndTime). And("(? = 0 or owner_id = ?)", req.OwnerId, req.OwnerId). And("status = 1") if len(req.ProjectId) > 0 { st.In("project_id", req.ProjectId) } count, err := st.Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list) if err != nil { goto ReturnPoint } type QuantityStatistics struct { TotalNumber int64 TotalSize int64 } data := make([]proto.DatasetItem, len(list)) for k, v := range list { detailList := make([]model.FileManager, 0) datasetCount, err := rp.engine.Where("dataset_id = ?", v.DatasetId).FindAndCount(&detailList) if err != nil { goto ReturnPoint } fm := new(model.FileManager) datasetSize, err := rp.engine.Where("dataset_id = ?", v.DatasetId).SumInt(fm, "file_size") if err != nil { goto ReturnPoint } qs := new(QuantityStatistics) _, err = rp.engine.SQL(`select sum(file_size) total_size, count(file_id) total_number from file_manager where is_disease > 0`).Get(qs) if err != nil { goto ReturnPoint } data[k] = proto.DatasetItem{ DatasetId: v.DatasetId, DatasetName: v.DatasetName, DatasetDesc: v.DatasetDesc, StoreName: v.StoreName, CategoryId: v.CategoryId, ProjectId: v.ProjectId, OwnerId: v.OwnerId, Creator: v.Creator, CreateAt: v.CreateAt, DatasetCount: datasetCount, DatasetSize: datasetSize, LabelCount: qs.TotalNumber, LabelSize: qs.TotalSize, } } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, data, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: item := &model.Dataset{ DatasetName: req.DatasetName, DatasetDesc: req.DatasetDesc, CategoryId: req.CategoryId, ProjectId: req.ProjectId, OwnerId: req.OwnerId, StoreName: req.StoreName, Creator: req.Creator, Status: 1, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } _, err = rp.engine.Insert(item) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Data = item rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: item := new(model.Dataset) var b bool b, err = rp.engine.ID(req.DatasetId).Get(item) if err != nil { goto ReturnPoint } if !b { err = fmt.Errorf("未能找到对应的数据集") goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool trainFileList []model.FileManager valFileList []model.FileManager testFileList []model.FileManager //wg sync.WaitGroup ) log := new(model.DatasetOperationLog) dataset := new(model.Dataset) h, err = rp.engine.ID(req.DatasetId).Get(dataset) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的采集数据集") goto ReturnPoint } fileList := make([]model.FileManager, 0) err = rp.engine.Where("dataset_id = ?", req.DatasetId).And("is_disease > 0").Find(&fileList) if err != nil { goto ReturnPoint } if req.TargetData == 0 { req.TargetData = len(fileList) } if req.TargetData > len(fileList) { err = fmt.Errorf("超出现有标注数据集数量") goto ReturnPoint } if req.SplitMethod == 1 { //随机 rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(fileList), func(i, j int) { fileList[i], fileList[j] = fileList[j], fileList[i] }) } trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100)) valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100)) testNumber := req.TargetData - trainNumber - valNumber if trainNumber-1 > 1 { trainFileList = fileList[:trainNumber-1] } else { trainFileList = make([]model.FileManager, 0) trainFileList = append(trainFileList, fileList[0]) } if trainNumber != trainNumber+valNumber-1 { valFileList = fileList[trainNumber : trainNumber+valNumber-1] } else { valFileList = make([]model.FileManager, 0) valFileList = append(valFileList, fileList[trainNumber]) } if trainNumber+valNumber < len(fileList) { testFileList = fileList[trainNumber+valNumber:] } else { testFileList = make([]model.FileManager, 0) testFileList = append(testFileList, fileList[trainNumber+valNumber]) } train := new(model.TrainingDataset) h, err = rp.engine.Where("name = ?", req.TrainName).Get(train) if err != nil { goto ReturnPoint } if !h { train.Name = req.TrainName train.DatasetDesc = req.TrainDesc //train.DatasetId = req.DatasetId train.CategoryId = dataset.CategoryId _, err = rp.engine.Insert(train) if err != nil { goto ReturnPoint } } log.TargetData = int64(req.TargetData) log.DatasetId = req.DatasetId log.TrainingDatasetId = train.DatasetId log.SplitMethod = req.SplitMethod log.TrainNumber = int64(trainNumber) log.ValidationNumber = int64(valNumber) log.TestNumber = int64(testNumber) log.Creator = req.UserId _, err = rp.engine.Insert(log) if err != nil { goto ReturnPoint } //wg.Add(3) go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, req.TrainName, rp) //, &wg go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, req.TrainName, rp) //, &wg go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, req.TrainName, rp) //, &wg //wg.Wait() rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = log return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, trainName string, rp *repo) { //, wg *sync.WaitGroup batchList := make([]model.TrainingDatasetDetail, len(list)) for k, v := range list { dir := "no_disease" if v.IsDisease == 1 { dir = "disease" } utils.DownloadMinioFileToLocalPath(v.AccessUrl, path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir), v.FileName, rp.AppConfig.Minio.Protocol, rp.AppConfig.Minio.Endpoint, rp.AppConfig.Minio.Bucket, rp.AppConfig.Minio.AccessKeyId, rp.AppConfig.Minio.SecretAccessKey, rp.logger) item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir, v.FileName), DatasetId: trainId, CategoryId: categoryId, FileSize: v.FileSize, FileMd5: v.FileMd5, IsDisease: v.IsDisease, OperationLogId: logId, Creator: userId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } batchList[k] = item } _, _ = rp.engine.Insert(batchList) //wg.Done() } func (rp *repo) TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: type QuantityStatistics struct { DatasetId int64 CategoryId int Total int64 } var ( count int64 list []proto.TrainingDataset ) trainingList := make([]model.TrainingDataset, 0) count, err = rp.engine.Where("(?=0 or dataset_id = ?)", req.DatasetId, req.DatasetId). And("(?= 0 or category_id = ?)", req.BizType, req.BizType). And("(? ='' or name like ?)", req.TrainName, "%"+req.TrainName+"%"). Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&trainingList) if err != nil { goto ReturnPoint } list = make([]proto.TrainingDataset, len(trainingList)) for k, v := range trainingList { qs := make([]QuantityStatistics, 0) err = rp.engine.SQL("select dataset_id, category_id, count(1) as total from training_dataset_detail where dataset_id = ? group by category_id, dataset_id", v.DatasetId).Find(&qs) if err != nil { goto ReturnPoint } item := proto.TrainingDataset{ DatasetId: v.DatasetId, Name: v.Name, CategoryId: v.CategoryId, DatasetDesc: v.DatasetDesc, TotalSize: 0, TrainSize: 0, ValSize: 0, TestSize: 0, StoreName: v.StoreName, CreateAt: v.CreateAt, UpdateAt: v.UpdateAt, } for _, val := range qs { switch val.CategoryId { case 1: item.TrainSize = val.Total case 2: item.ValSize = val.Total case 3: item.TestSize = val.Total } item.TotalSize += val.Total } list[k] = item } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, list, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( count int64 list []proto.TrainingDatasetFileItem ) fileList := make([]model.TrainingDatasetDetail, 0) count, err = rp.engine.Where("dataset_id = ?", req.DatasetId). Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&fileList) if err != nil { goto ReturnPoint } list = make([]proto.TrainingDatasetFileItem, len(fileList)) for k, v := range fileList { buff := utils.ReadFile(v.FilePath) img := utils.BuffToImage(buff) buf := utils.ImageToBuff(img, "jpeg") list[k] = proto.TrainingDatasetFileItem{ DetailId: v.DetailId, FileName: v.FileName, FileSize: v.FileSize, FilePath: v.FilePath, FileContent: "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()), IsDisease: v.IsDisease, CategoryId: v.CategoryId, } } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, list, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err }