529 lines
16 KiB
Go
529 lines
16 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"git.hpds.cc/Component/logging"
|
|
"hpds-iot-web/config"
|
|
"hpds-iot-web/internal/proto"
|
|
"hpds-iot-web/model"
|
|
"hpds-iot-web/pkg/utils"
|
|
"math"
|
|
"math/rand"
|
|
"net/http"
|
|
"path"
|
|
"time"
|
|
"xorm.io/xorm"
|
|
)
|
|
|
|
type DatasetService interface {
|
|
GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error)
|
|
DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error)
|
|
ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error)
|
|
DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error)
|
|
|
|
CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error)
|
|
TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error)
|
|
TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error)
|
|
}
|
|
|
|
func NewDatasetService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DatasetService {
|
|
return &repo{
|
|
AppConfig: cfg,
|
|
engine: engine,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (rp *repo) GetOwnerProjectList(ctx context.Context, req proto.OwnerProjectRequest) (rsp *proto.BaseResponse, err error) {
|
|
rsp = new(proto.BaseResponse)
|
|
select {
|
|
case <-ctx.Done():
|
|
err = fmt.Errorf("超时/取消")
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Message = "超时/取消"
|
|
rsp.Err = ctx.Err()
|
|
return rsp, ctx.Err()
|
|
default:
|
|
ownerList := make([]model.Owner, 0)
|
|
err := rp.engine.Where("(? = '' or owner_name like ?)", req.Key, "%"+req.Key+"%").
|
|
And("status = 1").Find(&ownerList)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
data := make([]proto.OwnerProjectItem, len(ownerList))
|
|
for k, v := range ownerList {
|
|
projectList := make([]proto.ProjectItem, 0)
|
|
err = rp.engine.Table("project").Cols("concat('"+fmt.Sprintf("%d", v.OwnerId)+"-', project_id) as project_id", "project_name").
|
|
Where("owner_id = ?", v.OwnerId).And("status = 1").Find(&projectList)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
data[k] = proto.OwnerProjectItem{
|
|
OwnerId: v.OwnerId,
|
|
OwnerName: fmt.Sprintf("%s[%s]", v.OwnerName, v.ChargeUser),
|
|
ProjectList: projectList,
|
|
}
|
|
}
|
|
rsp.Code = http.StatusOK
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
rsp.Message = "成功"
|
|
rsp.Data = data
|
|
rsp.Err = err
|
|
return rsp, err
|
|
}
|
|
ReturnPoint:
|
|
if err != nil {
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Err = err
|
|
rsp.Message = "失败"
|
|
}
|
|
return rsp, err
|
|
}
|
|
|
|
func (rp *repo) DatasetList(ctx context.Context, req proto.DatasetRequest) (rsp *proto.BaseResponse, err error) {
|
|
rsp = new(proto.BaseResponse)
|
|
select {
|
|
case <-ctx.Done():
|
|
err = fmt.Errorf("超时/取消")
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Message = "超时/取消"
|
|
rsp.Err = ctx.Err()
|
|
return rsp, ctx.Err()
|
|
default:
|
|
list := make([]model.Dataset, 0)
|
|
st := rp.engine.Where("(? = '' or dataset_name like ?)", req.DatasetName, "%"+req.DatasetName+"%").
|
|
And("(? = '' or create_at >= ?)", req.StartTime, req.StartTime).
|
|
And("(? = '' or create_at < ?)", req.EndTime, req.EndTime).
|
|
And("(? = 0 or owner_id = ?)", req.OwnerId, req.OwnerId).
|
|
And("status = 1")
|
|
if len(req.ProjectId) > 0 {
|
|
st.In("project_id", req.ProjectId)
|
|
}
|
|
count, err := st.Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
|
|
type QuantityStatistics struct {
|
|
TotalNumber int64
|
|
TotalSize int64
|
|
}
|
|
data := make([]proto.DatasetItem, len(list))
|
|
for k, v := range list {
|
|
detailList := make([]model.FileManager, 0)
|
|
datasetCount, err := rp.engine.Where("dataset_id = ?", v.DatasetId).FindAndCount(&detailList)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
fm := new(model.FileManager)
|
|
datasetSize, err := rp.engine.Where("dataset_id = ?", v.DatasetId).SumInt(fm, "file_size")
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
qs := new(QuantityStatistics)
|
|
_, err = rp.engine.SQL(`select sum(file_size) total_size, count(file_id) total_number from file_manager where is_disease > 0`).Get(qs)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
data[k] = proto.DatasetItem{
|
|
DatasetId: v.DatasetId,
|
|
DatasetName: v.DatasetName,
|
|
DatasetDesc: v.DatasetDesc,
|
|
StoreName: v.StoreName,
|
|
CategoryId: v.CategoryId,
|
|
ProjectId: v.ProjectId,
|
|
OwnerId: v.OwnerId,
|
|
Creator: v.Creator,
|
|
CreateAt: v.CreateAt,
|
|
DatasetCount: datasetCount,
|
|
DatasetSize: datasetSize,
|
|
LabelCount: qs.TotalNumber,
|
|
LabelSize: qs.TotalSize,
|
|
}
|
|
}
|
|
rsp.Code = http.StatusOK
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
rsp.Message = "成功"
|
|
rsp = FillPaging(count, req.Page, req.Size, data, rsp)
|
|
rsp.Err = err
|
|
return rsp, err
|
|
}
|
|
ReturnPoint:
|
|
if err != nil {
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Err = err
|
|
rsp.Message = "失败"
|
|
}
|
|
return rsp, err
|
|
}
|
|
|
|
func (rp *repo) ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error) {
|
|
rsp = new(proto.BaseResponse)
|
|
select {
|
|
case <-ctx.Done():
|
|
err = fmt.Errorf("超时/取消")
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Message = "超时/取消"
|
|
rsp.Err = ctx.Err()
|
|
return rsp, ctx.Err()
|
|
default:
|
|
item := &model.Dataset{
|
|
DatasetName: req.DatasetName,
|
|
DatasetDesc: req.DatasetDesc,
|
|
CategoryId: req.CategoryId,
|
|
ProjectId: req.ProjectId,
|
|
OwnerId: req.OwnerId,
|
|
StoreName: req.StoreName,
|
|
Creator: req.Creator,
|
|
Status: 1,
|
|
CreateAt: time.Now().Unix(),
|
|
UpdateAt: time.Now().Unix(),
|
|
}
|
|
_, err = rp.engine.Insert(item)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
rsp.Code = http.StatusOK
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
rsp.Message = "成功"
|
|
rsp.Data = item
|
|
rsp.Err = err
|
|
return rsp, err
|
|
}
|
|
ReturnPoint:
|
|
if err != nil {
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Err = err
|
|
rsp.Message = "失败"
|
|
}
|
|
return rsp, err
|
|
}
|
|
|
|
func (rp *repo) DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error) {
|
|
rsp = new(proto.BaseResponse)
|
|
select {
|
|
case <-ctx.Done():
|
|
err = fmt.Errorf("超时/取消")
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Message = "超时/取消"
|
|
rsp.Err = ctx.Err()
|
|
return rsp, ctx.Err()
|
|
default:
|
|
item := new(model.Dataset)
|
|
var b bool
|
|
b, err = rp.engine.ID(req.DatasetId).Get(item)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
if !b {
|
|
err = fmt.Errorf("未能找到对应的数据集")
|
|
goto ReturnPoint
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
rsp.Message = "成功"
|
|
rsp.Err = err
|
|
rsp.Data = item
|
|
return rsp, err
|
|
}
|
|
ReturnPoint:
|
|
if err != nil {
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Err = err
|
|
rsp.Message = "失败"
|
|
}
|
|
return rsp, err
|
|
}
|
|
|
|
func (rp *repo) CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) {
|
|
rsp = new(proto.BaseResponse)
|
|
select {
|
|
case <-ctx.Done():
|
|
err = fmt.Errorf("超时/取消")
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Message = "超时/取消"
|
|
rsp.Err = ctx.Err()
|
|
return rsp, ctx.Err()
|
|
default:
|
|
var (
|
|
h bool
|
|
trainFileList []model.FileManager
|
|
valFileList []model.FileManager
|
|
testFileList []model.FileManager
|
|
//wg sync.WaitGroup
|
|
)
|
|
log := new(model.DatasetOperationLog)
|
|
dataset := new(model.Dataset)
|
|
h, err = rp.engine.ID(req.DatasetId).Get(dataset)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
if !h {
|
|
err = fmt.Errorf("未能找到对应的采集数据集")
|
|
goto ReturnPoint
|
|
}
|
|
fileList := make([]model.FileManager, 0)
|
|
err = rp.engine.Where("dataset_id = ?", req.DatasetId).And("is_disease > 0").Find(&fileList)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
if req.TargetData == 0 {
|
|
req.TargetData = len(fileList)
|
|
}
|
|
if req.TargetData > len(fileList) {
|
|
err = fmt.Errorf("超出现有标注数据集数量")
|
|
goto ReturnPoint
|
|
}
|
|
if req.SplitMethod == 1 { //随机
|
|
rand.Seed(time.Now().UnixNano())
|
|
rand.Shuffle(len(fileList), func(i, j int) {
|
|
fileList[i], fileList[j] = fileList[j], fileList[i]
|
|
})
|
|
}
|
|
trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100))
|
|
valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100))
|
|
testNumber := req.TargetData - trainNumber - valNumber
|
|
if trainNumber-1 > 1 {
|
|
trainFileList = fileList[:trainNumber-1]
|
|
} else {
|
|
trainFileList = make([]model.FileManager, 0)
|
|
trainFileList = append(trainFileList, fileList[0])
|
|
}
|
|
if trainNumber != trainNumber+valNumber-1 {
|
|
valFileList = fileList[trainNumber : trainNumber+valNumber-1]
|
|
} else {
|
|
valFileList = make([]model.FileManager, 0)
|
|
valFileList = append(valFileList, fileList[trainNumber])
|
|
}
|
|
if trainNumber+valNumber < len(fileList) {
|
|
testFileList = fileList[trainNumber+valNumber:]
|
|
} else {
|
|
testFileList = make([]model.FileManager, 0)
|
|
testFileList = append(testFileList, fileList[trainNumber+valNumber])
|
|
}
|
|
|
|
train := new(model.TrainingDataset)
|
|
h, err = rp.engine.Where("name = ?", req.TrainName).Get(train)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
if !h {
|
|
train.Name = req.TrainName
|
|
train.DatasetDesc = req.TrainDesc
|
|
//train.DatasetId = req.DatasetId
|
|
train.CategoryId = dataset.CategoryId
|
|
_, err = rp.engine.Insert(train)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
}
|
|
log.TargetData = int64(req.TargetData)
|
|
log.DatasetId = req.DatasetId
|
|
log.TrainingDatasetId = train.DatasetId
|
|
log.SplitMethod = req.SplitMethod
|
|
log.TrainNumber = int64(trainNumber)
|
|
log.ValidationNumber = int64(valNumber)
|
|
log.TestNumber = int64(testNumber)
|
|
log.Creator = req.UserId
|
|
_, err = rp.engine.Insert(log)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
//wg.Add(3)
|
|
go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, req.TrainName, rp) //, &wg
|
|
go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, req.TrainName, rp) //, &wg
|
|
go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, req.TrainName, rp) //, &wg
|
|
//wg.Wait()
|
|
rsp.Code = http.StatusOK
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
rsp.Message = "成功"
|
|
rsp.Err = err
|
|
rsp.Data = log
|
|
return rsp, err
|
|
}
|
|
ReturnPoint:
|
|
if err != nil {
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Err = err
|
|
rsp.Message = "失败"
|
|
}
|
|
return rsp, err
|
|
}
|
|
|
|
func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, trainName string, rp *repo) { //, wg *sync.WaitGroup
|
|
batchList := make([]model.TrainingDatasetDetail, len(list))
|
|
for k, v := range list {
|
|
dir := "no_disease"
|
|
if v.IsDisease == 1 {
|
|
dir = "disease"
|
|
}
|
|
utils.DownloadMinioFileToLocalPath(v.AccessUrl, path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir), v.FileName,
|
|
rp.AppConfig.Minio.Protocol, rp.AppConfig.Minio.Endpoint, rp.AppConfig.Minio.Bucket, rp.AppConfig.Minio.AccessKeyId,
|
|
rp.AppConfig.Minio.SecretAccessKey, rp.logger)
|
|
item := model.TrainingDatasetDetail{
|
|
FileName: v.FileName,
|
|
FilePath: path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir, v.FileName),
|
|
DatasetId: trainId,
|
|
CategoryId: categoryId,
|
|
FileSize: v.FileSize,
|
|
FileMd5: v.FileMd5,
|
|
IsDisease: v.IsDisease,
|
|
OperationLogId: logId,
|
|
Creator: userId,
|
|
CreateAt: time.Now().Unix(),
|
|
UpdateAt: time.Now().Unix(),
|
|
}
|
|
batchList[k] = item
|
|
}
|
|
_, _ = rp.engine.Insert(batchList)
|
|
//wg.Done()
|
|
}
|
|
|
|
func (rp *repo) TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) {
|
|
rsp = new(proto.BaseResponse)
|
|
select {
|
|
case <-ctx.Done():
|
|
err = fmt.Errorf("超时/取消")
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Message = "超时/取消"
|
|
rsp.Err = ctx.Err()
|
|
return rsp, ctx.Err()
|
|
default:
|
|
type QuantityStatistics struct {
|
|
DatasetId int64
|
|
CategoryId int
|
|
Total int64
|
|
}
|
|
|
|
var (
|
|
count int64
|
|
list []proto.TrainingDataset
|
|
)
|
|
trainingList := make([]model.TrainingDataset, 0)
|
|
count, err = rp.engine.Where("(?=0 or dataset_id = ?)", req.DatasetId, req.DatasetId).
|
|
And("(?= 0 or category_id = ?)", req.BizType, req.BizType).
|
|
And("(? ='' or name like ?)", req.TrainName, "%"+req.TrainName+"%").
|
|
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&trainingList)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
list = make([]proto.TrainingDataset, len(trainingList))
|
|
for k, v := range trainingList {
|
|
qs := make([]QuantityStatistics, 0)
|
|
err = rp.engine.SQL("select dataset_id, category_id, count(1) as total from training_dataset_detail where dataset_id = ? group by category_id, dataset_id", v.DatasetId).Find(&qs)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
item := proto.TrainingDataset{
|
|
DatasetId: v.DatasetId,
|
|
Name: v.Name,
|
|
CategoryId: v.CategoryId,
|
|
DatasetDesc: v.DatasetDesc,
|
|
TotalSize: 0,
|
|
TrainSize: 0,
|
|
ValSize: 0,
|
|
TestSize: 0,
|
|
StoreName: v.StoreName,
|
|
CreateAt: v.CreateAt,
|
|
UpdateAt: v.UpdateAt,
|
|
}
|
|
for _, val := range qs {
|
|
switch val.CategoryId {
|
|
case 1:
|
|
item.TrainSize = val.Total
|
|
case 2:
|
|
item.ValSize = val.Total
|
|
case 3:
|
|
item.TestSize = val.Total
|
|
}
|
|
item.TotalSize += val.Total
|
|
}
|
|
list[k] = item
|
|
}
|
|
rsp.Code = http.StatusOK
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
rsp.Message = "成功"
|
|
rsp = FillPaging(count, req.Page, req.Size, list, rsp)
|
|
rsp.Err = err
|
|
return rsp, err
|
|
}
|
|
|
|
ReturnPoint:
|
|
if err != nil {
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Err = err
|
|
rsp.Message = "失败"
|
|
}
|
|
return rsp, err
|
|
}
|
|
|
|
func (rp *repo) TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) {
|
|
rsp = new(proto.BaseResponse)
|
|
select {
|
|
case <-ctx.Done():
|
|
err = fmt.Errorf("超时/取消")
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Message = "超时/取消"
|
|
rsp.Err = ctx.Err()
|
|
return rsp, ctx.Err()
|
|
default:
|
|
var (
|
|
count int64
|
|
list []proto.TrainingDatasetFileItem
|
|
)
|
|
fileList := make([]model.TrainingDatasetDetail, 0)
|
|
count, err = rp.engine.Where("dataset_id = ?", req.DatasetId).
|
|
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&fileList)
|
|
if err != nil {
|
|
goto ReturnPoint
|
|
}
|
|
list = make([]proto.TrainingDatasetFileItem, len(fileList))
|
|
for k, v := range fileList {
|
|
buff := utils.ReadFile(v.FilePath)
|
|
img := utils.BuffToImage(buff)
|
|
buf := utils.ImageToBuff(img, "jpeg")
|
|
list[k] = proto.TrainingDatasetFileItem{
|
|
DetailId: v.DetailId,
|
|
FileName: v.FileName,
|
|
FileSize: v.FileSize,
|
|
FilePath: v.FilePath,
|
|
FileContent: "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()),
|
|
IsDisease: v.IsDisease,
|
|
CategoryId: v.CategoryId,
|
|
}
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
rsp.Message = "成功"
|
|
rsp = FillPaging(count, req.Page, req.Size, list, rsp)
|
|
rsp.Err = err
|
|
return rsp, err
|
|
}
|
|
|
|
ReturnPoint:
|
|
if err != nil {
|
|
rsp.Code = http.StatusInternalServerError
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
rsp.Err = err
|
|
rsp.Message = "失败"
|
|
}
|
|
return rsp, err
|
|
}
|