Skip to content

Commit

Permalink
Merge pull request #5 from xconnio/concurrency
Browse files Browse the repository at this point in the history
Concurrency fixes
  • Loading branch information
om26er authored May 30, 2024
2 parents 37e0bc0 + c32fa15 commit 31f6e7e
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 40 deletions.
19 changes: 14 additions & 5 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log"
"testing"

"github.com/gammazero/workerpool"
"github.com/stretchr/testify/require"

"github.com/xconnio/xconn-go"
Expand All @@ -17,7 +18,9 @@ func connect(t *testing.T) *xconn.Session {
defer func() { _ = listener.Close() }()
address := fmt.Sprintf("ws://%s/ws", listener.Addr().String())

client := &xconn.Client{}
client := &xconn.Client{
SerializerSpec: xconn.JSONSerializerSpec,
}

session, err := client.Connect(context.Background(), address, "realm1")
require.NoError(t, err)
Expand Down Expand Up @@ -54,11 +57,17 @@ func TestRegisterCall(t *testing.T) {
require.NotNil(t, reg)

t.Run("Call", func(t *testing.T) {
result, err := session.Call(context.Background(), "foo.bar", nil, nil, nil)
require.NoError(t, err)
require.NotNil(t, result)
wp := workerpool.New(10)
for i := 0; i < 100; i++ {
wp.Submit(func() {
result, err := session.Call(context.Background(), "foo.bar", nil, nil, nil)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "hello", result.Args[0])
})
}

require.Equal(t, "hello", result.Args[0])
wp.StopWait()
})
}

Expand Down
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ go 1.20

