Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image Support #67

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ packages
*.pickle
*.json
*.npy
*.png
models.txt
12 changes: 12 additions & 0 deletions image/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM python:3.10.12

WORKDIR /app

HEALTHCHECK --interval=15s --timeout=5s --start-period=30s --start-interval=30s --retries=15 CMD curl --silent --fail http://localhost/ > /dev/null || exit 1

COPY requirements.txt .
RUN pip install --no-cache-dir --upgrade -r requirements.txt

COPY main.py .

ENTRYPOINT ["uvicorn", "main:app", "--port", "80", "--host", "0.0.0.0"]
71 changes: 71 additions & 0 deletions image/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import time
from enum import Enum
from fastapi import FastAPI
import base64
from io import BytesIO
import os
from huggingface_hub import HfApi
import diffusers
import torch

from pydantic import BaseModel

MODEL_NAME = os.getenv("MODEL", None)
if MODEL_NAME is None:
print("Missing model name")
exit()

print("loading api")
api = HfApi()
model = api.model_info(MODEL_NAME)
if model is None or model.config is None:
"Cant find model"
exit()


diffuser_class = model.config["diffusers"]["_class_name"]
diffuser = getattr(diffusers, diffuser_class)
print("loading from pretrained")
model = diffuser.from_pretrained(MODEL_NAME)
print("moving to cuda")
model.to("cuda")


app = FastAPI()


class Sizes(Enum):
SMALL = "256x256"
MEDIUM = "512x512"
LARGE = "1024x1024"
EXTRA_WIDE = "1792x1024"
EXTRA_TALL = "1024x1792"


class ImageRequest(BaseModel):
prompt: str
model: str
size: Sizes


@app.post("/v1/images/generations")
async def generate_question(req: ImageRequest):
generator = torch.Generator(device="cuda").manual_seed(4)
width, height = req.size.value.split("x")
image = model(
prompt=req.prompt, height=int(height), width=int(width), generator=generator
)
print(image)
image = image.images[0]
buffered = BytesIO()
image.save(buffered, format="png")
img_str = base64.b64encode(buffered.getvalue())
return {"created": time.time(), "data": [{"b64_json": img_str}]}


@app.get("/")
def ping():
return "", 200


print("Starting fastapi")
8 changes: 8 additions & 0 deletions image/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
diffusers[torch]==0.31.0
safetensors==0.4.5
sentencepiece==0.2.0
protobuf==5.28.3
fastapi==0.115.0
uvicorn==0.30.6
transformers==4.46.2
huggingface_hub==0.27.0
6 changes: 6 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ run_verifier_prod model port gpu gpus name memory_util='.9' tag='latest':

push_verifier: build_verifier
docker push manifoldlabs/sn4-verifier:latest

image:
docker run -p 80:80 -v /var/targon/huggingface/cache:/root/.cache/huggingface -e MODEL=black-forest-labs/FLUX.1-schnell -d --gpus all --name image image

build_image:
cd image && docker build . -t image
20 changes: 20 additions & 0 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests
from starlette.background import BackgroundTask
from starlette.responses import StreamingResponse
from starlette.responses import Response

from neurons.base import BaseNeuron, NeuronType
from targon.epistula import verify_signature
Expand Down Expand Up @@ -77,6 +78,19 @@ async def create_completion(self, request: Request):
return StreamingResponse(
r.aiter_raw(), background=BackgroundTask(r.aclose), headers=r.headers
)

async def create_image_completion(self, request: Request):
bt.logging.info(
"\u2713",
f"Getting Image Completion request from {request.headers.get('Epistula-Signed-By', '')[:8]}!",
)
req = self.client.build_request(
"POST", "/images/generations", content=await request.body(), timeout=httpx.Timeout(300.0)
)
r = await self.client.send(req)
return Response(
r.content, headers=r.headers
)

