diff --git a/service/grpc/service.go b/service/grpc/service.go index 4d96c7e9..b5fe90f5 100644 --- a/service/grpc/service.go +++ b/service/grpc/service.go @@ -85,6 +85,7 @@ type Server struct { authToken string grpcServer *grpc.Server started uint32 + stopped uint32 } // Deprecated: Use RegisterActorImplFactoryContext instead. @@ -101,26 +102,33 @@ func (s *Server) Start() error { if !atomic.CompareAndSwapUint32(&s.started, 0, 1) { return errors.New("a gRPC server can only be started once") } + return s.grpcServer.Serve(s.listener) } // Stop stops the previously-started service. func (s *Server) Stop() error { - if atomic.LoadUint32(&s.started) == 0 { - return nil - } - s.grpcServer.Stop() - s.grpcServer = nil - return nil + return s.gracefulStop(false) } // GrecefulStop stops the previously-started service gracefully. func (s *Server) GracefulStop() error { + return s.gracefulStop(true) +} + +func (s *Server) gracefulStop(graceful bool) error { if atomic.LoadUint32(&s.started) == 0 { - return nil + return errors.New("the server doesn't start") + } + + if atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { + if graceful { + s.grpcServer.GracefulStop() + } else { + s.grpcServer.Stop() + } + s.grpcServer = nil } - s.grpcServer.GracefulStop() - s.grpcServer = nil return nil } diff --git a/service/grpc/service_test.go b/service/grpc/service_test.go index b2a3e97d..ffc59a0d 100644 --- a/service/grpc/service_test.go +++ b/service/grpc/service_test.go @@ -15,6 +15,7 @@ package grpc import ( "testing" + "time" "github.com/stretchr/testify/assert" "google.golang.org/grpc" @@ -53,6 +54,7 @@ func startTestServer(server *Server) { panic(err) } }() + time.Sleep(time.Second) } func stopTestServer(t *testing.T, server *Server) { @@ -62,3 +64,52 @@ func stopTestServer(t *testing.T, server *Server) { err := server.Stop() assert.Nilf(t, err, "error stopping server") } + +func TestStartServerTimes(t *testing.T) { + server := getTestServer() + + startTestServer(server) + assert.PanicsWithError(t, "a gRPC server can only be started once", func() { + if err := server.Start(); err != nil && err.Error() != "closed" { + panic(err) + } + }) + + time.Sleep(time.Second) + + stopTestServer(t, server) +} + +func TestStopServerTimes(t *testing.T) { + server := getTestServer() + startTestServer(server) + + err := server.Stop() + assert.Nilf(t, err, "error stopping server") + + err = server.Stop() + assert.Nilf(t, err, "error stopping server") +} + +func TestStopServerBeforeStart(t *testing.T) { + server := getTestServer() + assert.NotNil(t, server) + err := server.Stop() + assert.NotNilf(t, err, "should return error when stopping server before starting") +} + +func TestStartServerAfterStop(t *testing.T) { + server := getTestServer() + startTestServer(server) + stopTestServer(t, server) + err := server.Start() + assert.NotNil(t, err) +} + +func TestGracefulStopServer(t *testing.T) { + server := getTestServer() + startTestServer(server) + assert.NotNil(t, server) + err := server.GracefulStop() + assert.Nilf(t, err, "error stopping server") +}