diff --git a/.gitignore b/.gitignore index e2e3b4f..a6d289e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ packages *.pickle *.json *.npy +*.png models.txt diff --git a/image/Dockerfile b/image/Dockerfile new file mode 100644 index 0000000..d8de12d --- /dev/null +++ b/image/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3.10.12 + +WORKDIR /app + +HEALTHCHECK --interval=15s --timeout=5s --start-period=30s --start-interval=30s --retries=15 CMD curl --silent --fail http://localhost/ > /dev/null || exit 1 + +COPY requirements.txt . +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +COPY main.py . + +ENTRYPOINT ["uvicorn", "main:app", "--port", "80", "--host", "0.0.0.0"] diff --git a/image/main.py b/image/main.py new file mode 100644 index 0000000..5a61f69 --- /dev/null +++ b/image/main.py @@ -0,0 +1,71 @@ +import time +from enum import Enum +from fastapi import FastAPI +import base64 +from io import BytesIO +import os +from huggingface_hub import HfApi +import diffusers +import torch + +from pydantic import BaseModel + +MODEL_NAME = os.getenv("MODEL", None) +if MODEL_NAME is None: + print("Missing model name") + exit() + +print("loading api") +api = HfApi() +model = api.model_info(MODEL_NAME) +if model is None or model.config is None: + "Cant find model" + exit() + + +diffuser_class = model.config["diffusers"]["_class_name"] +diffuser = getattr(diffusers, diffuser_class) +print("loading from pretrained") +model = diffuser.from_pretrained(MODEL_NAME) +print("moving to cuda") +model.to("cuda") + + +app = FastAPI() + + +class Sizes(Enum): + SMALL = "256x256" + MEDIUM = "512x512" + LARGE = "1024x1024" + EXTRA_WIDE = "1792x1024" + EXTRA_TALL = "1024x1792" + + +class ImageRequest(BaseModel): + prompt: str + model: str + size: Sizes + + +@app.post("/v1/images/generations") +async def generate_question(req: ImageRequest): + generator = torch.Generator(device="cuda").manual_seed(4) + width, height = req.size.value.split("x") + image = model( + prompt=req.prompt, height=int(height), width=int(width), generator=generator + ) + print(image) + image = image.images[0] + buffered = BytesIO() + image.save(buffered, format="png") + img_str = base64.b64encode(buffered.getvalue()) + return {"created": time.time(), "data": [{"b64_json": img_str}]} + + +@app.get("/") +def ping(): + return "", 200 + + +print("Starting fastapi") diff --git a/image/requirements.txt b/image/requirements.txt new file mode 100644 index 0000000..ebf20d7 --- /dev/null +++ b/image/requirements.txt @@ -0,0 +1,8 @@ +diffusers[torch]==0.31.0 +safetensors==0.4.5 +sentencepiece==0.2.0 +protobuf==5.28.3 +fastapi==0.115.0 +uvicorn==0.30.6 +transformers==4.46.2 +huggingface_hub==0.27.0 diff --git a/justfile b/justfile index f18c9c2..571937a 100644 --- a/justfile +++ b/justfile @@ -28,3 +28,9 @@ run_verifier_prod model port gpu gpus name memory_util='.9' tag='latest': push_verifier: build_verifier docker push manifoldlabs/sn4-verifier:latest + +image: + docker run -p 80:80 -v /var/targon/huggingface/cache:/root/.cache/huggingface -e MODEL=black-forest-labs/FLUX.1-schnell -d --gpus all --name image image + +build_image: + cd image && docker build . -t image diff --git a/neurons/miner.py b/neurons/miner.py index a913e68..d31bf9d 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -8,6 +8,7 @@ import requests from starlette.background import BackgroundTask from starlette.responses import StreamingResponse +from starlette.responses import Response from neurons.base import BaseNeuron, NeuronType from targon.epistula import verify_signature @@ -77,6 +78,19 @@ async def create_completion(self, request: Request): return StreamingResponse( r.aiter_raw(), background=BackgroundTask(r.aclose), headers=r.headers ) + + async def create_image_completion(self, request: Request): + bt.logging.info( + "\u2713", + f"Getting Image Completion request from {request.headers.get('Epistula-Signed-By', '')[:8]}!", + ) + req = self.client.build_request( + "POST", "/images/generations", content=await request.body(), timeout=httpx.Timeout(300.0) + ) + r = await self.client.send(req) + return Response( + r.content, headers=r.headers + ) async def receive_models(self, request: Request): models = await request.json() @@ -194,6 +208,12 @@ def run(self): dependencies=[Depends(self.determine_epistula_version_and_verify)], methods=["POST"], ) + router.add_api_route( + "/v1/images/generations", + self.create_image_completion, + dependencies=[Depends(self.determine_epistula_version_and_verify)], + methods=["POST"], + ) router.add_api_route( "/models", self.receive_models, diff --git a/tests/client.py b/tests/client.py new file mode 100644 index 0000000..070eff7 --- /dev/null +++ b/tests/client.py @@ -0,0 +1,77 @@ +import requests +import io +from PIL import Image + +class OmniGenClient: + def __init__(self, base_url="http://103.219.171.95:8000"): + self.base_url = base_url.rstrip('/') + + def ping(self): + """Test the connection to the server""" + try: + response = requests.get(f"{self.base_url}/") + return response.json() + except requests.RequestException as e: + return {"error": str(e)} + + def generate_image(self, prompt, height=1024, width=1024, guidance_scale=2.5, seed=0, save_path=None): + """Generate an image using the OmniGen model""" + try: + payload = { + "prompt": prompt, + "height": height, + "width": width, + "guidance_scale": guidance_scale, + "seed": seed + } + + response = requests.post(f"{self.base_url}/generate", json=payload) + + if response.status_code == 200: + # Convert the response content to a PIL Image + image = Image.open(io.BytesIO(response.content)) + + # Save the image if a path is provided + if save_path: + image.save(save_path) + print(f"Image saved to: {save_path}") + + return image + else: + return {"error": f"Request failed with status code: {response.status_code}"} + + except requests.RequestException as e: + return {"error": str(e)} + +def main(): + # Example usage + client = OmniGenClient() + + # Test the connection + print("Testing connection...") + result = client.ping() + print(f"Server response: {result}") + + # Generate an image + print("\nGenerating image...") + prompt = "a beautiful sunset over mountains" + + # You can now specify a custom save path + save_path = "./output/my_generated_image.png" + image = client.generate_image(prompt, save_path=save_path) + + if isinstance(image, Image.Image): + print("Image generated successfully!") + + # You can also perform additional operations on the image + # For example, display it (if running in a notebook): + # image.show() + + # Or resize it: + # resized_image = image.resize((512, 512)) + # resized_image.save("resized_image.png") + else: + print(f"Error generating image: {image.get('error')}") + +if __name__ == "__main__": + main() diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000..86600e9 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/VectorSpaceLab/OmniGen.git diff --git a/tests/server.py b/tests/server.py new file mode 100644 index 0000000..58214f6 --- /dev/null +++ b/tests/server.py @@ -0,0 +1,47 @@ +import uvicorn +from OmniGen import OmniGenPipeline +from fastapi import FastAPI +from pydantic import BaseModel +import io +from fastapi.responses import StreamingResponse + +app = FastAPI() + +# Create Pydantic model for request body +class ImageRequest(BaseModel): + prompt: str + height: int = 1024 + width: int = 1024 + guidance_scale: float = 2.5 + seed: int = 0 + +# Initialize the pipeline globally +pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1") + +@app.get("/") +def read_root(): + return {"message": "Hello, World!"} + +@app.post("/generate") +async def generate_image(request: ImageRequest): + # Generate image using the pipeline + images = pipe( + prompt=request.prompt, + height=request.height, + width=request.width, + guidance_scale=request.guidance_scale, + seed=request.seed, + ) + + # Convert PIL Image to bytes + img_byte_arr = io.BytesIO() + images[0].save(img_byte_arr, format='PNG') + img_byte_arr.seek(0) + + # Return the image as a streaming response + return StreamingResponse(img_byte_arr, media_type="image/png") + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) + + diff --git a/verifier/requirements.txt b/verifier/requirements.txt index 2bc3b40..2da54a3 100644 --- a/verifier/requirements.txt +++ b/verifier/requirements.txt @@ -1,4 +1,11 @@ -vllm==0.6.2 +# LLM +#vllm==0.6.2 fastapi==0.115.0 openai==1.44.1 uvicorn==0.30.6 + +# Image +diffusers[torch]==0.31.0 +safetensors==0.4.5 +sentencepiece==0.2.0 +protobuf==5.28.3 diff --git a/verifier_new/Dockerfile b/verifier_new/Dockerfile new file mode 100644 index 0000000..815e18b --- /dev/null +++ b/verifier_new/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.9 +WORKDIR /app + +COPY ./requirements.txt requirements.txt +RUN pip install --no-cache-dir --upgrade -r requirements.txt +COPY ./verifier.py . + +HEALTHCHECK --interval=15s --timeout=5s --start-period=30s --start-interval=30s --retries=15 CMD curl --silent --fail http://localhost/ > /dev/null || exit 1 + +ENTRYPOINT ["uvicorn", "verifier:app", "--port", "80", "--host", "0.0.0.0"] diff --git a/verifier_new/image.py b/verifier_new/image.py new file mode 100644 index 0000000..d270f9f --- /dev/null +++ b/verifier_new/image.py @@ -0,0 +1,31 @@ +#################################### +# ___ +# |_ _|_ __ ___ __ _ __ _ ___ +# | || '_ ` _ \ / _` |/ _` |/ _ \ +# | || | | | | | (_| | (_| | __/ +# |___|_| |_| |_|\__,_|\__, |\___| +# |___/ +# +#################################### + + +import base64 +from io import BytesIO + + +## TODO build verification for images +def generate_image_functions(MODEL_WRAPPER,MODEL_NAME,ENDPOINTS): + + ## cache this across requests for the same inputs + async def generate_ground_truth(prompt,width, height): + image = MODEL_WRAPPER(prompt, height=int(height), width=int(width)).images[0] # type: ignore + buffered = BytesIO() + image.save(buffered, format="png") + img_str = base64.b64encode(buffered.getvalue()) + return img_str + + async def verify(ground_truth, miner_response): + pass + + return verify + diff --git a/verifier_new/llm.py b/verifier_new/llm.py new file mode 100644 index 0000000..2b7fabe --- /dev/null +++ b/verifier_new/llm.py @@ -0,0 +1,415 @@ +#################################### +# _ _ __ __ +# | | | | | \/ | +# | | | | | |\/| | +# | |___| |___| | | | +# |_____|_____|_| |_| +# +#################################### + +import random +import math +import os +import traceback +from pydantic import BaseModel +from enum import Enum +from typing import Dict, List, Optional, Tuple +from vllm import SamplingParams + +LOGPROB_LOG_THRESHOLD = 0.65 +LOGPROB_FAILURE_THRESHOLD = 0.75 + +MODEL_NAME = os.getenv("MODEL_NAME", None) +if MODEL_NAME is None: + exit() + +class RequestParams(BaseModel): + messages: Optional[List[Dict[str, str]]] = None + prompt: Optional[str] = None + temperature: float = 0.0 + seed: int = 42 + max_tokens: int + + +class OutputItem(BaseModel): + text: str + logprob: float + token_id: int + + +class RequestType(Enum): + CHAT = "CHAT" + COMPLETION = "COMPLETION" + + +class VerificationRequest(BaseModel): + request_type: str + model: str = MODEL_NAME + request_params: RequestParams + output_sequence: List[OutputItem] + + +class RequestSamplingParams(BaseModel): + temperature: float = 0.0 + seed: int = 42 + max_tokens: int + + +class GenerateRequest(BaseModel): + messages: List[Dict[str, str]] + sampling_params: RequestSamplingParams + +def get_llm_functions(MODEL_WRAPPER, TOKENIZER, LOCK, LOCK_GENERATE, ENDPOINTS): + async def verify(request: VerificationRequest) -> Dict: + """Verify a miner's output.""" + if MODEL_WRAPPER is None or TOKENIZER is None: + return { + "error": f"Unable to verify model={request.model}, since we are using {MODEL_NAME}", + "cause": "INTERNAL_ERROR", + } + + # If the miner didn't return any outputs, fail. + if len(request.output_sequence) < 3: + return { + "verified": False, + "error": "Output sequence too short!", + "cause": "TOO_SHORT", + } + if ( + request.request_params.max_tokens + and len(request.output_sequence) > request.request_params.max_tokens + ): + return { + "verified": False, + "error": f"Too many tokens produced: {request.request_params.max_tokens} < {len(request.output_sequence)}", + "cause": "TOO_LONG", + } + if request.model != MODEL_NAME: + return { + "error": f"Unable to verify model={request.model}, since we are using {MODEL_NAME}", + "cause": "INTERNAL_ERROR", + } + + # Tokenize the input sequence. + input_text = ( + request.request_params.prompt + if request.request_type == RequestType.COMPLETION.value + else TOKENIZER.apply_chat_template( + request.request_params.messages, # type: ignore + tokenize=False, + add_special_tokens=False, + add_generation_prompt=True, + ) + ) + assert isinstance(input_text, str) + if hasattr(TOKENIZER, "bos_token"): + if input_text.startswith(TOKENIZER.bos_token): # type: ignore + input_text = input_text[len(TOKENIZER.bos_token) :] # type: ignore + input_tokens = TOKENIZER(input_text).input_ids + + # Verify! + async with LOCK: + return_value = { + "verified": False, + "error": None, + } + + # Logprob checks. + res = verify_logprobs(request, str(input_text), input_tokens) + if res is None: + return {"error": "Failed to check log probs", "cause": "INTERNAL_ERROR"} + result, message, cause = res + return_value.update( + { + "verified": result, + "cause": cause, + "error": message, + } + ) + if not result: + return return_value + + # Random logprob check. + if request.request_params.temperature > 0.75: + return {"verified": True} + + res = verify_logprobs_random(request, str(input_text)) + if res is None: + return { + "error": "Failed to check log probs", + "cause": "INTERNAL_ERROR", + } + result, message = res + return_value.update( + { + "verified": result, + "cause": "LOGPROB_RANDOM", + "error": message, + } + ) + if not result: + return return_value + + return {"verified": True} + + async def generate_question(req: GenerateRequest): + if MODEL_WRAPPER is None: + print("Failed generate request, endpoint not supported") + return {"text": None} + async with LOCK_GENERATE: + try: + if "chat" in ENDPOINTS: + output = ( + MODEL_WRAPPER.chat( + messages=req.messages, sampling_params=SamplingParams(**req.sampling_params.model_dump()), use_tqdm=False # type: ignore + )[0] + .outputs[0] + .text + ) + else: + prompt = "" + for message in req.messages: + prompt += ( + message.get("role", "") + + ": " + + message.get("content", "") + + "\n" + ) + prompt += "\nResponse: " + output = ( + MODEL_WRAPPER.generate( + prompts=prompt, + sampling_params=SamplingParams( + **req.sampling_params.model_dump() + ), + use_tqdm=False, + )[0] + .outputs[0] + .text + ) + return {"text": output} + except Exception as e: + print("Failed generate request", str(e), traceback.format_exc()) + return {"text": None} + + def verify_logprobs_random( + request: VerificationRequest, input_text: str + ) -> Tuple[bool, str]: + """ + Generate a handful of random outputs to ensure the logprobs weren't generated after the fact. + """ + if MODEL_WRAPPER is None or TOKENIZER is None: + message = "Failed generate request, endpoint not supported" + print(message) + return False, message + indices = list(range(1, len(request.output_sequence) - 1)) + indices_to_check = list( + sorted( + [ + 0, # always check first token + len(request.output_sequence) - 1, # always check last token + ] + + random.sample(indices, min(len(indices), 3)) + ) + ) + + # Generate a single token at each index, comparing logprobs. + top_logprobs = int(request.request_params.temperature * 10) + 3 + sampling_params = SamplingParams( + temperature=request.request_params.temperature, + seed=request.request_params.seed, + max_tokens=1, + logprobs=top_logprobs, + ) + for idx in indices_to_check: + full_text = input_text + "".join( + [item.text for item in request.output_sequence[0:idx]] + ) + output = MODEL_WRAPPER.generate([full_text], sampling_params, use_tqdm=False)[ + 0 + ].outputs[0] + + # The miner's output token should be in the logprobs... + top_tokens = [] + if output.logprobs is None: + print("No log probs to check") + continue + for lp in output.logprobs: + top_tokens += list(lp.keys()) + if request.output_sequence[idx].token_id not in top_tokens: + message = f"Token output at index {idx} [{TOKENIZER.decode([request.output_sequence[idx].token_id])}] not found in top {top_logprobs} logprobs: {[TOKENIZER.decode([token]) for token in top_tokens]}" + return False, message + return ( + True, + f"Successfully verified {len(indices_to_check)} random logprobs: {indices_to_check}", + ) + + + def verify_logprobs( + request: VerificationRequest, input_text: str, input_tokens: List[int] + ) -> Optional[Tuple[bool, str, str]]: + """ + Compare the produced logprob values against the ground truth, or at least + the ground truth according to this particular GPU/software pairing. + """ + if MODEL_WRAPPER is None or TOKENIZER is None: + message = "Failed generate request, endpoint not supported" + print(message) + return None + + # Set up sampling parameters for the "fast" check, which just compares input logprobs against output logprobs. + top_logprobs = int(request.request_params.temperature * 10) + 6 + sampling_params = SamplingParams( + temperature=request.request_params.temperature, + seed=request.request_params.seed, + max_tokens=1, + logprobs=top_logprobs, + prompt_logprobs=top_logprobs, + ) + + # Generate output for a single token, which will return input logprobs based on prompt_logprobs=1 + output = None + for _ in range(5): + full_text = input_text + "".join( + [item.text for item in request.output_sequence] + ) + output = MODEL_WRAPPER.generate([full_text], sampling_params, use_tqdm=False)[0] + if output.prompt_logprobs is not None: + break + + if not output or output.prompt_logprobs is None: + return None + + # The actual logprobs should be *very* close, but typically not 100% because of GPU/driver/etc. differences. + total_score = 0.0 + idxs = min( + len(output.prompt_logprobs) - len(input_tokens) - 3, + len(request.output_sequence) - 1, + ) + perfect_tokens = 0 + eos_token_id = getattr(TOKENIZER, "eos_token_id", -1) + eot_token_id = TOKENIZER.get_vocab().get("<|eot_id|>", -1) # type: ignore + output_tokens = [item.token_id for item in request.output_sequence] + really_low_prob = 0 + not_first = 0 + for idx in range(idxs): + item = request.output_sequence[idx] + expected_logprob = output.prompt_logprobs[idx + len(input_tokens)] + assert expected_logprob is not None + eos_logprob = expected_logprob.get(eos_token_id) + eot_logprob = expected_logprob.get(eot_token_id) + if ( + not eos_logprob + and eot_logprob + or ( + eos_logprob + and eot_logprob + and eot_logprob.rank != None + and eos_logprob.rank != None + and eot_logprob.rank < eos_logprob.rank + ) + ): + eos_logprob = eot_logprob + expected_logprob = expected_logprob.get(item.token_id) + if eos_logprob and ( + not expected_logprob + or ( + eos_logprob + and expected_logprob.rank != None + and eos_logprob.rank != None + and eos_logprob.rank < expected_logprob.rank + and expected_logprob.rank > 10 + ) + ): + return False, f"Expected EOS/EOT token at index {idx}", "SKIPPED_EOS_EOT" + if expected_logprob is None: + continue + rank = expected_logprob.rank + assert rank != None + if rank >= 75: + return ( + False, + f"Found extraordinarily improbable token '{TOKENIZER.decode([item.token_id])}' at index {idx}: {rank=}", + "UNLIKELY_TOKEN", + ) + elif rank >= 25: + really_low_prob += 1 + elif rank > top_logprobs: + continue + if rank != 1: + not_first += 1 + expected_logprob = expected_logprob.logprob + produced_logprob = item.logprob + score = 1.0 - min( + 1.0, abs(math.exp(expected_logprob) - math.exp(produced_logprob)) + ) + + # Prevents over fitting smaller models + if produced_logprob == 0: + perfect_tokens += 1 + + # To accomodate architectural difference and such, we'll give a perfect score if >= 0.9 + if score >= 0.9: + score = 1.0 + + # Logprobs rarely match well for high temps so we can use rank instead. + if ( + rank == 1 + and request.request_params.temperature >= 0.9 + and produced_logprob != 0 + ): + score = 1.0 + + total_score += score + + # Check if miner produced non-top ranking tokens more than top-ranking tokens. + ratio = not_first / len(output_tokens) + if ratio >= 0.5: + return ( + False, + f"{not_first} of {len(output_tokens)} [{ratio=}] tokens were not rank 1.", + "UNLIKELY_TOKENS", + ) + + # Check if miner prematurely stopped generating, meaning the single output token generated + # from the "throwaway" above was NOT an EOS/EOT token. + if eos_token_id > 0 or eot_token_id > 0: + if len(output_tokens) < request.request_params.max_tokens: + last_token_probs = [] + if output: + last_token_probs = output.outputs[0] + last_token_probs = ( + last_token_probs.logprobs[0] + if last_token_probs and last_token_probs.logprobs + else [] + ) + if ( + eos_token_id not in last_token_probs + and eot_token_id not in last_token_probs + and len(last_token_probs) != 0 + ): + return ( + False, + "Premature end of generation, EOS/EOT unlikely after last token.", + "EARLY_END", + ) + + # Calculate average score. + average_score = round(total_score / idxs, 5) + passes = average_score >= LOGPROB_FAILURE_THRESHOLD + perfect_avg = round(perfect_tokens / idxs, 5) + if passes and perfect_avg >= ( + 1 - min(request.request_params.temperature * 0.5, 0.6) + ): + return False, f"Overfitted response tokens. {perfect_avg}% perfect", "OVERFIT" + if really_low_prob >= 5: + return ( + False, + f"Found {really_low_prob} highly improbable tokens.", + "UNLIKELY_TOKEN", + ) + + return True, "", "" + return verify, generate_question + + diff --git a/verifier_new/requirements.txt b/verifier_new/requirements.txt new file mode 100644 index 0000000..e4c09e5 --- /dev/null +++ b/verifier_new/requirements.txt @@ -0,0 +1,11 @@ +# LLM +vllm==0.6.2 +fastapi==0.115.0 +openai==1.44.1 +uvicorn==0.30.6 + +# Image +diffusers[torch]==0.31.0 +safetensors==0.4.5 +sentencepiece==0.2.0 +protobuf==5.28.3 diff --git a/verifier_new/verifier.py b/verifier_new/verifier.py new file mode 100644 index 0000000..5fe8ed6 --- /dev/null +++ b/verifier_new/verifier.py @@ -0,0 +1,72 @@ +import os +import asyncio +from fastapi import FastAPI +from fastapi.routing import APIRouter +from vllm import LLM +from huggingface_hub import HfApi +import importlib + +from verifier_new.image import generate_image_functions +from verifier_new.llm import get_llm_functions + +# Load the model. +MODEL_NAME = os.getenv("MODEL_NAME", None) +if MODEL_NAME is None: + exit() + +api = HfApi() +model = api.model_info(MODEL_NAME) +if model is None or model.config is None: + exit() + +# Lock to ensure atomicity. +LOCK = asyncio.Lock() +LOCK_GENERATE = asyncio.Lock() + +ENDPOINTS = [] + +app = FastAPI() +router = APIRouter() + +match model.pipeline_tag: + case "text-to-image": + ENDPOINTS.append("image") + diffuser_class = model.config["diffusers"]["_class_name"] + diffuser = importlib.import_module(f"diffusers.{diffuser_class}") + MODEL_WRAPPER = diffuser.from_pretrained(MODEL_NAME) + MODEL_WRAPPER.to("cuda") + verify = generate_image_functions(MODEL_WRAPPER, MODEL_NAME, ENDPOINTS) + router.add_api_route("/image/verify", verify, methods=["POST"]) + case "text-generation": + TENSOR_PARALLEL = int(os.getenv("TENSOR_PARALLEL", 1)) + ENDPOINTS.append("completion") + MODEL_WRAPPER = LLM( + model=MODEL_NAME, + enforce_eager=True, + gpu_memory_utilization=1, + tensor_parallel_size=TENSOR_PARALLEL, + ) + TOKENIZER = MODEL_WRAPPER.get_tokenizer() + if TOKENIZER.chat_template is not None: + ENDPOINTS.append("chat") + verify, generate_question = get_llm_functions( + MODEL_WRAPPER, TOKENIZER, LOCK, LOCK_GENERATE, ENDPOINTS + ) + router.add_api_route("/llm/generate", generate_question, methods=["POST"]) + router.add_api_route("/llm/verify", verify, methods=["POST"]) + case _: + print(f"Unknown pipeline {model.pipeline_tag}") + exit() + + +app.include_router(router) + + +@app.get("/endpoints") +def endpoints(): + return ENDPOINTS + + +@app.get("/") +def ping(): + return "", 200