From 3091a6766a80cc9d051822dda6b9496937f679f3 Mon Sep 17 00:00:00 2001 From: "felix.wang" Date: Fri, 29 Dec 2023 16:49:53 +0800 Subject: [PATCH 1/5] chore: move model to hugging face --- server/clip_server/model/pretrained_models.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index 3494bee4b..c16a1c799 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -2,9 +2,11 @@ import hashlib import shutil import urllib +import requests _OPENCLIP_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch' +_OPENCLIP_HUGGINGFACE_BUCKET = 'https://huggingface.co/jinaai/' _OPENCLIP_MODELS = { 'RN50::openai': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), 'RN50::yfcc15m': ('RN50-yfcc15m.pt', 'e9c564f91ae7dc754d9043fdcd2a9f22'), @@ -143,6 +145,16 @@ def get_model_url_md5(name: str): if len(model_pretrained) == 0: # not on s3 return None, None else: + hg_download_url = _OPENCLIP_HUGGINGFACE_BUCKET + name.split('::')[0] + '/resolve/main/' + model_pretrained[0] + '?download=true' + try: + response = requests.head(hg_download_url) + if response.status_code in [200, 302] : + return (hg_download_url, model_pretrained[1]) + else: + print(f"Model not found on hugging face, trying to download from s3.") + except requests.exceptions.RequestException as e: + print(str(e)) + print(f"Model not found on hugging face, trying to download from s3.") return (_OPENCLIP_S3_BUCKET + '/' + model_pretrained[0], model_pretrained[1]) From 1fd6525f47341559ea98917bedc97e8b7f947d0f Mon Sep 17 00:00:00 2001 From: "felix.wang" Date: Fri, 29 Dec 2023 16:54:07 +0800 Subject: [PATCH 2/5] chore: black file --- server/clip_server/model/pretrained_models.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index c16a1c799..ecb46c747 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -134,7 +134,7 @@ def md5file(filename: str): hash_md5 = hashlib.md5() with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): + for chunk in iter(lambda: f.read(4096), b''): hash_md5.update(chunk) return hash_md5.hexdigest() @@ -151,16 +151,16 @@ def get_model_url_md5(name: str): if response.status_code in [200, 302] : return (hg_download_url, model_pretrained[1]) else: - print(f"Model not found on hugging face, trying to download from s3.") + print(f'Model not found on hugging face, trying to download from s3.') except requests.exceptions.RequestException as e: print(str(e)) - print(f"Model not found on hugging face, trying to download from s3.") + print(f'Model not found on hugging face, trying to download from s3.') return (_OPENCLIP_S3_BUCKET + '/' + model_pretrained[0], model_pretrained[1]) def download_model( url: str, - target_folder: str = os.path.expanduser("~/.cache/clip"), + target_folder: str = os.path.expanduser('~/.cache/clip'), md5sum: str = None, with_resume: bool = True, max_attempts: int = 3, @@ -187,14 +187,14 @@ def download_model( ) progress = Progress( - " \n", # divide this bar from Flow's bar - TextColumn("[bold blue]{task.fields[filename]}", justify="right"), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", + ' \n', # divide this bar from Flow's bar + TextColumn('[bold blue]{task.fields[filename]}', justify='right'), + '[progress.percentage]{task.percentage:>3.1f}%', + '•', DownloadColumn(), - "•", + '•', TransferSpeedColumn(), - "•", + '•', TimeRemainingColumn(), ) From 69b22ef87277d5b79303b619cc3869385e767653 Mon Sep 17 00:00:00 2001 From: "felix.wang" Date: Fri, 29 Dec 2023 16:55:51 +0800 Subject: [PATCH 3/5] chore: run black --- client/clip_client/__init__.py | 6 +- client/clip_client/client.py | 434 +++++++++--------- client/clip_client/helper.py | 17 +- client/setup.py | 112 ++--- docs/conf.py | 189 ++++---- scripts/benchmark.py | 34 +- scripts/get-last-release-note.py | 6 +- scripts/get-requirements.py | 6 +- scripts/setup.py | 84 ++-- server/clip_server/__init__.py | 2 +- server/clip_server/__main__.py | 10 +- server/clip_server/executors/clip_onnx.py | 96 ++-- server/clip_server/executors/clip_tensorrt.py | 82 ++-- server/clip_server/executors/clip_torch.py | 94 ++-- server/clip_server/executors/helper.py | 53 +-- server/clip_server/helper.py | 25 +- server/clip_server/model/clip.py | 2 +- server/clip_server/model/clip_model.py | 6 +- server/clip_server/model/clip_onnx.py | 282 ++++++------ server/clip_server/model/clip_trt.py | 44 +- server/clip_server/model/cnclip_model.py | 16 +- server/clip_server/model/mclip_model.py | 16 +- server/clip_server/model/model.py | 44 +- server/clip_server/model/openclip_model.py | 24 +- server/clip_server/model/pretrained_models.py | 282 ++++++------ server/clip_server/model/simple_tokenizer.py | 34 +- server/clip_server/model/tokenization.py | 14 +- server/setup.py | 126 ++--- tests/__init__.py | 2 +- tests/conftest.py | 20 +- tests/test_asyncio.py | 24 +- tests/test_client.py | 71 ++- tests/test_helper.py | 34 +- tests/test_model.py | 26 +- tests/test_ranker.py | 132 +++--- tests/test_search.py | 44 +- tests/test_server.py | 44 +- tests/test_simple.py | 90 ++-- tests/test_tensorrt.py | 40 +- tests/test_tokenization.py | 14 +- 40 files changed, 1342 insertions(+), 1339 deletions(-) diff --git a/client/clip_client/__init__.py b/client/clip_client/__init__.py index 46c9176ae..d69e4ee69 100644 --- a/client/clip_client/__init__.py +++ b/client/clip_client/__init__.py @@ -1,10 +1,10 @@ -__version__ = '0.8.4' +__version__ = "0.8.4" import os from clip_client.client import Client -if 'NO_VERSION_CHECK' not in os.environ: +if "NO_VERSION_CHECK" not in os.environ: from clip_client.helper import is_latest_version - is_latest_version(github_repo='clip-as-service') + is_latest_version(github_repo="clip-as-service") diff --git a/client/clip_client/client.py b/client/clip_client/client.py index c32a63d5c..213508e3c 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -37,21 +37,21 @@ def __init__(self, server: str, credential: dict = {}, **kwargs): _port = r.port self._scheme = r.scheme except: - raise ValueError(f'{server} is not a valid scheme') + raise ValueError(f"{server} is not a valid scheme") _tls = False - if self._scheme in ('grpcs', 'https', 'wss'): + if self._scheme in ("grpcs", "https", "wss"): self._scheme = self._scheme[:-1] _tls = True - if self._scheme == 'ws': - self._scheme = 'websocket' # temp fix for the core + if self._scheme == "ws": + self._scheme = "websocket" # temp fix for the core if credential: warnings.warn( - 'Credential is not supported for websocket, please use grpc or http' + "Credential is not supported for websocket, please use grpc or http" ) - if self._scheme in ('grpc', 'http', 'websocket'): + if self._scheme in ("grpc", "http", "websocket"): _kwargs = dict(host=r.hostname, port=_port, protocol=self._scheme, tls=_tls) from jina import Client @@ -59,13 +59,13 @@ def __init__(self, server: str, credential: dict = {}, **kwargs): self._client = Client(**_kwargs) self._async_client = Client(**_kwargs, asyncio=True) else: - raise ValueError(f'{server} is not a valid scheme') + raise ValueError(f"{server} is not a valid scheme") self._authorization = credential.get( - 'Authorization', os.environ.get('CLIP_AUTH_TOKEN') + "Authorization", os.environ.get("CLIP_AUTH_TOKEN") ) - def profile(self, content: Optional[str] = '') -> Dict[str, float]: + def profile(self, content: Optional[str] = "") -> Dict[str, float]: """Profiling a single query's roundtrip including network and computation latency. Results is summarized in a table. :param content: the content to be sent for profiling. By default it sends an empty Document that helps you understand the network latency. @@ -73,7 +73,7 @@ def profile(self, content: Optional[str] = '') -> Dict[str, float]: """ st = time.perf_counter() r = self._client.post( - '/', self._iter_doc([content], DocumentArray()), return_responses=True + "/", self._iter_doc([content], DocumentArray()), return_responses=True ) ed = (time.perf_counter() - st) * 1000 route = r[0].routes @@ -91,35 +91,35 @@ def profile(self, content: Optional[str] = '') -> Dict[str, float]: def make_table(_title, _time, _percent): table = Table(show_header=False, box=None) table.add_row( - _title, f'[b]{_time:.0f}[/b]ms', f'[dim]{_percent * 100:.0f}%[/dim]' + _title, f"[b]{_time:.0f}[/b]ms", f"[dim]{_percent * 100:.0f}%[/dim]" ) return table from rich.tree import Tree - t = Tree(make_table('Roundtrip', ed, 1)) - t.add(make_table('Client-server network', network_time, network_time / ed)) - t2 = t.add(make_table('Server', gateway_time, gateway_time / ed)) + t = Tree(make_table("Roundtrip", ed, 1)) + t.add(make_table("Client-server network", network_time, network_time / ed)) + t2 = t.add(make_table("Server", gateway_time, gateway_time / ed)) t2.add( make_table( - 'Gateway-CLIP network', server_network, server_network / gateway_time + "Gateway-CLIP network", server_network, server_network / gateway_time ) ) - t2.add(make_table('CLIP model', clip_time, clip_time / gateway_time)) + t2.add(make_table("CLIP model", clip_time, clip_time / gateway_time)) from rich import print print(t) return { - 'Roundtrip': ed, - 'Client-server network': network_time, - 'Server': gateway_time, - 'Gateway-CLIP network': server_network, - 'CLIP model': clip_time, + "Roundtrip": ed, + "Client-server network": network_time, + "Server": gateway_time, + "Gateway-CLIP network": server_network, + "CLIP model": clip_time, } - def _update_pbar(self, response, func: Optional['CallbackFnType'] = None): + def _update_pbar(self, response, func: Optional["CallbackFnType"] = None): from rich import filesize r = response.data.docs @@ -129,7 +129,7 @@ def _update_pbar(self, response, func: Optional['CallbackFnType'] = None): self._r_task, advance=len(r), total_size=str( - filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) + filesize.decimal(int(os.environ.get("JINA_GRPC_RECV_BYTES", "0"))) ), ) if func is not None: @@ -139,48 +139,48 @@ def _prepare_streaming(self, disable, total): if total is None: total = 500 warnings.warn( - 'The length of the input is unknown, the progressbar would not be accurate.' + "The length of the input is unknown, the progressbar would not be accurate." ) elif total > 500: warnings.warn( - 'Please ensure all the inputs are valid, otherwise the request will be aborted.' + "Please ensure all the inputs are valid, otherwise the request will be aborted." ) from docarray.array.mixins.io.pbar import get_pbar self._pbar = get_pbar(disable) - os.environ['JINA_GRPC_SEND_BYTES'] = '0' - os.environ['JINA_GRPC_RECV_BYTES'] = '0' + os.environ["JINA_GRPC_SEND_BYTES"] = "0" + os.environ["JINA_GRPC_RECV_BYTES"] = "0" self._r_task = self._pbar.add_task( - ':arrow_down: Progress', total=total, total_size=0, start=False + ":arrow_down: Progress", total=total, total_size=0, start=False ) @staticmethod def _gather_result( - response, results: 'DocumentArray', attribute: Optional[str] = None + response, results: "DocumentArray", attribute: Optional[str] = None ): r = response.data.docs if attribute: - results[r[:, 'id']][:, attribute] = r[:, attribute] + results[r[:, "id"]][:, attribute] = r[:, attribute] def _iter_doc( - self, content, results: Optional['DocumentArray'] = None - ) -> Generator['Document', None, None]: + self, content, results: Optional["DocumentArray"] = None + ) -> Generator["Document", None, None]: from docarray import Document for c in content: if isinstance(c, str): _mime = mimetypes.guess_type(c)[0] - if _mime and _mime.startswith('image'): + if _mime and _mime.startswith("image"): d = Document( uri=c, ).load_uri_to_blob() else: d = Document(text=c) elif isinstance(c, Document): - if c.content_type in ('text', 'blob'): + if c.content_type in ("text", "blob"): d = c elif not c.blob and c.uri: c.load_uri_to_blob() @@ -188,37 +188,37 @@ def _iter_doc( elif c.tensor is not None: d = c else: - raise TypeError(f'unsupported input type {c!r} {c.content_type}') + raise TypeError(f"unsupported input type {c!r} {c.content_type}") else: - raise TypeError(f'unsupported input type {c!r}') + raise TypeError(f"unsupported input type {c!r}") if results is not None: results.append(d) yield d def _get_post_payload( - self, content, results: Optional['DocumentArray'] = None, **kwargs + self, content, results: Optional["DocumentArray"] = None, **kwargs ): payload = dict( inputs=self._iter_doc(content, results), - request_size=kwargs.get('batch_size', 8), - total_docs=len(content) if hasattr(content, '__len__') else None, + request_size=kwargs.get("batch_size", 8), + total_docs=len(content) if hasattr(content, "__len__") else None, ) - if self._scheme == 'grpc' and self._authorization: - payload.update(metadata=(('authorization', self._authorization),)) - elif self._scheme == 'http' and self._authorization: - payload.update(headers={'Authorization': self._authorization}) + if self._scheme == "grpc" and self._authorization: + payload.update(metadata=(("authorization", self._authorization),)) + elif self._scheme == "http" and self._authorization: + payload.update(headers={"Authorization": self._authorization}) return payload @staticmethod - def _unboxed_result(results: Optional['DocumentArray'] = None, unbox: bool = False): + def _unboxed_result(results: Optional["DocumentArray"] = None, unbox: bool = False): if results is not None: if results.embeddings is None: raise ValueError( - 'Empty embedding returned from the server. ' - 'This often due to a mis-config of the server, ' - 'restarting the server or changing the serving port number often solves the problem' + "Empty embedding returned from the server. " + "This often due to a mis-config of the server, " + "restarting the server or changing the serving port number often solves the problem" ) return results.embeddings if unbox else results @@ -230,11 +230,11 @@ def encode( batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, - ) -> 'np.ndarray': + ) -> "np.ndarray": """Encode images and texts into embeddings where the input is an iterable of raw strings. Each image and text must be represented as a string. The following strings are acceptable: - local image filepath, will be considered as an image @@ -260,16 +260,16 @@ def encode( @overload def encode( self, - content: Union['DocumentArray', Iterable['Document']], + content: Union["DocumentArray", Iterable["Document"]], *, batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, - ) -> 'DocumentArray': + ) -> "DocumentArray": """Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`. :param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`. :param batch_size: the number of elements in each request when sending ``content`` @@ -292,32 +292,32 @@ def encode(self, content, **kwargs): raise TypeError( f'Content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead' ) - if hasattr(content, '__len__') and len(content) == 0: + if hasattr(content, "__len__") and len(content) == 0: return DocumentArray() if isinstance(content, DocumentArray) else [] self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(content) if hasattr(content, '__len__') else None, + not kwargs.get("show_progress"), + total=len(content) if hasattr(content, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: on_done = partial( - self._gather_result, results=results, attribute='embedding' + self._gather_result, results=results, attribute="embedding" ) with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) - model_name = parameters.pop('model_name', '') if parameters else '' + model_name = parameters.pop("model_name", "") if parameters else "" self._client.post( - on=f'/encode/{model_name}'.rstrip('/'), + on=f"/encode/{model_name}".rstrip("/"), **self._get_post_payload(content, results, **kwargs), on_done=on_done, on_error=on_error, @@ -326,7 +326,7 @@ def encode(self, content, **kwargs): prefetch=prefetch, ) - unbox = hasattr(content, '__len__') and isinstance(content[0], str) + unbox = hasattr(content, "__len__") and isinstance(content[0], str) return self._unboxed_result(results, unbox) @overload @@ -337,26 +337,26 @@ async def aencode( batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, - ) -> 'np.ndarray': + ) -> "np.ndarray": ... @overload async def aencode( self, - content: Union['DocumentArray', Iterable['Document']], + content: Union["DocumentArray", Iterable["Document"]], *, batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, - ) -> 'DocumentArray': + ) -> "DocumentArray": ... async def aencode(self, content, **kwargs): @@ -364,32 +364,32 @@ async def aencode(self, content, **kwargs): raise TypeError( f'Content must be an Iterable of [str, Document], try `.aencode(["{content}"])` instead' ) - if hasattr(content, '__len__') and len(content) == 0: + if hasattr(content, "__len__") and len(content) == 0: return DocumentArray() if isinstance(content, DocumentArray) else [] self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(content) if hasattr(content, '__len__') else None, + not kwargs.get("show_progress"), + total=len(content) if hasattr(content, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: on_done = partial( - self._gather_result, results=results, attribute='embedding' + self._gather_result, results=results, attribute="embedding" ) with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) - model_name = parameters.get('model_name', '') if parameters else '' + model_name = parameters.get("model_name", "") if parameters else "" async for _ in self._async_client.post( - on=f'/encode/{model_name}'.rstrip('/'), + on=f"/encode/{model_name}".rstrip("/"), **self._get_post_payload(content, results, **kwargs), on_done=on_done, on_error=on_error, @@ -399,42 +399,42 @@ async def aencode(self, content, **kwargs): ): continue - unbox = hasattr(content, '__len__') and isinstance(content[0], str) + unbox = hasattr(content, "__len__") and isinstance(content[0], str) return self._unboxed_result(results, unbox) def _iter_rank_docs( - self, content, results: Optional['DocumentArray'] = None, source='matches' - ) -> Generator['Document', None, None]: + self, content, results: Optional["DocumentArray"] = None, source="matches" + ) -> Generator["Document", None, None]: from docarray import Document for c in content: if isinstance(c, Document): d = self._prepare_rank_doc(c, source) else: - raise TypeError(f'Unsupported input type {c!r}') + raise TypeError(f"Unsupported input type {c!r}") if results is not None: results.append(d) yield d def _get_rank_payload( - self, content, results: Optional['DocumentArray'] = None, **kwargs + self, content, results: Optional["DocumentArray"] = None, **kwargs ): payload = dict( inputs=self._iter_rank_docs( - content, results, source=kwargs.get('source', 'matches') + content, results, source=kwargs.get("source", "matches") ), - request_size=kwargs.get('batch_size', 8), - total_docs=len(content) if hasattr(content, '__len__') else None, + request_size=kwargs.get("batch_size", 8), + total_docs=len(content) if hasattr(content, "__len__") else None, ) - if self._scheme == 'grpc' and self._authorization: - payload.update(metadata=(('authorization', self._authorization),)) - elif self._scheme == 'http' and self._authorization: - payload.update(headers={'Authorization': self._authorization}) + if self._scheme == "grpc" and self._authorization: + payload.update(metadata=(("authorization", self._authorization),)) + elif self._scheme == "http" and self._authorization: + payload.update(headers={"Authorization": self._authorization}) return payload @staticmethod - def _prepare_single_doc(d: 'Document'): - if d.content_type in ('text', 'blob'): + def _prepare_single_doc(d: "Document"): + if d.content_type in ("text", "blob"): return d elif not d.blob and d.uri: d.load_uri_to_blob() @@ -442,20 +442,20 @@ def _prepare_single_doc(d: 'Document'): elif d.tensor is not None: return d else: - raise TypeError(f'Unsupported input type {d!r} {d.content_type}') + raise TypeError(f"Unsupported input type {d!r} {d.content_type}") @staticmethod - def _prepare_rank_doc(d: 'Document', _source: str = 'matches'): + def _prepare_rank_doc(d: "Document", _source: str = "matches"): _get = lambda d: getattr(d, _source) if not _get(d): - raise ValueError(f'`.rank()` requires every doc to have `.{_source}`') + raise ValueError(f"`.rank()` requires every doc to have `.{_source}`") d = Client._prepare_single_doc(d) setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)]) return d def rank( - self, docs: Union['DocumentArray', Iterable['Document']], **kwargs - ) -> 'DocumentArray': + self, docs: Union["DocumentArray", Iterable["Document"]], **kwargs + ) -> "DocumentArray": """Rank image-text matches according to the server CLIP model. Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e. text/image; this method ranks the matches according to the CLIP model. @@ -466,30 +466,30 @@ def rank( :return: the ranked Documents in a DocumentArray. """ if isinstance(docs, str): - raise TypeError(f'Content must be an Iterable of [Document]') + raise TypeError(f"Content must be an Iterable of [Document]") self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(docs) if hasattr(docs, '__len__') else None, + not kwargs.get("show_progress"), + total=len(docs) if hasattr(docs, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: - on_done = partial(self._gather_result, results=results, attribute='matches') + on_done = partial(self._gather_result, results=results, attribute="matches") with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) - model_name = parameters.get('model_name', '') if parameters else '' + model_name = parameters.get("model_name", "") if parameters else "" self._client.post( - on=f'/rank/{model_name}'.rstrip('/'), + on=f"/rank/{model_name}".rstrip("/"), **self._get_rank_payload(docs, results, **kwargs), on_done=on_done, on_error=on_error, @@ -501,32 +501,32 @@ def rank( return results async def arank( - self, docs: Union['DocumentArray', Iterable['Document']], **kwargs - ) -> 'DocumentArray': + self, docs: Union["DocumentArray", Iterable["Document"]], **kwargs + ) -> "DocumentArray": if isinstance(docs, str): - raise TypeError(f'Content must be an Iterable of [Document]') + raise TypeError(f"Content must be an Iterable of [Document]") self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(docs) if hasattr(docs, '__len__') else None, + not kwargs.get("show_progress"), + total=len(docs) if hasattr(docs, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: - on_done = partial(self._gather_result, results=results, attribute='matches') + on_done = partial(self._gather_result, results=results, attribute="matches") with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) - model_name = parameters.get('model_name', '') if parameters else '' + model_name = parameters.get("model_name", "") if parameters else "" async for _ in self._async_client.post( - on=f'/rank/{model_name}'.rstrip('/'), + on=f"/rank/{model_name}".rstrip("/"), **self._get_rank_payload(docs, results, **kwargs), on_done=on_done, on_error=on_error, @@ -546,9 +546,9 @@ def index( batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[Dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, ): """Index the images or texts where their embeddings are computed by the server CLIP model. @@ -577,16 +577,16 @@ def index( @overload def index( self, - content: Union['DocumentArray', Iterable['Document']], + content: Union["DocumentArray", Iterable["Document"]], *, batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, - ) -> 'DocumentArray': + ) -> "DocumentArray": """Index the images or texts where their embeddings are computed by the server CLIP model. :param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`. @@ -612,27 +612,27 @@ def index(self, content, **kwargs): ) self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(content) if hasattr(content, '__len__') else None, + not kwargs.get("show_progress"), + total=len(content) if hasattr(content, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: on_done = partial( - self._gather_result, results=results, attribute='embedding' + self._gather_result, results=results, attribute="embedding" ) with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) self._client.post( - on='/index', + on="/index", **self._get_post_payload(content, results, **kwargs), on_done=on_done, on_error=on_error, @@ -651,9 +651,9 @@ async def aindex( batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[Dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, ): ... @@ -661,14 +661,14 @@ async def aindex( @overload async def aindex( self, - content: Union['DocumentArray', Iterable['Document']], + content: Union["DocumentArray", Iterable["Document"]], *, batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, ): ... @@ -680,27 +680,27 @@ async def aindex(self, content, **kwargs): ) self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(content) if hasattr(content, '__len__') else None, + not kwargs.get("show_progress"), + total=len(content) if hasattr(content, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: on_done = partial( - self._gather_result, results=results, attribute='embedding' + self._gather_result, results=results, attribute="embedding" ) with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) async for _ in self._async_client.post( - on='/index', + on="/index", **self._get_post_payload(content, results, **kwargs), on_done=on_done, on_error=on_error, @@ -721,11 +721,11 @@ def search( batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[Dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, - ) -> 'DocumentArray': + ) -> "DocumentArray": """Search for top k results for given query string or ``Document``. If the input is a string, will use this string as query. If the input is a ``Document``, @@ -750,17 +750,17 @@ def search( @overload def search( self, - content: Union['DocumentArray', Iterable['Document']], + content: Union["DocumentArray", Iterable["Document"]], *, limit: int = 10, batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, - ) -> 'DocumentArray': + ) -> "DocumentArray": """Search for top k results for given query string or ``Document``. If the input is a string, will use this string as query. If the input is a ``Document``, @@ -782,33 +782,33 @@ def search( """ ... - def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray': + def search(self, content, limit: int = 10, **kwargs) -> "DocumentArray": if isinstance(content, str): raise TypeError( f'content must be an Iterable of [str, Document], try `.search(["{content}"])` instead' ) self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(content) if hasattr(content, '__len__') else None, + not kwargs.get("show_progress"), + total=len(content) if hasattr(content, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: - on_done = partial(self._gather_result, results=results, attribute='matches') + on_done = partial(self._gather_result, results=results, attribute="matches") with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['limit'] = limit - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["limit"] = limit + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) self._client.post( - on='/search', + on="/search", **self._get_post_payload(content, results, **kwargs), on_done=on_done, on_error=on_error, @@ -828,9 +828,9 @@ async def asearch( batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[Dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, ): ... @@ -838,15 +838,15 @@ async def asearch( @overload async def asearch( self, - content: Union['DocumentArray', Iterable['Document']], + content: Union["DocumentArray", Iterable["Document"]], *, limit: int = 10, batch_size: Optional[int] = None, show_progress: bool = False, parameters: Optional[dict] = None, - on_done: Optional['CallbackFnType'] = None, - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, + on_done: Optional["CallbackFnType"] = None, + on_error: Optional["CallbackFnType"] = None, + on_always: Optional["CallbackFnType"] = None, prefetch: int = 100, ): ... @@ -858,26 +858,26 @@ async def asearch(self, content, limit: int = 10, **kwargs): ) self._prepare_streaming( - not kwargs.get('show_progress'), - total=len(content) if hasattr(content, '__len__') else None, + not kwargs.get("show_progress"), + total=len(content) if hasattr(content, "__len__") else None, ) - on_done = kwargs.pop('on_done', None) - on_error = kwargs.pop('on_error', None) - on_always = kwargs.pop('on_always', None) - prefetch = kwargs.pop('prefetch', 100) + on_done = kwargs.pop("on_done", None) + on_error = kwargs.pop("on_error", None) + on_always = kwargs.pop("on_always", None) + prefetch = kwargs.pop("prefetch", 100) results = DocumentArray() if not on_done and not on_always else None if not on_done: - on_done = partial(self._gather_result, results=results, attribute='matches') + on_done = partial(self._gather_result, results=results, attribute="matches") with self._pbar: - parameters = kwargs.pop('parameters', {}) - parameters['limit'] = limit - parameters['drop_image_content'] = parameters.get( - 'drop_image_content', True + parameters = kwargs.pop("parameters", {}) + parameters["limit"] = limit + parameters["drop_image_content"] = parameters.get( + "drop_image_content", True ) async for _ in self._async_client.post( - on='/search', + on="/search", **self._get_post_payload(content, results, **kwargs), on_done=on_done, on_error=on_error, diff --git a/client/clip_client/helper.py b/client/clip_client/helper.py index b47cbc32d..4be353253 100644 --- a/client/clip_client/helper.py +++ b/client/clip_client/helper.py @@ -11,31 +11,30 @@ def _version_check(package: str = None, github_repo: str = None): try: - if not package: - package = vars(sys.modules[__name__])['__package__'] + package = vars(sys.modules[__name__])["__package__"] if not github_repo: github_repo = package cur_ver = Version(pkg_resources.get_distribution(package).version) req = Request( - f'https://pypi.python.org/pypi/{package}/json', - headers={'User-Agent': 'Mozilla/5.0'}, + f"https://pypi.python.org/pypi/{package}/json", + headers={"User-Agent": "Mozilla/5.0"}, ) with urlopen( req, timeout=1 ) as resp: # 'with' is important to close the resource after use j = json.load(resp) - releases = j.get('releases', {}) + releases = j.get("releases", {}) latest_release_ver = max( - Version(v) for v in releases.keys() if '.dev' not in v + Version(v) for v in releases.keys() if ".dev" not in v ) if cur_ver < latest_release_ver: print( Panel( - f'You are using [b]{package} {cur_ver}[/b], but [bold green]{latest_release_ver}[/] is available. ' - f'You may upgrade it via [b]pip install -U {package}[/b]. [link=https://github.com/jina-ai/{github_repo}/releases]Read Changelog here[/link].', - title=':new: New version available!', + f"You are using [b]{package} {cur_ver}[/b], but [bold green]{latest_release_ver}[/] is available. " + f"You may upgrade it via [b]pip install -U {package}[/b]. [link=https://github.com/jina-ai/{github_repo}/releases]Read Changelog here[/link].", + title=":new: New version available!", width=50, ) ) diff --git a/client/setup.py b/client/setup.py index 0a248fec3..0f16ce6fa 100644 --- a/client/setup.py +++ b/client/setup.py @@ -5,89 +5,89 @@ from setuptools import setup if sys.version_info < (3, 7, 0): - raise OSError(f'CLIP-as-service requires Python >=3.7, but yours is {sys.version}') + raise OSError(f"CLIP-as-service requires Python >=3.7, but yours is {sys.version}") try: - pkg_name = 'clip-client' + pkg_name = "clip-client" libinfo_py = path.join( - path.dirname(__file__), pkg_name.replace('-', '_'), '__init__.py' + path.dirname(__file__), pkg_name.replace("-", "_"), "__init__.py" ) - libinfo_content = open(libinfo_py, 'r', encoding='utf8').readlines() - version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][ + libinfo_content = open(libinfo_py, "r", encoding="utf8").readlines() + version_line = [l.strip() for l in libinfo_content if l.startswith("__version__")][ 0 ] exec(version_line) # gives __version__ except FileNotFoundError as ex: - __version__ = '0.0.0' + __version__ = "0.0.0" try: - with open('../README.md', encoding='utf8') as fp: + with open("../README.md", encoding="utf8") as fp: _long_description = fp.read() except FileNotFoundError: - _long_description = '' + _long_description = "" setup( name=pkg_name, packages=find_packages(), version=__version__, include_package_data=True, - description='Embed images and sentences into fixed-length vectors via CLIP', - author='Jina AI', - author_email='hello@jina.ai', - license='Apache 2.0', - url='https://github.com/jina-ai/clip-as-service', - download_url='https://github.com/jina-ai/clip-as-service/tags', + description="Embed images and sentences into fixed-length vectors via CLIP", + author="Jina AI", + author_email="hello@jina.ai", + license="Apache 2.0", + url="https://github.com/jina-ai/clip-as-service", + download_url="https://github.com/jina-ai/clip-as-service/tags", long_description=_long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", zip_safe=False, - setup_requires=['setuptools>=18.0', 'wheel'], + setup_requires=["setuptools>=18.0", "wheel"], install_requires=[ - 'jina>=3.12.0', - 'docarray[common]>=0.19.0,<0.30.0', - 'packaging', + "jina>=3.12.0", + "docarray[common]>=0.19.0,<0.30.0", + "packaging", ], extras_require={ - 'test': [ - 'pytest', - 'pytest-timeout', - 'pytest-mock', - 'pytest-asyncio', - 'pytest-cov', - 'pytest-repeat', - 'pytest-reraise', - 'mock', - 'pytest-custom_exit_code', - 'black', + "test": [ + "pytest", + "pytest-timeout", + "pytest-mock", + "pytest-asyncio", + "pytest-cov", + "pytest-repeat", + "pytest-reraise", + "mock", + "pytest-custom_exit_code", + "black", ], }, classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Unix Shell', - 'Environment :: Console', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Topic :: Database :: Database Engines/Servers', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Internet :: WWW/HTTP :: Indexing/Search', - 'Topic :: Scientific/Engineering :: Image Recognition', - 'Topic :: Multimedia :: Video', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Unix Shell", + "Environment :: Console", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Topic :: Database :: Database Engines/Servers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Internet :: WWW/HTTP :: Indexing/Search", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], project_urls={ - 'Documentation': 'https://clip-as-service.jina.ai', - 'Source': 'https://github.com/jina-ai/clip-as-service/', - 'Tracker': 'https://github.com/jina-ai/clip-as-service/issues', + "Documentation": "https://clip-as-service.jina.ai", + "Source": "https://github.com/jina-ai/clip-as-service/", + "Tracker": "https://github.com/jina-ai/clip-as-service/issues", }, - keywords='jina openai clip deep-learning cross-modal multi-modal neural-search', + keywords="jina openai clip deep-learning cross-modal multi-modal neural-search", ) diff --git a/docs/conf.py b/docs/conf.py index cec066bf4..1ae3742ea 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -3,54 +3,54 @@ import sys from os import path -sys.path.insert(0, path.abspath('..')) +sys.path.insert(0, path.abspath("..")) -project = 'CLIP-as-service' -slug = re.sub(r'\W+', '-', project.lower()) -author = 'Jina AI' -copyright = 'Jina AI Limited. All rights reserved.' -source_suffix = ['.rst', '.md'] -master_doc = 'index' -language = 'en' -repo_dir = '../' +project = "CLIP-as-service" +slug = re.sub(r"\W+", "-", project.lower()) +author = "Jina AI" +copyright = "Jina AI Limited. All rights reserved." +source_suffix = [".rst", ".md"] +master_doc = "index" +language = "en" +repo_dir = "../" try: - if 'CAS_VERSION' not in os.environ: - libinfo_py = path.join(repo_dir, 'client/clip_client', '__init__.py') - libinfo_content = open(libinfo_py, 'r').readlines() + if "CAS_VERSION" not in os.environ: + libinfo_py = path.join(repo_dir, "client/clip_client", "__init__.py") + libinfo_content = open(libinfo_py, "r").readlines() version_line = [ - l.strip() for l in libinfo_content if l.startswith('__version__') + l.strip() for l in libinfo_content if l.startswith("__version__") ][0] exec(version_line) else: - __version__ = os.environ['CAS_VERSION'] + __version__ = os.environ["CAS_VERSION"] except FileNotFoundError: - __version__ = '0.0.0' + __version__ = "0.0.0" version = __version__ release = __version__ -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [ - '_build', - 'Thumbs.db', - '.DS_Store', - 'tests', - 'page_templates', - '.github', + "_build", + "Thumbs.db", + ".DS_Store", + "tests", + "page_templates", + ".github", ] -pygments_style = 'rainbow_dash' -html_theme = 'furo' +pygments_style = "rainbow_dash" +html_theme = "furo" -base_url = '/' -html_baseurl = 'https://clip-as-service.jina.ai' -sitemap_url_scheme = '{link}' +base_url = "/" +html_baseurl = "https://clip-as-service.jina.ai" +sitemap_url_scheme = "{link}" sitemap_locales = [None] sitemap_filename = "sitemap.xml" html_theme_options = { - 'light_logo': 'logo-light.svg', - 'dark_logo': 'logo-dark.svg', + "light_logo": "logo-light.svg", + "dark_logo": "logo-dark.svg", "sidebar_hide_name": True, "light_css_variables": { "color-brand-primary": "#009191", @@ -66,101 +66,104 @@ # end-announce } -html_static_path = ['_static'] -html_extra_path = ['html_extra'] +html_static_path = ["_static"] +html_extra_path = ["html_extra"] html_css_files = [ - 'main.css', - 'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta2/css/all.min.css', + "main.css", + "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta2/css/all.min.css", ] html_js_files = [ - 'https://cdn.jsdelivr.net/npm/vue@2/dist/vue.min.js', + "https://cdn.jsdelivr.net/npm/vue@2/dist/vue.min.js", ] htmlhelp_basename = slug html_show_sourcelink = False -html_favicon = '_static/favicon.png' +html_favicon = "_static/favicon.png" -intersphinx_mapping = {'docarray': ('https://docarray.jina.ai/', None), 'finetuner': ('https://finetuner.jina.ai/', None)} +intersphinx_mapping = { + "docarray": ("https://docarray.jina.ai/", None), + "finetuner": ("https://finetuner.jina.ai/", None), +} -latex_documents = [(master_doc, f'{slug}.tex', project, author, 'manual')] +latex_documents = [(master_doc, f"{slug}.tex", project, author, "manual")] man_pages = [(master_doc, slug, project, [author], 1)] texinfo_documents = [ - (master_doc, slug, project, author, slug, project, 'Miscellaneous') + (master_doc, slug, project, author, slug, project, "Miscellaneous") ] epub_title = project -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- extensions = [ - 'sphinx.ext.autodoc', - 'sphinx_autodoc_typehints', - 'sphinx.ext.viewcode', - 'sphinx.ext.coverage', - 'sphinxcontrib.apidoc', - 'sphinxarg.ext', - 'sphinx_copybutton', - 'sphinx_sitemap', - 'sphinx.ext.intersphinx', - 'sphinxext.opengraph', - 'notfound.extension', - 'myst_parser', - 'sphinx_design', - 'sphinx_inline_tabs', + "sphinx.ext.autodoc", + "sphinx_autodoc_typehints", + "sphinx.ext.viewcode", + "sphinx.ext.coverage", + "sphinxcontrib.apidoc", + "sphinxarg.ext", + "sphinx_copybutton", + "sphinx_sitemap", + "sphinx.ext.intersphinx", + "sphinxext.opengraph", + "notfound.extension", + "myst_parser", + "sphinx_design", + "sphinx_inline_tabs", ] -myst_enable_extensions = ['colon_fence', 'substitution', 'deflist'] +myst_enable_extensions = ["colon_fence", "substitution", "deflist"] # -- Custom 404 page # sphinx-notfound-page # https://github.com/readthedocs/sphinx-notfound-page notfound_context = { - 'title': 'Page Not Found', - 'body': ''' + "title": "Page Not Found", + "body": """

