@@ -32,15 +32,15 @@ import (
32
32
)
33
33
34
34
func TestMnistPyTorchAppWrapperCpu (t * testing.T ) {
35
- runMnistPyTorchAppWrapper (t , "cpu" , 0 )
35
+ runMnistPyTorchAppWrapper (t , CPU )
36
36
}
37
37
38
38
func TestMnistPyTorchAppWrapperGpu (t * testing.T ) {
39
- runMnistPyTorchAppWrapper (t , "gpu" , 1 )
39
+ runMnistPyTorchAppWrapper (t , NVIDIA )
40
40
}
41
41
42
42
// Trains the MNIST dataset as a batch Job in an AppWrapper, and asserts successful completion of the training job.
43
- func runMnistPyTorchAppWrapper (t * testing.T , accelerator string , numberOfGpus int ) {
43
+ func runMnistPyTorchAppWrapper (t * testing.T , accelerator Accelerator ) {
44
44
test := With (t )
45
45
46
46
// Create a namespace
@@ -51,7 +51,7 @@ func runMnistPyTorchAppWrapper(t *testing.T, accelerator string, numberOfGpus in
51
51
defer func () {
52
52
_ = test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
53
53
}()
54
- clusterQueue := createClusterQueue (test , resourceFlavor , numberOfGpus )
54
+ clusterQueue := createClusterQueue (test , resourceFlavor , accelerator )
55
55
defer func () {
56
56
_ = test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
57
57
}()
@@ -109,7 +109,7 @@ func runMnistPyTorchAppWrapper(t *testing.T, accelerator string, numberOfGpus in
109
109
{Name : "MNIST_DATASET_URL" , Value : GetMnistDatasetURL ()},
110
110
{Name : "PIP_INDEX_URL" , Value : GetPipIndexURL ()},
111
111
{Name : "PIP_TRUSTED_HOST" , Value : GetPipTrustedHost ()},
112
- {Name : "ACCELERATOR" , Value : accelerator },
112
+ {Name : "ACCELERATOR" , Value : accelerator . Type },
113
113
},
114
114
Command : []string {"/bin/sh" , "-c" , "pip install -r /test/requirements.txt && torchrun /test/mnist.py" },
115
115
VolumeMounts : []corev1.VolumeMount {
0 commit comments