websocket.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. package websocket
  2. import (
  3. "github.com/olahol/melody"
  4. "github.com/pkg/errors"
  5. "net/http"
  6. "sync"
  7. "time"
  8. )
  9. type HandleConnectFunc func(context map[string]any)
  10. type HandleDisconnectFunc func(context map[string]any)
  11. type HandleErrorFunc func(err error, context map[string]any)
  12. type HandleCloseFunc func(i int, s string, context map[string]any) error
  13. type HandlePongFunc func(context map[string]any)
  14. type HandleMessageFunc func(message []byte, context map[string]any)
  15. type BroadCastFilterCallback func(context map[string]any) bool
  16. var managerInstance *Manager
  17. func Init() {
  18. if managerInstance == nil {
  19. managerInstance = &Manager{
  20. melodyMapMutex: &sync.RWMutex{},
  21. melodyMap: make(map[string]*melody.Melody),
  22. }
  23. }
  24. }
  25. func Destroy() {
  26. if managerInstance != nil {
  27. managerInstance.melodyMapMutex.Lock()
  28. defer managerInstance.melodyMapMutex.Unlock()
  29. for _, melodyInstance := range managerInstance.melodyMap {
  30. err := melodyInstance.Close()
  31. if err != nil {
  32. panic(err)
  33. }
  34. }
  35. managerInstance.melodyMap = nil
  36. managerInstance = nil
  37. }
  38. managerInstance = nil
  39. }
  40. func GetInstance() *Manager {
  41. return managerInstance
  42. }
  43. type Manager struct {
  44. melodyMapMutex *sync.RWMutex
  45. melodyMap map[string]*melody.Melody
  46. }
  47. func (m *Manager) RegisterHub(groupID string, opts ...InitOption) {
  48. m.melodyMapMutex.Lock()
  49. defer m.melodyMapMutex.Unlock()
  50. _, ok := m.melodyMap[groupID]
  51. if ok {
  52. return
  53. }
  54. melodyInstance := melody.New()
  55. options := new(InitOptions)
  56. for _, opt := range opts {
  57. opt(options)
  58. }
  59. if options.writeWaitSec != 0 {
  60. melodyInstance.Config.WriteWait = time.Duration(options.writeWaitSec) * time.Second
  61. }
  62. if options.pongWaitSec != 0 {
  63. melodyInstance.Config.PongWait = time.Duration(options.pongWaitSec) * time.Second
  64. }
  65. if options.pingPeriodSec != 0 {
  66. melodyInstance.Config.PingPeriod = time.Duration(options.pingPeriodSec) * time.Second
  67. }
  68. if options.maxMessageSize != 0 {
  69. melodyInstance.Config.MaxMessageSize = options.maxMessageSize
  70. }
  71. if options.messageBufferSize != 0 {
  72. melodyInstance.Config.MessageBufferSize = options.messageBufferSize
  73. }
  74. melodyInstance.Config.ConcurrentMessageHandling = options.concurrentMessageHandling
  75. melodyInstance.Upgrader.CheckOrigin = func(r *http.Request) bool { return true }
  76. m.melodyMap[groupID] = melodyInstance
  77. }
  78. func (m *Manager) UnregisterHub(groupID string) {
  79. m.melodyMapMutex.Lock()
  80. defer m.melodyMapMutex.Unlock()
  81. melodyInstance, ok := m.melodyMap[groupID]
  82. if !ok {
  83. return
  84. }
  85. err := melodyInstance.Close()
  86. if err != nil {
  87. panic(err)
  88. }
  89. melodyInstance = nil
  90. delete(m.melodyMap, groupID)
  91. }
  92. func (m *Manager) HandleConnect(groupID string, handleConnectFunc HandleConnectFunc) {
  93. m.melodyMapMutex.RLock()
  94. defer m.melodyMapMutex.RUnlock()
  95. melodyInstance, ok := m.melodyMap[groupID]
  96. if !ok {
  97. return
  98. }
  99. melodyInstance.HandleConnect(func(session *melody.Session) {
  100. if handleConnectFunc != nil {
  101. handleConnectFunc(session.Keys)
  102. }
  103. })
  104. }
  105. func (m *Manager) HandleDisconnect(groupID string, handleDisconnectFunc HandleDisconnectFunc) {
  106. m.melodyMapMutex.Lock()
  107. defer m.melodyMapMutex.Unlock()
  108. melodyInstance, ok := m.melodyMap[groupID]
  109. if !ok {
  110. return
  111. }
  112. melodyInstance.HandleDisconnect(func(session *melody.Session) {
  113. if handleDisconnectFunc != nil {
  114. handleDisconnectFunc(session.Keys)
  115. }
  116. })
  117. }
  118. func (m *Manager) HandleError(groupID string, handleErrorFunc HandleErrorFunc) {
  119. m.melodyMapMutex.RLock()
  120. defer m.melodyMapMutex.RUnlock()
  121. melodyInstance, ok := m.melodyMap[groupID]
  122. if !ok {
  123. return
  124. }
  125. melodyInstance.HandleError(func(session *melody.Session, err error) {
  126. if handleErrorFunc != nil {
  127. handleErrorFunc(err, session.Keys)
  128. }
  129. })
  130. }
  131. func (m *Manager) HandleClose(groupID string, handleCloseFunc HandleCloseFunc) {
  132. m.melodyMapMutex.RLock()
  133. defer m.melodyMapMutex.RUnlock()
  134. melodyInstance, ok := m.melodyMap[groupID]
  135. if !ok {
  136. return
  137. }
  138. melodyInstance.HandleClose(func(session *melody.Session, i int, s string) error {
  139. if handleCloseFunc != nil {
  140. err := handleCloseFunc(i, s, session.Keys)
  141. if err != nil {
  142. return err
  143. }
  144. }
  145. return nil
  146. })
  147. }
  148. func (m *Manager) HandlePong(groupID string, handlePongFunc HandlePongFunc) {
  149. m.melodyMapMutex.RLock()
  150. defer m.melodyMapMutex.RUnlock()
  151. melodyInstance, ok := m.melodyMap[groupID]
  152. if !ok {
  153. return
  154. }
  155. melodyInstance.HandlePong(func(session *melody.Session) {
  156. if handlePongFunc != nil {
  157. handlePongFunc(session.Keys)
  158. }
  159. })
  160. }
  161. func (m *Manager) HandleRequest(groupID string, w http.ResponseWriter, r *http.Request, opts ...ConnectionOption) error {
  162. m.melodyMapMutex.RLock()
  163. defer m.melodyMapMutex.RUnlock()
  164. melodyInstance, ok := m.melodyMap[groupID]
  165. if !ok {
  166. return errors.New("groupID尚未注册")
  167. }
  168. sessionMap := map[string]any{}
  169. for _, opt := range opts {
  170. opt(&sessionMap)
  171. }
  172. err := melodyInstance.HandleRequestWithKeys(w, r, sessionMap)
  173. if err != nil {
  174. return err
  175. }
  176. return nil
  177. }
  178. func (m *Manager) HandleMessage(groupID string, handleMessageFunc HandleMessageFunc) error {
  179. m.melodyMapMutex.RLock()
  180. defer m.melodyMapMutex.RUnlock()
  181. melodyInstance, ok := m.melodyMap[groupID]
  182. if !ok {
  183. return errors.New("groupID尚未注册")
  184. }
  185. melodyInstance.HandleMessage(func(session *melody.Session, bytes []byte) {
  186. if handleMessageFunc != nil {
  187. handleMessageFunc(bytes, session.Keys)
  188. }
  189. })
  190. return nil
  191. }
  192. func (m *Manager) BroadCast(groupID string, msg []byte) error {
  193. m.melodyMapMutex.RLock()
  194. defer m.melodyMapMutex.RUnlock()
  195. melodyInstance, ok := m.melodyMap[groupID]
  196. if !ok {
  197. return errors.New("groupID尚未注册")
  198. }
  199. return melodyInstance.Broadcast(msg)
  200. }
  201. func (m *Manager) BroadCastFilter(groupID string, msg []byte, filterCallback BroadCastFilterCallback) error {
  202. m.melodyMapMutex.RLock()
  203. defer m.melodyMapMutex.RUnlock()
  204. melodyInstance, ok := m.melodyMap[groupID]
  205. if !ok {
  206. return errors.New("groupID尚未注册")
  207. }
  208. return melodyInstance.BroadcastFilter(msg, func(session *melody.Session) bool {
  209. return filterCallback(session.Keys)
  210. })
  211. }
  212. type InitOption func(*InitOptions)
  213. type InitOptions struct {
  214. writeWaitSec int64
  215. pongWaitSec int64
  216. pingPeriodSec int64
  217. maxMessageSize int64
  218. messageBufferSize int
  219. concurrentMessageHandling bool
  220. }
  221. func InitWithWriteWaitSec(writeWaitSec int64) InitOption {
  222. return func(options *InitOptions) {
  223. options.writeWaitSec = writeWaitSec
  224. }
  225. }
  226. func InitWithPongWaitSec(pongWaitSec int64) InitOption {
  227. return func(options *InitOptions) {
  228. options.pongWaitSec = pongWaitSec
  229. }
  230. }
  231. func InitWithPingPeriodSec(pingPeriodSec int64) InitOption {
  232. return func(options *InitOptions) {
  233. options.pingPeriodSec = pingPeriodSec
  234. }
  235. }
  236. func InitWithMaxMessageSize(maxMessageSize int64) InitOption {
  237. return func(options *InitOptions) {
  238. options.maxMessageSize = maxMessageSize
  239. }
  240. }
  241. func InitWithMaxMessageBufferSize(messageBufferSize int) InitOption {
  242. return func(options *InitOptions) {
  243. options.messageBufferSize = messageBufferSize
  244. }
  245. }
  246. func InitWithConcurrentMessageHandling(concurrentMessageHandling bool) InitOption {
  247. return func(options *InitOptions) {
  248. options.concurrentMessageHandling = concurrentMessageHandling
  249. }
  250. }
  251. type ConnectionOption func(sessionMap *map[string]any)
  252. func WithConnectionContext(context map[string]any) ConnectionOption {
  253. return func(sessionMap *map[string]any) {
  254. sessionMap = &context
  255. }
  256. }