1、增加边缘设备数据集功能标注;

2、增加训练数据集、训练任务、训练日志、训练结果导出的功能
This commit is contained in:
wangjian 2023-05-18 11:01:34 +08:00
parent 1d698fe0a6
commit db923ac9ae
24 changed files with 1039 additions and 36 deletions

View File

@ -84,13 +84,12 @@ func NewStartCmd() *cobra.Command {
tags, 300, 300, 300)
must(err)
logger := LoadLoggerConfig(cfg.Logging)
//连接数据库
model.New(cfg.Db.DriveName, cfg.Db.Conn, cfg.Mode == "dev")
model.New(cfg.Db.DriveName, cfg.Db.Conn, cfg.Mode == "dev", logger)
//连接redis
model.NewCache(cfg.Cache)
logger := LoadLoggerConfig(cfg.Logging)
//创建消息连接点
mq.MqList, err = mq.NewMqClient(cfg.Funcs, cfg.Node, logger)
must(err)

View File

@ -2,6 +2,7 @@ name: web
host: 0.0.0.0
port: 8088
mode: dev
trainDir : ./classification_dataset_balanced/
logging:
path: ./logs
prefix: hpds-iot-web
@ -16,7 +17,7 @@ logging:
mineData:
accessKey: f0bda738033e47ffbfbd5d3f865c19e1
minio:
endpoint: 192.168.0.200:9000
endpoint: 127.0.0.1:9000
accessKeyId: root
secretAccessKey: OIxv7QptYBO3
consul:
@ -29,9 +30,18 @@ db:
conn: root:OIxv7QptYBO3@tcp(114.55.236.153:27136)/hpds_jky?charset=utf8mb4
drive_name: mysql
cache:
host: 192.168.0.200
host: 127.0.0.1
port: 6379
db: 8
db: 0
pool_size: 10
node:
host: 127.0.0.1
port: 27188
token: 06d36c6f5705507dae778fdce90d0767
functions:
- name: web-sf
- name: task-request
dataTag : 12
mqType: 1
- name: task-log
dataTag: 28
mqType: 2

View File

