|
@@ -0,0 +1,532 @@
|
|
|
+package mqtt_client
|
|
|
+
|
|
|
+import (
|
|
|
+ "encoding/json"
|
|
|
+ "errors"
|
|
|
+ "git.sxidc.com/go-tools/utils/strutils"
|
|
|
+ "git.sxidc.com/service-supports/fslog"
|
|
|
+ mqtt "github.com/eclipse/paho.mqtt.golang"
|
|
|
+ "log"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+)
|
|
|
+
|
|
|
+var (
|
|
|
+ ErrMqttMessageIgnore = errors.New("mqtt消息忽略")
|
|
|
+)
|
|
|
+
|
|
|
+type MessageHandler func(client *Client, token *SubscribeToken, topic string, data []byte) error
|
|
|
+
|
|
|
+type ClientOptions struct {
|
|
|
+ UserName string
|
|
|
+ Password string
|
|
|
+ Address string
|
|
|
+ ClientID string
|
|
|
+ KeepAliveSec int64
|
|
|
+ PingTimeoutSec int64
|
|
|
+ WriteTimeoutSec int64
|
|
|
+}
|
|
|
+
|
|
|
+func (opts *ClientOptions) check() error {
|
|
|
+ if strutils.IsStringEmpty(opts.UserName) {
|
|
|
+ return errors.New("必须传递用户名")
|
|
|
+ }
|
|
|
+
|
|
|
+ if strutils.IsStringEmpty(opts.Password) {
|
|
|
+ return errors.New("必须传递密码")
|
|
|
+ }
|
|
|
+
|
|
|
+ if strutils.IsStringEmpty(opts.Address) {
|
|
|
+ return errors.New("必须传递地址")
|
|
|
+ }
|
|
|
+
|
|
|
+ if strutils.IsStringEmpty(opts.ClientID) {
|
|
|
+ return errors.New("必须传递客户端ID")
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+type SubscribeToken struct {
|
|
|
+ messageHandler MessageHandler
|
|
|
+ successHandleCount int
|
|
|
+ handleCount int
|
|
|
+}
|
|
|
+
|
|
|
+func (token *SubscribeToken) SuccessHandleCount() int {
|
|
|
+ return token.successHandleCount
|
|
|
+}
|
|
|
+
|
|
|
+func (token *SubscribeToken) HandleCount() int {
|
|
|
+ return token.handleCount
|
|
|
+}
|
|
|
+
|
|
|
+type subscribeTopic struct {
|
|
|
+ topic string
|
|
|
+ tokens []*SubscribeToken
|
|
|
+}
|
|
|
+
|
|
|
+type Client struct {
|
|
|
+ client mqtt.Client
|
|
|
+ opts *ClientOptions
|
|
|
+
|
|
|
+ topicsMutex *sync.Mutex
|
|
|
+ topics []*subscribeTopic
|
|
|
+
|
|
|
+ publishAndReceiveTopicMapMutex sync.Mutex
|
|
|
+ publishAndReceiveTopicMap map[string]chan any
|
|
|
+}
|
|
|
+
|
|
|
+func New(opts *ClientOptions) (*Client, error) {
|
|
|
+ if opts == nil {
|
|
|
+ return nil, errors.New("必须传递参数")
|
|
|
+ }
|
|
|
+
|
|
|
+ err := opts.check()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if opts.WriteTimeoutSec == 0 {
|
|
|
+ opts.WriteTimeoutSec = 60
|
|
|
+ }
|
|
|
+
|
|
|
+ client := &Client{
|
|
|
+ opts: opts,
|
|
|
+ topicsMutex: &sync.Mutex{},
|
|
|
+ topics: make([]*subscribeTopic, 0),
|
|
|
+ publishAndReceiveTopicMap: make(map[string]chan any),
|
|
|
+ }
|
|
|
+
|
|
|
+ mqttClient := mqtt.NewClient(mqtt.NewClientOptions().
|
|
|
+ SetAutoReconnect(true).
|
|
|
+ SetUsername(opts.UserName).
|
|
|
+ SetPassword(opts.Password).
|
|
|
+ AddBroker(opts.Address).
|
|
|
+ SetClientID(opts.ClientID).
|
|
|
+ SetKeepAlive(time.Duration(opts.KeepAliveSec)*time.Second).
|
|
|
+ SetPingTimeout(time.Duration(opts.PingTimeoutSec)*time.Second).
|
|
|
+ SetWriteTimeout(time.Duration(opts.WriteTimeoutSec)*time.Second).
|
|
|
+ SetOrderMatters(false).
|
|
|
+ SetWill(opts.ClientID+"/will", "dead", 2, true).
|
|
|
+ SetOnConnectHandler(func(mqttClient mqtt.Client) {
|
|
|
+ err := client.subscribeAll()
|
|
|
+ if err != nil {
|
|
|
+ fslog.Error("SetOnConnectHandler: " + err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }))
|
|
|
+
|
|
|
+ token := mqttClient.Connect()
|
|
|
+ if !token.WaitTimeout(time.Duration(opts.WriteTimeoutSec) * time.Second) {
|
|
|
+ return nil, errors.New("连接超时")
|
|
|
+ }
|
|
|
+
|
|
|
+ if token.Error() != nil {
|
|
|
+ return nil, token.Error()
|
|
|
+ }
|
|
|
+
|
|
|
+ client.client = mqttClient
|
|
|
+
|
|
|
+ return client, nil
|
|
|
+}
|
|
|
+
|
|
|
+func Destroy(client *Client) {
|
|
|
+ client.client.Disconnect(250)
|
|
|
+ client.topicsMutex.Lock()
|
|
|
+
|
|
|
+ for _, subscribedTopic := range client.topics {
|
|
|
+ token := client.client.Unsubscribe(subscribedTopic.topic)
|
|
|
+ if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if token.Error() != nil {
|
|
|
+ fslog.Error(token.Error())
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ client.topics = nil
|
|
|
+
|
|
|
+ client.topicsMutex.Unlock()
|
|
|
+
|
|
|
+ client = nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) Publish(topic string, qos byte, retained bool, payload any) error {
|
|
|
+ if strutils.IsStringEmpty(topic) {
|
|
|
+ return errors.New("没有传递发布主题")
|
|
|
+ }
|
|
|
+
|
|
|
+ if payload == nil {
|
|
|
+ return errors.New("发布的payload不能为nil")
|
|
|
+ }
|
|
|
+
|
|
|
+ client.publish(topic, qos, retained, payload)
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) Subscribe(topic string, handlerFunc MessageHandler) (*SubscribeToken, error) {
|
|
|
+ if strutils.IsStringEmpty(topic) {
|
|
|
+ return nil, errors.New("没有传递订阅主题")
|
|
|
+ }
|
|
|
+
|
|
|
+ if handlerFunc == nil {
|
|
|
+ return nil, errors.New("必须传递处理函数")
|
|
|
+ }
|
|
|
+
|
|
|
+ return client.subscribe(topic, handlerFunc)
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) Unsubscribe(topic string, token *SubscribeToken) error {
|
|
|
+ if strutils.IsStringEmpty(topic) {
|
|
|
+ return errors.New("没有传递取消订阅主题")
|
|
|
+ }
|
|
|
+
|
|
|
+ return client.unsubscribe(topic, token)
|
|
|
+}
|
|
|
+
|
|
|
+type PublishAndReceiveReplyParams struct {
|
|
|
+ Topic string
|
|
|
+ ReplyTopic string
|
|
|
+ PublishData []byte
|
|
|
+ RepublishDurationSec int64
|
|
|
+ TryTimes int
|
|
|
+ StopBefore bool
|
|
|
+}
|
|
|
+
|
|
|
+func (params *PublishAndReceiveReplyParams) Check() error {
|
|
|
+ if strutils.IsStringEmpty(params.Topic) {
|
|
|
+ return errors.New("没有传递订阅主题")
|
|
|
+ }
|
|
|
+
|
|
|
+ if strutils.IsStringEmpty(params.ReplyTopic) {
|
|
|
+ return errors.New("没有传递响应订阅主题")
|
|
|
+ }
|
|
|
+
|
|
|
+ if params.PublishData == nil {
|
|
|
+ return errors.New("发布的数据不能为nil")
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+type msgResponse struct {
|
|
|
+ Success bool `json:"success"`
|
|
|
+ Msg string `json:"msg"`
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) PublishAndReceiveReplyMsgResponse(params *PublishAndReceiveReplyParams) error {
|
|
|
+ err := params.Check()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ err = client.publishAndReceiveReply(params, func(payload []byte) error {
|
|
|
+ resp := new(msgResponse)
|
|
|
+ err := json.Unmarshal(payload, resp)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ if !resp.Success {
|
|
|
+ return errors.New(resp.Msg)
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) PublishAndReceiveReply(params *PublishAndReceiveReplyParams, payloadDealFunc func(payload []byte) error) error {
|
|
|
+ err := params.Check()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return client.publishAndReceiveReply(params, payloadDealFunc)
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) publish(topic string, qos byte, retained bool, payload any) {
|
|
|
+ client.waitConnected()
|
|
|
+
|
|
|
+ token := client.client.Publish(topic, qos, retained, payload)
|
|
|
+ if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) {
|
|
|
+ log.Println("发布超时")
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if token.Error() != nil {
|
|
|
+ log.Println(token.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) subscribe(topic string, handlerFunc MessageHandler) (*SubscribeToken, error) {
|
|
|
+ client.waitConnected()
|
|
|
+
|
|
|
+ return client.addSubscribedTopic(topic, handlerFunc, func(subscribedTopic *subscribeTopic) error {
|
|
|
+ return client.subscribeMqtt(subscribedTopic)
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) unsubscribe(topic string, token *SubscribeToken) error {
|
|
|
+ client.waitConnected()
|
|
|
+
|
|
|
+ err := client.removeSubscribedTopic(topic, token, func(subscribedTopic *subscribeTopic) error {
|
|
|
+ token := client.client.Unsubscribe(subscribedTopic.topic)
|
|
|
+ if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) {
|
|
|
+ return errors.New("取消订阅超时")
|
|
|
+ }
|
|
|
+
|
|
|
+ if token.Error() != nil {
|
|
|
+ return token.Error()
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) publishAndReceiveReply(params *PublishAndReceiveReplyParams, payloadDealFunc func(payload []byte) error) error {
|
|
|
+ doneChan := make(chan any)
|
|
|
+
|
|
|
+ if params.StopBefore {
|
|
|
+ client.publishAndReceiveTopicMapMutex.Lock()
|
|
|
+ existDoneChan, ok := client.publishAndReceiveTopicMap[params.Topic]
|
|
|
+ if ok {
|
|
|
+ close(existDoneChan)
|
|
|
+ delete(client.publishAndReceiveTopicMap, params.Topic)
|
|
|
+ } else {
|
|
|
+ client.publishAndReceiveTopicMap[params.Topic] = doneChan
|
|
|
+ }
|
|
|
+ client.publishAndReceiveTopicMapMutex.Unlock()
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ token, err := client.subscribe(params.ReplyTopic, func(c *Client, token *SubscribeToken, topic string, data []byte) error {
|
|
|
+ defer func() {
|
|
|
+ client.publishAndReceiveTopicMapMutex.Lock()
|
|
|
+ existDoneChan, ok := client.publishAndReceiveTopicMap[params.Topic]
|
|
|
+ if ok {
|
|
|
+ close(existDoneChan)
|
|
|
+ delete(client.publishAndReceiveTopicMap, params.Topic)
|
|
|
+ }
|
|
|
+ client.publishAndReceiveTopicMapMutex.Unlock()
|
|
|
+ }()
|
|
|
+
|
|
|
+ if token.SuccessHandleCount() >= 1 {
|
|
|
+ return ErrMqttMessageIgnore
|
|
|
+ }
|
|
|
+
|
|
|
+ if payloadDealFunc != nil {
|
|
|
+ err := payloadDealFunc(data)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ go func() {
|
|
|
+ client.publish(params.Topic, 2, false, params.PublishData)
|
|
|
+
|
|
|
+ currentTryTime := 1
|
|
|
+ timer := time.NewTimer(time.Duration(params.RepublishDurationSec) * time.Second)
|
|
|
+ defer timer.Stop()
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-doneChan:
|
|
|
+ err := client.unsubscribe(params.ReplyTopic, token)
|
|
|
+ if err != nil {
|
|
|
+ fslog.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ client.publishAndReceiveTopicMapMutex.Lock()
|
|
|
+ delete(client.publishAndReceiveTopicMap, params.Topic)
|
|
|
+ client.publishAndReceiveTopicMapMutex.Unlock()
|
|
|
+
|
|
|
+ return
|
|
|
+ case <-timer.C:
|
|
|
+ client.publish(params.Topic, 2, false, params.PublishData)
|
|
|
+
|
|
|
+ if params.TryTimes != 0 {
|
|
|
+ currentTryTime++
|
|
|
+ }
|
|
|
+
|
|
|
+ if params.TryTimes != 0 && currentTryTime > params.TryTimes {
|
|
|
+ client.publishAndReceiveTopicMapMutex.Lock()
|
|
|
+ existDoneChan, ok := client.publishAndReceiveTopicMap[params.Topic]
|
|
|
+ if ok {
|
|
|
+ close(existDoneChan)
|
|
|
+ delete(client.publishAndReceiveTopicMap, params.Topic)
|
|
|
+ }
|
|
|
+ client.publishAndReceiveTopicMapMutex.Unlock()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ resetDuration := time.Duration(params.RepublishDurationSec*int64(currentTryTime)) * time.Second
|
|
|
+ timer.Reset(resetDuration)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) waitConnected() {
|
|
|
+ for {
|
|
|
+ if client.client.IsConnected() {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ time.Sleep(1 * time.Second)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) subscribeAll() error {
|
|
|
+ err := client.rangeSubscribedTopics(func(subscribedTopic *subscribeTopic) error {
|
|
|
+ return client.subscribeMqtt(subscribedTopic)
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) subscribeMqtt(subscribedTopic *subscribeTopic) error {
|
|
|
+ token := client.client.Subscribe(subscribedTopic.topic, 2, func(mqttClient mqtt.Client, message mqtt.Message) {
|
|
|
+ wg := sync.WaitGroup{}
|
|
|
+ wg.Add(len(subscribedTopic.tokens))
|
|
|
+ for _, token := range subscribedTopic.tokens {
|
|
|
+ go func(token *SubscribeToken, message mqtt.Message) {
|
|
|
+ defer func() {
|
|
|
+ token.handleCount++
|
|
|
+ wg.Done()
|
|
|
+ }()
|
|
|
+
|
|
|
+ err := token.messageHandler(client, token, subscribedTopic.topic, message.Payload())
|
|
|
+ if err != nil && !errors.Is(err, ErrMqttMessageIgnore) {
|
|
|
+ fslog.Error(err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if err != nil && errors.Is(err, ErrMqttMessageIgnore) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ token.successHandleCount++
|
|
|
+ }(token, message)
|
|
|
+ }
|
|
|
+ wg.Wait()
|
|
|
+ })
|
|
|
+ if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) {
|
|
|
+ return errors.New("订阅超时")
|
|
|
+ }
|
|
|
+
|
|
|
+ if token.Error() != nil {
|
|
|
+ return token.Error()
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) addSubscribedTopic(topic string, handler MessageHandler, addNew func(subscribedTopic *subscribeTopic) error) (*SubscribeToken, error) {
|
|
|
+ client.topicsMutex.Lock()
|
|
|
+ defer client.topicsMutex.Unlock()
|
|
|
+
|
|
|
+ newToken := &SubscribeToken{messageHandler: handler}
|
|
|
+
|
|
|
+ for _, savedTopic := range client.topics {
|
|
|
+ if savedTopic.topic == topic {
|
|
|
+ savedTopic.tokens = append(savedTopic.tokens, newToken)
|
|
|
+ return newToken, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ subscribedTopic := &subscribeTopic{
|
|
|
+ topic: topic,
|
|
|
+ tokens: []*SubscribeToken{newToken},
|
|
|
+ }
|
|
|
+
|
|
|
+ err := addNew(subscribedTopic)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ client.topics = append(client.topics, subscribedTopic)
|
|
|
+ return newToken, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) removeSubscribedTopic(topic string, token *SubscribeToken, noTokens func(subscribedTopic *subscribeTopic) error) error {
|
|
|
+ client.topicsMutex.Lock()
|
|
|
+ defer client.topicsMutex.Unlock()
|
|
|
+
|
|
|
+ findSubscribeTopicIndex := -1
|
|
|
+ for index, subscribedTopic := range client.topics {
|
|
|
+ if subscribedTopic.topic == topic {
|
|
|
+ findSubscribeTopicIndex = index
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if findSubscribeTopicIndex == -1 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ subscribedTopic := client.topics[findSubscribeTopicIndex]
|
|
|
+ if subscribedTopic.tokens != nil && len(subscribedTopic.tokens) == 1 {
|
|
|
+ err := noTokens(subscribedTopic)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ client.topics = append(client.topics[:findSubscribeTopicIndex], client.topics[findSubscribeTopicIndex+1:]...)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ findTokenIndex := -1
|
|
|
+ for index, savedToken := range subscribedTopic.tokens {
|
|
|
+ if savedToken == token {
|
|
|
+ findTokenIndex = index
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if findTokenIndex == -1 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ subscribedTopic.tokens = append(subscribedTopic.tokens[:findTokenIndex], subscribedTopic.tokens[findTokenIndex+1:]...)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (client *Client) rangeSubscribedTopics(rangeFunc func(subscribedTopic *subscribeTopic) error) error {
|
|
|
+ client.topicsMutex.Lock()
|
|
|
+ defer client.topicsMutex.Unlock()
|
|
|
+
|
|
|
+ for _, subscribedTopic := range client.topics {
|
|
|
+ tempTopic := subscribedTopic
|
|
|
+ err := rangeFunc(tempTopic)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|