Skip to content

Commit

Permalink
Merge pull request #99 from NexaAI/david/newfeature
Browse files Browse the repository at this point in the history
support nexa run/pull gguf models from huggingface
  • Loading branch information
zhiyuan8 authored Sep 18, 2024
2 parents 10995f7 + e9aa386 commit 05e0538
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 24 deletions.
22 changes: 15 additions & 7 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ def run_ggml_inference(args):
from nexa.gguf.server.nexa_service import run_nexa_ai_service as NexaServer
NexaServer(model_path, **kwargs)
return

from nexa.general import pull_model
local_path, run_type = pull_model(model_path)


hf = kwargs.pop('huggingface', False)
stop_words = kwargs.pop("stop_words", [])

from nexa.general import pull_model
local_path, run_type = pull_model(model_path, hf)

try:
if run_type == "NLP":
from nexa.gguf.nexa_inference_text import NexaTextInference
Expand Down Expand Up @@ -107,6 +108,7 @@ def main():
text_group.add_argument("-k", "--top_k", type=int, help="Top-k sampling parameter")
text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter")
text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping")
text_group.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub")

# Image generation arguments
image_group = run_parser.add_argument_group('Image generation options')
Expand Down Expand Up @@ -168,8 +170,13 @@ def main():
server_parser.add_argument("--nctx", type=int, default=2048, help="Length of context window")

# Other commands
subparsers.add_parser("pull", help="Pull a model from official or hub.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub")
subparsers.add_parser("remove", help="Remove a model from local machine.").add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub")
pull_parser = subparsers.add_parser("pull", help="Pull a model from official or hub.")
pull_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub")
pull_parser.add_argument("-hf", "--huggingface", action="store_true", help="Pull model from Hugging Face Hub")

remove_parser = subparsers.add_parser("remove", help="Remove a model from local machine.")
remove_parser.add_argument("model_path", type=str, help="Path or identifier for the model in Nexa Model Hub")

subparsers.add_parser("clean", help="Clean up all model files.")
subparsers.add_parser("list", help="List all models in the local machine.")
subparsers.add_parser("login", help="Login to Nexa API.")
Expand All @@ -185,7 +192,8 @@ def main():
run_onnx_inference(args)
elif args.command == "pull":
from nexa.general import pull_model
pull_model(args.model_path)
hf = getattr(args, 'huggingface', False)
pull_model(args.model_path, hf)
elif args.command == "remove":
from nexa.general import remove_model
remove_model(args.model_path)
Expand Down
1 change: 1 addition & 0 deletions nexa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
NEXA_TOKEN_PATH = NEXA_CACHE_ROOT / "token"
NEXA_MODELS_HUB_DIR = NEXA_CACHE_ROOT / "hub"
NEXA_MODELS_HUB_OFFICIAL_DIR = NEXA_MODELS_HUB_DIR / "official"
NEXA_MODELS_HUB_HF_DIR = NEXA_MODELS_HUB_DIR / "huggingface"
NEXA_MODEL_LIST_PATH = NEXA_MODELS_HUB_DIR / "model_list.json"

# URLs and buckets
Expand Down
128 changes: 111 additions & 17 deletions nexa/general.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from pathlib import Path
from typing import Tuple
import shutil
import requests

Expand All @@ -10,6 +11,7 @@
NEXA_MODEL_LIST_PATH,
NEXA_MODELS_HUB_DIR,
NEXA_MODELS_HUB_OFFICIAL_DIR,
NEXA_MODELS_HUB_HF_DIR,
NEXA_OFFICIAL_BUCKET,
NEXA_RUN_MODEL_MAP,
NEXA_TOKEN_PATH,
Expand Down Expand Up @@ -99,19 +101,22 @@ def get_user_info(token):
return None


def pull_model(model_path):
def pull_model(model_path, hf = False):
model_path = NEXA_RUN_MODEL_MAP.get(model_path, model_path)

try:
if is_model_exists(model_path):
location, run_type = get_model_info(model_path)
print(f"Model {model_path} already exists at {location}")
return location, run_type

if "/" in model_path:
result = pull_model_from_hub(model_path)
else:
result = pull_model_from_official(model_path)
if hf == True:
result = pull_model_from_hf(model_path)
else:
if is_model_exists(model_path):
location, run_type = get_model_info(model_path)
print(f"Model {model_path} already exists at {location}")
return location, run_type

