package websocket import ( "github.com/olahol/melody" "github.com/pkg/errors" "net/http" "sync" "time" ) type HandleConnectFunc func(context map[string]any) type HandleDisconnectFunc func(context map[string]any) type HandleErrorFunc func(err error, context map[string]any) type HandleCloseFunc func(i int, s string, context map[string]any) error type HandlePongFunc func(context map[string]any) type HandleMessageFunc func(message []byte, context map[string]any) type BroadCastFilterCallback func(context map[string]any) bool var managerInstance *Manager func Init() { if managerInstance == nil { managerInstance = &Manager{ melodyMapMutex: &sync.RWMutex{}, melodyMap: make(map[string]*melody.Melody), } } } func Destroy() { if managerInstance != nil { managerInstance.melodyMapMutex.Lock() defer managerInstance.melodyMapMutex.Unlock() for _, melodyInstance := range managerInstance.melodyMap { err := melodyInstance.Close() if err != nil { panic(err) } } managerInstance.melodyMap = nil managerInstance = nil } managerInstance = nil } func GetInstance() *Manager { return managerInstance } type Manager struct { melodyMapMutex *sync.RWMutex melodyMap map[string]*melody.Melody } func (m *Manager) RegisterHub(groupID string, opts ...InitOption) { m.melodyMapMutex.Lock() defer m.melodyMapMutex.Unlock() _, ok := m.melodyMap[groupID] if ok { return } melodyInstance := melody.New() options := new(InitOptions) for _, opt := range opts { opt(options) } if options.writeWaitSec != 0 { melodyInstance.Config.WriteWait = time.Duration(options.writeWaitSec) * time.Second } if options.pongWaitSec != 0 { melodyInstance.Config.PongWait = time.Duration(options.pongWaitSec) * time.Second } if options.pingPeriodSec != 0 { melodyInstance.Config.PingPeriod = time.Duration(options.pingPeriodSec) * time.Second } if options.maxMessageSize != 0 { melodyInstance.Config.MaxMessageSize = options.maxMessageSize } if options.messageBufferSize != 0 { melodyInstance.Config.MessageBufferSize = options.messageBufferSize } melodyInstance.Config.ConcurrentMessageHandling = options.concurrentMessageHandling melodyInstance.Upgrader.CheckOrigin = func(r *http.Request) bool { return true } m.melodyMap[groupID] = melodyInstance } func (m *Manager) UnregisterHub(groupID string) { m.melodyMapMutex.Lock() defer m.melodyMapMutex.Unlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return } err := melodyInstance.Close() if err != nil { panic(err) } melodyInstance = nil delete(m.melodyMap, groupID) } func (m *Manager) HandleConnect(groupID string, handleConnectFunc HandleConnectFunc) { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return } melodyInstance.HandleConnect(func(session *melody.Session) { if handleConnectFunc != nil { handleConnectFunc(session.Keys) } }) } func (m *Manager) HandleDisconnect(groupID string, handleDisconnectFunc HandleDisconnectFunc) { m.melodyMapMutex.Lock() defer m.melodyMapMutex.Unlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return } melodyInstance.HandleDisconnect(func(session *melody.Session) { if handleDisconnectFunc != nil { handleDisconnectFunc(session.Keys) } }) } func (m *Manager) HandleError(groupID string, handleErrorFunc HandleErrorFunc) { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return } melodyInstance.HandleError(func(session *melody.Session, err error) { if handleErrorFunc != nil { handleErrorFunc(err, session.Keys) } }) } func (m *Manager) HandleClose(groupID string, handleCloseFunc HandleCloseFunc) { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return } melodyInstance.HandleClose(func(session *melody.Session, i int, s string) error { if handleCloseFunc != nil { err := handleCloseFunc(i, s, session.Keys) if err != nil { return err } } return nil }) } func (m *Manager) HandlePong(groupID string, handlePongFunc HandlePongFunc) { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return } melodyInstance.HandlePong(func(session *melody.Session) { if handlePongFunc != nil { handlePongFunc(session.Keys) } }) } func (m *Manager) HandleRequest(groupID string, w http.ResponseWriter, r *http.Request, opts ...ConnectionOption) error { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return errors.New("groupID尚未注册") } sessionMap := make(map[string]any) for _, opt := range opts { opt(&sessionMap) } err := melodyInstance.HandleRequestWithKeys(w, r, sessionMap) if err != nil { return err } return nil } func (m *Manager) HandleMessage(groupID string, handleMessageFunc HandleMessageFunc) error { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return errors.New("groupID尚未注册") } melodyInstance.HandleMessage(func(session *melody.Session, bytes []byte) { if handleMessageFunc != nil { handleMessageFunc(bytes, session.Keys) } }) return nil } func (m *Manager) BroadCast(groupID string, msg []byte) error { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return errors.New("groupID尚未注册") } return melodyInstance.Broadcast(msg) } func (m *Manager) BroadCastFilter(groupID string, msg []byte, filterCallback BroadCastFilterCallback) error { m.melodyMapMutex.RLock() defer m.melodyMapMutex.RUnlock() melodyInstance, ok := m.melodyMap[groupID] if !ok { return errors.New("groupID尚未注册") } return melodyInstance.BroadcastFilter(msg, func(session *melody.Session) bool { return filterCallback(session.Keys) }) } type InitOption func(*InitOptions) type InitOptions struct { writeWaitSec int64 pongWaitSec int64 pingPeriodSec int64 maxMessageSize int64 messageBufferSize int concurrentMessageHandling bool } func InitWithWriteWaitSec(writeWaitSec int64) InitOption { return func(options *InitOptions) { options.writeWaitSec = writeWaitSec } } func InitWithPongWaitSec(pongWaitSec int64) InitOption { return func(options *InitOptions) { options.pongWaitSec = pongWaitSec } } func InitWithPingPeriodSec(pingPeriodSec int64) InitOption { return func(options *InitOptions) { options.pingPeriodSec = pingPeriodSec } } func InitWithMaxMessageSize(maxMessageSize int64) InitOption { return func(options *InitOptions) { options.maxMessageSize = maxMessageSize } } func InitWithMaxMessageBufferSize(messageBufferSize int) InitOption { return func(options *InitOptions) { options.messageBufferSize = messageBufferSize } } func InitWithConcurrentMessageHandling(concurrentMessageHandling bool) InitOption { return func(options *InitOptions) { options.concurrentMessageHandling = concurrentMessageHandling } } type ConnectionOption func(sessionMap *map[string]any) func WithConnectionContext(context map[string]any) ConnectionOption { return func(sessionMap *map[string]any) { *sessionMap = context } }