Skip to content

Commit

Permalink
update tests to use sdk interface instead of from scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
xyyimian committed Aug 21, 2024
1 parent 5bc20f8 commit 85d6ed3
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 98 deletions.
84 changes: 62 additions & 22 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class NexaImageInference:
A class used for loading image models and running image generation.
Methods:
run_txt2img: Run the text-to-image generation loop.
run_img2img: Run the image-to-image generation loop.
txt2img: (Used for SDK) Run the text-to-image generation loop.
img2img: (Used for SDK) Run the image-to-image generation loop.
run_streamlit: Run the Streamlit UI.
Args:
Expand Down Expand Up @@ -109,7 +109,16 @@ def _save_images(self, images):
image.save(file_path)
logging.info(f"\nImage {i+1} saved to: {file_path}")

def txt2img(self, prompt, negative_prompt):
def txt2img(self,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9):
"""
Used for SDK. Generate images from text.
Expand All @@ -122,14 +131,14 @@ def txt2img(self, prompt, negative_prompt):
"""
images = self.model.txt_to_img(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else "",
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
width=width,
height=height,
sample_steps=sample_steps,
seed=seed,
control_cond=control_cond,
control_strength=control_strength,
)
return images

Expand All @@ -141,7 +150,17 @@ def run_txt2img(self):
"Enter your negative prompt (press Enter to skip): "
)
try:
images = self.txt2img(prompt, negative_prompt)
images = self.txt2img(
prompt,
negative_prompt,
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)
self._save_images(images)
except Exception as e:
logging.error(f"Error during text to image generation: {e}")
Expand All @@ -150,7 +169,17 @@ def run_txt2img(self):
except Exception as e:
logging.error(f"Error during generation: {e}", exc_info=True)

