Skip to content

Commit

Permalink
Scheduler: NodeType does not need to be proto-generated (#3840)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert Smith <[email protected]>
  • Loading branch information
robertdavidsmith authored Aug 1, 2024
1 parent 8a64c0f commit 688302c
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 754 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,65 @@
package schedulerobjects
package internaltypes

import (
"github.com/segmentio/fasthash/fnv1a"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
v1 "k8s.io/api/core/v1"

koTaint "github.com/armadaproject/armada/internal/scheduler/kubernetesobjects/taint"
)

// NodeType represents a particular combination of taints and labels.
// The scheduler groups nodes by node type. When assigning pods to nodes,
// the scheduler only considers nodes with a NodeType for which the taints and labels match.
// Its fields should be immutable! Do not change these!
type NodeType struct {
// Unique identifier. Used for map lookup.
id uint64
// Kubernetes taints.
// To reduce the number of distinct node types,
// may contain only a subset of the taints of the node the node type is created from.
taints []v1.Taint
// Kubernetes labels.
// To reduce the number of distinct node types,
// may contain only a subset of the labels of the node the node type is created from.
labels map[string]string
// Well-known labels not set by this node type.
// Used to filter out nodes when looking for nodes for a pod
// that requires at least one well-known label to be set.
unsetIndexedLabels map[string]string
}

func (m *NodeType) GetId() uint64 {
return m.id
}

func (m *NodeType) GetTaints() []v1.Taint {
return koTaint.DeepCopyTaints(m.taints)
}

func (m *NodeType) FindMatchingUntoleratedTaint(tolerations ...[]v1.Toleration) (v1.Taint, bool) {
return koTaint.FindMatchingUntoleratedTaint(m.taints, tolerations...)
}

func (m *NodeType) GetLabels() map[string]string {
return deepCopyLabels(m.labels)
}

func (m *NodeType) GetLabelValue(key string) (string, bool) {
val, ok := m.labels[key]
return val, ok
}

func (m *NodeType) GetUnsetIndexedLabels() map[string]string {
return deepCopyLabels(m.unsetIndexedLabels)
}

func (m *NodeType) GetUnsetIndexedLabelValue(key string) (string, bool) {
val, ok := m.unsetIndexedLabels[key]
return val, ok
}

type (
taintsFilterFunc func(*v1.Taint) bool
labelsFilterFunc func(key, value string) bool
Expand Down Expand Up @@ -63,10 +116,10 @@ func NewNodeType(taints []v1.Taint, labels map[string]string, indexedTaints map[
}

return &NodeType{
Id: nodeTypeIdFromTaintsAndLabels(taints, labels, unsetIndexedLabels),
Taints: taints,
Labels: labels,
UnsetIndexedLabels: unsetIndexedLabels,
id: nodeTypeIdFromTaintsAndLabels(taints, labels, unsetIndexedLabels),
taints: taints,
labels: labels,
unsetIndexedLabels: unsetIndexedLabels,
}
}

Expand Down Expand Up @@ -139,15 +192,3 @@ func getFilteredLabels(labels map[string]string, inclusionFilter labelsFilterFun
}
return filteredLabels
}

func (nodeType *NodeType) DeepCopy() *NodeType {
if nodeType == nil {
return nil
}
return &NodeType{
Id: nodeType.Id,
Taints: slices.Clone(nodeType.Taints),
Labels: maps.Clone(nodeType.Labels),
UnsetIndexedLabels: maps.Clone(nodeType.UnsetIndexedLabels),
}
}
92 changes: 92 additions & 0 deletions internal/scheduler/internaltypes/node_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package internaltypes

import (
"testing"

"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"
)

func TestNodeType_GetId(t *testing.T) {
nodeType := makeSut()

assert.True(t, nodeType.GetId() != 0)
}

func TestNodeType_GetTaints(t *testing.T) {
nodeType := makeSut()

assert.Equal(t,
[]v1.Taint{
{Key: "taint1", Value: "value1", Effect: v1.TaintEffectNoSchedule},
{Key: "taint2", Value: "value2", Effect: v1.TaintEffectNoSchedule},
},
nodeType.GetTaints(),
)
}

func TestNodeType_FindMatchingUntoleratedTaint(t *testing.T) {
nodeType := makeSut()
taint, ok := nodeType.FindMatchingUntoleratedTaint([]v1.Toleration{{Key: "taint1", Operator: v1.TolerationOpExists, Effect: v1.TaintEffectNoSchedule}})

assert.True(t, ok)
assert.Equal(t,
v1.Taint{Key: "taint2", Value: "value2", Effect: v1.TaintEffectNoSchedule},
taint)
}

func TestNodeTypeLabels(t *testing.T) {
nodeType := makeSut()

assert.Equal(t,
map[string]string{
"label1": "value1",
"label2": "value2",
},
nodeType.GetLabels(),
)

val1, ok1 := nodeType.GetLabelValue("label1")
assert.Equal(t, val1, "value1")
assert.True(t, ok1)

val2, ok2 := nodeType.GetLabelValue("not-there")
assert.Equal(t, val2, "")
assert.False(t, ok2)

assert.Equal(t,
map[string]string{
"label3": "",
},
nodeType.GetUnsetIndexedLabels(),
)

val3, ok3 := nodeType.GetUnsetIndexedLabelValue("label3")
assert.Equal(t, val3, "")
assert.True(t, ok3)

val4, ok4 := nodeType.GetUnsetIndexedLabelValue("not-there")
assert.Equal(t, val4, "")
assert.False(t, ok4)
}

func makeSut() *NodeType {
taints := []v1.Taint{
{Key: "taint1", Value: "value1", Effect: v1.TaintEffectNoSchedule},
{Key: "not-indexed-taint", Value: "not-indexed-taint-value", Effect: v1.TaintEffectNoSchedule},
{Key: "taint2", Value: "value2", Effect: v1.TaintEffectNoSchedule},
}

labels := map[string]string{
"label1": "value1",
"label2": "value2",
"not-indexed-label;": "not-indexed-label-value",
}

return NewNodeType(
taints,
labels,
map[string]interface{}{"taint1": true, "taint2": true, "taint3": true},
map[string]interface{}{"label1": true, "label2": true, "label3": true},
)
}
20 changes: 10 additions & 10 deletions internal/scheduler/nodedb/nodedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (nodeDb *NodeDb) create(node *schedulerobjects.Node) (*internaltypes.Node,

totalResources := node.TotalResources

nodeType := schedulerobjects.NewNodeType(
nodeType := internaltypes.NewNodeType(
taints,
labels,
nodeDb.indexedTaints,
Expand Down Expand Up @@ -77,14 +77,14 @@ func (nodeDb *NodeDb) create(node *schedulerobjects.Node) (*internaltypes.Node,
}
index := uint64(nodeDb.numNodes)
nodeDb.numNodes++
nodeDb.numNodesByNodeType[nodeType.Id]++
nodeDb.numNodesByNodeType[nodeType.GetId()]++
nodeDb.totalResources.Add(totalResources)
nodeDb.nodeTypes[nodeType.Id] = nodeType
nodeDb.nodeTypes[nodeType.GetId()] = nodeType
nodeDb.mu.Unlock()

return internaltypes.CreateNode(
node.Id,
nodeType.Id,
nodeType.GetId(),
index,
node.Executor,
node.Name,
Expand Down Expand Up @@ -193,7 +193,7 @@ type NodeDb struct {
totalResources schedulerobjects.ResourceList
// Set of node types. Populated automatically as nodes are inserted.
// Node types are not cleaned up if all nodes of that type are removed from the NodeDb.
nodeTypes map[uint64]*schedulerobjects.NodeType
nodeTypes map[uint64]*internaltypes.NodeType

wellKnownNodeTypes map[string]*configuration.WellKnownNodeType

Expand Down Expand Up @@ -267,7 +267,7 @@ func NewNodeDb(
indexedTaints: mapFromSlice(indexedTaints),
indexedNodeLabels: mapFromSlice(indexedNodeLabels),
indexedNodeLabelValues: indexedNodeLabelValues,
nodeTypes: make(map[uint64]*schedulerobjects.NodeType),
nodeTypes: make(map[uint64]*internaltypes.NodeType),
wellKnownNodeTypes: make(map[string]*configuration.WellKnownNodeType),
numNodesByNodeType: make(map[uint64]int),
totalResources: schedulerobjects.ResourceList{Resources: make(map[string]resource.Quantity)},
Expand Down Expand Up @@ -353,7 +353,7 @@ func (nodeDb *NodeDb) String() string {
} else {
fmt.Fprint(w, "Node types:\n")
for _, nodeType := range nodeDb.nodeTypes {
fmt.Fprintf(w, " %d\n", nodeType.Id)
fmt.Fprintf(w, " %d\n", nodeType.GetId())
}
}
w.Flush()
Expand Down Expand Up @@ -1069,12 +1069,12 @@ func (nodeDb *NodeDb) NodeTypesMatchingJob(jctx *schedulercontext.JobSchedulingC
for _, nodeType := range nodeDb.nodeTypes {
matches, reason := NodeTypeJobRequirementsMet(nodeType, jctx)
if matches {
matchingNodeTypeIds = append(matchingNodeTypeIds, nodeType.Id)
matchingNodeTypeIds = append(matchingNodeTypeIds, nodeType.GetId())
} else if reason != nil {
s := nodeDb.stringFromPodRequirementsNotMetReason(reason)
numExcludedNodesByReason[s] += nodeDb.numNodesByNodeType[nodeType.Id]
numExcludedNodesByReason[s] += nodeDb.numNodesByNodeType[nodeType.GetId()]
} else {
numExcludedNodesByReason[PodRequirementsNotMetReasonUnknown] += nodeDb.numNodesByNodeType[nodeType.Id]
numExcludedNodesByReason[PodRequirementsNotMetReasonUnknown] += nodeDb.numNodesByNodeType[nodeType.GetId()]
}
}
return matchingNodeTypeIds, numExcludedNodesByReason, nil
Expand Down
4 changes: 2 additions & 2 deletions internal/scheduler/nodedb/nodeiteration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,11 @@ func gpuNodeTypeLabelToNodeTypeId(nodeTypeLabel string) uint64 {
}

func labelsToNodeTypeId(labels map[string]string) uint64 {
nodeType := schedulerobjects.NewNodeType(
nodeType := internaltypes.NewNodeType(
[]v1.Taint{},
labels,
mapFromSlice(testfixtures.TestIndexedTaints),
mapFromSlice(testfixtures.TestIndexedNodeLabels),
)
return nodeType.Id
return nodeType.GetId()
}
26 changes: 9 additions & 17 deletions internal/scheduler/nodedb/nodematching.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (

schedulercontext "github.com/armadaproject/armada/internal/scheduler/context"
"github.com/armadaproject/armada/internal/scheduler/internaltypes"
koTaint "github.com/armadaproject/armada/internal/scheduler/kubernetesobjects/taint"
"github.com/armadaproject/armada/internal/scheduler/schedulerobjects"
)

const (
Expand Down Expand Up @@ -126,24 +124,18 @@ func (err *InsufficientResources) String() string {
// NodeTypeJobRequirementsMet determines whether a pod can be scheduled on nodes of this NodeType.
// If the requirements are not met, it returns the reason for why.
// If the requirements can't be parsed, an error is returned.
func NodeTypeJobRequirementsMet(nodeType *schedulerobjects.NodeType, jctx *schedulercontext.JobSchedulingContext) (bool, PodRequirementsNotMetReason) {
matches, reason := TolerationRequirementsMet(nodeType.GetTaints(), jctx.AdditionalTolerations, jctx.PodRequirements.GetTolerations())
func NodeTypeJobRequirementsMet(nodeType *internaltypes.NodeType, jctx *schedulercontext.JobSchedulingContext) (bool, PodRequirementsNotMetReason) {
matches, reason := TolerationRequirementsMet(nodeType, jctx.AdditionalTolerations, jctx.PodRequirements.GetTolerations())
if !matches {
return matches, reason
}

nodeTypeLabels := nodeType.GetLabels()
nodeTypeLabelGetter := func(key string) (string, bool) {
val, ok := nodeTypeLabels[key]
return val, ok
}

matches, reason = NodeSelectorRequirementsMet(nodeTypeLabelGetter, nodeType.GetUnsetIndexedLabels(), jctx.AdditionalNodeSelectors)
matches, reason = NodeSelectorRequirementsMet(nodeType.GetLabelValue, nodeType.GetUnsetIndexedLabelValue, jctx.AdditionalNodeSelectors)
if !matches {
return matches, reason
}

return NodeSelectorRequirementsMet(nodeTypeLabelGetter, nodeType.GetUnsetIndexedLabels(), jctx.PodRequirements.GetNodeSelector())
return NodeSelectorRequirementsMet(nodeType.GetLabelValue, nodeType.GetUnsetIndexedLabelValue, jctx.PodRequirements.GetNodeSelector())
}

// JobRequirementsMet determines whether a job can be scheduled onto this node.
Expand Down Expand Up @@ -202,8 +194,8 @@ func DynamicJobRequirementsMet(allocatableResources internaltypes.ResourceList,
return matches, reason
}

func TolerationRequirementsMet(taints []v1.Taint, tolerations ...[]v1.Toleration) (bool, PodRequirementsNotMetReason) {
untoleratedTaint, hasUntoleratedTaint := koTaint.FindMatchingUntoleratedTaint(taints, tolerations...)
func TolerationRequirementsMet(nodeType *internaltypes.NodeType, tolerations ...[]v1.Toleration) (bool, PodRequirementsNotMetReason) {
untoleratedTaint, hasUntoleratedTaint := nodeType.FindMatchingUntoleratedTaint(tolerations...)
if hasUntoleratedTaint {
return false, &UntoleratedTaint{Taint: untoleratedTaint}
}
Expand All @@ -218,7 +210,7 @@ func NodeTolerationRequirementsMet(node *internaltypes.Node, tolerations ...[]v1
return true, nil
}

func NodeSelectorRequirementsMet(nodeLabelGetter func(string) (string, bool), unsetIndexedLabels, nodeSelector map[string]string) (bool, PodRequirementsNotMetReason) {
func NodeSelectorRequirementsMet(nodeLabelGetter func(string) (string, bool), unsetIndexedLabelGetter func(string) (string, bool), nodeSelector map[string]string) (bool, PodRequirementsNotMetReason) {
for label, podValue := range nodeSelector {
// If the label value differs between nodeLabels and the pod, always return false.
if nodeValue, ok := nodeLabelGetter(label); ok {
Expand All @@ -233,8 +225,8 @@ func NodeSelectorRequirementsMet(nodeLabelGetter func(string) (string, bool), un
// If unsetIndexedLabels is provided, return false only if this label is explicitly marked as not set.
//
// If unsetIndexedLabels is not provided, we assume that nodeLabels contains all labels and return false.
if unsetIndexedLabels != nil {
if _, ok := unsetIndexedLabels[label]; ok {
if unsetIndexedLabelGetter != nil {
if _, ok := unsetIndexedLabelGetter(label); ok {
return false, &MissingLabel{Label: label}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion internal/scheduler/nodedb/nodematching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) {
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
nodeType := schedulerobjects.NewNodeType(
nodeType := internaltypes.NewNodeType(
tc.Taints,
tc.Labels,
tc.IndexedTaints,
Expand Down
Loading

0 comments on commit 688302c

Please sign in to comment.