Browse Source

修改实现

yjp 9 months ago
parent
commit
10f3753ff2
4 changed files with 203 additions and 135 deletions
  1. 4 0
      client/client.go
  2. 0 115
      client/watch_workflow.go
  3. 191 0
      client/workflow_watcher.go
  4. 8 20
      test/workflow_test.go

+ 4 - 0
client/client.go

@@ -21,6 +21,7 @@ type Client struct {
 	client                  apiclient.Client
 	workflowTemplateService workflowtemplate.WorkflowTemplateServiceClient
 	workflowService         workflow.WorkflowServiceClient
+	workflowWatcherManager  *workflowWatcherManager
 }
 
 func NewClient(kubeConfigEnv string, opts ...Option) (*Client, error) {
@@ -62,6 +63,7 @@ func NewClient(kubeConfigEnv string, opts ...Option) (*Client, error) {
 		client:                  apiClient,
 		workflowTemplateService: workflowTemplateService,
 		workflowService:         apiClient.NewWorkflowServiceClient(),
+		workflowWatcherManager:  newWorkflowWatcherManager(),
 	}, nil
 }
 
@@ -70,6 +72,8 @@ func Destroy(c *Client) {
 		return
 	}
 
+	c.unregisterAllWorkflowWatchers()
+
 	c.ctx.Done()
 	c.client = nil
 	c.workflowTemplateService = nil

+ 0 - 115
client/watch_workflow.go

@@ -1,115 +0,0 @@
-package client
-
-import (
-	"context"
-	"fmt"
-	"github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflow"
-	"github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
-	"github.com/argoproj/argo-workflows/v3/util"
-	"github.com/pkg/errors"
-	"io"
-	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
-	"time"
-)
-
-type WatchWorkflowCallback func(wf *v1alpha1.Workflow)
-
-type WatchWorkflowToken struct {
-	workflowChan        chan *v1alpha1.Workflow
-	doneChan            chan any
-	streamReceiveCancel context.CancelFunc
-}
-
-func (token *WatchWorkflowToken) Done() {
-	token.streamReceiveCancel()
-
-	token.doneChan <- nil
-	close(token.doneChan)
-	token.doneChan = nil
-
-	close(token.workflowChan)
-	token.workflowChan = nil
-}
-
-type WatchWorkflowParams struct {
-	Namespace string
-	Name      string
-}
-
-// WatchWorkflow 监听工作流
-func (c *Client) WatchWorkflow(params WatchWorkflowParams, callback WatchWorkflowCallback) (*WatchWorkflowToken, error) {
-	req := &workflow.WatchWorkflowsRequest{
-		Namespace: params.Namespace,
-		ListOptions: &metav1.ListOptions{
-			FieldSelector:   util.GenerateFieldSelectorFromWorkflowName(params.Name),
-			ResourceVersion: "0",
-		},
-	}
-
-	cancelCtx, cancel := context.WithCancel(c.ctx)
-	stream, err := c.workflowService.WatchWorkflows(cancelCtx, req)
-	if err != nil {
-		cancel()
-		return nil, errors.New(err.Error())
-	}
-
-	token := &WatchWorkflowToken{
-		workflowChan:        make(chan *v1alpha1.Workflow),
-		doneChan:            make(chan any),
-		streamReceiveCancel: cancel,
-	}
-
-	go func() {
-		for {
-			select {
-			case <-token.doneChan:
-				return
-			case wf := <-token.workflowChan:
-				if wf == nil {
-					continue
-				}
-
-				if callback != nil {
-					callback(wf)
-				}
-			}
-		}
-	}()
-
-	go func() {
-		for {
-			event, err := stream.Recv()
-			if err != nil {
-				if err == io.EOF {
-					if stream.Context().Err().Error() == "context canceled" {
-						return
-					}
-
-					cancelCtx, cancel := context.WithCancel(c.ctx)
-					stream, err = c.workflowService.WatchWorkflows(cancelCtx, req)
-					if err != nil {
-						cancel()
-						fmt.Printf("%v\n", errors.New(err.Error()))
-						time.Sleep(5 * time.Second)
-						continue
-					}
-
-					token.streamReceiveCancel = cancel
-
-					continue
-				}
-
-				fmt.Printf("%v\n", errors.New(err.Error()))
-				return
-			}
-
-			if event == nil {
-				continue
-			}
-
-			token.workflowChan <- event.Object
-		}
-	}()
-
-	return token, nil
-}

+ 191 - 0
client/workflow_watcher.go

