-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Diffusion 1.4 #70
base: main
Are you sure you want to change the base?
Changes from 8 commits
a352aaa
900c08b
302f293
136841a
572cdb3
cd869eb
f7774d7
a92e81c
4582863
f5650c5
ada9877
a75d0e6
835fc91
016af35
63d973c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
TT_METAL_DOCKERFILE_VERSION=v0.53.0-rc34 | ||
TT_METAL_COMMIT_SHA_OR_TAG=4da4a5e79a13ece7ff5096c30cef79cb0c504f0e | ||
TT_METAL_COMMIT_DOCKER_TAG=4da4a5e79a13 # 12-character version of TT_METAL_COMMIT_SHA_OR_TAG | ||
IMAGE_VERSION=v0.0.1 | ||
# These are secrets and must be stored securely for production environments | ||
JWT_SECRET=testing |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# TT Metalium Stable Diffusion 1.4 Inference API | ||
|
||
This implementation supports Stable Diffusion 1.4 execution on Worhmole n150 (n300 currently broken). | ||
|
||
|
||
## Table of Contents | ||
- [Run server](#run-server) | ||
- [JWT_TOKEN Authorization](#jwt_token-authorization) | ||
- [Development](#development) | ||
- [Tests](#tests) | ||
|
||
|
||
## Run server | ||
To run the SD1.4 inference server, run the following command from the project root at `tt-inference-server`: | ||
```bash | ||
cd tt-inference-server | ||
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml up --build | ||
``` | ||
|
||
This will start the default Docker container with the entrypoint command set to run the gunicorn server. The next section describes how to override the container's default command with an interractive shell via `bash`. | ||
|
||
|
||
### JWT_TOKEN Authorization | ||
|
||
To authenticate requests use the header `Authorization`. The JWT token can be computed using the script `jwt_util.py`. This is an example: | ||
```bash | ||
cd tt-inference-server/tt-metal-yolov4/server | ||
export JWT_SECRET=<your-secure-secret> | ||
export AUTHORIZATION="Bearer $(python scripts/jwt_util.py --secret ${JWT_SECRET?ERROR env var JWT_SECRET must be set} encode '{"team_id": "tenstorrent", "token_id":"debug-test"}')" | ||
``` | ||
|
||
|
||
## Development | ||
Inside the container you can then start the server with: | ||
```bash | ||
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml run --rm --build inference_server /bin/bash | ||
``` | ||
|
||
Inside the container, run `cd ~/app/server` to navigate to the server implementation. | ||
|
||
|
||
## Tests | ||
Tests can be found in `tests/`. The tests have their own dependencies found in `requirements-test.txt`. | ||
|
||
First, ensure the server is running (see [how to run the server](#run-server)). Then in a different shell with the base dev `venv` activated: | ||
```bash | ||
cd tt-metal-stable-diffusion-1.4 | ||
pip install -r requirements-test.txt | ||
cd tests/ | ||
locust --config locust_config.conf | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
services: | ||
inference_server: | ||
image: ghcr.io/tenstorrent/tt-inference-server/tt-metal-stable-diffusion-1.4-src-base:${IMAGE_VERSION}-tt-metal-${TT_METAL_COMMIT_DOCKER_TAG} | ||
build: | ||
context: ../ | ||
dockerfile: tt-metal-stable-diffusion-1.4/stable-diffusion-1.4.src.Dockerfile | ||
args: | ||
TT_METAL_DOCKERFILE_VERSION: ${TT_METAL_DOCKERFILE_VERSION} | ||
TT_METAL_COMMIT_SHA_OR_TAG: ${TT_METAL_COMMIT_SHA_OR_TAG} | ||
container_name: sd_inference_server | ||
ports: | ||
- "7000:7000" | ||
devices: | ||
- "/dev/tenstorrent:/dev/tenstorrent" | ||
volumes: | ||
- "/dev/hugepages-1G/:/dev/hugepages-1G:rw" | ||
shm_size: "32G" | ||
cap_add: | ||
- ALL | ||
stdin_open: true | ||
tty: true | ||
# this is redundant as docker compose automatically uses the .env file as its in the same directory | ||
# but this explicitly demonstrates its usage | ||
env_file: | ||
- .env.default | ||
restart: no |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pillow==10.3.0 | ||
locust==2.25.0 | ||
pytest==7.2.2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# inference server requirements | ||
flask==3.0.2 | ||
gunicorn==21.2.0 | ||
requests==2.31.0 | ||
pyjwt==2.7.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from flask import ( | ||
abort, | ||
Flask, | ||
request, | ||
jsonify, | ||
send_from_directory, | ||
) | ||
import json | ||
import os | ||
import atexit | ||
import time | ||
import threading | ||
from http import HTTPStatus | ||
from utils.authentication import api_key_required | ||
|
||
import subprocess | ||
import signal | ||
import sys | ||
|
||
# script to run in background | ||
script = "pytest models/demos/wormhole/stable_diffusion/demo/web_demo/sdserver.py" | ||
|
||
# Start script using subprocess | ||
process1 = subprocess.Popen(script, shell=True) | ||
|
||
|
||
# Function to terminate both processes and kill port 5000 | ||
def signal_handler(sig, frame): | ||
print("Terminating processes...") | ||
process1.terminate() | ||
sys.exit(0) | ||
|
||
|
||
signal.signal(signal.SIGINT, signal_handler) | ||
signal.signal(signal.SIGTERM, signal_handler) | ||
|
||
app = Flask(__name__) | ||
|
||
# var to indicate ready state | ||
ready = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Global variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# internal json prompt file | ||
json_file_path = ( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
) | ||
|
||
|
||
@app.route("/") | ||
def hello_world(): | ||
return "Hello, World!" | ||
|
||
|
||
def submit_prompt(prompt_file, prompt): | ||
if not os.path.isfile(prompt_file): | ||
with open(prompt_file, "w") as f: | ||
json.dump({"prompts": []}, f) | ||
|
||
with open(prompt_file, "r") as f: | ||
prompts_data = json.load(f) | ||
|
||
prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"}) | ||
|
||
with open(prompt_file, "w") as f: | ||
json.dump(prompts_data, f, indent=4) | ||
|
||
|
||
def warmup(): | ||
sample_prompt = "Unicorn on a banana" | ||
# submit sample prompt to perform tracing and server warmup | ||
submit_prompt(json_file_path, sample_prompt) | ||
global ready | ||
while not ready: | ||
with open(json_file_path, "r") as f: | ||
prompts_data = json.load(f) | ||
# sample prompt should be first prompt | ||
sample_prompt_data = prompts_data["prompts"][0] | ||
if sample_prompt_data["prompt"] == sample_prompt: | ||
# TODO: remove this and replace with status check == "done" | ||
# to flip ready flag | ||
if sample_prompt_data["status"] == "done": | ||
ready = True | ||
print(sample_prompt_data["status"]) | ||
time.sleep(3) | ||
|
||
|
||
# start warmup routine in background while server starts | ||
warmup_thread = threading.Thread(target=warmup, name="warmup") | ||
warmup_thread.start() | ||
|
||
|
||
@app.route("/health") | ||
def health_check(): | ||
if not ready: | ||
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet") | ||
return jsonify({"message": "OK\n"}), 200 | ||
|
||
|
||
@app.route("/submit", methods=["POST"]) | ||
@api_key_required | ||
def submit(): | ||
global ready | ||
if not ready: | ||
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet") | ||
data = request.get_json() | ||
prompt = data.get("prompt") | ||
print(prompt) | ||
|
||
submit_prompt(json_file_path, prompt) | ||
|
||
return jsonify({"message": "Prompt received and added to queue."}) | ||
|
||
|
||
@app.route("/update_status", methods=["POST"]) | ||
def update_status(): | ||
data = request.get_json() | ||
prompt = data.get("prompt") | ||
|
||
with open(json_file_path, "r") as f: | ||
prompts_data = json.load(f) | ||
|
||
for p in prompts_data["prompts"]: | ||
if p["prompt"] == prompt: | ||
p["status"] = "generated" | ||
break | ||
|
||
with open(json_file_path, "w") as f: | ||
json.dump(prompts_data, f, indent=4) | ||
|
||
return jsonify({"message": "Prompt status updated to generated."}) | ||
|
||
|
||
@app.route("/get_image", methods=["GET"]) | ||
def get_image(): | ||
image_name = "interactive_512x512_ttnn.png" | ||
directory = os.getcwd() # Get the current working directory | ||
return send_from_directory(directory, image_name) | ||
|
||
|
||
@app.route("/image_exists", methods=["GET"]) | ||
def image_exists(): | ||
image_path = "interactive_512x512_ttnn.png" | ||
if os.path.isfile(image_path): | ||
return jsonify({"exists": True}), 200 | ||
else: | ||
return jsonify({"exists": False}), 200 | ||
|
||
|
||
@app.route("/clean_up", methods=["POST"]) | ||
def clean_up(): | ||
with open(json_file_path, "r") as f: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Repeated code blocks : the code repeatedly reads and writes def read_json_file(file_path):
if not os.path.isfile(file_path):
return {"prompts": []}
with open(file_path, "r") as f:
return json.load(f)
def write_json_file(file_path, data):
with open(file_path, "w") as f:
json.dump(data, f, indent=4) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored in a75d0e6 |
||
prompts_data = json.load(f) | ||
|
||
prompts_data["prompts"] = [ | ||
p for p in prompts_data["prompts"] if p["status"] != "done" | ||
] | ||
|
||
with open(json_file_path, "w") as f: | ||
json.dump(prompts_data, f, indent=4) | ||
|
||
return jsonify({"message": "Cleaned up done prompts."}) | ||
|
||
|
||
@app.route("/get_latest_time", methods=["GET"]) | ||
def get_latest_time(): | ||
if not os.path.isfile(json_file_path): | ||
return jsonify({"message": "No prompts found"}), 404 | ||
|
||
with open(json_file_path, "r") as f: | ||
prompts_data = json.load(f) | ||
|
||
# Filter prompts that have a total_acc time available | ||
completed_prompts = [p for p in prompts_data["prompts"] if "total_acc" in p] | ||
|
||
if not completed_prompts: | ||
return jsonify({"message": "No completed prompts with time available"}), 404 | ||
|
||
# Get the latest prompt with total_acc | ||
latest_prompt = completed_prompts[-1] # Assuming prompts are in chronological order | ||
|
||
return ( | ||
jsonify( | ||
{ | ||
"prompt": latest_prompt["prompt"], | ||
"total_acc": latest_prompt["total_acc"], | ||
"batch_size": latest_prompt["batch_size"], | ||
"steps": latest_prompt["steps"], | ||
} | ||
), | ||
200, | ||
) | ||
|
||
|
||
def cleanup(): | ||
if os.path.isfile( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
): | ||
os.remove( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
) | ||
print("Deleted json") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Used |
||
|
||
if os.path.isfile("interactive_512x512_ttnn.png"): | ||
os.remove("interactive_512x512_ttnn.png") | ||
print("Deleted image") | ||
|
||
signal_handler(None, None) | ||
|
||
|
||
atexit.register(cleanup) | ||
|
||
|
||
def create_server(): | ||
return app |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
|
||
workers = 1 | ||
# use 0.0.0.0 for externally accessible | ||
bind = f"0.0.0.0:{7000}" | ||
reload = False | ||
worker_class = "gthread" | ||
threads = 16 | ||
timeout = 0 | ||
|
||
# server factory | ||
wsgi_app = "server.flaskserver:create_server()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reference to this commit in
tt-metal
anywhere? Curious how you selected this one.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copy+pasted this from the YOLOv4 server because that required a specific commit as the metal YOLOv4 improvements got reverted as couldn't pass CI.
Should I just use the latest release? That would be release v0.55.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use release v0.55.0 in 4582863