Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partition delete logic in sync #243

Merged
merged 13 commits into from
Oct 11, 2024
33 changes: 31 additions & 2 deletions internal/database/repository/mock_repository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 67 additions & 6 deletions internal/database/repository/partitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package repository
import (
"context"
"fmt"
"time"

"github.com/apache/yunikorn-core/pkg/webservice/dao"
"github.com/jackc/pgx/v5"
"github.com/oklog/ulid/v2"

"github.com/G-Research/yunikorn-history-server/internal/model"
)

func (s *PostgresRepository) UpsertPartitions(ctx context.Context, partitions []*dao.PartitionInfo) error {
Expand Down Expand Up @@ -54,19 +57,53 @@ func (s *PostgresRepository) UpsertPartitions(ctx context.Context, partitions []
return nil
}

func (s *PostgresRepository) GetAllPartitions(ctx context.Context) ([]*dao.PartitionInfo, error) {
rows, err := s.dbpool.Query(ctx, "SELECT * FROM partitions")
func (s *PostgresRepository) GetAllPartitions(ctx context.Context) ([]*model.PartitionInfo, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be preparing for pagination parameters here? It seems unlikely we'd ever want to return everything to the front end all at once...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we also have several other APIs that need to be updated with pagination. This task is tracked in #245

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you should probably get only active partitions. We don't need to get deleted partitions to operate on them.
We should also change the name, so we know that we are fetching active partitions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, Thanks for the awesome reviews 🚀
I found that, GetAllPartitions is also being used by the webservice here. I think frontend might need the deleted partition data as well to show some graphical design on them. Therefore, I updated the logic in findPartitionDeleteCandidates()

rows, err := s.dbpool.Query(ctx, `SELECT * FROM partitions`)
if err != nil {
return nil, fmt.Errorf("could not get partitions from DB: %v", err)
}
defer rows.Close()

var partitions []*model.PartitionInfo
for rows.Next() {
var p model.PartitionInfo
if err := rows.Scan(
&p.Id,
&p.ClusterID,
&p.Name,
&p.Capacity.Capacity,
&p.Capacity.UsedCapacity,
&p.Capacity.Utilization,
&p.TotalNodes,
&p.Applications,
&p.TotalContainers,
&p.State,
&p.LastStateTransitionTime,
&p.DeletedAt,
); err != nil {
return nil, fmt.Errorf("could not scan partition from DB: %v", err)
}
partitions = append(partitions, &p)
}

if err := rows.Err(); err != nil {
return nil, fmt.Errorf("failed to read rows: %v", err)
}
return partitions, nil
}

func (s *PostgresRepository) GetActivePartitions(ctx context.Context) ([]*model.PartitionInfo, error) {
rows, err := s.dbpool.Query(ctx, `SELECT * FROM partitions WHERE deleted_at IS NULL`)
if err != nil {
return nil, fmt.Errorf("could not get partitions from DB: %v", err)
}
defer rows.Close()

var partitions []*dao.PartitionInfo
var partitions []*model.PartitionInfo
for rows.Next() {
var p dao.PartitionInfo
var id string
var p model.PartitionInfo
if err := rows.Scan(
&id,
&p.Id,
&p.ClusterID,
&p.Name,
&p.Capacity.Capacity,
Expand All @@ -77,6 +114,7 @@ func (s *PostgresRepository) GetAllPartitions(ctx context.Context) ([]*dao.Parti
&p.TotalContainers,
&p.State,
&p.LastStateTransitionTime,
&p.DeletedAt,
); err != nil {
return nil, fmt.Errorf("could not scan partition from DB: %v", err)
}
Expand All @@ -88,3 +126,26 @@ func (s *PostgresRepository) GetAllPartitions(ctx context.Context) ([]*dao.Parti
}
return partitions, nil
}

// DeleteInactivePartitions deletes partitions that are not in the list of activePartitions.
func (s *PostgresRepository) DeleteInactivePartitions(ctx context.Context, activePartitions []*dao.PartitionInfo) error {
partitionNames := make([]string, len(activePartitions))
for i, p := range activePartitions {
partitionNames[i] = p.Name
}
deletedAt := time.Now().Unix()
query := `
UPDATE partitions
SET deleted_at = $1
WHERE id IN (
SELECT id
FROM partitions
WHERE deleted_at IS NULL AND NOT(name = ANY($2))
)
`
_, err := s.dbpool.Exec(ctx, query, deletedAt, partitionNames)
if err != nil {
return fmt.Errorf("could not mark partitions as deleted in DB: %w", err)
}
return nil
}
4 changes: 3 additions & 1 deletion internal/database/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ type Repository interface {
GetNodeUtilizations(ctx context.Context) ([]*dao.PartitionNodesUtilDAOInfo, error)
GetNodesPerPartition(ctx context.Context, partition string) ([]*dao.NodeDAOInfo, error)
UpsertPartitions(ctx context.Context, partitions []*dao.PartitionInfo) error
GetAllPartitions(ctx context.Context) ([]*dao.PartitionInfo, error)
GetAllPartitions(ctx context.Context) ([]*model.PartitionInfo, error)
GetActivePartitions(ctx context.Context) ([]*model.PartitionInfo, error)
DeleteInactivePartitions(ctx context.Context, activePartitions []*dao.PartitionInfo) error
AddQueues(ctx context.Context, parentId *string, queues []*dao.PartitionQueueDAOInfo) error
UpdateQueue(ctx context.Context, queue *dao.PartitionQueueDAOInfo) error
UpsertQueues(ctx context.Context, queues []*dao.PartitionQueueDAOInfo) error
Expand Down
6 changes: 6 additions & 0 deletions internal/model/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ type PartitionQueueDAOInfo struct {
CreatedAt *int64 `json:"createdAt,omitempty"`
DeletedAt *int64 `json:"deletedAt,omitempty"`
}

type PartitionInfo struct {
Id string `json:"id"`
dao.PartitionInfo `json:",inline"`
DeletedAt *int64 `json:"deletedAt,omitempty"`
}
2 changes: 1 addition & 1 deletion internal/yunikorn/event_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (s *Service) handleQueueEvents(ctx context.Context, events []*si.EventRecor
func (s *Service) handleQueueAddEvent(ctx context.Context) {
logger := log.FromContext(ctx)

partitions, err := s.upsertPartitions(ctx)
partitions, err := s.syncPartitions(ctx)
if err != nil {
logger.Errorf("could not get partitions: %v", err)
return
Expand Down
23 changes: 14 additions & 9 deletions internal/yunikorn/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// sync fetches the state of the applications from the Yunikorn API and upserts them into the database
func (s *Service) sync(ctx context.Context) error {
partitions, err := s.upsertPartitions(ctx)
partitions, err := s.syncPartitions(ctx)
if err != nil {
return fmt.Errorf("error getting and upserting partitions: %v", err)
}
Expand Down Expand Up @@ -74,8 +74,8 @@ func (s *Service) sync(ctx context.Context) error {
return nil
}

// upsertPartitions fetches partitions from the Yunikorn API and upserts them into the database
func (s *Service) upsertPartitions(ctx context.Context) ([]*dao.PartitionInfo, error) {
// syncPartitions fetches partitions from the Yunikorn API and syncs them into the database
func (s *Service) syncPartitions(ctx context.Context) ([]*dao.PartitionInfo, error) {
logger := log.FromContext(ctx)
// Get partitions from Yunikorn API and upsert into DB
partitions, err := s.client.GetPartitions(ctx)
Expand All @@ -84,11 +84,16 @@ func (s *Service) upsertPartitions(ctx context.Context) ([]*dao.PartitionInfo, e
}

err = s.workqueue.Add(func(ctx context.Context) error {
logger.Infow("upserting partitions", "count", len(partitions))
return s.repo.UpsertPartitions(ctx, partitions)
}, workqueue.WithJobName("upsert_partitions"))
logger.Infow("syncing partitions", "count", len(partitions))
err := s.repo.UpsertPartitions(ctx, partitions)
if err != nil {
return fmt.Errorf("could not upsert partitions: %w", err)
}
// Delete partitions that are not present in the API response
return s.repo.DeleteInactivePartitions(ctx, partitions)
}, workqueue.WithJobName("sync_partitions"))
if err != nil {
logger.Errorf("could not add upsert partitions job to workqueue: %v", err)
logger.Errorf("could not add sync_partitions job to workqueue: %v", err)
}

return partitions, nil
Expand Down Expand Up @@ -152,7 +157,7 @@ func (s *Service) syncPartitionQueues(ctx context.Context, partition *dao.Partit

queues := flattenQueues([]*dao.PartitionQueueDAOInfo{queue})
// Find candidates for deletion
deleteCandidates, err := s.findDeleteCandidates(ctx, partition, queues)
deleteCandidates, err := s.findQueueDeleteCandidates(ctx, partition, queues)
if err != nil {
return nil, fmt.Errorf("failed to find delete candidates: %w", err)
}
Expand All @@ -165,7 +170,7 @@ func (s *Service) syncPartitionQueues(ctx context.Context, partition *dao.Partit
return queues, nil
}

func (s *Service) findDeleteCandidates(
func (s *Service) findQueueDeleteCandidates(
ctx context.Context,
partition *dao.PartitionInfo,
apiQueues []*dao.PartitionQueueDAOInfo) ([]*model.PartitionQueueDAOInfo, error) {
Expand Down
129 changes: 129 additions & 0 deletions internal/yunikorn/sync_int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,126 @@ func TestSync_syncQueues_Integration(t *testing.T) {
}
}

func TestSync_syncPartitions_Integration(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

pool, repo, cleanupDB := setupDatabase(t, ctx)
t.Cleanup(cleanupDB)
eventRepository := repository.NewInMemoryEventRepository()

tests := []struct {
name string
setup func() *httptest.Server
existingPartitions []*dao.PartitionInfo
expected []*model.PartitionInfo
wantErr bool
}{
{
name: "Sync partition with no existing partitions in DB",
setup: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := []*dao.PartitionInfo{
{Name: "default"},
{Name: "secondary"},
}
writeResponse(t, w, response)
}))
},
existingPartitions: nil,
expected: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "default"}},
{PartitionInfo: dao.PartitionInfo{Name: "secondary"}},
},
wantErr: false,
},
{
name: "Should mark secondary partition as deleted in DB",
setup: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := []*dao.PartitionInfo{
{Name: "default"},
}
writeResponse(t, w, response)
}))
},
existingPartitions: []*dao.PartitionInfo{
{Name: "default"},
{Name: "secondary"},
},
expected: []*model.PartitionInfo{
{PartitionInfo: dao.PartitionInfo{Name: "default"}},
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// clean up the table after the test
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM partitions")
require.NoError(t, err)
})
// seed the existing partitions
if tt.existingPartitions != nil {
if err := repo.UpsertPartitions(ctx, tt.existingPartitions); err != nil {
t.Fatalf("could not seed partition: %v", err)
}
}

ts := tt.setup()
defer ts.Close()

client := NewRESTClient(getMockServerYunikornConfig(t, ts.URL))
s := NewService(repo, eventRepository, client)

// Start the service
ctx, cancel := context.WithCancel(context.Background())
go func() {
_ = s.Run(ctx)
}()

// Ensure workqueue is started
assert.Eventually(t, func() bool {
return s.workqueue.Started()
}, 500*time.Millisecond, 50*time.Millisecond)
time.Sleep(100 * time.Millisecond)

// Cleanup after each test case
t.Cleanup(func() {
cancel()
s.workqueue.Shutdown()
})

_, err := s.syncPartitions(ctx)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
var partitionsInDB []*model.PartitionInfo
assert.Eventually(t, func() bool {
partitionsInDB, err = s.repo.GetActivePartitions(ctx)
if err != nil {
t.Logf("error getting partitions: %v", err)
}
return len(partitionsInDB) == len(tt.expected)
}, 5*time.Second, 50*time.Millisecond)

for _, target := range tt.expected {
if !isPartitionPresent(partitionsInDB, target) {
t.Errorf("Partition %s is not found in the DB", target.Name)
}
}
})
}
}

func isQueuePresent(queuesInDB []*model.PartitionQueueDAOInfo, targetQueue *model.PartitionQueueDAOInfo) bool {
for _, dbQueue := range queuesInDB {
if dbQueue.QueueName == targetQueue.QueueName && dbQueue.Partition == targetQueue.Partition {
Expand All @@ -482,6 +602,15 @@ func extractPartitionNameFromURL(urlPath string) string {
return ""
}

func isPartitionPresent(partitionsInDB []*model.PartitionInfo, targetPartition *model.PartitionInfo) bool {
for _, dbPartition := range partitionsInDB {
if dbPartition.Name == targetPartition.Name {
return true
}
}
return false
}

func setupDatabase(t *testing.T, ctx context.Context) (*pgxpool.Pool, repository.Repository, func()) {
schema := database.CreateTestSchema(ctx, t)
cfg := config.GetTestPostgresConfig()
Expand Down
Loading