@@ -0,0 +1,191 @@
+package client
+
+import (
+	"context"
+	"fmt"
+	"github.com/argoproj/argo-workflows/v3/pkg/apiclient/workflow"
+	"github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
+	"github.com/argoproj/argo-workflows/v3/util"
+	"github.com/pkg/errors"
+	"io"
+	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+	"strings"
+	"sync"
+	"time"
+)
+
+type WatcherCallback func(wf *v1alpha1.Workflow)
+
+type workflowWatcherManager struct {
+	watcherMapMutex *sync.Mutex
+	watcherMap      map[string]*workflowWatcher
+}
+
+func newWorkflowWatcherManager() *workflowWatcherManager {
+	return &workflowWatcherManager{
+		watcherMapMutex: &sync.Mutex{},
+		watcherMap:      make(map[string]*workflowWatcher),
+	}
+}
+
+func (manager *workflowWatcherManager) addWatcher(namespace string, name string, watcher *workflowWatcher) {
+	manager.watcherMapMutex.Lock()
+	defer manager.watcherMapMutex.Unlock()
+
+	manager.watcherMap[manager.formKey(namespace, name)] = watcher
+}
+
+func (manager *workflowWatcherManager) removeWatcher(namespace string, name string) (*workflowWatcher, bool) {
+	manager.watcherMapMutex.Lock()
+	defer manager.watcherMapMutex.Unlock()
+
+	key := manager.formKey(namespace, name)
+
+	value, ok := manager.watcherMap[key]
+	if ok {
+		delete(manager.watcherMap, key)
+	}
+
+	return value, ok
+}
+
+func (manager *workflowWatcherManager) clearWatchers(f func(namespace string, name string, watcher *workflowWatcher) bool) {
+	manager.watcherMapMutex.Lock()
+	defer manager.watcherMapMutex.Unlock()
+
+	for key, value := range manager.watcherMap {
+		namespace, name := manager.splitKey(key)
+		notStop := f(namespace, name, value)
+		if !notStop {
+			break
+		}
+	}
+
+	manager.watcherMap = make(map[string]*workflowWatcher)
+}
+
+func (manager *workflowWatcherManager) formKey(namespace string, name string) string {
+	return strings.Join([]string{namespace, name}, "::")
+}
+
+func (manager *workflowWatcherManager) splitKey(key string) (namespace string, name string) {
+	keyParts := strings.Split(key, "::")
+	namespace = keyParts[0]
+	name = keyParts[1]
+	return
+}
+
+type workflowWatcher struct {
+	workflowChan        chan *v1alpha1.Workflow
+	doneChan            chan any
+	streamReceiveCancel context.CancelFunc
+}
+
+func (watcher *workflowWatcher) close() {
+	// 停止接收数据协程
+	watcher.streamReceiveCancel()
+
+	// 停止处理数据协程
+	watcher.doneChan <- nil
+	close(watcher.doneChan)
+	watcher.doneChan = nil
+
+	// 关闭数据通道
+	close(watcher.workflowChan)
+	watcher.workflowChan = nil
+}
+
+func (c *Client) RegisterWorkflowWatcher(namespace string, name string, callback WatcherCallback) error {
+	req := &workflow.WatchWorkflowsRequest{
+		Namespace: namespace,
+		ListOptions: &metav1.ListOptions{
+			FieldSelector:   util.GenerateFieldSelectorFromWorkflowName(name),
+			ResourceVersion: "0",
+		},
+	}
+
+	cancelCtx, cancel := context.WithCancel(c.ctx)
+	stream, err := c.workflowService.WatchWorkflows(cancelCtx, req)
+	if err != nil {
+		cancel()
+		return errors.New(err.Error())
+	}
+
+	watcher := &workflowWatcher{
+		workflowChan:        make(chan *v1alpha1.Workflow),
+		doneChan:            make(chan any),
+		streamReceiveCancel: cancel,
+	}
+
+	go func() {
+		for {
+			select {
+			case <-watcher.doneChan:
+				return
+			case wf := <-watcher.workflowChan:
+				if wf == nil {
+					continue
+				}
+
+				if callback != nil {
+					callback(wf)
+				}
+			}
+		}
+	}()
+
+	go func() {
+		for {
+			event, err := stream.Recv()
+			if err != nil {
+				if err == io.EOF {
+					if stream.Context().Err().Error() == "context canceled" {
+						return
+					}
+
+					cancelCtx, cancel := context.WithCancel(c.ctx)
+					stream, err = c.workflowService.WatchWorkflows(cancelCtx, req)
+					if err != nil {
+						cancel()
+						fmt.Printf("%v\n", errors.New(err.Error()))
+						time.Sleep(5 * time.Second)
+						continue
+					}
+
+					watcher.streamReceiveCancel = cancel
+
+					continue
+				}
+
+				fmt.Printf("%v\n", errors.New(err.Error()))
+				return
+			}
+
+			if event == nil {
+				continue
+			}
+
+			watcher.workflowChan <- event.Object
+		}
+	}()
+
+	c.workflowWatcherManager.addWatcher(namespace, name, watcher)
+
+	return nil
+}
+
+func (c *Client) UnregisterWorkflowWatcher(namespace string, name string) {
+	watcher, loaded := c.workflowWatcherManager.removeWatcher(namespace, name)
+	if !loaded {
+		return
+	}
+
+	watcher.close()
+}
+
+func (c *Client) unregisterAllWorkflowWatchers() {
+	c.workflowWatcherManager.clearWatchers(func(namespace string, name string, watcher *workflowWatcher) bool {
+		watcher.close()
+		return true
+	})
+}

