Skip to content

Commit

Permalink
fix: mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
jaluma committed Aug 1, 2024
1 parent 21c622e commit 755f886
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions private_gpt/utils/ollama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from typing import Any, Generator, Mapping, Iterator
from tqdm import tqdm
from collections import deque
from collections.abc import Iterator, Mapping
from typing import Any

from tqdm import tqdm # type: ignore

try:
from ollama import Client # type: ignore
Expand All @@ -24,10 +26,12 @@ def check_connection(client: Client) -> bool:

def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
progress_bars = {}
queue = deque()
queue = deque() # type: ignore

def create_progress_bar(dgt: str, total: int) -> tqdm:
return tqdm(total=total, desc=f"Pulling model {dgt[7:17]}...", unit='B', unit_scale=True)
def create_progress_bar(dgt: str, total: int) -> Any:
return tqdm(
total=total, desc=f"Pulling model {dgt[7:17]}...", unit="B", unit_scale=True
)

current_digest = None

Expand All @@ -52,10 +56,7 @@ def create_progress_bar(dgt: str, total: int) -> tqdm:
progress_bar.update(progress)
if progress_bar.n >= total_size:
progress_bar.close()
if queue:
current_digest = queue.popleft()
else:
current_digest = None
current_digest = queue.popleft() if queue else None
else:
# Store progress for later update
progress_bars[digest].total = total_size
Expand All @@ -65,6 +66,7 @@ def create_progress_bar(dgt: str, total: int) -> tqdm:
for progress_bar in progress_bars.values():
progress_bar.close()


def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
try:
installed_models = [model["name"] for model in client.list().get("models", {})]
Expand Down

0 comments on commit 755f886

Please sign in to comment.