Skip to content

Commit

Permalink
feat(python): Add threaded executor
Browse files Browse the repository at this point in the history
  • Loading branch information
manzt committed Jul 21, 2024
1 parent e10951d commit abc546a
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 123 deletions.
40 changes: 0 additions & 40 deletions python/Untitled.ipynb

This file was deleted.

6 changes: 4 additions & 2 deletions python/deno.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
]
},
"fmt": {
"useTabs": true
"useTabs": true,
"exclude": [".venv", "notebooks"]
},
"lint": {
"rules": {
"exclude": [
"prefer-const"
]
}
},
"exclude": [".venv", "notebooks"]
}
}
7 changes: 2 additions & 5 deletions python/notebooks/mandelbrot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@
"# Initialize the store\n",
"store = MandlebrotStore(levels=50, tilesize=512, compressor=numcodecs.Blosc())\n",
"# Wrap in a cache so that tiles don't need to be computed as often\n",
"store = zarr.LRUStoreCache(store, max_size=1e9)\n",
"\n",
"# This store implements the 'multiscales' zarr specfiication which is recognized by vizarr\n",
"grp = zarr.open(store, mode=\"r\")"
"store = zarr.LRUStoreCache(store, max_size=1e9)"
]
},
{
Expand All @@ -182,7 +179,7 @@
"import vizarr\n",
"\n",
"viewer = vizarr.Viewer()\n",
"viewer.add_image(source=grp, name=\"mandelbrot\")\n",
"viewer.add_image(source=store, name=\"mandelbrot\")\n",
"viewer"
]
}
Expand Down
23 changes: 20 additions & 3 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ dependencies = ["anywidget", "zarr"]
[project.optional-dependencies]
dev = ["watchfiles", "jupyterlab"]

# automatically add the dev feature to the default env (e.g., hatch shell)
[tool.hatch.envs.default]
features = ["dev"]
[tool.ruff.lint]
pydocstyle = { convention = "numpy" }
select = [
"E", # style errors
"W", # style warnings
"F", # flakes
"D", # pydocstyle
"D417", # Missing argument descriptions in Docstrings
"I", # isort
"UP", # pyupgrade
"C4", # flake8-comprehensions
"B", # flake8-bugbear
"A001", # flake8-builtins
"RUF", # ruff-specific rules
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
]

