diff --git a/wit/wasi-nn.wit b/wit/wasi-nn.wit index 1749c14..edc3d6b 100644 --- a/wit/wasi-nn.wit +++ b/wit/wasi-nn.wit @@ -15,8 +15,20 @@ world ml { import errors; } +/// Inference is performed on a specific `device`. +interface device { + /// Define where tensors reside and graphs execute. + enum location { + cpu, + gpu, + tpu + } +} + /// All inputs and outputs to an ML inference are represented as `tensor`s. interface tensor { + use device.{location}; + /// The dimensions of a tensor. /// /// The array length matches the tensor rank and each element in the array describes the size of @@ -44,8 +56,8 @@ interface tensor { type tensor-data = list; resource tensor { - constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data, - location: option); + /// Construct a tensor that lives on the host CPU. + constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data); // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor // containing a single value, use `[1]` for the tensor dimensions. @@ -55,7 +67,7 @@ interface tensor { ty: func() -> tensor-type; // Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`). - location: func() -> execution-target; + location: func() -> location; // Return the tensor data. If the tensor is located on a device other than the CPU, this // operation may result in an expensive data copy operation. @@ -74,8 +86,9 @@ interface tensor { /// framework (e.g., TensorFlow): interface graph { use errors.{error}; - use tensor.{tensor}; + use device.{location}; use inference.{graph-execution-context}; + use tensor.{tensor}; /// An execution graph for performing inference (i.e., a model). resource graph { @@ -93,21 +106,15 @@ interface graph { autodetect, } - /// Define where the graph should be executed. - enum execution-target { - cpu, - gpu, - tpu - } - /// The graph initialization data. /// /// This gets bundled up into an array of buffers because implementing backends may encode their /// graph IR in parts (e.g., OpenVINO stores its IR and weights separately). type graph-builder = list; - /// Load a `graph` from an opaque sequence of bytes to use for inference. - load: func(builder: list, encoding: graph-encoding, target: execution-target) -> result; + /// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device + /// `location`. + load: func(builder: list, encoding: graph-encoding, location: location) -> result; /// Load a `graph` by name. /// @@ -128,6 +135,11 @@ interface inference { /// TODO: this may no longer be necessary in WIT /// (https://github.com/WebAssembly/wasi-nn/issues/43) resource graph-execution-context { + /// Load a tensor using the graph context. Unlike the `tensor` constructor, this function + /// will co-locate the tensor data on a specific device using the graph's underlying + /// backend; this may avoid some copies, improving performance. + load-tensor: func(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data) -> result; + /// Define the inputs to use for inference. set-input: func(name: string, tensor: tensor) -> result<_, error>;