From 981f4c542d06539ee4e356698c4cfbe535552e1c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 21 Feb 2024 10:20:34 +0000 Subject: [PATCH] Prevent dataset to break if already exists (#9) --- lightning_data/processing/data_processor.py | 29 +++++--- lightning_data/processing/utilities.py | 78 +++++++++++++++++++-- 2 files changed, 91 insertions(+), 16 deletions(-) diff --git a/lightning_data/processing/data_processor.py b/lightning_data/processing/data_processor.py index fdaf83a9..a9347902 100644 --- a/lightning_data/processing/data_processor.py +++ b/lightning_data/processing/data_processor.py @@ -2,6 +2,7 @@ import json import logging import os +import random import shutil import signal import tempfile @@ -17,7 +18,7 @@ from urllib import parse import numpy as np -from lightning import seed_everything +import torch from tqdm.auto import tqdm as _tqdm from lightning_data.constants import ( @@ -29,6 +30,7 @@ _TORCH_GREATER_EQUAL_2_1_0, ) from lightning_data.processing.readers import BaseReader +from lightning_data.processing.utilities import _create_dataset from lightning_data.streaming import Cache from lightning_data.streaming.cache import Dir from lightning_data.streaming.client import S3Client @@ -41,7 +43,6 @@ if _LIGHTNING_CLOUD_LATEST: from lightning_cloud.openapi import V1DatasetType - from lightning_cloud.utils.dataset import _create_dataset if _BOTO3_AVAILABLE: @@ -427,7 +428,7 @@ def _loop(self) -> None: uploader.join() if self.remove: - assert self.remover # noqa: S101 + assert self.remover self.remove_queue.put(None) self.remover.join() @@ -487,7 +488,7 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None: if isinstance(data, str): assert os.path.exists(data), data else: - assert os.path.exists(data[-1]), data # noqa: S101 + assert os.path.exists(data[-1]), data self.to_upload_queues[self._counter % self.num_uploaders].put(data) @@ -724,7 +725,12 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]]) num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]]) - data_format = tree_unflatten(config["config"]["data_format"], treespec_loads(config["config"]["data_spec"])) + if config["config"] is not None: + data_format = tree_unflatten( + config["config"]["data_format"], treespec_loads(config["config"]["data_spec"]) + ) + else: + data_format = None num_chunks = len(config["chunks"]) # The platform can't store more than 1024 entries. @@ -735,7 +741,7 @@ def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Resul size=size, num_bytes=num_bytes, data_format=data_format, - compression=config["config"]["compression"], + compression=config["config"]["compression"] if config["config"] else None, num_chunks=len(config["chunks"]), num_bytes_per_chunk=num_bytes_per_chunk, ) @@ -772,7 +778,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra # Get the index file locally for node_rank in range(num_nodes - 1): output_dir_path = output_dir.url if output_dir.url else output_dir.path - assert output_dir_path # noqa: S101 + assert output_dir_path remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}") node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath)) if obj.scheme == "s3": @@ -874,7 +880,9 @@ def run(self, data_recipe: DataRecipe) -> None: print(f"Setup started with fast_dev_run={self.fast_dev_run}.") # Force random seed to be fixed - seed_everything(self.random_seed) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + torch.manual_seed(self.random_seed) # Call the setup method of the user user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None) @@ -941,7 +949,7 @@ def run(self, data_recipe: DataRecipe) -> None: error = self.error_queue.get(timeout=0.001) self._exit_on_error(error) except Empty: - assert self.progress_queue # noqa: S101 + assert self.progress_queue try: index, counter = self.progress_queue.get(timeout=0.001) except Empty: @@ -973,7 +981,8 @@ def run(self, data_recipe: DataRecipe) -> None: print("Workers are finished.") result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir) - if num_nodes == node_rank + 1 and self.output_dir.url: + if num_nodes == node_rank + 1 and self.output_dir.url and _IS_IN_STUDIO: + assert self.output_dir.path _create_dataset( input_dir=self.input_dir.path, storage_dir=self.output_dir.path, diff --git a/lightning_data/processing/utilities.py b/lightning_data/processing/utilities.py index 9e160839..a049fd81 100644 --- a/lightning_data/processing/utilities.py +++ b/lightning_data/processing/utilities.py @@ -2,10 +2,76 @@ import os import urllib from contextlib import contextmanager -from subprocess import Popen # noqa: S404 -from typing import Any, Callable, Optional, Tuple +from subprocess import DEVNULL, Popen +from typing import Any, Callable, List, Optional, Tuple, Union + +from lightning_data.constants import _IS_IN_STUDIO, _LIGHTNING_CLOUD_LATEST + +if _LIGHTNING_CLOUD_LATEST: + from lightning_cloud.openapi import ( + ProjectIdDatasetsBody, + V1DatasetType, + ) + from lightning_cloud.openapi.rest import ApiException + from lightning_cloud.rest_client import LightningClient + + +def _create_dataset( + input_dir: Optional[str], + storage_dir: str, + dataset_type: V1DatasetType, + empty: Optional[bool] = None, + size: Optional[int] = None, + num_bytes: Optional[str] = None, + data_format: Optional[Union[str, Tuple[str]]] = None, + compression: Optional[str] = None, + num_chunks: Optional[int] = None, + num_bytes_per_chunk: Optional[List[int]] = None, + name: Optional[str] = None, + version: Optional[int] = None, +) -> None: + """Create a dataset with metadata information about its source and destination.""" + project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) + cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) + user_id = os.getenv("LIGHTNING_USER_ID", None) + cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) + lightning_app_id = os.getenv("LIGHTNING_CLOUD_APP_ID", None) + + if project_id is None: + return + + if not storage_dir: + raise ValueError("The storage_dir should be defined.") + + client = LightningClient(retry=False) -from lightning_data.constants import _IS_IN_STUDIO + try: + client.dataset_service_create_dataset( + body=ProjectIdDatasetsBody( + cloud_space_id=cloud_space_id if lightning_app_id is None else None, + cluster_id=cluster_id, + creator_id=user_id, + empty=empty, + input_dir=input_dir, + lightning_app_id=lightning_app_id, + name=name, + size=size, + num_bytes=num_bytes, + data_format=str(data_format) if data_format else data_format, + compression=compression, + num_chunks=num_chunks, + num_bytes_per_chunk=num_bytes_per_chunk, + storage_dir=storage_dir, + type=dataset_type, + version=version, + ), + project_id=project_id, + ) + except ApiException as ex: + if "already exists" in str(ex.body): + pass + else: + raise ex def get_worker_rank() -> Optional[str]: @@ -29,12 +95,12 @@ def _wrapper(*args: Any, **kwargs: Any) -> Tuple[Any, Optional[Exception]]: def make_request( url: str, timeout: int = 10, - user_agent_token: str = "lit-data", + user_agent_token: str = "pytorch-lightning", ) -> io.BytesIO: """Download an image with urllib.""" user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" if user_agent_token: - user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/lit-data)" + user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/Lightning-AI/pytorch-lightning)" with urllib.request.urlopen( # noqa: S310 urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}), timeout=timeout @@ -68,7 +134,7 @@ def optimize_dns(enable: bool) -> None: f"sudo /home/zeus/miniconda3/envs/cloudspace/bin/python" f" -c 'from lightning_data.processing.utilities import _optimize_dns; _optimize_dns({enable})'" ) - Popen(cmd, shell=True).wait() # E501 + Popen(cmd, shell=True, stdout=DEVNULL, stderr=DEVNULL).wait() # E501 def _optimize_dns(enable: bool) -> None: