Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion 1.4 #70

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tt-metal-stable-diffusion-1.4/.env.default
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
TT_METAL_DOCKERFILE_VERSION=v0.53.0-rc34
TT_METAL_COMMIT_SHA_OR_TAG=4da4a5e79a13ece7ff5096c30cef79cb0c504f0e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reference to this commit in tt-metal anywhere? Curious how you selected this one.

Copy link
Contributor Author

@bgoelTT bgoelTT Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copy+pasted this from the YOLOv4 server because that required a specific commit as the metal YOLOv4 improvements got reverted as couldn't pass CI.

Should I just use the latest release? That would be release v0.55.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use release v0.55.0 in 4582863

TT_METAL_COMMIT_DOCKER_TAG=4da4a5e79a13 # 12-character version of TT_METAL_COMMIT_SHA_OR_TAG
IMAGE_VERSION=v0.0.1
# These are secrets and must be stored securely for production environments
JWT_SECRET=testing
51 changes: 51 additions & 0 deletions tt-metal-stable-diffusion-1.4/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# TT Metalium Stable Diffusion 1.4 Inference API

This implementation supports Stable Diffusion 1.4 execution on Worhmole n150 (n300 currently broken).


## Table of Contents
- [Run server](#run-server)
- [JWT_TOKEN Authorization](#jwt_token-authorization)
- [Development](#development)
- [Tests](#tests)


## Run server
To run the SD1.4 inference server, run the following command from the project root at `tt-inference-server`:
```bash
cd tt-inference-server
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml up --build
```

This will start the default Docker container with the entrypoint command set to run the gunicorn server. The next section describes how to override the container's default command with an interractive shell via `bash`.


### JWT_TOKEN Authorization

To authenticate requests use the header `Authorization`. The JWT token can be computed using the script `jwt_util.py`. This is an example:
```bash
cd tt-inference-server/tt-metal-yolov4/server
export JWT_SECRET=<your-secure-secret>
export AUTHORIZATION="Bearer $(python scripts/jwt_util.py --secret ${JWT_SECRET?ERROR env var JWT_SECRET must be set} encode '{"team_id": "tenstorrent", "token_id":"debug-test"}')"
```


## Development
Inside the container you can then start the server with:
```bash
docker compose --env-file tt-metal-stable-diffusion-1.4/.env.default -f tt-metal-stable-diffusion-1.4/docker-compose.yaml run --rm --build inference_server /bin/bash
```

Inside the container, run `cd ~/app/server` to navigate to the server implementation.


## Tests
Tests can be found in `tests/`. The tests have their own dependencies found in `requirements-test.txt`.

First, ensure the server is running (see [how to run the server](#run-server)). Then in a different shell with the base dev `venv` activated:
```bash
cd tt-metal-stable-diffusion-1.4
pip install -r requirements-test.txt
cd tests/
locust --config locust_config.conf
```
26 changes: 26 additions & 0 deletions tt-metal-stable-diffusion-1.4/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
services:
inference_server:
image: ghcr.io/tenstorrent/tt-inference-server/tt-metal-stable-diffusion-1.4-src-base:${IMAGE_VERSION}-tt-metal-${TT_METAL_COMMIT_DOCKER_TAG}
build:
context: ../
dockerfile: tt-metal-stable-diffusion-1.4/stable-diffusion-1.4.src.Dockerfile
args:
TT_METAL_DOCKERFILE_VERSION: ${TT_METAL_DOCKERFILE_VERSION}
TT_METAL_COMMIT_SHA_OR_TAG: ${TT_METAL_COMMIT_SHA_OR_TAG}
container_name: sd_inference_server
ports:
- "7000:7000"
devices:
- "/dev/tenstorrent:/dev/tenstorrent"
volumes:
- "/dev/hugepages-1G/:/dev/hugepages-1G:rw"
shm_size: "32G"
cap_add:
- ALL
stdin_open: true
tty: true
# this is redundant as docker compose automatically uses the .env file as its in the same directory
# but this explicitly demonstrates its usage
env_file:
- .env.default
restart: no
3 changes: 3 additions & 0 deletions tt-metal-stable-diffusion-1.4/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pillow==10.3.0
locust==2.25.0
pytest==7.2.2
5 changes: 5 additions & 0 deletions tt-metal-stable-diffusion-1.4/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# inference server requirements
flask==3.0.2
gunicorn==21.2.0
requests==2.31.0
pyjwt==2.7.0
217 changes: 217 additions & 0 deletions tt-metal-stable-diffusion-1.4/server/flaskserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from flask import (
abort,
Flask,
request,
jsonify,
send_from_directory,
)
import json
import os
import atexit
import time
import threading
from http import HTTPStatus
from utils.authentication import api_key_required

import subprocess
import signal
import sys

# script to run in background
script = "pytest models/demos/wormhole/stable_diffusion/demo/web_demo/sdserver.py"

# Start script using subprocess
process1 = subprocess.Popen(script, shell=True)


# Function to terminate both processes and kill port 5000
def signal_handler(sig, frame):
print("Terminating processes...")
process1.terminate()
sys.exit(0)


signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

app = Flask(__name__)

# var to indicate ready state
ready = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global variable ready : this flag is accessed and modified by multiple threads (warmup and request handlers). This may lead to race conditions. Consider using thread-safe mechanisms like threading.Lock or Event to manage state changes.

Copy link
Contributor Author

@bgoelTT bgoelTT Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforced mutual exclusion of ready flag in 835fc91 then fixed a bug in that implementeation in 63d973c


# internal json prompt file
json_file_path = (
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)


@app.route("/")
def hello_world():
return "Hello, World!"


def submit_prompt(prompt_file, prompt):
if not os.path.isfile(prompt_file):
with open(prompt_file, "w") as f:
json.dump({"prompts": []}, f)

with open(prompt_file, "r") as f:
prompts_data = json.load(f)

prompts_data["prompts"].append({"prompt": prompt, "status": "not generated"})

with open(prompt_file, "w") as f:
json.dump(prompts_data, f, indent=4)


def warmup():
sample_prompt = "Unicorn on a banana"
# submit sample prompt to perform tracing and server warmup
submit_prompt(json_file_path, sample_prompt)
global ready
while not ready:
with open(json_file_path, "r") as f:
prompts_data = json.load(f)
# sample prompt should be first prompt
sample_prompt_data = prompts_data["prompts"][0]
if sample_prompt_data["prompt"] == sample_prompt:
# TODO: remove this and replace with status check == "done"
# to flip ready flag
if sample_prompt_data["status"] == "done":
ready = True
print(sample_prompt_data["status"])
time.sleep(3)


# start warmup routine in background while server starts
warmup_thread = threading.Thread(target=warmup, name="warmup")
warmup_thread.start()


@app.route("/health")
def health_check():
if not ready:
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet")
return jsonify({"message": "OK\n"}), 200


@app.route("/submit", methods=["POST"])
@api_key_required
def submit():
global ready
if not ready:
abort(HTTPStatus.SERVICE_UNAVAILABLE, description="Server is not ready yet")
data = request.get_json()
prompt = data.get("prompt")
print(prompt)

submit_prompt(json_file_path, prompt)

return jsonify({"message": "Prompt received and added to queue."})


@app.route("/update_status", methods=["POST"])
def update_status():
data = request.get_json()
prompt = data.get("prompt")

with open(json_file_path, "r") as f:
prompts_data = json.load(f)

for p in prompts_data["prompts"]:
if p["prompt"] == prompt:
p["status"] = "generated"
break

with open(json_file_path, "w") as f:
json.dump(prompts_data, f, indent=4)

return jsonify({"message": "Prompt status updated to generated."})


@app.route("/get_image", methods=["GET"])
def get_image():
image_name = "interactive_512x512_ttnn.png"
directory = os.getcwd() # Get the current working directory
return send_from_directory(directory, image_name)


@app.route("/image_exists", methods=["GET"])
def image_exists():
image_path = "interactive_512x512_ttnn.png"
if os.path.isfile(image_path):
return jsonify({"exists": True}), 200
else:
return jsonify({"exists": False}), 200


@app.route("/clean_up", methods=["POST"])
def clean_up():
with open(json_file_path, "r") as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated code blocks : the code repeatedly reads and writes json_file_path. This can be refactored into utility functions for cleaner logic:

def read_json_file(file_path):
    if not os.path.isfile(file_path):
        return {"prompts": []}
    with open(file_path, "r") as f:
        return json.load(f)

def write_json_file(file_path, data):
    with open(file_path, "w") as f:
        json.dump(data, f, indent=4)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored in a75d0e6

prompts_data = json.load(f)

prompts_data["prompts"] = [
p for p in prompts_data["prompts"] if p["status"] != "done"
]

with open(json_file_path, "w") as f:
json.dump(prompts_data, f, indent=4)

return jsonify({"message": "Cleaned up done prompts."})


@app.route("/get_latest_time", methods=["GET"])
def get_latest_time():
if not os.path.isfile(json_file_path):
return jsonify({"message": "No prompts found"}), 404

with open(json_file_path, "r") as f:
prompts_data = json.load(f)

# Filter prompts that have a total_acc time available
completed_prompts = [p for p in prompts_data["prompts"] if "total_acc" in p]

if not completed_prompts:
return jsonify({"message": "No completed prompts with time available"}), 404

# Get the latest prompt with total_acc
latest_prompt = completed_prompts[-1] # Assuming prompts are in chronological order

return (
jsonify(
{
"prompt": latest_prompt["prompt"],
"total_acc": latest_prompt["total_acc"],
"batch_size": latest_prompt["batch_size"],
"steps": latest_prompt["steps"],
}
),
200,
)


def cleanup():
if os.path.isfile(
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
):
os.remove(
"models/demos/wormhole/stable_diffusion/demo/web_demo/input_prompts.json"
)
print("Deleted json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use logging instead of print statements for better production diagnostics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used logging instead of print in f5650c5


if os.path.isfile("interactive_512x512_ttnn.png"):
os.remove("interactive_512x512_ttnn.png")
print("Deleted image")

signal_handler(None, None)


atexit.register(cleanup)


def create_server():
return app
15 changes: 15 additions & 0 deletions tt-metal-stable-diffusion-1.4/server/gunicorn.conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC


workers = 1
# use 0.0.0.0 for externally accessible
bind = f"0.0.0.0:{7000}"
reload = False
worker_class = "gthread"
threads = 16
timeout = 0

# server factory
wsgi_app = "server.flaskserver:create_server()"
Loading