@ -15,6 +15,7 @@ type WebConfig struct {
Host string `yaml:"host,omitempty"`
Port int `yaml:"port,omitempty"`
Mode string `yaml:"mode,omitempty"`
TrainDir string `yaml:"trainDir,omitempty"`
Consul ConsulConfig `yaml:"consul,omitempty"`
Db DbConfig `yaml:"db"`
Cache CacheConfig `yaml:"cache"`

View File

@ -2,6 +2,7 @@ name: web
host: 0.0.0.0
port: 8088
mode: dev
trainDir : ./classification_dataset_balanced/
logging:
path: ./logs
prefix: hpds-iot-web

View File

@ -79,18 +79,66 @@ func (s HandlerService) DatasetInfo(c *gin.Context) (data interface{}, err error
return
}
func (s HandlerService) CreateTraining(c *gin.Context) (data interface{}, err error) {
func (s HandlerService) CreateTrainDataset(c *gin.Context) (data interface{}, err error) {
repo := service.NewDatasetService(s.AppConfig, s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.TrainDatasetRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("CreateTraining", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
go s.SaveLog("CreateTrainDataset", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
req.UserId = userInfo.UserId
data, err = repo.CreateTraining(c, req)
data, err = repo.CreateTrainDataset(c, req)
go s.SaveLog("创建训练数据集", "Training", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}
func (s HandlerService) TrainDatasetList(c *gin.Context) (data interface{}, err error) {
repo := service.NewDatasetService(s.AppConfig, s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.TrainDatasetItemRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("CreateTrainDataset", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
if req.Size < 1 {
req.Size = 20
}
if req.Size > 1000 {
req.Size = 1000
}
if req.Page < 1 {
req.Page = 1
}
data, err = repo.TrainDatasetList(c, req)
go s.SaveLog("获取训练数据集列表", "Training", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}
func (s HandlerService) TrainDatasetFileList(c *gin.Context) (data interface{}, err error) {
repo := service.NewDatasetService(s.AppConfig, s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.TrainDatasetItemRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("TrainDatasetFileList", "Training", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
if req.Size < 1 {
req.Size = 20
}
if req.Size > 1000 {
req.Size = 1000
}
if req.Page < 1 {
req.Page = 1
}
data, err = repo.TrainDatasetFileList(c, req)
go s.SaveLog("获取训练数据集中的文件列表", "Training", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}

40
internal/handler/edge.go Normal file
View File

@ -0,0 +1,40 @@
package handler
import (
"fmt"
"github.com/gin-gonic/gin"
"hpds-iot-web/internal/proto"
"hpds-iot-web/internal/service"
"hpds-iot-web/model"
e "hpds-iot-web/pkg/err"
)
func (s HandlerService) GetEdgeList(c *gin.Context) (data interface{}, err error) {
repo := service.NewEdgeService(s.AppConfig, s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.EdgeDatasetRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("GetEdgeList", "Dataset", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
data, err = repo.GetEdgeList(c, req)
go s.SaveLog("获取边缘端数据列表", "Dataset", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}
func (s HandlerService) GetEdgeInfo(c *gin.Context) (data interface{}, err error) {
repo := service.NewEdgeService(s.AppConfig, s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.EdgeDatasetRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("GetEdgeInfo", "Dataset", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
data, err = repo.GetEdgeInfo(c, req)
go s.SaveLog("获取边缘端数据详情", "Dataset", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}

View File

@ -61,3 +61,18 @@ func (s HandlerService) FileList(c *gin.Context) (data interface{}, err error) {
go s.SaveLog("获取数据集详情", "FileManage", "", "", ToString(data), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}
func (s HandlerService) FileLabel(c *gin.Context) (data interface{}, err error) {
repo := service.NewFileService(s.AppConfig, s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.FileLabelRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("FileLabel", "FileManage", "", "", ToString(req), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
data, err = repo.FileLabel(c, req)
go s.SaveLog("标注文件", "FileManage", "", "", ToString(data), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}

View File

@ -115,3 +115,55 @@ func (s HandlerService) TaskLog(c *gin.Context) (data interface{}, err error) {
go s.SaveLog("获取任务日志信息", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}
func (s HandlerService) TrainingTaskList(c *gin.Context) (data interface{}, err error) {
repo := service.NewTaskService(s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.TaskRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("TrainingTaskList", "Manage", "", "", req.ToString(), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
if req.Size < 1 {
req.Size = 20
}
if req.Size > 1000 {
req.Size = 1000
}
if req.Page < 1 {
req.Page = 1
}
data, err = repo.TrainingTaskList(c, req)
go s.SaveLog("获取训练任务列表", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}
func (s HandlerService) TrainingTaskInfo(c *gin.Context) (data interface{}, err error) {
repo := service.NewTaskService(s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.TaskItemRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("TrainingTaskInfo", "Manage", "", "", req.ToString(), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
data, err = repo.TrainingTaskInfo(c, req)
go s.SaveLog("获取训练任务详情", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}
func (s HandlerService) TrainingTaskLog(c *gin.Context) (data interface{}, err error) {
repo := service.NewTaskService(s.Engine, s.Logger)
us, _ := c.Get("operatorUser")
userInfo := us.(*model.SystemUser)
var req proto.TaskItemRequest
err = c.ShouldBindJSON(&req)
if err != nil {
go s.SaveLog("TrainingTaskLog", "Manage", "", "", req.ToString(), fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return nil, e.NewValidErr(err)
}
data, err = repo.TrainingTaskLog(c, req)
go s.SaveLog("获取训练任务日志详情", "Manage", "", "", "", fmt.Sprintf("%d", userInfo.UserId), c.Request.RemoteAddr, "")
return
}

View File

@ -580,6 +580,11 @@ func (p DatasetItemRequest) ToString() string {
return string(data)
}
type FileLabelRequest struct {
FileId int64 `json:"fileId"`
IsDisease bool `json:"isDisease"`
}
type ImportDatasetRequest struct {
DatasetId int64 `json:"datasetId"`
CategoryId int `json:"categoryId"`
@ -745,3 +750,15 @@ type TrainDatasetRequest struct {
SplitMethod int `json:"splitMethod"`
UserId int64 `json:"userId"`
}
type TrainDatasetItemRequest struct {
DatasetId int64 `json:"datasetId"`
TrainName string `json:"trainName"`
BizType int `json:"bizType"`
BasePageList
}
type EdgeDatasetRequest struct {
NodeId int64 `json:"nodeId"`
Path string `json:"path"`
}

View File

@ -168,3 +168,27 @@ type TaskLogPayload struct {
Status int `json:"status"` //1执行成功;2:执行失败
EventTime int64 `json:"eventTime"`
}
type TrainingDataset struct {
DatasetId int64 `json:"datasetId"`
Name string `json:"name"`
CategoryId int `json:"categoryId"`
DatasetDesc string `json:"datasetDesc"`
TotalSize int64 `json:"totalSize"`
TrainSize int64 `json:"trainSize"`
ValSize int64 `json:"valSize"`
TestSize int64 `json:"testSize"`
StoreName string `json:"storeName"`
CreateAt int64 `json:"createAt"`
UpdateAt int64 `json:"updateAt"`
}
type TrainingDatasetFileItem struct {
DetailId int64 `json:"detailId"`
FileName string `json:"fileName"`
FileSize int64 `json:"fileSize"`
FilePath string `json:"filePath"`
FileContent string `json:"fileContent"`
IsDisease int `json:"isDisease"`
CategoryId int `json:"categoryId"`
}

View File

@ -148,6 +148,8 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi
file.Use(middleware.JwtAuthMiddleware(logger.Logger))
file.POST("/upload", e.ErrorWrapper(hs.UploadFile))
file.POST("/list", e.ErrorWrapper(hs.FileList))
file.POST("/label", e.ErrorWrapper(hs.FileLabel))
//file.POST("/batchLabel", e.ErrorWrapper(hs.FileBatchLabel))
}
system := r.Group("/system")
{
@ -220,6 +222,12 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi
// flusher.Flush()
//}
})
train := task.Group("/train")
{
train.POST("/list", e.ErrorWrapper(hs.TrainingTaskList))
train.POST("/info", e.ErrorWrapper(hs.TrainingTaskInfo))
train.POST("/log", e.ErrorWrapper(hs.TrainingTaskLog))
}
}
disease := r.Group("/disease")
{
@ -243,10 +251,12 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi
dataset.POST("/info", e.ErrorWrapper(hs.DatasetInfo))
}
training := r.Group("/training")
training := r.Group("/trainDataset")
{
training.Use(middleware.JwtAuthMiddleware(logger.Logger))
training.POST("/create", e.ErrorWrapper(hs.CreateTraining))
training.POST("/create", e.ErrorWrapper(hs.CreateTrainDataset))
training.POST("/list", e.ErrorWrapper(hs.TrainDatasetList))
training.POST("/fileList", e.ErrorWrapper(hs.TrainDatasetFileList))
//training.POST("/list", e.ErrorWrapper(hs.TrainingList))
//training.POST("/info", e.ErrorWrapper(hs.TrainingInfo))
}
@ -263,6 +273,16 @@ func InitRouter(cfg *config.WebConfig, logger *logging.Logger, engine *xorm.Engi
report.POST("/generate", e.ErrorWrapper(hs.GenerateReport))
//report.POST("/view", e.ErrorWrapper(hs.ViewReport))
}
edge := r.Group("/edge")
{
edge.Use(middleware.JwtAuthMiddleware(logger.Logger))
dir := edge.Group("/directory")
{
dir.POST("/list", e.ErrorWrapper(hs.GetEdgeList))
dir.POST("/info", e.ErrorWrapper(hs.GetEdgeInfo))
}
}
}
return root
}

View File

@ -2,14 +2,17 @@ package service
import (
"context"
"encoding/base64"
"fmt"
"git.hpds.cc/Component/logging"
"hpds-iot-web/config"
"hpds-iot-web/internal/proto"
"hpds-iot-web/model"
"hpds-iot-web/pkg/utils"
"math"
"math/rand"
"net/http"
"path"
"time"
"xorm.io/xorm"
)
@ -20,7 +23,9 @@ type DatasetService interface {
ImportDataset(ctx context.Context, req proto.ImportDatasetRequest) (rsp *proto.BaseResponse, err error)
DatasetInfo(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error)
CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error)
CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error)
TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error)
TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error)
}
func NewDatasetService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) DatasetService {
@ -229,7 +234,7 @@ ReturnPoint:
return rsp, err
}
func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) {
func (rp *repo) CreateTrainDataset(ctx context.Context, req proto.TrainDatasetRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
@ -258,7 +263,7 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques
goto ReturnPoint
}
fileList := make([]model.FileManager, 0)
err = rp.engine.Where("dataset_id = ?", req.DatasetId).Find(&fileList)
err = rp.engine.Where("dataset_id = ?", req.DatasetId).And("is_disease > 0").Find(&fileList)
if err != nil {
goto ReturnPoint
}
@ -266,7 +271,7 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques
req.TargetData = len(fileList)
}
if req.TargetData > len(fileList) {
err = fmt.Errorf("超出现有数据集数量")
err = fmt.Errorf("超出现有标注数据集数量")
goto ReturnPoint
}
if req.SplitMethod == 1 { //随机
@ -278,9 +283,24 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques
trainNumber := int(math.Floor(float64(int64(req.TargetData)*req.TrainNumber) / 100))
valNumber := int(math.Floor(float64(int64(req.TargetData)*req.ValidationNumber) / 100))
testNumber := req.TargetData - trainNumber - valNumber
trainFileList = fileList[:trainNumber-1]
valFileList = fileList[trainNumber : valNumber-1]
testFileList = fileList[valNumber:]
if trainNumber-1 > 1 {
trainFileList = fileList[:trainNumber-1]
} else {
trainFileList = make([]model.FileManager, 0)
trainFileList = append(trainFileList, fileList[0])
}
if trainNumber != trainNumber+valNumber-1 {
valFileList = fileList[trainNumber : trainNumber+valNumber-1]
} else {
valFileList = make([]model.FileManager, 0)
valFileList = append(valFileList, fileList[trainNumber])
}
if trainNumber+valNumber < len(fileList) {
testFileList = fileList[trainNumber+valNumber:]
} else {
testFileList = make([]model.FileManager, 0)
testFileList = append(testFileList, fileList[trainNumber+valNumber])
}
train := new(model.TrainingDataset)
h, err = rp.engine.Where("name = ?", req.TrainName).Get(train)
@ -290,7 +310,7 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques
if !h {
train.Name = req.TrainName
train.DatasetDesc = req.TrainDesc
train.DatasetId = req.DatasetId
//train.DatasetId = req.DatasetId
train.CategoryId = dataset.CategoryId
_, err = rp.engine.Insert(train)
if err != nil {
@ -310,9 +330,9 @@ func (rp *repo) CreateTraining(ctx context.Context, req proto.TrainDatasetReques
goto ReturnPoint
}
//wg.Add(3)
go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, rp.engine) //, &wg
go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, rp.engine) //, &wg
go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, rp.engine) //, &wg
go BatchCopyData(trainFileList, train.DatasetId, log.LogId, req.UserId, 1, req.TrainName, rp) //, &wg
go BatchCopyData(valFileList, train.DatasetId, log.LogId, req.UserId, 2, req.TrainName, rp) //, &wg
go BatchCopyData(testFileList, train.DatasetId, log.LogId, req.UserId, 3, req.TrainName, rp) //, &wg
//wg.Wait()
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
@ -331,16 +351,24 @@ ReturnPoint:
return rsp, err
}
func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, engine *xorm.Engine) { //, wg *sync.WaitGroup
func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categoryId int, trainName string, rp *repo) { //, wg *sync.WaitGroup
batchList := make([]model.TrainingDatasetDetail, len(list))
for k, v := range list {
dir := "no_disease"
if v.IsDisease == 1 {
dir = "disease"
}
utils.DownloadMinioFileToLocalPath(v.AccessUrl, path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir), v.FileName,
rp.AppConfig.Minio.Protocol, rp.AppConfig.Minio.Endpoint, rp.AppConfig.Minio.Bucket, rp.AppConfig.Minio.AccessKeyId,
rp.AppConfig.Minio.SecretAccessKey, rp.logger)
item := model.TrainingDatasetDetail{
FileName: v.FileName,
FilePath: v.AccessUrl,
FilePath: path.Join(rp.AppConfig.TrainDir, trainName, model.GetTrainCategory(categoryId), dir, v.FileName),
DatasetId: trainId,
CategoryId: categoryId,
FileSize: v.FileSize,
FileMd5: v.FileMd5,
IsDisease: v.IsDisease,
OperationLogId: logId,
Creator: userId,
CreateAt: time.Now().Unix(),
@ -348,6 +376,141 @@ func BatchCopyData(list []model.FileManager, trainId, logId, userId int64, categ
}
batchList[k] = item
}
_, _ = engine.Insert(batchList)
_, _ = rp.engine.Insert(batchList)
//wg.Done()
}
func (rp *repo) TrainDatasetList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
type QuantityStatistics struct {
DatasetId int64
CategoryId int
Total int64
}
var (
count int64
list []proto.TrainingDataset
)
trainingList := make([]model.TrainingDataset, 0)
count, err = rp.engine.Where("(?=0 or dataset_id = ?)", req.DatasetId, req.DatasetId).
And("(?= 0 or category_id = ?)", req.BizType, req.BizType).
And("(? ='' or name like ?)", req.TrainName, "%"+req.TrainName+"%").
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&trainingList)
if err != nil {
goto ReturnPoint
}
list = make([]proto.TrainingDataset, len(trainingList))
for k, v := range trainingList {
qs := make([]QuantityStatistics, 0)
err = rp.engine.SQL("select dataset_id, category_id, count(1) as total from training_dataset_detail where dataset_id = ? group by category_id, dataset_id", v.DatasetId).Find(&qs)
if err != nil {
goto ReturnPoint
}
item := proto.TrainingDataset{
DatasetId: v.DatasetId,
Name: v.Name,
CategoryId: v.CategoryId,
DatasetDesc: v.DatasetDesc,
TotalSize: 0,
TrainSize: 0,
ValSize: 0,
TestSize: 0,
StoreName: v.StoreName,
CreateAt: v.CreateAt,
UpdateAt: v.UpdateAt,
}
for _, val := range qs {
switch val.CategoryId {
case 1:
item.TrainSize = val.Total
case 2:
item.ValSize = val.Total
case 3:
item.TestSize = val.Total
}
item.TotalSize += val.Total
}
list[k] = item
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp = FillPaging(count, req.Page, req.Size, list, rsp)
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TrainDatasetFileList(ctx context.Context, req proto.TrainDatasetItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
count int64
list []proto.TrainingDatasetFileItem
)
fileList := make([]model.TrainingDatasetDetail, 0)
count, err = rp.engine.Where("dataset_id = ?", req.DatasetId).
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&fileList)
if err != nil {
goto ReturnPoint
}
list = make([]proto.TrainingDatasetFileItem, len(fileList))
for k, v := range fileList {
buff := utils.ReadFile(v.FilePath)
img := utils.BuffToImage(buff)
buf := utils.ImageToBuff(img, "jpeg")
list[k] = proto.TrainingDatasetFileItem{
DetailId: v.DetailId,
FileName: v.FileName,
FileSize: v.FileSize,
FilePath: v.FilePath,
FileContent: "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()),
IsDisease: v.IsDisease,
CategoryId: v.CategoryId,
}
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp = FillPaging(count, req.Page, req.Size, list, rsp)
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}

96
internal/service/edge.go Normal file
View File

@ -0,0 +1,96 @@
package service
import (
"context"
"encoding/json"
"fmt"
"git.hpds.cc/Component/logging"
"hpds-iot-web/config"
"hpds-iot-web/internal/proto"
"hpds-iot-web/pkg/utils"
"net/http"
"xorm.io/xorm"
)
type EdgeService interface {
GetEdgeList(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error)
GetEdgeInfo(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error)
}
func NewEdgeService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) EdgeService {
return &repo{
AppConfig: cfg,
engine: engine,
logger: logger,
}
}
func (rp *repo) GetEdgeList(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
param := make(map[string]string)
param["path"] = req.Path
header := make(map[string]string)
header["Content-Type"] = "application/json"
res, err := utils.HttpDo("http://192.168.22.151:8099/api/directory/list", "POST", param, header)
if err != nil {
goto ReturnPoint
}
err = json.Unmarshal(res, &rsp)
if err != nil {
goto ReturnPoint
}
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) GetEdgeInfo(ctx context.Context, req proto.EdgeDatasetRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
param := make(map[string]string)
param["path"] = req.Path
header := make(map[string]string)
header["Content-Type"] = "application/json"
res, err := utils.HttpDo("http://192.168.22.151:8099/api/directory/info", "POST", param, header)
if err != nil {
goto ReturnPoint
}
err = json.Unmarshal(res, &rsp)
if err != nil {
goto ReturnPoint
}
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}

View File

@ -24,6 +24,7 @@ type FileService interface {
UploadFile(ctx context.Context, req proto.UploadFileRequest) (rsp *proto.BaseResponse, err error)
UploadFileToMinIo(ctx context.Context, srcFile *multipart.FileHeader, scene string, datasetId, creator int64, dataType int) (data *model.FileManager, err error)
FileList(ctx context.Context, req proto.DatasetItemRequest) (rsp *proto.BaseResponse, err error)
FileLabel(ctx context.Context, req proto.FileLabelRequest) (rsp *proto.BaseResponse, err error)
}
func NewFileService(cfg *config.WebConfig, engine *xorm.Engine, logger *logging.Logger) FileService {
@ -171,3 +172,50 @@ ReturnPoint:
}
return rsp, err
}
func (rp *repo) FileLabel(ctx context.Context, req proto.FileLabelRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
item := new(model.FileManager)
var h bool
h, err = rp.engine.ID(req.FileId).Get(item)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的文件")
goto ReturnPoint
}
if req.IsDisease {
item.IsDisease = 1
} else {
item.IsDisease = 2
}
_, err = rp.engine.ID(req.FileId).Cols("is_disease").Update(item)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}

View File

@ -383,20 +383,22 @@ func (rp *repo) AddProject(ctx context.Context, req proto.ProjectItemRequest) (r
Status: 1,
Creator: req.Creator,
}
slng, slat, err := rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.StartName))
var (
sLng, sLat, eLng, eLat float64
)
sLng, sLat, err = rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.StartName))
if err != nil {
goto ReturnPoint
}
item.StartPointLng = slng
item.StartPointLat = slat
item.StartPointLng = sLng
item.StartPointLat = sLat
elng, elat, err := rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.EndName))
eLng, eLat, err = rp.GetLngLat(ctx, fmt.Sprintf("%s+%s", req.LineName, req.EndName))
if err != nil {
goto ReturnPoint
}
item.EndPointLng = elng
item.EndPointLat = elat
item.EndPointLng = eLng
item.EndPointLat = eLat
_, err = rp.engine.Insert(item)
if err != nil {

View File

@ -26,6 +26,10 @@ type TaskService interface {
//EditTask(ctx context.Context, req proto.ModelItemRequest) (rsp *proto.BaseResponse, err error)
TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *proto.BaseResponse, err error)
TaskLog(ctx context.Context, req proto.TaskLogItem) (rsp *proto.BaseResponse, err error)
TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error)
TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error)
}
func NewTaskService(engine *xorm.Engine, logger *logging.Logger) TaskService {
@ -402,6 +406,11 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p
if err := json.Unmarshal([]byte(str), &mr); err != nil {
continue
}
switch md.BizType {
case 1: //道路
case 2: //桥梁
case 3: //隧道
}
switch mr.Code {
case 0: //轻量化模型返回
lr := new(mq.LightweightResult)
@ -433,6 +442,9 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p
fn, _ := base64.StdEncoding.DecodeString(fileDiscern)
buff := bytes.NewBuffer(fn)
_, imgType, _ := image.Decode(buff)
if len(fileDiscern) == 0 {
fileDiscern = lr.ImgSrc
}
fileDiscern = fmt.Sprintf("data:image/%s;base64,%s", imgType, fileDiscern)
item := proto.TaskResultItem{
FileId: v.FileId,
@ -555,6 +567,9 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p
memo = "检测到坑洼"
}
fileDiscern = lr.ImgDiscern
if len(fileDiscern) == 0 {
fileDiscern = lr.ImgSrc
}
diseaseLevel = 3
diseaseLevelName = "重度"
switch md.BizType {
@ -568,6 +583,8 @@ func (rp *repo) TaskResult(ctx context.Context, req proto.ReportRequest) (rsp *p
diseaseType = 4
diseaseTypeName = "横向裂缝"
}
} else {
fileDiscern = lr.ImgSrc
}
//
case 2000: //网新返回没有病害
@ -677,3 +694,115 @@ ReturnPoint:
}
return rsp, err
}
func (rp *repo) TrainingTaskList(ctx context.Context, req proto.TaskRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
count int64
)
list := make([]model.TrainTask, 0)
count, err = rp.engine.Where("(? = 0 or category_id = ?)", req.BizType, req.BizType).
And("(? = '' or task_name like ?)", req.TaskName, "%"+req.TaskName+"%").
Limit(int(req.Size), int(((req.Page)-1)*req.Size)).FindAndCount(&list)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp = FillPaging(count, req.Page, req.Size, list, rsp)
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TrainingTaskInfo(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
var (
h bool
)
item := new(model.TrainTask)
h, err = rp.engine.ID(req.TaskId).Get(item)
if err != nil {
goto ReturnPoint
}
if !h {
err = fmt.Errorf("未能找到对应的任务信息")
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp.Data = item
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}
func (rp *repo) TrainingTaskLog(ctx context.Context, req proto.TaskItemRequest) (rsp *proto.BaseResponse, err error) {
rsp = new(proto.BaseResponse)
select {
case <-ctx.Done():
err = fmt.Errorf("超时/取消")
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Message = "超时/取消"
rsp.Err = ctx.Err()
return rsp, ctx.Err()
default:
list := make([]model.TrainTaskLog, 0)
err = rp.engine.Where("task_id = ?", req.TaskId).Asc("epoch").Find(&list)
if err != nil {
goto ReturnPoint
}
rsp.Code = http.StatusOK
rsp.Status = http.StatusText(http.StatusOK)
rsp.Message = "成功"
rsp.Err = err
rsp = FillPaging(int64(len(list)), 1, 1000, list, rsp)
return rsp, err
}
ReturnPoint:
if err != nil {
rsp.Code = http.StatusInternalServerError
rsp.Status = http.StatusText(http.StatusInternalServerError)
rsp.Err = err
rsp.Message = "失败"
}
return rsp, err
}

View File

@ -9,6 +9,7 @@ type FileManager struct {
DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //数据集
FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小
FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5
IsDisease int `xorm:"TINYINT index default 0" json:"isDisease"` //数据标注状态; 0:未标注;1:有病害;2:无病害
Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人
CreateAt int64 `xorm:"created" json:"createAt"` //上传时间
UpdateAt int64 `xorm:"updated" json:"updateAt"` //更新时间

View File

@ -2,6 +2,7 @@ package model
import (
"fmt"
"git.hpds.cc/Component/logging"
"github.com/go-redis/redis"
_ "github.com/go-sql-driver/mysql"
"go.uber.org/zap"
@ -21,7 +22,7 @@ var (
Redis *redis.Client
)
func New(driveName, dsn string, showSql bool) {
func New(driveName, dsn string, showSql bool, logger *logging.Logger) {
DB, _ = NewDbConnection(driveName, dsn)
DB.ShowSQL(showSql)
DB.Dialect().SetQuotePolicy(dialects.QuotePolicyReserved)
@ -64,9 +65,11 @@ func New(driveName, dsn string, showSql bool) {
&TaskResult{},
&TrainingDataset{},
&TrainingDatasetDetail{},
&TrainTask{},
&TrainTaskLog{},
)
if err != nil {
zap.L().Error("同步数据库表结构", zap.Error(err))
logger.Error("同步数据库表结构", zap.Error(err))
os.Exit(1)
}
}
@ -99,3 +102,15 @@ func NewCache(c config.CacheConfig) {
zap.L().Info("Redis连接成功", zap.String("pong", pong))
}
}
func GetTrainCategory(categoryId int) string {
switch categoryId {
case 1:
return "train"
case 2:
return "val"
case 3:
return "test"
}
return "other"
}

18
model/trainTask.go Normal file
View File

@ -0,0 +1,18 @@
package model
type TrainTask struct {
TaskId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"taskId"`
TrainDatasetId int64 `xorm:"INT(11) index" json:"trainDatasetId"`
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //业务分类, 1:道路 2:桥梁 3:隧道 4:边坡
TaskName string `xorm:"VARCHAR(200)" json:"taskName"`
TaskDesc string `xorm:"VARCHAR(500)" json:"taskDesc"`
StartTime int64 `xorm:"BIGINT" json:"startTime"`
FinishTime int64 `xorm:"BIGINT" json:"finishTime"`
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
ModelFilePath string `xorm:"VARCHAR(2000)" json:"modelFilePath"`
PbModelFilePath string `xorm:"VARCHAR(2000)" json:"pbModelFilePath"`
Status int `xorm:"not null SMALLINT default 0" json:"status"` // 1:等待执行; 2:执行中; 3:执行完成; 4:任务分配失败; 5:任务执行失败
CreateAt int64 `xorm:"created" json:"createAt"`
UpdateAt int64 `xorm:"updated" json:"updateAt"`
}

12
model/trainTaskLog.go Normal file
View File

@ -0,0 +1,12 @@
package model
type TrainTaskLog struct {
LogId int64 `xorm:"not null pk autoincr BIGINT(11)" json:"logId"`
TaskId int64 `xorm:"INT(11) index" json:"taskId"`
Epoch int `xorm:"SMALLINT" json:"epoch"`
Loss float64 `xorm:"DECIMAL(18,6)" json:"loss"`
Accuracy float64 `xorm:"DECIMAL(18,6)" json:"accuracy"`
ValLoss float64 `xorm:"DECIMAL(18,6)" json:"valLoss"`
ValAccuracy float64 `xorm:"DECIMAL(18,6)" json:"valAccuracy"`
CreateAt int64 `xorm:"created" json:"createAt"`
}

View File

@ -6,8 +6,9 @@ type TrainingDatasetDetail struct {
FilePath string `xorm:"VARCHAR(1000)" json:"filePath"`
DatasetId int64 `xorm:"INT(11) index default 0" json:"datasetId"` //训练数据集
CategoryId int `xorm:"not null SMALLINT default 1" json:"categoryId"` //训练集分类1训练集;2测试集;3:验证集
FileSize int64 `xorm:"BININT" json:"fileSize"` //文件大小
FileSize int64 `xorm:"BIGINT" json:"fileSize"` //文件大小
FileMd5 string `xorm:"VARCHAR(64)" json:"fileMd5"` //文件MD5
IsDisease int `xorm:"TINYINT(1)" json:"isDisease"` //是否有病害
OperationLogId int64 `xorm:"INT(11) index" json:"operationLogId"` //操作日志编号
Creator int64 `xorm:"INT(11) index" json:"creator"` //上传人
CreateAt int64 `xorm:"created" json:"createAt"` //上传时间

46
pkg/minio/index.go Normal file
View File

@ -0,0 +1,46 @@
package minio
import (
"context"
"git.hpds.cc/Component/logging"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"io"
)
type MinClient struct {
Client *minio.Client
Logger *logging.Logger
}
func NewClient(ak, sak, ep string, useSSL bool, logger *logging.Logger) *MinClient {
opt := &minio.Options{
Creds: credentials.NewStaticV4(ak, sak, ""),
Secure: useSSL,
}
client, err := minio.New(ep, opt)
if err != nil {
return nil
}
return &MinClient{
Client: client,
Logger: logger,
}
}
func (cli *MinClient) UploadObject(fn, dst, bucket string) error {
_, err := cli.Client.FPutObject(context.Background(), bucket, dst, fn, minio.PutObjectOptions{})
if err != nil {
return err
}
return nil
}
func (cli *MinClient) GetObject(dstUrl, bucket string) ([]byte, error) {
f, err := cli.Client.GetObject(context.Background(), bucket, dstUrl, minio.GetObjectOptions{})
if err != nil {
return nil, err
}
imgByte, _ := io.ReadAll(f)
return imgByte, nil
}

119
pkg/utils/file.go Normal file
View File

@ -0,0 +1,119 @@
package utils
import (
"crypto/md5"
"encoding/hex"
"fmt"
"git.hpds.cc/Component/logging"
"go.uber.org/zap"
"hpds-iot-web/pkg/minio"
"io"
"os"
"path"
"path/filepath"
"strings"
)
func CopyFile(src, dst string) error {
sourceFileStat, err := os.Stat(src)
if err != nil {
return err
}
if !sourceFileStat.Mode().IsRegular() {
return fmt.Errorf("%s is not a regular file", src)
}
source, err := os.Open(src)
if err != nil {
return err
}
defer func(source *os.File) {
_ = source.Close()
}(source)
destination, err := os.Create(dst)
if err != nil {
return err
}
defer func(destination *os.File) {
_ = destination.Close()
}(destination)
_, err = io.Copy(destination, source)
return err
}
func PathExists(path string) bool {
_, err := os.Stat(path)
if err == nil {
return true
}
if os.IsNotExist(err) {
return false
}
return false
}
// ReadFile 读取到file中再利用ioutil将file直接读取到[]byte中, 这是最优
func ReadFile(fn string) []byte {
f, err := os.Open(fn)
if err != nil {
logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err))
return nil
}
defer func(f *os.File) {
_ = f.Close()
}(f)
fd, err := io.ReadAll(f)
if err != nil {
logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err))
return nil
}
return fd
}
func GetFileName(fn string) string {
fileType := path.Ext(fn)
return strings.TrimSuffix(fn, fileType)
}
func GetFileNameAndExt(fn string) string {
_, fileName := filepath.Split(fn)
return fileName
}
func GetFileMd5(data []byte) string {
hash := md5.New()
hash.Write(data)
return hex.EncodeToString(hash.Sum(nil))
}
func DownloadMinioFileToLocalPath(accessUrl, dstPath, fileName, protocol, endpoint, bucket, accessKeyId, secretAccessKey string,
logger *logging.Logger) {
if !PathExists(path.Join(dstPath, fileName)) {
dPath := strings.Replace(accessUrl, fmt.Sprintf("%s://%s/", protocol, endpoint), "", 1)
dPath = strings.Replace(dPath, bucket, "", 1)
minioCli := minio.NewClient(accessKeyId, secretAccessKey, endpoint, false, logger)
imgByte, err := minioCli.GetObject(dPath, bucket)
if err != nil {
logger.With(zap.String("源文件名", accessUrl)).
With(zap.String("文件名", path.Join(dstPath, fileName))).
Error("文件下载", zap.Error(err))
}
err = os.MkdirAll(dstPath, os.ModePerm)
if err != nil {
logger.With(zap.String("源文件名", accessUrl)).
With(zap.String("文件名", path.Join(dstPath, fileName))).
Error("创建文件下载目录", zap.Error(err))
}
err = os.WriteFile(path.Join(dstPath, fileName), imgByte, os.ModePerm)
if err != nil {
logger.With(zap.String("源文件名", accessUrl)).
With(zap.String("文件名", path.Join(dstPath, fileName))).
Error("文件写入", zap.Error(err))
}
}
}

126
pkg/utils/http.go Normal file
View File

@ -0,0 +1,126 @@
package utils
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"path/filepath"
"strings"
)
func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) {
var paramStr string = ""
if contentType, ok := header["Content-Type"]; ok && strings.Contains(contentType, "json") {
bytesData, _ := json.Marshal(params)
paramStr = string(bytesData)
} else {
for k, v := range params {
if len(paramStr) == 0 {
paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v))
} else {
paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v))
}
}
}
client := &http.Client{}
req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr))
if err != nil {
return nil, err
}
for k, v := range header {
req.Header.Set(k, v)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() {
if resp.Body != nil {
err = resp.Body.Close()
if err != nil {
return
}
}
}()
var body []byte
if resp.Body != nil {
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
}
return body, nil
}
type UploadFile struct {
// 表单名称
Name string
Filepath string
// 文件全路径
File *bytes.Buffer
}
func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string {
requestBody, realContentType := getReader(reqParams, contentType, files)
httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody)
// 添加请求头
httpRequest.Header.Add("Content-Type", realContentType)
if headers != nil {
for k, v := range headers {
httpRequest.Header.Add(k, v)
}
}
httpClient := &http.Client{}
// 发送请求
resp, err := httpClient.Do(httpRequest)
if err != nil {
panic(err)
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
response, _ := io.ReadAll(resp.Body)
return string(response)
}
func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) {
if strings.Index(contentType, "json") > -1 {
bytesData, _ := json.Marshal(reqParams)
return bytes.NewReader(bytesData), contentType
} else if files != nil {
body := &bytes.Buffer{}
// 文件写入 body
writer := multipart.NewWriter(body)
for _, uploadFile := range files {
part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath))
if err != nil {
panic(err)
}
_, err = io.Copy(part, uploadFile.File)
}
// 其他参数列表写入 body
for k, v := range reqParams {
if err := writer.WriteField(k, v); err != nil {
panic(err)
}
}
if err := writer.Close(); err != nil {
panic(err)
}
// 上传文件需要自己专用的contentType
return body, writer.FormDataContentType()
} else {
urlValues := url.Values{}
for key, val := range reqParams {
urlValues.Set(key, val)
}
reqBody := urlValues.Encode()
return strings.NewReader(reqBody), contentType
}
}