Browse Source

修改bug

yjp 1 week ago
parent
commit
984e4a3073

+ 1 - 0
baize.go

@@ -18,6 +18,7 @@ func NewApplication(conf application.Config) *application.App {
 	infrastructureConfig := new(infrastructure.Config)
 	infrastructureConfig.DatabaseConfig = conf.InfrastructureConfig.Database
 	infrastructureConfig.CacheConfig = conf.InfrastructureConfig.Cache
+	infrastructureConfig.MessageQueueConfig = conf.MessageQueue
 
 	infrastructureInstance := infrastructure.NewInfrastructure(*infrastructureConfig)
 

+ 3 - 2
framework/core/infrastructure/message_queue/common/common.go

@@ -7,8 +7,9 @@ import "git.sxidc.com/go-framework/baize/framework/core/data_protocol"
 // - queue: 消息队列
 // - topic: 主题
 // - data: 消息数据
-// 返回值: 无
-type MessageHandler func(queue MessageQueue, topic string, event *data_protocol.CloudEvent)
+// 返回值:
+// - error: 错误信息
+type MessageHandler func(queue MessageQueue, topic string, event *data_protocol.CloudEvent) error
 
 // MessageQueue 消息队列接口
 type MessageQueue interface {

+ 10 - 1
framework/core/infrastructure/message_queue/mqtt/mqtt.go

@@ -28,6 +28,7 @@ func New(address string, userName string, password string) (*MessageQueue, error
 		AddBroker(address).
 		SetClientID(clientID).
 		SetOrderMatters(false).
+		SetAutoAckDisabled(true).
 		SetWill(clientID+"/will", "dead", 2, true))
 
 	token := mqttClient.Connect()
@@ -78,8 +79,16 @@ func (messageQueue *MessageQueue) Subscribe(group string, topic string, handler
 		}
 
 		for _, groupHandler := range messageQueue.topicGroupHandlerMap[topic] {
-			go groupHandler(messageQueue, message.Topic(), event)
+			go func() {
+				err := groupHandler(messageQueue, message.Topic(), event)
+				if err != nil {
+					logger.GetInstance().Error(err)
+					return
+				}
+			}()
 		}
+
+		message.Ack()
 	})
 
 	if !token.WaitTimeout(20 * time.Second) {

+ 10 - 4
framework/core/infrastructure/message_queue/redis/redis.go

@@ -120,17 +120,23 @@ func (messageQueue *MessageQueue) Subscribe(group string, topic string, handler
 	newConsumer.Register(topic, func(message *redisqueue.Message) error {
 		data, ok := message.Values[messageValuesDataKey].(string)
 		if !ok {
-			logger.GetInstance().Error(errors.New("数据不存在"))
-			return nil
+			err := errors.New("数据不存在")
+			logger.GetInstance().Error(err)
+			return errors.New("数据不存在")
 		}
 
 		event, err := data_protocol.UnmarshalJsonCloudEvent([]byte(data))
 		if err != nil {
 			logger.GetInstance().Error(err)
-			return nil
+			return err
+		}
+
+		err = handler(messageQueue, message.Stream, event)
+		if err != nil {
+			logger.GetInstance().Error(err)
+			return err
 		}
 
-		handler(messageQueue, message.Stream, event)
 		return nil
 	})
 

+ 24 - 6
test/message_queue_test.go

@@ -60,7 +60,7 @@ func testRedisMessageQueue(t *testing.T, redisMessageQueue *redis.MessageQueue)
 	wg.Add(2)
 
 	err := redisMessageQueue.Subscribe("test1", "test-redis",
-		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) {
+		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) error {
 			if event.ID != "1" {
 				t.Fatalf("%+v\n", errors.New("消息ID不一致"))
 			}
@@ -76,13 +76,15 @@ func testRedisMessageQueue(t *testing.T, redisMessageQueue *redis.MessageQueue)
 			fmt.Println("redis test1 consumed")
 
 			wg.Done()
+
+			return nil
 		})
 	if err != nil {
 		t.Fatalf("%+v\n", err)
 	}
 
 	err = redisMessageQueue.Subscribe("test2", "test-redis",
-		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) {
+		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) error {
 			if event.ID != "1" {
 				t.Fatalf("%+v\n", errors.New("消息ID不一致"))
 			}
@@ -98,6 +100,8 @@ func testRedisMessageQueue(t *testing.T, redisMessageQueue *redis.MessageQueue)
 			fmt.Println("redis test2 consumed")
 
 			wg.Done()
+
+			return nil
 		})
 	if err != nil {
 		t.Fatalf("%+v\n", err)
@@ -116,8 +120,14 @@ func testMqttMessageQueue(t *testing.T, mqttMessageQueue *mqtt.MessageQueue) {
 	wg := sync.WaitGroup{}
 	wg.Add(2)
 
+	//err := mqttMessageQueue.Publish("test-mqtt",
+	//	data_protocol.NewCloudEvent("1", "test", "baize-test.com", "application/text", []byte("test-message")))
+	//if err != nil {
+	//	t.Fatalf("%+v\n", err)
+	//}
+
 	err := mqttMessageQueue.Subscribe("test1", "test-mqtt",
-		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) {
+		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) error {
 			if event.ID != "1" {
 				t.Fatalf("%+v\n", errors.New("消息ID不一致"))
 			}
@@ -133,13 +143,15 @@ func testMqttMessageQueue(t *testing.T, mqttMessageQueue *mqtt.MessageQueue) {
 			fmt.Println("mqtt test1 consumed")
 
 			wg.Done()
+
+			return nil
 		})
 	if err != nil {
 		t.Fatalf("%+v\n", err)
 	}
 
 	err = mqttMessageQueue.Subscribe("test2", "test-mqtt",
-		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) {
+		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) error {
 			if event.ID != "1" {
 				t.Fatalf("%+v\n", errors.New("消息ID不一致"))
 			}
@@ -155,6 +167,8 @@ func testMqttMessageQueue(t *testing.T, mqttMessageQueue *mqtt.MessageQueue) {
 			fmt.Println("mqtt test2 consumed")
 
 			wg.Done()
+
+			return nil
 		})
 	if err != nil {
 		t.Fatalf("%+v\n", err)
@@ -174,7 +188,7 @@ func testMessageQueue(t *testing.T, messageQueue common.MessageQueue) {
 	wg.Add(2)
 
 	err := message_queue.Subscribe(messageQueue, "test1", "test-message-queue",
-		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) {
+		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) error {
 			if event.ID != "1" {
 				t.Fatalf("%+v\n", errors.New("消息ID不一致"))
 			}
@@ -190,13 +204,15 @@ func testMessageQueue(t *testing.T, messageQueue common.MessageQueue) {
 			fmt.Println("test1 consumed")
 
 			wg.Done()
+
+			return nil
 		})
 	if err != nil {
 		t.Fatalf("%+v\n", err)
 	}
 
 	err = message_queue.Subscribe(messageQueue, "test2", "test-message-queue",
-		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) {
+		func(queue common.MessageQueue, topic string, event *data_protocol.CloudEvent) error {
 			if event.ID != "1" {
 				t.Fatalf("%+v\n", errors.New("消息ID不一致"))
 			}
@@ -212,6 +228,8 @@ func testMessageQueue(t *testing.T, messageQueue common.MessageQueue) {
 			fmt.Println("test2 consumed")
 
 			wg.Done()
+
+			return nil
 		})
 	if err != nil {
 		t.Fatalf("%+v\n", err)