Skip to content

Commit

Permalink
Add tests to run KFTO pytorch MNIST training using multi-node/multi-g…
Browse files Browse the repository at this point in the history
…pu usecases
  • Loading branch information
abhijeet-dhumal committed Jan 6, 2025
1 parent 0dcc478 commit 48c546c
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions tests/kfto/kfto_mnist_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,26 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestPyTorchJobMnistCpu(t *testing.T) {
func TestPyTorchJobMnistMultiNodeCpu(t *testing.T) {
runKFTOPyTorchMnistJob(t, 0, 2, "", GetCudaTrainingImage(), "resources/requirements.txt")
}
func TestPyTorchJobMnistWithCuda(t *testing.T) {

func TestPyTorchJobMnistMultiNodeWithCuda(t *testing.T) {
runKFTOPyTorchMnistJob(t, 1, 1, "nvidia.com/gpu", GetCudaTrainingImage(), "resources/requirements.txt")
}

func TestPyTorchJobMnistWithROCm(t *testing.T) {
func TestPyTorchJobMnistMultiNodeWithROCm(t *testing.T) {
runKFTOPyTorchMnistJob(t, 1, 1, "amd.com/gpu", GetROCmTrainingImage(), "resources/requirements-rocm.txt")
}

func TestPyTorchJobMnistMultiNodeMultiGpuWithCuda(t *testing.T) {
runKFTOPyTorchMnistJob(t, 2, 1, "nvidia.com/gpu", GetCudaTrainingImage(), "resources/requirements.txt")
}

func TestPyTorchJobMnistMultiNodeMultiGpuWithROCm(t *testing.T) {
runKFTOPyTorchMnistJob(t, 2, 1, "amd.com/gpu", GetROCmTrainingImage(), "resources/requirements-rocm.txt")
}

func runKFTOPyTorchMnistJob(t *testing.T, numGpus int, workerReplicas int, gpuLabel string, image string, requirementsFile string) {
test := With(t)

Expand Down

0 comments on commit 48c546c

Please sign in to comment.