+ 8 - 20
test/workflow_test.go

@@ -209,10 +209,7 @@ func TestRetryWorkflow(t *testing.T) {
 	wg := sync.WaitGroup{}
 	wg.Add(1)
 
-	token, err := argo.GetInstance().WatchWorkflow(client.WatchWorkflowParams{
-		Namespace: namespace,
-		Name:      workflowName,
-	}, func(wf *v1alpha1.Workflow) {
+	err = argo.GetInstance().RegisterWorkflowWatcher(namespace, workflowName, func(wf *v1alpha1.Workflow) {
 		if terminalCalled {
 			if wf.Status.Phase == "Running" {
 				return
@@ -225,7 +222,7 @@ func TestRetryWorkflow(t *testing.T) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	defer token.Done()
+	defer argo.GetInstance().UnregisterWorkflowWatcher(namespace, workflowName)
 
 	err = argo.GetInstance().TerminateWorkflow(client.TerminateWorkflowParams{
 		Namespace: namespace,
@@ -278,10 +275,7 @@ func TestStopWorkflow(t *testing.T) {
 	wg := sync.WaitGroup{}
 	wg.Add(1)
 
-	token, err := argo.GetInstance().WatchWorkflow(client.WatchWorkflowParams{
-		Namespace: namespace,
-		Name:      workflowName,
-	}, func(wf *v1alpha1.Workflow) {
+	err = argo.GetInstance().RegisterWorkflowWatcher(namespace, workflowName, func(wf *v1alpha1.Workflow) {
 		if stopCalled {
 			if wf.Status.Phase == "Running" {
 				return
@@ -294,7 +288,7 @@ func TestStopWorkflow(t *testing.T) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	defer token.Done()
+	defer argo.GetInstance().UnregisterWorkflowWatcher(namespace, workflowName)
 
 	err = argo.GetInstance().StopWorkflow(client.StopWorkflowParams{
 		Namespace: namespace,
@@ -339,10 +333,7 @@ func TestTerminateWorkflow(t *testing.T) {
 	wg := sync.WaitGroup{}
 	wg.Add(1)
 
-	token, err := argo.GetInstance().WatchWorkflow(client.WatchWorkflowParams{
-		Namespace: namespace,
-		Name:      workflowName,
-	}, func(wf *v1alpha1.Workflow) {
+	err = argo.GetInstance().RegisterWorkflowWatcher(namespace, workflowName, func(wf *v1alpha1.Workflow) {
 		if terminalCalled {
 			if wf.Status.Phase == "Running" {
 				return
@@ -355,7 +346,7 @@ func TestTerminateWorkflow(t *testing.T) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	defer token.Done()
+	defer argo.GetInstance().UnregisterWorkflowWatcher(namespace, workflowName)
 
 	err = argo.GetInstance().TerminateWorkflow(client.TerminateWorkflowParams{
 		Namespace: namespace,
@@ -479,10 +470,7 @@ func TestSuspendAndResumeWorkflow(t *testing.T) {
 	wg1 := sync.WaitGroup{}
 	wg1.Add(1)
 
-	token, err := argo.GetInstance().WatchWorkflow(client.WatchWorkflowParams{
-		Namespace: namespace,
-		Name:      workflowName,
-	}, func(wf *v1alpha1.Workflow) {
+	err = argo.GetInstance().RegisterWorkflowWatcher(namespace, workflowName, func(wf *v1alpha1.Workflow) {
 		if suspendCalled {
 			if wf.Status.Phase != "Running" {
 				return
@@ -499,7 +487,7 @@ func TestSuspendAndResumeWorkflow(t *testing.T) {
 		t.Fatalf("%+v\n", err)
 	}
 
-	defer token.Done()
+	defer argo.GetInstance().UnregisterWorkflowWatcher(namespace, workflowName)
 
 	err = argo.GetInstance().SuspendWorkflow(client.SuspendWorkflowParams{
 		Namespace: namespace,