Skip to content

Commit

Permalink
push finding delete candidate logic into db layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
sudiptob2 committed Oct 3, 2024
1 parent e5b3ccc commit 4f06b06
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 156 deletions.
2 changes: 1 addition & 1 deletion internal/database/repository/mock_repository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 13 additions & 7 deletions internal/database/repository/partitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,24 @@ func (s *PostgresRepository) GetActivePartitions(ctx context.Context) ([]*model.
return partitions, nil
}

func (s *PostgresRepository) DeletePartitions(ctx context.Context, partitions []*model.PartitionInfo) error {
partitionIds := make([]string, len(partitions))
func (s *PostgresRepository) DeletePartitions(ctx context.Context, partitions []*dao.PartitionInfo) error {
partitionNames := make([]string, len(partitions))
for i, p := range partitions {
partitionIds[i] = p.Id
partitionNames[i] = p.Name
}

deletedAt := time.Now().Unix()
query := `UPDATE partitions SET deleted_at = $1 WHERE id = ANY($2)`
_, err := s.dbpool.Exec(ctx, query, deletedAt, partitionIds)
query := `
UPDATE partitions
SET deleted_at = $1
WHERE id IN (
SELECT id
FROM partitions
WHERE deleted_at IS NULL AND NOT(name = ANY($2))
)
`
_, err := s.dbpool.Exec(ctx, query, deletedAt, partitionNames)
if err != nil {
return fmt.Errorf("could not mark partitions as deleted in DB: %w", err)
}

return nil
}
2 changes: 1 addition & 1 deletion internal/database/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Repository interface {
UpsertPartitions(ctx context.Context, partitions []*dao.PartitionInfo) error
GetAllPartitions(ctx context.Context) ([]*model.PartitionInfo, error)
GetActivePartitions(ctx context.Context) ([]*model.PartitionInfo, error)
DeletePartitions(ctx context.Context, partitions []*model.PartitionInfo) error
DeletePartitions(ctx context.Context, partitions []*dao.PartitionInfo) error
AddQueues(ctx context.Context, parentId *string, queues []*dao.PartitionQueueDAOInfo) error
UpdateQueue(ctx context.Context, queue *dao.PartitionQueueDAOInfo) error
UpsertQueues(ctx context.Context, queues []*dao.PartitionQueueDAOInfo) error
Expand Down
30 changes: 1 addition & 29 deletions internal/yunikorn/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ func (s *Service) syncPartitions(ctx context.Context) ([]*dao.PartitionInfo, err
if err != nil {
return nil, fmt.Errorf("could not get partitions: %v", err)
}
deleteCandidates, err := s.findPartitionDeleteCandidates(ctx, partitions)
if err != nil {
return nil, fmt.Errorf("failed to find delete candidates: %w", err)
}

err = s.workqueue.Add(func(ctx context.Context) error {
logger.Infow("syncing partitions", "count", len(partitions))
Expand All @@ -94,7 +90,7 @@ func (s *Service) syncPartitions(ctx context.Context) ([]*dao.PartitionInfo, err
return fmt.Errorf("could not upsert partitions: %w", err)
}
// Delete partitions that are not present in the API response
return s.repo.DeletePartitions(ctx, deleteCandidates)
return s.repo.DeletePartitions(ctx, partitions)
}, workqueue.WithJobName("sync_partitions"))
if err != nil {
logger.Errorf("could not add sync_partitions job to workqueue: %v", err)
Expand All @@ -103,30 +99,6 @@ func (s *Service) syncPartitions(ctx context.Context) ([]*dao.PartitionInfo, err
return partitions, nil
}

func (s *Service) findPartitionDeleteCandidates(ctx context.Context, apiPartitions []*dao.PartitionInfo) ([]*model.PartitionInfo, error) {

apiPartitionMap := make(map[string]*dao.PartitionInfo)
for _, p := range apiPartitions {
apiPartitionMap[p.Name] = p
}

// Fetch partitions from the database
dbPartitions, err := s.repo.GetActivePartitions(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve partitions from DB: %w", err)
}

// Identify active partitions in the database that are not present in the API response
var deleteCandidates []*model.PartitionInfo
for _, dbPartition := range dbPartitions {
if _, found := apiPartitionMap[dbPartition.Name]; !found {
deleteCandidates = append(deleteCandidates, dbPartition)
}
}

return deleteCandidates, nil
}

// syncQueues fetches queues for each partition and upserts them into the database
func (s *Service) syncQueues(ctx context.Context, partitions []*dao.PartitionInfo) ([]*dao.PartitionQueueDAOInfo, error) {
logger := log.FromContext(ctx)
Expand Down
11 changes: 9 additions & 2 deletions internal/yunikorn/sync_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,15 @@ func TestSync_syncPartitions_Integration(t *testing.T) {
return
}
require.NoError(t, err)
partitionsInDB, err := s.repo.GetActivePartitions(ctx)
require.NoError(t, err)
var partitionsInDB []*model.PartitionInfo
assert.Eventually(t, func() bool {
partitionsInDB, err = s.repo.GetActivePartitions(ctx)
if err != nil {
t.Logf("error getting partitions: %v", err)
}
return len(partitionsInDB) == len(tt.expected)
}, 5*time.Second, 50*time.Millisecond)

for _, target := range tt.expected {
if !isPartitionPresent(partitionsInDB, target) {
t.Errorf("Partition %s is not found in the DB", target.Name)
Expand Down
116 changes: 0 additions & 116 deletions internal/yunikorn/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,119 +114,3 @@ func TestSync_findQueueDeleteCandidates(t *testing.T) {
})
}
}

func TestSync_findPartitionDeleteCandidates(t *testing.T) {
tests := []struct {
name string
apiPartitions []*dao.PartitionInfo
activePartitionsInDB []*model.PartitionInfo
expectedDelete []*model.PartitionInfo
expectedErr error
}{
{
name: "Single partition in DB not present in API",
apiPartitions: []*dao.PartitionInfo{
{Name: "partition1"},
},
activePartitionsInDB: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "partition1"}, Id: "p1"},
{PartitionInfo: dao.PartitionInfo{Name: "partition3"}, Id: "p3"},
},
expectedDelete: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "partition3"}, Id: "p3"},
},
expectedErr: nil,
},
{
name: "Multiple partitions, no delete candidates",
apiPartitions: []*dao.PartitionInfo{
{Name: "partition1"},
{Name: "partition2"},
},
activePartitionsInDB: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "partition1"}, Id: "p1"},
{PartitionInfo: dao.PartitionInfo{Name: "partition2"}, Id: "p2"},
},
expectedDelete: nil,
expectedErr: nil,
},
{
name: "Multiple delete candidates in DB",
apiPartitions: []*dao.PartitionInfo{
{Name: "partition1"},
},
activePartitionsInDB: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "partition1"}, Id: "p1"},
{PartitionInfo: dao.PartitionInfo{Name: "partition2"}, Id: "p2"},
{PartitionInfo: dao.PartitionInfo{Name: "partition3"}, Id: "p3"},
},
expectedDelete: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "partition2"}, Id: "p2"},
{PartitionInfo: dao.PartitionInfo{Name: "partition3"}, Id: "p3"},
},
expectedErr: nil,
},
{
name: "Previously deleted partition with same name in DB",
apiPartitions: []*dao.PartitionInfo{
{Name: "partition1"},
},
activePartitionsInDB: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "partition1"}, Id: "p1"},
{PartitionInfo: dao.PartitionInfo{Name: "partition2"}, Id: "p3"},
},
expectedDelete: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "partition2"}, Id: "p3"},
},
expectedErr: nil,
},
{
name: "No partitions in API or DB",
apiPartitions: []*dao.PartitionInfo{},
activePartitionsInDB: []*model.PartitionInfo{},
expectedDelete: nil,
expectedErr: nil,
},
{
name: "DB returns error",
apiPartitions: []*dao.PartitionInfo{
{Name: "partition1"},
},
activePartitionsInDB: nil, // Simulate an error from the DB
expectedDelete: nil,
expectedErr: errors.New("db error"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockRepo := repository.NewMockRepository(mockCtrl)

if tt.expectedErr != nil {
mockRepo.EXPECT().GetActivePartitions(ctx).Return(nil, tt.expectedErr)
} else {
mockRepo.EXPECT().GetActivePartitions(ctx).Return(tt.activePartitionsInDB, nil)
}

s := &Service{
repo: mockRepo,
}

deleteCandidates, err := s.findPartitionDeleteCandidates(ctx, tt.apiPartitions)

if tt.expectedErr != nil {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, len(tt.expectedDelete), len(deleteCandidates))

require.Equal(t, tt.expectedDelete, deleteCandidates)
})
}
}

0 comments on commit 4f06b06

Please sign in to comment.