Skip to content

Commit

Permalink
feat(nodeadm): PCIe detection for nvidia GPU instances
Browse files Browse the repository at this point in the history
  • Loading branch information
ndbaker1 committed Feb 8, 2025
1 parent e80a713 commit 61313e7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
28 changes: 27 additions & 1 deletion nodeadm/internal/containerd/runtime_config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package containerd

import (
"os"
"slices"
"strings"

Expand All @@ -15,11 +16,34 @@ type instanceOptions struct {
type instanceTypeMixin struct {
instanceFamilies []string
apply func() instanceOptions
pcieVendorName string
pcieDevicesPath string
}

func (m *instanceTypeMixin) matches(instanceType string) bool {
instanceFamily := strings.Split(instanceType, ".")[0]
return slices.Contains(m.instanceFamilies, instanceFamily)
return slices.Contains(m.instanceFamilies, instanceFamily) || m.matchesPCIeVendor()
}

func (m *instanceTypeMixin) matchesPCIeVendor() bool {
if len(m.pcieVendorName) == 0 {
return false
}
devices, err := os.ReadFile(m.pcieDevicesPath)
if err != nil {
zap.L().Error("Failed to read PCIe devices", zap.Error(err))
return false
}
// The contents of '/proc/bus/pci/devices' looks like the following, where
// the last column contains the vendor name if present:
//
// 0018 1d0f1111 0 c1000008 0 0 0 0 0 c0002 400000 0 0 0 0 0 20000
// 0020 1d0f8061 b c1508000 0 0 0 0 0 0 4000 0 0 0 0 0 0 nvme
// 0028 1d0fec20 0 c1504000 0 c1400008 0 0 0 0 4000 0 100000 0 0 0 0 ena
// 00f0 10de1eb8 a c0000000 44000000c 0 45000000c 0 0 0 1000000 10000000 0 2000000 0 0 0 nvidia
// 00f8 1d0fcd01 0 c1500000 0 c150c008 0 0 0 0 4000 0 2000 0 0 0 0 nvme
// 0030 1d0fec20 0 c1510000 0 c1600008 0 0 0 0 4000 0 100000 0 0 0 0 ena
return strings.Contains(string(devices), m.pcieVendorName)
}

var (
Expand All @@ -28,6 +52,8 @@ var (
NvidiaInstanceTypeMixin = instanceTypeMixin{
instanceFamilies: nvidiaInstances,
apply: applyNvidia,
pcieVendorName: "nvidia",
pcieDevicesPath: "/proc/bus/pci/devices",
}

mixins = []instanceTypeMixin{
Expand Down
30 changes: 30 additions & 0 deletions nodeadm/internal/containerd/runtime_config_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package containerd

import (
"os"
"path/filepath"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

func TestApplyInstanceTypeMixins(t *testing.T) {
Expand All @@ -29,3 +33,29 @@ func TestApplyInstanceTypeMixins(t *testing.T) {
}
}
}

func TestPCIeDetection(t *testing.T) {
t.Run("Matches", func(t *testing.T) {
mixin := instanceTypeMixin{
pcieDevicesPath: filepath.Join(t.TempDir(), "devices"),
pcieVendorName: "nvidia",
}
assert.NoError(t, os.WriteFile(mixin.pcieDevicesPath, []byte("nvidia"), 0777))
assert.True(t, mixin.matches("x.x"))
assert.True(t, mixin.matchesPCIeVendor())
})
t.Run("NotMatchesBecauseDifferent", func(t *testing.T) {
mixin := instanceTypeMixin{
pcieDevicesPath: filepath.Join(t.TempDir(), "devices"),
pcieVendorName: "nvidia",
}
assert.NoError(t, os.WriteFile(mixin.pcieDevicesPath, []byte("nvme"), 0777))
assert.False(t, mixin.matches("x.x"))
assert.False(t, mixin.matchesPCIeVendor())
})
t.Run("NotMatchesBecauseMissing", func(t *testing.T) {
mixin := instanceTypeMixin{}
assert.False(t, mixin.matches("x.x"))
assert.False(t, mixin.matchesPCIeVendor())
})
}

0 comments on commit 61313e7

Please sign in to comment.