if "/" in model_path:
result = pull_model_from_hub(model_path)
else:
result = pull_model_from_official(model_path)

if result["success"]:
add_model_to_list(model_path, result["local_path"], result["model_type"], result["run_type"])
Expand Down Expand Up @@ -208,6 +213,18 @@ def pull_model_from_official(model_path):
"run_type": run_type_str
}

def pull_model_from_hf(repo_id):
repo_id, filename = select_gguf_in_hf_repo(repo_id)
success, model_path = download_gguf_from_hf(repo_id, filename)

# For beta version, we only support NLP gguf models
return {
"success": success,
"local_path": model_path,
"model_type": "gguf",
"run_type": "NLP"
}


def get_run_type_from_model_path(model_path):
model_name, _ = model_path.split(":")
Expand Down Expand Up @@ -309,6 +326,30 @@ def download_model_from_official(model_path, model_type):
print(f"An error occurred while downloading or processing the model: {e}")
return False, None

def download_gguf_from_hf(repo_id, filename):
try:
from huggingface_hub import hf_hub_download
from pathlib import Path
except ImportError:
print("The huggingface-hub package is required. Please install it with `pip install huggingface-hub`.")
return None

# Define the local directory to save the model
local_dir = NEXA_MODELS_HUB_HF_DIR / Path(repo_id)
local_dir.mkdir(parents=True, exist_ok=True)

# Download the model
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
local_files_only=False,
)
return True, model_path
except Exception as e:
print(f"Failed to download the model: {e}")
return False, None

def is_model_exists(model_name):
if not NEXA_MODEL_LIST_PATH.exists():
Expand Down Expand Up @@ -441,12 +482,65 @@ def clean():
except Exception as e:
print(f"An error occurred while cleaning the directory: {e}")

def select_gguf_in_hf_repo(repo_id: str) -> Tuple[str, str]:
"""
Lists all files ending with .gguf in the given Hugging Face repository,
prompts the user to select one, and returns the repo_id and the selected filename.
Args:
repo_id (str): The Hugging Face repository ID.
Returns:
Tuple[str, str]: A tuple containing the repo_id and the selected filename.
"""
try:
from huggingface_hub import HfFileSystem
from huggingface_hub.utils import validate_repo_id
from pathlib import Path
except ImportError:
print("The huggingface-hub package is required. Please install it with `pip install huggingface-hub`.")
exit(1)

validate_repo_id(repo_id)
hffs = HfFileSystem()

if __name__ == "__main__":
# login()
# whoami()
# logout()
# pull_model("phi3")
list_models()
# remove_model("phi3")
try:
files = [
file["name"] if isinstance(file, dict) else file
for file in hffs.ls(repo_id, recursive=True)
]
except Exception as e:
print(f"Error accessing repository '{repo_id}'. Please make sure you have access to the Hugging Face repository first.")
exit(1)

# Remove the repo prefix from files
file_list = []
for file in files:
rel_path = Path(file).relative_to(repo_id)
file_list.append(str(rel_path))

# Filter for files ending with .gguf
gguf_files = [file for file in file_list if file.endswith('.gguf')]

if not gguf_files:
print(f"No gguf models found in repository '{repo_id}'.")
exit(1)

print("Available gguf models in the repository:")
for i, file in enumerate(gguf_files, 1):
print(f"{i}. {file}")

# Prompt the user to select a file
while True:
try:
selected_index = int(input("Please enter the number of the model you want to download and use: "))
if 1 <= selected_index <= len(gguf_files):
filename = gguf_files[selected_index - 1]
print(f"You have selected: {filename}")
break
else:
print(f"Please enter a number between 1 and {len(gguf_files)}")
except ValueError:
print("Invalid input. Please enter a number.")

return repo_id, filename
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"uvicorn",
"pydantic",
"pillow",
"huggingface_hub",
"prompt_toolkit",
"tqdm", # Shared dependencies
"tabulate",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ uvicorn
pydantic
pillow
python-multipart
huggingface_hub

# For onnx
optimum[onnxruntime] # for CPU version
Expand Down

0 comments on commit 05e0538

Please sign in to comment.