Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion 1.4 #70

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tt-metal-stable-diffusion-1.4/.env.default
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
TT_METAL_DOCKERFILE_VERSION=v0.55.0
TT_METAL_COMMIT_SHA_OR_TAG=v0.55.0
TT_METAL_COMMIT_DOCKER_TAG=v0.55.0 # technically redundant but this var is used to name the image
IMAGE_VERSION=v0.0.1
# These are secrets and must be stored securely for production environments
JWT_SECRET=testing
51 changes: 51 additions & 0 deletions tt-metal-stable-diffusion-1.4/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# TT Metalium Stable Diffusion 1.4 Inference API

This implementation supports Stable Diffusion 1.4 execution on Worhmole n150 (n300 currently broken).


## Table of Contents
- [Run server](#run-server)
- [JWT_TOKEN Authorization](#jwt_token-authorization)
- [Development](#development)
- [Tests](#tests)


## Run server
To run the SD1.4 inference server, run the following command from the project root at `tt-inference-server`:
```bash
cd tt-inference-server
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml up --build
```

This will start the default Docker container with the entrypoint command set to run the gunicorn server. The next section describes how to override the container's default command with an interractive shell via `bash`.


### JWT_TOKEN Authorization

To authenticate requests use the header `Authorization`. The JWT token can be computed using the script `jwt_util.py`. This is an example:
```bash
cd tt-inference-server/tt-metal-yolov4/server
export JWT_SECRET=<your-secure-secret>
export AUTHORIZATION="Bearer $(python scripts/jwt_util.py --secret ${JWT_SECRET?ERROR env var JWT_SECRET must be set} encode '{"team_id": "tenstorrent", "token_id":"debug-test"}')"
```


## Development
Inside the container you can then start the server with:
```bash
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml run --rm --build inference_server /bin/bash
```

Inside the container, run `cd ~/app/server` to navigate to the server implementation.


## Tests
Tests can be found in `tests/`. The tests have their own dependencies found in `requirements-test.txt`.

First, ensure the server is running (see [how to run the server](#run-server)). Then in a different shell with the base dev `venv` activated:
```bash
cd tt-metal-stable-diffusion-1.4
pip install -r requirements-test.txt
cd tests/
locust --config locust_config.conf
```
26 changes: 26 additions & 0 deletions tt-metal-stable-diffusion-1.4/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
services:
inference_server:
image: ghcr.io/tenstorrent/tt-inference-server/tt-metal-stable-diffusion-1.4-src-base:${IMAGE_VERSION}-tt-metal-${TT_METAL_COMMIT_DOCKER_TAG}
build:
context: ../
dockerfile: tt-metal-stable-diffusion-1.4/stable-diffusion-1.4.src.Dockerfile
args:
TT_METAL_DOCKERFILE_VERSION: ${TT_METAL_DOCKERFILE_VERSION}
TT_METAL_COMMIT_SHA_OR_TAG: ${TT_METAL_COMMIT_SHA_OR_TAG}
container_name: sd_inference_server
ports:
- "7000:7000"
devices:
- "/dev/tenstorrent:/dev/tenstorrent"
volumes:
- "/dev/hugepages-1G/:/dev/hugepages-1G:rw"
shm_size: "32G"
cap_add:
- ALL
stdin_open: true
tty: true
# this is redundant as docker compose automatically uses the .env file as its in the same directory
# but this explicitly demonstrates its usage
env_file:
- .env.default
restart: no
3 changes: 3 additions & 0 deletions tt-metal-stable-diffusion-1.4/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pillow==10.3.0
locust==2.25.0
pytest==7.2.2
5 changes: 5 additions & 0 deletions tt-metal-stable-diffusion-1.4/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# inference server requirements
flask==3.0.2
gunicorn==21.2.0
requests==2.31.0
pyjwt==2.7.0
234 changes: 234 additions & 0 deletions tt-metal-stable-diffusion-1.4/server/flaskserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
#
# SPDX-License-Identifier: Apache-2.0

from flask import (
abort,
Flask,
request,
jsonify,
send_from_directory,
)
import json
import logging
import os
import atexit
import time
import threading
from http import HTTPStatus
from utils.authentication import api_key_required

import subprocess
import signal
import sys


# initialize logger
logger = logging.getLogger(__name__)


# script to run in background
script = "pytest models/demos/wormhole/stable_diffusion/demo/web_demo/sdserver.py"

# Start script using subprocess
process1 = subprocess.Popen(script, shell=True)


# Function to terminate both processes and kill port 5000
def signal_handler(sig, frame):
logger.info("Terminating processes...")
process1.terminate()
sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

app = Flask(__name__)

# var to indicate ready state
ready = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global variable ready : this flag is accessed and modified by multiple threads (warmup and request handlers). This may lead to race conditions. Consider using thread-safe mechanisms like threading.Lock or Event to manage state changes.

Copy link
Contributor Author

@bgoelTT bgoelTT Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforced mutual exclusion of ready flag in 835fc91 then fixed a bug in that implementeation in 63d973c

# lock for guaranteeing mutual exclusion
ready_lock = threading.Lock()

# internal json prompt file
json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)


@app.route("/")
def hello_world():
return "Hello, World!"


def read_json_file(file_path):
if not os.path.isfile(file_path):
raise FileNotFoundError(f"{file_path} is not a file")
with open(file_path, "r") as f:
return json.load(f)


def write_json_file(file_path, data):
with open(file_path, "w") as f:
json.dump(data, f, indent=4)


def submit_prompt(prompt_file, prompt):
if not os.path.isfile(prompt_file):
write_json_file(prompt_file, {"prompts": []})

prompts_data = read_json_file(prompt_file)
prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"})

write_json_file(prompt_file, prompts_data)


def warmup():
sample_prompt = "Unicorn on a banana"
# submit sample prompt to perform tracing and server warmup
submit_prompt(json_file_path, sample_prompt)
global ready
while True:
with ready_lock:
if ready:
break
prompts_data = read_json_file(json_file_path)
# sample prompt should be first prompt
sample_prompt_data = prompts_data["prompts"][0]
if sample_prompt_data["prompt"] == sample_prompt:
# TODO: remove this and replace with status check == "done"
# to flip ready flag
if sample_prompt_data["status"] == "done":
with ready_lock:
ready = True
logger.info("Warmup complete")
time.sleep(3)


# start warmup routine in background while server starts
warmup_thread = threading.Thread(target=warmup, name="warmup")
warmup_thread.start()


@app.route("/health")
def health_check():
with ready_lock:
if not ready:
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet")
return jsonify({"message": "OK\n"}), 200


@app.route("/submit", methods=["POST"])
@api_key_required
def submit():
global ready
with ready_lock:
if not ready:
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet")
data = request.get_json()
prompt = data.get("prompt")
logger.info(f"Prompt: {prompt}")

submit_prompt(json_file_path, prompt)

return jsonify({"message": "Prompt received and added to queue."})


@app.route("/update_status", methods=["POST"])
def update_status():
data = request.get_json()
prompt = data.get("prompt")

prompts_data = read_json_file(json_file_path)

for p in prompts_data["prompts"]:
if p["prompt"] == prompt:
p["status"] = "generated"
break

write_json_file(json_file_path, prompts_data)

return jsonify({"message": "Prompt status updated to generated."})


@app.route("/get_image", methods=["GET"])
def get_image():
image_name = "interactive_512x512_ttnn.png"
directory = os.getcwd() # Get the current working directory
return send_from_directory(directory, image_name)


@app.route("/image_exists", methods=["GET"])
def image_exists():
image_path = "interactive_512x512_ttnn.png"
if os.path.isfile(image_path):
return jsonify({"exists": True}), 200
else:
return jsonify({"exists": False}), 200


@app.route("/clean_up", methods=["POST"])
def clean_up():
prompts_data = read_json_file(json_file_path)

prompts_data["prompts"] = [
p for p in prompts_data["prompts"] if p["status"] != "done"
]

write_json_file(json_file_path, prompts_data)

return jsonify({"message": "Cleaned up done prompts."})


@app.route("/get_latest_time", methods=["GET"])
def get_latest_time():
if not os.path.isfile(json_file_path):
return jsonify({"message": "No prompts found"}), 404

prompts_data = read_json_file(json_file_path)

# Filter prompts that have a total_acc time available
completed_prompts = [p for p in prompts_data["prompts"] if "total_acc" in p]

if not completed_prompts:
return jsonify({"message": "No completed prompts with time available"}), 404

# Get the latest prompt with total_acc
latest_prompt = completed_prompts[-1] # Assuming prompts are in chronological order

return (
jsonify(
{
"prompt": latest_prompt["prompt"],
"total_acc": latest_prompt["total_acc"],
"batch_size": latest_prompt["batch_size"],
"steps": latest_prompt["steps"],
}
),
200,
)


def cleanup():
if os.path.isfile(
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
):
os.remove(
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)
logger.info("Deleted json")

if os.path.isfile("interactive_512x512_ttnn.png"):
os.remove("interactive_512x512_ttnn.png")
logger.info("Deleted image")

signal_handler(None, None)
logger.info("Cleanup complete")


atexit.register(cleanup)


def create_server():
return app
15 changes: 15 additions & 0 deletions tt-metal-stable-diffusion-1.4/server/gunicorn.conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC


workers = 1
# use 0.0.0.0 for externally accessible
bind = f"0.0.0.0:{7000}"
reload = False
worker_class = "gthread"
threads = 16
timeout = 0

# server factory
wsgi_app = "server.flaskserver:create_server()"
Loading