Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent dataset to break if already exists #9

Merged
merged 29 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions lightning_data/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import random
import shutil
import signal
import tempfile
Expand All @@ -17,7 +18,7 @@
from urllib import parse

import numpy as np
from lightning import seed_everything
import torch
from tqdm.auto import tqdm as _tqdm

from lightning_data.constants import (
Expand All @@ -29,6 +30,7 @@
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning_data.processing.readers import BaseReader
from lightning_data.processing.utilities import _create_dataset
from lightning_data.streaming import Cache
from lightning_data.streaming.cache import Dir
from lightning_data.streaming.client import S3Client
Expand All @@ -41,7 +43,6 @@

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.openapi import V1DatasetType
from lightning_cloud.utils.dataset import _create_dataset


if _BOTO3_AVAILABLE:
Expand Down Expand Up @@ -427,7 +428,7 @@ def _loop(self) -> None:
uploader.join()

if self.remove:
assert self.remover # noqa: S101
assert self.remover
self.remove_queue.put(None)
self.remover.join()

Expand Down Expand Up @@ -487,7 +488,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
if isinstance(data, str):
assert os.path.exists(data), data
else:
assert os.path.exists(data[-1]), data # noqa: S101
assert os.path.exists(data[-1]), data

self.to_upload_queues[self._counter % self.num_uploaders].put(data)

Expand Down Expand Up @@ -724,7 +725,12 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul

size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
data_format = tree_unflatten(config["config"]["data_format"], treespec_loads(config["config"]["data_spec"]))
if config["config"] is not None:
data_format = tree_unflatten(
config["config"]["data_format"], treespec_loads(config["config"]["data_spec"])
)
else:
data_format = None
num_chunks = len(config["chunks"])

# The platform can't store more than 1024 entries.
Expand All @@ -735,7 +741,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul
size=size,
num_bytes=num_bytes,
data_format=data_format,
compression=config["config"]["compression"],
compression=config["config"]["compression"] if config["config"] else None,
num_chunks=len(config["chunks"]),
num_bytes_per_chunk=num_bytes_per_chunk,
)
Expand Down Expand Up @@ -772,7 +778,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
# Get the index file locally
for node_rank in range(num_nodes - 1):
output_dir_path = output_dir.url if output_dir.url else output_dir.path
assert output_dir_path # noqa: S101
assert output_dir_path
remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}")
node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath))
if obj.scheme == "s3":
Expand Down Expand Up @@ -874,7 +880,9 @@ def run(self, data_recipe: DataRecipe) -> None:
print(f"Setup started with fast_dev_run={self.fast_dev_run}.")

# Force random seed to be fixed
seed_everything(self.random_seed)
random.seed(self.random_seed)
np.random.seed(self.random_seed)
torch.manual_seed(self.random_seed)

# Call the setup method of the user
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None)
Expand Down Expand Up @@ -941,7 +949,7 @@ def run(self, data_recipe: DataRecipe) -> None:
error = self.error_queue.get(timeout=0.001)
self._exit_on_error(error)
except Empty:
assert self.progress_queue # noqa: S101
assert self.progress_queue
try:
index, counter = self.progress_queue.get(timeout=0.001)
except Empty:
Expand Down Expand Up @@ -973,7 +981,8 @@ def run(self, data_recipe: DataRecipe) -> None:
print("Workers are finished.")
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == node_rank + 1 and self.output_dir.url:
if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO:
assert self.output_dir.path
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
Expand Down
78 changes: 72 additions & 6 deletions lightning_data/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,76 @@
import os
import urllib
from contextlib import contextmanager
from subprocess import Popen # noqa: S404
from typing import Any, Callable, Optional, Tuple
from subprocess import DEVNULL, Popen
from typing import Any, Callable, List, Optional, Tuple, Union

from lightning_data.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_LATEST

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.openapi import (
ProjectIdDatasetsBody,
V1DatasetType,
)
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient


def _create_dataset(
input_dir: Optional[str],
storage_dir: str,
dataset_type: V1DatasetType,
empty: Optional[bool] = None,
size: Optional[int] = None,
num_bytes: Optional[str] = None,
data_format: Optional[Union[str, Tuple[str]]] = None,
compression: Optional[str] = None,
num_chunks: Optional[int] = None,
num_bytes_per_chunk: Optional[List[int]] = None,
name: Optional[str] = None,
version: Optional[int] = None,
) -> None:
"""Create a dataset with metadata information about its source and destination."""
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None)
user_id = os.getenv("LIGHTNING_USER_ID", None)
cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None)
lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None)

if project_id is None:
return

if not storage_dir:
raise ValueError("The storage_dir should be defined.")

client = LightningClient(retry=False)

from lightning_data.constants import _IS_IN_STUDIO
try:
client.dataset_service_create_dataset(
body=ProjectIdDatasetsBody(
cloud_space_id=cloud_space_id if lightning_app_id is None else None,
cluster_id=cluster_id,
creator_id=user_id,
empty=empty,
input_dir=input_dir,
lightning_app_id=lightning_app_id,
name=name,
size=size,
num_bytes=num_bytes,
data_format=str(data_format) if data_format else data_format,
compression=compression,
num_chunks=num_chunks,
num_bytes_per_chunk=num_bytes_per_chunk,
storage_dir=storage_dir,
type=dataset_type,
version=version,
),
project_id=project_id,
)
except ApiException as ex:
if "already exists" in str(ex.body):
pass
else:
raise ex


def get_worker_rank() -> Optional[str]:
Expand All @@ -29,12 +95,12 @@ def _wrapper(*args: Any, **kwargs: Any) -> Tuple[Any, Optional[Exception]]:
def make_request(
url: str,
timeout: int = 10,
user_agent_token: str = "lit-data",
user_agent_token: str = "pytorch-lightning",
) -> io.BytesIO:
"""Download an image with urllib."""
user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
if user_agent_token:
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/lit-data)"
user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)"

with urllib.request.urlopen( # noqa: S310
urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}), timeout=timeout
Expand Down Expand Up @@ -68,7 +134,7 @@ def optimize_dns(enable: bool) -> None:
f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python"
f" -c 'from lightning_data.processing.utilities import _optimize_dns; _optimize_dns({enable})'"
)
Popen(cmd, shell=True).wait() # E501
Popen(cmd, shell=True, stdout=DEVNULL, stderr=DEVNULL).wait() # E501


def _optimize_dns(enable: bool) -> None:
Expand Down
Loading