hpds_jkw_web/internal/service/dataset.go

354 lines
10 KiB
Go

package service
import (
"context"
"fmt"
"git.hpds.cc/Component/logging"
"hpds-iot-web/config"
"hpds-iot-web/internal/proto"
"hpds-iot-web/model"
"math"
"math/rand"
"net/http"
"time"
"xorm.io/xorm"
)
type DatasetService interface {
GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error)
DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error)
ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error)
DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error)
CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error)
}
func NewDatasetService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DatasetService {
return &repo{
AppConfig: cfg,
engine: engine,
logger: logger,
}
}
func (rp *repo) GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
ownerList := make([]model.Owner, 0)
err := rp.engine.Where("(? = '' or owner_name like ?)", req.Key, "%"+req.Key+"%").
And("status = 1").Find(&ownerList)
if err != nil {
goto ReturnPoint
}
data := make([]proto.OwnerProjectItem, len(ownerList))
for k, v := range ownerList {
projectList := make([]proto.ProjectItem, 0)
err = rp.engine.Table("project").Cols("concat('"+fmt.Sprintf("%d", v.OwnerId)+"-', project_id) as project_id", "project_name").
Where("owner_id = ?", v.OwnerId).And("status = 1").Find(&projectList)
if err != nil {
goto ReturnPoint
}
data[k] = proto.OwnerProjectItem{
OwnerId: v.OwnerId,
OwnerName: fmt.Sprintf("%s[%s]", v.OwnerName, v.ChargeUser),
ProjectList: projectList,
}
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Data = data
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
list := make([]model.Dataset, 0)
st := rp.engine.Where("(? = '' or dataset_name like ?)", req.DatasetName, "%"+req.DatasetName+"%").
And("(? = '' or create_at >= ?)", req.StartTime, req.StartTime).
And("(? = '' or create_at < ?)", req.EndTime, req.EndTime).
And("(? = 0 or owner_id = ?)", req.OwnerId, req.OwnerId).
And("status = 1")
if len(req.ProjectId) > 0 {
st.In("project_id", req.ProjectId)
}
count, err := st.Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list)
if err != nil {
goto ReturnPoint
}
data := make([]proto.DatasetItem, len(list))
for k, v := range list {
detailList := make([]model.FileManager, 0)
datasetCount, err := rp.engine.Where("dataset_id = ?", v.DatasetId).And("data_type=1").FindAndCount(&detailList)
if err != nil {
goto ReturnPoint
}
fm := new(model.FileManager)
datasetSize, err := rp.engine.Where("dataset_id = ?", v.DatasetId).And("data_type=1").SumInt(fm, "file_size")
if err != nil {
goto ReturnPoint
}
data[k] = proto.DatasetItem{
DatasetId: v.DatasetId,
DatasetName: v.DatasetName,
DatasetDesc: v.DatasetDesc,
StoreName: v.StoreName,
CategoryId: v.CategoryId,
ProjectId: v.ProjectId,
OwnerId: v.OwnerId,
Creator: v.Creator,
CreateAt: v.CreateAt,
DatasetCount: datasetCount,
DatasetSize: datasetSize,
}
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp = FillPaging(count, req.Page, req.Size, data, rsp)
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
item := &model.Dataset{
DatasetName: req.DatasetName,
DatasetDesc: req.DatasetDesc,
CategoryId: req.CategoryId,
ProjectId: req.ProjectId,
OwnerId: req.OwnerId,
StoreName: req.StoreName,
Creator: req.Creator,
Status: 1,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
_, err = rp.engine.Insert(item)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Data = item
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
item := new(model.Dataset)
var b bool
b, err = rp.engine.ID(req.DatasetId).Get(item)
if err != nil {
goto ReturnPoint
}
if !b {
err = fmt.Errorf("未能找到对应的数据集")
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = item
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
h bool
trainFileList []model.FileManager
valFileList []model.FileManager
testFileList []model.FileManager
//wg sync.WaitGroup
)
log := new(model.DatasetOperationLog)
dataset := new(model.Dataset)
h, err = rp.engine.ID(req.DatasetId).Get(dataset)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的采集数据集")
goto ReturnPoint
}
fileList := make([]model.FileManager, 0)
err = rp.engine.Where("dataset_id = ?", req.DatasetId).Find(&fileList)
if err != nil {
goto ReturnPoint
}
if req.TargetData == 0 {
req.TargetData = len(fileList)
}
if req.TargetData > len(fileList) {
err = fmt.Errorf("超出现有数据集数量")
goto ReturnPoint
}
if req.SplitMethod == 1 { //随机
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(fileList), func(i, j int) {
fileList[i], fileList[j] = fileList[j], fileList[i]
})
}
trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100))
valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100))
testNumber := req.TargetData - trainNumber - valNumber
trainFileList = fileList[:trainNumber-1]
valFileList = fileList[trainNumber : valNumber-1]
testFileList = fileList[valNumber:]
train := new(model.TrainingDataset)
h, err = rp.engine.Where("name = ?", req.TrainName).Get(train)
if err != nil {
goto ReturnPoint
}
if !h {
train.Name = req.TrainName
train.DatasetDesc = req.TrainDesc
train.DatasetId = req.DatasetId
train.CategoryId = dataset.CategoryId
_, err = rp.engine.Insert(train)
if err != nil {
goto ReturnPoint
}
}
log.TargetData = int64(req.TargetData)
log.DatasetId = req.DatasetId
log.TrainingDatasetId = train.DatasetId
log.SplitMethod = req.SplitMethod
log.TrainNumber = int64(trainNumber)
log.ValidationNumber = int64(valNumber)
log.TestNumber = int64(testNumber)
log.Creator = req.UserId
_, err = rp.engine.Insert(log)
if err != nil {
goto ReturnPoint
}
//wg.Add(3)
go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, rp.engine) //, &wg
go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, rp.engine) //, &wg
go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, rp.engine) //, &wg
//wg.Wait()
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = log
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, engine *xorm.Engine) { //, wg *sync.WaitGroup
batchList := make([]model.TrainingDatasetDetail, len(list))
for k, v := range list {
item := model.TrainingDatasetDetail{
FileName: v.FileName,
FilePath: v.AccessUrl,
DatasetId: trainId,
CategoryId: categoryId,
FileSize: v.FileSize,
FileMd5: v.FileMd5,
OperationLogId: logId,
Creator: userId,
CreateAt: time.Now().Unix(),
UpdateAt: time.Now().Unix(),
}
batchList[k] = item
}
_, _ = engine.Insert(batchList)
//wg.Done()
}