-
Notifications
You must be signed in to change notification settings - Fork 0
/
pyproject.toml
58 lines (52 loc) · 1.53 KB
/
pyproject.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# uv pip install jax[tpu]\>=0.4.31 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --prerelease=allow
[project]
name = "fae"
version = "0.1.0"
description = "Flux SAE"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"accelerate>=1.2.1",
"datasets>=2.19.1",
"diffusers>=0.31.0",
"einops>=0.8.0",
"equinox>=0.11.10",
"fastapi>=0.115.4",
"fh-plotly>=0.1.1",
"flax>=0.9.0",
"jax-smi>=1.0.4",
"jax[tpu]>=0.4.34",
"jaxonnxruntime",
"libtpu-nightly",
"loguru>=0.7.2",
"matplotlib>=3.9.2",
"more-itertools>=10.5.0",
"numba>=0.60.0",
"onnx>=1.17.0",
"orbax>=0.1.9",
"oryx>=0.2.7",
"plotly-express>=0.4.1",
"python-fasthtml>=0.10.0",
"qax>=0.4.1",
"sentencepiece>=0.2.0",
"torch==2.4.1+cpu",
"torchvision==0.19.1+cpu",
"transformers>=4.44.2",
"uvicorn>=0.32.0",
"wandb>=0.18.7",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/vaov"]
[tool.uv]
extra-index-url = ["https://download.pytorch.org/whl/cpu"]
index-strategy = "first-index"
[tool.uv.sources]
libtpu-nightly = { url = "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20241002+nightly-py3-none-any.whl" }
jaxonnxruntime = { path = "jort" }
oryx = { git = "https://github.com/jax-ml/oryx", rev = "339bc9215962dc998637d53ca8f19a3cea100ec0" }
mishax = { git = "https://github.com/google-deepmind/mishax" }
[tool.ruff.lint]
ignore = ["F722"]