From abcae5f249d21d4c751db95e3a7241da239de675 Mon Sep 17 00:00:00 2001 From: Clement Date: Sat, 29 Jun 2024 02:19:52 +0800 Subject: [PATCH] etcdserver: compact the raft log up to the minimum snapshot index of all ongoin snapshot creation Signed-off-by: Clement --- server/etcdserver/raft.go | 16 ++- server/etcdserver/server.go | 31 +++--- server/etcdserver/snapshot_tracker.go | 74 +++++++++++++ server/etcdserver/snapshot_tracker_test.go | 122 +++++++++++++++++++++ 4 files changed, 223 insertions(+), 20 deletions(-) create mode 100644 server/etcdserver/snapshot_tracker.go create mode 100644 server/etcdserver/snapshot_tracker_test.go diff --git a/server/etcdserver/raft.go b/server/etcdserver/raft.go index d397612af9c4..f6858dc7b378 100644 --- a/server/etcdserver/raft.go +++ b/server/etcdserver/raft.go @@ -92,6 +92,9 @@ type raftNode struct { // a chan to send out readState readStateC chan raft.ReadState + // keep track of snapshots being created + snapshotTracker SnapshotTracker + // utility ticker *time.Ticker // contention detectors for raft heartbeat message @@ -136,12 +139,13 @@ func newRaftNode(cfg raftNodeConfig) *raftNode { raftNodeConfig: cfg, // set up contention detectors for raft heartbeat message. // expect to send a heartbeat within 2 heartbeat intervals. - td: contention.NewTimeoutDetector(2 * cfg.heartbeat), - readStateC: make(chan raft.ReadState, 1), - msgSnapC: make(chan raftpb.Message, maxInFlightMsgSnap), - applyc: make(chan toApply), - stopped: make(chan struct{}), - done: make(chan struct{}), + td: contention.NewTimeoutDetector(2 * cfg.heartbeat), + readStateC: make(chan raft.ReadState, 1), + snapshotTracker: SnapshotTracker{}, + msgSnapC: make(chan raftpb.Message, maxInFlightMsgSnap), + applyc: make(chan toApply), + stopped: make(chan struct{}), + done: make(chan struct{}), } if r.heartbeat == 0 { r.ticker = &time.Ticker{} diff --git a/server/etcdserver/server.go b/server/etcdserver/server.go index d1784f8a3a1a..c061cae2c955 100644 --- a/server/etcdserver/server.go +++ b/server/etcdserver/server.go @@ -208,7 +208,6 @@ type Server interface { type EtcdServer struct { // inflightSnapshots holds count the number of snapshots currently inflight. inflightSnapshots int64 // must use atomic operations to access; keep 64-bit aligned. - creatingSnapshots int64 // must use atomic operations to access; keep 64-bit aligned. appliedIndex uint64 // must use atomic operations to access; keep 64-bit aligned. committedIndex uint64 // must use atomic operations to access; keep 64-bit aligned. term uint64 // must use atomic operations to access; keep 64-bit aligned. @@ -2129,11 +2128,13 @@ func (s *EtcdServer) snapshot(snapi uint64, confState raftpb.ConfState) { // the go routine created below. s.KV().Commit() - atomic.AddInt64(&s.creatingSnapshots, 1) + s.r.snapshotTracker.Track(snapi) + s.GoAttach(func() { defer func() { - atomic.AddInt64(&s.creatingSnapshots, -1) + s.r.snapshotTracker.UnTrack(snapi) }() + lg := s.Logger() // For backward compatibility, generate v2 snapshot from v3 state. @@ -2177,21 +2178,23 @@ func (s *EtcdServer) compactRaftLog(appliedi uint64) { return } - // If there are snapshots being created, skip compaction until they are done. - // This ensures `s.r.raftStorage.Compact` does not remove elements from `s.r.raftStorage.Ents`, - // preventing `s.r.raftStorage.CreateSnapshot` from causing a panic. - if atomic.LoadInt64(&s.creatingSnapshots) != 0 { - lg.Info("skip compaction since there are snapshots being created") + // keep some in memory log entries for slow followers. + compacti := uint64(0) + if appliedi > s.Cfg.SnapshotCatchUpEntries { + compacti = appliedi - s.Cfg.SnapshotCatchUpEntries + } + + // if there are snapshots being created, compact the raft log up to the minimum snapshot index. + if minSpani, err := s.r.snapshotTracker.MinSnapi(); err == nil && minSpani < appliedi && minSpani > s.Cfg.SnapshotCatchUpEntries { + compacti = minSpani - s.Cfg.SnapshotCatchUpEntries + } + + // no need to compact if compacti == 0 + if compacti == 0 { return } s.GoAttach(func() { - // keep some in memory log entries for slow followers. - compacti := uint64(0) - if appliedi > s.Cfg.SnapshotCatchUpEntries { - compacti = appliedi - s.Cfg.SnapshotCatchUpEntries - } - err := s.r.raftStorage.Compact(compacti) if err != nil { // the compaction was done asynchronously with the progress of raft. diff --git a/server/etcdserver/snapshot_tracker.go b/server/etcdserver/snapshot_tracker.go new file mode 100644 index 000000000000..e18a87255f84 --- /dev/null +++ b/server/etcdserver/snapshot_tracker.go @@ -0,0 +1,74 @@ +package etcdserver + +import ( + "cmp" + "container/heap" + "errors" + "sync" +) + +// SnapshotTracker keeps track of all ongoing snapshot creation. To safeguard ongoing snapshot creation, +// only compact the raft log up to the minimum snapshot index in the track. +type SnapshotTracker struct { + h minHeap[uint64] + mu sync.Mutex +} + +// MinSnapi returns the minimum snapshot index in the track or an error if the tracker is empty. +func (st *SnapshotTracker) MinSnapi() (uint64, error) { + st.mu.Lock() + defer st.mu.Unlock() + if st.h.Len() == 0 { + return 0, errors.New("SnapshotTracker is empty") + } + return st.h[0], nil +} + +// Track adds a snapi to the tracker. Make sure to call UnTrack once the snapshot has been created. +func (st *SnapshotTracker) Track(snapi uint64) { + st.mu.Lock() + defer st.mu.Unlock() + heap.Push(&st.h, snapi) +} + +// UnTrack removes 'snapi' from the tracker. No action taken if 'snapi' is not found. +func (st *SnapshotTracker) UnTrack(snapi uint64) { + st.mu.Lock() + defer st.mu.Unlock() + + for i := 0; i < len((*st).h); i++ { + if (*st).h[i] == snapi { + heap.Remove(&st.h, i) + return + } + } +} + +// minHeap implements the heap.Interface for E. +type minHeap[E interface { + cmp.Ordered +}] []E + +func (h minHeap[_]) Len() int { + return len(h) +} + +func (h minHeap[_]) Less(i, j int) bool { + return h[i] < h[j] +} + +func (h minHeap[_]) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *minHeap[E]) Push(x any) { + *h = append(*h, x.(E)) +} + +func (h *minHeap[E]) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/server/etcdserver/snapshot_tracker_test.go b/server/etcdserver/snapshot_tracker_test.go new file mode 100644 index 000000000000..27f2cf670d04 --- /dev/null +++ b/server/etcdserver/snapshot_tracker_test.go @@ -0,0 +1,122 @@ +package etcdserver + +import ( + "container/heap" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSnapTracker_MinSnapi(t *testing.T) { + st := SnapshotTracker{} + + _, err := st.MinSnapi() + assert.NotNil(t, err, "SnapshotTracker should be empty initially") + + st.Track(10) + minSnapi, err := st.MinSnapi() + assert.Nil(t, err) + assert.Equal(t, uint64(10), minSnapi, "MinSnapi should return the only tracked snapshot index") + + st.Track(5) + minSnapi, err = st.MinSnapi() + assert.Nil(t, err) + assert.Equal(t, uint64(5), minSnapi, "MinSnapi should return the minimum tracked snapshot index") + + st.UnTrack(5) + minSnapi, err = st.MinSnapi() + assert.Nil(t, err) + assert.Equal(t, uint64(10), minSnapi, "MinSnapi should return the remaining tracked snapshot index") +} + +func TestSnapTracker_Track(t *testing.T) { + st := SnapshotTracker{} + st.Track(20) + st.Track(10) + st.Track(15) + + assert.Equal(t, 3, st.h.Len(), "SnapshotTracker should have 3 snapshots tracked") + + minSnapi, err := st.MinSnapi() + assert.Nil(t, err) + assert.Equal(t, uint64(10), minSnapi, "MinSnapi should return the minimum tracked snapshot index") +} + +func TestSnapTracker_UnTrack(t *testing.T) { + st := SnapshotTracker{} + st.Track(20) + st.Track(30) + st.Track(40) + // track another snapshot with the same index + st.Track(20) + + st.UnTrack(30) + assert.Equal(t, 3, st.h.Len()) + + minSnapi, err := st.MinSnapi() + assert.Nil(t, err) + assert.Equal(t, uint64(20), minSnapi) + + st.UnTrack(20) + assert.Equal(t, 2, st.h.Len()) + + minSnapi, err = st.MinSnapi() + assert.Nil(t, err) + assert.Equal(t, uint64(20), minSnapi) + + st.UnTrack(20) + minSnapi, err = st.MinSnapi() + assert.Equal(t, uint64(40), minSnapi) + + st.UnTrack(40) + _, err = st.MinSnapi() + assert.NotNil(t, err) +} + +func newMinHeap(elements ...uint64) minHeap[uint64] { + h := minHeap[uint64](elements) + heap.Init(&h) + return h +} + +func TestMinHeapLen(t *testing.T) { + h := newMinHeap(3, 2, 1) + assert.Equal(t, 3, h.Len()) +} + +func TestMinHeapLess(t *testing.T) { + h := newMinHeap(3, 2, 1) + assert.True(t, h.Less(0, 1)) +} + +func TestMinHeapSwap(t *testing.T) { + h := newMinHeap(3, 2, 1) + h.Swap(0, 1) + assert.Equal(t, uint64(2), h[0]) + assert.Equal(t, uint64(1), h[1]) + assert.Equal(t, uint64(3), h[2]) +} + +func TestMinHeapPushPop(t *testing.T) { + h := newMinHeap(3, 2) + heap.Push(&h, uint64(1)) + assert.Equal(t, 3, h.Len()) + + got := heap.Pop(&h).(uint64) + assert.Equal(t, uint64(1), got) +} + +func TestMinHeapEmpty(t *testing.T) { + h := minHeap[uint64]{} + assert.Equal(t, 0, h.Len()) +} + +func TestMinHeapSingleElement(t *testing.T) { + h := newMinHeap(uint64(1)) + assert.Equal(t, 1, h.Len()) + + heap.Push(&h, uint64(2)) + assert.Equal(t, 2, h.Len()) + + got := heap.Pop(&h) + assert.Equal(t, uint64(1), got) +}