2023-01-12 10:21:40 +08:00
|
|
|
package service
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"git.hpds.cc/Component/logging"
|
2023-06-17 09:38:26 +08:00
|
|
|
"go.uber.org/zap"
|
2023-01-12 10:21:40 +08:00
|
|
|
"hpds-iot-web/config"
|
|
|
|
"hpds-iot-web/internal/proto"
|
|
|
|
"hpds-iot-web/model"
|
2023-06-17 09:38:26 +08:00
|
|
|
"hpds-iot-web/pkg/utils"
|
|
|
|
"math"
|
|
|
|
"math/rand"
|
2023-01-12 10:21:40 +08:00
|
|
|
"net/http"
|
|
|
|
"time"
|
|
|
|
"xorm.io/xorm"
|
|
|
|
)
|
|
|
|
|
|
|
|
type DiseaseService interface {
|
2023-03-23 18:03:09 +08:00
|
|
|
DiseaseList(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error)
|
2023-06-17 09:38:26 +08:00
|
|
|
DiseaseListNew(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error)
|
|
|
|
DiseaseStatistics(ctx context.Context) (rsp *proto.BaseResponse, err error)
|
2023-01-12 10:21:40 +08:00
|
|
|
DiseaseTypeList(ctx context.Context, req proto.DiseaseTypeRequest) (rsp *proto.BaseResponse, err error)
|
|
|
|
AddDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error)
|
|
|
|
EditDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error)
|
|
|
|
DeleteDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error)
|
2023-06-17 09:38:26 +08:00
|
|
|
|
|
|
|
CreateTrainDatasetByDisease(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error)
|
2023-01-12 10:21:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
func NewDiseaseService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DiseaseService {
|
|
|
|
return &repo{
|
|
|
|
AppConfig: cfg,
|
|
|
|
engine: engine,
|
|
|
|
logger: logger,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-23 18:03:09 +08:00
|
|
|
func (rp *repo) DiseaseList(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
data := make([]model.Disease, 0)
|
|
|
|
count, err := rp.engine.Where("(? = '' or disease_name like ?)", req.Key, "%"+req.Key+"%").
|
2023-06-17 09:38:26 +08:00
|
|
|
And("(? = 0 or category_id = ?)", req.DiseaseType, req.DiseaseType).
|
2023-03-23 18:03:09 +08:00
|
|
|
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).
|
|
|
|
FindAndCount(&data)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "成功"
|
|
|
|
rsp = FillPaging(count, req.Page, req.Size, data, rsp)
|
|
|
|
rsp.Err = err
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
|
2023-06-17 09:38:26 +08:00
|
|
|
func (rp *repo) DiseaseListNew(ctx context.Context, req proto.DiseaseRequest) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
var count int64
|
|
|
|
list := make([]model.LabelFile, 0)
|
|
|
|
count, err = rp.engine.Where("(? = 0 or category_id = ?)", req.DiseaseType, req.DiseaseType).
|
|
|
|
And("(?=0 or file_type = ?)", req.FileType, req.FileType).
|
|
|
|
And("(?=0 or label_type = ?)", req.LabelType, req.LabelType).
|
|
|
|
And("(?=-1 or pid = ?)", req.Pid, req.Pid).
|
|
|
|
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).
|
|
|
|
FindAndCount(&list)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
data := make([]proto.DiseaseFileInfoItem, len(list))
|
|
|
|
for k, v := range list {
|
|
|
|
item := proto.DiseaseFileInfoItem{
|
|
|
|
FileId: v.FileId,
|
|
|
|
FileName: v.FileName,
|
|
|
|
FilePath: v.FilePath,
|
|
|
|
CategoryId: v.CategoryId,
|
|
|
|
CategoryName: model.GetBizType(v.CategoryId),
|
|
|
|
FileSize: v.FileSize,
|
|
|
|
LabelType: v.LabelType,
|
|
|
|
FileType: v.FileType,
|
|
|
|
FileTypeName: model.GetFileType(v.FileType),
|
|
|
|
LabelTypeName: model.GetLabelType(v.LabelType),
|
|
|
|
CreateAt: v.CreateAt,
|
|
|
|
UpdateAt: v.UpdateAt,
|
|
|
|
Pid: v.Pid,
|
|
|
|
}
|
|
|
|
if item.FileType == 1 || (item.FileType == 3 && item.Pid > 0) {
|
|
|
|
item.FileContent = utils.ImgFileToBase64(v.FilePath)
|
|
|
|
}
|
|
|
|
if item.FileType == 2 {
|
|
|
|
item.FileContent = utils.FileToBase64(v.FilePath)
|
|
|
|
}
|
|
|
|
|
|
|
|
data[k] = item
|
|
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "成功"
|
|
|
|
rsp = FillPaging(count, req.Page, req.Size, data, rsp)
|
|
|
|
rsp.Err = err
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (rp *repo) DiseaseStatistics(ctx context.Context) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
list := make([]*proto.DiseaseStatisticsItem, 5)
|
|
|
|
type Statistics struct {
|
|
|
|
CategoryId int
|
|
|
|
TotalCount int64
|
|
|
|
TotalSize int64
|
|
|
|
}
|
|
|
|
//所有的数据
|
|
|
|
statList := make([]Statistics, 0)
|
|
|
|
err = rp.engine.SQL(`select category_id, sum(file_size) total_size, count(file_id) total_count from label_file group by category_id;`).Find(&statList)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
totalItem := &proto.DiseaseStatisticsItem{
|
|
|
|
DiseaseType: 0,
|
|
|
|
DiseaseName: "数据量总计",
|
|
|
|
}
|
|
|
|
for _, v := range statList {
|
|
|
|
item := &proto.DiseaseStatisticsItem{
|
|
|
|
DiseaseType: v.CategoryId,
|
|
|
|
DiseaseName: model.GetBizType(v.CategoryId) + "数据",
|
|
|
|
TotalNum: v.TotalCount,
|
|
|
|
TotalSize: v.TotalSize,
|
|
|
|
}
|
|
|
|
totalItem.TotalNum += v.TotalCount
|
|
|
|
totalItem.TotalSize += v.TotalSize
|
|
|
|
list[v.CategoryId] = item
|
|
|
|
}
|
|
|
|
list[0] = totalItem
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "成功"
|
|
|
|
rsp.Data = list
|
|
|
|
rsp.Err = err
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
|
2023-01-12 10:21:40 +08:00
|
|
|
func (rp *repo) DiseaseTypeList(ctx context.Context, req proto.DiseaseTypeRequest) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
data := make([]model.DiseaseType, 0)
|
|
|
|
count, err := rp.engine.Where("(? = '' or type_name like ?)", req.Key, "%"+req.Key+"%").
|
|
|
|
And("(? = 0 or category_id = ?)", req.CategoryId, req.CategoryId).
|
|
|
|
And("status = 1").Limit(int(req.Size), int(((req.Page)-1)*req.Size)).
|
|
|
|
FindAndCount(&data)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "成功"
|
|
|
|
rsp = FillPaging(count, req.Page, req.Size, data, rsp)
|
|
|
|
rsp.Err = err
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
func (rp *repo) AddDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
item := &model.DiseaseType{
|
|
|
|
TypeName: req.TypeName,
|
|
|
|
CategoryId: req.CategoryId,
|
|
|
|
Status: 1,
|
|
|
|
CreateAt: time.Now().Unix(),
|
|
|
|
UpdateAt: time.Now().Unix(),
|
|
|
|
}
|
|
|
|
_, err = rp.engine.Insert(item)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "新增病害类型成功"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
rsp.Data = item
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
func (rp *repo) EditDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
var h bool
|
|
|
|
item := new(model.DiseaseType)
|
|
|
|
h, err = rp.engine.ID(req.TypeId).Get(item)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
if !h {
|
|
|
|
err = fmt.Errorf("未能找到对应的类型")
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
if len(req.TypeName) > 0 {
|
|
|
|
item.TypeName = req.TypeName
|
|
|
|
}
|
|
|
|
if req.CategoryId > 0 {
|
|
|
|
item.CategoryId = req.CategoryId
|
|
|
|
}
|
|
|
|
item.UpdateAt = time.Now().Unix()
|
|
|
|
_, err = rp.engine.ID(req.TypeId).AllCols().Update(item)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "修改病害类型成功"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
rsp.Data = item
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
func (rp *repo) DeleteDiseaseType(ctx context.Context, req proto.DiseaseTypeItemRequest) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
var h bool
|
|
|
|
item := new(model.DiseaseType)
|
|
|
|
h, err = rp.engine.ID(req.TypeId).Get(item)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
if !h {
|
|
|
|
err = fmt.Errorf("未能找到对应的类型")
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
item.Status = 0
|
|
|
|
item.UpdateAt = time.Now().Unix()
|
|
|
|
_, err = rp.engine.ID(req.TypeId).AllCols().Update(item)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "删除病害类型成功"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
rsp.Data = item
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|
2023-06-17 09:38:26 +08:00
|
|
|
|
|
|
|
func (rp *repo) CreateTrainDatasetByDisease(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) {
|
|
|
|
rsp = new(proto.BaseResponse)
|
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
err = fmt.Errorf("超时/取消")
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Message = "超时/取消"
|
|
|
|
rsp.Err = ctx.Err()
|
|
|
|
return rsp, ctx.Err()
|
|
|
|
default:
|
|
|
|
var (
|
|
|
|
h bool
|
|
|
|
//trainFileList []model.LabelFile
|
|
|
|
trainDiseaseFileList []model.LabelFile
|
|
|
|
trainNoDiseaseFileList []model.LabelFile
|
|
|
|
//valFileList []model.LabelFile
|
|
|
|
valDiseaseFileList []model.LabelFile
|
|
|
|
valNoDiseaseFileList []model.LabelFile
|
|
|
|
//testFileList []model.LabelFile
|
|
|
|
testDiseaseFileList []model.LabelFile
|
|
|
|
testNoDiseaseFileList []model.LabelFile
|
|
|
|
trainDiseaseCount int64
|
|
|
|
trainNoDiseaseCount int64
|
|
|
|
//wg sync.WaitGroup
|
|
|
|
)
|
|
|
|
log := new(model.DatasetOperationLog)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
fileList := make([]model.LabelFile, 0)
|
|
|
|
err = rp.engine.Where("category_id = ? and file_type = 1", req.BizType).Find(&fileList)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
trainDiseaseFileList = make([]model.LabelFile, 0)
|
|
|
|
err = rp.engine.Where("category_id = ? and label_type = 1 and file_type = 1", req.BizType).Find(&trainDiseaseFileList)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
trainNoDiseaseFileList = make([]model.LabelFile, 0)
|
|
|
|
err = rp.engine.Where("category_id = ? and label_type = 2 and file_type = 1", req.BizType).Find(&trainNoDiseaseFileList)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
|
|
|
|
if req.TargetData == 0 {
|
|
|
|
req.TargetData = len(fileList)
|
|
|
|
}
|
|
|
|
if req.TargetData > len(fileList) {
|
|
|
|
err = fmt.Errorf("超出现有标注数据集数量")
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
|
|
|
|
trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100))
|
|
|
|
valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100))
|
|
|
|
testNumber := req.TargetData - trainNumber - valNumber
|
|
|
|
trainDiseaseCount = int64(float64(int64(len(trainDiseaseFileList))*req.TrainNumber) / 100)
|
|
|
|
trainNoDiseaseCount = int64(float64(int64(len(trainNoDiseaseFileList))*req.TrainNumber) / 100)
|
|
|
|
if trainDiseaseCount+trainNoDiseaseCount > int64(trainNumber) {
|
|
|
|
if trainDiseaseCount > int64(trainNumber/2) {
|
|
|
|
trainDiseaseCount = int64(trainNumber / 2)
|
|
|
|
}
|
|
|
|
if trainNoDiseaseCount > int64(trainNumber)-trainDiseaseCount {
|
|
|
|
trainNoDiseaseCount = int64(trainNumber) - trainDiseaseCount
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
valDiseaseCount := int64(float64(int64(len(trainDiseaseFileList))*req.ValidationNumber) / 100)
|
|
|
|
valNoDiseaseCount := int64(float64(int64(len(trainNoDiseaseFileList))*req.ValidationNumber) / 100)
|
|
|
|
if valDiseaseCount+valNoDiseaseCount > int64(valNumber) {
|
|
|
|
if valDiseaseCount > int64(valNumber/2) {
|
|
|
|
valDiseaseCount = int64(valNumber / 2)
|
|
|
|
}
|
|
|
|
if valNoDiseaseCount > int64(valNumber)-valDiseaseCount {
|
|
|
|
valNoDiseaseCount = int64(valNumber) - valDiseaseCount
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
testDiseaseCount := int64(float64(int64(len(trainDiseaseFileList))*req.TestNumber) / 100)
|
|
|
|
testNoDiseaseCount := int64(float64(int64(len(trainNoDiseaseFileList))*req.TestNumber) / 100)
|
|
|
|
if testDiseaseCount+testNoDiseaseCount > int64(testNumber) {
|
|
|
|
if testDiseaseCount > int64(testNumber/2) {
|
|
|
|
testDiseaseCount = int64(testNumber / 2)
|
|
|
|
}
|
|
|
|
if testNoDiseaseCount > int64(testNumber)-testDiseaseCount {
|
|
|
|
testNoDiseaseCount = int64(testNumber) - testDiseaseCount
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if req.SplitMethod == 1 {
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
rand.Shuffle(len(trainDiseaseFileList), func(i, j int) {
|
|
|
|
trainDiseaseFileList[i], trainDiseaseFileList[j] = trainDiseaseFileList[j], trainDiseaseFileList[i]
|
|
|
|
})
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
rand.Shuffle(len(trainNoDiseaseFileList), func(i, j int) {
|
|
|
|
trainNoDiseaseFileList[i], trainNoDiseaseFileList[j] = trainNoDiseaseFileList[j], trainNoDiseaseFileList[i]
|
|
|
|
})
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
|
|
|
rand.Shuffle(len(fileList), func(i, j int) {
|
|
|
|
fileList[i], fileList[j] = fileList[j], fileList[i]
|
|
|
|
})
|
|
|
|
}
|
|
|
|
testDiseaseFileList = trainDiseaseFileList[:testDiseaseCount]
|
|
|
|
testNoDiseaseFileList = trainNoDiseaseFileList[:testNoDiseaseCount]
|
|
|
|
valDiseaseFileList = trainDiseaseFileList[testDiseaseCount : testDiseaseCount+valDiseaseCount]
|
|
|
|
valNoDiseaseFileList = trainNoDiseaseFileList[testNoDiseaseCount : testNoDiseaseCount+valNoDiseaseCount]
|
|
|
|
trainDiseaseFileList = trainDiseaseFileList[testDiseaseCount+valDiseaseCount : testDiseaseCount+valDiseaseCount+trainDiseaseCount]
|
|
|
|
trainNoDiseaseFileList = trainNoDiseaseFileList[testNoDiseaseCount+valNoDiseaseCount : testNoDiseaseCount+valNoDiseaseCount+trainNoDiseaseCount]
|
|
|
|
rp.logger.With(zap.String("创建训练集", "数据集大小"),
|
|
|
|
zap.Int("有病害训练数据集", len(trainDiseaseFileList)),
|
|
|
|
zap.Int("无病害训练数据集", len(trainNoDiseaseFileList)),
|
|
|
|
zap.Int("有病害验证数据集", len(valDiseaseFileList)),
|
|
|
|
zap.Int("无病害验证数据集", len(valNoDiseaseFileList)),
|
|
|
|
).Info("总数据集", zap.Int("len(fileList)", len(fileList)))
|
|
|
|
train := new(model.TrainingDataset)
|
|
|
|
h, err = rp.engine.Where("name = ?", req.TrainName).Get(train)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
if !h {
|
|
|
|
train.Name = req.TrainName
|
|
|
|
train.DatasetDesc = req.TrainDesc
|
|
|
|
//train.DatasetId = req.DatasetId
|
|
|
|
train.CategoryId = req.BizType
|
|
|
|
train.ValidationNumber = float64(req.ValidationNumber)
|
|
|
|
train.TestNumber = float64(req.TestNumber)
|
|
|
|
_, err = rp.engine.Insert(train)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
}
|
|
|
|
log.TargetData = int64(req.TargetData)
|
|
|
|
log.DatasetId = req.DatasetId
|
|
|
|
log.TrainingDatasetId = train.DatasetId
|
|
|
|
log.SplitMethod = req.SplitMethod
|
|
|
|
log.TrainNumber = int64(trainNumber)
|
|
|
|
log.ValidationNumber = int64(valNumber)
|
|
|
|
log.TestNumber = int64(testNumber)
|
|
|
|
log.Creator = req.UserId
|
|
|
|
_, err = rp.engine.Insert(log)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
list := make([]model.TrainingDatasetDetail, 0)
|
|
|
|
for _, v := range trainDiseaseFileList {
|
|
|
|
item := model.TrainingDatasetDetail{
|
|
|
|
FileName: v.FileName,
|
|
|
|
FilePath: v.FilePath,
|
|
|
|
DatasetId: train.DatasetId,
|
|
|
|
CategoryId: 1,
|
|
|
|
FileSize: v.FileSize,
|
|
|
|
IsDisease: 1,
|
|
|
|
OperationLogId: log.LogId,
|
|
|
|
Creator: req.UserId,
|
|
|
|
CreateAt: time.Now().Unix(),
|
|
|
|
UpdateAt: time.Now().Unix(),
|
|
|
|
}
|
|
|
|
list = append(list, item)
|
|
|
|
}
|
|
|
|
for _, v := range trainNoDiseaseFileList {
|
|
|
|
item := model.TrainingDatasetDetail{
|
|
|
|
FileName: v.FileName,
|
|
|
|
FilePath: v.FilePath,
|
|
|
|
DatasetId: train.DatasetId,
|
|
|
|
CategoryId: 1,
|
|
|
|
FileSize: v.FileSize,
|
|
|
|
IsDisease: 2,
|
|
|
|
OperationLogId: log.LogId,
|
|
|
|
Creator: req.UserId,
|
|
|
|
CreateAt: time.Now().Unix(),
|
|
|
|
UpdateAt: time.Now().Unix(),
|
|
|
|
}
|
|
|
|
list = append(list, item)
|
|
|
|
}
|
|
|
|
for _, v := range valDiseaseFileList {
|
|
|
|
item := model.TrainingDatasetDetail{
|
|
|
|
FileName: v.FileName,
|
|
|
|
FilePath: v.FilePath,
|
|
|
|
DatasetId: train.DatasetId,
|
|
|
|
CategoryId: 3,
|
|
|
|
FileSize: v.FileSize,
|
|
|
|
IsDisease: 1,
|
|
|
|
OperationLogId: log.LogId,
|
|
|
|
Creator: req.UserId,
|
|
|
|
CreateAt: time.Now().Unix(),
|
|
|
|
UpdateAt: time.Now().Unix(),
|
|
|
|
}
|
|
|
|
list = append(list, item)
|
|
|
|
}
|
|
|
|
for _, v := range valNoDiseaseFileList {
|
|
|
|
item := model.TrainingDatasetDetail{
|
|
|
|
FileName: v.FileName,
|
|
|
|
FilePath: v.FilePath,
|
|
|
|
DatasetId: train.DatasetId,
|
|
|
|
CategoryId: 3,
|
|
|
|
FileSize: v.FileSize,
|
|
|
|
IsDisease: 2,
|
|
|
|
OperationLogId: log.LogId,
|
|
|
|
Creator: req.UserId,
|
|
|
|
CreateAt: time.Now().Unix(),
|
|
|
|
UpdateAt: time.Now().Unix(),
|
|
|
|
}
|
|
|
|
list = append(list, item)
|
|
|
|
}
|
|
|
|
for _, v := range testDiseaseFileList {
|
|
|
|
item := model.TrainingDatasetDetail{
|
|
|
|
FileName: v.FileName,
|
|
|
|
FilePath: v.FilePath,
|
|
|
|
DatasetId: train.DatasetId,
|
|
|
|
CategoryId: 2,
|
|
|
|
FileSize: v.FileSize,
|
|
|
|
IsDisease: 1,
|
|
|
|
OperationLogId: log.LogId,
|
|
|
|
Creator: req.UserId,
|
|
|
|
CreateAt: time.Now().Unix(),
|
|
|
|
UpdateAt: time.Now().Unix(),
|
|
|
|
}
|
|
|
|
list = append(list, item)
|
|
|
|
}
|
|
|
|
for _, v := range testNoDiseaseFileList {
|
|
|
|
item := model.TrainingDatasetDetail{
|
|
|
|
FileName: v.FileName,
|
|
|
|
FilePath: v.FilePath,
|
|
|
|
DatasetId: train.DatasetId,
|
|
|
|
CategoryId: 2,
|
|
|
|
FileSize: v.FileSize,
|
|
|
|
IsDisease: 2,
|
|
|
|
OperationLogId: log.LogId,
|
|
|
|
Creator: req.UserId,
|
|
|
|
CreateAt: time.Now().Unix(),
|
|
|
|
UpdateAt: time.Now().Unix(),
|
|
|
|
}
|
|
|
|
list = append(list, item)
|
|
|
|
}
|
|
|
|
_, err = rp.engine.Insert(list)
|
|
|
|
if err != nil {
|
|
|
|
goto ReturnPoint
|
|
|
|
}
|
|
|
|
rsp.Code = http.StatusOK
|
|
|
|
rsp.Status = http.StatusText(http.StatusOK)
|
|
|
|
rsp.Message = "成功"
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Data = log
|
|
|
|
return rsp, err
|
|
|
|
}
|
|
|
|
ReturnPoint:
|
|
|
|
if err != nil {
|
|
|
|
rsp.Code = http.StatusInternalServerError
|
|
|
|
rsp.Status = http.StatusText(http.StatusInternalServerError)
|
|
|
|
rsp.Err = err
|
|
|
|
rsp.Message = "失败"
|
|
|
|
}
|
|
|
|
return rsp, err
|
|
|
|
}
|