package mqtt_client import ( "encoding/json" "errors" "git.sxidc.com/go-tools/utils/strutils" mqtt "github.com/eclipse/paho.mqtt.golang" "log" "sync" "time" ) var ( ErrMessageIgnore = errors.New("mqtt消息忽略") ) type MessageHandler func(client *Client, token *SubscribeToken, topic string, data []byte) error type ClientOptions struct { UserName string Password string Address string ClientID string KeepAliveSec int64 PingTimeoutSec int64 WriteTimeoutSec int64 } func (opts *ClientOptions) check() error { if strutils.IsStringEmpty(opts.UserName) { return errors.New("必须传递用户名") } if strutils.IsStringEmpty(opts.Password) { return errors.New("必须传递密码") } if strutils.IsStringEmpty(opts.Address) { return errors.New("必须传递地址") } if strutils.IsStringEmpty(opts.ClientID) { return errors.New("必须传递客户端ID") } return nil } type SubscribeToken struct { messageHandler MessageHandler successHandleCount int handleCount int } func (token *SubscribeToken) SuccessHandleCount() int { return token.successHandleCount } func (token *SubscribeToken) HandleCount() int { return token.handleCount } type subscribeTopic struct { topic string tokens []*SubscribeToken } type Client struct { client mqtt.Client opts *ClientOptions topicsMutex *sync.Mutex topics []*subscribeTopic publishAndReceiveTopicMapMutex sync.Mutex publishAndReceiveTopicMap map[string]chan any } func New(opts *ClientOptions) (*Client, error) { if opts == nil { return nil, errors.New("必须传递参数") } err := opts.check() if err != nil { return nil, err } if opts.WriteTimeoutSec == 0 { opts.WriteTimeoutSec = 60 } client := &Client{ opts: opts, topicsMutex: &sync.Mutex{}, topics: make([]*subscribeTopic, 0), publishAndReceiveTopicMap: make(map[string]chan any), } mqttClient := mqtt.NewClient(mqtt.NewClientOptions(). SetAutoReconnect(true). SetUsername(opts.UserName). SetPassword(opts.Password). AddBroker(opts.Address). SetClientID(opts.ClientID). SetKeepAlive(time.Duration(opts.KeepAliveSec)*time.Second). SetPingTimeout(time.Duration(opts.PingTimeoutSec)*time.Second). SetWriteTimeout(time.Duration(opts.WriteTimeoutSec)*time.Second). SetOrderMatters(false). SetWill(opts.ClientID+"/will", "dead", 2, true). SetOnConnectHandler(func(mqttClient mqtt.Client) { err := client.subscribeAll() if err != nil { log.Println("SetOnConnectHandler:", err.Error()) return } })) token := mqttClient.Connect() if !token.WaitTimeout(time.Duration(opts.WriteTimeoutSec) * time.Second) { return nil, errors.New("连接超时") } if token.Error() != nil { return nil, token.Error() } client.client = mqttClient return client, nil } func Destroy(client *Client) { client.client.Disconnect(250) client.topicsMutex.Lock() for _, subscribedTopic := range client.topics { token := client.client.Unsubscribe(subscribedTopic.topic) if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) { continue } if token.Error() != nil { log.Println("Destroy: ", token.Error()) continue } } client.topics = nil client.topicsMutex.Unlock() client = nil } func (client *Client) Publish(topic string, qos byte, retained bool, payload any) error { if strutils.IsStringEmpty(topic) { return errors.New("没有传递发布主题") } if payload == nil { return errors.New("发布的payload不能为nil") } client.publish(topic, qos, retained, payload) return nil } func (client *Client) Subscribe(topic string, handlerFunc MessageHandler) (*SubscribeToken, error) { if strutils.IsStringEmpty(topic) { return nil, errors.New("没有传递订阅主题") } if handlerFunc == nil { return nil, errors.New("必须传递处理函数") } return client.subscribe(topic, handlerFunc) } func (client *Client) Unsubscribe(topic string, token *SubscribeToken) error { if strutils.IsStringEmpty(topic) { return errors.New("没有传递取消订阅主题") } return client.unsubscribe(topic, token) } type PublishAndReceiveReplyParams struct { Topic string ReplyTopic string PublishData []byte RepublishDurationSec int64 TryTimes int StopBefore bool } func (params *PublishAndReceiveReplyParams) Check() error { if strutils.IsStringEmpty(params.Topic) { return errors.New("没有传递订阅主题") } if strutils.IsStringEmpty(params.ReplyTopic) { return errors.New("没有传递响应订阅主题") } if params.PublishData == nil { return errors.New("发布的数据不能为nil") } return nil } type msgResponse struct { Success bool `json:"success"` Msg string `json:"msg"` } func (client *Client) PublishAndReceiveReplyMsgResponse(params *PublishAndReceiveReplyParams) error { err := params.Check() if err != nil { return err } err = client.publishAndReceiveReply(params, func(payload []byte) error { resp := new(msgResponse) err := json.Unmarshal(payload, resp) if err != nil { return err } if !resp.Success { return errors.New(resp.Msg) } return nil }) if err != nil { return err } return nil } func (client *Client) PublishAndReceiveReply(params *PublishAndReceiveReplyParams, payloadDealFunc func(payload []byte) error) error { err := params.Check() if err != nil { return err } return client.publishAndReceiveReply(params, payloadDealFunc) } func (client *Client) publish(topic string, qos byte, retained bool, payload any) { client.waitConnected() token := client.client.Publish(topic, qos, retained, payload) if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) { log.Println("发布超时") return } if token.Error() != nil { log.Println("publish:", token.Error()) return } } func (client *Client) subscribe(topic string, handlerFunc MessageHandler) (*SubscribeToken, error) { client.waitConnected() return client.addSubscribedTopic(topic, handlerFunc, func(subscribedTopic *subscribeTopic) error { return client.doSubscribe(subscribedTopic) }) } func (client *Client) unsubscribe(topic string, token *SubscribeToken) error { client.waitConnected() err := client.removeSubscribedTopic(topic, token, func(subscribedTopic *subscribeTopic) error { token := client.client.Unsubscribe(subscribedTopic.topic) if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) { return errors.New("取消订阅超时") } if token.Error() != nil { return token.Error() } return nil }) if err != nil { return err } return nil } func (client *Client) publishAndReceiveReply(params *PublishAndReceiveReplyParams, payloadDealFunc func(payload []byte) error) error { doneChan := make(chan any) if params.StopBefore { client.publishAndReceiveTopicMapMutex.Lock() existDoneChan, ok := client.publishAndReceiveTopicMap[params.Topic] if ok { close(existDoneChan) delete(client.publishAndReceiveTopicMap, params.Topic) } else { client.publishAndReceiveTopicMap[params.Topic] = doneChan } client.publishAndReceiveTopicMapMutex.Unlock() } // 订阅响应主题 token, err := client.subscribe(params.ReplyTopic, func(c *Client, token *SubscribeToken, topic string, data []byte) error { defer func() { client.publishAndReceiveTopicMapMutex.Lock() existDoneChan, ok := client.publishAndReceiveTopicMap[params.Topic] if ok { close(existDoneChan) delete(client.publishAndReceiveTopicMap, params.Topic) } client.publishAndReceiveTopicMapMutex.Unlock() }() if token.SuccessHandleCount() >= 1 { return ErrMessageIgnore } if payloadDealFunc != nil { err := payloadDealFunc(data) if err != nil { return err } } return nil }) if err != nil { return err } go func() { client.publish(params.Topic, 2, false, params.PublishData) currentTryTime := 1 timer := time.NewTimer(time.Duration(params.RepublishDurationSec) * time.Second) defer timer.Stop() for { select { case <-doneChan: err := client.unsubscribe(params.ReplyTopic, token) if err != nil { log.Println("publishAndReceiveReply done:", err.Error()) return } client.publishAndReceiveTopicMapMutex.Lock() delete(client.publishAndReceiveTopicMap, params.Topic) client.publishAndReceiveTopicMapMutex.Unlock() return case <-timer.C: client.publish(params.Topic, 2, false, params.PublishData) if params.TryTimes != 0 { currentTryTime++ } if params.TryTimes != 0 && currentTryTime > params.TryTimes { client.publishAndReceiveTopicMapMutex.Lock() existDoneChan, ok := client.publishAndReceiveTopicMap[params.Topic] if ok { close(existDoneChan) delete(client.publishAndReceiveTopicMap, params.Topic) } client.publishAndReceiveTopicMapMutex.Unlock() return } resetDuration := time.Duration(params.RepublishDurationSec*int64(currentTryTime)) * time.Second timer.Reset(resetDuration) } } }() return nil } func (client *Client) waitConnected() { for { if client.client.IsConnected() { break } time.Sleep(1 * time.Second) } } func (client *Client) subscribeAll() error { err := client.rangeSubscribedTopics(func(subscribedTopic *subscribeTopic) error { return client.doSubscribe(subscribedTopic) }) if err != nil { return err } return nil } func (client *Client) doSubscribe(subscribedTopic *subscribeTopic) error { token := client.client.Subscribe(subscribedTopic.topic, 2, func(mqttClient mqtt.Client, message mqtt.Message) { wg := sync.WaitGroup{} wg.Add(len(subscribedTopic.tokens)) for _, token := range subscribedTopic.tokens { go func(token *SubscribeToken, message mqtt.Message) { defer func() { token.handleCount++ wg.Done() }() err := token.messageHandler(client, token, subscribedTopic.topic, message.Payload()) if err != nil && !errors.Is(err, ErrMessageIgnore) { log.Println("doSubscribe token.messageHandler:", err.Error()) return } if err != nil && errors.Is(err, ErrMessageIgnore) { return } token.successHandleCount++ }(token, message) } wg.Wait() }) if !token.WaitTimeout(time.Duration(client.opts.WriteTimeoutSec) * time.Second) { return errors.New("订阅超时") } if token.Error() != nil { return token.Error() } return nil } func (client *Client) addSubscribedTopic(topic string, handler MessageHandler, addNew func(subscribedTopic *subscribeTopic) error) (*SubscribeToken, error) { client.topicsMutex.Lock() defer client.topicsMutex.Unlock() newToken := &SubscribeToken{messageHandler: handler} for _, savedTopic := range client.topics { if savedTopic.topic == topic { savedTopic.tokens = append(savedTopic.tokens, newToken) return newToken, nil } } subscribedTopic := &subscribeTopic{ topic: topic, tokens: []*SubscribeToken{newToken}, } err := addNew(subscribedTopic) if err != nil { return nil, err } client.topics = append(client.topics, subscribedTopic) return newToken, nil } func (client *Client) removeSubscribedTopic(topic string, token *SubscribeToken, noTokens func(subscribedTopic *subscribeTopic) error) error { client.topicsMutex.Lock() defer client.topicsMutex.Unlock() findSubscribeTopicIndex := -1 for index, subscribedTopic := range client.topics { if subscribedTopic.topic == topic { findSubscribeTopicIndex = index } } if findSubscribeTopicIndex == -1 { return nil } subscribedTopic := client.topics[findSubscribeTopicIndex] if subscribedTopic.tokens != nil && len(subscribedTopic.tokens) == 1 { err := noTokens(subscribedTopic) if err != nil { return err } client.topics = append(client.topics[:findSubscribeTopicIndex], client.topics[findSubscribeTopicIndex+1:]...) return nil } findTokenIndex := -1 for index, savedToken := range subscribedTopic.tokens { if savedToken == token { findTokenIndex = index } } if findTokenIndex == -1 { return nil } subscribedTopic.tokens = append(subscribedTopic.tokens[:findTokenIndex], subscribedTopic.tokens[findTokenIndex+1:]...) return nil } func (client *Client) rangeSubscribedTopics(rangeFunc func(subscribedTopic *subscribeTopic) error) error { client.topicsMutex.Lock() defer client.topicsMutex.Unlock() for _, subscribedTopic := range client.topics { tempTopic := subscribedTopic err := rangeFunc(tempTopic) if err != nil { return err } } return nil }