Browse Source

修改部分bug

yjp 3 months ago
parent
commit
ef0b65186c

+ 30 - 1
baize.go

@@ -4,6 +4,7 @@ import (
 	"git.sxidc.com/go-framework/baize/framework/core/api"
 	"git.sxidc.com/go-framework/baize/framework/core/application"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure"
+	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api"
 )
 
 func NewApplication(conf application.Config) *application.App {
@@ -20,7 +21,30 @@ func NewApplication(conf application.Config) *application.App {
 
 	infrastructureInstance := infrastructure.NewInfrastructure(*infrastructureConfig)
 
-	return application.New(apiInstance, infrastructureInstance)
+	app := application.New(apiInstance, infrastructureInstance)
+
+	// 添加MqttApi
+	if conf.MqttApiConfig != nil {
+		mqttConfig := conf.MqttApiConfig
+		mqttApi, err := mqtt_api.New(mqtt_api.WithTopicPrefix(mqttConfig.TopicPrefix),
+			mqtt_api.WithLogSkipPaths(mqttConfig.LogSkipPaths...),
+			mqtt_api.WithMqttOptions(&mqtt_api.MqttClientOptions{
+				UserName:        mqttConfig.MqttConfig.UserName,
+				Password:        mqttConfig.MqttConfig.Password,
+				Address:         mqttConfig.MqttConfig.Address,
+				ClientID:        mqttConfig.MqttConfig.ClientID,
+				KeepAliveSec:    mqttConfig.MqttConfig.KeepAliveSec,
+				PingTimeoutSec:  mqttConfig.MqttConfig.PingTimeoutSec,
+				WriteTimeoutSec: mqttConfig.MqttConfig.WriteTimeoutSec,
+			}))
+		if err != nil {
+			panic(err)
+		}
+
+		app.AddMqttApi(mqttApi)
+	}
+
+	return app
 }
 
 func DestroyApplication(app *application.App) {
@@ -28,5 +52,10 @@ func DestroyApplication(app *application.App) {
 		return
 	}
 
+	mqttApi := app.MqttApi()
+	if mqttApi != nil {
+		mqtt_api.Destroy(mqttApi)
+	}
+
 	infrastructure.DestroyInfrastructure(app.Infrastructure())
 }

+ 38 - 0
framework/core/application/application.go

@@ -3,6 +3,7 @@ package application
 import (
 	"git.sxidc.com/go-framework/baize/framework/core/api"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure"
+	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api"
 )
 
 type App struct {
@@ -11,6 +12,9 @@ type App struct {
 
 	// 基础设施实例
 	infrastructureInstance *infrastructure.Infrastructure
+
+	// 可选mqtt api
+	mqttApiInstance *mqtt_api.MqttApi
 }
 
 // New 创建Application
@@ -75,3 +79,37 @@ func (app *App) Infrastructure() *infrastructure.Infrastructure {
 func (app *App) ChooseRouter(routerType string, version string) api.Router {
 	return app.Api().ChooseRouter(routerType, version)
 }
+
+func (app *App) StartMqttApi() error {
+	if app.mqttApiInstance == nil {
+		return nil
+	}
+
+	err := app.mqttApiInstance.Start()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (app *App) FinishMqttApi() error {
+	if app.mqttApiInstance != nil {
+		app.mqttApiInstance.Finish()
+	}
+
+	err := app.apiInstance.Finish()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (app *App) AddMqttApi(mqttApi *mqtt_api.MqttApi) {
+	app.mqttApiInstance = mqttApi
+}
+
+func (app *App) MqttApi() *mqtt_api.MqttApi {
+	return app.mqttApiInstance
+}

+ 17 - 0
framework/core/application/config.go

@@ -12,6 +12,7 @@ import (
 type Config struct {
 	ApiConfig            `json:"api" yaml:"api"`
 	InfrastructureConfig `json:"infrastructure" yaml:"infrastructure"`
+	*MqttApiConfig       `json:"mqtt_api" yaml:"mqtt_api"`
 }
 
 type ApiConfig struct {
@@ -26,6 +27,22 @@ type InfrastructureConfig struct {
 	MessageQueue infrastructure.MessageQueueConfig `json:"message_queue" yaml:"message_queue"`
 }
 
+type MqttApiConfig struct {
+	TopicPrefix  string     `json:"topic_prefix" yaml:"topic_prefix"`
+	LogSkipPaths []string   `json:"log_skip_paths" yaml:"log_skip_paths"`
+	MqttConfig   MqttConfig `json:"mqtt_config" yaml:"mqtt_config"`
+}
+
+type MqttConfig struct {
+	UserName        string `json:"username" yaml:"username"`
+	Password        string `json:"password" yaml:"password"`
+	Address         string `json:"address" yaml:"address"`
+	ClientID        string `json:"client_id" yaml:"client_id"`
+	KeepAliveSec    int64  `json:"keep_alive_sec" yaml:"keep_alive_sec"`
+	PingTimeoutSec  int64  `json:"ping_timeout_sec" yaml:"ping_timeout_sec"`
+	WriteTimeoutSec int64  `json:"write_timeout_sec" yaml:"write_timeout_sec"`
+}
+
 func LoadFromJsonFile(jsonFilePath string) (Config, error) {
 	if !fileutils.PathExists(jsonFilePath) {
 		return Config{}, errors.New("配置文件不存在")

+ 1 - 13
framework/core/mqtt_api/request/request.go → framework/core/mqtt_api/mqtt_request/mqtt_request.go

@@ -1,10 +1,8 @@
-package request
+package mqtt_request
 
 import (
 	"encoding/json"
-	"git.sxidc.com/go-framework/baize/framework/core/domain"
 	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api"
-	"git.sxidc.com/go-framework/baize/framework/core/tag/assign"
 	"github.com/go-playground/validator/v10"
 )
 
@@ -27,13 +25,3 @@ func BindingJson(c *mqtt_api.Context, request any) error {
 
 	return nil
 }
-
-// AssignRequestParamsToDomainObject 基于assign tag将请求参数赋值到领域对象
-// 参数:
-// - params: 请求参数
-// - domainObject: 领域对象
-// 返回值:
-// - 错误
-func AssignRequestParamsToDomainObject(params Params, domainObject domain.Object) error {
-	return assign.DefaultUsage(params, domainObject)
-}

+ 91 - 0
framework/core/mqtt_api/mqtt_request/params.go

@@ -0,0 +1,91 @@
+package mqtt_request
+
+import (
+	"git.sxidc.com/go-framework/baize/framework/core/domain"
+	"git.sxidc.com/go-framework/baize/framework/core/tag/assign"
+	"git.sxidc.com/go-tools/utils/reflectutils"
+	"github.com/pkg/errors"
+	"reflect"
+)
+
+// AssignRequestParamsToDomainObject 基于assign tag将请求参数赋值到领域对象
+// 参数:
+// - params: 请求参数
+// - domainObject: 领域对象
+// 返回值:
+// - 错误
+func AssignRequestParamsToDomainObject(params any, domainObject domain.Object) error {
+	return assign.DefaultUsage(params, domainObject)
+}
+
+// Field 获取请求对象中的字段值
+// 泛型参数:
+// - T: 请求对象中的字段值的类型
+// 参数:
+// - params: 请求参数
+// - fieldName: 字段名
+// 返回值:
+// - 请求对象中的字段值
+// - 错误
+func Field[T any](params any, fieldName string) (T, error) {
+	zero := reflectutils.Zero[T]()
+
+	if params == nil {
+		return zero, errors.New("请求参数为nil")
+	}
+
+	fieldValue, err := getRequestParamsValue(params, fieldName)
+	if err != nil {
+		return zero, err
+	}
+
+	if !fieldValue.IsValid() {
+		return zero, errors.New("请求参数" + fieldValue.Type().String() + "的字段" + fieldName + "无法赋值")
+	}
+
+	retValue, ok := fieldValue.Interface().(T)
+	if !ok {
+		return zero, errors.New("请求参数" + fieldValue.Type().String() + "的字段" + fieldName + "无法转换类型")
+	}
+
+	return retValue, nil
+}
+
+// ToConcrete 将请求参数转换为具体的请求参数结构类型
+// 泛型参数:
+// - T: 具体的请求参数结构类型
+// 参数:
+// - params: 请求参数
+// 返回值:
+// - 具体的请求参数结构
+// - 错误
+func ToConcrete[T any](params any) (T, error) {
+	zero := reflectutils.Zero[T]()
+
+	if params == nil {
+		return zero, errors.New("请求参数为nil")
+	}
+
+	concrete, ok := params.(T)
+	if !ok {
+		return zero, errors.New("请求参数转化失败")
+	}
+
+	return concrete, nil
+}
+
+func getRequestParamsValue(params any, fieldName string) (*reflect.Value, error) {
+	if params == nil {
+		return nil, errors.New("请求参数为nil")
+	}
+
+	paramsValue := reflect.ValueOf(params)
+
+	if !reflectutils.IsValueStructOrStructPointer(paramsValue) {
+		return nil, errors.New("请求参数必须是结构或结构指针")
+	}
+
+	fieldValue := reflectutils.PointerValueElem(paramsValue).FieldByName(fieldName)
+
+	return &fieldValue, nil
+}

+ 1 - 1
framework/core/mqtt_api/response/response.go → framework/core/mqtt_api/mqtt_response/mqtt_response.go

@@ -1,4 +1,4 @@
-package response
+package mqtt_response
 
 import (
 	"encoding/json"

+ 79 - 43
framework/core/mqtt_api/router.go

@@ -7,6 +7,11 @@ import (
 
 type Handler func(c *Context)
 
+type subscribeItem struct {
+	handlers   []Handler
+	subscribed bool
+}
+
 type Router struct {
 	Group string
 
@@ -17,19 +22,19 @@ type Router struct {
 
 	globalHandlers []Handler
 
-	topicHandlersMutex *sync.Mutex
-	topicHandlers      map[string][]Handler
+	subscribeItemsMutex *sync.Mutex
+	subscribeItems      map[string]*subscribeItem
 }
 
 func NewRouter(group string, globalHandlers []Handler) *Router {
 	return &Router{
-		Group:              group,
-		mqttClient:         nil,
-		contextsMutex:      &sync.Mutex{},
-		contexts:           make([]*Context, 0),
-		globalHandlers:     globalHandlers,
-		topicHandlersMutex: &sync.Mutex{},
-		topicHandlers:      make(map[string][]Handler),
+		Group:               group,
+		mqttClient:          nil,
+		contextsMutex:       &sync.Mutex{},
+		contexts:            make([]*Context, 0),
+		globalHandlers:      globalHandlers,
+		subscribeItemsMutex: &sync.Mutex{},
+		subscribeItems:      make(map[string]*subscribeItem),
 	}
 }
 
@@ -83,17 +88,13 @@ func (router *Router) Finish() {
 	router.mqttClient = nil
 }
 
-func (router *Router) AddTopic(topic string, handlers ...Handler) error {
-	added := router.addTopicHandlers(topic, handlers...)
-	if !added {
-		return nil
-	}
-
-	if router.mqttClient == nil {
-		return nil
-	}
+func (router *Router) AddGlobalHandlers(handlers ...Handler) {
+	router.globalHandlers = append(router.globalHandlers, handlers...)
+}
 
-	err := router.mqttClient.subscribe(topic, func(topic string, data []byte) {})
+func (router *Router) AddTopic(topic string, handlers ...Handler) error {
+	allHandlers := append(router.globalHandlers, handlers...)
+	err := router.addAndSubscribeTopicHandlers(router.Group+topic, allHandlers...)
 	if err != nil {
 		return err
 	}
@@ -102,16 +103,12 @@ func (router *Router) AddTopic(topic string, handlers ...Handler) error {
 }
 
 func (router *Router) subscribeTopics(client *MqttClient) {
+	if !client.mqttClient.IsConnected() {
+		return
+	}
+
 	router.rangeTopicHandlers(func(topic string, handlers []Handler) {
-		err := client.subscribe(topic, func(topic string, data []byte) {
-			c, err := newContext(router.mqttClient, topic, data, handlers)
-			if err != nil {
-				logger.GetInstance().Error(err)
-				return
-			}
-
-			c.Next()
-		})
+		err := router.subscribeMqttClient(client, topic, handlers)
 		if err != nil {
 			logger.GetInstance().Error(err)
 			return
@@ -120,6 +117,10 @@ func (router *Router) subscribeTopics(client *MqttClient) {
 }
 
 func (router *Router) unsubscribeTopics(client *MqttClient) {
+	if !client.mqttClient.IsConnected() {
+		return
+	}
+
 	router.rangeTopicHandlers(func(topic string, handlers []Handler) {
 		err := client.unsubscribe(topic)
 		if err != nil {
@@ -129,35 +130,70 @@ func (router *Router) unsubscribeTopics(client *MqttClient) {
 	})
 }
 
-func (router *Router) addTopicHandlers(topic string, handler ...Handler) bool {
-	router.topicHandlersMutex.Lock()
-	defer router.topicHandlersMutex.Unlock()
+func (router *Router) subscribeMqttClient(client *MqttClient, topic string, handlers []Handler) error {
+	err := client.subscribe(topic, func(topic string, data []byte) {
+		c, err := newContext(router.mqttClient, topic, data, handlers)
+		if err != nil {
+			logger.GetInstance().Error(err)
+			return
+		}
+
+		c.callHandlers(data)
+
+		c.Next()
+	})
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (router *Router) addAndSubscribeTopicHandlers(topic string, handlers ...Handler) error {
+	router.subscribeItemsMutex.Lock()
+	defer router.subscribeItemsMutex.Unlock()
+
+	router.subscribeItems[topic] = &subscribeItem{
+		handlers:   handlers,
+		subscribed: false,
+	}
+
+	if router.mqttClient == nil {
+		return nil
+	}
 
-	if router.topicHandlers[topic] != nil && len(router.topicHandlers[topic]) > 0 {
-		return false
+	if !router.mqttClient.mqttClient.IsConnected() {
+		return nil
 	}
 
-	router.topicHandlers[topic] = append(router.globalHandlers, handler...)
+	err := router.subscribeMqttClient(router.mqttClient, topic, handlers)
+	if err != nil {
+		return err
+	}
+
+	router.subscribeItems[topic].subscribed = true
 
-	return true
+	return nil
 }
 
 func (router *Router) removeTopicHandlers(topic string) {
-	router.topicHandlersMutex.Lock()
-	defer router.topicHandlersMutex.Unlock()
+	router.subscribeItemsMutex.Lock()
+	defer router.subscribeItemsMutex.Unlock()
 
-	if router.topicHandlers[topic] == nil || len(router.topicHandlers[topic]) == 0 {
+	if router.subscribeItems[topic] == nil {
 		return
 	}
 
-	delete(router.topicHandlers, topic)
+	delete(router.subscribeItems, topic)
 }
 
 func (router *Router) rangeTopicHandlers(callback func(topic string, handlers []Handler)) {
-	router.topicHandlersMutex.Lock()
-	defer router.topicHandlersMutex.Unlock()
+	router.subscribeItemsMutex.Lock()
+	defer router.subscribeItemsMutex.Unlock()
 
-	for topic, topicHandlers := range router.topicHandlers {
-		callback(topic, topicHandlers)
+	for topic, item := range router.subscribeItems {
+		if !item.subscribed {
+			callback(topic, item.handlers)
+		}
 	}
 }

+ 10 - 9
framework/mqtt_binding/bind_item.go

@@ -5,8 +5,8 @@ import (
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure"
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
 	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api"
-	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api/request"
-	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api/response"
+	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api/mqtt_request"
+	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api/mqtt_response"
 	"git.sxidc.com/go-tools/utils/reflectutils"
 	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/pkg/errors"
@@ -55,8 +55,8 @@ type ServiceFunc[O any] func(c *mqtt_api.Context, params any, objects []domain.O
 // - item: 执行绑定的参数
 // - responseIdentifier: 全局响应标识符
 // - middlewares: 该绑定的中间件
-func Bind[O any](binder *Binder, item *BindItem[O], responseIdentifier string, middlewares ...Middleware) {
-	item.bind(binder, responseIdentifier, middlewares...)
+func Bind[O any](binder *Binder, item *BindItem[O], middlewares ...Middleware) {
+	item.bind(binder, middlewares...)
 }
 
 // BindItem 通用BindItem
@@ -65,7 +65,7 @@ type BindItem[O any] struct {
 	Topic string
 
 	// 响应泛型函数,如果不响应,需要使用NoResponse零值占位
-	SendResponseFunc response.SendResponseFunc[O]
+	SendResponseFunc mqtt_response.SendResponseFunc[O]
 
 	// 使用的请求参数,非必传,当请求参数为nil时,说明该接口没有参数
 	RequestParams any
@@ -87,7 +87,7 @@ type BindItem[O any] struct {
 	ServiceFunc ServiceFunc[O]
 }
 
-func (item *BindItem[O]) bind(binder *Binder, responseIdentifier string, middlewares ...Middleware) {
+func (item *BindItem[O]) bind(binder *Binder, middlewares ...Middleware) {
 	if strutils.IsStringEmpty(item.Topic) {
 		panic("需要指定主题")
 	}
@@ -106,7 +106,7 @@ func (item *BindItem[O]) bind(binder *Binder, responseIdentifier string, middlew
 		panic("bind的输出类型不能使用指针类型")
 	}
 
-	if outputZeroValue.IsValid() && strings.Contains(outputZeroValue.String(), "response.InfosData") {
+	if outputZeroValue.IsValid() && strings.Contains(outputZeroValue.String(), "mqtt_response.InfosData") {
 		infosField := outputZeroValue.FieldByName("Infos")
 		if infosField.IsValid() && infosField.Type().Elem().Kind() == reflect.Pointer {
 			panic("bind的输出类型不能使用指针类型")
@@ -123,6 +123,7 @@ func (item *BindItem[O]) bind(binder *Binder, responseIdentifier string, middlew
 
 	handlers := append(apiMiddlewares, func(c *mqtt_api.Context) {
 		var params any
+		var responseIdentifier string
 
 		// 有请求数据
 		if item.RequestParams != nil {
@@ -141,7 +142,7 @@ func (item *BindItem[O]) bind(binder *Binder, responseIdentifier string, middlew
 			}
 
 			// 将请求数据解析到请求参数中
-			err := request.BindingJson(c, params)
+			err := mqtt_request.BindingJson(c, params)
 			if err != nil {
 				logger.GetInstance().Error(err)
 				item.SendResponseFunc(c, responseIdentifier, outputZero, err)
@@ -186,7 +187,7 @@ func (item *BindItem[O]) bind(binder *Binder, responseIdentifier string, middlew
 					obj := reflect.New(reflectutils.PointerTypeElem(objectType)).Interface().(domain.Object)
 
 					if params != nil {
-						err := request.AssignRequestParamsToDomainObject(params, obj)
+						err := mqtt_request.AssignRequestParamsToDomainObject(params, obj)
 						if err != nil {
 							item.SendResponseFunc(c, responseIdentifier, outputZero, err)
 							return

+ 101 - 0
test/mqtt_binding_test.go

@@ -0,0 +1,101 @@
+package test
+
+import (
+	"fmt"
+	"git.sxidc.com/go-framework/baize"
+	"git.sxidc.com/go-framework/baize/framework/core/application"
+	"git.sxidc.com/go-framework/baize/framework/core/domain"
+	"git.sxidc.com/go-framework/baize/framework/core/infrastructure"
+	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api"
+	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api/mqtt_request"
+	"git.sxidc.com/go-framework/baize/framework/core/mqtt_api/mqtt_response"
+	"git.sxidc.com/go-framework/baize/framework/mqtt_binding"
+	"testing"
+)
+
+type Hello struct {
+	What  string `json:"what"`
+	Reply string `json:"reply"`
+}
+
+func TestMqttBinding(t *testing.T) {
+	app := baize.NewApplication(application.Config{
+		ApiConfig: application.ApiConfig{
+			UrlPrefix: "test",
+			Port:      "10080",
+		},
+		InfrastructureConfig: application.InfrastructureConfig{},
+		MqttApiConfig: &application.MqttApiConfig{
+			TopicPrefix:  "test",
+			LogSkipPaths: []string{"test/version"},
+			MqttConfig: application.MqttConfig{
+				UserName:        "admin",
+				Password:        "mtyzxhc123",
+				Address:         "localhost:1883",
+				ClientID:        "test",
+				KeepAliveSec:    60,
+				PingTimeoutSec:  60,
+				WriteTimeoutSec: 60,
+			},
+		},
+	})
+
+	defer baize.DestroyApplication(app)
+
+	err := app.StartMqttApi()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer func(app *application.App) {
+		err := app.FinishMqttApi()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}(app)
+
+	app.MqttApi().Router().AddGlobalHandlers(func(c *mqtt_api.Context) {
+		fmt.Println("Global")
+	})
+
+	mqttBinder := mqtt_binding.NewBinder(app.MqttApi().Router(), app.Infrastructure())
+
+	mqtt_binding.Bind(mqttBinder, &mqtt_binding.BindItem[any]{
+		Topic:            "/version",
+		SendResponseFunc: mqtt_response.SendMsgResponse,
+		ServiceFunc: func(c *mqtt_api.Context, params any, objects []domain.Object, i *infrastructure.Infrastructure) (any, error) {
+			fmt.Println("Version")
+			return nil, nil
+		},
+	}, func(c *mqtt_api.Context, i *infrastructure.Infrastructure) {
+		fmt.Println("Version Middleware")
+	})
+
+	mqtt_binding.Bind(mqttBinder, &mqtt_binding.BindItem[map[string]any]{
+		Topic:            "/hello",
+		SendResponseFunc: mqtt_response.SendMapResponse,
+		RequestParams:    &Hello{},
+		ResponseIdentifierFunc: func(c *mqtt_api.Context, params any) (string, error) {
+			req, err := mqtt_request.ToConcrete[*Hello](params)
+			if err != nil {
+				return "", err
+			}
+
+			return req.Reply, nil
+		},
+		ServiceFunc: func(c *mqtt_api.Context, params any, objects []domain.Object, i *infrastructure.Infrastructure) (map[string]any, error) {
+			req, err := mqtt_request.ToConcrete[*Hello](params)
+			if err != nil {
+				return make(map[string]any), err
+			}
+
+			return map[string]any{"message": "Hello " + req.What}, nil
+		},
+	}, func(c *mqtt_api.Context, i *infrastructure.Infrastructure) {
+		fmt.Println("Hello Middleware")
+	})
+
+	for {
+
+	}
+}