From 91d825ef04c41efb3aeae21a9b796672f10fd229 Mon Sep 17 00:00:00 2001 From: Omer Akram Date: Fri, 31 May 2024 01:35:38 +0500 Subject: [PATCH] make session concurrency-safe --- session.go | 64 +++++++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/session.go b/session.go index 75dead8..2188a93 100644 --- a/session.go +++ b/session.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "sync" "github.com/xconnio/wampproto-go" "github.com/xconnio/wampproto-go/messages" @@ -21,10 +22,10 @@ type Session struct { idGen *wampproto.SessionScopeIDGenerator // remote procedure calls data structures - registerRequests map[int64]chan *RegisterResponse - unregisterRequests map[int64]chan *UnRegisterResponse - registrations map[int64]InvocationHandler - callRequests map[int64]chan *CallResponse + registerRequests sync.Map + unregisterRequests sync.Map + registrations sync.Map + callRequests sync.Map // publish subscribe data structures subscribeRequests map[int64]chan *SubscribeResponse @@ -39,10 +40,10 @@ func NewSession(base BaseSession, serializer serializers.Serializer) *Session { proto: wampproto.NewSession(serializer), idGen: &wampproto.SessionScopeIDGenerator{}, - registerRequests: map[int64]chan *RegisterResponse{}, - unregisterRequests: map[int64]chan *UnRegisterResponse{}, - registrations: map[int64]InvocationHandler{}, - callRequests: map[int64]chan *CallResponse{}, + registerRequests: sync.Map{}, + unregisterRequests: sync.Map{}, + registrations: sync.Map{}, + callRequests: sync.Map{}, subscribeRequests: map[int64]chan *SubscribeResponse{}, unsubscribeRequests: map[int64]chan *UnSubscribeResponse{}, @@ -81,31 +82,35 @@ func (s *Session) processIncomingMessage(msg messages.Message) error { switch msg.Type() { case messages.MessageTypeRegistered: registered := msg.(*messages.Registered) - request, exists := s.registerRequests[registered.RequestID()] + request, exists := s.registerRequests.Load(registered.RequestID()) if !exists { return fmt.Errorf("received REGISTERED for unknown request") } - request <- &RegisterResponse{msg: registered} + requestChan := request.(chan *RegisterResponse) + requestChan <- &RegisterResponse{msg: registered} case messages.MessageTypeUnRegistered: unregistered := msg.(*messages.UnRegistered) - request, exists := s.unregisterRequests[unregistered.RequestID()] + request, exists := s.unregisterRequests.Load(unregistered.RequestID()) if !exists { return fmt.Errorf("received UNREGISTERED for unknown request") } - request <- &UnRegisterResponse{msg: unregistered} + requestChan := request.(chan *UnRegisterResponse) + requestChan <- &UnRegisterResponse{msg: unregistered} case messages.MessageTypeResult: result := msg.(*messages.Result) - request, exists := s.callRequests[result.RequestID()] + request, exists := s.callRequests.Load(result.RequestID()) if !exists { return fmt.Errorf("received RESULT for unknown request") } - request <- &CallResponse{msg: result} + req := request.(chan *CallResponse) + req <- &CallResponse{msg: result} case messages.MessageTypeInvocation: invocation := msg.(*messages.Invocation) - endpoint := s.registrations[invocation.RegistrationID()] + end, _ := s.registrations.Load(invocation.RegistrationID()) + endpoint := end.(InvocationHandler) inv := &Invocation{ Args: invocation.Args(), @@ -167,32 +172,31 @@ func (s *Session) processIncomingMessage(msg messages.Message) error { errorMsg := msg.(*messages.Error) switch errorMsg.MessageType() { case messages.MessageTypeCall: - responseChan, exists := s.callRequests[errorMsg.RequestID()] + response, exists := s.callRequests.LoadAndDelete(errorMsg.RequestID()) if !exists { return fmt.Errorf("received ERROR for invalid call request") } - delete(s.callRequests, errorMsg.RequestID()) err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()} + responseChan := response.(chan *CallResponse) responseChan <- &CallResponse{error: err} return nil case messages.MessageTypeRegister: - responseChan, exists := s.registerRequests[errorMsg.RequestID()] + request, exists := s.registerRequests.LoadAndDelete(errorMsg.RequestID()) if !exists { return fmt.Errorf("received ERROR for invalid register request") } - delete(s.registerRequests, errorMsg.RequestID()) err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()} - responseChan <- &RegisterResponse{error: err} + requestChan := request.(chan *RegisterResponse) + requestChan <- &RegisterResponse{error: err} return nil case messages.MessageTypeUnRegister: - _, exists := s.unregisterRequests[errorMsg.RequestID()] + _, exists := s.unregisterRequests.LoadAndDelete(errorMsg.RequestID()) if !exists { return fmt.Errorf("received ERROR for invalid unregister request") } - delete(s.unregisterRequests, errorMsg.RequestID()) return nil case messages.MessageTypeSubscribe: _, exists := s.subscribeRequests[errorMsg.RequestID()] @@ -238,8 +242,8 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca } channel := make(chan *RegisterResponse, 1) - s.registerRequests[register.RequestID()] = channel - defer delete(s.registerRequests, register.RequestID()) + s.registerRequests.Store(register.RequestID(), channel) + defer s.registerRequests.Delete(register.RequestID()) if err = s.base.Write(toSend); err != nil { return nil, err @@ -251,7 +255,7 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca return nil, response.error } - s.registrations[response.msg.RegistrationID()] = handler + s.registrations.Store(response.msg.RegistrationID(), handler) registration := &Registration{ ID: response.msg.RegistrationID(), } @@ -269,8 +273,8 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error { } channel := make(chan *UnRegisterResponse, 1) - s.unregisterRequests[unregister.RequestID()] = channel - defer delete(s.unregisterRequests, unregister.RequestID()) + s.unregisterRequests.Store(unregister.RequestID(), channel) + defer s.unregisterRequests.Delete(unregister.RequestID()) if err = s.base.Write(toSend); err != nil { return err @@ -282,7 +286,7 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error { return response.error } - delete(s.registrations, registrationID) + s.registrations.Delete(registrationID) return nil case <-ctx.Done(): return fmt.Errorf("unregister request timed") @@ -299,8 +303,8 @@ func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs } channel := make(chan *CallResponse, 1) - s.callRequests[call.RequestID()] = channel - defer delete(s.callRequests, call.RequestID()) + s.callRequests.Store(call.RequestID(), channel) + defer s.callRequests.Delete(call.RequestID()) if err = s.base.Write(toSend); err != nil { return nil, err }