Skip to content

Commit

Permalink
Prevent dataset to break if already exists (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Feb 21, 2024
1 parent 4309257 commit 981f4c5
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 16 deletions.
29 changes: 19 additions & 10 deletions lightning_data/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import random
import shutil
import signal
import tempfile
Expand All @@ -17,7 +18,7 @@
from urllib import parse

import numpy as np
from lightning import seed_everything
import torch
from tqdm.auto import tqdm as _tqdm

from lightning_data.constants import (
Expand All @@ -29,6 +30,7 @@
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning_data.processing.readers import BaseReader
from lightning_data.processing.utilities import _create_dataset
from lightning_data.streaming import Cache
from lightning_data.streaming.cache import Dir
from lightning_data.streaming.client import S3Client
Expand All @@ -41,7 +43,6 @@

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.openapi import V1DatasetType
from lightning_cloud.utils.dataset import _create_dataset


if _BOTO3_AVAILABLE:
Expand Down Expand Up @@ -427,7 +428,7 @@ def _loop(self) -> None:
uploader.join()

if self.remove:
assert self.remover # noqa: S101
assert self.remover
self.remove_queue.put(None)
self.remover.join()

Expand Down Expand Up @@ -487,7 +488,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
if isinstance(data, str):
assert os.path.exists(data), data
else:
assert os.path.exists(data[-1]), data # noqa: S101
assert os.path.exists(data[-1]), data

self.to_upload_queues[self._counter % self.num_uploaders].put(data)

Expand Down Expand Up @@ -724,7 +725,12 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul

size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
data_format = tree_unflatten(config["config"]["data_format"], treespec_loads(config["config"]["data_spec"]))
if config["config"] is not None:
data_format = tree_unflatten(
config["config"]["data_format"], treespec_loads(config["config"]["data_spec"])
)
else:
data_format = None
num_chunks = len(config["chunks"])

# The platform can't store more than 1024 entries.
Expand All @@ -735,7 +741,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
size=size,
num_bytes=num_bytes,
data_format=data_format,
compression=config["config"]["compression"],
compression=config["config"]["compression"] if config["config"] else None,
num_chunks=len(config["chunks"]),
num_bytes_per_chunk=num_bytes_per_chunk,
)
Expand Down Expand Up @@ -772,7 +778,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
# Get the index file locally
for node_rank in range(num_nodes - 1):
output_dir_path = output_dir.url if output_dir.url else output_dir.path
assert output_dir_path # noqa: S101
assert output_dir_path
remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}")
node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath))
if obj.scheme == "s3":
Expand Down Expand Up @@ -874,7 +880,9 @@ def run(self, data_recipe: DataRecipe) -> None:
print(f"Setup started with fast_dev_run={self.fast_dev_run}.")

# Force random seed to be fixed
seed_everything(self.random_seed)
random.seed(self.random_seed)
np.random.seed(self.random_seed)
torch.manual_seed(self.random_seed)

# Call the setup method of the user
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None)
Expand Down Expand Up @@ -941,7 +949,7 @@ def run(self, data_recipe: DataRecipe) -> None:
error = self.error_queue.get(timeout=0.001)
self._exit_on_error(error)
except Empty:
assert self.progress_queue # noqa: S101
assert self.progress_queue
try:
index, counter = self.progress_queue.get(timeout=0.001)
except Empty:
Expand Down Expand Up @@ -973,7 +981,8 @@ def run(self, data_recipe: DataRecipe) -> None:
print("Workers are finished.")
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == node_rank + 1 and self.output_dir.url:
if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO:
assert self.output_dir.path
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
Expand Down
78 changes: 72 additions & 6 deletions lightning_data/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,76 @@
import os
import urllib
from contextlib import contextmanager
from subprocess import Popen # noqa: S404
from typing import Any, Callable, Optional, Tuple
from subprocess import DEVNULL, Popen
from typing import Any, Callable, List, Optional, Tuple, Union

from lightning_data.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_LATEST

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.openapi import (
ProjectIdDatasetsBody,
V1DatasetType,
)
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient


def _create_dataset(
input_dir: Optional[str],
storage_dir: str,
dataset_type: V1DatasetType,
empty: Optional[bool] = None,
size: Optional[int] = None,
num_bytes: Optional[str] = None,
data_format: Optional[Union[str, Tuple[str]]] = None,
compression: Optional[str] = None,
num_chunks: Optional[int] = None,
num_bytes_per_chunk: Optional[List[int]] = None,
name: Optional[str] = None,
version: Optional[int] = None,
) -> None:
"""Create a dataset with metadata information about its source and destination."""
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
user_id = os.getenv("LIGHTNING_USER_ID", None)
cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)
lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)

if project_id is None:
return

if not storage_dir:
raise ValueError("The storage_dir should be defined.")

client = LightningClient(retry=False)

from lightning_data.constants import _IS_IN_STUDIO
try:
client.dataset_service_create_dataset(
body=ProjectIdDatasetsBody(
cloud_space_id=cloud_space_id if lightning_app_id is None else None,
cluster_id=cluster_id,
creator_id=user_id,
empty=empty,
input_dir=input_dir,
lightning_app_id=lightning_app_id,
name=name,
size=size,
num_bytes=num_bytes,
data_format=str(data_format) if data_format else data_format,
compression=compression,
num_chunks=num_chunks,
num_bytes_per_chunk=num_bytes_per_chunk,
storage_dir=storage_dir,
type=dataset_type,
version=version,
),
project_id=project_id,
)
except ApiException as ex:
if "already exists" in str(ex.body):
pass
else:
raise ex


def get_worker_rank() -> Optional[str]:
Expand All @@ -29,12 +95,12 @@ def _wrapper(*args: Any, **kwargs: Any) -> Tuple[Any, Optional[Exception]]:
def make_request(
url: str,
timeout: int = 10,
user_agent_token: str = "lit-data",
user_agent_token: str = "pytorch-lightning",
) -> io.BytesIO:
"""Download an image with urllib."""
user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
if user_agent_token:
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/lit-data)"
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)"

with urllib.request.urlopen( # noqa: S310
urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}), timeout=timeout
Expand Down Expand Up @@ -68,7 +134,7 @@ def optimize_dns(enable: bool) -> None:
f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python"
f" -c 'from lightning_data.processing.utilities import _optimize_dns; _optimize_dns({enable})'"
)
Popen(cmd, shell=True).wait() # E501
Popen(cmd, shell=True, stdout=DEVNULL, stderr=DEVNULL).wait() # E501


def _optimize_dns(enable: bool) -> None:
Expand Down

0 comments on commit 981f4c5

Please sign in to comment.