require (
github.com/gammazero/nexus/v3 v3.2.2
github.com/gammazero/workerpool v1.1.3
github.com/gobwas/ws v1.4.0
github.com/stretchr/testify v1.8.4
github.com/xconnio/wampproto-go v0.0.0-20240530132134-a9a2ca11944a
github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fxamacker/cbor/v2 v2.6.0 // indirect
github.com/gammazero/deque v0.2.0 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
Expand Down
8 changes: 6 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA=
github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/gammazero/deque v0.2.0 h1:SkieyNB4bg2/uZZLxvya0Pq6diUlwx7m2TeT7GAIWaA=
github.com/gammazero/deque v0.2.0/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU=
github.com/gammazero/nexus/v3 v3.2.2 h1:uEBe4rKIcbBcbdP6XuyKUhnWBXxT0BnJrecG9+yZSTs=
github.com/gammazero/nexus/v3 v3.2.2/go.mod h1:55oZwPZFgRFCEjpMj1kdzffiPORKKmRsipSY8BeKRvY=
github.com/gammazero/workerpool v1.1.3 h1:WixN4xzukFoN0XSeXF6puqEqFTl2mECI9S6W44HWy9Q=
github.com/gammazero/workerpool v1.1.3/go.mod h1:wPjyBLDbyKnUn2XwwyD3EEwo9dHutia9/fwNmSHWACc=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
Expand All @@ -24,8 +28,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xconnio/wampproto-go v0.0.0-20240530132134-a9a2ca11944a h1:0Vb6+/sNho0zn7kjRsdBvdGylFkz9uZZzeT1RrgKHIU=
github.com/xconnio/wampproto-go v0.0.0-20240530132134-a9a2ca11944a/go.mod h1:BH0AFRLJ9POvVfxsFd9GyvA15U9o0XYQfq8TdkqO2vQ=
github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226 h1:1UFs+1ev6G1qDVgf5tGqnhsm5e9btuon41ogyrR1QD4=
github.com/xconnio/wampproto-go v0.0.0-20240530202948-a758eb534226/go.mod h1:BH0AFRLJ9POvVfxsFd9GyvA15U9o0XYQfq8TdkqO2vQ=
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
Expand Down
10 changes: 9 additions & 1 deletion peer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package xconn

import "net"
import (
"net"
"sync"
)

func NewBaseSession(id int64, realm, authID, authRole string, cl Peer) BaseSession {
return &baseSession{
Expand Down Expand Up @@ -80,13 +83,18 @@ type WebSocketPeer struct {
conn net.Conn
wsReader ReaderFunc
wsWriter WriterFunc

wm sync.Mutex
}

func (c *WebSocketPeer) Read() ([]byte, error) {
return c.wsReader(c.conn)
}

func (c *WebSocketPeer) Write(bytes []byte) error {
c.wm.Lock()
defer c.wm.Unlock()

return c.wsWriter(c.conn, bytes)
}

Expand Down
64 changes: 34 additions & 30 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"sync"

"github.com/xconnio/wampproto-go"
"github.com/xconnio/wampproto-go/messages"
Expand All @@ -21,10 +22,10 @@ type Session struct {
idGen *wampproto.SessionScopeIDGenerator

// remote procedure calls data structures
registerRequests map[int64]chan *RegisterResponse
unregisterRequests map[int64]chan *UnRegisterResponse
registrations map[int64]InvocationHandler
callRequests map[int64]chan *CallResponse
registerRequests sync.Map
unregisterRequests sync.Map
registrations sync.Map
callRequests sync.Map

// publish subscribe data structures
subscribeRequests map[int64]chan *SubscribeResponse
Expand All @@ -39,10 +40,10 @@ func NewSession(base BaseSession, serializer serializers.Serializer) *Session {
proto: wampproto.NewSession(serializer),
idGen: &wampproto.SessionScopeIDGenerator{},

registerRequests: map[int64]chan *RegisterResponse{},
unregisterRequests: map[int64]chan *UnRegisterResponse{},
registrations: map[int64]InvocationHandler{},
callRequests: map[int64]chan *CallResponse{},
registerRequests: sync.Map{},
unregisterRequests: sync.Map{},
registrations: sync.Map{},
callRequests: sync.Map{},

subscribeRequests: map[int64]chan *SubscribeResponse{},
unsubscribeRequests: map[int64]chan *UnSubscribeResponse{},
Expand Down Expand Up @@ -81,31 +82,35 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
switch msg.Type() {
case messages.MessageTypeRegistered:
registered := msg.(*messages.Registered)
request, exists := s.registerRequests[registered.RequestID()]
request, exists := s.registerRequests.Load(registered.RequestID())
if !exists {
return fmt.Errorf("received REGISTERED for unknown request")
}

request <- &RegisterResponse{msg: registered}
requestChan := request.(chan *RegisterResponse)
requestChan <- &RegisterResponse{msg: registered}
case messages.MessageTypeUnRegistered:
unregistered := msg.(*messages.UnRegistered)
request, exists := s.unregisterRequests[unregistered.RequestID()]
request, exists := s.unregisterRequests.Load(unregistered.RequestID())
if !exists {
return fmt.Errorf("received UNREGISTERED for unknown request")
}

request <- &UnRegisterResponse{msg: unregistered}
requestChan := request.(chan *UnRegisterResponse)
requestChan <- &UnRegisterResponse{msg: unregistered}
case messages.MessageTypeResult:
result := msg.(*messages.Result)
request, exists := s.callRequests[result.RequestID()]
request, exists := s.callRequests.Load(result.RequestID())
if !exists {
return fmt.Errorf("received RESULT for unknown request")
}

request <- &CallResponse{msg: result}
req := request.(chan *CallResponse)
req <- &CallResponse{msg: result}
case messages.MessageTypeInvocation:
invocation := msg.(*messages.Invocation)
endpoint := s.registrations[invocation.RegistrationID()]
end, _ := s.registrations.Load(invocation.RegistrationID())
endpoint := end.(InvocationHandler)

inv := &Invocation{
Args: invocation.Args(),
Expand Down Expand Up @@ -167,32 +172,31 @@ func (s *Session) processIncomingMessage(msg messages.Message) error {
errorMsg := msg.(*messages.Error)
switch errorMsg.MessageType() {
case messages.MessageTypeCall:
responseChan, exists := s.callRequests[errorMsg.RequestID()]
response, exists := s.callRequests.LoadAndDelete(errorMsg.RequestID())
if !exists {
return fmt.Errorf("received ERROR for invalid call request")
}

delete(s.callRequests, errorMsg.RequestID())
err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()}
responseChan := response.(chan *CallResponse)
responseChan <- &CallResponse{error: err}
return nil
case messages.MessageTypeRegister:
responseChan, exists := s.registerRequests[errorMsg.RequestID()]
request, exists := s.registerRequests.LoadAndDelete(errorMsg.RequestID())
if !exists {
return fmt.Errorf("received ERROR for invalid register request")
}

delete(s.registerRequests, errorMsg.RequestID())
err := &Error{URI: errorMsg.URI(), Args: errorMsg.Args(), KwArgs: errorMsg.KwArgs()}
responseChan <- &RegisterResponse{error: err}
requestChan := request.(chan *RegisterResponse)
requestChan <- &RegisterResponse{error: err}
return nil
case messages.MessageTypeUnRegister:
_, exists := s.unregisterRequests[errorMsg.RequestID()]
_, exists := s.unregisterRequests.LoadAndDelete(errorMsg.RequestID())
if !exists {
return fmt.Errorf("received ERROR for invalid unregister request")
}

delete(s.unregisterRequests, errorMsg.RequestID())
return nil
case messages.MessageTypeSubscribe:
_, exists := s.subscribeRequests[errorMsg.RequestID()]
Expand Down Expand Up @@ -238,8 +242,8 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca
}

channel := make(chan *RegisterResponse, 1)
s.registerRequests[register.RequestID()] = channel
defer delete(s.registerRequests, register.RequestID())
s.registerRequests.Store(register.RequestID(), channel)
defer s.registerRequests.Delete(register.RequestID())

if err = s.base.Write(toSend); err != nil {
return nil, err
Expand All @@ -251,7 +255,7 @@ func (s *Session) Register(ctx context.Context, procedure string, handler Invoca
return nil, response.error
}

s.registrations[response.msg.RegistrationID()] = handler
s.registrations.Store(response.msg.RegistrationID(), handler)
registration := &Registration{
ID: response.msg.RegistrationID(),
}
Expand All @@ -269,8 +273,8 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error {
}

channel := make(chan *UnRegisterResponse, 1)
s.unregisterRequests[unregister.RequestID()] = channel
defer delete(s.unregisterRequests, unregister.RequestID())
s.unregisterRequests.Store(unregister.RequestID(), channel)
defer s.unregisterRequests.Delete(unregister.RequestID())

if err = s.base.Write(toSend); err != nil {
return err
Expand All @@ -282,7 +286,7 @@ func (s *Session) UnRegister(ctx context.Context, registrationID int64) error {
return response.error
}

delete(s.registrations, registrationID)
s.registrations.Delete(registrationID)
return nil
case <-ctx.Done():
return fmt.Errorf("unregister request timed")
Expand All @@ -299,8 +303,8 @@ func (s *Session) Call(ctx context.Context, procedure string, args []any, kwArgs
}

channel := make(chan *CallResponse, 1)
s.callRequests[call.RequestID()] = channel
defer delete(s.callRequests, call.RequestID())
s.callRequests.Store(call.RequestID(), channel)
defer s.callRequests.Delete(call.RequestID())
if err = s.base.Write(toSend); err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion types.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ type Result struct {
}

type Error struct {
error
URI string
Args []any
KwArgs map[string]any
}

func (e *Error) Error() string {
return e.URI
}

type RegisterResponse struct {
msg *messages.Registered
error *Error
Expand Down

0 comments on commit 31f6e7e

Please sign in to comment.