Skip to content

Commit

Permalink
Add delete logic in partition sync
Browse files Browse the repository at this point in the history
  • Loading branch information
sudiptob2 committed Sep 27, 2024
1 parent 27af92a commit bda4b34
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 13 deletions.
2 changes: 1 addition & 1 deletion internal/yunikorn/event_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (s *Service) handleQueueEvents(ctx context.Context, events []*si.EventRecor
func (s *Service) handleQueueAddEvent(ctx context.Context) {
logger := log.FromContext(ctx)

partitions, err := s.upsertPartitions(ctx)
partitions, err := s.syncPartitions(ctx)
if err != nil {
logger.Errorf("could not get partitions: %v", err)
return
Expand Down
51 changes: 42 additions & 9 deletions internal/yunikorn/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

// sync fetches the state of the applications from the Yunikorn API and upserts them into the database
func (s *Service) sync(ctx context.Context) error {
partitions, err := s.upsertPartitions(ctx)
partitions, err := s.syncPartitions(ctx)
if err != nil {
return fmt.Errorf("error getting and upserting partitions: %v", err)
}
Expand Down Expand Up @@ -75,8 +75,8 @@ func (s *Service) sync(ctx context.Context) error {
return nil
}

// upsertPartitions fetches partitions from the Yunikorn API and upserts them into the database
func (s *Service) upsertPartitions(ctx context.Context) ([]*dao.PartitionInfo, error) {
// syncPartitions fetches partitions from the Yunikorn API and syncs them into the database
func (s *Service) syncPartitions(ctx context.Context) ([]*dao.PartitionInfo, error) {
logger := log.FromContext(ctx)
// Get partitions from Yunikorn API and upsert into DB
partitions, err := s.client.GetPartitions(ctx)
Expand All @@ -85,16 +85,49 @@ func (s *Service) upsertPartitions(ctx context.Context) ([]*dao.PartitionInfo, e
}

err = s.workqueue.Add(func(ctx context.Context) error {
logger.Infow("upserting partitions", "count", len(partitions))
return s.repo.UpsertPartitions(ctx, partitions)
}, workqueue.WithJobName("upsert_partitions"))
logger.Infow("syncing partitions", "count", len(partitions))
err := s.repo.UpsertPartitions(ctx, partitions)
if err != nil {
return fmt.Errorf("could not upsert partitions: %w", err)
}
// Delete partitions that are not present in the API response
deleteCandidates, err := s.findPartitionDeleteCandidates(ctx, partitions)
if err != nil {
return fmt.Errorf("failed to find delete candidates: %w", err)
}
return s.repo.DeletePartitions(ctx, deleteCandidates)
}, workqueue.WithJobName("sync_partitions"))
if err != nil {
logger.Errorf("could not add upsert partitions job to workqueue: %v", err)
logger.Errorf("could not add sync_partitions job to workqueue: %v", err)
}

return partitions, nil
}

func (s *Service) findPartitionDeleteCandidates(ctx context.Context, apiPartitions []*dao.PartitionInfo) ([]*model.PartitionInfo, error) {

apiPartitionMap := make(map[string]*dao.PartitionInfo)
for _, p := range apiPartitions {
apiPartitionMap[p.Name] = p
}

// Fetch partitions from the database
dbPartitions, err := s.repo.GetAllPartitions(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve partitions from DB: %w", err)
}

// Identify partitions in the database that are not present in the API response
var deleteCandidates []*model.PartitionInfo
for _, dbPartition := range dbPartitions {
if _, found := apiPartitionMap[dbPartition.Name]; !found {
deleteCandidates = append(deleteCandidates, dbPartition)
}
}

return deleteCandidates, nil
}

// syncQueues fetches queues for each partition and upserts them into the database
func (s *Service) syncQueues(ctx context.Context, partitions []*dao.PartitionInfo) ([]*dao.PartitionQueueDAOInfo, error) {
logger := log.FromContext(ctx)
Expand Down Expand Up @@ -153,7 +186,7 @@ func (s *Service) syncPartitionQueues(ctx context.Context, partition *dao.Partit

queues := flattenQueues([]*dao.PartitionQueueDAOInfo{queue})
// Find candidates for deletion
deleteCandidates, err := s.findDeleteCandidates(ctx, partition, queues)
deleteCandidates, err := s.findQueueDeleteCandidates(ctx, partition, queues)
if err != nil {
return nil, fmt.Errorf("failed to find delete candidates: %w", err)
}
Expand All @@ -166,7 +199,7 @@ func (s *Service) syncPartitionQueues(ctx context.Context, partition *dao.Partit
return queues, nil
}

func (s *Service) findDeleteCandidates(
func (s *Service) findQueueDeleteCandidates(
ctx context.Context,
partition *dao.PartitionInfo,
apiQueues []*dao.PartitionQueueDAOInfo) ([]*model.PartitionQueueDAOInfo, error) {
Expand Down
132 changes: 130 additions & 2 deletions internal/yunikorn/sync_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ import (
"testing"
"time"

"github.com/G-Research/yunikorn-history-server/internal/model"
"github.com/apache/yunikorn-core/pkg/webservice/dao"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"

"github.com/G-Research/yunikorn-history-server/internal/model"

"github.com/G-Research/yunikorn-history-server/internal/database/migrations"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -459,6 +458,122 @@ func TestSync_syncQueues_Integration(t *testing.T) {
}
}

func TestSync_syncPartitions_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

pool, repo, cleanupDB := setupDatabase(t, ctx)
t.Cleanup(cleanupDB)
eventRepository := repository.NewInMemoryEventRepository()

now := time.Now().Unix()

tests := []struct {
name string
setup func() *httptest.Server
existingPartitions []*dao.PartitionInfo
expected []*model.PartitionInfo
wantErr bool
}{
{
name: "Sync partition with no existing partitions in DB",
setup: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := []*dao.PartitionInfo{
{Name: "default"},
{Name: "secondary"},
}
writeResponse(t, w, response)
}))
},
existingPartitions: nil,
expected: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "default"}},
{PartitionInfo: dao.PartitionInfo{Name: "secondary"}},
},
wantErr: false,
},
{
name: "Should mark secondary partition as deleted in DB",
setup: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := []*dao.PartitionInfo{
{Name: "default"},
}
writeResponse(t, w, response)
}))
},
existingPartitions: []*dao.PartitionInfo{
{Name: "default"},
{Name: "secondary"},
},
expected: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "default"}},
{PartitionInfo: dao.PartitionInfo{Name: "secondary"}, DeletedAt: &now},
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// clean up the table after the test
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM partitions")
require.NoError(t, err)
})
// seed the existing partitions
if tt.existingPartitions != nil {
if err := repo.UpsertPartitions(ctx, tt.existingPartitions); err != nil {
t.Fatalf("could not seed partition: %v", err)
}
}

