websocket.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. package websocket
  2. import (
  3. "github.com/olahol/melody"
  4. "github.com/pkg/errors"
  5. "net/http"
  6. "sync"
  7. "time"
  8. )
  9. type HandleConnectFunc func(context map[string]any)
  10. type HandleDisconnectFunc func(context map[string]any)
  11. type HandleErrorFunc func(err error, context map[string]any)
  12. type HandleCloseFunc func(i int, s string, context map[string]any) error
  13. type HandlePongFunc func(context map[string]any)
  14. type HandleMessageFunc func(message []byte, context map[string]any)
  15. var managerInstance *Manager
  16. func Init() {
  17. if managerInstance == nil {
  18. managerInstance = &Manager{
  19. melodyMapMutex: &sync.RWMutex{},
  20. melodyMap: make(map[string]*melody.Melody),
  21. }
  22. }
  23. }
  24. func Destroy() {
  25. if managerInstance != nil {
  26. managerInstance.melodyMapMutex.Lock()
  27. defer managerInstance.melodyMapMutex.Unlock()
  28. for _, melodyInstance := range managerInstance.melodyMap {
  29. err := melodyInstance.Close()
  30. if err != nil {
  31. panic(err)
  32. }
  33. }
  34. managerInstance.melodyMap = nil
  35. managerInstance = nil
  36. }
  37. managerInstance = nil
  38. }
  39. func GetInstance() *Manager {
  40. return managerInstance
  41. }
  42. type Manager struct {
  43. melodyMapMutex *sync.RWMutex
  44. melodyMap map[string]*melody.Melody
  45. }
  46. func (m *Manager) RegisterHub(groupID string, opts ...InitOption) {
  47. m.melodyMapMutex.Lock()
  48. defer m.melodyMapMutex.Unlock()
  49. _, ok := m.melodyMap[groupID]
  50. if ok {
  51. return
  52. }
  53. melodyInstance := melody.New()
  54. options := new(InitOptions)
  55. for _, opt := range opts {
  56. opt(options)
  57. }
  58. if options.writeWaitSec != 0 {
  59. melodyInstance.Config.WriteWait = time.Duration(options.writeWaitSec) * time.Second
  60. }
  61. if options.pongWaitSec != 0 {
  62. melodyInstance.Config.PongWait = time.Duration(options.pongWaitSec) * time.Second
  63. }
  64. if options.pingPeriodSec != 0 {
  65. melodyInstance.Config.PingPeriod = time.Duration(options.pingPeriodSec) * time.Second
  66. }
  67. if options.maxMessageSize != 0 {
  68. melodyInstance.Config.MaxMessageSize = options.maxMessageSize
  69. }
  70. if options.messageBufferSize != 0 {
  71. melodyInstance.Config.MessageBufferSize = options.messageBufferSize
  72. }
  73. melodyInstance.Config.ConcurrentMessageHandling = options.concurrentMessageHandling
  74. melodyInstance.Upgrader.CheckOrigin = func(r *http.Request) bool { return true }
  75. m.melodyMap[groupID] = melodyInstance
  76. }
  77. func (m *Manager) UnregisterHub(groupID string) {
  78. m.melodyMapMutex.Lock()
  79. defer m.melodyMapMutex.Unlock()
  80. melodyInstance, ok := m.melodyMap[groupID]
  81. if !ok {
  82. return
  83. }
  84. err := melodyInstance.Close()
  85. if err != nil {
  86. panic(err)
  87. }
  88. melodyInstance = nil
  89. delete(m.melodyMap, groupID)
  90. }
  91. func (m *Manager) HandleConnect(groupID string, handleConnectFunc HandleConnectFunc) {
  92. m.melodyMapMutex.RLock()
  93. defer m.melodyMapMutex.RUnlock()
  94. melodyInstance, ok := m.melodyMap[groupID]
  95. if !ok {
  96. return
  97. }
  98. melodyInstance.HandleConnect(func(session *melody.Session) {
  99. if handleConnectFunc != nil {
  100. handleConnectFunc(session.Keys)
  101. }
  102. })
  103. }
  104. func (m *Manager) HandleDisconnect(groupID string, handleDisconnectFunc HandleDisconnectFunc) {
  105. m.melodyMapMutex.Lock()
  106. defer m.melodyMapMutex.Unlock()
  107. melodyInstance, ok := m.melodyMap[groupID]
  108. if !ok {
  109. return
  110. }
  111. melodyInstance.HandleDisconnect(func(session *melody.Session) {
  112. if handleDisconnectFunc != nil {
  113. handleDisconnectFunc(session.Keys)
  114. }
  115. })
  116. }
  117. func (m *Manager) HandleError(groupID string, handleErrorFunc HandleErrorFunc) {
  118. m.melodyMapMutex.RLock()
  119. defer m.melodyMapMutex.RUnlock()
  120. melodyInstance, ok := m.melodyMap[groupID]
  121. if !ok {
  122. return
  123. }
  124. melodyInstance.HandleError(func(session *melody.Session, err error) {
  125. if handleErrorFunc != nil {
  126. handleErrorFunc(err, session.Keys)
  127. }
  128. })
  129. }
  130. func (m *Manager) HandleClose(groupID string, handleCloseFunc HandleCloseFunc) {
  131. m.melodyMapMutex.RLock()
  132. defer m.melodyMapMutex.RUnlock()
  133. melodyInstance, ok := m.melodyMap[groupID]
  134. if !ok {
  135. return
  136. }
  137. melodyInstance.HandleClose(func(session *melody.Session, i int, s string) error {
  138. if handleCloseFunc != nil {
  139. err := handleCloseFunc(i, s, session.Keys)
  140. if err != nil {
  141. return err
  142. }
  143. }
  144. return nil
  145. })
  146. }
  147. func (m *Manager) HandlePong(groupID string, handlePongFunc HandlePongFunc) {
  148. m.melodyMapMutex.RLock()
  149. defer m.melodyMapMutex.RUnlock()
  150. melodyInstance, ok := m.melodyMap[groupID]
  151. if !ok {
  152. return
  153. }
  154. melodyInstance.HandlePong(func(session *melody.Session) {
  155. if handlePongFunc != nil {
  156. handlePongFunc(session.Keys)
  157. }
  158. })
  159. }
  160. func (m *Manager) HandleRequest(groupID string, w http.ResponseWriter, r *http.Request, opts ...ConnectionOption) error {
  161. m.melodyMapMutex.RLock()
  162. defer m.melodyMapMutex.RUnlock()
  163. melodyInstance, ok := m.melodyMap[groupID]
  164. if !ok {
  165. return errors.New("groupID尚未注册")
  166. }
  167. sessionMap := map[string]any{}
  168. for _, opt := range opts {
  169. opt(&sessionMap)
  170. }
  171. err := melodyInstance.HandleRequestWithKeys(w, r, sessionMap)
  172. if err != nil {
  173. return err
  174. }
  175. return nil
  176. }
  177. func (m *Manager) HandleMessage(groupID string, handleMessageFunc HandleMessageFunc) error {
  178. m.melodyMapMutex.RLock()
  179. defer m.melodyMapMutex.RUnlock()
  180. melodyInstance, ok := m.melodyMap[groupID]
  181. if !ok {
  182. return errors.New("groupID尚未注册")
  183. }
  184. melodyInstance.HandleMessage(func(session *melody.Session, bytes []byte) {
  185. if handleMessageFunc != nil {
  186. handleMessageFunc(bytes, session.Keys)
  187. }
  188. })
  189. return nil
  190. }
  191. func (m *Manager) BroadCast(groupID string, msg []byte) error {
  192. m.melodyMapMutex.RLock()
  193. defer m.melodyMapMutex.RUnlock()
  194. melodyInstance, ok := m.melodyMap[groupID]
  195. if !ok {
  196. return errors.New("groupID尚未注册")
  197. }
  198. return melodyInstance.Broadcast(msg)
  199. }
  200. type InitOption func(*InitOptions)
  201. type InitOptions struct {
  202. writeWaitSec int64
  203. pongWaitSec int64
  204. pingPeriodSec int64
  205. maxMessageSize int64
  206. messageBufferSize int
  207. concurrentMessageHandling bool
  208. }
  209. func InitWithWriteWaitSec(writeWaitSec int64) InitOption {
  210. return func(options *InitOptions) {
  211. options.writeWaitSec = writeWaitSec
  212. }
  213. }
  214. func InitWithPongWaitSec(pongWaitSec int64) InitOption {
  215. return func(options *InitOptions) {
  216. options.pongWaitSec = pongWaitSec
  217. }
  218. }
  219. func InitWithPingPeriodSec(pingPeriodSec int64) InitOption {
  220. return func(options *InitOptions) {
  221. options.pingPeriodSec = pingPeriodSec
  222. }
  223. }
  224. func InitWithMaxMessageSize(maxMessageSize int64) InitOption {
  225. return func(options *InitOptions) {
  226. options.maxMessageSize = maxMessageSize
  227. }
  228. }
  229. func InitWithMaxMessageBufferSize(messageBufferSize int) InitOption {
  230. return func(options *InitOptions) {
  231. options.messageBufferSize = messageBufferSize
  232. }
  233. }
  234. func InitWithConcurrentMessageHandling(concurrentMessageHandling bool) InitOption {
  235. return func(options *InitOptions) {
  236. options.concurrentMessageHandling = concurrentMessageHandling
  237. }
  238. }
  239. type ConnectionOption func(sessionMap *map[string]any)
  240. func WithConnectionContext(context map[string]any) ConnectionOption {
  241. return func(sessionMap *map[string]any) {
  242. sessionMap = &context
  243. }
  244. }