diff --git a/embed/etcd.go b/embed/etcd.go index a09e8e5da1c..d04103820eb 100644 --- a/embed/etcd.go +++ b/embed/etcd.go @@ -109,7 +109,7 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { if !serving { // errored before starting gRPC server for serveCtx.serversC for _, sctx := range e.sctxs { - close(sctx.serversC) + sctx.close() } } e.Close() diff --git a/embed/serve.go b/embed/serve.go index 34856d1fac2..09d68c8854d 100644 --- a/embed/serve.go +++ b/embed/serve.go @@ -16,12 +16,14 @@ package embed import ( "context" + "errors" "fmt" "io/ioutil" defaultLog "log" "net" "net/http" "strings" + "sync" "go.etcd.io/etcd/etcdserver" "go.etcd.io/etcd/etcdserver/api/v3client" @@ -63,6 +65,7 @@ type serveCtx struct { userHandlers map[string]http.Handler serviceRegister func(*grpc.Server) serversC chan *servers + closeOnce sync.Once } type servers struct { @@ -94,7 +97,15 @@ func (sctx *serveCtx) serve( splitHttp bool, gopts ...grpc.ServerOption) (err error) { logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0) - <-s.ReadyNotify() + + // Make sure serversC is closed even if we prematurely exit the function. + defer sctx.close() + + select { + case <-s.StoppingNotify(): + return errors.New("server is stopping") + case <-s.ReadyNotify(): + } if sctx.lg != nil { sctx.lg.Info("ready to serve client requests") @@ -113,8 +124,6 @@ func (sctx *serveCtx) serve( servElection := v3election.NewElectionServer(v3c) servLock := v3lock.NewLockServer(v3c) - // Make sure serversC is closed even if we prematurely exit the function. - defer close(sctx.serversC) var gwmux *gw.ServeMux if s.Cfg.EnableGRPCGateway { // GRPC gateway connects to grpc server via connection provided by grpc dial. @@ -549,3 +558,9 @@ func (sctx *serveCtx) registerTrace() { evf := func(w http.ResponseWriter, r *http.Request) { trace.RenderEvents(w, r, true) } sctx.registerUserHandler("/debug/events", http.HandlerFunc(evf)) } + +func (sctx *serveCtx) close() { + sctx.closeOnce.Do(func() { + close(sctx.serversC) + }) +} diff --git a/etcdserver/server.go b/etcdserver/server.go index 9264dea4fcf..174254dbe2a 100644 --- a/etcdserver/server.go +++ b/etcdserver/server.go @@ -1670,6 +1670,10 @@ func (s *EtcdServer) stopWithDelay(d time.Duration, err error) { // when the server is stopped. func (s *EtcdServer) StopNotify() <-chan struct{} { return s.done } +// StoppingNotify returns a channel that receives an empty struct +// when the server is being stopped. +func (s *EtcdServer) StoppingNotify() <-chan struct{} { return s.stopping } + func (s *EtcdServer) SelfStats() []byte { return s.stats.JSON() } func (s *EtcdServer) LeaderStats() []byte { @@ -2163,6 +2167,7 @@ func (s *EtcdServer) publish(timeout time.Duration) { Val: string(b), } + // gofail: var beforePublishing struct{} for { ctx, cancel := context.WithTimeout(s.ctx, timeout) _, err := s.Do(ctx, req) diff --git a/integration/embed_test.go b/integration/embed_test.go index 14e950e0521..ff639c34674 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -29,8 +29,11 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/embed" + gofail "go.etcd.io/gofail/runtime" ) func TestEmbedEtcd(t *testing.T) { @@ -196,3 +199,56 @@ func setupEmbedCfg(cfg *embed.Config, curls []url.URL, purls []url.URL) { } cfg.InitialCluster = cfg.InitialCluster[1:] } + +func TestEmbedEtcdStopDuringBootstrapping(t *testing.T) { + if len(gofail.List()) == 0 { + t.Skip("please run 'make gofail-enable' before running the test") + } + + fpName := "beforePublishing" + require.NoError(t, gofail.Enable(fpName, `sleep("2s")`)) + t.Cleanup(func() { + terr := gofail.Disable(fpName) + if terr != nil && terr != gofail.ErrDisabled { + t.Fatalf("failed to disable %s: %v", fpName, terr) + } + }) + + done := make(chan struct{}) + go func() { + defer close(done) + + cfg := embed.NewConfig() + urls := newEmbedURLs(false, 2) + setupEmbedCfg(cfg, []url.URL{urls[0]}, []url.URL{urls[1]}) + cfg.Dir = filepath.Join(t.TempDir(), "embed-etcd") + + e, err := embed.StartEtcd(cfg) + if err != nil { + t.Errorf("Failed to start etcd, got error %v", err) + } + defer e.Close() + + go func() { + time.Sleep(time.Second) + e.Server.Stop() + t.Log("Stopped server during bootstrapping") + }() + + select { + case <-e.Server.ReadyNotify(): + t.Log("Server is ready!") + case <-e.Server.StopNotify(): + t.Log("Server is stopped") + case <-time.After(20 * time.Second): + e.Server.Stop() // trigger a shutdown + t.Error("Server took too long to start!") + } + }() + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Error("timeout in bootstrapping etcd") + } +}