Page Not Found

Oops, we couldn't find that page.

You can try "asking our docs" on the right corner of the page to find answer.

Otherwise, please create a Github issue and one of our team will respond.

-''', +""", } notfound_no_urls_prefix = True -apidoc_module_dir = '../client' -apidoc_output_dir = 'api' -apidoc_excluded_paths = ['tests', 'legacy', 'hub', 'toy*', 'setup.py'] +apidoc_module_dir = "../client" +apidoc_output_dir = "api" +apidoc_excluded_paths = ["tests", "legacy", "hub", "toy*", "setup.py"] apidoc_separate_modules = True -apidoc_extra_args = ['-t', 'template/'] -autodoc_member_order = 'bysource' -autodoc_mock_imports = ['argparse', 'numpy', 'np', 'tensorflow', 'torch', 'scipy'] -autoclass_content = 'both' +apidoc_extra_args = ["-t", "template/"] +autodoc_member_order = "bysource" +autodoc_mock_imports = ["argparse", "numpy", "np", "tensorflow", "torch", "scipy"] +autoclass_content = "both" set_type_checking_flag = False -html_last_updated_fmt = '' +html_last_updated_fmt = "" nitpicky = True -nitpick_ignore = [('py:class', 'type')] +nitpick_ignore = [("py:class", "type")] linkcheck_ignore = [ # Avoid link check on local uri - 'http://0.0.0.0:*', - 'pods/encode.yml', - 'https://github.com/jina-ai/clip-as-service/commit/*', - '.github/*', - 'extra-requirements.txt', - 'fastentrypoints.py' '../../101', - '../../102', - 'http://www.twinsun.com/tz/tz-link.htm', # Broken link from pytz library - 'https://urllib3.readthedocs.io/en/latest/contrib.html#google-app-engine', # Broken link from urllib3 library - 'https://linuxize.com/post/how-to-add-swap-space-on-ubuntu-20-04/', + "http://0.0.0.0:*", + "pods/encode.yml", + "https://github.com/jina-ai/clip-as-service/commit/*", + ".github/*", + "extra-requirements.txt", + "fastentrypoints.py" "../../101", + "../../102", + "http://www.twinsun.com/tz/tz-link.htm", # Broken link from pytz library + "https://urllib3.readthedocs.io/en/latest/contrib.html#google-app-engine", # Broken link from urllib3 library + "https://linuxize.com/post/how-to-add-swap-space-on-ubuntu-20-04/", # This link works but gets 403 error on linkcheck ] linkcheck_timeout = 20 linkcheck_retries = 2 linkcheck_anchors = False -ogp_site_url = 'https://clip-as-service.jina.ai/' -ogp_image = 'https://clip-as-service.jina.ai/_static/banner.png' +ogp_site_url = "https://clip-as-service.jina.ai/" +ogp_image = "https://clip-as-service.jina.ai/_static/banner.png" ogp_use_first_image = True ogp_description_length = 300 -ogp_type = 'website' +ogp_type = "website" ogp_site_name = f'CLIP-as-service {os.environ.get("SPHINX_MULTIVERSION_VERSION", version)} Documentation' ogp_custom_meta_tags = [ @@ -169,7 +172,7 @@ '', '', '', - ''' + """ - ''', + """, ] def add_server_address(app): # This makes variable `server_address` available to docbot.js - server_address = app.config['server_address'] + server_address = app.config["server_address"] js_text = "var server_address = '%s';" % server_address app.add_js_file(None, body=js_text) @@ -198,23 +201,23 @@ def setup(app): from sphinx.locale import _ app.add_object_type( - 'confval', - 'confval', - objname='configuration value', - indextemplate='pair: %s; configuration value', + "confval", + "confval", + objname="configuration value", + indextemplate="pair: %s; configuration value", doc_field_types=[ PyField( - 'type', - label=_('Type'), + "type", + label=_("Type"), has_arg=False, - names=('type',), - bodyrolename='class', + names=("type",), + bodyrolename="class", ), Field( - 'default', - label=_('Default'), + "default", + label=_("Default"), has_arg=False, - names=('default',), + names=("default",), ), ], ) diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 2ecd702bd..88beb4591 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -24,7 +24,7 @@ def __init__( self, server: str, batch_size: int = 1, - modality: str = 'text', + modality: str = "text", num_iter: Optional[int] = 100, image_sample: str = None, **kwargs, @@ -35,7 +35,7 @@ def __init__( @param num_iter: number of repeat run per experiment @param image_sample: uri of the test image """ - assert num_iter > 2, 'num_iter must be greater than 2' + assert num_iter > 2, "num_iter must be greater than 2" super().__init__() self.server = server self.batch_size = batch_size @@ -49,25 +49,25 @@ def run(self): from clip_client import Client except ImportError: raise ImportError( - 'clip_client module is not available. it is required for benchmarking.' + "clip_client module is not available. it is required for benchmarking." 'Please use ""pip install clip-client" to install it.' ) - if self.modality == 'text': + if self.modality == "text": from clip_server.model.simple_tokenizer import SimpleTokenizer tokenizer = SimpleTokenizer() vocab = list(tokenizer.encoder.keys()) batch = DocumentArray( [ - Document(text=' '.join(random.choices(vocab, k=78))) + Document(text=" ".join(random.choices(vocab, k=78))) for _ in range(self.batch_size) ] ) - elif self.modality == 'image': + elif self.modality == "image": batch = DocumentArray( [ - Document(blob=open(self.image_sample, 'rb').read()) + Document(blob=open(self.image_sample, "rb").read()) for _ in range(self.batch_size) ] ) @@ -84,26 +84,26 @@ def run(self): self.avg_time = np.mean(time_costs[2:]) -@click.command(name='clip-as-service benchmark') -@click.argument('server') +@click.command(name="clip-as-service benchmark") +@click.argument("server") @click.option( - '--batch_sizes', + "--batch_sizes", multiple=True, type=int, default=[1, 8, 16, 32, 64], - help='number of batch', + help="number of batch", ) @click.option( - '--num_iter', default=10, help='number of repeat run per experiment (must > 2)' + "--num_iter", default=10, help="number of repeat run per experiment (must > 2)" ) @click.option( "--concurrent_clients", multiple=True, type=int, default=[1, 4, 16, 32, 64], - help='number of concurrent clients per experiment', + help="number of concurrent clients per experiment", ) -@click.option("--image_sample", help='path to the image sample file') +@click.option("--image_sample", help="path to the image sample file") def main(server, batch_sizes, num_iter, concurrent_clients, image_sample): # wait until the server is ready for batch_size in batch_sizes: @@ -113,7 +113,7 @@ def main(server, batch_sizes, num_iter, concurrent_clients, image_sample): server, batch_size=batch_size, num_iter=num_iter, - modality='image' if (image_sample is not None) else 'text', + modality="image" if (image_sample is not None) else "text", image_sample=image_sample, ) for _ in range(num_client) @@ -134,11 +134,11 @@ def main(server, batch_sizes, num_iter, concurrent_clients, image_sample): ) print( - '(concurrent client=%d, batch_size=%d) avg speed: %.3f\tmax speed: %.3f\tmin speed: %.3f' + "(concurrent client=%d, batch_size=%d) avg speed: %.3f\tmax speed: %.3f\tmin speed: %.3f" % (num_client, batch_size, avg_speed, max_speed, min_speed), flush=True, ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/get-last-release-note.py b/scripts/get-last-release-note.py index 75ed1268e..c9f2173d4 100644 --- a/scripts/get-last-release-note.py +++ b/scripts/get-last-release-note.py @@ -2,12 +2,12 @@ # python scripts/get-last-release-note.py ## result in root/tmp.md -with open('CHANGELOG.md') as fp: +with open("CHANGELOG.md") as fp: n = [] for v in fp: - if v.startswith('## Release Note'): + if v.startswith("## Release Note"): n.clear() n.append(v) -with open('tmp.md', 'w') as fp: +with open("tmp.md", "w") as fp: fp.writelines(n) diff --git a/scripts/get-requirements.py b/scripts/get-requirements.py index c17066694..dae6b74d4 100644 --- a/scripts/get-requirements.py +++ b/scripts/get-requirements.py @@ -6,7 +6,7 @@ result = run_setup("./server/setup.py", stop_after="init") -with open(sys.argv[2], 'w') as fp: - fp.write('\n'.join(result.install_requires) + '\n') +with open(sys.argv[2], "w") as fp: + fp.write("\n".join(result.install_requires) + "\n") if sys.argv[1]: - fp.write('\n'.join(result.extras_require[sys.argv[1]]) + '\n') + fp.write("\n".join(result.extras_require[sys.argv[1]]) + "\n") diff --git a/scripts/setup.py b/scripts/setup.py index 416cf7b64..e6d4ffb93 100644 --- a/scripts/setup.py +++ b/scripts/setup.py @@ -5,68 +5,68 @@ from setuptools import setup if sys.version_info < (3, 7, 0): - raise OSError(f'Clip-as-service requires Python >=3.7, but yours is {sys.version}') + raise OSError(f"Clip-as-service requires Python >=3.7, but yours is {sys.version}") try: - pkg_name = 'clip-as-service' - libinfo_py = path.join('server/clip_server/__init__.py') - libinfo_content = open(libinfo_py, 'r', encoding='utf8').readlines() - version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][ + pkg_name = "clip-as-service" + libinfo_py = path.join("server/clip_server/__init__.py") + libinfo_content = open(libinfo_py, "r", encoding="utf8").readlines() + version_line = [l.strip() for l in libinfo_content if l.startswith("__version__")][ 0 ] exec(version_line) # gives __version__ except FileNotFoundError: - __version__ = '0.0.0' + __version__ = "0.0.0" try: - with open('README.md', encoding='utf8') as fp: + with open("README.md", encoding="utf8") as fp: _long_description = fp.read() except FileNotFoundError: - _long_description = '' + _long_description = "" setup( name=pkg_name, packages=find_packages(), version=__version__, include_package_data=True, - description='Embed images and sentences into fixed-length vectors via CLIP', - author='Jina AI', - author_email='hello@jina.ai', - license='Apache 2.0', - url='https://github.com/jina-ai/clip-as-service', - download_url='https://github.com/jina-ai/clip-as-service/tags', + description="Embed images and sentences into fixed-length vectors via CLIP", + author="Jina AI", + author_email="hello@jina.ai", + license="Apache 2.0", + url="https://github.com/jina-ai/clip-as-service", + download_url="https://github.com/jina-ai/clip-as-service/tags", long_description=_long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", zip_safe=False, - setup_requires=['setuptools>=18.0', 'wheel'], - install_requires=['clip-server', 'clip-client'], + setup_requires=["setuptools>=18.0", "wheel"], + install_requires=["clip-server", "clip-client"], classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Unix Shell', - 'Environment :: Console', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Topic :: Database :: Database Engines/Servers', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Internet :: WWW/HTTP :: Indexing/Search', - 'Topic :: Scientific/Engineering :: Image Recognition', - 'Topic :: Multimedia :: Video', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Unix Shell", + "Environment :: Console", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Topic :: Database :: Database Engines/Servers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Internet :: WWW/HTTP :: Indexing/Search", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], project_urls={ - 'Documentation': 'https://clip-as-service.jina.ai/', - 'Source': 'https://github.com/jina-ai/clip-as-service', - 'Tracker': 'https://github.com/jina-ai/clip-as-service/issues', + "Documentation": "https://clip-as-service.jina.ai/", + "Source": "https://github.com/jina-ai/clip-as-service", + "Tracker": "https://github.com/jina-ai/clip-as-service/issues", }, - keywords='jina openai clip deep-learning cross-modal multi-modal neural-search', + keywords="jina openai clip deep-learning cross-modal multi-modal neural-search", ) diff --git a/server/clip_server/__init__.py b/server/clip_server/__init__.py index 21320a81c..fa3ddd8c5 100644 --- a/server/clip_server/__init__.py +++ b/server/clip_server/__init__.py @@ -1 +1 @@ -__version__ = '0.8.4' +__version__ = "0.8.4" diff --git a/server/clip_server/__main__.py b/server/clip_server/__main__.py index e844ee1fb..0a7efd398 100644 --- a/server/clip_server/__main__.py +++ b/server/clip_server/__main__.py @@ -2,21 +2,21 @@ import os import sys -if __name__ == '__main__': - if 'NO_VERSION_CHECK' not in os.environ: +if __name__ == "__main__": + if "NO_VERSION_CHECK" not in os.environ: from clip_server.helper import is_latest_version - is_latest_version(github_repo='clip-as-service') + is_latest_version(github_repo="clip-as-service") from jina import Flow if len(sys.argv) > 1: - if sys.argv[1] == '-i': + if sys.argv[1] == "-i": _input = sys.stdin.read() else: _input = sys.argv[1] else: - _input = 'torch-flow.yml' + _input = "torch-flow.yml" f = Flow.load_config( _input, diff --git a/server/clip_server/executors/clip_onnx.py b/server/clip_server/executors/clip_onnx.py index 204eb7648..507f87fb0 100644 --- a/server/clip_server/executors/clip_onnx.py +++ b/server/clip_server/executors/clip_onnx.py @@ -21,11 +21,11 @@ class CLIPEncoder(Executor): def __init__( self, - name: str = 'ViT-B-32::openai', + name: str = "ViT-B-32::openai", device: Optional[str] = None, num_worker_preprocess: int = 4, minibatch_size: int = 32, - access_paths: str = '@r', + access_paths: str = "@r", model_path: Optional[str] = None, dtype: Optional[str] = None, **kwargs, @@ -48,19 +48,19 @@ def __init__( import torch if not device: - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" self._device = device if not dtype: - dtype = 'fp32' if self._device in ('cpu', torch.device('cpu')) else 'fp16' + dtype = "fp32" if self._device in ("cpu", torch.device("cpu")) else "fp16" self._dtype = dtype self._minibatch_size = minibatch_size self._access_paths = access_paths - if 'traversal_paths' in kwargs: + if "traversal_paths" in kwargs: warnings.warn( - f'`traversal_paths` is deprecated. Use `access_paths` instead.' + f"`traversal_paths` is deprecated. Use `access_paths` instead." ) - self._access_paths = kwargs['traversal_paths'] + self._access_paths = kwargs["traversal_paths"] self._num_worker_preprocess = num_worker_preprocess self._pool = ThreadPool(processes=num_worker_preprocess) @@ -71,11 +71,11 @@ def __init__( self._image_transform = clip._transform_blob(self._model.image_size) # define the priority order for the execution providers - providers = ['CPUExecutionProvider'] + providers = ["CPUExecutionProvider"] # prefer CUDA Execution Provider over CPU Execution Provider - if self._device.startswith('cuda'): - providers.insert(0, 'CUDAExecutionProvider') + if self._device.startswith("cuda"): + providers.insert(0, "CUDAExecutionProvider") sess_options = ort.SessionOptions() @@ -84,16 +84,16 @@ def __init__( ort.GraphOptimizationLevel.ORT_ENABLE_ALL ) - if not self._device.startswith('cuda') and ( - 'OMP_NUM_THREADS' not in os.environ - and hasattr(self.runtime_args, 'replicas') + if not self._device.startswith("cuda") and ( + "OMP_NUM_THREADS" not in os.environ + and hasattr(self.runtime_args, "replicas") ): - replicas = getattr(self.runtime_args, 'replicas', 1) + replicas = getattr(self.runtime_args, "replicas", 1) num_threads = max(1, torch.get_num_threads() * 2 // replicas) if num_threads < 2: warnings.warn( - f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in ' - f'sub-optimal performance.' + f"Too many replicas ({replicas}) vs too few threads {num_threads} may result in " + f"sub-optimal performance." ) # Run the operators in the graph in parallel (not support the CUDA Execution Provider) @@ -110,12 +110,12 @@ def __init__( if not self.tracer: self.tracer = NoOpTracer() - def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): + def _preproc_images(self, docs: "DocumentArray", drop_image_content: bool): with self.monitor( - name='preprocess_images_seconds', - documentation='images preprocess time in seconds', + name="preprocess_images_seconds", + documentation="images preprocess time in seconds", ): - with self.tracer.start_as_current_span('preprocess_images'): + with self.tracer.start_as_current_span("preprocess_images"): return preproc_image( docs, preprocess_fn=self._image_transform, @@ -124,56 +124,56 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): dtype=self._dtype, ) - def _preproc_texts(self, docs: 'DocumentArray'): + def _preproc_texts(self, docs: "DocumentArray"): with self.monitor( - name='preprocess_texts_seconds', - documentation='texts preprocess time in seconds', + name="preprocess_texts_seconds", + documentation="texts preprocess time in seconds", ): - with self.tracer.start_as_current_span('preprocess_images'): + with self.tracer.start_as_current_span("preprocess_images"): return preproc_text(docs, tokenizer=self._tokenizer, return_np=True) - @requests(on='/rank') - async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - _drop_image_content = parameters.get('drop_image_content', False) - await self.encode(docs['@r,m'], drop_image_content=_drop_image_content) + @requests(on="/rank") + async def rank(self, docs: "DocumentArray", parameters: Dict, **kwargs): + _drop_image_content = parameters.get("drop_image_content", False) + await self.encode(docs["@r,m"], drop_image_content=_drop_image_content) set_rank(docs) @requests async def encode( self, - docs: 'DocumentArray', + docs: "DocumentArray", tracing_context=None, parameters: Dict = {}, **kwargs, ): with self.tracer.start_as_current_span( - 'encode', context=tracing_context + "encode", context=tracing_context ) as span: - span.set_attribute('device', self._device) - span.set_attribute('runtime', 'onnx') - access_paths = parameters.get('access_paths', self._access_paths) - if 'traversal_paths' in parameters: + span.set_attribute("device", self._device) + span.set_attribute("runtime", "onnx") + access_paths = parameters.get("access_paths", self._access_paths) + if "traversal_paths" in parameters: warnings.warn( - f'`traversal_paths` is deprecated. Use `access_paths` instead.' + f"`traversal_paths` is deprecated. Use `access_paths` instead." ) - access_paths = parameters['traversal_paths'] - _drop_image_content = parameters.get('drop_image_content', False) + access_paths = parameters["traversal_paths"] + _drop_image_content = parameters.get("drop_image_content", False) _img_da = DocumentArray() _txt_da = DocumentArray() for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) - with self.tracer.start_as_current_span('inference') as inference_span: - inference_span.set_attribute('drop_image_content', _drop_image_content) - inference_span.set_attribute('minibatch_size', self._minibatch_size) - inference_span.set_attribute('has_img_da', True if _img_da else False) - inference_span.set_attribute('has_txt_da', True if _txt_da else False) + with self.tracer.start_as_current_span("inference") as inference_span: + inference_span.set_attribute("drop_image_content", _drop_image_content) + inference_span.set_attribute("minibatch_size", self._minibatch_size) + inference_span.set_attribute("has_img_da", True if _img_da else False) + inference_span.set_attribute("has_txt_da", True if _txt_da else False) # for image if _img_da: with self.tracer.start_as_current_span( - 'img_minibatch_encoding' + "img_minibatch_encoding" ) as img_encode_span: for minibatch, batch_data in _img_da.map_batch( partial( @@ -184,8 +184,8 @@ async def encode( pool=self._pool, ): with self.monitor( - name='encode_images_seconds', - documentation='images encode time in seconds', + name="encode_images_seconds", + documentation="images encode time in seconds", ): minibatch.embeddings = self._model.encode_image( batch_data @@ -194,7 +194,7 @@ async def encode( # for text if _txt_da: with self.tracer.start_as_current_span( - 'txt_minibatch_encoding' + "txt_minibatch_encoding" ) as txt_encode_span: for minibatch, batch_data in _txt_da.map_batch( self._preproc_texts, @@ -202,8 +202,8 @@ async def encode( pool=self._pool, ): with self.monitor( - name='encode_texts_seconds', - documentation='texts encode time in seconds', + name="encode_texts_seconds", + documentation="texts encode time in seconds", ): minibatch.embeddings = self._model.encode_text( batch_data diff --git a/server/clip_server/executors/clip_tensorrt.py b/server/clip_server/executors/clip_tensorrt.py index 28f4ddb9c..0cd801c53 100644 --- a/server/clip_server/executors/clip_tensorrt.py +++ b/server/clip_server/executors/clip_tensorrt.py @@ -20,11 +20,11 @@ class CLIPEncoder(Executor): def __init__( self, - name: str = 'ViT-B-32::openai', - device: str = 'cuda', + name: str = "ViT-B-32::openai", + device: str = "cuda", num_worker_preprocess: int = 4, minibatch_size: int = 32, - access_paths: str = '@r', + access_paths: str = "@r", **kwargs, ): """ @@ -44,19 +44,19 @@ def __init__( self._minibatch_size = minibatch_size self._access_paths = access_paths - if 'traversal_paths' in kwargs: + if "traversal_paths" in kwargs: warnings.warn( - f'`traversal_paths` is deprecated. Use `access_paths` instead.' + f"`traversal_paths` is deprecated. Use `access_paths` instead." ) - self._access_paths = kwargs['traversal_paths'] + self._access_paths = kwargs["traversal_paths"] self._device = device import torch - assert self._device.startswith('cuda'), ( - f'can not perform inference on {self._device}' - f' with Nvidia TensorRT as backend' + assert self._device.startswith("cuda"), ( + f"can not perform inference on {self._device}" + f" with Nvidia TensorRT as backend" ) assert ( @@ -73,12 +73,12 @@ def __init__( if not self.tracer: self.tracer = NoOpTracer() - def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): + def _preproc_images(self, docs: "DocumentArray", drop_image_content: bool): with self.monitor( - name='preprocess_images_seconds', - documentation='images preprocess time in seconds', + name="preprocess_images_seconds", + documentation="images preprocess time in seconds", ): - with self.tracer.start_as_current_span('preprocess_images'): + with self.tracer.start_as_current_span("preprocess_images"): return preproc_image( docs, preprocess_fn=self._image_transform, @@ -87,12 +87,12 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): drop_image_content=drop_image_content, ) - def _preproc_texts(self, docs: 'DocumentArray'): + def _preproc_texts(self, docs: "DocumentArray"): with self.monitor( - name='preprocess_texts_seconds', - documentation='texts preprocess time in seconds', + name="preprocess_texts_seconds", + documentation="texts preprocess time in seconds", ): - with self.tracer.start_as_current_span('preprocess_images'): + with self.tracer.start_as_current_span("preprocess_images"): return preproc_text( docs, tokenizer=self._tokenizer, @@ -100,48 +100,48 @@ def _preproc_texts(self, docs: 'DocumentArray'): return_np=False, ) - @requests(on='/rank') - async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - _drop_image_content = parameters.get('drop_image_content', False) - await self.encode(docs['@r,m'], drop_image_content=_drop_image_content) + @requests(on="/rank") + async def rank(self, docs: "DocumentArray", parameters: Dict, **kwargs): + _drop_image_content = parameters.get("drop_image_content", False) + await self.encode(docs["@r,m"], drop_image_content=_drop_image_content) set_rank(docs) @requests async def encode( self, - docs: 'DocumentArray', + docs: "DocumentArray", tracing_context=None, parameters: Dict = {}, **kwargs, ): with self.tracer.start_as_current_span( - 'encode', context=tracing_context + "encode", context=tracing_context ) as span: - span.set_attribute('device', self._device) - span.set_attribute('runtime', 'tensorrt') - access_paths = parameters.get('access_paths', self._access_paths) - if 'traversal_paths' in parameters: + span.set_attribute("device", self._device) + span.set_attribute("runtime", "tensorrt") + access_paths = parameters.get("access_paths", self._access_paths) + if "traversal_paths" in parameters: warnings.warn( - f'`traversal_paths` is deprecated. Use `access_paths` instead.' + f"`traversal_paths` is deprecated. Use `access_paths` instead." ) - access_paths = parameters['traversal_paths'] - _drop_image_content = parameters.get('drop_image_content', False) + access_paths = parameters["traversal_paths"] + _drop_image_content = parameters.get("drop_image_content", False) _img_da = DocumentArray() _txt_da = DocumentArray() for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) - with self.tracer.start_as_current_span('inference') as inference_span: - inference_span.set_attribute('drop_image_content', _drop_image_content) - inference_span.set_attribute('minibatch_size', self._minibatch_size) - inference_span.set_attribute('has_img_da', True if _img_da else False) - inference_span.set_attribute('has_txt_da', True if _txt_da else False) + with self.tracer.start_as_current_span("inference") as inference_span: + inference_span.set_attribute("drop_image_content", _drop_image_content) + inference_span.set_attribute("minibatch_size", self._minibatch_size) + inference_span.set_attribute("has_img_da", True if _img_da else False) + inference_span.set_attribute("has_txt_da", True if _txt_da else False) # for image if _img_da: with self.tracer.start_as_current_span( - 'img_minibatch_encoding' + "img_minibatch_encoding" ) as img_encode_span: for minibatch, batch_data in _img_da.map_batch( partial( @@ -152,8 +152,8 @@ async def encode( pool=self._pool, ): with self.monitor( - name='encode_images_seconds', - documentation='images encode time in seconds', + name="encode_images_seconds", + documentation="images encode time in seconds", ): minibatch.embeddings = ( self._model.encode_image(batch_data) @@ -166,7 +166,7 @@ async def encode( # for text if _txt_da: with self.tracer.start_as_current_span( - 'txt_minibatch_encoding' + "txt_minibatch_encoding" ) as txt_encode_span: for minibatch, batch_data in _txt_da.map_batch( self._preproc_texts, @@ -174,8 +174,8 @@ async def encode( pool=self._pool, ): with self.monitor( - name='encode_texts_seconds', - documentation='texts encode time in seconds', + name="encode_texts_seconds", + documentation="texts encode time in seconds", ): minibatch.embeddings = ( self._model.encode_text(batch_data) diff --git a/server/clip_server/executors/clip_torch.py b/server/clip_server/executors/clip_torch.py index 78fe4f772..338813d40 100644 --- a/server/clip_server/executors/clip_torch.py +++ b/server/clip_server/executors/clip_torch.py @@ -23,12 +23,12 @@ class CLIPEncoder(Executor): def __init__( self, - name: str = 'ViT-B-32::openai', + name: str = "ViT-B-32::openai", device: Optional[str] = None, jit: bool = False, num_worker_preprocess: int = 4, minibatch_size: int = 32, - access_paths: str = '@r', + access_paths: str = "@r", dtype: Optional[Union[str, torch.dtype]] = None, **kwargs, ): @@ -48,35 +48,35 @@ def __init__( self._minibatch_size = minibatch_size self._access_paths = access_paths - if 'traversal_paths' in kwargs: + if "traversal_paths" in kwargs: warnings.warn( - f'`traversal_paths` is deprecated. Use `access_paths` instead.' + f"`traversal_paths` is deprecated. Use `access_paths` instead." ) - self._access_paths = kwargs['traversal_paths'] + self._access_paths = kwargs["traversal_paths"] if not device: - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" self._device = device if isinstance(dtype, str): dtype = __cast_dtype__.get(dtype) elif not dtype: dtype = ( torch.float32 - if self._device in ('cpu', torch.device('cpu')) + if self._device in ("cpu", torch.device("cpu")) else torch.float16 ) self._dtype = dtype - if not self._device.startswith('cuda') and ( - 'OMP_NUM_THREADS' not in os.environ - and hasattr(self.runtime_args, 'replicas') + if not self._device.startswith("cuda") and ( + "OMP_NUM_THREADS" not in os.environ + and hasattr(self.runtime_args, "replicas") ): - replicas = getattr(self.runtime_args, 'replicas', 1) + replicas = getattr(self.runtime_args, "replicas", 1) num_threads = max(1, torch.get_num_threads() // replicas) if num_threads < 2: warnings.warn( - f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in ' - f'sub-optimal performance.' + f"Too many replicas ({replicas}) vs too few threads {num_threads} may result in " + f"sub-optimal performance." ) # NOTE: make sure to set the threads right after the torch import, @@ -97,12 +97,12 @@ def __init__( if not self.tracer: self.tracer = NoOpTracer() - def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): + def _preproc_images(self, docs: "DocumentArray", drop_image_content: bool): with self.monitor( - name='preprocess_images_seconds', - documentation='images preprocess time in seconds', + name="preprocess_images_seconds", + documentation="images preprocess time in seconds", ): - with self.tracer.start_as_current_span('preprocess_images'): + with self.tracer.start_as_current_span("preprocess_images"): return preproc_image( docs, preprocess_fn=self._image_transform, @@ -112,12 +112,12 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool): dtype=self._dtype, ) - def _preproc_texts(self, docs: 'DocumentArray'): + def _preproc_texts(self, docs: "DocumentArray"): with self.monitor( - name='preprocess_texts_seconds', - documentation='texts preprocess time in seconds', + name="preprocess_texts_seconds", + documentation="texts preprocess time in seconds", ): - with self.tracer.start_as_current_span('preprocess_images'): + with self.tracer.start_as_current_span("preprocess_images"): return preproc_text( docs, tokenizer=self._tokenizer, @@ -125,58 +125,58 @@ def _preproc_texts(self, docs: 'DocumentArray'): return_np=False, ) - @requests(on='/rank') - async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs): - _drop_image_content = parameters.get('drop_image_content', False) - await self.encode(docs['@r,m'], drop_image_content=_drop_image_content) + @requests(on="/rank") + async def rank(self, docs: "DocumentArray", parameters: Dict, **kwargs): + _drop_image_content = parameters.get("drop_image_content", False) + await self.encode(docs["@r,m"], drop_image_content=_drop_image_content) set_rank(docs) @requests async def encode( self, - docs: 'DocumentArray', + docs: "DocumentArray", tracing_context=None, parameters: Dict = {}, **kwargs, ): with self.tracer.start_as_current_span( - 'encode', context=tracing_context + "encode", context=tracing_context ) as span: - span.set_attribute('device', self._device) - span.set_attribute('runtime', 'torch') - access_paths = parameters.get('access_paths', self._access_paths) - if 'traversal_paths' in parameters: + span.set_attribute("device", self._device) + span.set_attribute("runtime", "torch") + access_paths = parameters.get("access_paths", self._access_paths) + if "traversal_paths" in parameters: warnings.warn( - f'`traversal_paths` is deprecated. Use `access_paths` instead.' + f"`traversal_paths` is deprecated. Use `access_paths` instead." ) - access_paths = parameters['traversal_paths'] - _drop_image_content = parameters.get('drop_image_content', False) + access_paths = parameters["traversal_paths"] + _drop_image_content = parameters.get("drop_image_content", False) _img_da = DocumentArray() _txt_da = DocumentArray() for d in docs[access_paths]: split_img_txt_da(d, _img_da, _txt_da) - with self.tracer.start_as_current_span('inference') as inference_span: + with self.tracer.start_as_current_span("inference") as inference_span: with torch.inference_mode(): inference_span.set_attribute( - 'drop_image_content', _drop_image_content + "drop_image_content", _drop_image_content ) - inference_span.set_attribute('minibatch_size', self._minibatch_size) + inference_span.set_attribute("minibatch_size", self._minibatch_size) inference_span.set_attribute( - 'has_img_da', True if _img_da else False + "has_img_da", True if _img_da else False ) inference_span.set_attribute( - 'has_txt_da', True if _txt_da else False + "has_txt_da", True if _txt_da else False ) # for image if _img_da: with self.tracer.start_as_current_span( - 'img_minibatch_encoding' + "img_minibatch_encoding" ) as img_encode_span: img_encode_span.set_attribute( - 'num_pool_workers', self._num_worker_preprocess + "num_pool_workers", self._num_worker_preprocess ) for minibatch, batch_data in _img_da.map_batch( partial( @@ -187,8 +187,8 @@ async def encode( pool=self._pool, ): with self.monitor( - name='encode_images_seconds', - documentation='images encode time in seconds', + name="encode_images_seconds", + documentation="images encode time in seconds", ): minibatch.embeddings = ( self._model.encode_image(**batch_data) @@ -200,10 +200,10 @@ async def encode( # for text if _txt_da: with self.tracer.start_as_current_span( - 'txt_minibatch_encoding' + "txt_minibatch_encoding" ) as txt_encode_span: txt_encode_span.set_attribute( - 'num_pool_workers', self._num_worker_preprocess + "num_pool_workers", self._num_worker_preprocess ) for minibatch, batch_data in _txt_da.map_batch( self._preproc_texts, @@ -211,8 +211,8 @@ async def encode( pool=self._pool, ): with self.monitor( - name='encode_texts_seconds', - documentation='texts encode time in seconds', + name="encode_texts_seconds", + documentation="texts encode time in seconds", ): minibatch.embeddings = ( self._model.encode_text(**batch_data) diff --git a/server/clip_server/executors/helper.py b/server/clip_server/executors/helper.py index fa3846cad..b1293352c 100644 --- a/server/clip_server/executors/helper.py +++ b/server/clip_server/executors/helper.py @@ -9,7 +9,7 @@ from clip_server.model.tokenization import Tokenizer -def numpy_softmax(x: 'np.ndarray', axis: int = -1) -> 'np.ndarray': +def numpy_softmax(x: "np.ndarray", axis: int = -1) -> "np.ndarray": max = np.max(x, axis=axis, keepdims=True) e_x = np.exp(x - max) div = np.sum(e_x, axis=axis, keepdims=True) @@ -18,14 +18,13 @@ def numpy_softmax(x: 'np.ndarray', axis: int = -1) -> 'np.ndarray': def preproc_image( - da: 'DocumentArray', + da: "DocumentArray", preprocess_fn: Callable, - device: str = 'cpu', + device: str = "cpu", return_np: bool = False, drop_image_content: bool = False, dtype: Union[str, torch.dtype] = torch.float32, -) -> Tuple['DocumentArray', Dict]: - +) -> Tuple["DocumentArray", Dict]: if isinstance(dtype, str): dtype = __cast_dtype__.get(dtype) @@ -35,7 +34,7 @@ def preproc_image( content = d.content if d.tensor is not None: d.convert_image_tensor_to_blob() - elif d.content_type != 'blob' and d.uri: + elif d.content_type != "blob" and d.uri: # in case user uses HTTP protocol and send data via curl not using .blob (base64), but in .uri d.load_uri_to_blob() @@ -44,7 +43,7 @@ def preproc_image( # recover doc content d.content = content if drop_image_content: - d.pop('blob', 'tensor') + d.pop("blob", "tensor") tensors_batch = torch.stack(tensors_batch).type(dtype) @@ -53,33 +52,32 @@ def preproc_image( else: tensors_batch = tensors_batch.to(device) - return da, {'pixel_values': tensors_batch} + return da, {"pixel_values": tensors_batch} def preproc_text( - da: 'DocumentArray', - tokenizer: 'Tokenizer', - device: str = 'cpu', + da: "DocumentArray", + tokenizer: "Tokenizer", + device: str = "cpu", return_np: bool = False, -) -> Tuple['DocumentArray', Dict]: - +) -> Tuple["DocumentArray", Dict]: inputs = tokenizer(da.texts) - inputs['input_ids'] = inputs['input_ids'].detach() + inputs["input_ids"] = inputs["input_ids"].detach() if return_np: - inputs['input_ids'] = inputs['input_ids'].cpu().numpy().astype(np.int32) - inputs['attention_mask'] = ( - inputs['attention_mask'].cpu().numpy().astype(np.int32) + inputs["input_ids"] = inputs["input_ids"].cpu().numpy().astype(np.int32) + inputs["attention_mask"] = ( + inputs["attention_mask"].cpu().numpy().astype(np.int32) ) else: - inputs['input_ids'] = inputs['input_ids'].to(device) - inputs['attention_mask'] = inputs['attention_mask'].to(device) + inputs["input_ids"] = inputs["input_ids"].to(device) + inputs["attention_mask"] = inputs["attention_mask"].to(device) - da[:, 'mime_type'] = 'text' + da[:, "mime_type"] = "text" return da, inputs -def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'): +def split_img_txt_da(doc: "Document", img_da: "DocumentArray", txt_da: "DocumentArray"): if doc.text: txt_da.append(doc) elif doc.blob or (doc.tensor is not None) or doc.uri: @@ -88,7 +86,7 @@ def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'Document def set_rank(docs, _logit_scale=np.exp(4.60517)): queries = docs - candidates = docs['@m'] + candidates = docs["@m"] query_embeddings = queries.embeddings # Q X D candidate_embeddings = candidates.embeddings # C = Sum(C_q1, C_q2, C_q3,...) x D @@ -97,7 +95,6 @@ def set_rank(docs, _logit_scale=np.exp(4.60517)): ) # Q x C Block matix start_idx = 0 for q, _cosine_scores in zip(docs, cosine_scores): - _candidates = q.matches end_idx = start_idx + len(_candidates) @@ -107,18 +104,18 @@ def set_rank(docs, _logit_scale=np.exp(4.60517)): for c, _c_score, _s_score in zip( _candidates, _candidate_cosines, _candidate_softmaxs ): - c.scores['clip_score'].value = _s_score - c.scores['clip_score'].op_name = 'softmax' + c.scores["clip_score"].value = _s_score + c.scores["clip_score"].op_name = "softmax" - c.scores['clip_score_cosine'].value = _c_score - c.scores['clip_score_cosine'].op_name = 'cosine' + c.scores["clip_score_cosine"].value = _c_score + c.scores["clip_score_cosine"].op_name = "cosine" start_idx = end_idx _candidates.embeddings = None # remove embedding to save bandwidth final = sorted( - _candidates, key=lambda _m: _m.scores['clip_score'].value, reverse=True + _candidates, key=lambda _m: _m.scores["clip_score"].value, reverse=True ) q.matches = final diff --git a/server/clip_server/helper.py b/server/clip_server/helper.py index d5108ffc0..5d50ef085 100644 --- a/server/clip_server/helper.py +++ b/server/clip_server/helper.py @@ -12,44 +12,43 @@ __resources_path__ = os.path.join( os.path.dirname( - sys.modules.get('clip_server').__file__ - if 'clip_server' in sys.modules + sys.modules.get("clip_server").__file__ + if "clip_server" in sys.modules else __file__ ), - 'resources', + "resources", ) -__cast_dtype__ = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16} +__cast_dtype__ = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} def _version_check(package: str = None, github_repo: str = None): try: - if not package: - package = vars(sys.modules[__name__])['__package__'] + package = vars(sys.modules[__name__])["__package__"] if not github_repo: github_repo = package cur_ver = Version(pkg_resources.get_distribution(package).version) req = Request( - f'https://pypi.python.org/pypi/{package}/json', - headers={'User-Agent': 'Mozilla/5.0'}, + f"https://pypi.python.org/pypi/{package}/json", + headers={"User-Agent": "Mozilla/5.0"}, ) with urlopen( req, timeout=1 ) as resp: # 'with' is important to close the resource after use j = json.load(resp) - releases = j.get('releases', {}) + releases = j.get("releases", {}) latest_release_ver = max( - Version(v) for v in releases.keys() if '.dev' not in v + Version(v) for v in releases.keys() if ".dev" not in v ) if cur_ver < latest_release_ver: print( Panel( - f'You are using [b]{package} {cur_ver}[/b], but [bold green]{latest_release_ver}[/] is available. ' - f'You may upgrade it via [b]pip install -U {package}[/b]. [link=https://github.com/jina-ai/{github_repo}/releases]Read Changelog here[/link].', - title=':new: New version available!', + f"You are using [b]{package} {cur_ver}[/b], but [bold green]{latest_release_ver}[/] is available. " + f"You may upgrade it via [b]pip install -U {package}[/b]. [link=https://github.com/jina-ai/{github_repo}/releases]Read Changelog here[/link].", + title=":new: New version available!", width=50, ) ) diff --git a/server/clip_server/model/clip.py b/server/clip_server/model/clip.py index 0af92054c..860ba319a 100644 --- a/server/clip_server/model/clip.py +++ b/server/clip_server/model/clip.py @@ -15,7 +15,7 @@ def _convert_image_to_rgb(image): - return image.convert('RGB') + return image.convert("RGB") def _blob2image(blob): diff --git a/server/clip_server/model/clip_model.py b/server/clip_server/model/clip_model.py index fc40ee75b..3bc5433de 100644 --- a/server/clip_server/model/clip_model.py +++ b/server/clip_server/model/clip_model.py @@ -41,11 +41,11 @@ def __new__(cls, name: str, **kwargs): instance = super().__new__(CNClipModel) else: raise ValueError( - 'CLIP model {} not found; below is a list of all available models:\n{}'.format( + "CLIP model {} not found; below is a list of all available models:\n{}".format( name, - ''.join( + "".join( [ - '\t- {}\n'.format(i) + "\t- {}\n".format(i) for i in list(_OPENCLIP_MODELS.keys()) + list(_MULTILINGUALCLIP_MODELS.keys()) + list(_CNCLIP_MODELS.keys()) diff --git a/server/clip_server/model/clip_onnx.py b/server/clip_server/model/clip_onnx.py index 90a9c0f05..0927869ae 100644 --- a/server/clip_server/model/clip_onnx.py +++ b/server/clip_server/model/clip_onnx.py @@ -9,192 +9,192 @@ from clip_server.model.clip_model import BaseCLIPModel _S3_BUCKET = ( - 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/' # Deprecated + "https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/" # Deprecated ) -_S3_BUCKET_V2 = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models-436c69702d61732d53657276696365/onnx/' +_S3_BUCKET_V2 = "https://clip-as-service.s3.us-east-2.amazonaws.com/models-436c69702d61732d53657276696365/onnx/" _MODELS = { - 'RN50::openai': ( - ('RN50/textual.onnx', '722418bfe47a1f5c79d1f44884bb3103'), - ('RN50/visual.onnx', '5761475db01c3abb68a5a805662dcd10'), + "RN50::openai": ( + ("RN50/textual.onnx", "722418bfe47a1f5c79d1f44884bb3103"), + ("RN50/visual.onnx", "5761475db01c3abb68a5a805662dcd10"), ), - 'RN50::yfcc15m': ( - ('RN50-yfcc15m/textual.onnx', '4ff2ea7228b9d2337b5440d1955c2108'), - ('RN50-yfcc15m/visual.onnx', '87daa9b4a67449b5390a9a73b8c15772'), + "RN50::yfcc15m": ( + ("RN50-yfcc15m/textual.onnx", "4ff2ea7228b9d2337b5440d1955c2108"), + ("RN50-yfcc15m/visual.onnx", "87daa9b4a67449b5390a9a73b8c15772"), ), - 'RN50::cc12m': ( - ('RN50-cc12m/textual.onnx', '78fa0ae0ea47aca4b8864f709c48dcec'), - ('RN50-cc12m/visual.onnx', '0e04bf92f3c181deea2944e322ebee77'), + "RN50::cc12m": ( + ("RN50-cc12m/textual.onnx", "78fa0ae0ea47aca4b8864f709c48dcec"), + ("RN50-cc12m/visual.onnx", "0e04bf92f3c181deea2944e322ebee77"), ), - 'RN101::openai': ( - ('RN101/textual.onnx', '2d9efb7d184c0d68a369024cedfa97af'), - ('RN101/visual.onnx', '0297ebc773af312faab54f8b5a622d71'), + "RN101::openai": ( + ("RN101/textual.onnx", "2d9efb7d184c0d68a369024cedfa97af"), + ("RN101/visual.onnx", "0297ebc773af312faab54f8b5a622d71"), ), - 'RN101::yfcc15m': ( - ('RN101-yfcc15m/textual.onnx', '7aa2a4e3d5b960998a397a6712389f08'), - ('RN101-yfcc15m/visual.onnx', '681a72dd91c9c79464947bf29b623cb4'), + "RN101::yfcc15m": ( + ("RN101-yfcc15m/textual.onnx", "7aa2a4e3d5b960998a397a6712389f08"), + ("RN101-yfcc15m/visual.onnx", "681a72dd91c9c79464947bf29b623cb4"), ), - 'RN50x4::openai': ( - ('RN50x4/textual.onnx', 'd9d63d3fe35fb14d4affaa2c4e284005'), - ('RN50x4/visual.onnx', '16afe1e35b85ad862e8bbdb12265c9cb'), + "RN50x4::openai": ( + ("RN50x4/textual.onnx", "d9d63d3fe35fb14d4affaa2c4e284005"), + ("RN50x4/visual.onnx", "16afe1e35b85ad862e8bbdb12265c9cb"), ), - 'RN50x16::openai': ( - ('RN50x16/textual.onnx', '1525785494ff5307cadc6bfa56db6274'), - ('RN50x16/visual.onnx', '2a293d9c3582f8abe29c9999e47d1091'), + "RN50x16::openai": ( + ("RN50x16/textual.onnx", "1525785494ff5307cadc6bfa56db6274"), + ("RN50x16/visual.onnx", "2a293d9c3582f8abe29c9999e47d1091"), ), - 'RN50x64::openai': ( - ('RN50x64/textual.onnx', '3ae8ade74578eb7a77506c11bfbfaf2c'), - ('RN50x64/visual.onnx', '1341f10b50b3aca6d2d5d13982cabcfc'), + "RN50x64::openai": ( + ("RN50x64/textual.onnx", "3ae8ade74578eb7a77506c11bfbfaf2c"), + ("RN50x64/visual.onnx", "1341f10b50b3aca6d2d5d13982cabcfc"), ), - 'ViT-B-32::openai': ( - ('ViT-B-32/textual.onnx', 'bd6d7871e8bb95f3cc83aff3398d7390'), - ('ViT-B-32/visual.onnx', '88c6f38e522269d6c04a85df18e6370c'), + "ViT-B-32::openai": ( + ("ViT-B-32/textual.onnx", "bd6d7871e8bb95f3cc83aff3398d7390"), + ("ViT-B-32/visual.onnx", "88c6f38e522269d6c04a85df18e6370c"), ), - 'ViT-B-32::laion2b_e16': ( - ('ViT-B-32-laion2b_e16/textual.onnx', 'aa6eac88fe77d21f337e806417957497'), - ('ViT-B-32-laion2b_e16/visual.onnx', '0cdc00a9dfad560153d40aced9df0c8f'), + "ViT-B-32::laion2b_e16": ( + ("ViT-B-32-laion2b_e16/textual.onnx", "aa6eac88fe77d21f337e806417957497"), + ("ViT-B-32-laion2b_e16/visual.onnx", "0cdc00a9dfad560153d40aced9df0c8f"), ), - 'ViT-B-32::laion400m_e31': ( - ('ViT-B-32-laion400m_e31/textual.onnx', '832f417bf1b3f1ced8f9958eda71665c'), - ('ViT-B-32-laion400m_e31/visual.onnx', '62326b925ae342313d4cc99c2741b313'), + "ViT-B-32::laion400m_e31": ( + ("ViT-B-32-laion400m_e31/textual.onnx", "832f417bf1b3f1ced8f9958eda71665c"), + ("ViT-B-32-laion400m_e31/visual.onnx", "62326b925ae342313d4cc99c2741b313"), ), - 'ViT-B-32::laion400m_e32': ( - ('ViT-B-32-laion400m_e32/textual.onnx', '93284915937ba42a2b52ae8d3e5283a0'), - ('ViT-B-32-laion400m_e32/visual.onnx', 'db220821a31fe9795fd8c2ba419078c5'), + "ViT-B-32::laion400m_e32": ( + ("ViT-B-32-laion400m_e32/textual.onnx", "93284915937ba42a2b52ae8d3e5283a0"), + ("ViT-B-32-laion400m_e32/visual.onnx", "db220821a31fe9795fd8c2ba419078c5"), ), - 'ViT-B-32::laion2b-s34b-b79k': ( - ('ViT-B-32-laion2b-s34b-b79k/textual.onnx', '84af5ae53da56464c76e67fe50fddbe9'), - ('ViT-B-32-laion2b-s34b-b79k/visual.onnx', 'a2d4cbd1cf2632cd09ffce9b40bfd8bd'), + "ViT-B-32::laion2b-s34b-b79k": ( + ("ViT-B-32-laion2b-s34b-b79k/textual.onnx", "84af5ae53da56464c76e67fe50fddbe9"), + ("ViT-B-32-laion2b-s34b-b79k/visual.onnx", "a2d4cbd1cf2632cd09ffce9b40bfd8bd"), ), - 'ViT-B-16::openai': ( - ('ViT-B-16/textual.onnx', '6f0976629a446f95c0c8767658f12ebe'), - ('ViT-B-16/visual.onnx', 'd5c03bfeef1abbd9bede54a8f6e1eaad'), + "ViT-B-16::openai": ( + ("ViT-B-16/textual.onnx", "6f0976629a446f95c0c8767658f12ebe"), + ("ViT-B-16/visual.onnx", "d5c03bfeef1abbd9bede54a8f6e1eaad"), ), - 'ViT-B-16::laion400m_e31': ( - ('ViT-B-16-laion400m_e31/textual.onnx', '5db27763c06c06c727c90240264bf4f7'), - ('ViT-B-16-laion400m_e31/visual.onnx', '04a6a780d855a36eee03abca64cd5361'), + "ViT-B-16::laion400m_e31": ( + ("ViT-B-16-laion400m_e31/textual.onnx", "5db27763c06c06c727c90240264bf4f7"), + ("ViT-B-16-laion400m_e31/visual.onnx", "04a6a780d855a36eee03abca64cd5361"), ), - 'ViT-B-16::laion400m_e32': ( - ('ViT-B-16-laion400m_e32/textual.onnx', '9abe000a51b6f1cbaac8fde601b16725'), - ('ViT-B-16-laion400m_e32/visual.onnx', 'd38c144ac3ad7fbc1966f88ff8fa522f'), + "ViT-B-16::laion400m_e32": ( + ("ViT-B-16-laion400m_e32/textual.onnx", "9abe000a51b6f1cbaac8fde601b16725"), + ("ViT-B-16-laion400m_e32/visual.onnx", "d38c144ac3ad7fbc1966f88ff8fa522f"), ), - 'ViT-B-16-plus-240::laion400m_e31': ( + "ViT-B-16-plus-240::laion400m_e31": ( ( - 'ViT-B-16-plus-240-laion400m_e31/textual.onnx', - '2b524e7a530a98010cc7e57756937c5c', + "ViT-B-16-plus-240-laion400m_e31/textual.onnx", + "2b524e7a530a98010cc7e57756937c5c", ), ( - 'ViT-B-16-plus-240-laion400m_e31/visual.onnx', - 'a78989da3300fd0c398a9877dd26a9f1', + "ViT-B-16-plus-240-laion400m_e31/visual.onnx", + "a78989da3300fd0c398a9877dd26a9f1", ), ), - 'ViT-B-16-plus-240::laion400m_e32': ( + "ViT-B-16-plus-240::laion400m_e32": ( ( - 'ViT-B-16-plus-240-laion400m_e32/textual.onnx', - '53c8d26726b386ca0749207876482907', + "ViT-B-16-plus-240-laion400m_e32/textual.onnx", + "53c8d26726b386ca0749207876482907", ), ( - 'ViT-B-16-plus-240-laion400m_e32/visual.onnx', - '7a32c4272c1ee46f734486570d81584b', + "ViT-B-16-plus-240-laion400m_e32/visual.onnx", + "7a32c4272c1ee46f734486570d81584b", ), ), - 'ViT-L-14::openai': ( - ('ViT-L-14/textual.onnx', '325380b31af4837c2e0d9aba2fad8e1b'), - ('ViT-L-14/visual.onnx', '53f5b319d3dc5d42572adea884e31056'), + "ViT-L-14::openai": ( + ("ViT-L-14/textual.onnx", "325380b31af4837c2e0d9aba2fad8e1b"), + ("ViT-L-14/visual.onnx", "53f5b319d3dc5d42572adea884e31056"), ), - 'ViT-L-14::laion400m_e31': ( - ('ViT-L-14-laion400m_e31/textual.onnx', '36216b85e32668ea849730a54e1e09a4'), - ('ViT-L-14-laion400m_e31/visual.onnx', '15fa5a24916e2a58325c5cf70350c300'), + "ViT-L-14::laion400m_e31": ( + ("ViT-L-14-laion400m_e31/textual.onnx", "36216b85e32668ea849730a54e1e09a4"), + ("ViT-L-14-laion400m_e31/visual.onnx", "15fa5a24916e2a58325c5cf70350c300"), ), - 'ViT-L-14::laion400m_e32': ( - ('ViT-L-14-laion400m_e32/textual.onnx', '8ba5b76ba71992923470c0261b10a67c'), - ('ViT-L-14-laion400m_e32/visual.onnx', '49db3ba92bd816001e932530ad92d76c'), + "ViT-L-14::laion400m_e32": ( + ("ViT-L-14-laion400m_e32/textual.onnx", "8ba5b76ba71992923470c0261b10a67c"), + ("ViT-L-14-laion400m_e32/visual.onnx", "49db3ba92bd816001e932530ad92d76c"), ), - 'ViT-L-14::laion2b-s32b-b82k': ( - ('ViT-L-14-laion2b-s32b-b82k/textual.onnx', 'da36a6cbed4f56abf576fdea8b6fe2ee'), - ('ViT-L-14-laion2b-s32b-b82k/visual.onnx', '1e337a190abba6a8650237dfae4740b7'), + "ViT-L-14::laion2b-s32b-b82k": ( + ("ViT-L-14-laion2b-s32b-b82k/textual.onnx", "da36a6cbed4f56abf576fdea8b6fe2ee"), + ("ViT-L-14-laion2b-s32b-b82k/visual.onnx", "1e337a190abba6a8650237dfae4740b7"), ), - 'ViT-L-14-336::openai': ( - ('ViT-L-14@336px/textual.onnx', '78fab479f136403eed0db46f3e9e7ed2'), - ('ViT-L-14@336px/visual.onnx', 'f3b1f5d55ca08d43d749e11f7e4ba27e'), + "ViT-L-14-336::openai": ( + ("ViT-L-14@336px/textual.onnx", "78fab479f136403eed0db46f3e9e7ed2"), + ("ViT-L-14@336px/visual.onnx", "f3b1f5d55ca08d43d749e11f7e4ba27e"), ), - 'ViT-H-14::laion2b-s32b-b79k': ( - ('ViT-H-14-laion2b-s32b-b79k/textual.onnx', '41e73c0c871d0e8e5d5e236f917f1ec3'), - ('ViT-H-14-laion2b-s32b-b79k/visual.zip', '38151ea5985d73de94520efef38db4e7'), + "ViT-H-14::laion2b-s32b-b79k": ( + ("ViT-H-14-laion2b-s32b-b79k/textual.onnx", "41e73c0c871d0e8e5d5e236f917f1ec3"), + ("ViT-H-14-laion2b-s32b-b79k/visual.zip", "38151ea5985d73de94520efef38db4e7"), ), - 'ViT-g-14::laion2b-s12b-b42k': ( - ('ViT-g-14-laion2b-s12b-b42k/textual.onnx', 'e597b7ab4414ecd92f715d47e79a033f'), - ('ViT-g-14-laion2b-s12b-b42k/visual.zip', '6d0ac4329de9b02474f4752a5d16ba82'), + "ViT-g-14::laion2b-s12b-b42k": ( + ("ViT-g-14-laion2b-s12b-b42k/textual.onnx", "e597b7ab4414ecd92f715d47e79a033f"), + ("ViT-g-14-laion2b-s12b-b42k/visual.zip", "6d0ac4329de9b02474f4752a5d16ba82"), ), # older version name format - 'RN50': ( - ('RN50/textual.onnx', '722418bfe47a1f5c79d1f44884bb3103'), - ('RN50/visual.onnx', '5761475db01c3abb68a5a805662dcd10'), + "RN50": ( + ("RN50/textual.onnx", "722418bfe47a1f5c79d1f44884bb3103"), + ("RN50/visual.onnx", "5761475db01c3abb68a5a805662dcd10"), ), - 'RN101': ( - ('RN101/textual.onnx', '2d9efb7d184c0d68a369024cedfa97af'), - ('RN101/visual.onnx', '0297ebc773af312faab54f8b5a622d71'), + "RN101": ( + ("RN101/textual.onnx", "2d9efb7d184c0d68a369024cedfa97af"), + ("RN101/visual.onnx", "0297ebc773af312faab54f8b5a622d71"), ), - 'RN50x4': ( - ('RN50x4/textual.onnx', 'd9d63d3fe35fb14d4affaa2c4e284005'), - ('RN50x4/visual.onnx', '16afe1e35b85ad862e8bbdb12265c9cb'), + "RN50x4": ( + ("RN50x4/textual.onnx", "d9d63d3fe35fb14d4affaa2c4e284005"), + ("RN50x4/visual.onnx", "16afe1e35b85ad862e8bbdb12265c9cb"), ), - 'RN50x16': ( - ('RN50x16/textual.onnx', '1525785494ff5307cadc6bfa56db6274'), - ('RN50x16/visual.onnx', '2a293d9c3582f8abe29c9999e47d1091'), + "RN50x16": ( + ("RN50x16/textual.onnx", "1525785494ff5307cadc6bfa56db6274"), + ("RN50x16/visual.onnx", "2a293d9c3582f8abe29c9999e47d1091"), ), - 'RN50x64': ( - ('RN50x64/textual.onnx', '3ae8ade74578eb7a77506c11bfbfaf2c'), - ('RN50x64/visual.onnx', '1341f10b50b3aca6d2d5d13982cabcfc'), + "RN50x64": ( + ("RN50x64/textual.onnx", "3ae8ade74578eb7a77506c11bfbfaf2c"), + ("RN50x64/visual.onnx", "1341f10b50b3aca6d2d5d13982cabcfc"), ), - 'ViT-B/32': ( - ('ViT-B-32/textual.onnx', 'bd6d7871e8bb95f3cc83aff3398d7390'), - ('ViT-B-32/visual.onnx', '88c6f38e522269d6c04a85df18e6370c'), + "ViT-B/32": ( + ("ViT-B-32/textual.onnx", "bd6d7871e8bb95f3cc83aff3398d7390"), + ("ViT-B-32/visual.onnx", "88c6f38e522269d6c04a85df18e6370c"), ), - 'ViT-B/16': ( - ('ViT-B-16/textual.onnx', '6f0976629a446f95c0c8767658f12ebe'), - ('ViT-B-16/visual.onnx', 'd5c03bfeef1abbd9bede54a8f6e1eaad'), + "ViT-B/16": ( + ("ViT-B-16/textual.onnx", "6f0976629a446f95c0c8767658f12ebe"), + ("ViT-B-16/visual.onnx", "d5c03bfeef1abbd9bede54a8f6e1eaad"), ), - 'ViT-L/14': ( - ('ViT-L-14/textual.onnx', '325380b31af4837c2e0d9aba2fad8e1b'), - ('ViT-L-14/visual.onnx', '53f5b319d3dc5d42572adea884e31056'), + "ViT-L/14": ( + ("ViT-L-14/textual.onnx", "325380b31af4837c2e0d9aba2fad8e1b"), + ("ViT-L-14/visual.onnx", "53f5b319d3dc5d42572adea884e31056"), ), - 'ViT-L/14@336px': ( - ('ViT-L-14@336px/textual.onnx', '78fab479f136403eed0db46f3e9e7ed2'), - ('ViT-L-14@336px/visual.onnx', 'f3b1f5d55ca08d43d749e11f7e4ba27e'), + "ViT-L/14@336px": ( + ("ViT-L-14@336px/textual.onnx", "78fab479f136403eed0db46f3e9e7ed2"), + ("ViT-L-14@336px/visual.onnx", "f3b1f5d55ca08d43d749e11f7e4ba27e"), ), # MultilingualCLIP models - 'M-CLIP/LABSE-Vit-L-14': ( - ('M-CLIP-LABSE-Vit-L-14/textual.onnx', '03727820116e63c7d19c72bb5d839488'), - ('M-CLIP-LABSE-Vit-L-14/visual.onnx', 'a78028eab30084c3913edfb0c8411f15'), + "M-CLIP/LABSE-Vit-L-14": ( + ("M-CLIP-LABSE-Vit-L-14/textual.onnx", "03727820116e63c7d19c72bb5d839488"), + ("M-CLIP-LABSE-Vit-L-14/visual.onnx", "a78028eab30084c3913edfb0c8411f15"), ), - 'M-CLIP/XLM-Roberta-Large-Vit-B-32': ( + "M-CLIP/XLM-Roberta-Large-Vit-B-32": ( ( - 'M-CLIP-XLM-Roberta-Large-Vit-B-32/textual.zip', - '41f51ec9af4754d11c7b7929e2caf5b9', + "M-CLIP-XLM-Roberta-Large-Vit-B-32/textual.zip", + "41f51ec9af4754d11c7b7929e2caf5b9", ), ( - 'M-CLIP-XLM-Roberta-Large-Vit-B-32/visual.onnx', - '5f18f68ac94e294863bfd1f695c8c5ca', + "M-CLIP-XLM-Roberta-Large-Vit-B-32/visual.onnx", + "5f18f68ac94e294863bfd1f695c8c5ca", ), ), - 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': ( + "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": ( ( - 'M-CLIP-XLM-Roberta-Large-Vit-B-16Plus/textual.zip', - '6c3e55f7d2d6c12f2c1f1dd36fdec607', + "M-CLIP-XLM-Roberta-Large-Vit-B-16Plus/textual.zip", + "6c3e55f7d2d6c12f2c1f1dd36fdec607", ), ( - 'M-CLIP-XLM-Roberta-Large-Vit-B-16Plus/visual.onnx', - '467a3ef3e5f50abcf850c3db9e705f8e', + "M-CLIP-XLM-Roberta-Large-Vit-B-16Plus/visual.onnx", + "467a3ef3e5f50abcf850c3db9e705f8e", ), ), - 'M-CLIP/XLM-Roberta-Large-Vit-L-14': ( + "M-CLIP/XLM-Roberta-Large-Vit-L-14": ( ( - 'M-CLIP-XLM-Roberta-Large-Vit-L-14/textual.zip', - '3dff00335dc3093acb726dab975ae57d', + "M-CLIP-XLM-Roberta-Large-Vit-L-14/textual.zip", + "3dff00335dc3093acb726dab975ae57d", ), ( - 'M-CLIP-XLM-Roberta-Large-Vit-L-14/visual.onnx', - 'a78028eab30084c3913edfb0c8411f15', + "M-CLIP-XLM-Roberta-Large-Vit-L-14/visual.onnx", + "a78028eab30084c3913edfb0c8411f15", ), ), } @@ -202,7 +202,7 @@ class CLIPOnnxModel(BaseCLIPModel): def __init__( - self, name: str, model_path: str = None, dtype: Optional[str] = 'fp32' + self, name: str, model_path: str = None, dtype: Optional[str] = "fp32" ): super().__init__(name) self._dtype = dtype @@ -227,24 +227,24 @@ def __init__( ) else: if os.path.isdir(model_path): - self._textual_path = os.path.join(model_path, 'textual.onnx') - self._visual_path = os.path.join(model_path, 'visual.onnx') + self._textual_path = os.path.join(model_path, "textual.onnx") + self._visual_path = os.path.join(model_path, "visual.onnx") if not os.path.isfile(self._textual_path) or not os.path.isfile( self._visual_path ): raise RuntimeError( - f'The given model path {model_path} does not contain `textual.onnx` and `visual.onnx`' + f"The given model path {model_path} does not contain `textual.onnx` and `visual.onnx`" ) else: raise RuntimeError( - f'The given model path {model_path} should be a folder containing both ' - f'`textual.onnx` and `visual.onnx`.' + f"The given model path {model_path} should be a folder containing both " + f"`textual.onnx` and `visual.onnx`." ) else: raise RuntimeError( - 'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format( + "CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}".format( name, - ''.join(['\t- {}\n'.format(i) for i in list(_MODELS.keys())]), + "".join(["\t- {}\n".format(i) for i in list(_MODELS.keys())]), ) ) @@ -269,18 +269,18 @@ def start_sessions( import onnxruntime as ort def _load_session(model_path: str, model_type: str, dtype: str): - if model_path.endswith('.zip') or dtype == 'fp16': + if model_path.endswith(".zip") or dtype == "fp16": import tempfile with tempfile.TemporaryDirectory() as tmp_dir: - tmp_model_path = tmp_dir + f'/{model_type}.onnx' - if model_path.endswith('.zip'): + tmp_model_path = tmp_dir + f"/{model_type}.onnx" + if model_path.endswith(".zip"): import zipfile - with zipfile.ZipFile(model_path, 'r') as zip_ref: + with zipfile.ZipFile(model_path, "r") as zip_ref: zip_ref.extractall(tmp_dir) model_path = tmp_model_path - if dtype == 'fp16': + if dtype == "fp16": import onnx from onnxmltools.utils import float16_converter @@ -293,8 +293,8 @@ def _load_session(model_path: str, model_type: str, dtype: str): return ort.InferenceSession(tmp_model_path, **kwargs) return ort.InferenceSession(model_path, **kwargs) - self._visual_session = _load_session(self._visual_path, 'visual', dtype) - self._textual_session = _load_session(self._textual_path, 'textual', dtype) + self._visual_session = _load_session(self._visual_path, "visual", dtype) + self._textual_session = _load_session(self._textual_path, "textual", dtype) self._visual_session.disable_fallback() self._textual_session.disable_fallback() diff --git a/server/clip_server/model/clip_trt.py b/server/clip_server/model/clip_trt.py index 1510003c5..5b84ddfee 100644 --- a/server/clip_server/model/clip_trt.py +++ b/server/clip_server/model/clip_trt.py @@ -21,27 +21,27 @@ from clip_server.model.clip_onnx import _MODELS as ONNX_MODELS _MODELS = [ - 'RN50::openai', - 'RN50::yfcc15m', - 'RN50::cc12m', - 'RN101::openai', - 'RN101::yfcc15m', - 'RN50x4::openai', - 'ViT-B-32::openai', - 'ViT-B-32::laion2b_e16', - 'ViT-B-32::laion400m_e31', - 'ViT-B-32::laion400m_e32', - 'ViT-B-16::openai', - 'ViT-B-16::laion400m_e31', - 'ViT-B-16::laion400m_e32', + "RN50::openai", + "RN50::yfcc15m", + "RN50::cc12m", + "RN101::openai", + "RN101::yfcc15m", + "RN50x4::openai", + "ViT-B-32::openai", + "ViT-B-32::laion2b_e16", + "ViT-B-32::laion400m_e31", + "ViT-B-32::laion400m_e32", + "ViT-B-16::openai", + "ViT-B-16::laion400m_e31", + "ViT-B-16::laion400m_e32", # older version name format - 'RN50', - 'RN101', - 'RN50x4', + "RN50", + "RN101", + "RN50x4", # 'RN50x16', # 'RN50x64', - 'ViT-B/32', - 'ViT-B/16', + "ViT-B/32", + "ViT-B/16", # 'ViT-L/14', # 'ViT-L/14@336px', ] @@ -61,11 +61,11 @@ def __init__( self._textual_path = os.path.join( cache_dir, - f'textual.{ONNX_MODELS[name][0][1]}.trt', + f"textual.{ONNX_MODELS[name][0][1]}.trt", ) self._visual_path = os.path.join( cache_dir, - f'visual.{ONNX_MODELS[name][1][1]}.trt', + f"visual.{ONNX_MODELS[name][1][1]}.trt", ) if not os.path.exists(self._textual_path) or not os.path.exists( @@ -114,9 +114,9 @@ def __init__( save_engine(text_engine, self._textual_path) else: raise RuntimeError( - 'CLIP model {} not found or not supports Nvidia TensorRT backend; below is a list of all available models:\n{}'.format( + "CLIP model {} not found or not supports Nvidia TensorRT backend; below is a list of all available models:\n{}".format( name, - ''.join(['\t- {}\n'.format(i) for i in list(_MODELS.keys())]), + "".join(["\t- {}\n".format(i) for i in list(_MODELS.keys())]), ) ) diff --git a/server/clip_server/model/cnclip_model.py b/server/clip_server/model/cnclip_model.py index a8761bae5..0a5b2f089 100644 --- a/server/clip_server/model/cnclip_model.py +++ b/server/clip_server/model/cnclip_model.py @@ -7,11 +7,11 @@ from cn_clip.clip import load_from_name _CNCLIP_MODEL_MAPS = { - 'CN-CLIP/ViT-B-16': 'ViT-B-16', - 'CN-CLIP/ViT-L-14': 'ViT-L-14', - 'CN-CLIP/ViT-L-14-336': 'ViT-L-14-336', - 'CN-CLIP/ViT-H-14': 'ViT-H-14', - 'CN-CLIP/RN50': 'RN50', + "CN-CLIP/ViT-B-16": "ViT-B-16", + "CN-CLIP/ViT-L-14": "ViT-L-14", + "CN-CLIP/ViT-L-14-336": "ViT-L-14-336", + "CN-CLIP/ViT-H-14": "ViT-H-14", + "CN-CLIP/RN50": "RN50", } @@ -19,7 +19,7 @@ class CNClipModel(CLIPModel): def __init__( self, name: str, - device: str = 'cpu', + device: str = "cpu", jit: bool = False, dtype: str = None, **kwargs @@ -36,10 +36,10 @@ def __init__( def get_model_name(name: str): return _CNCLIP_MODEL_MAPS[name] - def encode_text(self, input_ids: 'torch.Tensor', **kwargs): + def encode_text(self, input_ids: "torch.Tensor", **kwargs): return self._model.encode_text(input_ids).detach() - def encode_image(self, pixel_values: 'torch.Tensor', **kwargs): + def encode_image(self, pixel_values: "torch.Tensor", **kwargs): return self._model.encode_image(pixel_values).detach() @property diff --git a/server/clip_server/model/mclip_model.py b/server/clip_server/model/mclip_model.py index c85519501..12256671a 100644 --- a/server/clip_server/model/mclip_model.py +++ b/server/clip_server/model/mclip_model.py @@ -7,10 +7,10 @@ from clip_server.model.openclip_model import OpenCLIPModel _CLIP_MODEL_MAPS = { - 'M-CLIP/XLM-Roberta-Large-Vit-B-32': 'ViT-B-32::openai', - 'M-CLIP/XLM-Roberta-Large-Vit-L-14': 'ViT-L-14::openai', - 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': 'ViT-B-16-plus-240::laion400m_e31', - 'M-CLIP/LABSE-Vit-L-14': 'ViT-L-14::openai', + "M-CLIP/XLM-Roberta-Large-Vit-B-32": "ViT-B-32::openai", + "M-CLIP/XLM-Roberta-Large-Vit-L-14": "ViT-L-14::openai", + "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": "ViT-B-16-plus-240::laion400m_e31", + "M-CLIP/LABSE-Vit-L-14": "ViT-L-14::openai", } @@ -19,7 +19,7 @@ class MCLIPConfig(transformers.PretrainedConfig): def __init__( self, - modelBase: str = 'xlm-roberta-large', + modelBase: str = "xlm-roberta-large", transformerDimSize: int = 1024, imageDimSize: int = 768, **kwargs @@ -51,7 +51,7 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwarg class MultilingualCLIPModel(CLIPModel): - def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): + def __init__(self, name: str, device: str = "cpu", jit: bool = False, **kwargs): super().__init__(name, **kwargs) self._mclip_model = MultilingualCLIP.from_pretrained(name) self._mclip_model.to(device=device) @@ -60,10 +60,10 @@ def __init__(self, name: str, device: str = 'cpu', jit: bool = False, **kwargs): @staticmethod def get_model_name(name: str): - return _CLIP_MODEL_MAPS[name].split('::')[0] + return _CLIP_MODEL_MAPS[name].split("::")[0] def encode_text( - self, input_ids: 'torch.Tensor', attention_mask: 'torch.Tensor', **kwargs + self, input_ids: "torch.Tensor", attention_mask: "torch.Tensor", **kwargs ): return self._mclip_model( input_ids=input_ids, attention_mask=attention_mask, **kwargs diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 77a55e22b..eb8a03411 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -126,10 +126,10 @@ class CLIPVisionCfg: False # use (imagenet) pretrained weights for named model ) timm_pool: str = ( - 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') ) timm_proj: str = ( - 'linear' # linear projection for timm model output ('linear', 'mlp', '') + "linear" # linear projection for timm model output ('linear', 'mlp', '') ) timm_proj_bias: bool = False # enable bias final projection @@ -145,8 +145,8 @@ class CLIPTextCfg: hf_model_name: str = None hf_tokenizer_name: str = None hf_model_pretrained: bool = True - proj: str = 'mlp' - pooler_type: str = 'mean_pooler' + proj: str = "mlp" + pooler_type: str = "mean_pooler" def _build_vision_tower( @@ -293,7 +293,7 @@ def __init__( self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection - self.register_buffer('attn_mask', text.attn_mask, persistent=False) + self.register_buffer("attn_mask", text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) @@ -329,13 +329,13 @@ def _convert_weights(l): convert_weights_to_fp16 = convert_weights_to_lp # backwards compat -def load_state_dict(checkpoint_path: str, map_location='cpu'): +def load_state_dict(checkpoint_path: str, map_location="cpu"): checkpoint = torch.load(checkpoint_path, map_location=map_location) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint - if next(iter(state_dict.items()))[0].startswith('module'): + if next(iter(state_dict.items()))[0].startswith("module"): state_dict = {k[7:]: v for k, v in state_dict.items()} return state_dict @@ -429,7 +429,7 @@ def build_model_from_openai_state_dict( def load_openai_model( model_path: str, - device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu', + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", dtype: Optional[Union[str, torch.dtype]] = None, jit: bool = True, ): @@ -453,10 +453,10 @@ def load_openai_model( A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if isinstance(dtype, str): - dtype = __cast_dtype__.get(dtype, 'amp') + dtype = __cast_dtype__.get(dtype, "amp") elif dtype is None: dtype = ( - torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 + torch.float32 if device in ("cpu", torch.device("cpu")) else torch.float16 ) try: # loading JIT archive @@ -484,7 +484,7 @@ def load_openai_model( # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) if dtype == torch.float32 or ( - isinstance(dtype, str) and dtype.startswith('amp') + isinstance(dtype, str) and dtype.startswith("amp") ): model.float() elif dtype == torch.bfloat16: @@ -562,7 +562,7 @@ def patch_float(module): def load_openclip_model( model_name: str, model_path: str, - device: Union[str, torch.device] = 'cpu', + device: Union[str, torch.device] = "cpu", jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, @@ -573,35 +573,35 @@ def load_openclip_model( dtype = __cast_dtype__.get(dtype) elif dtype is None: dtype = ( - torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16 + torch.float32 if device in ("cpu", torch.device("cpu")) else torch.float16 ) model_name = model_name.replace( - '/', '-' + "/", "-" ) # for callers using old naming with / in ViT names if model_name in _MODEL_CONFIGS: model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) else: - raise RuntimeError(f'Model config for {model_name} not found.') + raise RuntimeError(f"Model config for {model_name} not found.") if force_quick_gelu: # override for use of QuickGELU on non-OpenAI transformer models model_cfg["quick_gelu"] = True if pretrained_image: - if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + if "timm_model_name" in model_cfg.get("vision_cfg", {}): # pretrained weight loading for timm models set via vision_cfg - model_cfg['vision_cfg']['timm_model_pretrained'] = True + model_cfg["vision_cfg"]["timm_model_pretrained"] = True else: assert ( False - ), 'pretrained image towers currently only supported for timm models' + ), "pretrained image towers currently only supported for timm models" custom_text = ( - model_cfg.pop('custom_text', False) + model_cfg.pop("custom_text", False) or force_custom_text - or ('hf_model_name' in model_cfg['text_cfg']) + or ("hf_model_name" in model_cfg["text_cfg"]) ) if custom_text: diff --git a/server/clip_server/model/openclip_model.py b/server/clip_server/model/openclip_model.py index 2091be33d..a02f31d17 100644 --- a/server/clip_server/model/openclip_model.py +++ b/server/clip_server/model/openclip_model.py @@ -16,25 +16,25 @@ class OpenCLIPModel(CLIPModel): def __init__( self, name: str, - device: str = 'cpu', + device: str = "cpu", jit: bool = False, dtype: str = None, **kwargs ): super().__init__(name, **kwargs) - if '::' in name: - model_name, pretrained = name.split('::') + if "::" in name: + model_name, pretrained = name.split("::") else: model_name = name - pretrained = 'openai' + pretrained = "openai" self._model_name = model_name model_url, md5sum = get_model_url_md5(name) model_path = download_model(model_url, md5sum=md5sum) - if pretrained == 'openai': + if pretrained == "openai": self._model = load_openai_model( model_path=model_path, device=device, jit=jit, dtype=dtype ) @@ -49,16 +49,16 @@ def __init__( @staticmethod def get_model_name(name: str): - if '::' in name: - model_name, pretrained = name.split('::') + if "::" in name: + model_name, pretrained = name.split("::") else: model_name = name - if model_name == 'ViT-L/14@336px': - return 'ViT-L-14-336' - return model_name.replace('/', '-') + if model_name == "ViT-L/14@336px": + return "ViT-L-14-336" + return model_name.replace("/", "-") - def encode_text(self, input_ids: 'torch.Tensor', **kwargs): + def encode_text(self, input_ids: "torch.Tensor", **kwargs): return self._model.encode_text(input_ids) - def encode_image(self, pixel_values: 'torch.Tensor', **kwargs): + def encode_image(self, pixel_values: "torch.Tensor", **kwargs): return self._model.encode_image(pixel_values) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index ecb46c747..222cec39c 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -5,136 +5,136 @@ import requests -_OPENCLIP_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch' -_OPENCLIP_HUGGINGFACE_BUCKET = 'https://huggingface.co/jinaai/' +_OPENCLIP_S3_BUCKET = "https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch" +_OPENCLIP_HUGGINGFACE_BUCKET = "https://huggingface.co/jinaai/" _OPENCLIP_MODELS = { - 'RN50::openai': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), - 'RN50::yfcc15m': ('RN50-yfcc15m.pt', 'e9c564f91ae7dc754d9043fdcd2a9f22'), - 'RN50::cc12m': ('RN50-cc12m.pt', '37cb01eb52bb6efe7666b1ff2d7311b5'), - 'RN101::openai': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), - 'RN101::yfcc15m': ('RN101-yfcc15m.pt', '48f7448879ce25e355804f6bb7928cb8'), - 'RN50x4::openai': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'), - 'RN50x16::openai': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'), - 'RN50x64::openai': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'), - 'ViT-B-32::openai': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'), - 'ViT-B-32::laion2b_e16': ( - 'ViT-B-32-laion2b_e16.pt', - 'df08de3d9f2dc53c71ea26e184633902', - ), - 'ViT-B-32::laion400m_e31': ( - 'ViT-B-32-laion400m_e31.pt', - 'ca8015f98ab0f8780510710681d7b73e', - ), - 'ViT-B-32::laion400m_e32': ( - 'ViT-B-32-laion400m_e32.pt', - '359e0dba4a419f175599ee0c63a110d8', - ), - 'ViT-B-32::laion2b-s34b-b79k': ( - 'ViT-B-32-laion2b-s34b-b79k.bin', - '2fc036aea9cd7306f5ce7ce6abb8d0bf', - ), - 'ViT-B-16::openai': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'), - 'ViT-B-16::laion400m_e31': ( - 'ViT-B-16-laion400m_e31.pt', - '31306a44224cc46fec1bc3b82fd0c4e6', - ), - 'ViT-B-16::laion400m_e32': ( - 'ViT-B-16-laion400m_e32.pt', - '07283adc5c17899f2ed22d82b563c54b', - ), - 'ViT-B-16-plus-240::laion400m_e31': ( - 'ViT-B-16-plus-240-laion400m_e31.pt', - 'c88f453644a998ecb094d878a2f0738d', - ), - 'ViT-B-16-plus-240::laion400m_e32': ( - 'ViT-B-16-plus-240-laion400m_e32.pt', - 'e573af3cef888441241e35022f30cc95', - ), - 'ViT-L-14::openai': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'), - 'ViT-L-14::laion400m_e31': ( - 'ViT-L-14-laion400m_e31.pt', - '09d223a6d41d2c5c201a9da618d833aa', - ), - 'ViT-L-14::laion400m_e32': ( - 'ViT-L-14-laion400m_e32.pt', - 'a76cde1bc744ca38c6036b920c847a89', - ), - 'ViT-L-14::laion2b-s32b-b82k': ( - 'ViT-L-14-laion2b-s32b-b82k.bin', - '4d2275fc7b2d7ee9db174f9b57ddecbd', - ), - 'ViT-L-14-336::openai': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'), - 'ViT-H-14::laion2b-s32b-b79k': ( - 'ViT-H-14-laion2b-s32b-b79k.bin', - '2aa6c46521b165a0daeb8cdc6668c7d3', - ), - 'ViT-g-14::laion2b-s12b-b42k': ( - 'ViT-g-14-laion2b-s12b-b42k.bin', - '3bf99353f6f1829faac0bb155be4382a', - ), - 'roberta-ViT-B-32::laion2b-s12b-b32k': ( - 'roberta-ViT-B-32-laion2b-s12b-b32k.bin', - '76d4c9d13774cc15fa0e2b1b94a8402c', - ), - 'xlm-roberta-base-ViT-B-32::laion5b-s13b-b90k': ( - 'xlm-roberta-base-ViT-B-32-laion5b-s13b-b90k.bin', - 'f68abc07ef349720f1f880180803142d', - ), - 'xlm-roberta-large-ViT-H-14::frozen_laion5b_s13b_b90k': ( - 'xlm-roberta-large-ViT-H-14-frozen_laion5b_s13b_b90k.bin', - 'b49991239a419d704fdba59c42d5536d', + "RN50::openai": ("RN50.pt", "9140964eaaf9f68c95aa8df6ca13777c"), + "RN50::yfcc15m": ("RN50-yfcc15m.pt", "e9c564f91ae7dc754d9043fdcd2a9f22"), + "RN50::cc12m": ("RN50-cc12m.pt", "37cb01eb52bb6efe7666b1ff2d7311b5"), + "RN101::openai": ("RN101.pt", "fa9d5f64ebf152bc56a18db245071014"), + "RN101::yfcc15m": ("RN101-yfcc15m.pt", "48f7448879ce25e355804f6bb7928cb8"), + "RN50x4::openai": ("RN50x4.pt", "03830990bc768e82f7fb684cde7e5654"), + "RN50x16::openai": ("RN50x16.pt", "83d63878a818c65d0fb417e5fab1e8fe"), + "RN50x64::openai": ("RN50x64.pt", "a6631a0de003c4075d286140fc6dd637"), + "ViT-B-32::openai": ("ViT-B-32.pt", "3ba34e387b24dfe590eeb1ae6a8a122b"), + "ViT-B-32::laion2b_e16": ( + "ViT-B-32-laion2b_e16.pt", + "df08de3d9f2dc53c71ea26e184633902", + ), + "ViT-B-32::laion400m_e31": ( + "ViT-B-32-laion400m_e31.pt", + "ca8015f98ab0f8780510710681d7b73e", + ), + "ViT-B-32::laion400m_e32": ( + "ViT-B-32-laion400m_e32.pt", + "359e0dba4a419f175599ee0c63a110d8", + ), + "ViT-B-32::laion2b-s34b-b79k": ( + "ViT-B-32-laion2b-s34b-b79k.bin", + "2fc036aea9cd7306f5ce7ce6abb8d0bf", + ), + "ViT-B-16::openai": ("ViT-B-16.pt", "44c3d804ecac03d9545ac1a3adbca3a6"), + "ViT-B-16::laion400m_e31": ( + "ViT-B-16-laion400m_e31.pt", + "31306a44224cc46fec1bc3b82fd0c4e6", + ), + "ViT-B-16::laion400m_e32": ( + "ViT-B-16-laion400m_e32.pt", + "07283adc5c17899f2ed22d82b563c54b", + ), + "ViT-B-16-plus-240::laion400m_e31": ( + "ViT-B-16-plus-240-laion400m_e31.pt", + "c88f453644a998ecb094d878a2f0738d", + ), + "ViT-B-16-plus-240::laion400m_e32": ( + "ViT-B-16-plus-240-laion400m_e32.pt", + "e573af3cef888441241e35022f30cc95", + ), + "ViT-L-14::openai": ("ViT-L-14.pt", "096db1af569b284eb76b3881534822d9"), + "ViT-L-14::laion400m_e31": ( + "ViT-L-14-laion400m_e31.pt", + "09d223a6d41d2c5c201a9da618d833aa", + ), + "ViT-L-14::laion400m_e32": ( + "ViT-L-14-laion400m_e32.pt", + "a76cde1bc744ca38c6036b920c847a89", + ), + "ViT-L-14::laion2b-s32b-b82k": ( + "ViT-L-14-laion2b-s32b-b82k.bin", + "4d2275fc7b2d7ee9db174f9b57ddecbd", + ), + "ViT-L-14-336::openai": ("ViT-L-14-336px.pt", "b311058cae50cb10fbfa2a44231c9473"), + "ViT-H-14::laion2b-s32b-b79k": ( + "ViT-H-14-laion2b-s32b-b79k.bin", + "2aa6c46521b165a0daeb8cdc6668c7d3", + ), + "ViT-g-14::laion2b-s12b-b42k": ( + "ViT-g-14-laion2b-s12b-b42k.bin", + "3bf99353f6f1829faac0bb155be4382a", + ), + "roberta-ViT-B-32::laion2b-s12b-b32k": ( + "roberta-ViT-B-32-laion2b-s12b-b32k.bin", + "76d4c9d13774cc15fa0e2b1b94a8402c", + ), + "xlm-roberta-base-ViT-B-32::laion5b-s13b-b90k": ( + "xlm-roberta-base-ViT-B-32-laion5b-s13b-b90k.bin", + "f68abc07ef349720f1f880180803142d", + ), + "xlm-roberta-large-ViT-H-14::frozen_laion5b_s13b_b90k": ( + "xlm-roberta-large-ViT-H-14-frozen_laion5b_s13b_b90k.bin", + "b49991239a419d704fdba59c42d5536d", ), # older version name format - 'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), - 'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), - 'RN50x4': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'), - 'RN50x16': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'), - 'RN50x64': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'), - 'ViT-B/32': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'), - 'ViT-B/16': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'), - 'ViT-L/14': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'), - 'ViT-L/14@336px': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'), + "RN50": ("RN50.pt", "9140964eaaf9f68c95aa8df6ca13777c"), + "RN101": ("RN101.pt", "fa9d5f64ebf152bc56a18db245071014"), + "RN50x4": ("RN50x4.pt", "03830990bc768e82f7fb684cde7e5654"), + "RN50x16": ("RN50x16.pt", "83d63878a818c65d0fb417e5fab1e8fe"), + "RN50x64": ("RN50x64.pt", "a6631a0de003c4075d286140fc6dd637"), + "ViT-B/32": ("ViT-B-32.pt", "3ba34e387b24dfe590eeb1ae6a8a122b"), + "ViT-B/16": ("ViT-B-16.pt", "44c3d804ecac03d9545ac1a3adbca3a6"), + "ViT-L/14": ("ViT-L-14.pt", "096db1af569b284eb76b3881534822d9"), + "ViT-L/14@336px": ("ViT-L-14-336px.pt", "b311058cae50cb10fbfa2a44231c9473"), } _MULTILINGUALCLIP_MODELS = { - 'M-CLIP/XLM-Roberta-Large-Vit-B-32': (), - 'M-CLIP/XLM-Roberta-Large-Vit-L-14': (), - 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': (), - 'M-CLIP/LABSE-Vit-L-14': (), + "M-CLIP/XLM-Roberta-Large-Vit-B-32": (), + "M-CLIP/XLM-Roberta-Large-Vit-L-14": (), + "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": (), + "M-CLIP/LABSE-Vit-L-14": (), } _CNCLIP_MODELS = { - 'CN-CLIP/ViT-B-16': (), - 'CN-CLIP/ViT-L-14': (), - 'CN-CLIP/ViT-L-14-336': (), - 'CN-CLIP/ViT-H-14': (), - 'CN-CLIP/RN50': (), + "CN-CLIP/ViT-B-16": (), + "CN-CLIP/ViT-L-14": (), + "CN-CLIP/ViT-L-14-336": (), + "CN-CLIP/ViT-H-14": (), + "CN-CLIP/RN50": (), } _VISUAL_MODEL_IMAGE_SIZE = { - 'RN50': 224, - 'RN101': 224, - 'RN50x4': 288, - 'RN50x16': 384, - 'RN50x64': 448, - 'ViT-B-32': 224, - 'roberta-ViT-B-32': 224, - 'xlm-roberta-base-ViT-B-32': 224, - 'ViT-B-16': 224, - 'Vit-B-16Plus': 240, - 'ViT-B-16-plus-240': 240, - 'ViT-L-14': 224, - 'ViT-L-14-336': 336, - 'ViT-H-14': 224, - 'xlm-roberta-large-ViT-H-14': 224, - 'ViT-g-14': 224, + "RN50": 224, + "RN101": 224, + "RN50x4": 288, + "RN50x16": 384, + "RN50x64": 448, + "ViT-B-32": 224, + "roberta-ViT-B-32": 224, + "xlm-roberta-base-ViT-B-32": 224, + "ViT-B-16": 224, + "Vit-B-16Plus": 240, + "ViT-B-16-plus-240": 240, + "ViT-L-14": 224, + "ViT-L-14-336": 336, + "ViT-H-14": 224, + "xlm-roberta-large-ViT-H-14": 224, + "ViT-g-14": 224, } def md5file(filename: str): hash_md5 = hashlib.md5() - with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b''): + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() @@ -145,22 +145,28 @@ def get_model_url_md5(name: str): if len(model_pretrained) == 0: # not on s3 return None, None else: - hg_download_url = _OPENCLIP_HUGGINGFACE_BUCKET + name.split('::')[0] + '/resolve/main/' + model_pretrained[0] + '?download=true' - try: - response = requests.head(hg_download_url) - if response.status_code in [200, 302] : + hg_download_url = ( + _OPENCLIP_HUGGINGFACE_BUCKET + + name.split("::")[0] + + "/resolve/main/" + + model_pretrained[0] + + "?download=true" + ) + try: + response = requests.head(hg_download_url) + if response.status_code in [200, 302]: return (hg_download_url, model_pretrained[1]) - else: - print(f'Model not found on hugging face, trying to download from s3.') + else: + print(f"Model not found on hugging face, trying to download from s3.") except requests.exceptions.RequestException as e: print(str(e)) - print(f'Model not found on hugging face, trying to download from s3.') - return (_OPENCLIP_S3_BUCKET + '/' + model_pretrained[0], model_pretrained[1]) + print(f"Model not found on hugging face, trying to download from s3.") + return (_OPENCLIP_S3_BUCKET + "/" + model_pretrained[0], model_pretrained[1]) def download_model( url: str, - target_folder: str = os.path.expanduser('~/.cache/clip'), + target_folder: str = os.path.expanduser("~/.cache/clip"), md5sum: str = None, with_resume: bool = True, max_attempts: int = 3, @@ -172,7 +178,7 @@ def download_model( if os.path.exists(download_target): if not os.path.isfile(download_target): - raise FileExistsError(f'{download_target} exists and is not a regular file') + raise FileExistsError(f"{download_target} exists and is not a regular file") actual_md5sum = md5file(download_target) if (not md5sum) or actual_md5sum == md5sum: @@ -187,33 +193,33 @@ def download_model( ) progress = Progress( - ' \n', # divide this bar from Flow's bar - TextColumn('[bold blue]{task.fields[filename]}', justify='right'), - '[progress.percentage]{task.percentage:>3.1f}%', - '•', + " \n", # divide this bar from Flow's bar + TextColumn("[bold blue]{task.fields[filename]}", justify="right"), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", DownloadColumn(), - '•', + "•", TransferSpeedColumn(), - '•', + "•", TimeRemainingColumn(), ) with progress: - task = progress.add_task('download', filename=filename, start=False) + task = progress.add_task("download", filename=filename, start=False) for _ in range(max_attempts): - tmp_file_path = download_target + '.part' + tmp_file_path = download_target + ".part" resume_byte_pos = ( os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0 ) try: # resolve the 403 error by passing a valid user-agent - req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"}) total_bytes = int( - urllib.request.urlopen(req).info().get('Content-Length', -1) + urllib.request.urlopen(req).info().get("Content-Length", -1) ) - mode = 'ab' if (with_resume and resume_byte_pos) else 'wb' + mode = "ab" if (with_resume and resume_byte_pos) else "wb" with open(tmp_file_path, mode) as output: progress.update(task, total=total_bytes) @@ -221,7 +227,7 @@ def download_model( if resume_byte_pos and with_resume: progress.update(task, advance=resume_byte_pos) - req.headers['Range'] = f'bytes={resume_byte_pos}-' + req.headers["Range"] = f"bytes={resume_byte_pos}-" with urllib.request.urlopen(req) as source: while True: @@ -239,15 +245,15 @@ def download_model( else: os.remove(tmp_file_path) raise RuntimeError( - f'MD5 mismatch: expected {md5sum}, got {actual_md5}' + f"MD5 mismatch: expected {md5sum}, got {actual_md5}" ) except Exception as ex: progress.console.print( - f'Failed to download {url} with {ex!r} at the {_}th attempt' + f"Failed to download {url} with {ex!r} at the {_}th attempt" ) progress.reset(task) raise RuntimeError( - f'Failed to download {url} within retry limit {max_attempts}' + f"Failed to download {url} within retry limit {max_attempts}" ) diff --git a/server/clip_server/model/simple_tokenizer.py b/server/clip_server/model/simple_tokenizer.py index a5f6a5478..8e342ac8c 100644 --- a/server/clip_server/model/simple_tokenizer.py +++ b/server/clip_server/model/simple_tokenizer.py @@ -13,7 +13,7 @@ @lru_cache() def default_bpe(): - return os.path.join(__resources_path__, 'bpe_simple_vocab_16e6.txt.gz') + return os.path.join(__resources_path__, "bpe_simple_vocab_16e6.txt.gz") @lru_cache() @@ -62,7 +62,7 @@ def basic_clean(text): def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) text = text.strip() return text @@ -71,20 +71,20 @@ class SimpleTokenizer(object): def __init__(self, bpe_path: str = default_bpe()): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") merges = merges[1 : 49152 - 256 - 2 + 1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v + '' for v in vocab] + vocab = vocab + [v + "" for v in vocab] for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = { - '<|startoftext|>': '<|startoftext|>', - '<|endoftext|>': '<|endoftext|>', + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", } self.pat = re.compile( r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", @@ -94,14 +94,14 @@ def __init__(self, bpe_path: str = default_bpe()): def bpe(self, token): if token in self.cache: return self.cache[token] - word = tuple(token[:-1]) + (token[-1] + '',) + word = tuple(token[:-1]) + (token[-1] + "",) pairs = get_pairs(word) if not pairs: - return token + '' + return token + "" while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram @@ -128,7 +128,7 @@ def bpe(self, token): break else: pairs = get_pairs(word) - word = ' '.join(word) + word = " ".join(word) self.cache[token] = word return word @@ -136,17 +136,17 @@ def encode(self, text): bpe_tokens = [] text = whitespace_clean(basic_clean(text)).lower() for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) bpe_tokens.extend( - self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ') + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") ) return bpe_tokens def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) + text = "".join([self.decoder[token] for token in tokens]) text = ( bytearray([self.byte_decoder[c] for c in text]) - .decode('utf-8', errors='replace') - .replace('', ' ') + .decode("utf-8", errors="replace") + .replace("", " ") ) return text diff --git a/server/clip_server/model/tokenization.py b/server/clip_server/model/tokenization.py index 605d571f0..3da29e6cf 100644 --- a/server/clip_server/model/tokenization.py +++ b/server/clip_server/model/tokenization.py @@ -57,13 +57,13 @@ def _tokenize( texts, max_length=context_length, return_attention_mask=True, - return_tensors='pt', + return_tensors="pt", padding=True, truncation=True, ) return { - 'input_ids': result['input_ids'], - 'attention_mask': result['attention_mask'], + "input_ids": result["input_ids"], + "attention_mask": result["attention_mask"], } elif self._name in _CNCLIP_MODELS: result = self._tokenizer.tokenize( @@ -77,8 +77,8 @@ def _tokenize( "attention_mask": attn_mask, } else: - sot_token = self._tokenizer.encoder['<|startoftext|>'] - eot_token = self._tokenizer.encoder['<|endoftext|>'] + sot_token = self._tokenizer.encoder["<|startoftext|>"] + eot_token = self._tokenizer.encoder["<|endoftext|>"] all_tokens = [ [sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts @@ -96,9 +96,9 @@ def _tokenize( tokens[-1] = eot_token else: raise RuntimeError( - f'Input {texts[i]} is too long for context length {context_length}' + f"Input {texts[i]} is too long for context length {context_length}" ) input_ids[i, : len(tokens)] = torch.tensor(tokens) attention_mask[i, : len(tokens)] = 1 - return {'input_ids': input_ids, 'attention_mask': attention_mask} + return {"input_ids": input_ids, "attention_mask": attention_mask} diff --git a/server/setup.py b/server/setup.py index a16cfda25..ba09a078b 100644 --- a/server/setup.py +++ b/server/setup.py @@ -4,99 +4,99 @@ from setuptools import find_packages, setup if sys.version_info < (3, 7, 0): - raise OSError(f'CLIP-as-service requires Python >=3.7, but yours is {sys.version}') + raise OSError(f"CLIP-as-service requires Python >=3.7, but yours is {sys.version}") try: - pkg_name = 'clip-server' + pkg_name = "clip-server" libinfo_py = path.join( - path.dirname(__file__), pkg_name.replace('-', '_'), '__init__.py' + path.dirname(__file__), pkg_name.replace("-", "_"), "__init__.py" ) - libinfo_content = open(libinfo_py, 'r', encoding='utf8').readlines() - version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][ + libinfo_content = open(libinfo_py, "r", encoding="utf8").readlines() + version_line = [l.strip() for l in libinfo_content if l.startswith("__version__")][ 0 ] exec(version_line) # gives __version__ except FileNotFoundError: - __version__ = '0.0.0' + __version__ = "0.0.0" try: - with open('../README.md', encoding='utf8') as fp: + with open("../README.md", encoding="utf8") as fp: _long_description = fp.read() except FileNotFoundError: - _long_description = '' + _long_description = "" setup( name=pkg_name, packages=find_packages(), version=__version__, include_package_data=True, - description='Embed images and sentences into fixed-length vectors via CLIP', - author='Jina AI', - author_email='hello@jina.ai', - license='Apache 2.0', - url='https://github.com/jina-ai/clip-as-service', - download_url='https://github.com/jina-ai/clip-as-service/tags', + description="Embed images and sentences into fixed-length vectors via CLIP", + author="Jina AI", + author_email="hello@jina.ai", + license="Apache 2.0", + url="https://github.com/jina-ai/clip-as-service", + download_url="https://github.com/jina-ai/clip-as-service/tags", long_description=_long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", zip_safe=False, - setup_requires=['setuptools>=18.0', 'wheel'], + setup_requires=["setuptools>=18.0", "wheel"], install_requires=[ - 'ftfy', - 'torch', - 'regex', - 'torchvision<=0.13.0' if sys.version_info <= (3, 7, 2) else 'torchvision', - 'jina>=3.12.0', - 'docarray==0.21.0', - 'prometheus-client', - 'open_clip_torch>=2.8.0,<2.9.0', - 'pillow-avif-plugin', + "ftfy", + "torch", + "regex", + "torchvision<=0.13.0" if sys.version_info <= (3, 7, 2) else "torchvision", + "jina>=3.12.0", + "docarray==0.21.0", + "prometheus-client", + "open_clip_torch>=2.8.0,<2.9.0", + "pillow-avif-plugin", ], extras_require={ - 'onnx': [ - 'onnx', - 'onnxmltools<1.12.0', + "onnx": [ + "onnx", + "onnxmltools<1.12.0", ] + ( - ['onnxruntime-gpu<=1.13.1'] - if sys.platform != 'darwin' - else ['onnxruntime<=1.13.1'] + ["onnxruntime-gpu<=1.13.1"] + if sys.platform != "darwin" + else ["onnxruntime<=1.13.1"] ), - 'tensorrt': [ - 'nvidia-tensorrt==8.4.1.5', + "tensorrt": [ + "nvidia-tensorrt==8.4.1.5", ], - 'transformers': ['transformers>=4.16.2'], - 'search': ['annlite>=0.3.10'], - 'flash-attn': ['flash-attn'], - 'cn_clip': ['cn_clip'], + "transformers": ["transformers>=4.16.2"], + "search": ["annlite>=0.3.10"], + "flash-attn": ["flash-attn"], + "cn_clip": ["cn_clip"], }, classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Unix Shell', - 'Environment :: Console', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Topic :: Database :: Database Engines/Servers', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Internet :: WWW/HTTP :: Indexing/Search', - 'Topic :: Scientific/Engineering :: Image Recognition', - 'Topic :: Multimedia :: Video', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Unix Shell", + "Environment :: Console", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Topic :: Database :: Database Engines/Servers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Internet :: WWW/HTTP :: Indexing/Search", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Multimedia :: Video", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], project_urls={ - 'Documentation': 'https://clip-as-service.jina.ai', - 'Source': 'https://github.com/jina-ai/clip-as-service/', - 'Tracker': 'https://github.com/jina-ai/clip-as-service/issues', + "Documentation": "https://clip-as-service.jina.ai", + "Source": "https://github.com/jina-ai/clip-as-service/", + "Tracker": "https://github.com/jina-ai/clip-as-service/issues", }, - keywords='jina openai clip deep-learning cross-modal multi-modal neural-search', + keywords="jina openai clip deep-learning cross-modal multi-modal neural-search", ) diff --git a/tests/__init__.py b/tests/__init__.py index fad87c7b2..839f3984b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,3 @@ import os -os.environ['OMP_NUM_THREADS'] = '1' +os.environ["OMP_NUM_THREADS"] = "1" diff --git a/tests/conftest.py b/tests/conftest.py index 0726beec3..c3cb34ae3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ from jina import helper, Flow -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def port_generator(): generated_ports = set() @@ -16,10 +16,10 @@ def random_port(): return random_port -@pytest.fixture(scope='session', params=['onnx', 'torch', 'onnx_custom']) +@pytest.fixture(scope="session", params=["onnx", "torch", "onnx_custom"]) def make_flow(port_generator, request): - if request.param != 'onnx_custom': - if request.param == 'onnx': + if request.param != "onnx_custom": + if request.param == "onnx": from clip_server.executors.clip_onnx import CLIPEncoder else: from clip_server.executors.clip_torch import CLIPEncoder @@ -33,14 +33,14 @@ def make_flow(port_generator, request): name=request.param, uses=CLIPEncoder, uses_with={ - 'model_path': os.path.expanduser('~/.cache/clip/ViT-B-32-openai') + "model_path": os.path.expanduser("~/.cache/clip/ViT-B-32-openai") }, ) with f: yield f -@pytest.fixture(scope='session', params=['torch']) +@pytest.fixture(scope="session", params=["torch"]) def make_torch_flow(port_generator, request): from clip_server.executors.clip_torch import CLIPEncoder @@ -49,7 +49,7 @@ def make_torch_flow(port_generator, request): yield f -@pytest.fixture(scope='session', params=['tensorrt']) +@pytest.fixture(scope="session", params=["tensorrt"]) def make_trt_flow(port_generator, request): from clip_server.executors.clip_tensorrt import CLIPEncoder @@ -58,7 +58,7 @@ def make_trt_flow(port_generator, request): yield f -@pytest.fixture(params=['torch']) +@pytest.fixture(params=["torch"]) def make_search_flow(tmpdir, port_generator, request): from clip_server.executors.clip_torch import CLIPEncoder from annlite.executor import AnnLiteIndexer @@ -67,10 +67,10 @@ def make_search_flow(tmpdir, port_generator, request): Flow(port=port_generator()) .add(name=request.param, uses=CLIPEncoder) .add( - name='annlite', + name="annlite", uses=AnnLiteIndexer, workspace=tmpdir, - uses_with={'n_dim': 512}, + uses_with={"n_dim": 512}, ) ) with f: diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 67bbbd2e7..cd7c2ff8c 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -12,33 +12,33 @@ async def another_heavylifting_job(): @pytest.mark.asyncio async def test_async_encode(make_flow): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") t1 = asyncio.create_task(another_heavylifting_job()) - t2 = asyncio.create_task(c.aencode(['hello world'] * 10)) + t2 = asyncio.create_task(c.aencode(["hello world"] * 10)) await asyncio.gather(t1, t2) assert t2.result().shape @pytest.mark.parametrize( - 'inputs', + "inputs", [ - DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), + DocumentArray([Document(text="hello, world"), Document(text="goodbye, world")]), DocumentArray( [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', - text='hello, world', + uri="https://clip-as-service.jina.ai/_static/favicon.png", + text="hello, world", ), ] ), DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ), ], ) @pytest.mark.asyncio async def test_async_docarray_preserve_original_inputs(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") t1 = asyncio.create_task(another_heavylifting_job()) t2 = asyncio.create_task(c.aencode(inputs if not callable(inputs) else inputs())) await asyncio.gather(t1, t2) @@ -51,15 +51,15 @@ async def test_async_docarray_preserve_original_inputs(make_flow, inputs): @pytest.mark.parametrize( - 'inputs', + "inputs", [ - [Document(id=str(i), text='hello, world') for i in range(20)], - DocumentArray([Document(id=str(i), text='hello, world') for i in range(20)]), + [Document(id=str(i), text="hello, world") for i in range(20)], + DocumentArray([Document(id=str(i), text="hello, world") for i in range(20)]), ], ) @pytest.mark.asyncio async def test_async_docarray_preserve_original_order(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") t1 = asyncio.create_task(another_heavylifting_job()) t2 = asyncio.create_task( c.aencode(inputs if not callable(inputs) else inputs(), batch_size=1) diff --git a/tests/test_client.py b/tests/test_client.py index 6a8b69726..776e0d899 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -15,7 +15,7 @@ async def aencode(self, docs, **kwargs): class Exec2(Executor): - def __init__(self, server_host: str = '', **kwargs): + def __init__(self, server_host: str = "", **kwargs): super().__init__(**kwargs) from clip_client.client import Client @@ -34,11 +34,10 @@ def foo(self, docs, **kwargs): def test_client_concurrent_requests(port_generator): - f1 = Flow(port=port_generator()).add(uses=Exec1) - f2 = Flow(protocol='http').add( - uses=Exec2, uses_with={'server_host': f'grpc://0.0.0.0:{f1.port}'} + f2 = Flow(protocol="http").add( + uses=Exec2, uses_with={"server_host": f"grpc://0.0.0.0:{f1.port}"} ) with f1, f2: @@ -46,18 +45,18 @@ def test_client_concurrent_requests(port_generator): from multiprocessing.pool import ThreadPool def run_post(docs): - c = jina.clients.Client(port=f2.port, protocol='http') - results = c.post(on='/', inputs=docs, request_size=2) + c = jina.clients.Client(port=f2.port, protocol="http") + results = c.post(on="/", inputs=docs, request_size=2) # assert set([d.id for d in results]) != set([d.id for d in docs]) return results def generate_docs(tag): return DocumentArray( - [Document(id=f'{tag}_{i}', text='hello') for i in range(20)] + [Document(id=f"{tag}_{i}", text="hello") for i in range(20)] ) with ThreadPool(5) as p: - results = p.map(run_post, [generate_docs(f't{k}') for k in range(5)]) + results = p.map(run_post, [generate_docs(f"t{k}") for k in range(5)]) for r in results: assert len(set([d.id[:2] for d in r])) == 1 @@ -66,29 +65,29 @@ def generate_docs(tag): def test_client_large_input(make_torch_flow): from clip_client.client import Client - inputs = ['hello' for _ in range(600)] + inputs = ["hello" for _ in range(600)] - c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_torch_flow.port}") with pytest.warns(UserWarning): c.encode(inputs if not callable(inputs) else inputs()) @pytest.mark.parametrize( - 'inputs', + "inputs", [ [], DocumentArray(), ], ) -@pytest.mark.parametrize('endpoint', ['encode', 'rank', 'index', 'search']) +@pytest.mark.parametrize("endpoint", ["encode", "rank", "index", "search"]) @pytest.mark.asyncio def test_empty_input(make_torch_flow, inputs, endpoint): from clip_client.client import Client - c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_torch_flow.port}") r = getattr(c, endpoint)(inputs if not callable(inputs) else inputs()) - if endpoint == 'encode': + if endpoint == "encode": if isinstance(inputs, DocumentArray): assert isinstance(r, DocumentArray) else: @@ -99,21 +98,21 @@ def test_empty_input(make_torch_flow, inputs, endpoint): @pytest.mark.parametrize( - 'inputs', + "inputs", [ [], DocumentArray(), ], ) -@pytest.mark.parametrize('endpoint', ['aencode', 'arank', 'aindex', 'asearch']) +@pytest.mark.parametrize("endpoint", ["aencode", "arank", "aindex", "asearch"]) @pytest.mark.asyncio async def test_async_empty_input(make_torch_flow, inputs, endpoint): from clip_client.client import Client - c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_torch_flow.port}") r = await getattr(c, endpoint)(inputs if not callable(inputs) else inputs()) - if endpoint == 'aencode': + if endpoint == "aencode": if isinstance(inputs, DocumentArray): assert isinstance(r, DocumentArray) else: @@ -123,33 +122,33 @@ async def test_async_empty_input(make_torch_flow, inputs, endpoint): assert len(r) == 0 -@pytest.mark.parametrize('endpoint', ['encode', 'rank', 'index', 'search']) +@pytest.mark.parametrize("endpoint", ["encode", "rank", "index", "search"]) def test_wrong_input_type(make_torch_flow, endpoint): from clip_client.client import Client - c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_torch_flow.port}") with pytest.raises(Exception): - getattr(c, endpoint)('hello') + getattr(c, endpoint)("hello") -@pytest.mark.parametrize('endpoint', ['aencode', 'arank', 'aindex', 'asearch']) +@pytest.mark.parametrize("endpoint", ["aencode", "arank", "aindex", "asearch"]) @pytest.mark.asyncio async def test_wrong_input_type(make_torch_flow, endpoint): from clip_client.client import Client - c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_torch_flow.port}") with pytest.raises(Exception): - await getattr(c, endpoint)('hello') + await getattr(c, endpoint)("hello") -@pytest.mark.parametrize('endpoint', ['encode', 'rank', 'index', 'search']) +@pytest.mark.parametrize("endpoint", ["encode", "rank", "index", "search"]) @pytest.mark.slow def test_custom_on_done(make_torch_flow, mocker, endpoint): from clip_client.client import Client - c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_torch_flow.port}") on_done_mock = mocker.Mock() on_error_mock = mocker.Mock() @@ -157,7 +156,7 @@ def test_custom_on_done(make_torch_flow, mocker, endpoint): r = getattr(c, endpoint)( DocumentArray( - [Document(text='hello', matches=DocumentArray([Document(text='jina')]))] + [Document(text="hello", matches=DocumentArray([Document(text="jina")]))] ), on_done=on_done_mock, on_error=on_error_mock, @@ -169,13 +168,13 @@ def test_custom_on_done(make_torch_flow, mocker, endpoint): on_always_mock.assert_called_once() -@pytest.mark.parametrize('endpoint', ['aencode', 'arank', 'aindex', 'asearch']) +@pytest.mark.parametrize("endpoint", ["aencode", "arank", "aindex", "asearch"]) @pytest.mark.slow @pytest.mark.asyncio async def test_async_custom_on_done(make_torch_flow, mocker, endpoint): from clip_client.client import Client - c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_torch_flow.port}") on_done_mock = mocker.Mock() on_error_mock = mocker.Mock() @@ -183,7 +182,7 @@ async def test_async_custom_on_done(make_torch_flow, mocker, endpoint): r = await getattr(c, endpoint)( DocumentArray( - [Document(text='hello', matches=DocumentArray([Document(text='jina')]))] + [Document(text="hello", matches=DocumentArray([Document(text="jina")]))] ), on_done=on_done_mock, on_error=on_error_mock, @@ -195,7 +194,7 @@ async def test_async_custom_on_done(make_torch_flow, mocker, endpoint): on_always_mock.assert_called_once() -@pytest.mark.parametrize('endpoint', ['encode', 'rank', 'index', 'search']) +@pytest.mark.parametrize("endpoint", ["encode", "rank", "index", "search"]) @pytest.mark.slow def test_custom_on_error(port_generator, mocker, endpoint): from clip_client.client import Client @@ -203,7 +202,7 @@ def test_custom_on_error(port_generator, mocker, endpoint): f = Flow(port=port_generator()).add(uses=ErrorExec) with f: - c = Client(server=f'grpc://0.0.0.0:{f.port}') + c = Client(server=f"grpc://0.0.0.0:{f.port}") on_done_mock = mocker.Mock() on_error_mock = mocker.Mock() @@ -211,7 +210,7 @@ def test_custom_on_error(port_generator, mocker, endpoint): r = getattr(c, endpoint)( DocumentArray( - [Document(text='hello', matches=DocumentArray([Document(text='jina')]))] + [Document(text="hello", matches=DocumentArray([Document(text="jina")]))] ), on_done=on_done_mock, on_error=on_error_mock, @@ -223,7 +222,7 @@ def test_custom_on_error(port_generator, mocker, endpoint): on_always_mock.assert_called_once() -@pytest.mark.parametrize('endpoint', ['aencode', 'arank', 'aindex', 'asearch']) +@pytest.mark.parametrize("endpoint", ["aencode", "arank", "aindex", "asearch"]) @pytest.mark.slow @pytest.mark.asyncio async def test_async_custom_on_error(port_generator, mocker, endpoint): @@ -232,7 +231,7 @@ async def test_async_custom_on_error(port_generator, mocker, endpoint): f = Flow(port=port_generator()).add(uses=ErrorExec) with f: - c = Client(server=f'grpc://0.0.0.0:{f.port}') + c = Client(server=f"grpc://0.0.0.0:{f.port}") on_done_mock = mocker.Mock() on_error_mock = mocker.Mock() @@ -240,7 +239,7 @@ async def test_async_custom_on_error(port_generator, mocker, endpoint): r = await getattr(c, endpoint)( DocumentArray( - [Document(text='hello', matches=DocumentArray([Document(text='jina')]))] + [Document(text="hello", matches=DocumentArray([Document(text="jina")]))] ), on_done=on_done_mock, on_error=on_error_mock, diff --git a/tests/test_helper.py b/tests/test_helper.py index 9836b0da9..de0b11048 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -6,8 +6,8 @@ from docarray import Document, DocumentArray -@pytest.mark.parametrize('shape', [(5, 10), (5, 10, 10)]) -@pytest.mark.parametrize('axis', [-1, 1, 0]) +@pytest.mark.parametrize("shape", [(5, 10), (5, 10, 10)]) +@pytest.mark.parametrize("axis", [-1, 1, 0]) def test_numpy_softmax(shape, axis): import torch @@ -23,19 +23,19 @@ def test_numpy_softmax(shape, axis): @pytest.mark.parametrize( - 'inputs', + "inputs", [ ( DocumentArray( [ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), Document( - text='hello, world', - uri='https://clip-as-service.jina.ai/_static/favicon.png', + text="hello, world", + uri="https://clip-as-service.jina.ai/_static/favicon.png", ), Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", ), ] ), @@ -44,17 +44,17 @@ def test_numpy_softmax(shape, axis): ( DocumentArray( [ - Document(text='hello, world'), + Document(text="hello, world"), Document(tensor=np.array([0, 1, 2])), Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png' + uri="https://clip-as-service.jina.ai/_static/favicon.png" ).load_uri_to_blob(), Document( tensor=np.array([0, 1, 2]), - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", ), Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", ), ] ), @@ -63,8 +63,8 @@ def test_numpy_softmax(shape, axis): ( DocumentArray( [ - Document(text='hello, world'), - Document(uri='https://clip-as-service.jina.ai/_static/favicon.png'), + Document(text="hello, world"), + Document(uri="https://clip-as-service.jina.ai/_static/favicon.png"), ] ), (1, 1), @@ -81,12 +81,12 @@ def test_split_img_txt_da(inputs): @pytest.mark.parametrize( - 'inputs', + "inputs", [ DocumentArray( [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", ).load_uri_to_image_tensor(), ] ) @@ -100,4 +100,4 @@ def test_preproc_image(inputs): assert len(da) == 1 assert not da[0].blob assert not da[0].tensor - assert pixel_values.get('pixel_values') is not None + assert pixel_values.get("pixel_values") is not None diff --git a/tests/test_model.py b/tests/test_model.py index 75b870b2d..0f380128a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,13 +7,13 @@ @pytest.mark.parametrize( - 'name, model_cls', + "name, model_cls", [ - ('ViT-L/14@336px', OpenCLIPModel), - ('RN50::openai', OpenCLIPModel), - ('roberta-ViT-B-32::laion2b-s12b-b32k', OpenCLIPModel), - ('M-CLIP/LABSE-Vit-L-14', MultilingualCLIPModel), - ('CN-CLIP/ViT-B-16', CNClipModel), + ("ViT-L/14@336px", OpenCLIPModel), + ("RN50::openai", OpenCLIPModel), + ("roberta-ViT-B-32::laion2b-s12b-b32k", OpenCLIPModel), + ("M-CLIP/LABSE-Vit-L-14", MultilingualCLIPModel), + ("CN-CLIP/ViT-B-16", CNClipModel), ], ) def test_torch_model(name, model_cls): @@ -22,11 +22,11 @@ def test_torch_model(name, model_cls): @pytest.mark.parametrize( - 'name', + "name", [ - 'RN50::openai', - 'ViT-H-14::laion2b-s32b-b79k', - 'M-CLIP/LABSE-Vit-L-14', + "RN50::openai", + "ViT-H-14::laion2b-s32b-b79k", + "M-CLIP/LABSE-Vit-L-14", ], ) def test_onnx_model(name): @@ -35,10 +35,10 @@ def test_onnx_model(name): @pytest.mark.gpu @pytest.mark.parametrize( - 'name', - ['ViT-H-14::laion2b-s32b-b79k'], + "name", + ["ViT-H-14::laion2b-s32b-b79k"], ) def test_large_onnx_model_fp16(name): from clip_server.executors.clip_onnx import CLIPEncoder - CLIPEncoder(name, dtype='fp16') + CLIPEncoder(name, dtype="fp16") diff --git a/tests/test_ranker.py b/tests/test_ranker.py index 450f66530..992913bab 100644 --- a/tests/test_ranker.py +++ b/tests/test_ranker.py @@ -10,102 +10,102 @@ @pytest.mark.asyncio -@pytest.mark.parametrize('encoder_class', [TorchCLIPEncoder, ONNXCLILPEncoder]) +@pytest.mark.parametrize("encoder_class", [TorchCLIPEncoder, ONNXCLILPEncoder]) async def test_torch_executor_rank_img2texts(encoder_class): ce = encoder_class() da = DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ) for d in da: - d.matches.append(Document(text='hello, world!')) - d.matches.append(Document(text='goodbye, world!')) - d.matches.append(Document(text='goodbye,!')) - d.matches.append(Document(text='good world!')) - d.matches.append(Document(text='good!')) - d.matches.append(Document(text='world!')) + d.matches.append(Document(text="hello, world!")) + d.matches.append(Document(text="goodbye, world!")) + d.matches.append(Document(text="goodbye,!")) + d.matches.append(Document(text="good world!")) + d.matches.append(Document(text="good!")) + d.matches.append(Document(text="world!")) await ce.rank(da, {}) - print(da['@m', 'scores__clip_score__value']) + print(da["@m", "scores__clip_score__value"]) for d in da: for c in d.matches: - assert c.scores['clip_score'].value is not None + assert c.scores["clip_score"].value is not None assert not c.tensor - org_score = d.matches[:, 'scores__clip_score__value'] + org_score = d.matches[:, "scores__clip_score__value"] assert org_score == list(sorted(org_score, reverse=True)) assert not d.tensor @pytest.mark.asyncio -@pytest.mark.parametrize('encoder_class', [TorchCLIPEncoder, ONNXCLILPEncoder]) +@pytest.mark.parametrize("encoder_class", [TorchCLIPEncoder, ONNXCLILPEncoder]) async def test_torch_executor_rank_text2imgs(encoder_class): ce = encoder_class() db = DocumentArray( - [Document(text='hello, world!'), Document(text='goodbye, world!')] + [Document(text="hello, world!"), Document(text="goodbye, world!")] ) for d in db: d.matches.extend( DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ) ) await ce.rank(db, {}) - print(db['@m', 'scores__clip_score__value']) + print(db["@m", "scores__clip_score__value"]) for d in db: for c in d.matches: - assert c.scores['clip_score'].value is not None - assert c.scores['clip_score_cosine'].value is not None + assert c.scores["clip_score"].value is not None + assert c.scores["clip_score_cosine"].value is not None assert not c.tensor np.testing.assert_almost_equal( - sum(c.scores['clip_score'].value for c in d.matches), 1 + sum(c.scores["clip_score"].value for c in d.matches), 1 ) assert not d.tensor assert not d.blob @pytest.mark.parametrize( - 'inputs', + "inputs", [ [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ), Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ), ], DocumentArray( [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ), Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ), ] ), lambda: ( Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ) for _ in range(10) @@ -113,13 +113,13 @@ async def test_torch_executor_rank_text2imgs(encoder_class): DocumentArray( [ Document( - text='hello, world', + text="hello, world", matches=[ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png' + uri="https://clip-as-service.jina.ai/_static/favicon.png" ), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ), ], ) @@ -128,12 +128,12 @@ async def test_torch_executor_rank_text2imgs(encoder_class): ], ) def test_docarray_inputs(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = c.rank(inputs if not callable(inputs) else inputs()) assert not r[0].tensor assert isinstance(r, DocumentArray) - rv1 = r['@m', 'scores__clip_score__value'] - rv2 = r['@m', 'scores__clip_score_cosine__value'] + rv1 = r["@m", "scores__clip_score__value"] + rv2 = r["@m", "scores__clip_score_cosine__value"] for v1, v2 in zip(rv1, rv2): assert v1 is not None assert v1 > 0 @@ -142,34 +142,34 @@ def test_docarray_inputs(make_flow, inputs): @pytest.mark.parametrize( - 'inputs', + "inputs", [ [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ), ], DocumentArray( [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ), ] ), lambda: ( Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', + uri="https://clip-as-service.jina.ai/_static/favicon.png", matches=[ - Document(text='hello, world'), - Document(text='goodbye, world'), + Document(text="hello, world"), + Document(text="goodbye, world"), ], ) for _ in range(1) @@ -177,13 +177,13 @@ def test_docarray_inputs(make_flow, inputs): DocumentArray( [ Document( - text='hello, world', + text="hello, world", matches=[ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png' + uri="https://clip-as-service.jina.ai/_static/favicon.png" ), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ), ], ) @@ -193,28 +193,28 @@ def test_docarray_inputs(make_flow, inputs): ) @pytest.mark.asyncio async def test_async_arank(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = await c.arank(inputs if not callable(inputs) else inputs()) assert not r[0].tensor assert isinstance(r, DocumentArray) - rv = r['@m', 'scores__clip_score__value'] + rv = r["@m", "scores__clip_score__value"] for v in rv: assert v is not None assert v > 0 np.testing.assert_almost_equal(sum(rv), 1.0) - rv = r['@m', 'scores__clip_score_cosine__value'] + rv = r["@m", "scores__clip_score_cosine__value"] for v in rv: assert v is not None assert -1.0 <= v <= 1.0 @pytest.mark.parametrize( - 'inputs', + "inputs", [ [ Document( - id=str(i), text='A', matches=[Document(text='B'), Document(text='C')] + id=str(i), text="A", matches=[Document(text="B"), Document(text="C")] ) for i in range(20) ], @@ -222,8 +222,8 @@ async def test_async_arank(make_flow, inputs): [ Document( id=str(i), - text='A', - matches=[Document(text='B'), Document(text='C')], + text="A", + matches=[Document(text="B"), Document(text="C")], ) for i in range(20) ] @@ -231,7 +231,7 @@ async def test_async_arank(make_flow, inputs): ], ) def test_docarray_preserve_original_order(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = c.rank(inputs, batch_size=1) assert isinstance(r, DocumentArray) for i in range(len(inputs)): @@ -240,11 +240,11 @@ def test_docarray_preserve_original_order(make_flow, inputs): @pytest.mark.parametrize( - 'inputs', + "inputs", [ [ Document( - id=str(i), text='A', matches=[Document(text='B'), Document(text='C')] + id=str(i), text="A", matches=[Document(text="B"), Document(text="C")] ) for i in range(20) ], @@ -252,8 +252,8 @@ def test_docarray_preserve_original_order(make_flow, inputs): [ Document( id=str(i), - text='A', - matches=[Document(text='B'), Document(text='C')], + text="A", + matches=[Document(text="B"), Document(text="C")], ) for i in range(20) ] @@ -262,7 +262,7 @@ def test_docarray_preserve_original_order(make_flow, inputs): ) @pytest.mark.asyncio async def test_async_docarray_preserve_original_order(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = await c.arank(inputs, batch_size=1) assert isinstance(r, DocumentArray) for i in range(len(inputs)): diff --git a/tests/test_search.py b/tests/test_search.py index b9645e96e..e8e4723dd 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -8,31 +8,31 @@ @pytest.mark.parametrize( - 'inputs', + "inputs", [ - [Document(text='hello, world'), Document(text='goodbye, world')], - DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), - lambda: (Document(text='hello, world') for _ in range(10)), + [Document(text="hello, world"), Document(text="goodbye, world")], + DocumentArray([Document(text="hello, world"), Document(text="goodbye, world")]), + lambda: (Document(text="hello, world") for _ in range(10)), DocumentArray( [ - Document(uri='https://clip-as-service.jina.ai/_static/favicon.png'), + Document(uri="https://clip-as-service.jina.ai/_static/favicon.png"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ), - Document(text='hello, world'), + Document(text="hello, world"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ).load_uri_to_image_tensor(), ] ), DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ), ], ) -@pytest.mark.parametrize('limit', [1, 2]) +@pytest.mark.parametrize("limit", [1, 2]) def test_index_search(make_search_flow, inputs, limit): - c = Client(server=f'grpc://0.0.0.0:{make_search_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_search_flow.port}") r = c.index(inputs if not callable(inputs) else inputs()) assert isinstance(r, DocumentArray) @@ -45,32 +45,32 @@ def test_index_search(make_search_flow, inputs, limit): @pytest.mark.parametrize( - 'inputs', + "inputs", [ - [Document(text='hello, world'), Document(text='goodbye, world')], - DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), - lambda: (Document(text='hello, world') for _ in range(10)), + [Document(text="hello, world"), Document(text="goodbye, world")], + DocumentArray([Document(text="hello, world"), Document(text="goodbye, world")]), + lambda: (Document(text="hello, world") for _ in range(10)), DocumentArray( [ - Document(uri='https://clip-as-service.jina.ai/_static/favicon.png'), + Document(uri="https://clip-as-service.jina.ai/_static/favicon.png"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ), - Document(text='hello, world'), + Document(text="hello, world"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ).load_uri_to_image_tensor(), ] ), DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ), ], ) -@pytest.mark.parametrize('limit', [1, 2]) +@pytest.mark.parametrize("limit", [1, 2]) @pytest.mark.asyncio async def test_async_index_search(make_search_flow, inputs, limit): - c = Client(server=f'grpc://0.0.0.0:{make_search_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_search_flow.port}") r = await c.aindex(inputs if not callable(inputs) else inputs()) assert isinstance(r, DocumentArray) diff --git a/tests/test_server.py b/tests/test_server.py index 56efa742f..d746c9f23 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -10,37 +10,37 @@ def test_server_download(tmpdir): download_model( - url='https://clip-as-service.jina.ai/_static/favicon.png', + url="https://clip-as-service.jina.ai/_static/favicon.png", target_folder=tmpdir, - md5sum='43104e468ddd23c55bc662d84c87a7f8', + md5sum="43104e468ddd23c55bc662d84c87a7f8", with_resume=False, ) - target_path = os.path.join(tmpdir, 'favicon.png') + target_path = os.path.join(tmpdir, "favicon.png") file_size = os.path.getsize(target_path) assert file_size > 0 - part_path = target_path + '.part' - with open(target_path, 'rb') as source, open(part_path, 'wb') as part_out: + part_path = target_path + ".part" + with open(target_path, "rb") as source, open(part_path, "wb") as part_out: buf = source.read(10) part_out.write(buf) os.remove(target_path) download_model( - url='https://clip-as-service.jina.ai/_static/favicon.png', + url="https://clip-as-service.jina.ai/_static/favicon.png", target_folder=tmpdir, - md5sum='43104e468ddd23c55bc662d84c87a7f8', + md5sum="43104e468ddd23c55bc662d84c87a7f8", with_resume=True, ) assert os.path.getsize(target_path) == file_size assert not os.path.exists(part_path) -@pytest.mark.parametrize('md5', ['ABC', None, '43104e468ddd23c55bc662d84c87a7f8']) +@pytest.mark.parametrize("md5", ["ABC", None, "43104e468ddd23c55bc662d84c87a7f8"]) def test_server_download_md5(tmpdir, md5): - if md5 != 'ABC': + if md5 != "ABC": download_model( - url='https://clip-as-service.jina.ai/_static/favicon.png', + url="https://clip-as-service.jina.ai/_static/favicon.png", target_folder=tmpdir, md5sum=md5, with_resume=False, @@ -48,7 +48,7 @@ def test_server_download_md5(tmpdir, md5): else: with pytest.raises(Exception): download_model( - url='https://clip-as-service.jina.ai/_static/favicon.png', + url="https://clip-as-service.jina.ai/_static/favicon.png", target_folder=tmpdir, md5sum=md5, with_resume=False, @@ -58,15 +58,15 @@ def test_server_download_md5(tmpdir, md5): def test_server_download_not_regular_file(tmpdir): with pytest.raises(Exception): download_model( - url='https://clip-as-service.jina.ai/_static/favicon.png', + url="https://clip-as-service.jina.ai/_static/favicon.png", target_folder=tmpdir, - md5sum='', + md5sum="", with_resume=False, ) download_model( - url='https://docarray.jina.ai/_static/', + url="https://docarray.jina.ai/_static/", target_folder=tmpdir, - md5sum='', + md5sum="", with_resume=False, ) @@ -76,21 +76,21 @@ def test_make_onnx_flow_wrong_name_path(): with pytest.raises(Exception): encoder = CLIPEncoder( - 'ABC', model_path=os.path.expanduser('~/.cache/clip/ViT-B-32') + "ABC", model_path=os.path.expanduser("~/.cache/clip/ViT-B-32") ) with pytest.raises(Exception) as info: - encoder = CLIPEncoder('ViT-B/32', model_path='~/.cache/') + encoder = CLIPEncoder("ViT-B/32", model_path="~/.cache/") @pytest.mark.parametrize( - 'image_uri', + "image_uri", [ - f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg', - 'https://clip-as-service.jina.ai/_static/favicon.png', + f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg", + "https://clip-as-service.jina.ai/_static/favicon.png", ], ) -@pytest.mark.parametrize('size', [224, 288, 384, 448]) +@pytest.mark.parametrize("size", [224, 288, 384, 448]) def test_server_preprocess_ndarray_image(image_uri, size): d1 = Document(uri=image_uri) d1.load_uri_to_blob() @@ -103,7 +103,7 @@ def test_server_preprocess_ndarray_image(image_uri, size): @pytest.mark.parametrize( - 'tensor', + "tensor", [ np.random.random([100, 100, 3]), np.random.random([1, 1, 3]), diff --git a/tests/test_simple.py b/tests/test_simple.py index 8dab51322..921d4164a 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -7,42 +7,42 @@ from clip_client.client import Client -@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket', 'other']) -@pytest.mark.parametrize('jit', [True, False]) +@pytest.mark.parametrize("protocol", ["grpc", "http", "websocket", "other"]) +@pytest.mark.parametrize("jit", [True, False]) def test_protocols(port_generator, protocol, jit, pytestconfig): from clip_server.executors.clip_torch import CLIPEncoder - if protocol == 'other': + if protocol == "other": with pytest.raises(ValueError): - Client(server=f'{protocol}://0.0.0.0:8000') + Client(server=f"{protocol}://0.0.0.0:8000") return f = Flow(port=port_generator(), protocol=protocol).add( - uses=CLIPEncoder, uses_with={'jit': jit} + uses=CLIPEncoder, uses_with={"jit": jit} ) with f: - c = Client(server=f'{protocol}://0.0.0.0:{f.port}') + c = Client(server=f"{protocol}://0.0.0.0:{f.port}") c.profile() - c.profile(content='hello world') - c.profile(content=f'{pytestconfig.rootdir}/tests/img/00000.jpg') + c.profile(content="hello world") + c.profile(content=f"{pytestconfig.rootdir}/tests/img/00000.jpg") @pytest.mark.gpu @pytest.mark.parametrize( - 'inputs', + "inputs", [ - ['hello, world', 'goodbye, world'], - ('hello, world', 'goodbye, world'), - lambda: ('hello, world' for _ in range(10)), + ["hello, world", "goodbye, world"], + ("hello, world", "goodbye, world"), + lambda: ("hello, world" for _ in range(10)), [ - 'https://clip-as-service.jina.ai/_static/favicon.png', - f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg', - 'hello, world', + "https://clip-as-service.jina.ai/_static/favicon.png", + f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg", + "hello, world", ], ], ) def test_plain_inputs(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = c.encode(inputs if not callable(inputs) else inputs()) assert ( r.shape[0] == len(list(inputs)) if not callable(inputs) else len(list(inputs())) @@ -51,57 +51,57 @@ def test_plain_inputs(make_flow, inputs): @pytest.mark.gpu @pytest.mark.parametrize( - 'inputs', + "inputs", [ - [Document(text='hello, world'), Document(text='goodbye, world')], - DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), - lambda: (Document(text='hello, world') for _ in range(10)), + [Document(text="hello, world"), Document(text="goodbye, world")], + DocumentArray([Document(text="hello, world"), Document(text="goodbye, world")]), + lambda: (Document(text="hello, world") for _ in range(10)), DocumentArray( [ - Document(uri='https://clip-as-service.jina.ai/_static/favicon.png'), + Document(uri="https://clip-as-service.jina.ai/_static/favicon.png"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ), - Document(text='hello, world'), + Document(text="hello, world"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ).load_uri_to_image_tensor(), ] ), DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ), ], ) def test_docarray_inputs(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = c.encode(inputs if not callable(inputs) else inputs()) assert isinstance(r, DocumentArray) assert r.embeddings.shape assert not r[0].tensor - if hasattr(inputs, '__len__'): + if hasattr(inputs, "__len__"): assert inputs[0] is r[0] @pytest.mark.parametrize( - 'inputs', + "inputs", [ - DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), + DocumentArray([Document(text="hello, world"), Document(text="goodbye, world")]), DocumentArray( [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', - text='hello, world', + uri="https://clip-as-service.jina.ai/_static/favicon.png", + text="hello, world", ), ] ), DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ), ], ) def test_docarray_preserve_original_inputs(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = c.encode(inputs if not callable(inputs) else inputs()) assert isinstance(r, DocumentArray) assert r.embeddings.shape @@ -111,19 +111,19 @@ def test_docarray_preserve_original_inputs(make_flow, inputs): @pytest.mark.parametrize( - 'inputs', + "inputs", [ - DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), + DocumentArray([Document(text="hello, world"), Document(text="goodbye, world")]), DocumentArray( [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', - text='hello, world', + uri="https://clip-as-service.jina.ai/_static/favicon.png", + text="hello, world", ), ] ), DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ), ], ) @@ -133,8 +133,8 @@ def test_docarray_traversal(make_flow, inputs): da = DocumentArray.empty(1) da[0].chunks = inputs - c = _Client(host=f'grpc://0.0.0.0', port=make_flow.port) - r1 = c.post(on='/', inputs=da, parameters={'traversal_paths': '@c'}) + c = _Client(host=f"grpc://0.0.0.0", port=make_flow.port) + r1 = c.post(on="/", inputs=da, parameters={"traversal_paths": "@c"}) assert isinstance(r1, DocumentArray) assert r1[0].chunks.embeddings.shape[0] == len(inputs) assert not r1[0].tensor @@ -142,7 +142,7 @@ def test_docarray_traversal(make_flow, inputs): assert not r1[0].chunks[0].tensor assert not r1[0].chunks[0].blob - r2 = c.post(on='/', inputs=da, parameters={'access_paths': '@c'}) + r2 = c.post(on="/", inputs=da, parameters={"access_paths": "@c"}) assert isinstance(r2, DocumentArray) assert r2[0].chunks.embeddings.shape[0] == len(inputs) assert not r2[0].tensor @@ -152,14 +152,14 @@ def test_docarray_traversal(make_flow, inputs): @pytest.mark.parametrize( - 'inputs', + "inputs", [ - [Document(id=str(i), text='hello, world') for i in range(20)], - DocumentArray([Document(id=str(i), text='hello, world') for i in range(20)]), + [Document(id=str(i), text="hello, world") for i in range(20)], + DocumentArray([Document(id=str(i), text="hello, world") for i in range(20)]), ], ) def test_docarray_preserve_original_order(make_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_flow.port}") r = c.encode(inputs if not callable(inputs) else inputs(), batch_size=1) assert isinstance(r, DocumentArray) for i in range(len(inputs)): diff --git a/tests/test_tensorrt.py b/tests/test_tensorrt.py index 395413f0e..35d142851 100644 --- a/tests/test_tensorrt.py +++ b/tests/test_tensorrt.py @@ -10,69 +10,69 @@ @pytest.mark.gpu @pytest.mark.parametrize( - 'inputs', + "inputs", [ - [Document(text='hello, world'), Document(text='goodbye, world')], - DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]), - lambda: (Document(text='hello, world') for _ in range(10)), + [Document(text="hello, world"), Document(text="goodbye, world")], + DocumentArray([Document(text="hello, world"), Document(text="goodbye, world")]), + lambda: (Document(text="hello, world") for _ in range(10)), DocumentArray( [ - Document(uri='https://clip-as-service.jina.ai/_static/favicon.png'), + Document(uri="https://clip-as-service.jina.ai/_static/favicon.png"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ), - Document(text='hello, world'), + Document(text="hello, world"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ).load_uri_to_image_tensor(), ] ), DocumentArray.from_files( - f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg' + f"{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg" ), ], ) def test_docarray_inputs(make_trt_flow, inputs): - c = Client(server=f'grpc://0.0.0.0:{make_trt_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_trt_flow.port}") r = c.encode(inputs if not callable(inputs) else inputs()) assert isinstance(r, DocumentArray) assert r.embeddings.shape - if hasattr(inputs, '__len__'): + if hasattr(inputs, "__len__"): assert inputs[0] is r[0] @pytest.mark.gpu @pytest.mark.asyncio @pytest.mark.parametrize( - 'd', + "d", [ Document( - uri='https://clip-as-service.jina.ai/_static/favicon.png', - matches=[Document(text='hello, world'), Document(text='goodbye, world')], + uri="https://clip-as-service.jina.ai/_static/favicon.png", + matches=[Document(text="hello, world"), Document(text="goodbye, world")], ), Document( - text='hello, world', + text="hello, world", matches=[ - Document(uri='https://clip-as-service.jina.ai/_static/favicon.png'), + Document(uri="https://clip-as-service.jina.ai/_static/favicon.png"), Document( - uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg' + uri=f"{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg" ), ], ), ], ) async def test_async_arank(make_trt_flow, d): - c = Client(server=f'grpc://0.0.0.0:{make_trt_flow.port}') + c = Client(server=f"grpc://0.0.0.0:{make_trt_flow.port}") r = await c.arank([d]) assert isinstance(r, DocumentArray) assert d is r[0] - rv = r['@m', 'scores__clip_score__value'] + rv = r["@m", "scores__clip_score__value"] for v in rv: assert v is not None assert v > 0 np.testing.assert_almost_equal(sum(rv), 1.0) - rv = r['@m', 'scores__clip_score_cosine__value'] + rv = r["@m", "scores__clip_score_cosine__value"] for v in rv: assert v is not None assert -1.0 <= v <= 1.0 diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 0fb954ca1..b79cada9c 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -3,15 +3,15 @@ @pytest.mark.parametrize( - 'name', ['ViT-L/14@336px', 'M-CLIP/XLM-Roberta-Large-Vit-B-32'] + "name", ["ViT-L/14@336px", "M-CLIP/XLM-Roberta-Large-Vit-B-32"] ) def test_tokenizer_name(name): tokenizer = Tokenizer(name) - result = tokenizer('hello world') - assert result['input_ids'].shape == result['attention_mask'].shape - assert result['input_ids'].shape[0] == 1 + result = tokenizer("hello world") + assert result["input_ids"].shape == result["attention_mask"].shape + assert result["input_ids"].shape[0] == 1 - result = tokenizer(['hello world', 'welcome to the world']) - assert result['input_ids'].shape == result['attention_mask'].shape - assert result['input_ids'].shape[0] == 2 + result = tokenizer(["hello world", "welcome to the world"]) + assert result["input_ids"].shape == result["attention_mask"].shape + assert result["input_ids"].shape[0] == 2 From a7d724e80588e9fc1c9498164f06e6018efd402b Mon Sep 17 00:00:00 2001 From: "felix.wang" Date: Fri, 29 Dec 2023 17:48:00 +0800 Subject: [PATCH 4/5] chore: fix black --- server/clip_server/model/pretrained_models.py | 242 +++++++++--------- 1 file changed, 121 insertions(+), 121 deletions(-) diff --git a/server/clip_server/model/pretrained_models.py b/server/clip_server/model/pretrained_models.py index 222cec39c..72d023dd7 100644 --- a/server/clip_server/model/pretrained_models.py +++ b/server/clip_server/model/pretrained_models.py @@ -5,136 +5,136 @@ import requests -_OPENCLIP_S3_BUCKET = "https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch" -_OPENCLIP_HUGGINGFACE_BUCKET = "https://huggingface.co/jinaai/" +_OPENCLIP_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/torch' +_OPENCLIP_HUGGINGFACE_BUCKET = 'https://huggingface.co/jinaai/' _OPENCLIP_MODELS = { - "RN50::openai": ("RN50.pt", "9140964eaaf9f68c95aa8df6ca13777c"), - "RN50::yfcc15m": ("RN50-yfcc15m.pt", "e9c564f91ae7dc754d9043fdcd2a9f22"), - "RN50::cc12m": ("RN50-cc12m.pt", "37cb01eb52bb6efe7666b1ff2d7311b5"), - "RN101::openai": ("RN101.pt", "fa9d5f64ebf152bc56a18db245071014"), - "RN101::yfcc15m": ("RN101-yfcc15m.pt", "48f7448879ce25e355804f6bb7928cb8"), - "RN50x4::openai": ("RN50x4.pt", "03830990bc768e82f7fb684cde7e5654"), - "RN50x16::openai": ("RN50x16.pt", "83d63878a818c65d0fb417e5fab1e8fe"), - "RN50x64::openai": ("RN50x64.pt", "a6631a0de003c4075d286140fc6dd637"), - "ViT-B-32::openai": ("ViT-B-32.pt", "3ba34e387b24dfe590eeb1ae6a8a122b"), - "ViT-B-32::laion2b_e16": ( - "ViT-B-32-laion2b_e16.pt", - "df08de3d9f2dc53c71ea26e184633902", + 'RN50::openai': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), + 'RN50::yfcc15m': ('RN50-yfcc15m.pt', 'e9c564f91ae7dc754d9043fdcd2a9f22'), + 'RN50::cc12m': ('RN50-cc12m.pt', '37cb01eb52bb6efe7666b1ff2d7311b5'), + 'RN101::openai': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), + 'RN101::yfcc15m': ('RN101-yfcc15m.pt', '48f7448879ce25e355804f6bb7928cb8'), + 'RN50x4::openai': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'), + 'RN50x16::openai': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'), + 'RN50x64::openai': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'), + 'ViT-B-32::openai': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'), + 'ViT-B-32::laion2b_e16': ( + 'ViT-B-32-laion2b_e16.pt', + 'df08de3d9f2dc53c71ea26e184633902', ), - "ViT-B-32::laion400m_e31": ( - "ViT-B-32-laion400m_e31.pt", - "ca8015f98ab0f8780510710681d7b73e", + 'ViT-B-32::laion400m_e31': ( + 'ViT-B-32-laion400m_e31.pt', + 'ca8015f98ab0f8780510710681d7b73e', ), - "ViT-B-32::laion400m_e32": ( - "ViT-B-32-laion400m_e32.pt", - "359e0dba4a419f175599ee0c63a110d8", + 'ViT-B-32::laion400m_e32': ( + 'ViT-B-32-laion400m_e32.pt', + '359e0dba4a419f175599ee0c63a110d8', ), - "ViT-B-32::laion2b-s34b-b79k": ( - "ViT-B-32-laion2b-s34b-b79k.bin", - "2fc036aea9cd7306f5ce7ce6abb8d0bf", + 'ViT-B-32::laion2b-s34b-b79k': ( + 'ViT-B-32-laion2b-s34b-b79k.bin', + '2fc036aea9cd7306f5ce7ce6abb8d0bf', ), - "ViT-B-16::openai": ("ViT-B-16.pt", "44c3d804ecac03d9545ac1a3adbca3a6"), - "ViT-B-16::laion400m_e31": ( - "ViT-B-16-laion400m_e31.pt", - "31306a44224cc46fec1bc3b82fd0c4e6", + 'ViT-B-16::openai': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'), + 'ViT-B-16::laion400m_e31': ( + 'ViT-B-16-laion400m_e31.pt', + '31306a44224cc46fec1bc3b82fd0c4e6', ), - "ViT-B-16::laion400m_e32": ( - "ViT-B-16-laion400m_e32.pt", - "07283adc5c17899f2ed22d82b563c54b", + 'ViT-B-16::laion400m_e32': ( + 'ViT-B-16-laion400m_e32.pt', + '07283adc5c17899f2ed22d82b563c54b', ), - "ViT-B-16-plus-240::laion400m_e31": ( - "ViT-B-16-plus-240-laion400m_e31.pt", - "c88f453644a998ecb094d878a2f0738d", + 'ViT-B-16-plus-240::laion400m_e31': ( + 'ViT-B-16-plus-240-laion400m_e31.pt', + 'c88f453644a998ecb094d878a2f0738d', ), - "ViT-B-16-plus-240::laion400m_e32": ( - "ViT-B-16-plus-240-laion400m_e32.pt", - "e573af3cef888441241e35022f30cc95", + 'ViT-B-16-plus-240::laion400m_e32': ( + 'ViT-B-16-plus-240-laion400m_e32.pt', + 'e573af3cef888441241e35022f30cc95', ), - "ViT-L-14::openai": ("ViT-L-14.pt", "096db1af569b284eb76b3881534822d9"), - "ViT-L-14::laion400m_e31": ( - "ViT-L-14-laion400m_e31.pt", - "09d223a6d41d2c5c201a9da618d833aa", + 'ViT-L-14::openai': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'), + 'ViT-L-14::laion400m_e31': ( + 'ViT-L-14-laion400m_e31.pt', + '09d223a6d41d2c5c201a9da618d833aa', ), - "ViT-L-14::laion400m_e32": ( - "ViT-L-14-laion400m_e32.pt", - "a76cde1bc744ca38c6036b920c847a89", + 'ViT-L-14::laion400m_e32': ( + 'ViT-L-14-laion400m_e32.pt', + 'a76cde1bc744ca38c6036b920c847a89', ), - "ViT-L-14::laion2b-s32b-b82k": ( - "ViT-L-14-laion2b-s32b-b82k.bin", - "4d2275fc7b2d7ee9db174f9b57ddecbd", + 'ViT-L-14::laion2b-s32b-b82k': ( + 'ViT-L-14-laion2b-s32b-b82k.bin', + '4d2275fc7b2d7ee9db174f9b57ddecbd', ), - "ViT-L-14-336::openai": ("ViT-L-14-336px.pt", "b311058cae50cb10fbfa2a44231c9473"), - "ViT-H-14::laion2b-s32b-b79k": ( - "ViT-H-14-laion2b-s32b-b79k.bin", - "2aa6c46521b165a0daeb8cdc6668c7d3", + 'ViT-L-14-336::openai': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'), + 'ViT-H-14::laion2b-s32b-b79k': ( + 'ViT-H-14-laion2b-s32b-b79k.bin', + '2aa6c46521b165a0daeb8cdc6668c7d3', ), - "ViT-g-14::laion2b-s12b-b42k": ( - "ViT-g-14-laion2b-s12b-b42k.bin", - "3bf99353f6f1829faac0bb155be4382a", + 'ViT-g-14::laion2b-s12b-b42k': ( + 'ViT-g-14-laion2b-s12b-b42k.bin', + '3bf99353f6f1829faac0bb155be4382a', ), - "roberta-ViT-B-32::laion2b-s12b-b32k": ( - "roberta-ViT-B-32-laion2b-s12b-b32k.bin", - "76d4c9d13774cc15fa0e2b1b94a8402c", + 'roberta-ViT-B-32::laion2b-s12b-b32k': ( + 'roberta-ViT-B-32-laion2b-s12b-b32k.bin', + '76d4c9d13774cc15fa0e2b1b94a8402c', ), - "xlm-roberta-base-ViT-B-32::laion5b-s13b-b90k": ( - "xlm-roberta-base-ViT-B-32-laion5b-s13b-b90k.bin", - "f68abc07ef349720f1f880180803142d", + 'xlm-roberta-base-ViT-B-32::laion5b-s13b-b90k': ( + 'xlm-roberta-base-ViT-B-32-laion5b-s13b-b90k.bin', + 'f68abc07ef349720f1f880180803142d', ), - "xlm-roberta-large-ViT-H-14::frozen_laion5b_s13b_b90k": ( - "xlm-roberta-large-ViT-H-14-frozen_laion5b_s13b_b90k.bin", - "b49991239a419d704fdba59c42d5536d", + 'xlm-roberta-large-ViT-H-14::frozen_laion5b_s13b_b90k': ( + 'xlm-roberta-large-ViT-H-14-frozen_laion5b_s13b_b90k.bin', + 'b49991239a419d704fdba59c42d5536d', ), # older version name format - "RN50": ("RN50.pt", "9140964eaaf9f68c95aa8df6ca13777c"), - "RN101": ("RN101.pt", "fa9d5f64ebf152bc56a18db245071014"), - "RN50x4": ("RN50x4.pt", "03830990bc768e82f7fb684cde7e5654"), - "RN50x16": ("RN50x16.pt", "83d63878a818c65d0fb417e5fab1e8fe"), - "RN50x64": ("RN50x64.pt", "a6631a0de003c4075d286140fc6dd637"), - "ViT-B/32": ("ViT-B-32.pt", "3ba34e387b24dfe590eeb1ae6a8a122b"), - "ViT-B/16": ("ViT-B-16.pt", "44c3d804ecac03d9545ac1a3adbca3a6"), - "ViT-L/14": ("ViT-L-14.pt", "096db1af569b284eb76b3881534822d9"), - "ViT-L/14@336px": ("ViT-L-14-336px.pt", "b311058cae50cb10fbfa2a44231c9473"), + 'RN50': ('RN50.pt', '9140964eaaf9f68c95aa8df6ca13777c'), + 'RN101': ('RN101.pt', 'fa9d5f64ebf152bc56a18db245071014'), + 'RN50x4': ('RN50x4.pt', '03830990bc768e82f7fb684cde7e5654'), + 'RN50x16': ('RN50x16.pt', '83d63878a818c65d0fb417e5fab1e8fe'), + 'RN50x64': ('RN50x64.pt', 'a6631a0de003c4075d286140fc6dd637'), + 'ViT-B/32': ('ViT-B-32.pt', '3ba34e387b24dfe590eeb1ae6a8a122b'), + 'ViT-B/16': ('ViT-B-16.pt', '44c3d804ecac03d9545ac1a3adbca3a6'), + 'ViT-L/14': ('ViT-L-14.pt', '096db1af569b284eb76b3881534822d9'), + 'ViT-L/14@336px': ('ViT-L-14-336px.pt', 'b311058cae50cb10fbfa2a44231c9473'), } _MULTILINGUALCLIP_MODELS = { - "M-CLIP/XLM-Roberta-Large-Vit-B-32": (), - "M-CLIP/XLM-Roberta-Large-Vit-L-14": (), - "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": (), - "M-CLIP/LABSE-Vit-L-14": (), + 'M-CLIP/XLM-Roberta-Large-Vit-B-32': (), + 'M-CLIP/XLM-Roberta-Large-Vit-L-14': (), + 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus': (), + 'M-CLIP/LABSE-Vit-L-14': (), } _CNCLIP_MODELS = { - "CN-CLIP/ViT-B-16": (), - "CN-CLIP/ViT-L-14": (), - "CN-CLIP/ViT-L-14-336": (), - "CN-CLIP/ViT-H-14": (), - "CN-CLIP/RN50": (), + 'CN-CLIP/ViT-B-16': (), + 'CN-CLIP/ViT-L-14': (), + 'CN-CLIP/ViT-L-14-336': (), + 'CN-CLIP/ViT-H-14': (), + 'CN-CLIP/RN50': (), } _VISUAL_MODEL_IMAGE_SIZE = { - "RN50": 224, - "RN101": 224, - "RN50x4": 288, - "RN50x16": 384, - "RN50x64": 448, - "ViT-B-32": 224, - "roberta-ViT-B-32": 224, - "xlm-roberta-base-ViT-B-32": 224, - "ViT-B-16": 224, - "Vit-B-16Plus": 240, - "ViT-B-16-plus-240": 240, - "ViT-L-14": 224, - "ViT-L-14-336": 336, - "ViT-H-14": 224, - "xlm-roberta-large-ViT-H-14": 224, - "ViT-g-14": 224, + 'RN50': 224, + 'RN101': 224, + 'RN50x4': 288, + 'RN50x16': 384, + 'RN50x64': 448, + 'ViT-B-32': 224, + 'roberta-ViT-B-32': 224, + 'xlm-roberta-base-ViT-B-32': 224, + 'ViT-B-16': 224, + 'Vit-B-16Plus': 240, + 'ViT-B-16-plus-240': 240, + 'ViT-L-14': 224, + 'ViT-L-14-336': 336, + 'ViT-H-14': 224, + 'xlm-roberta-large-ViT-H-14': 224, + 'ViT-g-14': 224, } def md5file(filename: str): hash_md5 = hashlib.md5() - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): hash_md5.update(chunk) return hash_md5.hexdigest() @@ -147,26 +147,26 @@ def get_model_url_md5(name: str): else: hg_download_url = ( _OPENCLIP_HUGGINGFACE_BUCKET - + name.split("::")[0] - + "/resolve/main/" + + name.split('::')[0] + + '/resolve/main/' + model_pretrained[0] - + "?download=true" + + '?download=true' ) try: response = requests.head(hg_download_url) if response.status_code in [200, 302]: return (hg_download_url, model_pretrained[1]) else: - print(f"Model not found on hugging face, trying to download from s3.") + print(f'Model not found on hugging face, trying to download from s3.') except requests.exceptions.RequestException as e: print(str(e)) - print(f"Model not found on hugging face, trying to download from s3.") - return (_OPENCLIP_S3_BUCKET + "/" + model_pretrained[0], model_pretrained[1]) + print(f'Model not found on hugging face, trying to download from s3.') + return (_OPENCLIP_S3_BUCKET + '/' + model_pretrained[0], model_pretrained[1]) def download_model( url: str, - target_folder: str = os.path.expanduser("~/.cache/clip"), + target_folder: str = os.path.expanduser('~/.cache/clip'), md5sum: str = None, with_resume: bool = True, max_attempts: int = 3, @@ -178,7 +178,7 @@ def download_model( if os.path.exists(download_target): if not os.path.isfile(download_target): - raise FileExistsError(f"{download_target} exists and is not a regular file") + raise FileExistsError(f'{download_target} exists and is not a regular file') actual_md5sum = md5file(download_target) if (not md5sum) or actual_md5sum == md5sum: @@ -193,33 +193,33 @@ def download_model( ) progress = Progress( - " \n", # divide this bar from Flow's bar - TextColumn("[bold blue]{task.fields[filename]}", justify="right"), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", + ' \n', # divide this bar from Flow's bar + TextColumn('[bold blue]{task.fields[filename]}', justify='right'), + '[progress.percentage]{task.percentage:>3.1f}%', + '•', DownloadColumn(), - "•", + '•', TransferSpeedColumn(), - "•", + '•', TimeRemainingColumn(), ) with progress: - task = progress.add_task("download", filename=filename, start=False) + task = progress.add_task('download', filename=filename, start=False) for _ in range(max_attempts): - tmp_file_path = download_target + ".part" + tmp_file_path = download_target + '.part' resume_byte_pos = ( os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0 ) try: # resolve the 403 error by passing a valid user-agent - req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"}) + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) total_bytes = int( - urllib.request.urlopen(req).info().get("Content-Length", -1) + urllib.request.urlopen(req).info().get('Content-Length', -1) ) - mode = "ab" if (with_resume and resume_byte_pos) else "wb" + mode = 'ab' if (with_resume and resume_byte_pos) else 'wb' with open(tmp_file_path, mode) as output: progress.update(task, total=total_bytes) @@ -227,7 +227,7 @@ def download_model( if resume_byte_pos and with_resume: progress.update(task, advance=resume_byte_pos) - req.headers["Range"] = f"bytes={resume_byte_pos}-" + req.headers['Range'] = f'bytes={resume_byte_pos}-' with urllib.request.urlopen(req) as source: while True: @@ -245,15 +245,15 @@ def download_model( else: os.remove(tmp_file_path) raise RuntimeError( - f"MD5 mismatch: expected {md5sum}, got {actual_md5}" + f'MD5 mismatch: expected {md5sum}, got {actual_md5}' ) except Exception as ex: progress.console.print( - f"Failed to download {url} with {ex!r} at the {_}th attempt" + f'Failed to download {url} with {ex!r} at the {_}th attempt' ) progress.reset(task) raise RuntimeError( - f"Failed to download {url} within retry limit {max_attempts}" + f'Failed to download {url} within retry limit {max_attempts}' ) From 2bb2fe79b3f39019b84dcd05172f2b951ede67a1 Mon Sep 17 00:00:00 2001 From: "felix.wang" Date: Fri, 29 Dec 2023 17:58:04 +0800 Subject: [PATCH 5/5] fix: black bug --- .github/README-img/banner.svg | 60 +- .github/README-img/pyclient-output.svg | 4 +- .github/README-img/rerank-chart.svg | 2 +- .github/README-img/server-output.svg | 4 +- .github/codecov.yml | 4 +- .github/release-template.ejs | 2 +- .github/workflows/cd.yml | 34 +- .github/workflows/ci.yml | 68 +- .github/workflows/force-docker-build-cas.yml | 64 +- .github/workflows/force-docker-build.yml | 44 +- .github/workflows/force-docs-build.yml | 12 +- .github/workflows/force-hub-push.yml | 86 +- .github/workflows/force-release.yml | 8 +- .github/workflows/label-pr.yml | 6 +- .github/workflows/tag.yml | 10 +- Dockerfiles/base.Dockerfile | 22 +- Dockerfiles/cuda.Dockerfile | 24 +- Dockerfiles/server.Dockerfile | 18 +- Dockerfiles/tensorrt.Dockerfile | 24 +- LICENSE | 34 +- README.md | 196 +-- client/clip_client/__init__.py | 6 +- client/clip_client/client.py | 482 +++--- client/clip_client/helper.py | 20 +- client/setup.py | 112 +- docs/Makefile | 8 +- docs/_static/JCloud-dark.svg | 16 +- docs/_static/JCloud-light.svg | 16 +- docs/_static/cas-dark.svg | 16 +- docs/_static/cas-grafana.json | 1370 ++++++++--------- docs/_static/cas-light.svg | 16 +- docs/_static/demo-embed.html | 54 +- docs/_static/demo-text-rank.html | 72 +- docs/_static/docarray-dark.svg | 12 +- docs/_static/docarray-light.svg | 12 +- docs/_static/finetuner-dark.svg | 4 +- docs/_static/finetuner-light.svg | 4 +- docs/_static/hub-dark.svg | 8 +- docs/_static/hub-light.svg | 8 +- docs/_static/logo-dark.svg | 26 +- docs/_static/logo-light.svg | 58 +- docs/_static/main.css | 12 +- docs/_static/now-dark.svg | 18 +- docs/_static/now-light.svg | 18 +- docs/_static/search-dark.svg | 10 +- docs/_static/search-light.svg | 10 +- docs/_templates/page.html | 226 +-- docs/_templates/sidebar/brand.html | 30 +- docs/_templates/sidebar/navigation.html | 48 +- docs/changelog/index.md | 2 +- docs/conf.py | 228 +-- docs/hosting/cas-on-colab.ipynb | 686 ++++----- docs/hosting/cas-on-colab.svg | 2 +- docs/index.md | 4 +- docs/playground/embedding.md | 2 +- docs/playground/reasoning.md | 2 +- docs/playground/searching.md | 2 +- docs/user-guides/client.md | 24 +- docs/user-guides/faq.md | 14 +- docs/user-guides/finetuner.md | 60 +- docs/user-guides/retriever.md | 2 +- docs/user-guides/server.md | 12 +- scripts/benchmark.py | 44 +- scripts/black.sh | 4 +- scripts/docstrings_lint.sh | 4 +- scripts/get-all-test-paths.sh | 12 +- scripts/get-last-release-note.py | 6 +- scripts/get-requirements.py | 8 +- scripts/onnx_helper.py | 8 +- scripts/release.sh | 74 +- scripts/setup.py | 84 +- server/clip_server/__init__.py | 2 +- server/clip_server/__main__.py | 10 +- server/clip_server/executors/clip_onnx.py | 100 +- server/clip_server/executors/clip_tensorrt.py | 88 +- server/clip_server/executors/clip_torch.py | 98 +- server/clip_server/executors/helper.py | 50 +- server/clip_server/helper.py | 28 +- server/clip_server/model/clip.py | 2 +- server/clip_server/model/clip_model.py | 6 +- server/clip_server/model/clip_onnx.py | 284 ++-- server/clip_server/model/clip_trt.py | 54 +- server/clip_server/model/cnclip_model.py | 16 +- server/clip_server/model/flash_attention.py | 4 +- server/clip_server/model/mclip_model.py | 18 +- server/clip_server/model/model.py | 132 +- server/clip_server/model/openclip_model.py | 24 +- server/clip_server/model/simple_tokenizer.py | 50 +- server/clip_server/model/tokenization.py | 22 +- server/clip_server/model/trt_utils.py | 48 +- server/setup.py | 126 +- tests/__init__.py | 2 +- tests/conftest.py | 20 +- tests/test_asyncio.py | 24 +- tests/test_client.py | 70 +- tests/test_helper.py | 34 +- tests/test_model.py | 26 +- tests/test_ranker.py | 132 +- tests/test_search.py | 44 +- tests/test_server.py | 44 +- tests/test_simple.py | 90 +- tests/test_tensorrt.py | 40 +- tests/test_tokenization.py | 14 +- 103 files changed, 3202 insertions(+), 3202 deletions(-) diff --git a/.github/README-img/banner.svg b/.github/README-img/banner.svg index 80a5711b5..85c64c486 100644 --- a/.github/README-img/banner.svg +++ b/.github/README-img/banner.svg @@ -1,43 +1,43 @@ - - - - + + + + - - - - - - - - - + + + + + + + + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - - + + + diff --git a/.github/README-img/pyclient-output.svg b/.github/README-img/pyclient-output.svg index 9a7bfd8f2..5964d8558 100644 --- a/.github/README-img/pyclient-output.svg +++ b/.github/README-img/pyclient-output.svg @@ -1,10 +1,10 @@ -