aboutsummaryrefslogtreecommitdiffstats
path: root/node/node_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'node/node_test.go')
-rw-r--r--node/node_test.go54
1 files changed, 45 insertions, 9 deletions
diff --git a/node/node_test.go b/node/node_test.go
index 7201276b5..1d5570e42 100644
--- a/node/node_test.go
+++ b/node/node_test.go
@@ -102,7 +102,7 @@ func TestNodeUsedDataDir(t *testing.T) {
type NoopService struct{}
func (s *NoopService) Protocols() []p2p.Protocol { return nil }
-func (s *NoopService) Start() error { return nil }
+func (s *NoopService) Start(*p2p.Server) error { return nil }
func (s *NoopService) Stop() error { return nil }
func NewNoopService(*ServiceContext) (Service, error) { return new(NoopService), nil }
@@ -148,7 +148,7 @@ type InstrumentedService struct {
stop error
protocolsHook func()
- startHook func()
+ startHook func(*p2p.Server)
stopHook func()
}
@@ -159,9 +159,9 @@ func (s *InstrumentedService) Protocols() []p2p.Protocol {
return s.protocols
}
-func (s *InstrumentedService) Start() error {
+func (s *InstrumentedService) Start(server *p2p.Server) error {
if s.startHook != nil {
- s.startHook()
+ s.startHook(server)
}
return s.start
}
@@ -189,7 +189,7 @@ func TestServiceLifeCycle(t *testing.T) {
id := id // Closure for the constructor
constructor := func(*ServiceContext) (Service, error) {
return &InstrumentedService{
- startHook: func() { started[id] = true },
+ startHook: func(*p2p.Server) { started[id] = true },
stopHook: func() { stopped[id] = true },
}, nil
}
@@ -241,7 +241,7 @@ func TestServiceRestarts(t *testing.T) {
running = false
return &InstrumentedService{
- startHook: func() {
+ startHook: func(*p2p.Server) {
if running {
panic("already running")
}
@@ -288,7 +288,7 @@ func TestServiceConstructionAbortion(t *testing.T) {
id := id // Closure for the constructor
constructor := func(*ServiceContext) (Service, error) {
return &InstrumentedService{
- startHook: func() { started[id] = true },
+ startHook: func(*p2p.Server) { started[id] = true },
}, nil
}
if err := stack.Register(id, constructor); err != nil {
@@ -334,7 +334,7 @@ func TestServiceStartupAbortion(t *testing.T) {
id := id // Closure for the constructor
constructor := func(*ServiceContext) (Service, error) {
return &InstrumentedService{
- startHook: func() { started[id] = true },
+ startHook: func(*p2p.Server) { started[id] = true },
stopHook: func() { stopped[id] = true },
}, nil
}
@@ -384,7 +384,7 @@ func TestServiceTerminationGuarantee(t *testing.T) {
id := id // Closure for the constructor
constructor := func(*ServiceContext) (Service, error) {
return &InstrumentedService{
- startHook: func() { started[id] = true },
+ startHook: func(*p2p.Server) { started[id] = true },
stopHook: func() { stopped[id] = true },
}, nil
}
@@ -438,6 +438,42 @@ func TestServiceTerminationGuarantee(t *testing.T) {
}
}
+// TestSingletonServiceRetrieval tests that singleton services can be retrieved.
+func TestSingletonServiceRetrieval(t *testing.T) {
+ // Create a simple stack and register two service types
+ stack, err := New(testNodeConfig)
+ if err != nil {
+ t.Fatalf("failed to create protocol stack: %v", err)
+ }
+ if err := stack.Register("noop", func(*ServiceContext) (Service, error) { return new(NoopService), nil }); err != nil {
+ t.Fatalf("noop service registration failed: %v", err)
+ }
+ if err := stack.Register("instrumented", func(*ServiceContext) (Service, error) { return new(InstrumentedService), nil }); err != nil {
+ t.Fatalf("instrumented service registration failed: %v", err)
+ }
+ // Make sure none of the services can be retrieved until started
+ var noopServ *NoopService
+ if id, err := stack.SingletonService(&noopServ); id != "" || err != ErrServiceUnknown {
+ t.Fatalf("noop service retrieval mismatch: have %v/%v, want %v/%v", id, err, "", ErrServiceUnknown)
+ }
+ var instServ *InstrumentedService
+ if id, err := stack.SingletonService(&instServ); id != "" || err != ErrServiceUnknown {
+ t.Fatalf("instrumented service retrieval mismatch: have %v/%v, want %v/%v", id, err, "", ErrServiceUnknown)
+ }
+ // Start the stack and ensure everything is retrievable now
+ if err := stack.Start(); err != nil {
+ t.Fatalf("failed to start stack: %v", err)
+ }
+ defer stack.Stop()
+
+ if id, err := stack.SingletonService(&noopServ); id != "noop" || err != nil {
+ t.Fatalf("noop service retrieval mismatch: have %v/%v, want %v/%v", id, err, "noop", nil)
+ }
+ if id, err := stack.SingletonService(&instServ); id != "instrumented" || err != nil {
+ t.Fatalf("instrumented service retrieval mismatch: have %v/%v, want %v/%v", id, err, "instrumented", nil)
+ }
+}
+
// Tests that all protocols defined by individual services get launched.
func TestProtocolGather(t *testing.T) {
stack, err := New(testNodeConfig)