Skip to content

Commit

Permalink
Make async
Browse files Browse the repository at this point in the history
  • Loading branch information
yankeexe committed Jan 22, 2025
1 parent cfb27f6 commit ee586f9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 80 deletions.
4 changes: 3 additions & 1 deletion ollama_manager/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from ollama_manager.commands.delete import delete_model
from ollama_manager.commands.pull import pull_model
from ollama_manager.commands.run import run_model
from ollama_manager.utils import coro


@click.group()
def cli():
@coro
async def cli():
pass


Expand Down
164 changes: 85 additions & 79 deletions ollama_manager/commands/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import sys

import click
import httpx
import ollama
import requests
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, SoupStrainer

from ollama_manager.utils import get_session, handle_interaction, make_request
from ollama_manager.utils import coro, handle_interaction


def extract_quantization(text):
Expand Down Expand Up @@ -92,11 +92,8 @@ def format_bytes(size_bytes: int) -> str:
return f"{scaled_size:.2f} {_SUFFIXES[magnitude]}"


def list_remote_model_tags(model_name: str, session: requests.Session):
response = make_request(
session=session,
url=f"https://ollama.com/library/{model_name}/tags",
)
async def list_remote_model_tags(model_name: str, client: httpx.AsyncClient):
response = await client.get(f"https://ollama.com/library/{model_name}/tags")
soup = BeautifulSoup(response.text, "html.parser")

# Find the span with the specific attribute
Expand Down Expand Up @@ -136,21 +133,21 @@ def list_remote_model_tags(model_name: str, session: requests.Session):
return results


def list_remote_models(session: requests.Session) -> list[str] | None:
response = make_request(session=session, url="https://ollama.com/search")
async def list_remote_models(client: httpx.AsyncClient) -> list[str] | None:
response = await client.get(url="https://ollama.com/search")

soup = BeautifulSoup(response.text, "html.parser")
# Find the span with the specific attribute
# @NOTE: This might change with website updates.
title_strainer = SoupStrainer("span", attrs={"x-test-search-response-title": True})
soup = BeautifulSoup(response.text, "html.parser", parse_only=title_strainer)
elements = soup.find_all("span", attrs={"x-test-search-response-title": True})

if not elements:
return None

return [element.text.strip() for element in elements]


