Skip to content

Commit

Permalink
make session concurrency-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
om26er committed May 30, 2024
1 parent 6c23e09 commit 91d825e
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"sync"

"github.com/xconnio/wampproto-go"
"github.com/xconnio/wampproto-go/messages"
Expand All @@ -21,10 +22,10 @@ type Session struct {
idGen *wampproto.SessionScopeIDGenerator

// remote procedure calls data structures
registerRequests map[int64]chan *RegisterResponse
unregisterRequests map[int64]chan *UnRegisterResponse
registrations map[int64]InvocationHandler
callRequests map[int64]chan *CallResponse
registerRequests sync.Map
unregisterRequests sync.Map
registrations sync.Map
callRequests sync.Map

// publish subscribe data structures
subscribeRequests map[int64]chan *SubscribeResponse
Expand All @@ -39,10 +40,10 @@ func NewSession(base BaseSession, serializer serializers.Serializer) *Session {
proto: wampproto.NewSession(serializer),
idGen: &wampproto.SessionScopeIDGenerator{},

registerRequests: map[int64]chan *RegisterResponse{},
unregisterRequests: map[int64]chan *UnRegisterResponse{},
registrations: map[int64]InvocationHandler{},
callRequests: map[int64]chan *CallResponse{},
registerRequests: sync.Map{},
unregisterRequests: sync.Map{},
registrations: sync.Map{},
callRequests: sync.Map{},

subscribeRequests: map[int64]chan *SubscribeResponse{},
unsubscribeRequests: map[int64]chan *UnSubscribeResponse{},
Expand Down Expand Up @@ -81,31 +82,35 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
switch msg.Type() {
case messages.MessageTypeRegistered:
registered := msg.(*messages.Registered)
request, exists := s.registerRequests[registered.RequestID()]
request, exists := s.registerRequests.Load(registered.RequestID())
if !exists {
return fmt.Errorf("received REGISTERED for unknown request")
}

request <- &RegisterResponse{msg: registered}
requestChan := request.(chan *RegisterResponse)
requestChan <- &RegisterResponse{msg: registered}
case messages.MessageTypeUnRegistered:
unregistered := msg.(*messages.UnRegistered)
request, exists := s.unregisterRequests[unregistered.RequestID()]
request, exists := s.unregisterRequests.Load(unregistered.RequestID())
if !exists {
return fmt.Errorf("received UNREGISTERED for unknown request")
}

request <- &UnRegisterResponse{msg: unregistered}
requestChan := request.(chan *UnRegisterResponse)
requestChan <- &UnRegisterResponse{msg: unregistered}
case messages.MessageTypeResult:
result := msg.(*messages.Result)
request, exists := s.callRequests[result.RequestID()]
request, exists := s.callRequests.Load(result.RequestID())
if !exists {
return fmt.Errorf("received RESULT for unknown request")
}

request <- &CallResponse{msg: result}
req := request.(chan *CallResponse)
req <- &CallResponse{msg: result}
case messages.MessageTypeInvocation:
invocation := msg.(*messages.Invocation)
endpoint := s.registrations[invocation.RegistrationID()]
end, _ := s.registrations.Load(invocation.RegistrationID())
endpoint := end.(InvocationHandler)

inv := &Invocation{
Args: invocation.Args(),
Expand Down Expand Up @@ -167,32 +172,31 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
errorMsg := msg.(*messages.Error)
switch errorMsg.MessageType() {
case messages.MessageTypeCall:
responseChan, exists := s.callRequests[errorMsg.RequestID()]
response, exists := s.callRequests.LoadAndDelete(errorMsg.RequestID())
if !exists {
return fmt.Errorf("received ERROR for invalid call request")
}

delete(s.callRequests, errorMsg.RequestID())
err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()}
responseChan := response.(chan *CallResponse)
responseChan <- &CallResponse{error: err}
return nil
case messages.MessageTypeRegister:
responseChan, exists := s.registerRequests[errorMsg.RequestID()]
request, exists := s.registerRequests.LoadAndDelete(errorMsg.RequestID())
if !exists {
return fmt.Errorf("received ERROR for invalid register request")
}

delete(s.registerRequests, errorMsg.RequestID())
err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()}
responseChan <- &RegisterResponse{error: err}
requestChan := request.(chan *RegisterResponse)
requestChan <- &RegisterResponse{error: err}
return nil
case messages.MessageTypeUnRegister:
_, exists := s.unregisterRequests[errorMsg.RequestID()]
_, exists := s.unregisterRequests.LoadAndDelete(errorMsg.RequestID())
if !exists {
return fmt.Errorf("received ERROR for invalid unregister request")
}

delete(s.unregisterRequests, errorMsg.RequestID())
return nil
case messages.MessageTypeSubscribe:
_, exists := s.subscribeRequests[errorMsg.RequestID()]
Expand Down Expand Up @@ -238,8 +242,8 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca
}

channel := make(chan *RegisterResponse, 1)
s.registerRequests[register.RequestID()] = channel
defer delete(s.registerRequests, register.RequestID())
s.registerRequests.Store(register.RequestID(), channel)
defer s.registerRequests.Delete(register.RequestID())

if err = s.base.Write(toSend); err != nil {
return nil, err
Expand All @@ -251,7 +255,7 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca
return nil, response.error
}

s.registrations[response.msg.RegistrationID()] = handler
s.registrations.Store(response.msg.RegistrationID(), handler)
registration := &Registration{
ID: response.msg.RegistrationID(),
}
Expand All @@ -269,8 +273,8 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error {
}

channel := make(chan *UnRegisterResponse, 1)
s.unregisterRequests[unregister.RequestID()] = channel
defer delete(s.unregisterRequests, unregister.RequestID())
s.unregisterRequests.Store(unregister.RequestID(), channel)
defer s.unregisterRequests.Delete(unregister.RequestID())

if err = s.base.Write(toSend); err != nil {
return err
Expand All @@ -282,7 +286,7 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error {
return response.error
}

delete(s.registrations, registrationID)
s.registrations.Delete(registrationID)
return nil
case <-ctx.Done():
return fmt.Errorf("unregister request timed")
Expand All @@ -299,8 +303,8 @@ func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs
}

channel := make(chan *CallResponse, 1)
s.callRequests[call.RequestID()] = channel
defer delete(s.callRequests, call.RequestID())
s.callRequests.Store(call.RequestID(), channel)
defer s.callRequests.Delete(call.RequestID())
if err = s.base.Write(toSend); err != nil {
return nil, err
}
Expand Down

0 comments on commit 91d825e

Please sign in to comment.