Browse Source

添加停止任务接口

yjp 1 month ago
parent
commit
da1c50ea25

+ 22 - 0
convenient/domain/task_manager/simple.go

@@ -0,0 +1,22 @@
+package task_manager
+
+import (
+	"git.sxidc.com/go-framework/baize/framework/binding"
+	"git.sxidc.com/go-framework/baize/framework/core/api"
+	"git.sxidc.com/go-framework/baize/framework/core/application"
+)
+
+// Simple Bind参数
+type Simple struct {
+	// schema
+	Schema string
+}
+
+func (simple *Simple) bind(binder *binding.Binder) {
+	// TODO 完成查询接口, 编写Info,数据库YAML
+}
+
+func Bind(app *application.App, simple *Simple) {
+	binder := binding.NewBinder(app.Api().ChooseRouter(api.RouterPrefix, ""), app.Infrastructure())
+	simple.bind(binder)
+}

+ 24 - 0
convenient/domain/task_manager/task/entity.go

@@ -1,6 +1,7 @@
 package task
 
 import (
+	"encoding/json"
 	"git.sxidc.com/go-framework/baize/framework/core/domain"
 	"git.sxidc.com/go-framework/baize/framework/core/domain/entity"
 	"github.com/pkg/errors"
@@ -10,6 +11,7 @@ const (
 	StatusCodeCreated = iota + 1
 	StatusCodeRunning
 	StatusCodeCompleted
+	StatusCodeStop
 	StatusCodeError
 )
 
@@ -17,6 +19,7 @@ const (
 	statusCreated   = "已创建"
 	statusRunning   = "运行中"
 	statusCompleted = "已完成"
+	statusStop      = "已停止"
 	statusError     = "错误"
 )
 
@@ -25,6 +28,7 @@ var (
 		StatusCodeCreated:   statusCreated,
 		StatusCodeRunning:   statusRunning,
 		StatusCodeCompleted: statusCompleted,
+		StatusCodeStop:      statusStop,
 		StatusCodeError:     statusError,
 	}
 
@@ -32,6 +36,7 @@ var (
 		statusCreated:   StatusCodeCreated,
 		statusRunning:   StatusCodeRunning,
 		statusCompleted: StatusCodeCompleted,
+		statusStop:      StatusCodeStop,
 		statusError:     StatusCodeError,
 	}
 )
@@ -80,6 +85,16 @@ func (e *Entity) GetFieldMap() map[string]string {
 	return fieldMap
 }
 