async def receive_models(self, request: Request):
models = await request.json()
Expand Down Expand Up @@ -194,6 +208,12 @@ def run(self):
dependencies=[Depends(self.determine_epistula_version_and_verify)],
methods=["POST"],
)
router.add_api_route(
"/v1/images/generations",
self.create_image_completion,
dependencies=[Depends(self.determine_epistula_version_and_verify)],
methods=["POST"],
)
router.add_api_route(
"/models",
self.receive_models,
Expand Down
77 changes: 77 additions & 0 deletions tests/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import requests
import io
from PIL import Image

class OmniGenClient:
def __init__(self, base_url="http://103.219.171.95:8000"):
self.base_url = base_url.rstrip('/')

def ping(self):
"""Test the connection to the server"""
try:
response = requests.get(f"{self.base_url}/")
return response.json()
except requests.RequestException as e:
return {"error": str(e)}

def generate_image(self, prompt, height=1024, width=1024, guidance_scale=2.5, seed=0, save_path=None):
"""Generate an image using the OmniGen model"""
try:
payload = {
"prompt": prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
"seed": seed
}

response = requests.post(f"{self.base_url}/generate", json=payload)

if response.status_code == 200:
# Convert the response content to a PIL Image
image = Image.open(io.BytesIO(response.content))

# Save the image if a path is provided
if save_path:
image.save(save_path)
print(f"Image saved to: {save_path}")

return image
else:
return {"error": f"Request failed with status code: {response.status_code}"}

except requests.RequestException as e:
return {"error": str(e)}

def main():
# Example usage
client = OmniGenClient()

# Test the connection
print("Testing connection...")
result = client.ping()
print(f"Server response: {result}")

# Generate an image
print("\nGenerating image...")
prompt = "a beautiful sunset over mountains"

# You can now specify a custom save path
save_path = "./output/my_generated_image.png"
image = client.generate_image(prompt, save_path=save_path)

if isinstance(image, Image.Image):
print("Image generated successfully!")

# You can also perform additional operations on the image
# For example, display it (if running in a notebook):
# image.show()

# Or resize it:
# resized_image = image.resize((512, 512))
# resized_image.save("resized_image.png")
else:
print(f"Error generating image: {image.get('error')}")

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
git+https://github.com/VectorSpaceLab/OmniGen.git
47 changes: 47 additions & 0 deletions tests/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import uvicorn
from OmniGen import OmniGenPipeline
from fastapi import FastAPI
from pydantic import BaseModel
import io
from fastapi.responses import StreamingResponse

app = FastAPI()

# Create Pydantic model for request body
class ImageRequest(BaseModel):
prompt: str
height: int = 1024
width: int = 1024
guidance_scale: float = 2.5
seed: int = 0

# Initialize the pipeline globally
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")

@app.get("/")
def read_root():
return {"message": "Hello, World!"}

@app.post("/generate")
async def generate_image(request: ImageRequest):
# Generate image using the pipeline
images = pipe(
prompt=request.prompt,
height=request.height,
width=request.width,
guidance_scale=request.guidance_scale,
seed=request.seed,
)

# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
images[0].save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)

# Return the image as a streaming response
return StreamingResponse(img_byte_arr, media_type="image/png")

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)


9 changes: 8 additions & 1 deletion verifier/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
vllm==0.6.2
# LLM
#vllm==0.6.2
fastapi==0.115.0
openai==1.44.1
uvicorn==0.30.6

# Image
diffusers[torch]==0.31.0
safetensors==0.4.5
sentencepiece==0.2.0
protobuf==5.28.3
10 changes: 10 additions & 0 deletions verifier_new/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM python:3.9
WORKDIR /app

COPY ./requirements.txt requirements.txt
RUN pip install --no-cache-dir --upgrade -r requirements.txt
COPY ./verifier.py .

HEALTHCHECK --interval=15s --timeout=5s --start-period=30s --start-interval=30s --retries=15 CMD curl --silent --fail http://localhost/ > /dev/null || exit 1

ENTRYPOINT ["uvicorn", "verifier:app", "--port", "80", "--host", "0.0.0.0"]
31 changes: 31 additions & 0 deletions verifier_new/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
####################################
# ___
# |_ _|_ __ ___ __ _ __ _ ___
# | || '_ ` _ \ / _` |/ _` |/ _ \
# | || | | | | | (_| | (_| | __/
# |___|_| |_| |_|\__,_|\__, |\___|
# |___/
#
####################################


import base64
from io import BytesIO


## TODO build verification for images
def generate_image_functions(MODEL_WRAPPER,MODEL_NAME,ENDPOINTS):

## cache this across requests for the same inputs
async def generate_ground_truth(prompt,width, height):
image = MODEL_WRAPPER(prompt, height=int(height), width=int(width)).images[0] # type: ignore
buffered = BytesIO()
image.save(buffered, format="png")
img_str = base64.b64encode(buffered.getvalue())
return img_str

async def verify(ground_truth, miner_response):
pass

return verify

Loading