yjp %!s(int64=2) %!d(string=hai) anos
pai
achega
069fdb54c6

+ 1 - 1
mqtt_binding/mqtt_binding.go

@@ -93,7 +93,7 @@ func (item *BindItem[I, O]) bind(r *router.Router, handlers ...router.Handler) {
 	})
 
 	// 所有的函数加入到执行链中
-	routerItem, err := router.NewItem(r.Group+item.Topic, item.Qos, item.Retained, handlers)
+	routerItem, err := router.NewItem(r.Group+item.Topic, item.Qos, handlers)
 	if err != nil {
 		panic("创建路由条目失败: " + err.Error())
 		return

+ 95 - 49
mqtt_binding/mqtt_client/mqtt_client.go

@@ -40,11 +40,10 @@ func (opt *MqttClientOptions) check() error {
 }
 
 type MqttClient struct {
-	client        mqtt.Client
-	clientOptions *mqtt.ClientOptions
+	client mqtt.Client
 
 	routersMutex *sync.Mutex
-	routers      []router.Router
+	routers      []*router.Router
 }
 
 func NewMqttClient(opts *MqttClientOptions) (*MqttClient, error) {
@@ -57,22 +56,38 @@ func NewMqttClient(opts *MqttClientOptions) (*MqttClient, error) {
 		return nil, err
 	}
 
+	mqttClient := &MqttClient{
+		routersMutex: &sync.Mutex{},
+		routers:      make([]*router.Router, 0),
+	}
+
 	mqttOptions := mqtt.NewClientOptions().
 		SetAutoReconnect(true).
 		SetUsername(opts.UserName).
 		SetPassword(opts.Password).
 		AddBroker(opts.Address).
 		SetClientID(opts.ClientID).
-		SetKeepAlive(opts.KeepAliveSec).
-		SetPingTimeout(opts.PingTimeoutSec).
-		SetWill(opts.ClientID+"/will", "dead", 2, true)
-
-	return &MqttClient{
-		client:        mqtt.NewClient(mqttOptions),
-		clientOptions: mqttOptions,
-		routersMutex:  &sync.Mutex{},
-		routers:       make([]router.Router, 0),
-	}, nil
+		SetKeepAlive(opts.KeepAliveSec*time.Second).
+		SetPingTimeout(opts.PingTimeoutSec*time.Second).
+		SetWill(opts.ClientID+"/will", "dead", 2, false).
+		SetOnConnectHandler(func(client mqtt.Client) {
+			err := mqttClient.onConnect()
+			if err != nil {
+				fmt.Println(err)
+				return
+			}
+		}).
+		SetConnectionLostHandler(func(client mqtt.Client, _ error) {
+			err := mqttClient.onConnectLost()
+			if err != nil {
+				fmt.Println(err)
+				return
+			}
+		})
+
+	mqttClient.client = mqtt.NewClient(mqttOptions)
+
+	return mqttClient, nil
 }
 
 func DestroyMqttClient(c *MqttClient) {
@@ -81,7 +96,7 @@ func DestroyMqttClient(c *MqttClient) {
 
 		c.routersMutex.Lock()
 		for _, r := range c.routers {
-			router.DestroyRouter(&r)
+			router.DestroyRouter(r)
 		}
 		c.routers = nil
 		c.routersMutex.Unlock()
@@ -91,33 +106,6 @@ func DestroyMqttClient(c *MqttClient) {
 }
 
 func (c *MqttClient) Connect() error {
-	c.clientOptions.SetOnConnectHandler(func(client mqtt.Client) {
-		c.routersMutex.Lock()
-		defer c.routersMutex.Unlock()
-
-		err := c.rangeRouters(func(r *router.Router) error {
-			err := r.RangeItem(func(item router.Item) error {
-				token := c.client.Subscribe(item.Topic, item.Qos, func(client mqtt.Client, message mqtt.Message) {
-					item.CallHandlers(message.Payload())
-				})
-				if token.Wait(); token.Error() != nil {
-					return token.Error()
-				}
-
-				return nil
-			})
-			if err != nil {
-				return errors.New("SetOnConnectHandler订阅失败: " + err.Error())
-			}
-
-			return nil
-		})
-		if err != nil {
-			fmt.Println(err)
-			return
-		}
-	})
-
 	token := c.client.Connect()
 	if token.Wait(); token.Error() != nil {
 		return token.Error()
@@ -131,7 +119,7 @@ func (c *MqttClient) Disconnect() {
 }
 
 func (c *MqttClient) GetRouter(group string, handlers []router.Handler) *router.Router {
-	r := router.NewRouter(group, handlers, func(item router.Item) error {
+	r := router.NewRouter(group, handlers, func(item *router.Item) error {
 		for {
 			if c.client.IsConnected() {
 				break
@@ -140,11 +128,18 @@ func (c *MqttClient) GetRouter(group string, handlers []router.Handler) *router.
 			time.Sleep(1 * time.Second)
 		}
 
-		token := c.client.Subscribe(item.Topic, item.Qos, func(client mqtt.Client, message mqtt.Message) {
-			item.CallHandlers(message.Payload())
+		err := item.DoIfUnSubscribe(func() error {
+			token := c.client.Subscribe(item.Topic, item.Qos, func(client mqtt.Client, message mqtt.Message) {
+				item.CallHandlers(message.Payload())
+			})
+			if token.Wait(); token.Error() != nil {
+				return token.Error()
+			}
+
+			return nil
 		})
-		if token.Wait(); token.Error() != nil {
-			return token.Error()
+		if err != nil {
+			return err
 		}
 
 		return nil
@@ -156,7 +151,7 @@ func (c *MqttClient) GetRouter(group string, handlers []router.Handler) *router.
 }
 
 func (c *MqttClient) Response(item *router.Item, data []byte) error {
-	token := c.client.Publish(item.Topic+"/reply", item.Qos, item.ResponseRetained, data)
+	token := c.client.Publish(item.Topic+"/reply", item.Qos, false, data)
 	if token.Wait(); token.Error() != nil {
 		return token.Error()
 	}
@@ -164,11 +159,62 @@ func (c *MqttClient) Response(item *router.Item, data []byte) error {
 	return nil
 }
 
+func (c *MqttClient) onConnect() error {
+	err := c.rangeRouters(func(r *router.Router) error {
+		err := r.RangeItem(func(item *router.Item) error {
+			err := item.DoIfUnSubscribe(func() error {
+				token := c.client.Subscribe(item.Topic, item.Qos, func(client mqtt.Client, message mqtt.Message) {
+					item.CallHandlers(message.Payload())
+				})
+				if token.Wait(); token.Error() != nil {
+					return token.Error()
+				}
+
+				return nil
+			})
+			if err != nil {
+				return err
+			}
+
+			return nil
+		})
+		if err != nil {
+			return errors.New("SetOnConnectHandler订阅失败: " + err.Error())
+		}
+
+		return nil
+	})
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *MqttClient) onConnectLost() error {
+	err := c.rangeRouters(func(r *router.Router) error {
+		err := r.RangeItem(func(item *router.Item) error {
+			item.SetUnSubscribe()
+			return nil
+		})
+		if err != nil {
+			return err
+		}
+
+		return nil
+	})
+	if err != nil {
+		return errors.New("SetOnConnectHandler订阅失败: " + err.Error())
+	}
+
+	return nil
+}
+
 func (c *MqttClient) addRouter(router *router.Router) {
 	c.routersMutex.Lock()
 	defer c.routersMutex.Unlock()
 
-	c.routers = append(c.routers, *router)
+	c.routers = append(c.routers, router)
 }
 
 func (c *MqttClient) rangeRouters(rangeFunc func(router *router.Router) error) error {
@@ -176,7 +222,7 @@ func (c *MqttClient) rangeRouters(rangeFunc func(router *router.Router) error) e
 	defer c.routersMutex.Unlock()
 
 	for _, r := range c.routers {
-		err := rangeFunc(&r)
+		err := rangeFunc(r)
 		if err != nil {
 			return err
 		}

+ 41 - 14
mqtt_binding/mqtt_client/router/router.go

@@ -7,13 +7,13 @@ import (
 )
 
 type Handler func(item *Item, data []byte)
-type OnAddItemFunc func(item Item) error
+type OnAddItemFunc func(item *Item) error
 
 type Router struct {
 	Group string
 
 	itemsMutex    *sync.Mutex
-	items         []Item
+	items         []*Item
 	handlers      []Handler
 	onAddItemFunc OnAddItemFunc
 }
@@ -22,7 +22,7 @@ func NewRouter(group string, handlers []Handler, onAddItemFunc OnAddItemFunc) *R
 	return &Router{
 		Group:         group,
 		itemsMutex:    &sync.Mutex{},
-		items:         make([]Item, 0),
+		items:         make([]*Item, 0),
 		onAddItemFunc: onAddItemFunc,
 		handlers:      handlers,
 	}
@@ -35,7 +35,7 @@ func DestroyRouter(router *Router) {
 
 	router.itemsMutex.Lock()
 	for _, item := range router.items {
-		DestroyItem(&item)
+		DestroyItem(item)
 	}
 	router.items = nil
 	router.itemsMutex.Unlock()
@@ -51,21 +51,21 @@ func (router *Router) AddItem(item *Item) error {
 	router.itemsMutex.Lock()
 	defer router.itemsMutex.Unlock()
 
-	item.handlers = append(item.handlers, router.handlers...)
-
-	router.items = append(router.items, *item)
+	item.handlers = append(router.handlers, item.handlers...)
 
 	if router.onAddItemFunc != nil {
-		err := router.onAddItemFunc(*item)
+		err := router.onAddItemFunc(item)
 		if err != nil {
 			return err
 		}
 	}
 
+	router.items = append(router.items, item)
+
 	return nil
 }
 
-func (router *Router) RangeItem(rangeFunc func(item Item) error) error {
+func (router *Router) RangeItem(rangeFunc func(item *Item) error) error {
 	if rangeFunc == nil {
 		return nil
 	}
@@ -84,16 +84,18 @@ func (router *Router) RangeItem(rangeFunc func(item Item) error) error {
 }
 
 type Item struct {
-	Topic            string
-	Qos              byte
-	ResponseRetained bool
+	Topic string
+	Qos   byte
+
+	subscribedMutex *sync.Mutex
+	subscribed      bool
 
 	handlers            []Handler
 	currentHandlerIndex int
 	currentData         []byte
 }
 
-func NewItem(topic string, qos byte, responseRetained bool, handlers []Handler) (*Item, error) {
+func NewItem(topic string, qos byte, handlers []Handler) (*Item, error) {
 	if utils.IsStringEmpty(topic) {
 		return nil, errors.New("没有传递主题")
 	}
@@ -101,7 +103,8 @@ func NewItem(topic string, qos byte, responseRetained bool, handlers []Handler)
 	return &Item{
 		Topic:               topic,
 		Qos:                 qos,
-		ResponseRetained:    responseRetained,
+		subscribedMutex:     &sync.Mutex{},
+		subscribed:          false,
 		handlers:            handlers,
 		currentHandlerIndex: 0,
 		currentData:         make([]byte, 0),
@@ -132,3 +135,27 @@ func (item *Item) Next() {
 func (item *Item) GetData() []byte {
 	return item.currentData
 }
+
+func (item *Item) SetUnSubscribe() {
+	item.subscribedMutex.Lock()
+	defer item.subscribedMutex.Unlock()
+
+	item.subscribed = false
+}
+
+func (item *Item) DoIfUnSubscribe(doFunc func() error) error {
+	item.subscribedMutex.Lock()
+	defer item.subscribedMutex.Unlock()
+
+	if item.subscribed {
+		return nil
+	}
+
+	err := doFunc()
+	if err != nil {
+		return err
+	}
+
+	item.subscribed = true
+	return nil
+}

+ 29 - 4
mqtt_binding_test.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"git.sxidc.com/go-tools/api_binding/mqtt_binding"
 	"git.sxidc.com/go-tools/api_binding/mqtt_binding/mqtt_client"
+	"git.sxidc.com/go-tools/api_binding/mqtt_binding/mqtt_client/router"
 	"git.sxidc.com/go-tools/api_binding/mqtt_binding/response"
 	mqtt "github.com/eclipse/paho.mqtt.golang"
 	"sync"
@@ -27,20 +28,34 @@ func TestMqttBinding(t *testing.T) {
 
 	defer mqtt_binding.Destroy()
 
-	testBinding := mqtt_binding.NewBinding("test")
-	mqtt_binding.Bind(testBinding, &mqtt_binding.BindItem[any, map[string]interface{}]{
+	testBinding := mqtt_binding.NewBinding("test", func(item *router.Item, data []byte) {
+		fmt.Println("Global Middleware!!!")
+		item.Next()
+	})
+	mqtt_binding.Bind(testBinding, &mqtt_binding.BindItem[struct {
+		Time string `json:"time"`
+	}, map[string]interface{}]{
 		Topic:        "/test-topic",
 		Qos:          2,
 		Retained:     true,
 		ResponseFunc: response.SendMapResponse,
-		BusinessFunc: func(c *mqtt_client.MqttClient, inputModel any) (map[string]interface{}, error) {
+		BusinessFunc: func(c *mqtt_client.MqttClient, inputModel struct {
+			Time string `json:"time"`
+		}) (map[string]interface{}, error) {
+			fmt.Printf("Received: %v\n", inputModel)
+
 			return map[string]interface{}{
 				"result": "pong",
 			}, nil
 		},
 		OptionalBindingFunc: nil,
+	}, func(item *router.Item, data []byte) {
+		fmt.Println("Binding Middleware!!!")
+		item.Next()
 	})
 
+	time.Sleep(10 * time.Minute)
+
 	wg := sync.WaitGroup{}
 	wg.Add(1)
 
@@ -79,7 +94,17 @@ func TestMqttBinding(t *testing.T) {
 					return
 				}
 
-				token = client.Publish("test/test-topic", 2, true, "test")
+				sendMap := map[string]any{
+					"time": time.Now().Format(time.DateTime),
+				}
+
+				sendJson, err := json.Marshal(sendMap)
+				if token.Wait(); token.Error() != nil {
+					fmt.Println(err)
+					return
+				}
+
+				token = client.Publish("test/test-topic", 2, false, sendJson)
 				if token.Wait(); token.Error() != nil {
 					fmt.Println(token.Error())
 					return