package service import ( "context" "fmt" "git.hpds.cc/Component/logging" "hpds-iot-web/config" "hpds-iot-web/internal/proto" "hpds-iot-web/model" "math" "math/rand" "net/http" "time" "xorm.io/xorm" ) type DatasetService interface { GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error) DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error) ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error) DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) } func NewDatasetService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DatasetService { return &repo{ AppConfig: cfg, engine: engine, logger: logger, } } func (rp *repo) GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: ownerList := make([]model.Owner, 0) err := rp.engine.Where("(? = '' or owner_name like ?)", req.Key, "%"+req.Key+"%"). And("status = 1").Find(&ownerList) if err != nil { goto ReturnPoint } data := make([]proto.OwnerProjectItem, len(ownerList)) for k, v := range ownerList { projectList := make([]proto.ProjectItem, 0) err = rp.engine.Table("project").Cols("concat('"+fmt.Sprintf("%d", v.OwnerId)+"-', project_id) as project_id", "project_name"). Where("owner_id = ?", v.OwnerId).And("status = 1").Find(&projectList) if err != nil { goto ReturnPoint } data[k] = proto.OwnerProjectItem{ OwnerId: v.OwnerId, OwnerName: fmt.Sprintf("%s[%s]", v.OwnerName, v.ChargeUser), ProjectList: projectList, } } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Data = data rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: list := make([]model.Dataset, 0) st := rp.engine.Where("(? = '' or dataset_name like ?)", req.DatasetName, "%"+req.DatasetName+"%"). And("(? = '' or create_at >= ?)", req.StartTime, req.StartTime). And("(? = '' or create_at < ?)", req.EndTime, req.EndTime). And("(? = 0 or owner_id = ?)", req.OwnerId, req.OwnerId). And("status = 1") if len(req.ProjectId) > 0 { st.In("project_id", req.ProjectId) } count, err := st.Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list) if err != nil { goto ReturnPoint } data := make([]proto.DatasetItem, len(list)) for k, v := range list { detailList := make([]model.FileManager, 0) datasetCount, err := rp.engine.Where("dataset_id = ?", v.DatasetId).And("data_type=1").FindAndCount(&detailList) if err != nil { goto ReturnPoint } fm := new(model.FileManager) datasetSize, err := rp.engine.Where("dataset_id = ?", v.DatasetId).And("data_type=1").SumInt(fm, "file_size") if err != nil { goto ReturnPoint } data[k] = proto.DatasetItem{ DatasetId: v.DatasetId, DatasetName: v.DatasetName, DatasetDesc: v.DatasetDesc, StoreName: v.StoreName, CategoryId: v.CategoryId, ProjectId: v.ProjectId, OwnerId: v.OwnerId, Creator: v.Creator, CreateAt: v.CreateAt, DatasetCount: datasetCount, DatasetSize: datasetSize, } } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp = FillPaging(count, req.Page, req.Size, data, rsp) rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: item := &model.Dataset{ DatasetName: req.DatasetName, DatasetDesc: req.DatasetDesc, CategoryId: req.CategoryId, ProjectId: req.ProjectId, OwnerId: req.OwnerId, StoreName: req.StoreName, Creator: req.Creator, Status: 1, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } _, err = rp.engine.Insert(item) if err != nil { goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Data = item rsp.Err = err return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: item := new(model.Dataset) var b bool b, err = rp.engine.ID(req.DatasetId).Get(item) if err != nil { goto ReturnPoint } if !b { err = fmt.Errorf("未能找到对应的数据集") goto ReturnPoint } rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = item return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) { rsp = new(proto.BaseResponse) select { case <-ctx.Done(): err = fmt.Errorf("超时/取消") rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Message = "超时/取消" rsp.Err = ctx.Err() return rsp, ctx.Err() default: var ( h bool trainFileList []model.FileManager valFileList []model.FileManager testFileList []model.FileManager //wg sync.WaitGroup ) log := new(model.DatasetOperationLog) dataset := new(model.Dataset) h, err = rp.engine.ID(req.DatasetId).Get(dataset) if err != nil { goto ReturnPoint } if !h { err = fmt.Errorf("未能找到对应的采集数据集") goto ReturnPoint } fileList := make([]model.FileManager, 0) err = rp.engine.Where("dataset_id = ?", req.DatasetId).Find(&fileList) if err != nil { goto ReturnPoint } if req.TargetData == 0 { req.TargetData = len(fileList) } if req.TargetData > len(fileList) { err = fmt.Errorf("超出现有数据集数量") goto ReturnPoint } if req.SplitMethod == 1 { //随机 rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(fileList), func(i, j int) { fileList[i], fileList[j] = fileList[j], fileList[i] }) } trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100)) valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100)) testNumber := req.TargetData - trainNumber - valNumber trainFileList = fileList[:trainNumber-1] valFileList = fileList[trainNumber : valNumber-1] testFileList = fileList[valNumber:] train := new(model.TrainingDataset) h, err = rp.engine.Where("name = ?", req.TrainName).Get(train) if err != nil { goto ReturnPoint } if !h { train.Name = req.TrainName train.DatasetDesc = req.TrainDesc train.DatasetId = req.DatasetId train.CategoryId = dataset.CategoryId _, err = rp.engine.Insert(train) if err != nil { goto ReturnPoint } } log.TargetData = int64(req.TargetData) log.DatasetId = req.DatasetId log.TrainingDatasetId = train.DatasetId log.SplitMethod = req.SplitMethod log.TrainNumber = int64(trainNumber) log.ValidationNumber = int64(valNumber) log.TestNumber = int64(testNumber) log.Creator = req.UserId _, err = rp.engine.Insert(log) if err != nil { goto ReturnPoint } //wg.Add(3) go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, rp.engine) //, &wg go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, rp.engine) //, &wg go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, rp.engine) //, &wg //wg.Wait() rsp.Code = http.StatusOK rsp.Status = http.StatusText(http.StatusOK) rsp.Message = "成功" rsp.Err = err rsp.Data = log return rsp, err } ReturnPoint: if err != nil { rsp.Code = http.StatusInternalServerError rsp.Status = http.StatusText(http.StatusInternalServerError) rsp.Err = err rsp.Message = "失败" } return rsp, err } func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, engine *xorm.Engine) { //, wg *sync.WaitGroup batchList := make([]model.TrainingDatasetDetail, len(list)) for k, v := range list { item := model.TrainingDatasetDetail{ FileName: v.FileName, FilePath: v.AccessUrl, DatasetId: trainId, CategoryId: categoryId, FileSize: v.FileSize, FileMd5: v.FileMd5, OperationLogId: logId, Creator: userId, CreateAt: time.Now().Unix(), UpdateAt: time.Now().Unix(), } batchList[k] = item } _, _ = engine.Insert(batchList) //wg.Done() }