def list_hugging_face_models(
session: requests.Session, limit: int, query: str
async def list_hugging_face_models(
client: httpx.AsyncClient, limit: int, query: str
) -> list[dict[str, str]]:
BASE_API_ENDPOINT = "https://huggingface.co/api/models"
params = {
Expand All @@ -162,7 +159,7 @@ def list_hugging_face_models(
"config": False,
"search": query,
}
res = make_request(session, url=BASE_API_ENDPOINT, params=params)
res = await client.get(url=BASE_API_ENDPOINT, params=params)
hf_response = res.json()
payload = []

Expand All @@ -175,9 +172,12 @@ def list_hugging_face_models(
return payload


def list_hugging_face_model_quantization(session: requests.Session, model_name: str):
API_ENDPOINT = f"https://huggingface.co/api/models/{model_name}?blobs=true"
res = make_request(session=session, url=API_ENDPOINT)
async def list_hugging_face_model_quantization(
client: httpx.AsyncClient, model_name: str
):
res = await client.get(
url=f"https://huggingface.co/api/models/{model_name}?blobs=true"
)
hf_response = res.json()
payload = []
files = hf_response.get("siblings")
Expand Down Expand Up @@ -217,75 +217,81 @@ def list_hugging_face_model_quantization(session: requests.Session, model_name:
type=int,
default=20,
)
def pull_model(hugging_face: bool, query: str, limit: int):
@coro
async def pull_model(hugging_face: bool, query: str, limit: int):
"""
Pull models from Ollama library:
https://ollama.dev/search
"""
session = get_session()
if hugging_face:
if not query:
query = input("🤗 hf search: ")
models = list_hugging_face_models(session, limit, query)
else:
models = list_remote_models(session)

if not models:
print("❌ No models selected for download")
sys.exit(0)

model_selection = handle_interaction(
models, title="📦 Select remote Ollama model\s:\n", multi_select=False
)
if model_selection:
print("Pulling Model....")
async with httpx.AsyncClient() as client:
if hugging_face:
model_tags = list_hugging_face_model_quantization(
session=session, model_name=model_selection[0]
)
if not query:
query = input("🤗 hf search: ")
models = await list_hugging_face_models(client, limit, query)
else:
model_tags = list_remote_model_tags(
model_name=model_selection[0], session=session
)
if not model_tags:
print(f"❌ Failed fetching tags for: {model_selection}. Please try again.")
sys.exit(1)
models = await list_remote_models(client)

max_length = max(len(f"{model_selection}:{tag['title']}") for tag in model_tags)
if not models:
print("❌ No models selected for download")
sys.exit(0)

if hugging_face:
model_name_with_tags = [
f"{tag['title']:<{max_length}}{tag['size']:<{max_length}}{tag['updated']}"
for tag in model_tags
]
else:
model_name_with_tags = [
f"{model_selection[0]}:{tag['title']:<{max_length + 5}}{tag['size']:<{max_length + 5}}{tag['updated']}"
for tag in model_tags
]
selected_model_with_tag = handle_interaction(
model_name_with_tags, title="🔖 Select tag/quantization:\n"
model_selection = handle_interaction(
models, title="📦 Select remote Ollama model\s:\n", multi_select=False
)
if not selected_model_with_tag:
print("No tag selected for the model")
sys.exit(1)
if model_selection:
if hugging_face:
model_tags = await list_hugging_face_model_quantization(
client=client, model_name=model_selection[0]
)
else:
model_tags = await list_remote_model_tags(
model_name=model_selection[0], client=client
)
if not model_tags:
print(
f"❌ Failed fetching tags for: {model_selection}. Please try again."
)
sys.exit(1)

max_length = max(
len(f"{model_selection}:{tag['title']}") for tag in model_tags
)

if hugging_face:
final_model = (
f"hf.co/{model_selection[0]}:{model_name_with_tags[0]}".split()[0]
if hugging_face:
model_name_with_tags = [
f"{tag['title']:<{max_length}}{tag['size']:<{max_length}}{tag['updated']}"
for tag in model_tags
]
else:
model_name_with_tags = [
f"{model_selection[0]}:{tag['title']:<{max_length + 5}}{tag['size']:<{max_length + 5}}{tag['updated']}"
for tag in model_tags
]
selected_model_with_tag = handle_interaction(
model_name_with_tags, title="🔖 Select tag/quantization:\n"
)
else:
final_model = selected_model_with_tag[0].split()[0]
print(f">>> Pulling model: {final_model}")
try:
response = ollama.pull(final_model, stream=True)
screen_padding = 100

for data in response:
out = f"Status: {data.get('status')} | Completed: {format_bytes(data.get('completed'))}/{format_bytes(data.get('total'))}"
print(f"{out:<{screen_padding}}", end="\r", flush=True)

print(f'\r{" " * screen_padding}\r') # Clear screen
print(f"✅ {final_model} model is ready for use!\n\n>>> olm run\n")
except Exception as e:
print(f"❌ Failed downloading {final_model}\n{str(e)}")
if not selected_model_with_tag:
print("No tag selected for the model")
sys.exit(1)

if hugging_face:
final_model = (
f"hf.co/{model_selection[0]}:{model_name_with_tags[0]}".split()[0]
)
else:
final_model = selected_model_with_tag[0].split()[0]
print(f">>> Pulling model: {final_model}")
try:
response = ollama.pull(final_model, stream=True)
screen_padding = 100

for data in response:
out = f"Status: {data.get('status')} | Completed: {format_bytes(data.get('completed'))}/{format_bytes(data.get('total'))}"
print(f"{out:<{screen_padding}}", end="\r", flush=True)

print(f'\r{" " * screen_padding}\r') # Clear screen
print(f"✅ {final_model} model is ready for use!\n\n>>> olm run\n")
except Exception as e:
print(f"❌ Failed downloading {final_model}\n{str(e)}")
10 changes: 10 additions & 0 deletions ollama_manager/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import asyncio
import sys
from functools import wraps

import ollama
import requests
from simple_term_menu import TerminalMenu


def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))

return wrapper


def get_session() -> requests.Session:
session = requests.Session()
session.headers = {
Expand Down

0 comments on commit ee586f9

Please sign in to comment.