Skip to content

Commit

Permalink
Update KFTO MNIST multinode test to make it compatible for disconnect…
Browse files Browse the repository at this point in the history
…ed environment
  • Loading branch information
abhijeet-dhumal committed Jan 3, 2025
1 parent a5a6d39 commit f023590
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 7 deletions.
53 changes: 53 additions & 0 deletions tests/kfto/kfto_mnist_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
backend = "gloo"
}

storage_bucket_endpoint, storage_bucket_endpoint_exists := GetStorageBucketDefaultEndpoint()
storage_bucket_access_key_id, storage_bucket_access_key_id_exists := GetStorageBucketAccessKeyId()
storage_bucket_secret_key, storage_bucket_secret_key_exists := GetStorageBucketSecretKey()
storage_bucket_name, storage_bucket_name_exists := GetStorageBucketName()
storage_bucket_mnist_dir, storage_bucket_mnist_dir_exists := GetStorageBucketMnistDir()

tuningJob := &kftov1.PyTorchJob{
TypeMeta: metav1.TypeMeta{
APIVersion: corev1.SchemeGroupVersion.String(),
Expand Down Expand Up @@ -298,6 +304,53 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
}
}

if storage_bucket_endpoint_exists && storage_bucket_access_key_id_exists && storage_bucket_secret_key_exists && storage_bucket_name_exists && storage_bucket_mnist_dir_exists {
tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Containers[0].Env = []corev1.EnvVar{
{
Name: "AWS_DEFAULT_ENDPOINT",
Value: storage_bucket_endpoint,
},
{
Name: "AWS_ACCESS_KEY_ID",
Value: storage_bucket_access_key_id,
},
{
Name: "AWS_SECRET_ACCESS_KEY",
Value: storage_bucket_secret_key,
},
{
Name: "AWS_STORAGE_BUCKET",
Value: storage_bucket_name,
},
{
Name: "AWS_STORAGE_BUCKET_MNIST_DIR",
Value: storage_bucket_mnist_dir,
},
}
tuningJob.Spec.PyTorchReplicaSpecs["Worker"].Template.Spec.Containers[0].Env = []corev1.EnvVar{
{
Name: "AWS_DEFAULT_ENDPOINT",
Value: storage_bucket_endpoint,
},
{
Name: "AWS_ACCESS_KEY_ID",
Value: storage_bucket_access_key_id,
},
{
Name: "AWS_SECRET_ACCESS_KEY",
Value: storage_bucket_secret_key,
},
{
Name: "AWS_STORAGE_BUCKET",
Value: storage_bucket_name,
},
{
Name: "AWS_STORAGE_BUCKET_MNIST_DIR",
Value: storage_bucket_mnist_dir,
},
}
}

tuningJob, err := test.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Create(test.Ctx(), tuningJob, metav1.CreateOptions{})
test.Expect(err).NotTo(HaveOccurred())
test.T().Logf("Created PytorchJob %s/%s successfully", tuningJob.Namespace, tuningJob.Name)
Expand Down
77 changes: 72 additions & 5 deletions tests/kfto/resources/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from torch.utils.data import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

import gzip
import shutil
from minio import Minio

class Net(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -206,17 +208,82 @@ def main():
dist.init_process_group(backend=args.backend)
model = nn.parallel.DistributedDataParallel(model)

if all(var in os.environ for var in ["AWS_DEFAULT_ENDPOINT","AWS_ACCESS_KEY_ID","AWS_SECRET_ACCESS_KEY","AWS_STORAGE_BUCKET","AWS_STORAGE_BUCKET_MNIST_DIR"]):
print("Using provided storage bucket to download datasets...")
dataset_dir = os.path.join("../data/", "MNIST/raw")
endpoint = os.environ.get("AWS_DEFAULT_ENDPOINT")
access_key = os.environ.get("AWS_ACCESS_KEY_ID")
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
bucket_name = os.environ.get("AWS_STORAGE_BUCKET")
print(f"Storage bucket endpoint: {endpoint}")
print(f"Storage bucket name: {bucket_name}\n")

# remove prefix if specified in storage bucket endpoint url
secure = True
if endpoint.startswith("https://"):
endpoint = endpoint[len("https://") :]
elif endpoint.startswith("http://"):
endpoint = endpoint[len("http://") :]
secure = False

client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
cert_check=False,
secure=secure
)
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
else:
print(f"Directory '{dataset_dir}' already exists")

# To download datasets from storage bucket's specific directory, use prefix to provide directory name
prefix=os.environ.get("AWS_STORAGE_BUCKET_MNIST_DIR")
print(f"Storage bucket MNIST directory prefix: {prefix}\n")

# download all files from prefix folder of storage bucket recursively
for item in client.list_objects(
bucket_name, prefix=prefix, recursive=True
):
file_name=item.object_name[len(prefix)+1:]
dataset_file_path = os.path.join(dataset_dir, file_name)
print(f"Downloading dataset file {file_name} to {dataset_file_path}..")
if not os.path.exists(dataset_file_path):
client.fget_object(
bucket_name, item.object_name, dataset_file_path
)
# Unzip files --
## Sample zipfilepath : ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
with gzip.open(dataset_file_path, "rb") as f_in:
filename=file_name.split(".")[0] #-> t10k-images-idx3-ubyte
file_path=("/".join(dataset_file_path.split("/")[:-1])) #->../data/MNIST/raw
full_file_path=os.path.join(file_path,filename) #->../data/MNIST/raw/t10k-images-idx3-ubyte
print(f"Extracting {dataset_file_path} to {file_path}..")

with open(full_file_path, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
print(f"Dataset file downloaded : {full_file_path}\n")
# delete zip file
os.remove(dataset_file_path)
else:
print(f"File-path '{dataset_file_path}' already exists")
download_datasets = False
else:
print("Using default MNIST mirror references to download datasets ...")
download_datasets = True

# Get FashionMNIST train and test dataset.
train_ds = datasets.FashionMNIST(
train_ds = datasets.MNIST(
"../data",
train=True,
download=True,
download=download_datasets,
transform=transforms.Compose([transforms.ToTensor()]),
)
test_ds = datasets.FashionMNIST(
test_ds = datasets.MNIST(
"../data",
train=False,
download=True,
download=download_datasets,
transform=transforms.Compose([transforms.ToTensor()]),
)
# Add train and test loaders.
Expand Down
3 changes: 2 additions & 1 deletion tests/kfto/resources/requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
torchvision==0.19.0
tensorboard==2.18.0
fsspec[http]==2024.6.1
numpy==2.0.2
numpy==2.0.2
minio==7.2.13
3 changes: 2 additions & 1 deletion tests/kfto/resources/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torchvision==0.19.0
tensorboard==2.18.0
fsspec[http]==2024.6.1
numpy==2.0.2
numpy==2.0.2
minio==7.2.13

0 comments on commit f023590

Please sign in to comment.