commit c9dd1ee1de22f169d802be4076902bd0f00cca0d Author: wangjian Date: Thu Mar 23 14:35:24 2023 +0800 1、初始化代码 diff --git a/cmd/server.go b/cmd/server.go new file mode 100644 index 0000000..6ce2487 --- /dev/null +++ b/cmd/server.go @@ -0,0 +1,119 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "github.com/gammazero/workerpool" + "github.com/spf13/cobra" + "go.uber.org/zap" + "os" + "os/signal" + "syscall" + "taskExecute/config" + "taskExecute/mq" + "taskExecute/pkg/docker" + "taskExecute/pkg/utils" + + "git.hpds.cc/Component/logging" +) + +var ( + ConfigFileFlag string = "./config/config.yaml" +) + +func must(err error) { + if err != nil { + _, _ = fmt.Fprint(os.Stderr, err) + os.Exit(1) + } +} +func NewStartCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "start", + Short: "Start hpds_web application", + Run: func(cmd *cobra.Command, args []string) { + var ( + cfg *config.TaskExecutorConfig + err error + ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + configFileFlag, err := cmd.Flags().GetString("c") + if err != nil { + fmt.Println("get local config err: ", err) + return + } + must(err) + cfg, err = config.ParseConfigByFile(configFileFlag) + must(err) + logger := LoadLoggerConfig(cfg.Logging) + config.Cfg = cfg + if len(cfg.TmpPath) > 0 { + _ = os.MkdirAll(cfg.TmpPath, 0755) + } + config.Logger = logger + //创建本地容器配置list + docker.ContainerList = make([]docker.ContainerStore, 0) + b := utils.PathExists(cfg.Store) + if b { + store, err := os.ReadFile(cfg.Store) + must(err) + err = json.Unmarshal(store, &docker.ContainerList) + must(err) + } else { + f, _ := os.Create(cfg.Store) + defer func() { + _ = f.Close() + }() + } + + exitChannel := make(chan os.Signal) + defer close(exitChannel) + // 退出信号监听 + go func(c chan os.Signal) { + docker.SaveStore() + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + }(exitChannel) + //创建消息连接点 + mq.MqList, err = mq.NewMqClient(cfg.Functions, cfg.Node, logger) + must(err) + //任务队列 + config.WorkPool = workerpool.New(cfg.TaskPoolCount) + + for { + select { + case <-ctx.Done(): + logger.With( + zap.String("web", "exit"), + ).Error(ctx.Err().Error()) + return + case errs := <-exitChannel: + logger.With( + zap.String("web", "服务退出"), + ).Info(errs.String()) + return + } + } + }, + } + cmd.Flags().StringVar(&ConfigFileFlag, "c", "./config/config.yaml", "The configuration file path") + return cmd +} + +func LoadLoggerConfig(opt config.LogOptions) *logging.Logger { + return logging.NewLogger( + logging.SetPath(opt.Path), + logging.SetPrefix(opt.Prefix), + logging.SetDevelopment(opt.Development), + logging.SetDebugFileSuffix(opt.DebugFileSuffix), + logging.SetWarnFileSuffix(opt.WarnFileSuffix), + logging.SetErrorFileSuffix(opt.ErrorFileSuffix), + logging.SetInfoFileSuffix(opt.InfoFileSuffix), + logging.SetMaxAge(opt.MaxAge), + logging.SetMaxBackups(opt.MaxBackups), + logging.SetMaxSize(opt.MaxSize), + logging.SetLevel(logging.LogLevel["debug"]), + ) +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..d0ded30 --- /dev/null +++ b/config/config.go @@ -0,0 +1,75 @@ +package config + +import ( + "git.hpds.cc/Component/logging" + "github.com/gammazero/workerpool" + "gopkg.in/yaml.v3" + "os" +) + +var ( + Cfg *TaskExecutorConfig + Logger *logging.Logger + WorkPool *workerpool.WorkerPool +) + +type TaskExecutorConfig struct { + Name string `yaml:"name,omitempty"` + Mode string `yaml:"mode,omitempty"` + TmpPath string `yaml:"tmpPath"` + Store string `yaml:"store"` + TaskPoolCount int `yaml:"taskPoolCount"` + Logging LogOptions `yaml:"logging"` + Minio MinioConfig `yaml:"minio"` + Node HpdsNode `yaml:"node,omitempty"` + Functions []FuncConfig `yaml:"functions,omitempty"` +} + +type MinioConfig struct { + Protocol string `yaml:"protocol"` //http or https + Endpoint string `yaml:"endpoint"` + AccessKeyId string `yaml:"accessKeyId"` + SecretAccessKey string `yaml:"secretAccessKey"` + Bucket string `yaml:"bucket"` +} + +type FuncConfig struct { + Name string `yaml:"name"` + DataTag uint8 `yaml:"dataTag"` + MqType uint `yaml:"mqType"` //消息类型, 发布,1;订阅;2 +} + +type HpdsNode struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Token string `yaml:"token,omitempty"` +} + +type LogOptions struct { + Path string `yaml:"path" json:"path" toml:"path"` // 文件保存地方 + Prefix string `yaml:"prefix" json:"prefix" toml:"prefix"` // 日志文件前缀 + ErrorFileSuffix string `yaml:"errorFileSuffix" json:"errorFileSuffix" toml:"errorFileSuffix"` // error日志文件后缀 + WarnFileSuffix string `yaml:"warnFileSuffix" json:"warnFileSuffix" toml:"warnFileSuffix"` // warn日志文件后缀 + InfoFileSuffix string `yaml:"infoFileSuffix" json:"infoFileSuffix" toml:"infoFileSuffix"` // info日志文件后缀 + DebugFileSuffix string `yaml:"debugFileSuffix" json:"debugFileSuffix" toml:"debugFileSuffix"` // debug日志文件后缀 + Level string `yaml:"level" json:"level" toml:"level"` // 日志等级 + MaxSize int `yaml:"maxSize" json:"maxSize" toml:"maxSize"` // 日志文件大小(M) + MaxBackups int `yaml:"maxBackups" json:"maxBackups" toml:"maxBackups"` // 最多存在多少个切片文件 + MaxAge int `yaml:"maxAge" json:"maxAge" toml:"maxAge"` // 保存的最大天数 + Development bool `yaml:"development" json:"development" toml:"development"` // 是否是开发模式 +} + +func ParseConfigByFile(path string) (cfg *TaskExecutorConfig, err error) { + buffer, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return load(buffer) +} + +func load(buf []byte) (cfg *TaskExecutorConfig, err error) { + cfg = new(TaskExecutorConfig) + cfg.Functions = make([]FuncConfig, 0) + err = yaml.Unmarshal(buf, cfg) + return +} diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..e54b29c --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,32 @@ +name: task-execute +mode: dev +tmpPath : ./tmp +store: ./config/store.json +taskPoolCount: 1 +logging: + path: ./logs + prefix: hpds-task-execute + errorFileSuffix: error.log + warnFileSuffix: warn.log + infoFileSuffix: info.log + debugFileSuffix: debug.log + maxSize: 100 + maxBackups: 3000 + maxAge: 30 + development: true +minio: + protocol: http + endpoint: 127.0.0.1:9000 + accessKeyId: root + secretAccessKey: OIxv7QptYBO3 +node: + host: 127.0.0.1 + port: 27188 + token: 06d36c6f5705507dae778fdce90d0767 +functions: + - name: task-response + dataTag: 14 + mqType: 1 + - name: task-execute + dataTag: 16 + mqType: 2 \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..185364f --- /dev/null +++ b/go.mod @@ -0,0 +1,75 @@ +module taskExecute + +go 1.19 + +require ( + git.hpds.cc/Component/logging v0.0.0-20230106105738-e378e873921b + git.hpds.cc/Component/network v0.0.0-20221012021659-2433c68452d5 + git.hpds.cc/pavement/hpds_node v0.0.0-20230307094826-753c4fe9c877 + github.com/docker/docker v23.0.1+incompatible + github.com/docker/go-connections v0.4.0 + github.com/emirpasic/gods v1.18.1 + github.com/fsnotify/fsnotify v1.4.9 + github.com/gammazero/workerpool v1.1.3 + github.com/minio/minio-go/v7 v7.0.49 + github.com/shirou/gopsutil/v3 v3.23.2 + github.com/spf13/cobra v1.6.1 + go.uber.org/zap v1.24.0 + golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + git.hpds.cc/Component/mq_coder v0.0.0-20221010064749-174ae7ae3340 // indirect + github.com/Microsoft/go-winio v0.6.0 // indirect + github.com/docker/distribution v2.8.1+incompatible // indirect + github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/gammazero/deque v0.2.0 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/mock v1.6.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/inconshreveable/mousetrap v1.0.1 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.15.15 // indirect + github.com/klauspost/cpuid/v2 v2.2.3 // indirect + github.com/lucas-clemente/quic-go v0.29.1 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/matoous/go-nanoid/v2 v2.0.0 // indirect + github.com/minio/md5-simd v1.1.2 // indirect + github.com/minio/sha256-simd v1.0.0 // indirect + github.com/moby/term v0.0.0-20221205130635-1aeaba878587 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/morikuni/aec v1.0.0 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.4 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.0.2 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/rs/xid v1.4.0 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/tklauser/go-sysconf v0.3.11 // indirect + github.com/tklauser/numcpus v0.6.0 // indirect + github.com/yusufpapurcu/wmi v1.2.2 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.6.0 // indirect + golang.org/x/crypto v0.6.0 // indirect + golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/net v0.7.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect + golang.org/x/time v0.3.0 // indirect + golang.org/x/tools v0.1.12 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect + gotest.tools/v3 v3.4.0 // indirect +) diff --git a/main.go b/main.go new file mode 100644 index 0000000..3c1e41c --- /dev/null +++ b/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "fmt" + "github.com/spf13/cobra" + "os" + "taskExecute/cmd" +) + +var ( + rootCmd = &cobra.Command{ + Use: "hpds_task_execute", + Long: "hpds_task_execute is a task execute", + Version: "0.1", + } +) + +func init() { + rootCmd.AddCommand(cmd.NewStartCmd()) +} +func main() { + if err := rootCmd.Execute(); err != nil { + _, _ = fmt.Fprint(os.Stderr, err.Error()) + os.Exit(1) + } +} diff --git a/mq/handler.go b/mq/handler.go new file mode 100644 index 0000000..831397b --- /dev/null +++ b/mq/handler.go @@ -0,0 +1,401 @@ +package mq + +import ( + "bufio" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "git.hpds.cc/pavement/hpds_node" + "github.com/emirpasic/gods/lists/arraylist" + "github.com/fsnotify/fsnotify" + "github.com/gammazero/workerpool" + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "github.com/shirou/gopsutil/v3/host" + "go.uber.org/zap" + "io" + "net/http" + "os" + "os/exec" + "path" + "strconv" + "strings" + "sync" + "taskExecute/config" + "taskExecute/pkg/compress" + "taskExecute/pkg/docker" + "taskExecute/pkg/download" + "taskExecute/pkg/utils" + "taskExecute/proto" + "time" +) + +var ( + wg sync.WaitGroup + TaskList map[string]docker.ContainerStore +) + +func TaskExecuteHandler(data []byte) (byte, []byte) { + fmt.Println("接收数据", string(data)) + cmd := new(InstructionReq) + err := json.Unmarshal(data, cmd) + if err != nil { + return 0x0B, []byte(err.Error()) + } + switch cmd.Command { + case TaskExecute: + //任务执行 + waitWorkerStartFinish(config.WorkPool, cmd.Payload.(map[string]interface{}), ModelTaskExecuteHandler) + case ModelIssueRepeater: + //模型下发 + waitWorkerStartFinish(config.WorkPool, cmd.Payload.(map[string]interface{}), ModelIssueRepeaterHandler) + } + return byte(cmd.Command), nil +} + +func waitWorkerStartFinish(wp *workerpool.WorkerPool, payload map[string]interface{}, f func(payload map[string]interface{})) { + startStop := make(chan time.Time, 2) + wp.Submit(func() { + startStop <- time.Now() + f(payload) + startStop <- time.Now() + }) + fmt.Println("Task started at:", <-startStop) + fmt.Println("Task finished at:", <-startStop) +} + +// execCommand 执行命令 +func execCommandWait(commandName string, params []string) bool { + cmd := exec.Command(commandName, params...) + + //显示运行的命令 + fmt.Println(cmd.Args) + + stdout, err := cmd.StdoutPipe() + + if err != nil { + fmt.Println(err) + return false + } + + _ = cmd.Start() + + reader := bufio.NewReader(stdout) + + //实时循环读取输出流中的一行内容 + for { + wg.Add(1) + line, err2 := reader.ReadString('\n') + if err2 != nil || io.EOF == err2 { + break + } + config.Logger.Info("执行命令", + zap.String("execCommandWait", line)) + } + + _ = cmd.Wait() + wg.Done() + return true +} + +func ModelIssueRepeaterHandler(payload map[string]interface{}) { + hi, _ := host.Info() + if payload["nodeGuid"].(string) == hi.HostID { + fileUrl := payload["dockerFile"].(string) + modelVersion := payload["modelVersion"].(string) + downFileName := path.Base(fileUrl) + //判断文件后缀名 + fileType := path.Ext(downFileName) + fileNameOnly := strings.TrimSuffix(downFileName, fileType) + dFile := path.Join(config.Cfg.TmpPath, fileNameOnly, downFileName) + //执行文件下载 + controller := download.ThreadController{ + ThreadCount: download.ThreadCount, + FileUrl: fileUrl, + DownloadFolder: dFile, + DownloadFileName: downFileName, + Logger: config.Logger, + } + controller.Download(download.OneThreadDownloadSize) + if strings.ToLower(fileType) == ".zip" { + err := compress.UnzipFromFile(path.Join(config.Cfg.TmpPath, fileNameOnly), dFile) + if err != nil { + controller.Logger.With(zap.String("文件解压缩", path.Join(config.Cfg.TmpPath, downFileName))). + Error("发生错误", zap.Error(err)) + return + } + dFile = path.Join(config.Cfg.TmpPath, fileNameOnly, fileNameOnly+".tar") + } + //docker 导入并运行 + imgName := fmt.Sprintf("%s:%s", fileNameOnly, modelVersion) + if strings.ToLower(path.Ext(dFile)) == ".tar" { + dCli := docker.NewDockerClient() + err := dCli.ImportImage(imgName, "latest", dFile) + //err = dCli.LoadImage(dFile) + if err != nil { + controller.Logger.With(zap.String("导入docker的文件", dFile)). + Error("发生错误", zap.Error(err)) + } + //设置data目录 + dataPath := path.Join(config.Cfg.TmpPath, fileNameOnly, "data") + _ = os.MkdirAll(dataPath, os.ModePerm) + vol := make(map[string]string) + vol[path.Join(dataPath, payload["inPath"].(string))] = payload["inPath"].(string) + vol[path.Join(dataPath, payload["outPath"].(string))] = payload["outPath"].(string) + //docker运行 + modelCommand := strings.Split(payload["modelCommand"].(string), " ") + + dstPort := dCli.CreateContainer(fileNameOnly, imgName, modelCommand, vol, strconv.Itoa(payload["mappedPort"].(int))) + //保存到本地临时文件 + item := docker.ContainerStore{ + ModelId: payload["modelId"].(int64), + NodeId: payload["nodeId"].(int64), + Name: fileNameOnly, + ImgName: imgName, + Volumes: vol, + SrcPort: strconv.Itoa(payload["mappedPort"].(int)), + DstPort: dstPort, + Command: modelCommand, + HttpUrl: payload["httpUrl"].(string), + } + docker.ContainerList = append(docker.ContainerList, item) + docker.SaveStore() + cli := GetMqClient("task-response", 1) + ap := cli.EndPoint.(hpds_node.AccessPoint) + res := new(InstructionReq) + res.Command = ModelIssueResponse + res.Payload = item + pData, _ := json.Marshal(res) + _ = GenerateAndSendData(ap, pData) + } + } +} + +func ModelTaskExecuteHandler(payload map[string]interface{}) { + hi, _ := host.Info() + if payload["nodeGuid"] == hi.HostID { + if len(payload["subDataset"].(string)) > 0 { + sf := hpds_node.NewStreamFunction( + payload["subDataset"].(string), + hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", config.Cfg.Node.Host, config.Cfg.Node.Port)), + hpds_node.WithObserveDataTags(payload["subDataTag"].(byte)), + hpds_node.WithCredential(config.Cfg.Node.Token), + ) + err := sf.Connect() + must(config.Logger, err) + nodeInfo := HpdsMqNode{ + MqType: 2, + Topic: payload["subDataset"].(string), + Node: config.Cfg.Node, + EndPoint: sf, + } + _ = sf.SetHandler(func(data []byte) (byte, []byte) { + + //查询docker是否已经开启 + issue := new(docker.ContainerStore) + _ = json.Unmarshal([]byte(payload["issueResult"].(string)), issue) + dCli := docker.NewDockerClient() + cList, err := dCli.SearchImage(issue.Name) + if err != nil { + + } + if len(cList) > 0 { + if len(payload["workflow"].(string)) > 0 { + //是否设置工作流程 + wf := new(Workflow) + err = json.Unmarshal([]byte(payload["workflow"].(string)), wf) + if err != nil { + + } + if len(payload["datasetPath"].(string)) > 0 { + //数据集处理 + opt := &minio.Options{ + Creds: credentials.NewStaticV4(config.Cfg.Minio.AccessKeyId, config.Cfg.Minio.SecretAccessKey, ""), + Secure: false, + } + cli, _ := minio.New(config.Cfg.Minio.Endpoint, opt) + doneCh := make(chan struct{}) + defer close(doneCh) + objectCh := cli.ListObjects(context.Background(), config.Cfg.Minio.Bucket, minio.ListObjectsOptions{ + Prefix: payload["datasetPath"].(string), + Recursive: true, + }) + for object := range objectCh { + file, _ := cli.GetObject(context.Background(), config.Cfg.Minio.Bucket, object.Key, minio.GetObjectOptions{}) + imgByte, _ := io.ReadAll(file) + + f := proto.FileCapture{ + FileName: object.Key, + File: base64.StdEncoding.EncodeToString(imgByte), + DatasetName: payload["datasetName"].(string), + CaptureTime: object.LastModified.Unix(), + } + ProcessWorkflow(payload, f, wf) + } + + } + } else { + f := new(proto.FileCapture) + err := json.Unmarshal(data, f) + if err != nil { + + } + if len(f.File) > 0 { + + i := strings.Index(f.File, ",") + dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(f.File[i+1:])) + if len(payload["httpUrl"].(string)) > 0 { + _ = os.MkdirAll(path.Join(config.Cfg.TmpPath, payload["subDataset"].(string)), os.ModePerm) + tmpFile, _ := os.Create(path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName)) + _, err = io.Copy(tmpFile, dec) + + reqUrl := fmt.Sprintf("http://localhost:%s/%s", issue.DstPort, issue.HttpUrl) + response, err := http.Post(reqUrl, "multipart/form-data", dec) + if err != nil { + config.Logger.With(zap.String("源文件名", f.FileName)). + With(zap.String("临时文件名", path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName))). + Error("文件提交", zap.Error(err)) + } + defer func() { + _ = response.Body.Close() + config.Logger.With(zap.String("源文件名", f.FileName)). + With(zap.String("临时文件名", path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName))). + Info("模型识别") + }() + body, err := io.ReadAll(response.Body) + if err != nil { + config.Logger.With(zap.String("源文件名", f.FileName)). + With(zap.String("临时文件名", path.Join(config.Cfg.TmpPath, payload["subDataset"].(string), f.FileName))). + Error("模型识别", zap.Error(err)) + } + cli := GetMqClient("task-response", 1) + ap := cli.EndPoint.(hpds_node.AccessPoint) + res := new(InstructionReq) + res.Command = TaskResponse + res.Payload = body + pData, _ := json.Marshal(res) + _ = GenerateAndSendData(ap, pData) + } + if len(payload["inPath"].(string)) > 0 { + outPath := "" + for k, v := range issue.Volumes { + if v == payload["outPath"].(string) { + outPath = k + break + } + } + //创建一个监控对象 + watch, err := fsnotify.NewWatcher() + if err != nil { + config.Logger.Error("创建文件监控", zap.Error(err)) + } + defer func(watch *fsnotify.Watcher) { + _ = watch.Close() + }(watch) + + err = watch.Add(outPath) + if err != nil { + config.Logger.With(zap.String("监控目录", outPath)). + Error("创建文件监控", zap.Error(err)) + } + for k, v := range issue.Volumes { + if v == payload["inPath"].(string) { + _ = os.MkdirAll(k, os.ModePerm) + tmpFile, _ := os.Create(path.Join(k, f.FileName)) + _, err = io.Copy(tmpFile, dec) + break + } + } + list := arraylist.New() // empty + t1 := time.NewTicker(1 * time.Second) + go func() { + for { + select { + case ev := <-watch.Events: + { + //判断事件发生的类型,如下5种 + // Create 创建 + // Write 写入 + // Remove 删除 + // Rename 重命名 + // Chmod 修改权限 + if ev.Op&fsnotify.Create == fsnotify.Create { + config.Logger.Info("创建文件", zap.String("文件名", ev.Name)) + list.Add(ev.Name) + } + } + case <-t1.C: + { + if list.Size() > 0 { + returnFileHandleResult(list, payload, issue) + } + } + case err = <-watch.Errors: + { + config.Logger.With(zap.String("监控目录", outPath)). + Error("文件监控", zap.Error(err)) + return + } + } + } + }() + } + } + } + } + return payload["subDataTag"].(byte), nil + }) + MqList = append(MqList, nodeInfo) + } + } +} +func returnFileHandleResult(list *arraylist.List, payload map[string]interface{}, issue *docker.ContainerStore) { + var ( + mu sync.RWMutex + wgp sync.WaitGroup + resTime time.Duration + ) + mu.Lock() + defer mu.Unlock() + startTime := time.Now() + for i := 0; i < list.Size(); i++ { + if fn, ok := list.Get(0); ok { + if utils.PathExists(fn.(string)) { + wgp.Add(1) + go func() { + mr := new(proto.ModelResult) + src := utils.ReadFile(fn.(string)) + + if src != nil { + mr.File = base64.StdEncoding.EncodeToString(src) + mr.TaskCode = utils.GetFileName(fn.(string)) + mr.TaskId = int64(payload["taskId"].(float64)) + mr.FileName = utils.GetFileNameAndExt(fn.(string)) + mr.DatasetName = payload["datasetName"].(string) + mr.SubDataset = payload["subDataset"].(string) + mr.FileMd5 = utils.GetFileMd5(src) + mr.ModelId = int64(payload["modelId"].(float64)) + mr.NodeId = int64(payload["nodeId"].(float64)) + mr.StartTime = startTime.Unix() + mr.FinishTime = time.Now().Unix() + cli := GetMqClient("task-response", 1) + ap := cli.EndPoint.(hpds_node.AccessPoint) + res := new(InstructionReq) + res.Command = TaskResponse + res.Payload = mr + pData, _ := json.Marshal(res) + _ = GenerateAndSendData(ap, pData) + } + wg.Done() + }() + wg.Wait() + resTime = time.Since(startTime) + config.Logger.Info("返回任务完成", + zap.String("文件名", fn.(string)), + zap.Duration("运行时间", resTime), + ) + } + } + } +} diff --git a/mq/index.go b/mq/index.go new file mode 100644 index 0000000..b86e26f --- /dev/null +++ b/mq/index.go @@ -0,0 +1,99 @@ +package mq + +import ( + "fmt" + "git.hpds.cc/Component/logging" + "go.uber.org/zap" + "os" + "taskExecute/config" + "time" + + "git.hpds.cc/pavement/hpds_node" +) + +var MqList []HpdsMqNode + +type HpdsMqNode struct { + MqType uint + Topic string + Node config.HpdsNode + EndPoint interface{} + Logger *logging.Logger +} + +func must(logger *logging.Logger, err error) { + if err != nil { + if logger != nil { + logger.With(zap.String("任务执行节点", "错误信息")).Error("启动错误", zap.Error(err)) + } else { + _, _ = fmt.Fprint(os.Stderr, err) + } + os.Exit(1) + } +} + +func NewMqClient(funcs []config.FuncConfig, node config.HpdsNode, logger *logging.Logger) (mqList []HpdsMqNode, err error) { + mqList = make([]HpdsMqNode, 0) + for _, v := range funcs { + switch v.MqType { + case 2: + sf := hpds_node.NewStreamFunction( + v.Name, + hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)), + hpds_node.WithObserveDataTags(v.DataTag), + hpds_node.WithCredential(node.Token), + ) + err = sf.Connect() + must(logger, err) + nodeInfo := HpdsMqNode{ + MqType: 2, + Topic: v.Name, + Node: node, + EndPoint: sf, + } + switch v.Name { + case "task-execute": + _ = sf.SetHandler(TaskExecuteHandler) + default: + + } + mqList = append(mqList, nodeInfo) + default: + ap := hpds_node.NewAccessPoint( + v.Name, + hpds_node.WithMqAddr(fmt.Sprintf("%s:%d", node.Host, node.Port)), + hpds_node.WithCredential(node.Token), + ) + err = ap.Connect() + nodeInfo := HpdsMqNode{ + MqType: 1, + Topic: v.Name, + Node: node, + EndPoint: ap, + } + must(logger, err) + ap.SetDataTag(v.DataTag) + mqList = append(mqList, nodeInfo) + } + + } + return mqList, err +} + +func GetMqClient(topic string, mqType uint) *HpdsMqNode { + for _, v := range MqList { + if v.Topic == topic && v.MqType == mqType { + return &v + } + } + return nil +} + +func GenerateAndSendData(stream hpds_node.AccessPoint, data []byte) error { + _, err := stream.Write(data) + if err != nil { + return err + } + time.Sleep(1000 * time.Millisecond) + return nil +} diff --git a/mq/instruction.go b/mq/instruction.go new file mode 100644 index 0000000..a3ee695 --- /dev/null +++ b/mq/instruction.go @@ -0,0 +1,27 @@ +package mq + +const ( + TaskAdd = iota + 1 + ModelIssue + TaskExecute + TaskResponse + ModelIssueRepeater + ModelIssueResponse +) + +type InstructionReq struct { + Command int `json:"command"` + Payload interface{} `json:"payload"` +} + +type TaskResponseBody struct { + Code int `json:"code"` + TaskId int64 `json:"taskId"` + TaskCode string `json:"taskCode"` + NodeId int64 `json:"nodeId"` + ModelId int64 `json:"modelId"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + Msg string `json:"msg"` + Body string `json:"body"` +} diff --git a/mq/queue.go b/mq/queue.go new file mode 100644 index 0000000..794a9d6 --- /dev/null +++ b/mq/queue.go @@ -0,0 +1,77 @@ +package mq + +import "fmt" + +type WorkflowQueue struct { + buff []string //队列的的数据存储在数组上 + maxsize int //队列最大容量 + front int //队列头索引,不包括自己(队列头索引值-1) + rear int //队列尾索引 +} + +func NewQueue(size int) *WorkflowQueue { + return &WorkflowQueue{ + buff: make([]string, 0, size), + maxsize: 5, + front: -1, + rear: -1, + } +} + +// Push +// @Description: 压入队列 +// @Author: maxwell.ke +// @time 2022-10-25 22:58:58 +// @receiver q +// @param n +// @return error +func (q *WorkflowQueue) Push(id string) error { + if q.rear == q.maxsize-1 { + if q.front == -1 { //头尾都到头了 + return fmt.Errorf("队列已满,PUSH失败") + } else { + q.front = -1 + q.rear = len(q.buff) - 1 + } + } + q.rear++ + q.buff = append(q.buff, id) + return nil +} + +// Pop +// @Description: 出队列 +// @Author: maxwell.ke +// @time 2022-10-25 23:14:20 +// @receiver q +// @return n +// @return err +func (q *WorkflowQueue) Pop() (id string, err error) { + if len(q.buff) == 0 { + return "", fmt.Errorf("空队列,POP失败") + } + id = q.buff[0] + q.buff = q.buff[1:] + q.front++ + return id, nil +} + +// List +// @Description: 队列遍历 +// @Author: maxwell.ke +// @time 2022-10-25 23:13:10 +// @receiver q +// @return error +func (q *WorkflowQueue) List() error { + if len(q.buff) == 0 { + return fmt.Errorf("空队列") + } + for i := 0; i < q.maxsize; i++ { + if i > q.front && i <= q.rear { + fmt.Println(q.buff[i-q.front-1]) + } else { + return fmt.Errorf("空队列") + } + } + return nil +} diff --git a/mq/service.go b/mq/service.go new file mode 100644 index 0000000..f757970 --- /dev/null +++ b/mq/service.go @@ -0,0 +1,195 @@ +package mq + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "git.hpds.cc/pavement/hpds_node" + "image" + "strings" + "taskExecute/pkg/utils" + "taskExecute/proto" + "time" +) + +func CreateWorkflowQueue(wf *Workflow) *WorkflowQueue { + nodeId := "" + qList := NewQueue(len(wf.Nodes)) + for i := 0; i < len(wf.Nodes); i++ { + node := GetNextNode(wf, nodeId) + _ = qList.Push(node.Id) + nodeId = node.Id + } + return qList + //switch node.Type { + //case "start-node": + // node = GetNextNode(wf, node.Id) + //case "image-node": + // //处理图像后 + // img, _ := ProcessWorkflowNode(node, payload, fc) + // payload["resImage"] = img + // nextNode := GetNextNode(wf, node.Id) + // + //case "fetch-node": + //case "model-node": + //case "mq-node": + //default: + // + //} +} +func ProcessWorkflow(payload map[string]interface{}, fc proto.FileCapture, wf *Workflow) { + qList := CreateWorkflowQueue(wf) + var ( + img image.Image + //imgBase64 string + imgType string = "jpeg" + err error + resultData string + ) + startTime := time.Now().Unix() + for i := 0; i < len(wf.Nodes); i++ { + nodeId, _ := qList.Pop() + node := GetWorkflowNodeById(wf, nodeId) + switch node.Type { + case "start-node": + continue + case "image-node": + //处理图像后 + fn, _ := base64.StdEncoding.DecodeString(fc.File) + if node.Properties.NodeData.Method == "crop" { + img, imgType, err = utils.Clip(fn, node.Properties.NodeData.Width, node.Properties.NodeData.Height, node.Properties.NodeData.EqualProportion) + if err != nil { + goto ReturnPoint + } + } else if node.Properties.NodeData.Method == "gray" { + img, err = utils.Gray(fn) + if err != nil { + goto ReturnPoint + } + } else if node.Properties.NodeData.Method == "rotate" { + switch node.Properties.NodeData.RotationAngle { + case 90: + img = utils.Rotate90(fn) + case 180: + img = utils.Rotate180(fn) + case 270: + img = utils.Rotate270(fn) + default: + img = utils.BuffToImage(fn) + } + } else if node.Properties.NodeData.Method == "formatConversion" { + img = utils.BuffToImage(fn) + switch node.Properties.NodeData.Format { + case "bmp": + imgType = "bmp" + case "png": + imgType = "png" + case "tiff": + imgType = "tiff" + default: + imgType = "jpeg" + } + } + case "fetch-node": + header := make(map[string]string) + header["ContentType"] = node.Properties.NodeData.ContentType + param := make(map[string]string) + isBody := false + for _, val := range node.Properties.NodeData.DynamicValidateForm.Fields { + switch val.Type { + case "fileName": + param[val.Key] = fc.FileName + case "imgBase64": + param[val.Key] = utils.ImageToBase64(img, imgType) + default: + isBody = true + } + } + if !isBody { + data, err := utils.HttpDo(fmt.Sprintf("%s%s", node.Properties.NodeData.Proto, node.Properties.NodeData.Url), + strings.ToUpper(node.Properties.NodeData.MethodType), param, header) + if err != nil { + goto ReturnPoint + } + resultData = string(data) + } else { + buff := utils.ImageToBuff(img, imgType) + files := make([]utils.UploadFile, 1) + files[0] = utils.UploadFile{ + Name: "file", + Filepath: "./output.jpg", + File: buff, + } + data := utils.PostFile(fmt.Sprintf("%s%s", node.Properties.NodeData.Proto, node.Properties.NodeData.Url), + param, "multipart/form-data", files, header) + resultData = data + } + + case "model-node": + continue + case "mq-node": + continue + default: + continue + } + } +ReturnPoint: + item := new(TaskResponseBody) + item.TaskId = int64(payload["taskId"].(float64)) + item.TaskCode = payload["taskCode"].(string) + item.NodeId = int64(payload["nodeId"].(float64)) + item.ModelId = int64(payload["modelId"].(float64)) + item.StartTime = startTime + item.FinishTime = time.Now().Unix() + if err != nil { + item.Code = 500 + item.Msg = fmt.Sprintf("执行任务:%s", err.Error()) + } else { + item.Code = 0 + item.Msg = "执行成功" + item.Body = resultData + } + cli := GetMqClient("task-response", 1) + ap := cli.EndPoint.(hpds_node.AccessPoint) + res := new(InstructionReq) + res.Command = TaskResponse + res.Payload = item + pData, _ := json.Marshal(res) + _ = GenerateAndSendData(ap, pData) +} + +func GetNextNode(wf *Workflow, currNodeId string) (node *WorkflowNode) { + var nextId string + if len(currNodeId) > 0 { + //下一节点 + for _, v := range wf.Edges { + if v.SourceNodeId == currNodeId { + nextId = v.TargetNodeId + } + } + } else { + //开始节点 + for _, v := range wf.Nodes { + if v.Type == "start-node" { + return &v + } + } + } + if len(nextId) > 0 { + for _, v := range wf.Nodes { + if v.Id == nextId { + return &v + } + } + } + return nil +} + +func GetWorkflowNodeById(wf *Workflow, id string) (node *WorkflowNode) { + for _, v := range wf.Nodes { + if v.Id == id { + return &v + } + } + return nil +} diff --git a/mq/workflow.go b/mq/workflow.go new file mode 100644 index 0000000..c6709f2 --- /dev/null +++ b/mq/workflow.go @@ -0,0 +1,73 @@ +package mq + +type Workflow struct { + Nodes []WorkflowNode `json:"nodes"` + Edges []WorkflowEdge `json:"edges"` +} + +type WorkflowNode struct { + Id string `json:"id"` + Type string `json:"type"` + X int `json:"x"` + Y int `json:"y"` + Properties NodeProperties `json:"properties,omitempty"` + Text NodePropertiesText `json:"text,omitempty"` +} + +type NodeProperties struct { + Ui string `json:"ui"` + Id string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + X int `json:"x,omitempty"` + Y int `json:"y,omitempty"` + Text NodePropertiesText `json:"text,omitempty"` + NodeData NodeData `json:"nodeData,omitempty"` +} + +type NodePropertiesText struct { + X int `json:"x"` + Y int `json:"y"` + Value string `json:"value"` +} + +type NodeData struct { + Method string `json:"method,omitempty"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + EqualProportion bool `json:"equalProportion,omitempty"` + RotationAngle int `json:"rotationAngle,omitempty"` + Format string `json:"format,omitempty"` + ResultFormat string `json:"resultFormat,omitempty"` + Proto string `json:"proto,omitempty"` + Url string `json:"url,omitempty"` + MethodType string `json:"methodType,omitempty"` + ContentType string `json:"contentType,omitempty"` + DynamicValidateForm DynamicForm `json:"dynamicValidateForm,omitempty"` + ResultData interface{} `json:"resultData,omitempty"` + Topic string `json:"topic,omitempty"` +} + +type DynamicForm struct { + Fields []RequestField `json:"fields"` +} + +type RequestField struct { + Key string `json:"key"` + Type string `json:"type"` + Id int64 `json:"id"` +} +type WorkflowEdge struct { + Id string `json:"id"` + Type string `json:"type"` + SourceNodeId string `json:"sourceNodeId"` + TargetNodeId string `json:"targetNodeId"` + StartPoint Point `json:"startPoint"` + EndPoint Point `json:"endPoint"` + Properties interface{} `json:"properties"` + PointsList []Point `json:"pointsList"` +} + +type Point struct { + X int `json:"x"` + Y int `json:"y"` +} diff --git a/pkg/compress/index.go b/pkg/compress/index.go new file mode 100644 index 0000000..1003c74 --- /dev/null +++ b/pkg/compress/index.go @@ -0,0 +1,100 @@ +package compress + +import ( + "archive/zip" + "bytes" + "io" + "os" + "path/filepath" +) + +// UnzipFromFile 解压压缩文件 +// @params dst string 解压后目标路径 +// @params src string 压缩文件目标路径 +func UnzipFromFile(dst, src string) error { + // 打开压缩文件 + zr, err := zip.OpenReader(filepath.Clean(src)) + if err != nil { + return err + } + defer func() { + _ = zr.Close() + }() + + // 解压 + return Unzip(dst, &zr.Reader) +} + +// UnzipFromBytes 解压压缩字节流 +// @params dst string 解压后目标路径 +// @params src []byte 压缩字节流 +func UnzipFromBytes(dst string, src []byte) error { + // 通过字节流创建zip的Reader对象 + zr, err := zip.NewReader(bytes.NewReader(src), int64(len(src))) + if err != nil { + return err + } + + // 解压 + return Unzip(dst, zr) +} + +// Unzip 解压压缩文件 +// @params dst string 解压后的目标路径 +// @params src *zip.Reader 压缩文件可读流 +func Unzip(dst string, src *zip.Reader) error { + // 强制转换一遍目录 + dst = filepath.Clean(dst) + // 遍历压缩文件 + for _, file := range src.File { + // 在闭包中完成以下操作可以及时释放文件句柄 + err := func() error { + // 跳过文件夹 + if file.Mode().IsDir() { + return nil + } + // 配置输出目标路径 + filename := filepath.Join(dst, file.Name) + // 创建目标路径所在文件夹 + e := os.MkdirAll(filepath.Dir(filename), os.ModeDir) + if e != nil { + return e + } + + // 打开这个压缩文件 + zfr, e := file.Open() + if e != nil { + return e + } + defer func() { + _ = zfr.Close() + }() + + // 创建目标文件 + fw, e := os.Create(filename) + if e != nil { + return e + } + defer func() { + _ = fw.Close() + }() + + // 执行拷贝 + _, e = io.Copy(fw, zfr) + if e != nil { + return e + } + + // 拷贝成功 + return nil + }() + + // 是否发生异常 + if err != nil { + return err + } + } + + // 解压完成 + return nil +} diff --git a/pkg/docker/index.go b/pkg/docker/index.go new file mode 100644 index 0000000..c1e0c5e --- /dev/null +++ b/pkg/docker/index.go @@ -0,0 +1,371 @@ +package docker + +import ( + "archive/tar" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/registry" + "github.com/docker/docker/client" + "github.com/docker/go-connections/nat" + "go.uber.org/zap" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "taskExecute/config" + "taskExecute/pkg/utils" +) + +var ( + ContainerList []ContainerStore +) + +func SaveStore() { + str, _ := json.Marshal(ContainerList) + _ = os.WriteFile(config.Cfg.Store, str, os.ModePerm) +} + +// Docker 1.Docker docker client +type Docker struct { + *client.Client +} + +// NewDockerClient 2.init docker client +func NewDockerClient() *Docker { + cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) + if err != nil { + return nil + } + return &Docker{ + cli, + } +} + +// Images get images from +func (d *Docker) Images(opt types.ImageListOptions) ([]types.ImageSummary, error) { + return d.ImageList(context.TODO(), opt) +} + +// PushImage --> pull image to harbor仓库 +func (d *Docker) PushImage(image, user, pwd string) error { + authConfig := types.AuthConfig{ + Username: user, //harbor用户名 + Password: pwd, //harbor 密码 + } + encodedJSON, err := json.Marshal(authConfig) + if err != nil { + return err + } + authStr := base64.URLEncoding.EncodeToString(encodedJSON) + out, err := d.ImagePush(context.TODO(), image, types.ImagePushOptions{RegistryAuth: authStr}) + if err != nil { + return err + } + + body, err := io.ReadAll(out) + if err != nil { + return err + } + fmt.Printf("Push docker image output: %v\n", string(body)) + + if strings.Contains(string(body), "error") { + return fmt.Errorf("push image to docker error") + } + + return nil +} + +// PullImage pull image +func (d *Docker) PullImage(name string) error { + resp, err := d.ImagePull(context.TODO(), name, types.ImagePullOptions{}) + + if err != nil { + return err + } + _, err = io.Copy(io.Discard, resp) + if err != nil { + return err + } + return nil +} + +// RemoveImage remove image 这里需要注意的一点就是移除了镜像之后, +// 会出现:的标签,这个是因为下载的镜像是分层的,所以删除会导致 +func (d *Docker) RemoveImage(name string) error { + _, err := d.ImageRemove(context.TODO(), name, types.ImageRemoveOptions{}) + return err +} + +// RemoveDanglingImages remove dangling images +func (d *Docker) RemoveDanglingImages() error { + opt := types.ImageListOptions{ + Filters: filters.NewArgs(filters.Arg("dangling", "true")), + } + + images, err := d.Images(opt) + if err != nil { + return err + } + + errIDs := make([]string, 0) + + for _, image := range images { + fmt.Printf("image.ID: %v\n", image.ID) + if err := d.RemoveImage(image.ID); err != nil { + errIDs = append(errIDs, image.ID[7:19]) + } + } + + if len(errIDs) > 1 { + return fmt.Errorf("can not remove ids\n%s", errIDs) + } + + return nil +} + +// SaveImage save image to tar file +func (d *Docker) SaveImage(ids []string, path string) error { + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666) + if err != nil { + return err + } + defer func() { + _ = file.Close() + }() + + out, err := d.ImageSave(context.TODO(), ids) + if err != nil { + return err + } + + if _, err = io.Copy(file, out); err != nil { + return err + } + + return nil +} + +// LoadImage load image from tar file +func (d *Docker) LoadImage(path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer func() { + _ = file.Close() + }() + _, err = d.ImageLoad(context.TODO(), file, true) + return err +} + +// ImportImage import image +func (d *Docker) ImportImage(name, tag, path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer func() { + _ = file.Close() + }() + source := types.ImageImportSource{ + Source: file, + SourceName: "-", + } + + opt := types.ImageImportOptions{ + Tag: tag, + } + + _, err = d.ImageImport(context.TODO(), source, name, opt) + + return err +} + +// SearchImage search images +func (d *Docker) SearchImage(name string) ([]registry.SearchResult, error) { + + return d.ImageSearch(context.TODO(), name, types.ImageSearchOptions{Limit: 100}) +} + +// BuildImage build image image 需要构建的镜像名称 +func (d *Docker) BuildImage(warName, image string) error { + + // 1.需要构建的war包上传到docker/web/目录下 + err := utils.CopyFile(fmt.Sprintf("/tmp/docker/%s", warName), fmt.Sprintf("docker/web/%s", warName)) + if err != nil { + return err + } + + var tags []string + tags = append(tags, image) + //打一个docker.tar包 + err = tarIt("docker/", ".") + if err != nil { + return err + } //src:要打包文件的源地址 target:要打包文件的目标地址 (使用相对路径-->相对于main.go) + //打开刚刚打的tar包 + dockerBuildContext, _ := os.Open("docker.tar") //打开打包的文件, + defer func(dockerBuildContext *os.File) { + _ = dockerBuildContext.Close() + }(dockerBuildContext) + options := types.ImageBuildOptions{ + Dockerfile: "docker/Dockerfile", //不能是绝对路径 是相对于build context来说的, + SuppressOutput: false, + Remove: true, + ForceRemove: true, + PullParent: true, + Tags: tags, //[]string{"192.168.0.1/harbor/cdisample:v1"} + } + buildResponse, err := d.ImageBuild(context.Background(), dockerBuildContext, options) + fmt.Printf("err build: %v\n", err) + if err != nil { + fmt.Printf("%s", err.Error()) + return err + } + fmt.Printf("********* %s **********", buildResponse.OSType) + response, err := io.ReadAll(buildResponse.Body) + if err != nil { + fmt.Printf("%s", err.Error()) + return err + } + fmt.Println(string(response)) + return nil + +} + +/* +source:打包的的路径 +target:放置打包文件的位置 +*/ +func tarIt(source string, target string) error { + filename := filepath.Base(source) + fmt.Println(filename) + target = filepath.Join(target, fmt.Sprintf("%s.tar", filename)) + //target := fmt.Sprintf("%s.tar", filename) + fmt.Println(target) + tarFile, err := os.Create(target) + if err != nil { + return err + } + fmt.Println(tarFile) + defer func(tarFile *os.File) { + _ = tarFile.Close() + }(tarFile) + + tarball := tar.NewWriter(tarFile) + // 这里不要忘记关闭,如果不能成功关闭会造成 tar 包不完整 + // 所以这里在关闭的同时进行判断,可以清楚的知道是否成功关闭 + defer func() { + if err := tarball.Close(); err != nil { + config.Logger.With(zap.String("docker打包的的路径", source)). + With(zap.String("放置打包文件的位置", target)).Error("错误信息", zap.Error(err)) + } + }() + + info, err := os.Stat(source) + if err != nil { + return nil + } + + var baseDir string + if info.IsDir() { + baseDir = filepath.Base(source) + } + + return filepath.Walk(source, + func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + header, err := tar.FileInfoHeader(info, info.Name()) + if err != nil { + return err + } + if baseDir != "" { + header.Name = filepath.Join(baseDir, strings.TrimPrefix(path, source)) + } + + if err := tarball.WriteHeader(header); err != nil { + return err + } + + if info.IsDir() { + return nil + } + + file, err := os.Open(path) + if err != nil { + return err + } + defer func(file *os.File) { + _ = file.Close() + }(file) + _, err = io.Copy(tarball, file) + return err + }) +} + +// CreateContainer create container +func (d *Docker) CreateContainer(containerName, image string, cmd []string, volumes map[string]string, srcPort string) string { + // 文件挂载 + m := make([]mount.Mount, 0, len(volumes)) + for k, v := range volumes { + m = append(m, mount.Mount{Type: mount.TypeBind, Source: k, Target: v}) + } + + exports := make(nat.PortSet) + netPort := make(nat.PortMap) + + // 网络端口映射 + natPort, _ := nat.NewPort("tcp", srcPort) + exports[natPort] = struct{}{} + dstPort, err := utils.GetAvailablePort() + if err != nil { + config.Logger.With( + zap.String("containerName", containerName), + zap.Strings("cmd", cmd), + zap.String("image", image), + zap.Int("dstPort", dstPort), + ).Error("创建镜像错误", zap.Error(err)) + } + portList := make([]nat.PortBinding, 0, 1) + portList = append(portList, nat.PortBinding{HostIP: "0.0.0.0", HostPort: strconv.Itoa(dstPort)}) + netPort[natPort] = portList + + ctx := context.Background() + // 创建容器 + resp, err := d.ContainerCreate(ctx, &container.Config{ + Image: image, + ExposedPorts: exports, + Cmd: cmd, + Tty: false, + // WorkingDir: workDir, + }, &container.HostConfig{ + PortBindings: netPort, + Mounts: m, + }, nil, nil, containerName) + + if err != nil { + config.Logger.With( + zap.String("containerName", containerName), + zap.Strings("cmd", cmd), zap.String("image", image), + ).Error("创建镜像错误", zap.Error(err)) + return "" + } + + if err := d.ContainerStart(ctx, resp.ID, types.ContainerStartOptions{}); err != nil { + config.Logger.With( + zap.String("containerName", containerName), + zap.Strings("cmd", cmd), zap.String("image", image), + ).Error("启动镜像错误", zap.Error(err)) + return "" + } + return strconv.Itoa(dstPort) +} diff --git a/pkg/docker/store.go b/pkg/docker/store.go new file mode 100644 index 0000000..9ef69ab --- /dev/null +++ b/pkg/docker/store.go @@ -0,0 +1,13 @@ +package docker + +type ContainerStore struct { + ModelId int64 `json:"modelId" yaml:"modelId"` + NodeId int64 `json:"nodeId" yaml:"nodeId"` + Name string `json:"name" yaml:"name"` + ImgName string `json:"imgName" yaml:"imgName"` + Volumes map[string]string `json:"volumes" yaml:"volumes"` + SrcPort string `json:"srcPort" yaml:"srcPort"` + DstPort string `json:"dstPort" yaml:"dstPort"` + Command []string `json:"command" yaml:"command"` + HttpUrl string `json:"httpUrl" yaml:"httpUrl"` +} diff --git a/pkg/download/index.go b/pkg/download/index.go new file mode 100644 index 0000000..ed0fc50 --- /dev/null +++ b/pkg/download/index.go @@ -0,0 +1,492 @@ +package download + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + "git.hpds.cc/Component/logging" + "go.uber.org/zap" + "io" + "net/http" + "os" + "strconv" + "strings" + "sync" + "time" +) + +const ( + OneThreadDownloadSize = 1024 * 1024 * 2 // 一个线程下载文件的大小 + ThreadCount = 6 +) + +type Task struct { + customFunc func(params interface{}) // 执行方法 + paramsInfo interface{} // 执行方法参数 +} + +type ThreadController struct { + TaskQueue chan Task // 用于接收下载任务 + TaskCount chan int // 用于记载当前任务数量 + Exit chan int // 用于记载当前任务数量 + ThreadCount int // 最大协程数 + WaitGroup sync.WaitGroup // 等待协程完成 + RangeStrs map[int]string // 所有需要下载的文件名 + FileUrl string // 下载链接 + DownloadResultInfoChan chan DownFileParams // 下载任务响应通道 + DownloadFolder string // 下载文件保存文件夹 + DownloadFileName string // 下载文件保存文件名 + Filenames []string // 子文件名,有序 + Logger *logging.Logger //日志 +} + +type DownFileParams struct { + UrlStr string + RangeStr string + RangeIndex int + TempFilename string + Succeed bool +} + +var fileTypeMap sync.Map + +func init() { //用于判断文件名的后缀 + fileTypeMap.Store("ffd8ffe000104a464946", "jpg") //JPEG (jpg) + fileTypeMap.Store("89504e470d0a1a0a0000", "png") //PNG (png) + fileTypeMap.Store("47494638396126026f01", "gif") //GIF (gif) + fileTypeMap.Store("49492a00227105008037", "tif") //TIFF (tif) + fileTypeMap.Store("424d228c010000000000", "bmp") //16色位图(bmp) + fileTypeMap.Store("424d8240090000000000", "bmp") //24位位图(bmp) + fileTypeMap.Store("424d8e1b030000000000", "bmp") //256色位图(bmp) + fileTypeMap.Store("41433130313500000000", "dwg") //CAD (dwg) + fileTypeMap.Store("3c21444f435459504520", "html") //HTML (html) 3c68746d6c3e0 3c68746d6c3e0 + fileTypeMap.Store("3c68746d6c3e0", "html") //HTML (html) 3c68746d6c3e0 3c68746d6c3e0 + fileTypeMap.Store("3c21646f637479706520", "htm") //HTM (htm) + fileTypeMap.Store("48544d4c207b0d0a0942", "css") //css + fileTypeMap.Store("696b2e71623d696b2e71", "js") //js + fileTypeMap.Store("7b5c727466315c616e73", "rtf") //Rich Text Format (rtf) + fileTypeMap.Store("38425053000100000000", "psd") //Photoshop (psd) + fileTypeMap.Store("46726f6d3a203d3f6762", "eml") //Email [Outlook Express 6] (eml) + fileTypeMap.Store("d0cf11e0a1b11ae10000", "doc") //MS Excel 注意:word、msi 和 excel的文件头一样 + fileTypeMap.Store("d0cf11e0a1b11ae10000", "vsd") //Visio 绘图 + fileTypeMap.Store("5374616E64617264204A", "mdb") //MS Access (mdb) + fileTypeMap.Store("252150532D41646F6265", "ps") + fileTypeMap.Store("255044462d312e350d0a", "pdf") //Adobe Acrobat (pdf) + fileTypeMap.Store("2e524d46000000120001", "rmvb") //rmvb/rm相同 + fileTypeMap.Store("464c5601050000000900", "flv") //flv与f4v相同 + fileTypeMap.Store("00000020667479706d70", "mp4") + fileTypeMap.Store("49443303000000002176", "mp3") + fileTypeMap.Store("000001ba210001000180", "mpg") // + fileTypeMap.Store("3026b2758e66cf11a6d9", "wmv") //wmv与asf相同 + fileTypeMap.Store("52494646e27807005741", "wav") //Wave (wav) + fileTypeMap.Store("52494646d07d60074156", "avi") + fileTypeMap.Store("4d546864000000060001", "mid") //MIDI (mid) + fileTypeMap.Store("504b0304140000000800", "zip") + fileTypeMap.Store("526172211a0700cf9073", "rar") + fileTypeMap.Store("235468697320636f6e66", "ini") + fileTypeMap.Store("504b03040a0000000000", "jar") + fileTypeMap.Store("4d5a9000030000000400", "exe") //可执行文件 + fileTypeMap.Store("3c25402070616765206c", "jsp") //jsp文件 + fileTypeMap.Store("4d616e69666573742d56", "mf") //MF文件 + fileTypeMap.Store("3c3f786d6c2076657273", "xml") //xml文件 + fileTypeMap.Store("494e5345525420494e54", "sql") //xml文件 + fileTypeMap.Store("7061636b616765207765", "java") //java文件 + fileTypeMap.Store("406563686f206f66660d", "bat") //bat文件 + fileTypeMap.Store("1f8b0800000000000000", "gz") //gz文件 + fileTypeMap.Store("6c6f67346a2e726f6f74", "properties") //bat文件 + fileTypeMap.Store("cafebabe0000002e0041", "class") //bat文件 + fileTypeMap.Store("49545346030000006000", "chm") //bat文件 + fileTypeMap.Store("04000000010000001300", "mxp") //bat文件 + fileTypeMap.Store("504b0304140006000800", "docx") //docx文件 + fileTypeMap.Store("d0cf11e0a1b11ae10000", "wps") //WPS文字wps、表格et、演示dps都是一样的 + fileTypeMap.Store("6431303a637265617465", "torrent") + fileTypeMap.Store("6D6F6F76", "mov") //Quicktime (mov) + fileTypeMap.Store("FF575043", "wpd") //WordPerfect (wpd) + fileTypeMap.Store("CFAD12FEC5FD746F", "dbx") //Outlook Express (dbx) + fileTypeMap.Store("2142444E", "pst") //Outlook (pst) + fileTypeMap.Store("AC9EBD8F", "qdf") //Quicken (qdf) + fileTypeMap.Store("E3828596", "pwl") //Windows Password (pwl) + fileTypeMap.Store("2E7261FD", "ram") //Real Audio (ram) +} + +// 获取前面结果字节的二进制 +func bytesToHexString(src []byte) string { + res := bytes.Buffer{} + if src == nil || len(src) <= 0 { + return "" + } + temp := make([]byte, 0) + for _, v := range src { + sub := v & 0xFF + hv := hex.EncodeToString(append(temp, sub)) + if len(hv) < 2 { + res.WriteString(strconv.FormatInt(int64(0), 10)) + } + res.WriteString(hv) + } + return res.String() +} + +func SafeMkdir(folder string) { + if _, err := os.Stat(folder); os.IsNotExist(err) { + _ = os.MkdirAll(folder, os.ModePerm) + } +} + +// GetFileType 用文件前面几个字节来判断 +// fSrc: 文件字节流(就用前面几个字节) +func GetFileType(fSrc []byte) string { + var fileType string + fileCode := bytesToHexString(fSrc) + + fileTypeMap.Range(func(key, value interface{}) bool { + k := key.(string) + v := value.(string) + if strings.HasPrefix(fileCode, strings.ToLower(k)) || + strings.HasPrefix(k, strings.ToLower(fileCode)) { + fileType = v + return false + } + return true + }) + return fileType +} +func GetBytesFile(filename string, bufferSize int) []byte { + file, err := os.Open(filename) + if err != nil { + fmt.Println(err) + return nil + } + defer func() { + _ = file.Close() + }() + + buffer := make([]byte, bufferSize) + _, err = file.Read(buffer) + if err != nil { + fmt.Println(err) + return nil + } + return buffer +} + +func (controller *ThreadController) GetSuffix(contentType string) string { + suffix := "" + contentTypes := map[string]string{ + "image/gif": "gif", + "image/jpeg": "jpg", + "application/x-img": "img", + "image/png": "png", + "application/json": "json", + "application/pdf": "pdf", + "application/msword": "word", + "application/octet-stream": "rar", + "application/x-zip-compressed": "zip", + "application/x-msdownload": "exe", + "video/mpeg4": "mp4", + "video/avi": "avi", + "audio/mp3": "mp3", + "text/css": "css", + "application/x-javascript": "js", + "application/vnd.android.package-archive": "apk", + } + for key, value := range contentTypes { + if strings.Contains(contentType, key) { + suffix = value + break + } + } + return suffix +} + +func (controller *ThreadController) Put(task Task) { + // 用于开启单个协程任务,下载文件的部分内容 + defer func() { + err := recover() //内置函数,可以捕捉到函数异常 + if err != nil { + controller.Logger.With(zap.String("文件下载", "错误信息")).Error("recover 错误", zap.Any("错误信息", err)) + } + }() + controller.WaitGroup.Add(1) // 每插入一个任务,就需要计数 + controller.TaskCount <- 1 // 含缓冲区的通道,用于控制下载器的协程最大数量 + controller.TaskQueue <- task // 插入下载任务 + //go task.customFunc(task.paramsInfo) +} + +func (controller *ThreadController) DownloadFile(paramsInfo interface{}) { + // 下载任务,接收对应的参数,负责从网页中下载对应部分的文件资源 + defer func() { + controller.WaitGroup.Done() // 下载任务完成,协程结束 + }() + switch paramsInfo.(type) { + case DownFileParams: + params := paramsInfo.(DownFileParams) + params.Succeed = false + defer func() { + err := recover() //内置函数,可以捕捉到函数异常 + if err != nil { + // 如果任意环节出错,表明下载流程未成功完成,标记下载失败 + controller.Logger.With(zap.String("文件下载", "错误信息")).Error("recover 错误", zap.Any("错误信息", err)) + params.Succeed = false + } + }() + //fmt.Println("Start to down load " + params.UrlStr + ", Content-type: " + params.RangeStr + " , save to file: " + params.TempFilename) + urlStr := params.UrlStr + rangeStr := params.RangeStr + tempFilename := params.TempFilename + _ = os.Remove(tempFilename) // 删除已有的文件, 避免下载的数据被污染 + // 发起文件下载请求 + req, _ := http.NewRequest("GET", urlStr, nil) + req.Header.Add("Range", rangeStr) // 测试下载部分内容 + res, err := http.DefaultClient.Do(req) // 发出下载请求,等待回应 + if err != nil { + controller.Logger.With(zap.String("文件下载", "错误信息")).Error("连接失败", zap.Error(err)) + params.Succeed = false // 无法连接, 标记下载失败 + } else if res.StatusCode != 206 { + params.Succeed = false + } else { // 能正常发起请求 + // 打开文件,写入文件 + fileObj, err := os.OpenFile(tempFilename, os.O_RDONLY|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + fmt.Println("Failed to open file " + tempFilename) + controller.Logger.With(zap.String("文件下载", "错误信息")). + With(zap.String("文件名", tempFilename)). + Error("打开文件失败", zap.Error(err)) + + params.Succeed = false // 无法打开文件, 标记下载失败 + } else { + defer func() { + _ = fileObj.Close() + }() // 关闭文件流 + body, err := io.ReadAll(res.Body) // 读取响应体的所有内容 + if err != nil { + controller.Logger.With(zap.String("文件下载", "错误信息")). + Error("读取返回值错误", zap.Error(err)) + params.Succeed = false + } else { + defer func() { + _ = res.Body.Close() + }() // 关闭连接流 + _, _ = fileObj.Write(body) // 写入字节数据到文件 + params.Succeed = true // 成功执行到最后一步,则表示下载成功 + } + } + } + controller.DownloadResultInfoChan <- params // 将下载结果传入 + } +} + +func (controller *ThreadController) Run() { + // 只需要将待下载的请求发送一次即可,成功了会直接剔除,不成功则由接收方重试 + for rangeIndex, rangeStr := range controller.RangeStrs { + params := DownFileParams{ + UrlStr: controller.FileUrl, + RangeStr: rangeStr, + TempFilename: controller.DownloadFolder + "/" + rangeStr, + RangeIndex: rangeIndex, + Succeed: true, + } // 下载参数初始化 + task := Task{controller.DownloadFile, params} + controller.Put(task) // 若通道满了会阻塞,等待空闲时再下载 + } +} + +func (controller *ThreadController) ResultProcess(trunkSize int) string { + // 负责处理各个协程下载资源的结果, 若成功则从下载列表中剔除,否则重新将该任务Put到任务列表中;超过5秒便会停止 + MaxRetryTime := 100 + nowRetryTime := 0 + resultMsg := "" + for { + select { + case resultInfo := <-controller.DownloadResultInfoChan: + <-controller.TaskCount // 取出一个计数器,表示一个协程已经完成 + if resultInfo.Succeed { // 成功下载该文件,清除文件名列表中的信息 + delete(controller.RangeStrs, resultInfo.RangeIndex) // 删除任务队列中的该任务(rangeStr队列) + fmt.Println("Download progress -> " + strconv.FormatFloat((1.0-float64(len(controller.RangeStrs))/float64(trunkSize))*100, 'f', 2, 64) + "%") + if len(controller.RangeStrs) == 0 { + resultMsg = "SUCCESSED" + break + } + } else { + nowRetryTime += 1 + if nowRetryTime > MaxRetryTime { // 超过最大的重试次数退出下载 + resultMsg = "MAX_RETRY" + break + } + task := Task{ + customFunc: controller.DownloadFile, + paramsInfo: resultInfo, + } // 重新加载该任务 + go controller.Put(task) + } + case task := <-controller.TaskQueue: + function := task.customFunc + go function(task.paramsInfo) + case <-time.After(5 * time.Second): + resultMsg = "TIMEOUT" + break + } + if resultMsg == "MAX_RETRY" { + fmt.Println("The network is unstable, exceeding the maximum number of downloads.") + break + } else if resultMsg == "SUCCESSED" { + fmt.Println("Download file success!") + break + } else if resultMsg == "TIMEOUT" { + fmt.Println("Download timeout!") + break + } + } + + close(controller.TaskCount) + close(controller.TaskQueue) + close(controller.DownloadResultInfoChan) + return resultMsg +} + +func (controller *ThreadController) Download(oneThreadDownloadSize int) bool { + fmt.Println("Try to parse the object file...") + length, rangeMaps, tempFilenames, contentType, err := TryDownload(controller.FileUrl, oneThreadDownloadSize) + fmt.Println("File total size -> " + strconv.FormatFloat(float64(length)/(1024.0*1024.0), 'f', 2, 64) + "M") + if err != nil { + fmt.Println("The file does not support multi-threaded download.") + return false + } + fmt.Println("Parse the target file successfully, start downloading the target file...") + controller.Init() // 初始化通道、分片等配置 + //oneThreadDownloadSize := 1024 * 1024 * 2 // 1024字节 = 1024bite = 1kb -> 2M + oneThreadDownloadSize = 1024 * 1024 * 4 // 1024字节 = 1024bite = 1kb -> 4M + filenames := make([]string, 0) + for _, value := range tempFilenames { + filenames = append(filenames, controller.DownloadFolder+"/"+value) + } + fileSuffix := controller.GetSuffix(contentType) + filename := controller.DownloadFileName // 获取文件下载名 + controller.Filenames = filenames //下载文件的切片列表 + controller.RangeStrs = rangeMaps // 下载文件的Range范围 + go controller.Run() // 开始下载文件 + processResult := controller.ResultProcess(len(rangeMaps)) + downloadResult := false // 定义下载结果标记 + if processResult == "SUCCESSED" { + absoluteFilename := controller.DownloadFolder + "/" + filename + "." + fileSuffix + downloadResult = controller.CombineFiles(filename + "." + fileSuffix) + if downloadResult { + newSuffix := GetFileType(GetBytesFile(absoluteFilename, 10)) + err = os.Rename(absoluteFilename, controller.DownloadFolder+"/"+filename+"."+newSuffix) + if err != nil { + downloadResult = false + fmt.Println("Combine file successes, Rename file failed " + absoluteFilename) + } else { + fmt.Println("Combine file successes, rename successes, new file name is -> " + controller.DownloadFolder + "/" + filename + "." + newSuffix) + } + } else { + fmt.Println("Failed to download file.") + } + } else { + fmt.Println("Failed to download file. Reason -> " + processResult) + downloadResult = false + } + return downloadResult +} + +func (controller *ThreadController) CombineFiles(filename string) bool { + _ = os.Remove(controller.DownloadFolder + "/" + filename) + goalFile, err := os.OpenFile(controller.DownloadFolder+"/"+filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, os.ModePerm) + if err != nil { + fmt.Println("Failed to open file ") + return false + } + + // 正确的话应按照初始计算的文件名顺序合并,并且无缺失 + for _, value := range controller.Filenames { + retryTime := 3 + tempFileBytes := make([]byte, 0) + for retryTime > 0 { + tempFileBytes = ReadFile(value) + time.Sleep(100) // 休眠100毫秒,看看是不是文件加载错误 + if tempFileBytes != nil { + break + } + retryTime = retryTime - 1 + } + _, _ = goalFile.Write(tempFileBytes) + _ = os.Remove(value) + } + _ = goalFile.Close() + return true +} + +func ReadFile(filename string) []byte { + tempFile, err := os.OpenFile(filename, os.O_RDONLY, os.ModePerm) + if err != nil { + fmt.Println("Failed to open file " + filename) + return nil + } + tempFileBytes, err := io.ReadAll(tempFile) + if err != nil { + fmt.Println("Failed to read file data " + filename) + return nil + } + _ = tempFile.Close() + return tempFileBytes +} + +func TryDownload(urlStr string, perThreadSize int) (int, map[int]string, []string, string, error) { + // 尝试连接目标资源,目标资源是否可以使用多线程下载 + length := 0 + rangeMaps := make(map[int]string) + req, _ := http.NewRequest("GET", urlStr, nil) + req.Header.Add("Range", "bytes=0-1") // 测试下载部分内容 + res, err := http.DefaultClient.Do(req) + contentType := "" + rangeIndex := 1 + filenames := make([]string, 0) + if err != nil { + rangeMaps[rangeIndex] = urlStr + return length, rangeMaps, filenames, contentType, errors.New("Failed to connect " + urlStr) + } + if res.StatusCode != 206 { + rangeMaps[rangeIndex] = urlStr + return length, rangeMaps, filenames, contentType, errors.New("Http status is not equal to 206! ") + } + // 206表示响应成功,仅仅返回部分内容 + contentLength := res.Header.Get("Content-Range") + contentType = res.Header.Get("Content-Type") + totalLength, err := strconv.Atoi(strings.Split(contentLength, "/")[1]) + if err != nil { + return length, rangeMaps, filenames, contentType, errors.New("Can't calculate the content-length form server " + urlStr) + } + nowLength := 0 // 记录byte偏移量 + for { + if nowLength >= totalLength { + break + } + var tempRangeStr string // 记录临时文件名 + if nowLength+perThreadSize >= totalLength { + tempRangeStr = "bytes=" + strconv.Itoa(nowLength) + "-" + strconv.Itoa(totalLength-1) + nowLength = totalLength + } else { + tempRangeStr = "bytes=" + strconv.Itoa(nowLength) + "-" + strconv.Itoa(nowLength+perThreadSize-1) + nowLength = nowLength + perThreadSize + } + rangeMaps[rangeIndex] = tempRangeStr + filenames = append(filenames, tempRangeStr) + rangeIndex = rangeIndex + 1 + } + return totalLength, rangeMaps, filenames, contentType, nil +} + +func (controller *ThreadController) Init() { + taskQueue := make(chan Task, controller.ThreadCount) + taskCount := make(chan int, controller.ThreadCount+1) + exit := make(chan int) + downloadResultInfoChan := make(chan DownFileParams) + controller.TaskQueue = taskQueue + controller.TaskCount = taskCount + controller.Exit = exit + controller.DownloadResultInfoChan = downloadResultInfoChan + controller.WaitGroup = sync.WaitGroup{} + controller.RangeStrs = make(map[int]string) + SafeMkdir(controller.DownloadFolder) +} diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 0000000..5c7e24d --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,88 @@ +package utils + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "git.hpds.cc/Component/logging" + "go.uber.org/zap" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +func CopyFile(src, dst string) error { + sourceFileStat, err := os.Stat(src) + if err != nil { + return err + } + + if !sourceFileStat.Mode().IsRegular() { + return fmt.Errorf("%s is not a regular file", src) + } + + source, err := os.Open(src) + if err != nil { + return err + } + defer func(source *os.File) { + _ = source.Close() + }(source) + + destination, err := os.Create(dst) + if err != nil { + return err + } + defer func(destination *os.File) { + _ = destination.Close() + }(destination) + _, err = io.Copy(destination, source) + return err +} + +func PathExists(path string) bool { + _, err := os.Stat(path) + if err == nil { + return true + } + if os.IsNotExist(err) { + return false + } + return false +} + +// ReadFile 读取到file中,再利用ioutil将file直接读取到[]byte中, 这是最优 +func ReadFile(fn string) []byte { + f, err := os.Open(fn) + if err != nil { + logging.L().Error("Read File", zap.String("File Name", fn), zap.Error(err)) + return nil + } + defer func(f *os.File) { + _ = f.Close() + }(f) + + fd, err := io.ReadAll(f) + if err != nil { + logging.L().Error("Read File To buff", zap.String("File Name", fn), zap.Error(err)) + return nil + } + + return fd +} + +func GetFileName(fn string) string { + fileType := path.Ext(fn) + return strings.TrimSuffix(fn, fileType) +} +func GetFileNameAndExt(fn string) string { + _, fileName := filepath.Split(fn) + return fileName +} +func GetFileMd5(data []byte) string { + hash := md5.New() + hash.Write(data) + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/pkg/utils/http.go b/pkg/utils/http.go new file mode 100644 index 0000000..939b427 --- /dev/null +++ b/pkg/utils/http.go @@ -0,0 +1,120 @@ +package utils + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "path/filepath" + "strings" +) + +func HttpDo(reqUrl, method string, params map[string]string, header map[string]string) (data []byte, err error) { + var paramStr string = "" + for k, v := range params { + if len(paramStr) == 0 { + paramStr = fmt.Sprintf("%s=%s", k, url.QueryEscape(v)) + } else { + paramStr = fmt.Sprintf("%s&%s=%s", paramStr, k, url.QueryEscape(v)) + } + } + client := &http.Client{} + req, err := http.NewRequest(strings.ToUpper(method), reqUrl, strings.NewReader(paramStr)) + if err != nil { + return nil, err + } + for k, v := range header { + req.Header.Set(k, v) + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + defer func() { + if resp.Body != nil { + err = resp.Body.Close() + if err != nil { + return + } + } + }() + var body []byte + if resp.Body != nil { + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + } + return body, nil +} + +type UploadFile struct { + // 表单名称 + Name string + Filepath string + // 文件全路径 + File *bytes.Buffer +} + +func PostFile(reqUrl string, reqParams map[string]string, contentType string, files []UploadFile, headers map[string]string) string { + requestBody, realContentType := getReader(reqParams, contentType, files) + httpRequest, _ := http.NewRequest("POST", reqUrl, requestBody) + // 添加请求头 + httpRequest.Header.Add("Content-Type", realContentType) + if headers != nil { + for k, v := range headers { + httpRequest.Header.Add(k, v) + } + } + httpClient := &http.Client{} + // 发送请求 + resp, err := httpClient.Do(httpRequest) + if err != nil { + panic(err) + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + response, _ := io.ReadAll(resp.Body) + return string(response) +} + +func getReader(reqParams map[string]string, contentType string, files []UploadFile) (io.Reader, string) { + if strings.Index(contentType, "json") > -1 { + bytesData, _ := json.Marshal(reqParams) + return bytes.NewReader(bytesData), contentType + } else if files != nil { + body := &bytes.Buffer{} + // 文件写入 body + writer := multipart.NewWriter(body) + for _, uploadFile := range files { + part, err := writer.CreateFormFile(uploadFile.Name, filepath.Base(uploadFile.Filepath)) + if err != nil { + panic(err) + } + _, err = io.Copy(part, uploadFile.File) + } + // 其他参数列表写入 body + for k, v := range reqParams { + if err := writer.WriteField(k, v); err != nil { + panic(err) + } + } + if err := writer.Close(); err != nil { + panic(err) + } + // 上传文件需要自己专用的contentType + return body, writer.FormDataContentType() + } else { + urlValues := url.Values{} + for key, val := range reqParams { + urlValues.Set(key, val) + } + reqBody := urlValues.Encode() + return strings.NewReader(reqBody), contentType + } +} diff --git a/pkg/utils/image.go b/pkg/utils/image.go new file mode 100644 index 0000000..68dce3c --- /dev/null +++ b/pkg/utils/image.go @@ -0,0 +1,131 @@ +package utils + +import ( + "bytes" + "encoding/base64" + "golang.org/x/image/bmp" + "golang.org/x/image/tiff" + "image" + "image/color" + "image/jpeg" + "image/png" +) + +func BuffToImage(in []byte) image.Image { + buff := bytes.NewBuffer(in) + m, _, _ := image.Decode(buff) + return m +} + +// Clip 图片裁剪 +func Clip(in []byte, wi, hi int, equalProportion bool) (out image.Image, imageType string, err error) { + buff := bytes.NewBuffer(in) + m, imgType, _ := image.Decode(buff) + rgbImg := m.(*image.YCbCr) + if equalProportion { + w := m.Bounds().Max.X + h := m.Bounds().Max.Y + if w > 0 && h > 0 && wi > 0 && hi > 0 { + wi, hi = fixSize(w, h, wi, hi) + } + } + return rgbImg.SubImage(image.Rect(0, 0, wi, hi)), imgType, nil +} + +func fixSize(img1W, img2H, wi, hi int) (new1W, new2W int) { + var ( //为了方便计算,将图片的宽转为 float64 + imgWidth, imgHeight = float64(img1W), float64(img2H) + ratio float64 + ) + if imgWidth >= imgHeight { + ratio = imgWidth / float64(wi) + return int(imgWidth * ratio), int(imgHeight * ratio) + } + ratio = imgHeight / float64(hi) + return int(imgWidth * ratio), int(imgHeight * ratio) +} + +func Gray(in []byte) (out image.Image, err error) { + m := BuffToImage(in) + bounds := m.Bounds() + dx := bounds.Dx() + dy := bounds.Dy() + newRgba := image.NewRGBA(bounds) + for i := 0; i < dx; i++ { + for j := 0; j < dy; j++ { + colorRgb := m.At(i, j) + _, g, _, a := colorRgb.RGBA() + gUint8 := uint8(g >> 8) + aUint8 := uint8(a >> 8) + newRgba.SetRGBA(i, j, color.RGBA{R: gUint8, G: gUint8, B: gUint8, A: aUint8}) + } + } + r := image.Rect(0, 0, dx, dy) + return newRgba.SubImage(r), nil +} + +func Rotate90(in []byte) image.Image { + m := BuffToImage(in) + rotate90 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx())) + // 矩阵旋转 + for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ { + for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- { + // 设置像素点 + rotate90.Set(m.Bounds().Max.Y-x, y, m.At(y, x)) + } + } + return rotate90 +} + +// Rotate180 旋转180度 +func Rotate180(in []byte) image.Image { + m := BuffToImage(in) + rotate180 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dx(), m.Bounds().Dy())) + // 矩阵旋转 + for x := m.Bounds().Min.X; x < m.Bounds().Max.X; x++ { + for y := m.Bounds().Min.Y; y < m.Bounds().Max.Y; y++ { + // 设置像素点 + rotate180.Set(m.Bounds().Max.X-x, m.Bounds().Max.Y-y, m.At(x, y)) + } + } + return rotate180 +} + +// Rotate270 旋转270度 +func Rotate270(in []byte) image.Image { + m := BuffToImage(in) + rotate270 := image.NewRGBA(image.Rect(0, 0, m.Bounds().Dy(), m.Bounds().Dx())) + // 矩阵旋转 + for x := m.Bounds().Min.Y; x < m.Bounds().Max.Y; x++ { + for y := m.Bounds().Max.X - 1; y >= m.Bounds().Min.X; y-- { + // 设置像素点 + rotate270.Set(x, m.Bounds().Max.X-y, m.At(y, x)) + } + } + return rotate270 + +} + +func ImageToBase64(img image.Image, imgType string) string { + buff := ImageToBuff(img, imgType) + return base64.StdEncoding.EncodeToString(buff.Bytes()) +} + +func ImageToBuff(img image.Image, imgType string) *bytes.Buffer { + buff := bytes.NewBuffer(nil) + switch imgType { + case "bmp": + imgType = "bmp" + _ = bmp.Encode(buff, img) + case "png": + imgType = "png" + _ = png.Encode(buff, img) + case "tiff": + imgType = "tiff" + _ = tiff.Encode(buff, img, nil) + default: + imgType = "jpeg" + _ = jpeg.Encode(buff, img, nil) + } + return buff +} diff --git a/pkg/utils/network.go b/pkg/utils/network.go new file mode 100644 index 0000000..c619faf --- /dev/null +++ b/pkg/utils/network.go @@ -0,0 +1,40 @@ +package utils + +import ( + "fmt" + "git.hpds.cc/Component/network/log" + "net" +) + +// GetAvailablePort 获取可用端口 +func GetAvailablePort() (int, error) { + address, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:0", "0.0.0.0")) + if err != nil { + return 0, err + } + + listener, err := net.ListenTCP("tcp", address) + if err != nil { + return 0, err + } + + defer func(listener *net.TCPListener) { + _ = listener.Close() + }(listener) + return listener.Addr().(*net.TCPAddr).Port, nil +} + +// IsPortAvailable 判断端口是否可以(未被占用) +func IsPortAvailable(port int) bool { + address := fmt.Sprintf("%s:%d", "0.0.0.0", port) + listener, err := net.Listen("tcp", address) + if err != nil { + log.Infof("port %s is taken: %s", address, err) + return false + } + + defer func(listener net.Listener) { + _ = listener.Close() + }(listener) + return true +} diff --git a/proto/mq.go b/proto/mq.go new file mode 100644 index 0000000..30573c2 --- /dev/null +++ b/proto/mq.go @@ -0,0 +1,24 @@ +package proto + +type FileCapture struct { + FileName string `json:"fileName"` + File string `json:"file"` + DatasetName string `json:"datasetName"` + CaptureTime int64 `json:"captureTime"` +} + +type ModelResult struct { + FileName string `json:"fileName"` + File string `json:"file"` + FileMd5 string `json:"fileMd5"` + DatasetName string `json:"datasetName"` + SubDataset string `json:"subDataset"` + Crack bool `json:"crack"` + Pothole bool `json:"pothole"` + TaskId int64 `json:"taskId"` + TaskCode string `json:"taskCode"` + ModelId int64 `json:"modelId"` + NodeId int64 `json:"nodeId"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` +}