[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
84 changes: 43 additions & 41 deletions python/src/vizarr/_widget.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as vizarr from "https://hms-dbmi.github.io/vizarr/index.js";
import debounce from "https://esm.sh/just-debounce-it@3";
import debounce from "https://esm.sh/just-debounce-it@3.2.0";

/**
* @template T
Expand All @@ -9,24 +9,24 @@ import debounce from "https://esm.sh/just-debounce-it@3";
* @returns {Promise<{ data: T, buffers: DataView[] }>}
*/
function send(model, payload, { timeout = 3000 } = {}) {
let uuid = globalThis.crypto.randomUUID();
let id = Math.random().toString(36).substring(7);
return new Promise((resolve, reject) => {
let timer = setTimeout(() => {
reject(new Error(`Promise timed out after ${timeout} ms`));
model.off("msg:custom", handler);
}, timeout);
/**
* @param {{ uuid: string, payload: T }} msg
* @param {{ id: string, payload: T }} msg
* @param {DataView[]} buffers
*/
function handler(msg, buffers) {
if (!(msg.uuid === uuid)) return;
if (!(msg.id === id)) return;
clearTimeout(timer);
resolve({ data: msg.payload, buffers });
model.off("msg:custom", handler);
}
model.on("msg:custom", handler);
model.send({ payload, uuid });
model.send({ payload, id });
});
}

Expand Down Expand Up @@ -71,41 +71,43 @@ function get_source(model, source) {
* @property {[x: number, y: number]} target
*/

/** @type {import("npm:@anywidget/types").Render<Model>} */
export async function render({ model, el }) {
let div = document.createElement("div");
{
div.style.height = model.get("height");
div.style.backgroundColor = "black";
model.on("change:height", () => {
export default {
/** @type {import("npm:@anywidget/types").Render<Model>} */
async render({ model, el }) {
let div = document.createElement("div");
{
div.style.height = model.get("height");
});
}
let viewer = await vizarr.createViewer(div);
{
model.on("change:view_state", () => {
viewer.setViewState(model.get("view_state"));
});
viewer.on(
"viewStateChange",
debounce((/** @type {ViewState} */ update) => {
model.set("view_state", update);
model.save_changes();
}, 200),
);
}
{
// sources are append-only now
for (const config of model.get("_configs")) {
const source = get_source(model, config.source);
viewer.addImage({ ...config, source });
div.style.backgroundColor = "black";
model.on("change:height", () => {
div.style.height = model.get("height");
});
}
model.on("change:_configs", () => {
const last = model.get("_configs").at(-1);
if (!last) return;
const source = get_source(model, last.source);
viewer.addImage({ ...last, source });
});
}
el.appendChild(div);
}
let viewer = await vizarr.createViewer(div);
{
model.on("change:view_state", () => {
viewer.setViewState(model.get("view_state"));
});
viewer.on(
"viewStateChange",
debounce((/** @type {ViewState} */ update) => {
model.set("view_state", update);
model.save_changes();
}, 200),
);
}
{
// sources are append-only now
for (const config of model.get("_configs")) {
const source = get_source(model, config.source);
viewer.addImage({ ...config, source });
}
model.on("change:_configs", () => {
const last = model.get("_configs").at(-1);
if (!last) return;
const source = get_source(model, last.source);
viewer.addImage({ ...last, source });
});
}
el.appendChild(div);
},
};
98 changes: 66 additions & 32 deletions python/src/vizarr/_widget.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,81 @@
from __future__ import annotations

import concurrent.futures
import os
import pathlib
from typing import TYPE_CHECKING, TypeGuard

import anywidget
import traitlets
import pathlib

import zarr
import numpy as np
if TYPE_CHECKING:
import numpy as np
import zarr
import zarr.storage

__all__ = ["Viewer"]

THREAD_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count())


def is_zarr_node(obj: object) -> TypeGuard[zarr.Array | zarr.Group]:
return hasattr(obj, "store") and hasattr(obj, "_key_prefix")


def is_readable_store(obj: object) -> TypeGuard[zarr.storage.BaseStore]:
return hasattr(obj, "__getitem__") and hasattr(obj, "__contains__")


def has_array_protocol(obj: object) -> bool:
return hasattr(obj, "__array__") or hasattr(obj, "__array_interface__")


def handle_custom_message(widget: Viewer, msg: dict, _buffers: list[bytes]):
store, key_prefix = widget._store_paths[msg["payload"]["source_id"]]
key = key_prefix + msg["payload"]["key"].lstrip("/")

if msg["payload"]["type"] == "has":
widget.send({"id": msg["id"], "payload": key in store})
return

if msg["payload"]["type"] == "get":

def target():
try:
buffers = [store[key]]
except KeyError:
buffers = []
widget.send(
{"id": msg["id"], "payload": {"success": len(buffers) == 1}},
buffers,
)

THREAD_EXECUTOR.submit(target)
return

def _store_keyprefix(obj):
# Just grab the store and key_prefix from zarr.Array and zarr.Group objects
if isinstance(obj, (zarr.Array, zarr.Group)):
raise ValueError(f"Unknown message type: {msg['payload']['type']}")


def get_store_keyprefix(obj: zarr.Array | zarr.Group | np.ndarray | dict):
if is_zarr_node(obj):
# Just grab the store and key_prefix from zarr.Array and zarr.Group objects
return obj.store, obj._key_prefix

if isinstance(obj, np.ndarray):
if has_array_protocol(obj):
# Create an in-memory store, and write array as as single chunk
store = {}
import numpy as np
import zarr
import zarr.storage

store = zarr.storage.MemoryStore()
data = np.asarray(obj)
arr = zarr.create(
store=store, shape=obj.shape, chunks=obj.shape, dtype=obj.dtype
store=store, shape=data.shape, chunks=data.shape, dtype=data.dtype
)
arr[:] = obj
return store, ""

if hasattr(obj, "__getitem__") and hasattr(obj, "__contains__"):
if is_readable_store(obj):
return obj, ""

raise TypeError("Cannot normalize store path")
Expand All @@ -37,31 +90,12 @@ class Viewer(anywidget.AnyWidget):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._store_paths = []
self.on_msg(self._handle_custom_msg)

def _handle_custom_msg(self, msg, buffers):
store, key_prefix = self._store_paths[msg["payload"]["source_id"]]
key = key_prefix + msg["payload"]["key"].lstrip("/")

if msg["payload"]["type"] == "has":
self.send({"uuid": msg["uuid"], "payload": key in store})
return

if msg["payload"]["type"] == "get":
try:
buffers = [store[key]]
except KeyError:
buffers = []
self.send(
{"uuid": msg["uuid"], "payload": {"success": len(buffers) == 1}},
buffers,
)
return
self.on_msg(handle_custom_message)

def add_image(self, source, **config):
if not isinstance(source, str):
store, key_prefix = _store_keyprefix(source)
store, key_prefix = get_store_keyprefix(source)
source = {"id": len(self._store_paths)}
self._store_paths.append((store, key_prefix))
config["source"] = source
self._configs = self._configs + [config]
self._configs = [*self._configs, config]

0 comments on commit abc546a

Please sign in to comment.