taskExecute/mq/handler.go

402 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mq
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"git.hpds.cc/pavement/hpds_node"
"github.com/emirpasic/gods/lists/arraylist"
"github.com/fsnotify/fsnotify"
"github.com/gammazero/workerpool"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/shirou/gopsutil/v3/host"
"go.uber.org/zap"
"io"
"net/http"
"os"
"os/exec"
"path"
"strconv"
"strings"
"sync"
"taskExecute/config"
"taskExecute/pkg/compress"
"taskExecute/pkg/docker"
"taskExecute/pkg/download"
"taskExecute/pkg/utils"
"taskExecute/proto"
"time"
)
var (
wg sync.WaitGroup
TaskList map[string]docker.ContainerStore
)
func TaskExecuteHandler(data []byte) (byte, []byte) {
fmt.Println("接收数据", string(data))
cmd := new(InstructionReq)
err := json.Unmarshal(data, cmd)
if err != nil {
return 0x0B, []byte(err.Error())
}
switch cmd.Command {
case TaskExecute:
//任务执行
waitWorkerStartFinish(config.WorkPool, cmd.Payload.(map[string]interface{}), ModelTaskExecuteHandler)
case ModelIssueRepeater:
//模型下发
waitWorkerStartFinish(config.WorkPool, cmd.Payload.(map[string]interface{}), ModelIssueRepeaterHandler)
}
return byte(cmd.Command), nil
}
func waitWorkerStartFinish(wp *workerpool.WorkerPool, payload map[string]interface{}, f func(payload map[string]interface{})) {
startStop := make(chan time.Time, 2)
wp.Submit(func() {
startStop <- time.Now()
f(payload)
startStop <- time.Now()
})
fmt.Println("Task started at:", <-startStop)
fmt.Println("Task finished at:", <-startStop)
}
// execCommand 执行命令
func execCommandWait(commandName string, params []string) bool {
cmd := exec.Command(commandName, params...)
//显示运行的命令
fmt.Println(cmd.Args)
stdout, err := cmd.StdoutPipe()
if err != nil {
fmt.Println(err)
return false
}
_ = cmd.Start()
reader := bufio.NewReader(stdout)
//实时循环读取输出流中的一行内容
for {
wg.Add(1)
line, err2 := reader.ReadString('\n')
if err2 != nil || io.EOF == err2 {
break
}
config.Logger.Info("执行命令",
zap.String("execCommandWait", line))
}
_ = cmd.Wait()
wg.Done()
return true
}
func ModelIssueRepeaterHandler(payload map[string]interface{}) {
hi, _ := host.Info()
if payload["nodeGuid"].(string) == hi.HostID {
fileUrl := payload["dockerFile"].(string)
modelVersion := payload["modelVersion"].(string)
downFileName := path.Base(fileUrl)
//判断文件后缀名
fileType := path.Ext(downFileName)
fileNameOnly := strings.TrimSuffix(downFileName, fileType)
dFile := path.Join(config.Cfg.TmpPath, fileNameOnly, downFileName)
//执行文件下载
controller := download.ThreadController{
ThreadCount: download.ThreadCount,
FileUrl: fileUrl,
DownloadFolder: dFile,
DownloadFileName: downFileName,
Logger: config.Logger,
}
controller.Download(download.OneThreadDownloadSize)
if strings.ToLower(fileType) == ".zip" {
err := compress.UnzipFromFile(path.Join(config.Cfg.TmpPath, fileNameOnly), dFile)
if err != nil {
controller.Logger.With(zap.String("文件解压缩", path.Join(config.Cfg.TmpPath, downFileName))).
Error("发生错误", zap.Error(err))
return
}
dFile = path.Join(config.Cfg.TmpPath, fileNameOnly, fileNameOnly+".tar")
}
//docker 导入并运行
imgName := fmt.Sprintf("%s:%s", fileNameOnly, modelVersion)
if strings.ToLower(path.Ext(dFile)) == ".tar" {
dCli := docker.NewDockerClient()
err := dCli.ImportImage(imgName, "latest", dFile)
//err = dCli.LoadImage(dFile)
if err != nil {
controller.Logger.With(zap.String("导入docker的文件", dFile)).
Error("发生错误", zap.Error(err))
}
//设置data目录
dataPath := path.Join(config.Cfg.TmpPath, fileNameOnly, "data")
_ = os.MkdirAll(dataPath, os.ModePerm)
vol := make(map[string]string)
vol[path.Join(dataPath, payload["inPath"].(string))] = payload["inPath"].(string)
vol[path.Join(dataPath, payload["outPath"].(string))] = payload["outPath"].(string)
//docker运行
modelCommand := strings.Split(payload["modelCommand"].(string), " ")
dstPort := dCli.CreateContainer(fileNameOnly, imgName, modelCommand, vol, strconv.Itoa(payload["mappedPort"].(int)))
//保存到本地临时文件
item := docker.ContainerStore{
ModelId: payload["modelId"].(int64),
NodeId: payload["nodeId"].(int64),
Name: fileNameOnly,
ImgName: imgName,
Volumes: vol,
SrcPort: strconv.Itoa(payload["mappedPort"].(int)),
DstPort: dstPort,
Command: modelCommand,
HttpUrl: payload["httpUrl"].(string),
}
docker.ContainerList = append(docker.ContainerList, item)
docker.SaveStore()
cli := GetMqClient("task-response", 1)
ap := cli.EndPoint.(hpds_node.AccessPoint)
res := new(InstructionReq)
res.Command = ModelIssueResponse
res.Payload = item
pData, _ := json.Marshal(res)
_ = GenerateAndSendData(ap, pData)
}
}
}
func ModelTaskExecuteHandler(payload map[string]interface{}) {
hi, _ := host.Info()
if payload["nodeGuid"] == hi.HostID {
if len(payload["subDataset"].(string)) > 0 {
sf := hpds_node.NewStreamFunction(
payload["subDataset"].(string),
hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", config.Cfg.Node.Host, config.Cfg.Node.Port)),
hpds_node.WithObserveDataTags(payload["subDataTag"].(byte)),
hpds_node.WithCredential(config.Cfg.Node.Token),
)
err := sf.Connect()
must(config.Logger, err)
nodeInfo := HpdsMqNode{
MqType: 2,
Topic: payload["subDataset"].(string),
Node: config.Cfg.Node,
EndPoint: sf,
}
_ = sf.SetHandler(func(data []byte) (byte, []byte) {
//查询docker是否已经开启
issue := new(docker.ContainerStore)
_ = json.Unmarshal([]byte(payload["issueResult"].(string)), issue)
dCli := docker.NewDockerClient()
cList, err := dCli.SearchImage(issue.Name)
if err != nil {
}
if len(cList) > 0 {
if len(payload["workflow"].(string)) > 0 {
//是否设置工作流程
wf := new(Workflow)
err = json.Unmarshal([]byte(payload["workflow"].(string)), wf)
if err != nil {
}
if len(payload["datasetPath"].(string)) > 0 {
//数据集处理
opt := &minio.Options{
Creds: credentials.NewStaticV4(config.Cfg.Minio.AccessKeyId, config.Cfg.Minio.SecretAccessKey, ""),
Secure: false,
}
cli, _ := minio.New(config.Cfg.Minio.Endpoint, opt)
doneCh := make(chan struct{})
defer close(doneCh)
objectCh := cli.ListObjects(context.Background(), config.Cfg.Minio.Bucket, minio.ListObjectsOptions{
Prefix: payload["datasetPath"].(string),
Recursive: true,
})
for object := range objectCh {
file, _ := cli.GetObject(context.Background(), config.Cfg.Minio.Bucket, object.Key, minio.GetObjectOptions{})
imgByte, _ := io.ReadAll(file)
f := proto.FileCapture{
FileName: object.Key,
File: base64.StdEncoding.EncodeToString(imgByte),
DatasetName: payload["datasetName"].(string),
CaptureTime: object.LastModified.Unix(),
}
ProcessWorkflow(payload, f, wf)
}
}
} else {
f := new(proto.FileCapture)
err := json.Unmarshal(data, f)
if err != nil {
}
if len(f.File) > 0 {
i := strings.Index(f.File, ",")
dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(f.File[i+1:]))
if len(payload["httpUrl"].(string)) > 0 {
_ = os.MkdirAll(path.Join(config.Cfg.TmpPath, payload["subDataset"].(string)), os.ModePerm)
tmpFile, _ := os.Create(path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName))
_, err = io.Copy(tmpFile, dec)
reqUrl := fmt.Sprintf("http://localhost:%s/%s", issue.DstPort, issue.HttpUrl)
response, err := http.Post(reqUrl, "multipart/form-data", dec)
if err != nil {
config.Logger.With(zap.String("源文件名", f.FileName)).
With(zap.String("临时文件名", path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName))).
Error("文件提交", zap.Error(err))
}
defer func() {
_ = response.Body.Close()
config.Logger.With(zap.String("源文件名", f.FileName)).
With(zap.String("临时文件名", path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName))).
Info("模型识别")
}()
body, err := io.ReadAll(response.Body)
if err != nil {
config.Logger.With(zap.String("源文件名", f.FileName)).
With(zap.String("临时文件名", path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName))).
Error("模型识别", zap.Error(err))
}
cli := GetMqClient("task-response", 1)
ap := cli.EndPoint.(hpds_node.AccessPoint)
res := new(InstructionReq)
res.Command = TaskResponse
res.Payload = body
pData, _ := json.Marshal(res)
_ = GenerateAndSendData(ap, pData)
}
if len(payload["inPath"].(string)) > 0 {
outPath := ""
for k, v := range issue.Volumes {
if v == payload["outPath"].(string) {
outPath = k
break
}
}
//创建一个监控对象
watch, err := fsnotify.NewWatcher()
if err != nil {
config.Logger.Error("创建文件监控", zap.Error(err))
}
defer func(watch *fsnotify.Watcher) {
_ = watch.Close()
}(watch)
err = watch.Add(outPath)
if err != nil {
config.Logger.With(zap.String("监控目录", outPath)).
Error("创建文件监控", zap.Error(err))
}
for k, v := range issue.Volumes {
if v == payload["inPath"].(string) {
_ = os.MkdirAll(k, os.ModePerm)
tmpFile, _ := os.Create(path.Join(k, f.FileName))
_, err = io.Copy(tmpFile, dec)
break
}
}
list := arraylist.New() // empty
t1 := time.NewTicker(1 * time.Second)
go func() {
for {
select {
case ev := <-watch.Events:
{
//判断事件发生的类型如下5种
// Create 创建
// Write 写入
// Remove 删除
// Rename 重命名
// Chmod 修改权限
if ev.Op&fsnotify.Create == fsnotify.Create {
config.Logger.Info("创建文件", zap.String("文件名", ev.Name))
list.Add(ev.Name)
}
}
case <-t1.C:
{
if list.Size() > 0 {
returnFileHandleResult(list, payload, issue)
}
}
case err = <-watch.Errors:
{
config.Logger.With(zap.String("监控目录", outPath)).
Error("文件监控", zap.Error(err))
return
}
}
}
}()
}
}
}
}
return payload["subDataTag"].(byte), nil
})
MqList = append(MqList, nodeInfo)
}
}
}
func returnFileHandleResult(list *arraylist.List, payload map[string]interface{}, issue *docker.ContainerStore) {
var (
mu sync.RWMutex
wgp sync.WaitGroup
resTime time.Duration
)
mu.Lock()
defer mu.Unlock()
startTime := time.Now()
for i := 0; i < list.Size(); i++ {
if fn, ok := list.Get(0); ok {
if utils.PathExists(fn.(string)) {
wgp.Add(1)
go func() {
mr := new(proto.ModelResult)
src := utils.ReadFile(fn.(string))
if src != nil {
mr.File = base64.StdEncoding.EncodeToString(src)
mr.TaskCode = utils.GetFileName(fn.(string))
mr.TaskId = int64(payload["taskId"].(float64))
mr.FileName = utils.GetFileNameAndExt(fn.(string))
mr.DatasetName = payload["datasetName"].(string)
mr.SubDataset = payload["subDataset"].(string)
mr.FileMd5 = utils.GetFileMd5(src)
mr.ModelId = int64(payload["modelId"].(float64))
mr.NodeId = int64(payload["nodeId"].(float64))
mr.StartTime = startTime.Unix()
mr.FinishTime = time.Now().Unix()
cli := GetMqClient("task-response", 1)
ap := cli.EndPoint.(hpds_node.AccessPoint)
res := new(InstructionReq)
res.Command = TaskResponse
res.Payload = mr
pData, _ := json.Marshal(res)
_ = GenerateAndSendData(ap, pData)
}
wg.Done()
}()
wg.Wait()
resTime = time.Since(startTime)
config.Logger.Info("返回任务完成",
zap.String("文件名", fn.(string)),
zap.Duration("运行时间", resTime),
)
}
}
}
}