diff --git a/client_test.go b/client_test.go index 9041c7a..19b03b1 100644 --- a/client_test.go +++ b/client_test.go @@ -7,6 +7,7 @@ import ( "log" "testing" + "github.com/gammazero/workerpool" "github.com/stretchr/testify/require" "github.com/xconnio/xconn-go" @@ -17,7 +18,9 @@ func connect(t *testing.T) *xconn.Session { defer func() { _ = listener.Close() }() address := fmt.Sprintf("ws://%s/ws", listener.Addr().String()) - client := &xconn.Client{} + client := &xconn.Client{ + SerializerSpec: xconn.JSONSerializerSpec, + } session, err := client.Connect(context.Background(), address, "realm1") require.NoError(t, err) @@ -54,11 +57,17 @@ func TestRegisterCall(t *testing.T) { require.NotNil(t, reg) t.Run("Call", func(t *testing.T) { - result, err := session.Call(context.Background(), "foo.bar", nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, result) + wp := workerpool.New(10) + for i := 0; i < 100; i++ { + wp.Submit(func() { + result, err := session.Call(context.Background(), "foo.bar", nil, nil, nil) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "hello", result.Args[0]) + }) + } - require.Equal(t, "hello", result.Args[0]) + wp.StopWait() }) } diff --git a/go.mod b/go.mod index a04de23..6e61729 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,16 @@ go 1.20 require ( github.com/gammazero/nexus/v3 v3.2.2 + github.com/gammazero/workerpool v1.1.3 github.com/gobwas/ws v1.4.0 github.com/stretchr/testify v1.8.4 - github.com/xconnio/wampproto-go v0.0.0-20240530132134-a9a2ca11944a + github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fxamacker/cbor/v2 v2.6.0 // indirect + github.com/gammazero/deque v0.2.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/gorilla/websocket v1.5.0 // indirect diff --git a/go.sum b/go.sum index 85e1e7d..08a80c3 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/gammazero/deque v0.2.0 h1:SkieyNB4bg2/uZZLxvya0Pq6diUlwx7m2TeT7GAIWaA= +github.com/gammazero/deque v0.2.0/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/gammazero/nexus/v3 v3.2.2 h1:uEBe4rKIcbBcbdP6XuyKUhnWBXxT0BnJrecG9+yZSTs= github.com/gammazero/nexus/v3 v3.2.2/go.mod h1:55oZwPZFgRFCEjpMj1kdzffiPORKKmRsipSY8BeKRvY= +github.com/gammazero/workerpool v1.1.3 h1:WixN4xzukFoN0XSeXF6puqEqFTl2mECI9S6W44HWy9Q= +github.com/gammazero/workerpool v1.1.3/go.mod h1:wPjyBLDbyKnUn2XwwyD3EEwo9dHutia9/fwNmSHWACc= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= @@ -24,8 +28,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/xconnio/wampproto-go v0.0.0-20240530132134-a9a2ca11944a h1:0Vb6+/sNho0zn7kjRsdBvdGylFkz9uZZzeT1RrgKHIU= -github.com/xconnio/wampproto-go v0.0.0-20240530132134-a9a2ca11944a/go.mod h1:BH0AFRLJ9POvVfxsFd9GyvA15U9o0XYQfq8TdkqO2vQ= +github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226 h1:1UFs+1ev6G1qDVgf5tGqnhsm5e9btuon41ogyrR1QD4= +github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226/go.mod h1:BH0AFRLJ9POvVfxsFd9GyvA15U9o0XYQfq8TdkqO2vQ= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= diff --git a/peer.go b/peer.go index 177a474..86c90b5 100644 --- a/peer.go +++ b/peer.go @@ -1,6 +1,9 @@ package xconn -import "net" +import ( + "net" + "sync" +) func NewBaseSession(id int64, realm, authID, authRole string, cl Peer) BaseSession { return &baseSession{ @@ -80,6 +83,8 @@ type WebSocketPeer struct { conn net.Conn wsReader ReaderFunc wsWriter WriterFunc + + wm sync.Mutex } func (c *WebSocketPeer) Read() ([]byte, error) { @@ -87,6 +92,9 @@ func (c *WebSocketPeer) Read() ([]byte, error) { } func (c *WebSocketPeer) Write(bytes []byte) error { + c.wm.Lock() + defer c.wm.Unlock() + return c.wsWriter(c.conn, bytes) } diff --git a/session.go b/session.go index 75dead8..2188a93 100644 --- a/session.go +++ b/session.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "sync" "github.com/xconnio/wampproto-go" "github.com/xconnio/wampproto-go/messages" @@ -21,10 +22,10 @@ type Session struct { idGen *wampproto.SessionScopeIDGenerator // remote procedure calls data structures - registerRequests map[int64]chan *RegisterResponse - unregisterRequests map[int64]chan *UnRegisterResponse - registrations map[int64]InvocationHandler - callRequests map[int64]chan *CallResponse + registerRequests sync.Map + unregisterRequests sync.Map + registrations sync.Map + callRequests sync.Map // publish subscribe data structures subscribeRequests map[int64]chan *SubscribeResponse @@ -39,10 +40,10 @@ func NewSession(base BaseSession, serializer serializers.Serializer) *Session { proto: wampproto.NewSession(serializer), idGen: &wampproto.SessionScopeIDGenerator{}, - registerRequests: map[int64]chan *RegisterResponse{}, - unregisterRequests: map[int64]chan *UnRegisterResponse{}, - registrations: map[int64]InvocationHandler{}, - callRequests: map[int64]chan *CallResponse{}, + registerRequests: sync.Map{}, + unregisterRequests: sync.Map{}, + registrations: sync.Map{}, + callRequests: sync.Map{}, subscribeRequests: map[int64]chan *SubscribeResponse{}, unsubscribeRequests: map[int64]chan *UnSubscribeResponse{}, @@ -81,31 +82,35 @@ func (s *Session) processIncomingMessage(msg messages.Message) error { switch msg.Type() { case messages.MessageTypeRegistered: registered := msg.(*messages.Registered) - request, exists := s.registerRequests[registered.RequestID()] + request, exists := s.registerRequests.Load(registered.RequestID()) if !exists { return fmt.Errorf("received REGISTERED for unknown request") } - request <- &RegisterResponse{msg: registered} + requestChan := request.(chan *RegisterResponse) + requestChan <- &RegisterResponse{msg: registered} case messages.MessageTypeUnRegistered: unregistered := msg.(*messages.UnRegistered) - request, exists := s.unregisterRequests[unregistered.RequestID()] + request, exists := s.unregisterRequests.Load(unregistered.RequestID()) if !exists { return fmt.Errorf("received UNREGISTERED for unknown request") } - request <- &UnRegisterResponse{msg: unregistered} + requestChan := request.(chan *UnRegisterResponse) + requestChan <- &UnRegisterResponse{msg: unregistered} case messages.MessageTypeResult: result := msg.(*messages.Result) - request, exists := s.callRequests[result.RequestID()] + request, exists := s.callRequests.Load(result.RequestID()) if !exists { return fmt.Errorf("received RESULT for unknown request") } - request <- &CallResponse{msg: result} + req := request.(chan *CallResponse) + req <- &CallResponse{msg: result} case messages.MessageTypeInvocation: invocation := msg.(*messages.Invocation) - endpoint := s.registrations[invocation.RegistrationID()] + end, _ := s.registrations.Load(invocation.RegistrationID()) + endpoint := end.(InvocationHandler) inv := &Invocation{ Args: invocation.Args(), @@ -167,32 +172,31 @@ func (s *Session) processIncomingMessage(msg messages.Message) error { errorMsg := msg.(*messages.Error) switch errorMsg.MessageType() { case messages.MessageTypeCall: - responseChan, exists := s.callRequests[errorMsg.RequestID()] + response, exists := s.callRequests.LoadAndDelete(errorMsg.RequestID()) if !exists { return fmt.Errorf("received ERROR for invalid call request") } - delete(s.callRequests, errorMsg.RequestID()) err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()} + responseChan := response.(chan *CallResponse) responseChan <- &CallResponse{error: err} return nil case messages.MessageTypeRegister: - responseChan, exists := s.registerRequests[errorMsg.RequestID()] + request, exists := s.registerRequests.LoadAndDelete(errorMsg.RequestID()) if !exists { return fmt.Errorf("received ERROR for invalid register request") } - delete(s.registerRequests, errorMsg.RequestID()) err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()} - responseChan <- &RegisterResponse{error: err} + requestChan := request.(chan *RegisterResponse) + requestChan <- &RegisterResponse{error: err} return nil case messages.MessageTypeUnRegister: - _, exists := s.unregisterRequests[errorMsg.RequestID()] + _, exists := s.unregisterRequests.LoadAndDelete(errorMsg.RequestID()) if !exists { return fmt.Errorf("received ERROR for invalid unregister request") } - delete(s.unregisterRequests, errorMsg.RequestID()) return nil case messages.MessageTypeSubscribe: _, exists := s.subscribeRequests[errorMsg.RequestID()] @@ -238,8 +242,8 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca } channel := make(chan *RegisterResponse, 1) - s.registerRequests[register.RequestID()] = channel - defer delete(s.registerRequests, register.RequestID()) + s.registerRequests.Store(register.RequestID(), channel) + defer s.registerRequests.Delete(register.RequestID()) if err = s.base.Write(toSend); err != nil { return nil, err @@ -251,7 +255,7 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca return nil, response.error } - s.registrations[response.msg.RegistrationID()] = handler + s.registrations.Store(response.msg.RegistrationID(), handler) registration := &Registration{ ID: response.msg.RegistrationID(), } @@ -269,8 +273,8 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error { } channel := make(chan *UnRegisterResponse, 1) - s.unregisterRequests[unregister.RequestID()] = channel - defer delete(s.unregisterRequests, unregister.RequestID()) + s.unregisterRequests.Store(unregister.RequestID(), channel) + defer s.unregisterRequests.Delete(unregister.RequestID()) if err = s.base.Write(toSend); err != nil { return err @@ -282,7 +286,7 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error { return response.error } - delete(s.registrations, registrationID) + s.registrations.Delete(registrationID) return nil case <-ctx.Done(): return fmt.Errorf("unregister request timed") @@ -299,8 +303,8 @@ func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs } channel := make(chan *CallResponse, 1) - s.callRequests[call.RequestID()] = channel - defer delete(s.callRequests, call.RequestID()) + s.callRequests.Store(call.RequestID(), channel) + defer s.callRequests.Delete(call.RequestID()) if err = s.base.Write(toSend); err != nil { return nil, err } diff --git a/types.go b/types.go index b2eb331..3fb0841 100644 --- a/types.go +++ b/types.go @@ -111,12 +111,15 @@ type Result struct { } type Error struct { - error URI string Args []any KwArgs map[string]any } +func (e *Error) Error() string { + return e.URI +} + type RegisterResponse struct { msg *messages.Registered error *Error