def img2img(self, image_path, prompt, negative_prompt):
def img2img(self,
image_path,
prompt,
negative_prompt="",
cfg_scale=7.5,
width=512,
height=512,
sample_steps=20,
seed=0,
control_cond="",
control_strength=0.9):
"""
Used for SDK. Generate images from an image.
Expand All @@ -165,14 +194,14 @@ def img2img(self, image_path, prompt, negative_prompt):
images = self.model.img_to_img(
image=image_path,
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else "",
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
negative_prompt=negative_prompt,
cfg_scale=cfg_scale,
width=width,
height=height,
sample_steps=sample_steps,
seed=seed,
control_cond=control_cond,
control_strength=control_strength,
)
return images

Expand All @@ -184,7 +213,18 @@ def run_img2img(self):
negative_prompt = nexa_prompt(
"Enter your negative prompt (press Enter to skip): "
)
images = self.img2img(image_path, prompt, negative_prompt)
images = self.img2img(image_path,
prompt,
negative_prompt,
cfg_scale=self.params["guidance_scale"],
width=self.params["width"],
height=self.params["height"],
sample_steps=self.params["num_inference_steps"],
seed=self.params["random_seed"],
control_cond=self.params.get("control_image_path", ""),
control_strength=self.params.get("control_strength", 0.9),
)

self._save_images(images)
except KeyboardInterrupt:
print(EXIT_REMINDER)
Expand Down
14 changes: 4 additions & 10 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class NexaTextInference:
Args:
model_path (str): Path or identifier for the model in Nexa Model Hub.
embedding (bool): Enable embedding generation.
stop_words (list): List of stop words for early stopping.
profiling (bool): Enable timing measurements for the generation process.
streamlit (bool): Run the inference in Streamlit UI.
Expand Down Expand Up @@ -83,27 +84,19 @@ def __init__(self, model_path, stop_words=None, **kwargs):
"Failed to load model or tokenizer. Exiting.", exc_info=True
)
exit(1)
def embed(
def create_embedding(
self,
input: Union[str, List[str]],
normalize: bool = False,
truncate: bool = True,
return_count: bool = False,
):
"""Embed a string.
Args:
input: The utf-8 encoded string or a list of string to embed.
normalize: whether to normalize embedding in embedding dimension.
trunca
truncate: whether to truncate tokens to window length before generating embedding.
return count: if true, return (embedding, count) tuple. else return embedding only.
Returns:
A list of embeddings
"""
return self.model.embed(input, normalize, truncate, return_count)
return self.model.create_embedding(input)

@SpinningCursorAnimation()
def _load_model(self):
Expand All @@ -112,6 +105,7 @@ def _load_model(self):
with suppress_stdout_stderr():
from nexa.gguf.llama.llama import Llama
self.model = Llama(
embedding=self.params.get("embedding", False),
model_path=self.downloaded_path,
verbose=self.profiling,
chat_format=self.chat_format,
Expand Down
2 changes: 1 addition & 1 deletion nexa/gguf/nexa_inference_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, model_path, stop_words=None, **kwargs):
)
exit(1)

@SpinningCursorAnimation()
# @SpinningCursorAnimation()
def _load_model(self):
logging.debug(f"Loading model from {self.downloaded_path}")
start_time = time.time()
Expand Down
35 changes: 11 additions & 24 deletions tests/test_image_generation.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,34 @@
import os
from nexa.gguf.sd import stable_diffusion
from tests.utils import download_model
from nexa.gguf import NexaImageInference
from tempfile import TemporaryDirectory
from .utils import download_model

# Constants
STABLE_DIFFUSION_URL = "https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF/resolve/main/stable-diffusion-v1-4-Q4_0.gguf"
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
OUTPUT_DIR = os.getcwd()
MODEL_PATH = download_model(STABLE_DIFFUSION_URL, OUTPUT_DIR)
sd = NexaImageInference(
model_path="sd1-4",
wtype="q4_0",
)


# Print the model path
print("Model downloaded to:", MODEL_PATH)

# Helper function for Stable Diffusion initialization
def init_stable_diffusion():
return stable_diffusion.StableDiffusion(
model_path=MODEL_PATH,
wtype="q4_0" # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
)

# Test text-to-image generation
def test_txt_to_img():
sd = init_stable_diffusion()
output = sd.txt_to_img("a lovely cat", width=128, height=128, sample_steps=2)
global sd
output = sd.txt2img("a lovely cat", width=128, height=128, sample_steps=2)
output[0].save("output_txt_to_img.png")

# Test image-to-image generation
def test_img_to_img():

sd = init_stable_diffusion()
global sd
img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
with TemporaryDirectory() as temp_dir:
img_path = download_model(img_url, temp_dir)
output = sd.img_to_img(
image=img_path,
output = sd.img2img(
image_path=img_path,
prompt="blue sky",
width=128,
height=128,
negative_prompt="black soil",
sample_steps=2
)
output[0].save("output_img_to_img.png")

# Main execution
# if __name__ == "__main__":
Expand Down
44 changes: 22 additions & 22 deletions tests/test_text_generation.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
import os
from nexa.gguf.llama import llama
from tests.utils import download_model
from nexa.gguf import NexaTextInference
from nexa.gguf.lib_utils import is_gpu_available
# Constants
TINY_LLAMA_URL = "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
OUTPUT_DIR = os.getcwd()
MODEL_PATH = download_model(TINY_LLAMA_URL, OUTPUT_DIR)

# Initialize Llama model
def init_llama_model(verbose=False, n_gpu_layers=-1, chat_format=None, embedding=False):
return llama.Llama(
model_path=MODEL_PATH,
verbose=verbose,
n_gpu_layers=n_gpu_layers if is_gpu_available() else 0,
chat_format=chat_format,
embedding=embedding,
)
model = NexaTextInference(
model_path="gemma",
verbose=False,
n_gpu_layers=-1 if is_gpu_available() else 0,
chat_format="llama-2",
)

# Test text generation from a prompt
def test_text_generation():
model = init_llama_model()
output = model(
global model
output = model.create_completion(
"Q: Name the planets in the solar system? A: ",
max_tokens=512,
stop=["Q:", "\n"],
echo=True,
)
print(output)
# print(output)
# TODO: add assertions here

# Test chat completion in streaming mode
def test_streaming():
model = init_llama_model()
global model
output = model.create_completion(
"Q: Name the planets in the solar system? A: ",
max_tokens=512,
Expand All @@ -40,10 +32,12 @@ def test_streaming():
for chunk in output:
if "choices" in chunk:
print(chunk["choices"][0]["text"], end="", flush=True)
# TODO: add assertions here

# Test conversation mode with chat format
def test_create_chat_completion():
model = init_llama_model(chat_format="llama-2")
global model

output = model.create_chat_completion(
messages=[
{"role": "user", "content": "write a long 1000 word story about a detective"}
Expand All @@ -58,7 +52,13 @@ def test_create_chat_completion():
print(delta["content"], end="", flush=True)

def test_create_embedding():
model = init_llama_model(embedding=True)
model = NexaTextInference(
model_path="gemma",
verbose=False,
n_gpu_layers=-1 if is_gpu_available() else 0,
chat_format="llama-2",
embedding=True,
)
embeddings = model.create_embedding("Hello, world!")
print("Embeddings:\n", embeddings)

Expand Down
19 changes: 0 additions & 19 deletions tests/test_vlm.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,8 @@
import base64
import os

from nexa.gguf import NexaVLMInference
from tests.utils import download_model
from nexa.gguf.lib_utils import is_gpu_available
import tempfile

def image_to_base64_data_uri(file_path):
"""
file_path = 'file_path.png'
data_uri = image_to_base64_data_uri(file_path)
"""
with open(file_path, "rb") as img_file:
base64_data = base64.b64encode(img_file.read()).decode("utf-8")
return f"data:image/png;base64,{base64_data}"


def test_image_generation():
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = os.path.dirname(os.path.abspath(__file__))
model_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/model-fp16.gguf"
mmproj_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/projector-fp16.gguf"

model = NexaVLMInference(
model_path="nanollava",
)
Expand Down

0 comments on commit 85d6ed3

Please sign in to comment.