+func (e *Entity) GetMapContext() (map[string]any, error) {
+	ctx := make(map[string]any)
+	err := json.Unmarshal([]byte(e.Context), &ctx)
+	if err != nil {
+		return nil, err
+	}
+
+	return ctx, nil
+}
+
 func (e *Entity) SetStatusCreated() {
 	e.StatusCode = StatusCodeCreated
 	e.Status = statusCreated
@@ -95,6 +110,11 @@ func (e *Entity) SetStatusCompleted() {
 	e.Status = statusCompleted
 }
 
+func (e *Entity) SetStatusStop() {
+	e.StatusCode = StatusCodeStop
+	e.Status = statusStop
+}
+
 func (e *Entity) SetStatusError(errMsg string) {
 	e.StatusCode = StatusCodeError
 	e.Status = statusError
@@ -113,6 +133,10 @@ func (e *Entity) IsStatusCompleted() bool {
 	return e.StatusCode == StatusCodeCompleted
 }
 
+func (e *Entity) IsStatusStop() bool {
+	return e.StatusCode == StatusCodeStop
+}
+
 func (e *Entity) IsStatusError() bool {
 	return e.StatusCode == StatusCodeError
 }

+ 1 - 0
convenient/domain/task_manager/task/runner.go

@@ -3,4 +3,5 @@ package task
 type Runner interface {
 	Run(ctx map[string]any) error
 	Restart(ctx map[string]any) error
+	Stop(ctx map[string]any) error
 }

+ 121 - 26
convenient/domain/task_manager/task_manager.go

@@ -3,17 +3,18 @@ package task_manager
 import (
 	"encoding/json"
 	"git.sxidc.com/go-framework/baize/convenient/domain/task_manager/task"
-	"git.sxidc.com/go-framework/baize/framework/binding"
-	"git.sxidc.com/go-framework/baize/framework/core/api"
-	"git.sxidc.com/go-framework/baize/framework/core/application"
 	"git.sxidc.com/go-framework/baize/framework/core/domain"
+	"git.sxidc.com/go-framework/baize/framework/core/domain/entity"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/database"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/database/sql"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
 	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/pkg/errors"
+	"sync"
 )
 
+var runnerRegister = new(sync.Map)
+
 type RunTaskParams struct {
 	Group      string
 	OperatorID string
@@ -42,10 +43,15 @@ func (params *RunTaskParams) check() error {
 	return nil
 }
 
-func RunTask(runner task.Runner, params *RunTaskParams) error {
+func RunTask(runner task.Runner, params *RunTaskParams) (string, error) {
+	err := params.check()
+	if err != nil {
+		return "", err
+	}
+
 	ctxJsonBytes, err := json.Marshal(params.Context)
 	if err != nil {
-		return err
+		return "", err
 	}
 
 	ctxJsonStr := string(ctxJsonBytes)
@@ -59,14 +65,17 @@ func RunTask(runner task.Runner, params *RunTaskParams) error {
 
 	err = createTaskDB(taskEntity, params.DBSchema, params.DBExecutor)
 	if err != nil {
-		return err
+		return "", err
 	}
 
-	go runTask(taskEntity, runner, params.DBSchema, params.DBExecutor, func(ctx map[string]any, runner task.Runner) error {
+	loaded, _ := runnerRegister.LoadOrStore(params.Group, runner)
+	loadedRunner := loaded.(task.Runner)
+
+	go runTask(taskEntity, loadedRunner, params.DBSchema, params.DBExecutor, func(ctx map[string]any, runner task.Runner) error {
 		return runner.Run(ctx)
 	})
 
-	return nil
+	return taskEntity.ID, nil
 }
 
 type RestartTaskParams struct {
@@ -92,6 +101,11 @@ func (params *RestartTaskParams) check() error {
 }
 
 func RestartTask(runner task.Runner, params *RestartTaskParams) error {
+	err := params.check()
+	if err != nil {
+		return err
+	}
+
 	runningResults, _, err := database.Query(params.DBExecutor, &sql.QueryExecuteParams{
 		TableName: domain.TableName(params.DBSchema, &task.Entity{}),
 		Conditions: sql.NewConditions().Equal(task.ColumnGroup, params.Group).
@@ -110,7 +124,10 @@ func RestartTask(runner task.Runner, params *RestartTaskParams) error {
 	}
 
 	for _, taskEntity := range taskEntities {
-		go runTask(&taskEntity, runner, params.DBSchema, params.DBExecutor, func(ctx map[string]any, runner task.Runner) error {
+		loaded, _ := runnerRegister.LoadOrStore(params.Group, runner)
+		loadedRunner := loaded.(task.Runner)
+
+		go runTask(&taskEntity, loadedRunner, params.DBSchema, params.DBExecutor, func(ctx map[string]any, runner task.Runner) error {
 			return runner.Restart(ctx)
 		})
 	}
@@ -118,6 +135,89 @@ func RestartTask(runner task.Runner, params *RestartTaskParams) error {
 	return nil
 }
 
+type StopTaskParams struct {
+	ID         string
+	DBSchema   string
+	DBExecutor database.Executor
+}
+
+func (params *StopTaskParams) check() error {
+	if strutils.IsStringEmpty(params.ID) {
+		return errors.New("没有传递任务ID")
+	}
+
+	if strutils.IsStringEmpty(params.DBSchema) {
+		return errors.New("没有传递数据库schema")
+	}
+
+	if params.DBExecutor == nil {
+		return errors.New("没有传递数据库执行器")
+	}
+
+	return nil
+}
+func StopTask(params *StopTaskParams) error {
+	err := params.check()
+	if err != nil {
+		return err
+	}
+
+	result, err := database.QueryOne(params.DBExecutor, &sql.QueryOneExecuteParams{
+		TableName: domain.TableName(params.DBSchema, &task.Entity{}),
+		Conditions: sql.NewConditions().Equal(entity.ColumnID, params.ID).
+			Equal(task.ColumnStatus, task.StatusCodeRunning),
+	})
+	if err != nil {
+		if database.IsErrorDBRecordNotExist(err) {
+			return errors.New("任务不存在")
+		}
+
+		return err
+	}
+
+	taskEntity := new(task.Entity)
+	err = sql.ParseSqlResult(result, taskEntity)
+	if err != nil {
+		return err
+	}
+
+	ctx, err := taskEntity.GetMapContext()
+	if err != nil {
+		updateErr := updateTaskStatusErrorDB(taskEntity, err.Error(), params.DBSchema, params.DBExecutor)
+		if updateErr != nil {
+			return updateErr
+		}
+
+		return err
+	}
+
+	loaded, ok := runnerRegister.Load(taskEntity.Group)
+	if !ok {
+		return errors.New("没有找到任务对应的执行器")
+	}
+
+	loadedRunner := loaded.(task.Runner)
+
+	err = database.Transaction(params.DBExecutor, func(tx database.Executor) error {
+		err := updateTaskStatusStopDBTx(taskEntity, params.DBSchema, tx)
+		if err != nil {
+			return err
+		}
+
+		err = loadedRunner.Stop(ctx)
+		if err != nil {
+			return err
+		}
+
+		return nil
+	})
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func createTaskDB(taskEntity *task.Entity, dbSchema string, dbExecutor database.Executor) error {
 	taskEntity.SetStatusCreated()
 
@@ -156,6 +256,17 @@ func updateTaskStatusCompleteDB(taskEntity *task.Entity, dbSchema string, dbExec
 	return nil
 }
 
+func updateTaskStatusStopDBTx(taskEntity *task.Entity, dbSchema string, tx database.Executor) error {
+	taskEntity.SetStatusStop()
+
+	err := database.UpdateEntity(tx, domain.TableName(dbSchema, taskEntity), taskEntity)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func updateTaskStatusErrorDB(taskEntity *task.Entity, errMsg string, dbSchema string, dbExecutor database.Executor) error {
 	taskEntity.SetStatusError(errMsg)
 
@@ -168,8 +279,7 @@ func updateTaskStatusErrorDB(taskEntity *task.Entity, errMsg string, dbSchema st
 }
 
 func runTask(taskEntity *task.Entity, runner task.Runner, dbSchema string, dbExecutor database.Executor, executeFunc func(ctx map[string]any, runner task.Runner) error) {
-	ctx := make(map[string]any)
-	err := json.Unmarshal([]byte(taskEntity.Context), &ctx)
+	ctx, err := taskEntity.GetMapContext()
 	if err != nil {
 		err = updateTaskStatusErrorDB(taskEntity, err.Error(), dbSchema, dbExecutor)
 		if err != nil {
@@ -204,18 +314,3 @@ func runTask(taskEntity *task.Entity, runner task.Runner, dbSchema string, dbExe
 		return
 	}
 }
-
-// Simple Bind参数
-type Simple struct {
-	// schema
-	Schema string
-}
-
-func (simple *Simple) bind(binder *binding.Binder) {
-	// TODO 完成查询接口, 编写Info,数据库YAML
-}
-
-func Bind(app *application.App, simple *Simple) {
-	binder := binding.NewBinder(app.Api().ChooseRouter(api.RouterPrefix, ""), app.Infrastructure())
-	simple.bind(binder)
-}