From 8b3bb8f9c08462ac30983b3cee0c9ae1dd72f28d Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Tue, 21 Mar 2023 10:48:02 -0400 Subject: [PATCH 1/8] Initial creation of abstract StorageBackend, AssetServiceBackend, and LocalFileBackend. Some asset storage operations implemented. --- src/stability_sdk/api.py | 353 ++++++++++++++++++++++++++++++++------- 1 file changed, 293 insertions(+), 60 deletions(-) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 18813a87..18c56b00 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -1,20 +1,24 @@ import grpc import json import logging +import os import random +import shutil import time +import uuid import warnings from google.protobuf.struct_pb2 import Struct from PIL import Image from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from abc import ABC, abstractmethod try: import cv2 import numpy as np except ImportError: warnings.warn( - "Failed to import animation reqs. To use the animation toolchain, install the requisite dependencies via:" + "Failed to import animation reqs. To use the animation toolchain, install the requisite dependencies via:" " pip install --upgrade stability_sdk[anim]" ) @@ -27,6 +31,7 @@ from .utils import ( image_mix, + image_to_png_bytes, image_to_prompt, tensor_to_prompt, ) @@ -40,7 +45,7 @@ def open_channel(host: str, api_key: str = None, max_message_len: int = 10*1024* options=[ ('grpc.max_send_message_length', max_message_len), ('grpc.max_receive_message_length', max_message_len), - ] + ] if host.endswith(":443"): call_credentials = [grpc.access_token_call_credentials(api_key)] channel_credentials = grpc.composite_channel_credentials( @@ -68,45 +73,35 @@ def __init__(self, stub, engine_id): self.stub = stub self.engine_id = engine_id -class Project(): - def __init__(self, context: 'Context', project: project.Project): + +class StorageBackend(ABC): + def __init__(self, project_id: str, project_file_id: str, context: 'Context', primary: bool = False, primary_fs: bool = False): + self._project_id = project_id + self._project_file_id = project_file_id self._context = context - self._project = project + self.primary = primary + self.primary_fs = primary_fs - @property - def id(self) -> str: - return self._project.id + @abstractmethod + def load_settings(self) -> dict: + pass - @property - def file_id(self) -> str: - return self._project.file.id + @abstractmethod + def save_settings(self, data: dict) -> str: + pass - @property - def title(self) -> str: - return self._project.title + @abstractmethod + def put_image_asset(self, image: Union[Image.Image, np.ndarray], use: generation.AssetUse, name: str = None) -> str: + pass - @staticmethod - def create( - context: 'Context', - title: str, - access: project.ProjectAccess=project.PROJECT_ACCESS_PRIVATE, - status: project.ProjectStatus=project.PROJECT_STATUS_ACTIVE - ) -> 'Project': - req = project.CreateProjectRequest(title=title, access=access, status=status) - proj: project.Project = context._proj_stub.Create(req, wait_for_ready=True) - return Project(context, proj) + @abstractmethod + def put_video_asset(self, video_path: str, asset_id: str) -> str: + pass - def delete(self): - self._context._proj_stub.Delete(project.DeleteProjectRequest(id=self.id)) - @staticmethod - def list_projects(context: 'Context') -> List['Project']: - list_req = project.ListProjectRequest(owner_id="") - results = [] - for proj in context._proj_stub.List(list_req, wait_for_ready=True): - results.append(Project(context, proj)) - results.sort(key=lambda x: x.title.lower()) - return results +class AssetServiceBackend(StorageBackend): + def __init__(self, project_id: str, project_file_id: str, context: 'Context', primary: bool = False): + super().__init__(project_id, project_file_id, context, primary) def load_settings(self) -> dict: request = generation.Request( @@ -115,19 +110,20 @@ def load_settings(self) -> dict: artifact=generation.Artifact( type=generation.ARTIFACT_TEXT, mime="application/json", - uuid=self.file_id, + uuid=self._project_file_id, ) )], asset=generation.AssetParameters( - action=generation.ASSET_GET, - project_id=self.id, + action=generation.ASSET_GET, + project_id=self._project_id, use=generation.ASSET_USE_PROJECT ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: return json.loads(results[generation.ARTIFACT_TEXT][0]) - raise Exception(f"Failed to load project file for {self.id}") + raise Exception(f"Failed to load project file for {self._project_id}") + def save_settings(self, data: dict) -> str: contents = json.dumps(data) @@ -138,48 +134,261 @@ def save_settings(self, data: dict) -> str: type=generation.ARTIFACT_TEXT, text=contents, mime="application/json", - uuid=self.file_id + uuid=self._project_file_id ) )], asset=generation.AssetParameters( - action=generation.ASSET_PUT, - project_id=self.id, + action=generation.ASSET_PUT, + project_id=self._project_id, use=generation.ASSET_USE_PROJECT ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to save project file for {self.id}") + raise Exception(f"Failed to save project file for {self._project_id}") - def put_image_asset( - self, - image: Union[Image.Image, np.ndarray], - use: generation.AssetUse=generation.ASSET_USE_OUTPUT - ): + def put_image_asset(self, image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: request = generation.Request( engine_id=self._context._asset.engine_id, prompt=[image_to_prompt(image)], asset=generation.AssetParameters( - action=generation.ASSET_PUT, - project_id=self.id, + action=generation.ASSET_PUT, + project_id=self._project_id, use=use ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to store image asset for project {self.id}") + raise Exception(f"Failed to store image asset for project {self._project_id}") + + def get_image_asset(self, name: str, use: generation.AssetUse) -> str: + request = generation.Request( + engine_id=self._context._asset.engine_id, + prompt=[generation.Prompt( + artifact=generation.Artifact( + type=generation.ARTIFACT_TEXT, + mime="image/png", + uuid=name, + ) + )], + asset=generation.AssetParameters( + action=generation.ASSET_GET, + project_id=self._project_id, + use=generation.ASSET_USE_PROJECT + ) + ) + results = self._context._run_request(self._context._asset, request) + if generation.ARTIFACT_TEXT in results: + return results[generation.ARTIFACT_TEXT][0] + raise Exception(f"Failed to store image asset for project {self._project_id}") + + def put_video_asset(self, video_path: str, asset_id: str) -> str: + if not os.path.isfile(video_path) or not video_path.endswith(".mp4"): + raise ValueError("Invalid video file path. Must be an existing .mp4 file.") + + with open(video_path, "rb") as f: + binary_data = f.read() + + request = generation.Request( + engine_id=self._context._asset.engine_id, + prompt=[ + generation.Prompt( + artifact=generation.Artifact( + type=generation.ARTIFACT_VIDEO, + mime="video/mp4", + binary=binary_data, + ) + ) + ], + asset=generation.AssetParameters( + action=generation.ASSET_PUT, + project_id=self._project_id, + use=generation.ASSET_USE_INPUT, + ), + ) + results = self._context._run_request(self._context._asset, request) + if generation.ARTIFACT_TEXT in results: + return results[generation.ARTIFACT_TEXT][0] + raise Exception(f"Failed to store video asset for project {self._project_id}") + + +class LocalFileBackend(StorageBackend): + def __init__(self, project_id: str, project_file_id: str, context: 'Context', primary: bool = False, primary_fs: bool = True, projects_root = 'projects'): + super().__init__(project_id, project_file_id, context, primary, primary_fs = primary_fs) + self._projects_root = projects_root + + def load_settings(self) -> dict: + # TODO(ADAM): Implement + pass - def update(self, title:str=None, file_id:str=None, file_uri:str=None): + def save_settings(self, data: dict) -> str: + # TODO(ADAM): Implement + pass + + def put_image_asset(self, image: Union[Image.Image, np.ndarray], + use: generation.AssetUse, + asset_id: str = None) -> str: + png = image_to_png_bytes(image) + if asset_id is not None: + filename = asset_id + else: + if not self.primary: + raise ValueError("If name is None, then LocalFileBackend must be primary.") + filename = str(uuid.uuid4()) + output_path = self.get_path_for_asset(filename) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path + '.png', "wb") as file: + file.write(png) + return filename + + def put_video_asset(self, video_path: str, asset_id: str = None) -> str: + if not os.path.isfile(video_path) or not video_path.endswith(".mp4"): + raise ValueError("Invalid video file path. Must be an existing .mp4 file.") + + if asset_id is not None: + filename = asset_id + else: + if not self.primary: + raise ValueError("If name is None, then LocalFileBackend must be primary.") + filename = str(uuid.uuid4()) + output_path = self.get_path_for_asset(filename) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + shutil.copy(video_path, output_path) + return filename + + def get_path_for_asset(self, filename: str): + path = os.path.join(self._projects_root, self._project_id, filename) + return path + +class Project(): + def __init__(self, context: 'Context', project: project.Project): + ## __init__ could take backends: Optional[List[StorageBackend]] = None + # self._backends = backends if backends else [AssetServiceBackend(primary=True)] + self._backends = [AssetServiceBackend(project_id=project.id, project_file_id = project.file_id, context=context, primary=True), + LocalFileBackend(project_id=project.id, project_file_id = project.file_id, context=context, primary=False)] + self._context = context + self._project = project + self._metadata_index = self.load_metadata_index() + + def _primary_backend(self) -> Optional[StorageBackend]: + for backend in self._backends: + if backend.primary: + return backend + return None + + @property + def id(self) -> str: + return self._project.id + + @property + def file_id(self) -> str: + return self._project.file.id + + @property + def title(self) -> str: + return self._project.title + + @staticmethod + def create( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + ) -> 'Project': + req = project.CreateProjectRequest(title=title, access=access, status=status) + proj: project.Project = context._proj_stub.Create(req, wait_for_ready=True) + return Project(context, proj) + + @staticmethod + def get( + context: 'Context', + id: str + ) -> 'Project': + req = project.GetProjectRequest(id=id) + proj: project.Project = context._proj_stub.Get(req, wait_for_ready=True) + return Project(context, proj) + + def list_assets(self): + req = project.QueryAssetsRequest(id=self.id) + query_assets_response: project.QueryAssetsResponse = self._context._proj_stub.QueryAssets(req, + wait_for_ready=True) + return query_assets_response.assets + + def delete(self): + self._context._proj_stub.Delete(project.DeleteProjectRequest(id=self.id)) + + @staticmethod + def list_projects(context: 'Context') -> List['Project']: + list_req = project.ListProjectRequest(owner_id="") + results = [] + for proj in context._proj_stub.List(list_req, wait_for_ready=True): + results.append(Project(context, proj)) + results.sort(key=lambda x: x.title.lower()) + return results + + def load_settings(self) -> dict: + for backend in self._backends: + if backend.primary: + result = backend.load_settings() + return result + raise Exception(f"Failed to load project file for {self.id}") + + def save_settings(self, data: dict) -> str: + results = None + for backend in self._backends: + temp = backend.save_settings(data) + if backend.primary: + results = temp + return results + + def put_image_asset( + self, + image: Union[Image.Image, np.ndarray], + use: generation.AssetUse = generation.ASSET_USE_PROJECT + ): + results = [] + asset_id = None + filename = None + for backend in self._backends: + result = backend.put_image_asset(image, use, asset_id=asset_id) + if backend.primary: + rsplit_res = result.rsplit('/', 1) + asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] + results.append(asset_id) + if backend.primary_fs: + filename = result + mimetype = "image/png" + self.add_asset_metadata(asset_id, mimetype, filename) + return results + + def put_video_asset(self, video_path: str) -> List[str]: + results = [] + filename = None + asset_id = None + for backend in self._backends: + result = backend.put_video_asset(video_path, asset_id=asset_id) + if backend.primary: + rsplit_res = result.rsplit('/', 1) + asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] + results.append(asset_id) + if backend.primary_fs: + filename = result + # E.g.: {3: ['https://object.lga1.coreweave.com/stability-assets-dev/org-yP0GBrIgOnDA6wwfyohorEPw/178c0ff3-5e01-4e4e-9f49-278510d80289/b8912c3b-eb98-4c8e-b346-fe483ba17f83']} + mimetype = "video/mp4" + self.add_asset_metadata(asset_id, mimetype, filename) + return results + + def update(self, title: str = None, file_id: str = None, file_uri: str = None): file = project.ProjectAsset( id=file_id, uri=file_uri, use=project.PROJECT_ASSET_USE_PROJECT, ) if file_id and file_uri else None - + self._context._proj_stub.Update(project.UpdateProjectRequest( - id=self.id, + id=self.id, title=title, file=file )) @@ -191,9 +400,33 @@ def update(self, title:str=None, file_id:str=None, file_uri:str=None): if file_uri: self._project.file.uri = file_uri + def add_asset_metadata(self, asset_id: str, mime_type: str, filename: str) -> None: + # metadata_index = self.load_metadata_index() # I assume metadata is updated by each operation + self._metadata_index[asset_id] = { + "mime_type": mime_type + } + if filename is not None: + self._metadata_index[asset_id]["file_name"] = filename + self.save_metadata_index() + + def save_metadata_index(self, metadata_index: dict = None) -> None: + if metadata_index is None: + metadata_index = self._metadata_index + index_file = f"{self.id}_metadata_index.json" + with open(index_file, "w") as f: + json.dump(metadata_index, f) + + def load_metadata_index(self) -> dict: + index_file = f"{self.id}_metadata_index.json" + if os.path.exists(index_file): + with open(index_file, "r") as f: + metadata_index = json.load(f) + return metadata_index + return {} + class Context: - def __init__(self, host: str="", api_key: str=None, stub: generation_grpc.GenerationServiceStub=None): + def __init__(self, host: str = "", api_key: str = None, stub: generation_grpc.GenerationServiceStub = None): if not host and stub is None: raise Exception("Must provide either GRPC host or stub to Api") channel = open_channel(host, api_key) if host else None @@ -332,7 +565,7 @@ def inpaint( ) -> Dict[int, List[Union[np.ndarray, Any]]]: """ Apply inpainting to an image. - + :param image: Source image :param mask: Mask image with 0 for pixels to change and 255 for pixels to keep :param prompts: List of text prompts @@ -340,7 +573,7 @@ def inpaint( :param steps: Number of steps to run :param seed: Random seed :param samples: Number of samples to generate - :param cfg_scale: Classifier free guidance scale + :param cfg_scale: Classifier free guidance scale :param sampler: Sampler to use for the diffusion process :param init_strength: Strength of the initial image :param init_noise_scale: Scale of the initial noise @@ -396,7 +629,7 @@ def interpolate( elif ratios[0] == 1.0: return [images[1]] elif mode == generation.INTERPOLATE_LINEAR: - return [image_mix(images[0], images[1], ratios[0])] + return [image_mix(images[0], images[1], ratios[0])] p = [image_to_prompt(image) for image in images] request = generation.Request( @@ -584,7 +817,7 @@ def transform_3d( id=op_id, request=rq_transform, on_status=[generation.OnStatus(action=[generation.STAGE_ACTION_RETURN])] - ) + ) ]) results = self._run_request(self._transform, chain_rq) @@ -688,13 +921,13 @@ def _run_request( except ClassifierException as ce: if attempt == self._max_retries or not self._retry_obfuscation: raise ce - + for exceed in ce.classifier_result.exceeds: logger.warning(f"Received classifier obfuscation. Exceeded {exceed.name} threshold") for concept in exceed.concepts: if concept.HasField("threshold"): logger.warning(f" {concept.concept} ({concept.threshold})") - + if isinstance(request, generation.Request) and request.HasField("image"): self._adjust_request_for_retry(request, attempt) elif isinstance(request, generation.ChainRequest): From b208c3b9e88f6a674211c8578f03729697435596 Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Fri, 24 Mar 2023 15:03:23 -0400 Subject: [PATCH 2/8] Refactor StorageBackend backends to not be specific to a project. Implement some create and load operations. --- src/stability_sdk/api.py | 365 +++++++++++++++++++++++++++++---------- 1 file changed, 269 insertions(+), 96 deletions(-) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 18c56b00..5347835e 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -75,57 +75,111 @@ def __init__(self, stub, engine_id): class StorageBackend(ABC): - def __init__(self, project_id: str, project_file_id: str, context: 'Context', primary: bool = False, primary_fs: bool = False): - self._project_id = project_id - self._project_file_id = project_file_id + def __init__(self, context: 'Context', primary: bool = False, primary_fs: bool = False): self._context = context self.primary = primary self.primary_fs = primary_fs + @staticmethod @abstractmethod - def load_settings(self) -> dict: + def create( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + ) -> 'Project': + pass + + @staticmethod + @abstractmethod + def load_project( + context: 'Context', + id: str + ) -> 'Project': pass + @staticmethod @abstractmethod - def save_settings(self, data: dict) -> str: + def list_projects(context: 'Context') -> List['Project']: pass @abstractmethod - def put_image_asset(self, image: Union[Image.Image, np.ndarray], use: generation.AssetUse, name: str = None) -> str: + def load_settings(self, proj: 'Project') -> dict: pass @abstractmethod - def put_video_asset(self, video_path: str, asset_id: str) -> str: + def save_settings(self, proj: 'Project', data: dict, asset_id: str = None) -> str: + pass + + def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: + pass + + @abstractmethod + def put_image_asset(self, proj: 'Project', image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: + pass + + def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> bytes: + pass + + @abstractmethod + def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str) -> str: pass class AssetServiceBackend(StorageBackend): - def __init__(self, project_id: str, project_file_id: str, context: 'Context', primary: bool = False): - super().__init__(project_id, project_file_id, context, primary) + def __init__(self, context: 'Context', primary: bool = False): + super().__init__(context, primary) - def load_settings(self) -> dict: + @staticmethod + def create( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + ) -> 'Project': + req = project.CreateProjectRequest(title=title, access=access, status=status) + proj: project.Project = context._proj_stub.Create(req, wait_for_ready=True) + return Project(context, proj) + + @staticmethod + def load_project(context: 'Context', id: str) -> 'Project': + req = project.GetProjectRequest(id=id) + proj: project.Project = context._proj_stub.Get(req, wait_for_ready=True) + return Project(context, proj) + + @staticmethod + def list_projects(context: 'Context') -> List['Project']: + list_req = project.ListProjectRequest(owner_id="") + results = [] + for proj in context._proj_stub.List(list_req, wait_for_ready=True): + results.append(Project(context, proj)) + results.sort(key=lambda x: x.title.lower()) + return results + + def load_settings(self, proj: 'Project') -> dict: request = generation.Request( engine_id=self._context._asset.engine_id, prompt=[generation.Prompt( artifact=generation.Artifact( type=generation.ARTIFACT_TEXT, mime="application/json", - uuid=self._project_file_id, + uuid=proj.file_id, ) )], asset=generation.AssetParameters( action=generation.ASSET_GET, - project_id=self._project_id, + project_id=proj.id, use=generation.ASSET_USE_PROJECT ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: - return json.loads(results[generation.ARTIFACT_TEXT][0]) - raise Exception(f"Failed to load project file for {self._project_id}") + settings_json = json.loads(results[generation.ARTIFACT_TEXT][0]) + return settings_json + raise Exception(f"Failed to load project file for {proj.id}") - def save_settings(self, data: dict) -> str: + def save_settings(self, proj: 'Project', data: dict, asset_id: str = None) -> str: contents = json.dumps(data) request = generation.Request( engine_id=self._context._asset.engine_id, @@ -134,57 +188,56 @@ def save_settings(self, data: dict) -> str: type=generation.ARTIFACT_TEXT, text=contents, mime="application/json", - uuid=self._project_file_id + uuid=proj.file_id ) )], asset=generation.AssetParameters( action=generation.ASSET_PUT, - project_id=self._project_id, + project_id=proj.id, use=generation.ASSET_USE_PROJECT ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to save project file for {self._project_id}") + raise Exception(f"Failed to save project file for {proj.id}") - def put_image_asset(self, image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: + def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> str: request = generation.Request( engine_id=self._context._asset.engine_id, - prompt=[image_to_prompt(image)], + prompt=[generation.Prompt( + artifact=generation.Artifact(generation.ARTIFACT_IMAGE, mime="image/png", uuid=asset_id) + )], asset=generation.AssetParameters( - action=generation.ASSET_PUT, - project_id=self._project_id, + action=generation.ASSET_GET, + project_id=proj.id, use=use ) ) results = self._context._run_request(self._context._asset, request) - if generation.ARTIFACT_TEXT in results: - return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to store image asset for project {self._project_id}") + if generation.ARTIFACT_IMAGE in results: + return results[generation.ARTIFACT_IMAGE][0] + raise Exception(f"Failed to load image asset for project {proj.id}") - def get_image_asset(self, name: str, use: generation.AssetUse) -> str: + def put_image_asset(self, proj: 'Project', image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: request = generation.Request( engine_id=self._context._asset.engine_id, - prompt=[generation.Prompt( - artifact=generation.Artifact( - type=generation.ARTIFACT_TEXT, - mime="image/png", - uuid=name, - ) - )], + prompt=[image_to_prompt(image)], asset=generation.AssetParameters( - action=generation.ASSET_GET, - project_id=self._project_id, - use=generation.ASSET_USE_PROJECT + action=generation.ASSET_PUT, + project_id=proj.id, + use=use ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to store image asset for project {self._project_id}") + raise Exception(f"Failed to store image asset for project {proj.id}") - def put_video_asset(self, video_path: str, asset_id: str) -> str: + def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> str: + pass + + def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str) -> str: if not os.path.isfile(video_path) or not video_path.endswith(".mp4"): raise ValueError("Invalid video file path. Must be an existing .mp4 file.") @@ -204,30 +257,96 @@ def put_video_asset(self, video_path: str, asset_id: str) -> str: ], asset=generation.AssetParameters( action=generation.ASSET_PUT, - project_id=self._project_id, + project_id=proj.id, use=generation.ASSET_USE_INPUT, ), ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to store video asset for project {self._project_id}") + raise Exception(f"Failed to store video asset for project {proj.id}") class LocalFileBackend(StorageBackend): - def __init__(self, project_id: str, project_file_id: str, context: 'Context', primary: bool = False, primary_fs: bool = True, projects_root = 'projects'): - super().__init__(project_id, project_file_id, context, primary, primary_fs = primary_fs) - self._projects_root = projects_root + _projects_root = None - def load_settings(self) -> dict: - # TODO(ADAM): Implement - pass + def __init__(self, context: 'Context', primary: bool = False, primary_fs: bool = True, projects_root = 'projects'): + super().__init__(context, primary, primary_fs = primary_fs) + LocalFileBackend._projects_root = projects_root - def save_settings(self, data: dict) -> str: - # TODO(ADAM): Implement - pass + @staticmethod + def create( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + ) -> 'Project': + proj_id = str(uuid.uuid4()) + proj_file_id = proj_id # str(uuid.uuid4()) # Let's keep it the same as the proj_id for now.. + proj = {"id": proj_id, + "title": title, + "file": {"id": proj_file_id}} + output_path = os.path.join(LocalFileBackend._projects_root, proj_id, proj_file_id) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as file: + json.dump(proj, file) + return Project(context, proj) - def put_image_asset(self, image: Union[Image.Image, np.ndarray], + @staticmethod + def load_project(context: 'Context', id: str) -> 'Project': + input_path = os.path.join(LocalFileBackend._projects_root, id, id) + with open(input_path, "r") as file: + proj = json.load(file) + return Project(context, proj) + + @staticmethod + def list_projects(context: 'Context') -> List['Project']: + # TODO: Replace with something more reliable than listing directories + # It's not reliable because the user might create directories there for various reasons. + # It may however be useful to require existence of the directory as an additional filter. + proj_root = LocalFileBackend._projects_root + all_entries = os.listdir(proj_root) + directories = [entry for entry in all_entries if os.path.isdir(os.path.join(proj_root, entry))] + projects = [] + for proj_id in directories: + proj_path = LocalFileBackend.get_path_for_asset(proj_id, proj_id) + try: + with open(proj_path, "r") as file: + proj_json = json.load(file) + proj_data = {"id": proj_json["id"], + "title": proj_json["title"], + "file": {"id": proj_json["file"]["id"]}} + projects.append(Project(context, proj_data)) + except FileNotFoundError: + pass + return projects + + def load_settings(self, proj: 'Project') -> dict: + input_path = self.get_path_for_asset(proj.id, proj.file_id) + with open(input_path, "r") as file: + settings_json = json.load(file) + return settings_json + + def save_settings(self, proj: 'Project', data: dict, asset_id: str = None) -> str: + if asset_id is not None: + filename = asset_id + else: + if not self.primary: + raise ValueError("If asset_id is None, then LocalFileBackend must be primary.") + filename = str(uuid.uuid4()) + output_path = self.get_path_for_asset(proj.id, filename) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as file: + json.dump(data, file) + return filename + + def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: + input_path = self.get_path_for_asset(proj.id, asset_id) + pil_image = Image.open(input_path) + return pil_image + + def put_image_asset(self, proj: 'Project', + image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: png = image_to_png_bytes(image) @@ -235,15 +354,21 @@ def put_image_asset(self, image: Union[Image.Image, np.ndarray], filename = asset_id else: if not self.primary: - raise ValueError("If name is None, then LocalFileBackend must be primary.") + raise ValueError("If asset_id is None, then LocalFileBackend must be primary.") filename = str(uuid.uuid4()) - output_path = self.get_path_for_asset(filename) + output_path = self.get_path_for_asset(proj.id, filename) os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path + '.png', "wb") as file: file.write(png) return filename - def put_video_asset(self, video_path: str, asset_id: str = None) -> str: + def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> bytes: + input_path = self.get_path_for_asset(proj.id, asset_id) + with open(input_path, 'rb') as file: + binary_data = file.read() + return binary_data + + def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str = None) -> str: if not os.path.isfile(video_path) or not video_path.endswith(".mp4"): raise ValueError("Invalid video file path. Must be an existing .mp4 file.") @@ -253,31 +378,46 @@ def put_video_asset(self, video_path: str, asset_id: str = None) -> str: if not self.primary: raise ValueError("If name is None, then LocalFileBackend must be primary.") filename = str(uuid.uuid4()) - output_path = self.get_path_for_asset(filename) + output_path = self.get_path_for_asset(proj.id, filename) os.makedirs(os.path.dirname(output_path), exist_ok=True) shutil.copy(video_path, output_path) return filename - def get_path_for_asset(self, filename: str): - path = os.path.join(self._projects_root, self._project_id, filename) + @staticmethod + def get_path_for_asset(project_id: str, filename: str): + path = os.path.join(LocalFileBackend._projects_root, project_id, filename) return path + class Project(): - def __init__(self, context: 'Context', project: project.Project): + _backends = None + _metadata_index = None + + def __init__(self, context: 'Context', proj: Union[project.Project, dict]): ## __init__ could take backends: Optional[List[StorageBackend]] = None # self._backends = backends if backends else [AssetServiceBackend(primary=True)] - self._backends = [AssetServiceBackend(project_id=project.id, project_file_id = project.file_id, context=context, primary=True), - LocalFileBackend(project_id=project.id, project_file_id = project.file_id, context=context, primary=False)] self._context = context - self._project = project - self._metadata_index = self.load_metadata_index() + + # proj should be project.Project or dict + # Currently, a supplied project.Project may contain additional properties that are ignored. + if isinstance(proj, dict): + self._project = project.Project() + self._project.id = proj["id"] + self._project.title = proj["title"] + self._project.file.id = proj["file"]["id"] + else: + self._project = proj def _primary_backend(self) -> Optional[StorageBackend]: - for backend in self._backends: + for backend in self.backends: if backend.primary: return backend return None + @property + def backends(self) -> str: + return Project._backends + @property def id(self) -> str: return self._project.id @@ -290,6 +430,11 @@ def file_id(self) -> str: def title(self) -> str: return self._project.title + @classmethod + def init_backends(cls, context: 'Context'): + cls._backends = [LocalFileBackend(context=context, primary=True)] + cls._metadata_index = cls.load_metadata_index() + @staticmethod def create( context: 'Context', @@ -297,18 +442,34 @@ def create( access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE ) -> 'Project': - req = project.CreateProjectRequest(title=title, access=access, status=status) - proj: project.Project = context._proj_stub.Create(req, wait_for_ready=True) - return Project(context, proj) - - @staticmethod - def get( + proj_file_id = '' + for backend in Project._backends: + proj = backend.create(context, title, access, status) + if isinstance(proj, dict): + proj_id = proj["id"] + proj_title = proj["title"] + else: + proj_id = proj.id + proj_title = proj.title + if backend.primary: + asset_id = proj_id + if backend.primary_fs: + filename = proj_id + proj_file_id = proj.file_id + mimetype = "application/json" + Project.add_asset_metadata(asset_id, asset_id, mimetype, filename, project_key="project_file_id") + return proj + + @classmethod + def get(cls, context: 'Context', id: str ) -> 'Project': - req = project.GetProjectRequest(id=id) - proj: project.Project = context._proj_stub.Get(req, wait_for_ready=True) - return Project(context, proj) + for backend in cls._backends: + if backend.primary: + proj = backend.load_project(context, id) + return proj + raise Exception(f"Failed to list projects") def list_assets(self): req = project.QueryAssetsRequest(id=self.id) @@ -319,29 +480,34 @@ def list_assets(self): def delete(self): self._context._proj_stub.Delete(project.DeleteProjectRequest(id=self.id)) - @staticmethod - def list_projects(context: 'Context') -> List['Project']: - list_req = project.ListProjectRequest(owner_id="") - results = [] - for proj in context._proj_stub.List(list_req, wait_for_ready=True): - results.append(Project(context, proj)) - results.sort(key=lambda x: x.title.lower()) - return results + @classmethod + def list_projects(cls, context: 'Context') -> List['Project']: + for backend in cls._backends: + if backend.primary: + results = backend.list_projects(context) + return results + raise Exception(f"Failed to list projects") def load_settings(self) -> dict: - for backend in self._backends: + for backend in self.backends: if backend.primary: - result = backend.load_settings() + result = backend.load_settings(self) return result raise Exception(f"Failed to load project file for {self.id}") def save_settings(self, data: dict) -> str: - results = None - for backend in self._backends: + asset_id = None + filename = None + for backend in self.backends: temp = backend.save_settings(data) if backend.primary: - results = temp - return results + rsplit_res = temp.rsplit('/', 1) + asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] + if backend.primary_fs: + filename = temp + mimetype = "application/json" + self.add_asset_metadata(asset_id, mimetype, filename, project_key="project_file_id") + return asset_id def put_image_asset( self, @@ -351,7 +517,7 @@ def put_image_asset( results = [] asset_id = None filename = None - for backend in self._backends: + for backend in self.backends: result = backend.put_image_asset(image, use, asset_id=asset_id) if backend.primary: rsplit_res = result.rsplit('/', 1) @@ -367,7 +533,7 @@ def put_video_asset(self, video_path: str) -> List[str]: results = [] filename = None asset_id = None - for backend in self._backends: + for backend in self.backends: result = backend.put_video_asset(video_path, asset_id=asset_id) if backend.primary: rsplit_res = result.rsplit('/', 1) @@ -400,24 +566,31 @@ def update(self, title: str = None, file_id: str = None, file_uri: str = None): if file_uri: self._project.file.uri = file_uri - def add_asset_metadata(self, asset_id: str, mime_type: str, filename: str) -> None: + @staticmethod + def add_asset_metadata(project_id: str, asset_id: str, mime_type: str, filename: str, project_key: str = None) -> None: # metadata_index = self.load_metadata_index() # I assume metadata is updated by each operation - self._metadata_index[asset_id] = { + if project_id not in Project._metadata_index: + Project._metadata_index[project_id] = {} + Project._metadata_index[project_id][asset_id] = { "mime_type": mime_type } if filename is not None: - self._metadata_index[asset_id]["file_name"] = filename - self.save_metadata_index() + Project._metadata_index[project_id][asset_id]["filename"] = filename + if project_key is not None: + Project._metadata_index[project_id][project_key] = asset_id + Project.save_metadata_index() - def save_metadata_index(self, metadata_index: dict = None) -> None: + @classmethod + def save_metadata_index(cls, metadata_index: dict = None) -> None: if metadata_index is None: - metadata_index = self._metadata_index - index_file = f"{self.id}_metadata_index.json" + metadata_index = cls._metadata_index + index_file = f"metadata_index.json" with open(index_file, "w") as f: json.dump(metadata_index, f) - def load_metadata_index(self) -> dict: - index_file = f"{self.id}_metadata_index.json" + @classmethod + def load_metadata_index(cls) -> dict: + index_file = "metadata_index.json" if os.path.exists(index_file): with open(index_file, "r") as f: metadata_index = json.load(f) From 18d8ffa2e0e9cfedd9ba9e8d623ca767123a56ae Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Fri, 24 Mar 2023 18:01:37 -0400 Subject: [PATCH 3/8] Verified saving and loading of project settings on asset service and local filesystem --- src/stability_sdk/animation_ui.py | 2 +- src/stability_sdk/api.py | 147 +++++++++++++++++++----------- 2 files changed, 95 insertions(+), 54 deletions(-) diff --git a/src/stability_sdk/animation_ui.py b/src/stability_sdk/animation_ui.py index d340de86..ff625108 100644 --- a/src/stability_sdk/animation_ui.py +++ b/src/stability_sdk/animation_ui.py @@ -253,7 +253,7 @@ def project_load(title: str): global project project = next(p for p in projects if p.title == title) try: - data = project.load_settings() + data = project.get_settings() except OutOfCreditsException as e: log = f"Not enough credits to load project '{title}'\n{e.details}" returns = args_to_controls(get_default_project()) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 5347835e..9376da33 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -82,33 +82,39 @@ def __init__(self, context: 'Context', primary: bool = False, primary_fs: bool = @staticmethod @abstractmethod - def create( + def create_project( context: 'Context', title: str, access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, - status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE, + proj_id_to_use: str = None ) -> 'Project': pass @staticmethod @abstractmethod - def load_project( + def get_project( context: 'Context', id: str ) -> 'Project': pass + @staticmethod + @abstractmethod + def delete_project(context: 'Context', id: str) -> None: + pass + @staticmethod @abstractmethod def list_projects(context: 'Context') -> List['Project']: pass @abstractmethod - def load_settings(self, proj: 'Project') -> dict: + def get_project_settings(self, proj: 'Project', asset_id: str = None) -> dict: pass @abstractmethod - def save_settings(self, proj: 'Project', data: dict, asset_id: str = None) -> str: + def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) -> str: pass def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: @@ -125,28 +131,36 @@ def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetU def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str) -> str: pass + def update_project(context: 'Context', proj: 'Project', title: str = None, file_id: str = None, file_uri: str = None) -> None: + pass + class AssetServiceBackend(StorageBackend): def __init__(self, context: 'Context', primary: bool = False): super().__init__(context, primary) @staticmethod - def create( + def create_project( context: 'Context', title: str, access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, - status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE, + proj_id_to_use: str = None ) -> 'Project': req = project.CreateProjectRequest(title=title, access=access, status=status) proj: project.Project = context._proj_stub.Create(req, wait_for_ready=True) return Project(context, proj) @staticmethod - def load_project(context: 'Context', id: str) -> 'Project': + def get_project(context: 'Context', id: str) -> 'Project': req = project.GetProjectRequest(id=id) proj: project.Project = context._proj_stub.Get(req, wait_for_ready=True) return Project(context, proj) + @staticmethod + def delete_project(context: 'Context', id: str) -> None: + context._proj_stub.Delete(project.DeleteProjectRequest(id=id)) + @staticmethod def list_projects(context: 'Context') -> List['Project']: list_req = project.ListProjectRequest(owner_id="") @@ -156,14 +170,15 @@ def list_projects(context: 'Context') -> List['Project']: results.sort(key=lambda x: x.title.lower()) return results - def load_settings(self, proj: 'Project') -> dict: + def get_project_settings(self, proj: 'Project', asset_id: str = None) -> dict: + asset_id = asset_id if asset_id else proj.file.id request = generation.Request( engine_id=self._context._asset.engine_id, prompt=[generation.Prompt( artifact=generation.Artifact( type=generation.ARTIFACT_TEXT, mime="application/json", - uuid=proj.file_id, + uuid=asset_id, ) )], asset=generation.AssetParameters( @@ -178,11 +193,10 @@ def load_settings(self, proj: 'Project') -> dict: return settings_json raise Exception(f"Failed to load project file for {proj.id}") - - def save_settings(self, proj: 'Project', data: dict, asset_id: str = None) -> str: + def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) -> str: contents = json.dumps(data) request = generation.Request( - engine_id=self._context._asset.engine_id, + engine_id=context._asset.engine_id, prompt=[generation.Prompt( artifact=generation.Artifact( type=generation.ARTIFACT_TEXT, @@ -197,7 +211,7 @@ def save_settings(self, proj: 'Project', data: dict, asset_id: str = None) -> st use=generation.ASSET_USE_PROJECT ) ) - results = self._context._run_request(self._context._asset, request) + results = context._run_request(context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] raise Exception(f"Failed to save project file for {proj.id}") @@ -266,6 +280,19 @@ def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str) -> st return results[generation.ARTIFACT_TEXT][0] raise Exception(f"Failed to store video asset for project {proj.id}") + def update_project(context: 'Context', proj: 'Project', title: str = None, file_id: str = None, file_uri: str = None) -> None: + file = project.ProjectAsset( + id=file_id, + uri=file_uri, + use=project.PROJECT_ASSET_USE_PROJECT, + ) if file_id and file_uri else None + + context._proj_stub.Update(project.UpdateProjectRequest( + id=proj.id, + title=title, + file=file + )) + class LocalFileBackend(StorageBackend): _projects_root = None @@ -275,14 +302,15 @@ def __init__(self, context: 'Context', primary: bool = False, primary_fs: bool = LocalFileBackend._projects_root = projects_root @staticmethod - def create( + def create_project( context: 'Context', title: str, access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, - status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE, + proj_id_to_use: str = None ) -> 'Project': - proj_id = str(uuid.uuid4()) - proj_file_id = proj_id # str(uuid.uuid4()) # Let's keep it the same as the proj_id for now.. + proj_id = proj_id_to_use if proj_id_to_use else str(uuid.uuid4()) + proj_file_id = proj_id # Let's keep it the same as the proj_id for now proj = {"id": proj_id, "title": title, "file": {"id": proj_file_id}} @@ -293,12 +321,19 @@ def create( return Project(context, proj) @staticmethod - def load_project(context: 'Context', id: str) -> 'Project': + def get_project(context: 'Context', id: str) -> 'Project': input_path = os.path.join(LocalFileBackend._projects_root, id, id) with open(input_path, "r") as file: proj = json.load(file) return Project(context, proj) + @staticmethod + def delete_project(context: 'Context', id: str): + if not id: + return + project_dir_path = os.path.join(LocalFileBackend._projects_root, id) + shutil.rmtree(project_dir_path) + @staticmethod def list_projects(context: 'Context') -> List['Project']: # TODO: Replace with something more reliable than listing directories @@ -321,19 +356,14 @@ def list_projects(context: 'Context') -> List['Project']: pass return projects - def load_settings(self, proj: 'Project') -> dict: - input_path = self.get_path_for_asset(proj.id, proj.file_id) + def get_project_settings(self, proj: 'Project', asset_id: str = None) -> dict: + input_path = self.get_path_for_asset(proj.id, "project_settings.json") with open(input_path, "r") as file: settings_json = json.load(file) return settings_json - def save_settings(self, proj: 'Project', data: dict, asset_id: str = None) -> str: - if asset_id is not None: - filename = asset_id - else: - if not self.primary: - raise ValueError("If asset_id is None, then LocalFileBackend must be primary.") - filename = str(uuid.uuid4()) + def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) -> str: + filename = "project_settings.json" output_path = self.get_path_for_asset(proj.id, filename) os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "w") as file: @@ -388,6 +418,17 @@ def get_path_for_asset(project_id: str, filename: str): path = os.path.join(LocalFileBackend._projects_root, project_id, filename) return path + def update_project(context: 'Context', proj: 'Project', title: str = None, file_id: str = None, file_uri: str = None): + proj_file_id = proj.file.id + proj = {"id": proj.id, + "title": title if title is not None else proj.title, + "file": {"id": file_id if file_id is not None else proj_file_id}} + output_path = os.path.join(LocalFileBackend._projects_root, proj.id, proj_file_id) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as file: + json.dump(proj, file) + return Project(context, proj) + class Project(): _backends = None @@ -432,7 +473,10 @@ def title(self) -> str: @classmethod def init_backends(cls, context: 'Context'): - cls._backends = [LocalFileBackend(context=context, primary=True)] + cls._backends = [ + AssetServiceBackend(context=context, primary=True), + LocalFileBackend(context=context, primary=False)] + #cls._backends = [LocalFileBackend(context=context, primary=True)] cls._metadata_index = cls.load_metadata_index() @staticmethod @@ -442,9 +486,9 @@ def create( access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE ) -> 'Project': - proj_file_id = '' + asset_id = None for backend in Project._backends: - proj = backend.create(context, title, access, status) + proj = backend.create_project(context, title, access, status, asset_id) if isinstance(proj, dict): proj_id = proj["id"] proj_title = proj["title"] @@ -457,7 +501,7 @@ def create( filename = proj_id proj_file_id = proj.file_id mimetype = "application/json" - Project.add_asset_metadata(asset_id, asset_id, mimetype, filename, project_key="project_file_id") + Project.add_asset_metadata(proj_id, asset_id, mimetype, filename, project_key="project_file_id") return proj @classmethod @@ -467,7 +511,7 @@ def get(cls, ) -> 'Project': for backend in cls._backends: if backend.primary: - proj = backend.load_project(context, id) + proj = backend.get_project(context, id) return proj raise Exception(f"Failed to list projects") @@ -478,7 +522,9 @@ def list_assets(self): return query_assets_response.assets def delete(self): - self._context._proj_stub.Delete(project.DeleteProjectRequest(id=self.id)) + for backend in self.backends: + backend.delete_project(self._context, self.id) + Project.delete_project_metadata(self.id) @classmethod def list_projects(cls, context: 'Context') -> List['Project']: @@ -488,10 +534,10 @@ def list_projects(cls, context: 'Context') -> List['Project']: return results raise Exception(f"Failed to list projects") - def load_settings(self) -> dict: + def get_settings(self) -> dict: for backend in self.backends: if backend.primary: - result = backend.load_settings(self) + result = backend.get_project_settings(self, self._metadata_index[self.id]["project_file_id"]) return result raise Exception(f"Failed to load project file for {self.id}") @@ -499,14 +545,14 @@ def save_settings(self, data: dict) -> str: asset_id = None filename = None for backend in self.backends: - temp = backend.save_settings(data) + temp = backend.put_project_settings(self._context, self, data) if backend.primary: rsplit_res = temp.rsplit('/', 1) asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] if backend.primary_fs: filename = temp mimetype = "application/json" - self.add_asset_metadata(asset_id, mimetype, filename, project_key="project_file_id") + Project.add_asset_metadata(self.id, asset_id, mimetype, filename, project_key="project_file_id") return asset_id def put_image_asset( @@ -526,7 +572,7 @@ def put_image_asset( if backend.primary_fs: filename = result mimetype = "image/png" - self.add_asset_metadata(asset_id, mimetype, filename) + Project.add_asset_metadata(self.id, asset_id, mimetype, filename) return results def put_video_asset(self, video_path: str) -> List[str]: @@ -543,22 +589,12 @@ def put_video_asset(self, video_path: str) -> List[str]: filename = result # E.g.: {3: ['https://object.lga1.coreweave.com/stability-assets-dev/org-yP0GBrIgOnDA6wwfyohorEPw/178c0ff3-5e01-4e4e-9f49-278510d80289/b8912c3b-eb98-4c8e-b346-fe483ba17f83']} mimetype = "video/mp4" - self.add_asset_metadata(asset_id, mimetype, filename) + Project.add_asset_metadata(self.id, asset_id, mimetype, filename) return results - def update(self, title: str = None, file_id: str = None, file_uri: str = None): - file = project.ProjectAsset( - id=file_id, - uri=file_uri, - use=project.PROJECT_ASSET_USE_PROJECT, - ) if file_id and file_uri else None - - self._context._proj_stub.Update(project.UpdateProjectRequest( - id=self.id, - title=title, - file=file - )) - + def update_project(self, title: str = None, file_id: str = None, file_uri: str = None): + for backend in self.backends: + result = backend.update_project(self._context, self, title, file_id, file_uri) if title: self._project.title = title if file_id: @@ -566,6 +602,7 @@ def update(self, title: str = None, file_id: str = None, file_uri: str = None): if file_uri: self._project.file.uri = file_uri + @staticmethod def add_asset_metadata(project_id: str, asset_id: str, mime_type: str, filename: str, project_key: str = None) -> None: # metadata_index = self.load_metadata_index() # I assume metadata is updated by each operation @@ -580,6 +617,10 @@ def add_asset_metadata(project_id: str, asset_id: str, mime_type: str, filename: Project._metadata_index[project_id][project_key] = asset_id Project.save_metadata_index() + @staticmethod + def delete_project_metadata(project_id: str) -> None: + Project._metadata_index.pop(project_id, None) + @classmethod def save_metadata_index(cls, metadata_index: dict = None) -> None: if metadata_index is None: From 044b7d29f4be660411b021b74f33186a1e24583d Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Sat, 25 Mar 2023 00:59:06 -0400 Subject: [PATCH 4/8] Implement storage backend get and put functions for image and video. get_video_asset for Asset Service backend isn't working. --- src/stability_sdk/api.py | 48 +++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 9376da33..abc98e5d 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -216,11 +216,11 @@ def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) return results[generation.ARTIFACT_TEXT][0] raise Exception(f"Failed to save project file for {proj.id}") - def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> str: + def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: request = generation.Request( engine_id=self._context._asset.engine_id, prompt=[generation.Prompt( - artifact=generation.Artifact(generation.ARTIFACT_IMAGE, mime="image/png", uuid=asset_id) + artifact=generation.Artifact(type=generation.ARTIFACT_IMAGE, mime="image/png", uuid=asset_id) )], asset=generation.AssetParameters( action=generation.ASSET_GET, @@ -230,8 +230,10 @@ def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetU ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_IMAGE in results: - return results[generation.ARTIFACT_IMAGE][0] - raise Exception(f"Failed to load image asset for project {proj.id}") + img = results[generation.ARTIFACT_IMAGE][0] + pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + return pil_img + raise Exception(f"Failed to load image asset {asset_id} for project {proj.id}") def put_image_asset(self, proj: 'Project', image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: request = generation.Request( @@ -249,7 +251,23 @@ def put_image_asset(self, proj: 'Project', image: Union[Image.Image, np.ndarray] raise Exception(f"Failed to store image asset for project {proj.id}") def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> str: - pass + request = generation.Request( + engine_id=self._context._asset.engine_id, + prompt=[generation.Prompt( + artifact=generation.Artifact(type=generation.ARTIFACT_VIDEO, mime="video/mp4", uuid=asset_id) + )], + asset=generation.AssetParameters( + action=generation.ASSET_GET, + project_id=proj.id, + use=use + ) + ) + results = self._context._run_request(self._context._asset, request) + # TODO: In testing so far, results contains ARTIFACT_VIDEO key.. but the value for it is an empty list. + # Thus it doesn't seem to be working. + if generation.ARTIFACT_VIDEO in results: + return results[generation.ARTIFACT_VIDEO][0] + raise Exception(f"Failed to load video asset {asset_id} for project {proj.id}") def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str) -> str: if not os.path.isfile(video_path) or not video_path.endswith(".mp4"): @@ -371,7 +389,7 @@ def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) return filename def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: - input_path = self.get_path_for_asset(proj.id, asset_id) + input_path = self.get_path_for_asset(proj.id, asset_id + '.png') pil_image = Image.open(input_path) return pil_image @@ -555,6 +573,13 @@ def save_settings(self, data: dict) -> str: Project.add_asset_metadata(self.id, asset_id, mimetype, filename, project_key="project_file_id") return asset_id + def get_image_asset(self, asset_id: str, use: generation.AssetUse = generation.ASSET_USE_PROJECT) -> Image.Image: + for backend in self.backends: + if backend.primary: + result = backend.get_image_asset(self, asset_id, use) + return result + raise Exception(f"Failed to load image asset {asset_id}") + def put_image_asset( self, image: Union[Image.Image, np.ndarray], @@ -564,7 +589,7 @@ def put_image_asset( asset_id = None filename = None for backend in self.backends: - result = backend.put_image_asset(image, use, asset_id=asset_id) + result = backend.put_image_asset(self, image, use, asset_id=asset_id) if backend.primary: rsplit_res = result.rsplit('/', 1) asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] @@ -575,12 +600,19 @@ def put_image_asset( Project.add_asset_metadata(self.id, asset_id, mimetype, filename) return results + def get_video_asset(self, asset_id: str, use: generation.AssetUse = generation.ASSET_USE_INPUT) -> bytes: + for backend in self.backends: + if backend.primary: + result = backend.get_video_asset(self, asset_id, use) + return result + raise Exception(f"Failed to load video asset {asset_id}") + def put_video_asset(self, video_path: str) -> List[str]: results = [] filename = None asset_id = None for backend in self.backends: - result = backend.put_video_asset(video_path, asset_id=asset_id) + result = backend.put_video_asset(self, video_path, asset_id=asset_id) if backend.primary: rsplit_res = result.rsplit('/', 1) asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] From d6e2ad4b2f503aa25b448f7298b19df8369c1f7a Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Sun, 26 Mar 2023 12:33:12 -0400 Subject: [PATCH 5/8] Clean up of comments in storage backend api --- src/stability_sdk/api.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index abc98e5d..6611973c 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -354,9 +354,7 @@ def delete_project(context: 'Context', id: str): @staticmethod def list_projects(context: 'Context') -> List['Project']: - # TODO: Replace with something more reliable than listing directories - # It's not reliable because the user might create directories there for various reasons. - # It may however be useful to require existence of the directory as an additional filter. + # This returns a listing of directories in the projects root. proj_root = LocalFileBackend._projects_root all_entries = os.listdir(proj_root) directories = [entry for entry in all_entries if os.path.isdir(os.path.join(proj_root, entry))] @@ -619,7 +617,6 @@ def put_video_asset(self, video_path: str) -> List[str]: results.append(asset_id) if backend.primary_fs: filename = result - # E.g.: {3: ['https://object.lga1.coreweave.com/stability-assets-dev/org-yP0GBrIgOnDA6wwfyohorEPw/178c0ff3-5e01-4e4e-9f49-278510d80289/b8912c3b-eb98-4c8e-b346-fe483ba17f83']} mimetype = "video/mp4" Project.add_asset_metadata(self.id, asset_id, mimetype, filename) return results From b7fc9d0232f1e794ab7569f4c4e29e94a2aca4cf Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Tue, 28 Mar 2023 14:32:14 -0400 Subject: [PATCH 6/8] Adds a working AssetServiceBackend get_video_asset(). Stored video assets are treated as binary data. --- src/stability_sdk/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 6611973c..7a015760 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -1137,6 +1137,8 @@ def _process_response(self, response) -> Dict[int, List[np.ndarray]]: results[artifact.type].append(artifact.tensor) elif artifact.type == generation.ARTIFACT_TEXT: results[artifact.type].append(artifact.text) + elif artifact.type == generation.ARTIFACT_VIDEO: + results[artifact.type].append(artifact.binary) return results def _run_request( From c148be4ec499e98e7441be4789817dcc84c0b089 Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Tue, 28 Mar 2023 14:33:36 -0400 Subject: [PATCH 7/8] Remove comment indicating that AssetServiceBackend get_video_asset doesn't work. --- src/stability_sdk/api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 7a015760..39856c67 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -263,8 +263,6 @@ def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetU ) ) results = self._context._run_request(self._context._asset, request) - # TODO: In testing so far, results contains ARTIFACT_VIDEO key.. but the value for it is an empty list. - # Thus it doesn't seem to be working. if generation.ARTIFACT_VIDEO in results: return results[generation.ARTIFACT_VIDEO][0] raise Exception(f"Failed to load video asset {asset_id} for project {proj.id}") From 9602b0a65eac038057a99d99f10e2c3d47e5d41c Mon Sep 17 00:00:00 2001 From: Adam Letts Date: Wed, 29 Mar 2023 13:15:00 -0400 Subject: [PATCH 8/8] Save project_settings.json with UTF-8. Raise ValueError if delete_project() for local filesystem has empty id. --- src/stability_sdk/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 39856c67..2ecd101a 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -346,7 +346,7 @@ def get_project(context: 'Context', id: str) -> 'Project': @staticmethod def delete_project(context: 'Context', id: str): if not id: - return + raise ValueError("Delete project requires a project id") project_dir_path = os.path.join(LocalFileBackend._projects_root, id) shutil.rmtree(project_dir_path) @@ -380,7 +380,7 @@ def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) filename = "project_settings.json" output_path = self.get_path_for_asset(proj.id, filename) os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, "w") as file: + with open(output_path, "w", encoding="utf-8") as file: json.dump(data, file) return filename