ts := tt.setup()
defer ts.Close()

client := NewRESTClient(getMockServerYunikornConfig(t, ts.URL))
s := NewService(repo, eventRepository, client)

// Start the service
ctx, cancel := context.WithCancel(context.Background())
go func() {
_ = s.Run(ctx)
}()

// Ensure workqueue is started
assert.Eventually(t, func() bool {
return s.workqueue.Started()
}, 500*time.Millisecond, 50*time.Millisecond)
time.Sleep(100 * time.Millisecond)

// Cleanup after each test case
t.Cleanup(func() {
cancel()
s.workqueue.Shutdown()
})

_, err := s.syncPartitions(ctx)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
partitionsInDB, err := s.repo.GetAllPartitions(ctx)
require.NoError(t, err)
for _, target := range tt.expected {
if !isPartitionPresent(partitionsInDB, target) {
t.Errorf("Partition %s is not found in the DB", target.Name)
}
}
})
}
}

func isQueuePresent(queuesInDB []*model.PartitionQueueDAOInfo, targetQueue *model.PartitionQueueDAOInfo) bool {
for _, dbQueue := range queuesInDB {
if dbQueue.QueueName == targetQueue.QueueName && dbQueue.Partition == targetQueue.Partition {
Expand All @@ -482,6 +597,19 @@ func extractPartitionNameFromURL(urlPath string) string {
return ""
}

func isPartitionPresent(partitionsInDB []*model.PartitionInfo, targetPartition *model.PartitionInfo) bool {
for _, dbPartition := range partitionsInDB {
if dbPartition.Name == targetPartition.Name {
// Check if DeletedAt fields match
if (dbPartition.DeletedAt == nil && targetPartition.DeletedAt != nil) || (dbPartition.DeletedAt != nil && targetPartition.DeletedAt == nil) {
return false
}
return true
}
}
return false
}

func setupDatabase(t *testing.T, ctx context.Context) (*pgxpool.Pool, repository.Repository, func()) {
schema := database.CreateTestSchema(ctx, t)
cfg := config.GetTestPostgresConfig()
Expand Down
2 changes: 1 addition & 1 deletion internal/yunikorn/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestSync_findDeleteCandidates(t *testing.T) {
repo: mockRepo,
}

deleteCandidates, err := s.findDeleteCandidates(ctx, partition, tt.apiQueues)
deleteCandidates, err := s.findQueueDeleteCandidates(ctx, partition, tt.apiQueues)

if tt.expectedErr != nil {
require.Error(t, err)
Expand Down

0 comments on commit bda4b34

Please sign in to comment.