-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Diffusion 1.4 #70
Open
bgoelTT
wants to merge
15
commits into
main
Choose a base branch
from
ben/sd1.4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
a352aaa
Add initial docker setup
bgoelTT 900c08b
Add sd server as background process in flask server
bgoelTT 302f293
Update README and testing
bgoelTT 136841a
Add ready check mechanism, not finished
bgoelTT 572cdb3
Finish readycheck warmup thread
bgoelTT cd869eb
Add health check endpoint
bgoelTT f7774d7
Add common API token utility and use
bgoelTT a92e81c
Return JSON object
bgoelTT 4582863
Use tt-metal release v0.55.0
bgoelTT f5650c5
Use logging instead of print
bgoelTT ada9877
Update header year to 2025
bgoelTT a75d0e6
Refactor repeated functionality
bgoelTT 835fc91
Add mutual exclusion to ready variable
bgoelTT 016af35
Remove locust tests and add healthcheck test
bgoelTT 63d973c
Test more API endpoints and fix bug in locking mechanism
bgoelTT File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
TT_METAL_DOCKERFILE_VERSION=v0.55.0 | ||
TT_METAL_COMMIT_SHA_OR_TAG=v0.55.0 | ||
TT_METAL_COMMIT_DOCKER_TAG=v0.55.0 # technically redundant but this var is used to name the image | ||
IMAGE_VERSION=v0.0.1 | ||
# These are secrets and must be stored securely for production environments | ||
JWT_SECRET=testing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# TT Metalium Stable Diffusion 1.4 Inference API | ||
|
||
This implementation supports Stable Diffusion 1.4 execution on Worhmole n150 (n300 currently broken). | ||
|
||
|
||
## Table of Contents | ||
- [Run server](#run-server) | ||
- [JWT_TOKEN Authorization](#jwt_token-authorization) | ||
- [Development](#development) | ||
- [Tests](#tests) | ||
|
||
|
||
## Run server | ||
To run the SD1.4 inference server, run the following command from the project root at `tt-inference-server`: | ||
```bash | ||
cd tt-inference-server | ||
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml up --build | ||
``` | ||
|
||
This will start the default Docker container with the entrypoint command set to run the gunicorn server. The next section describes how to override the container's default command with an interractive shell via `bash`. | ||
|
||
|
||
### JWT_TOKEN Authorization | ||
|
||
To authenticate requests use the header `Authorization`. The JWT token can be computed using the script `jwt_util.py`. This is an example: | ||
```bash | ||
cd tt-inference-server/tt-metal-yolov4/server | ||
export JWT_SECRET=<your-secure-secret> | ||
export AUTHORIZATION="Bearer $(python scripts/jwt_util.py --secret ${JWT_SECRET?ERROR env var JWT_SECRET must be set} encode '{"team_id": "tenstorrent", "token_id":"debug-test"}')" | ||
``` | ||
|
||
|
||
## Development | ||
Inside the container you can then start the server with: | ||
```bash | ||
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml run --rm --build inference_server /bin/bash | ||
``` | ||
|
||
Inside the container, run `cd ~/app/server` to navigate to the server implementation. | ||
|
||
|
||
## Tests | ||
Tests can be found in `tests/`. The tests have their own dependencies found in `requirements-test.txt`. | ||
|
||
First, ensure the server is running (see [how to run the server](#run-server)). Then in a different shell with the base dev `venv` activated: | ||
```bash | ||
cd tt-metal-stable-diffusion-1.4 | ||
pip install -r requirements-test.txt | ||
cd tests/ | ||
locust --config locust_config.conf | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
services: | ||
inference_server: | ||
image: ghcr.io/tenstorrent/tt-inference-server/tt-metal-stable-diffusion-1.4-src-base:${IMAGE_VERSION}-tt-metal-${TT_METAL_COMMIT_DOCKER_TAG} | ||
build: | ||
context: ../ | ||
dockerfile: tt-metal-stable-diffusion-1.4/stable-diffusion-1.4.src.Dockerfile | ||
args: | ||
TT_METAL_DOCKERFILE_VERSION: ${TT_METAL_DOCKERFILE_VERSION} | ||
TT_METAL_COMMIT_SHA_OR_TAG: ${TT_METAL_COMMIT_SHA_OR_TAG} | ||
container_name: sd_inference_server | ||
ports: | ||
- "7000:7000" | ||
devices: | ||
- "/dev/tenstorrent:/dev/tenstorrent" | ||
volumes: | ||
- "/dev/hugepages-1G/:/dev/hugepages-1G:rw" | ||
shm_size: "32G" | ||
cap_add: | ||
- ALL | ||
stdin_open: true | ||
tty: true | ||
# this is redundant as docker compose automatically uses the .env file as its in the same directory | ||
# but this explicitly demonstrates its usage | ||
env_file: | ||
- .env.default | ||
restart: no |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pillow==10.3.0 | ||
locust==2.25.0 | ||
pytest==7.2.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# inference server requirements | ||
flask==3.0.2 | ||
gunicorn==21.2.0 | ||
requests==2.31.0 | ||
pyjwt==2.7.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from flask import ( | ||
abort, | ||
Flask, | ||
request, | ||
jsonify, | ||
send_from_directory, | ||
) | ||
import json | ||
import logging | ||
import os | ||
import atexit | ||
import time | ||
import threading | ||
from http import HTTPStatus | ||
from utils.authentication import api_key_required | ||
|
||
import subprocess | ||
import signal | ||
import sys | ||
|
||
|
||
# initialize logger | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# script to run in background | ||
script = "pytest models/demos/wormhole/stable_diffusion/demo/web_demo/sdserver.py" | ||
|
||
# Start script using subprocess | ||
process1 = subprocess.Popen(script, shell=True) | ||
|
||
|
||
# Function to terminate both processes and kill port 5000 | ||
def signal_handler(sig, frame): | ||
logger.info("Terminating processes...") | ||
process1.terminate() | ||
sys.exit(0) | ||
|
||
|
||
signal.signal(signal.SIGINT, signal_handler) | ||
signal.signal(signal.SIGTERM, signal_handler) | ||
|
||
app = Flask(__name__) | ||
|
||
# var to indicate ready state | ||
ready = False | ||
# lock for guaranteeing mutual exclusion | ||
ready_lock = threading.Lock() | ||
|
||
# internal json prompt file | ||
json_file_path = ( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
) | ||
|
||
|
||
@app.route("/") | ||
def hello_world(): | ||
return "Hello, World!" | ||
|
||
|
||
def read_json_file(file_path): | ||
if not os.path.isfile(file_path): | ||
raise FileNotFoundError(f"{file_path} is not a file") | ||
with open(file_path, "r") as f: | ||
return json.load(f) | ||
|
||
|
||
def write_json_file(file_path, data): | ||
with open(file_path, "w") as f: | ||
json.dump(data, f, indent=4) | ||
|
||
|
||
def submit_prompt(prompt_file, prompt): | ||
if not os.path.isfile(prompt_file): | ||
write_json_file(prompt_file, {"prompts": []}) | ||
|
||
prompts_data = read_json_file(prompt_file) | ||
prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"}) | ||
|
||
write_json_file(prompt_file, prompts_data) | ||
|
||
|
||
def warmup(): | ||
sample_prompt = "Unicorn on a banana" | ||
# submit sample prompt to perform tracing and server warmup | ||
submit_prompt(json_file_path, sample_prompt) | ||
global ready | ||
while True: | ||
with ready_lock: | ||
if ready: | ||
break | ||
prompts_data = read_json_file(json_file_path) | ||
# sample prompt should be first prompt | ||
sample_prompt_data = prompts_data["prompts"][0] | ||
if sample_prompt_data["prompt"] == sample_prompt: | ||
# TODO: remove this and replace with status check == "done" | ||
# to flip ready flag | ||
if sample_prompt_data["status"] == "done": | ||
with ready_lock: | ||
ready = True | ||
logger.info("Warmup complete") | ||
time.sleep(3) | ||
|
||
|
||
# start warmup routine in background while server starts | ||
warmup_thread = threading.Thread(target=warmup, name="warmup") | ||
warmup_thread.start() | ||
|
||
|
||
@app.route("/health") | ||
def health_check(): | ||
with ready_lock: | ||
if not ready: | ||
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet") | ||
return jsonify({"message": "OK\n"}), 200 | ||
|
||
|
||
@app.route("/submit", methods=["POST"]) | ||
@api_key_required | ||
def submit(): | ||
global ready | ||
with ready_lock: | ||
if not ready: | ||
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet") | ||
data = request.get_json() | ||
prompt = data.get("prompt") | ||
logger.info(f"Prompt: {prompt}") | ||
|
||
submit_prompt(json_file_path, prompt) | ||
|
||
return jsonify({"message": "Prompt received and added to queue."}) | ||
|
||
|
||
@app.route("/update_status", methods=["POST"]) | ||
def update_status(): | ||
data = request.get_json() | ||
prompt = data.get("prompt") | ||
|
||
prompts_data = read_json_file(json_file_path) | ||
|
||
for p in prompts_data["prompts"]: | ||
if p["prompt"] == prompt: | ||
p["status"] = "generated" | ||
break | ||
|
||
write_json_file(json_file_path, prompts_data) | ||
|
||
return jsonify({"message": "Prompt status updated to generated."}) | ||
|
||
|
||
@app.route("/get_image", methods=["GET"]) | ||
def get_image(): | ||
image_name = "interactive_512x512_ttnn.png" | ||
directory = os.getcwd() # Get the current working directory | ||
return send_from_directory(directory, image_name) | ||
|
||
|
||
@app.route("/image_exists", methods=["GET"]) | ||
def image_exists(): | ||
image_path = "interactive_512x512_ttnn.png" | ||
if os.path.isfile(image_path): | ||
return jsonify({"exists": True}), 200 | ||
else: | ||
return jsonify({"exists": False}), 200 | ||
|
||
|
||
@app.route("/clean_up", methods=["POST"]) | ||
def clean_up(): | ||
prompts_data = read_json_file(json_file_path) | ||
|
||
prompts_data["prompts"] = [ | ||
p for p in prompts_data["prompts"] if p["status"] != "done" | ||
] | ||
|
||
write_json_file(json_file_path, prompts_data) | ||
|
||
return jsonify({"message": "Cleaned up done prompts."}) | ||
|
||
|
||
@app.route("/get_latest_time", methods=["GET"]) | ||
def get_latest_time(): | ||
if not os.path.isfile(json_file_path): | ||
return jsonify({"message": "No prompts found"}), 404 | ||
|
||
prompts_data = read_json_file(json_file_path) | ||
|
||
# Filter prompts that have a total_acc time available | ||
completed_prompts = [p for p in prompts_data["prompts"] if "total_acc" in p] | ||
|
||
if not completed_prompts: | ||
return jsonify({"message": "No completed prompts with time available"}), 404 | ||
|
||
# Get the latest prompt with total_acc | ||
latest_prompt = completed_prompts[-1] # Assuming prompts are in chronological order | ||
|
||
return ( | ||
jsonify( | ||
{ | ||
"prompt": latest_prompt["prompt"], | ||
"total_acc": latest_prompt["total_acc"], | ||
"batch_size": latest_prompt["batch_size"], | ||
"steps": latest_prompt["steps"], | ||
} | ||
), | ||
200, | ||
) | ||
|
||
|
||
def cleanup(): | ||
if os.path.isfile( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
): | ||
os.remove( | ||
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json" | ||
) | ||
logger.info("Deleted json") | ||
|
||
if os.path.isfile("interactive_512x512_ttnn.png"): | ||
os.remove("interactive_512x512_ttnn.png") | ||
logger.info("Deleted image") | ||
|
||
signal_handler(None, None) | ||
logger.info("Cleanup complete") | ||
|
||
|
||
atexit.register(cleanup) | ||
|
||
|
||
def create_server(): | ||
return app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC | ||
|
||
|
||
workers = 1 | ||
# use 0.0.0.0 for externally accessible | ||
bind = f"0.0.0.0:{7000}" | ||
reload = False | ||
worker_class = "gthread" | ||
threads = 16 | ||
timeout = 0 | ||
|
||
# server factory | ||
wsgi_app = "server.flaskserver:create_server()" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Global variable
ready
: this flag is accessed and modified by multiple threads (warmup and request handlers). This may lead to race conditions. Consider using thread-safe mechanisms likethreading.Lock
orEvent
to manage state changes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforced mutual exclusion of
ready
flag in 835fc91 then fixed a bug in that implementeation in 63d973c