From d87f32320bd720843f95f7d1fbd7f005b1606763 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 9 Apr 2024 23:49:44 -0400 Subject: [PATCH] Define and export a `sum` function --- crates/web/src/lib.rs | 13 +++++++++++++ packages/core/src/impl.ts | 13 +++++++++++++ packages/core/src/index.test.ts | 19 +++---------------- packages/core/src/index.ts | 1 + 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index ccbd532..0c07101 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -1586,6 +1586,19 @@ impl Block { self.instr(f, id::ty(t), expr) } + /// Return the variable ID for a new instruction accumulating `addend` into `accum`. + /// + /// Assumes `accum` and `addend` are defined and in scope. + #[wasm_bindgen(js_name = "addTo")] + pub fn add_to(&mut self, f: &mut FuncBuilder, accum: usize, addend: usize) -> usize { + let t = id::ty(f.ty_unit()); + let expr = rose::Expr::Add { + accum: id::var(accum), + addend: id::var(addend), + }; + self.instr(f, t, expr) + } + /// Return the variable ID for a new instruction resolving the given accumulator `var`. /// /// Assumes `var` is defined and in scope, and that `t` is the inner type of the reference type diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 7d1a4ce..51582b4 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -1160,6 +1160,19 @@ export const vec = ( return idVal(ctx, t, id) as Vec>; }; +/** Return the sum after computing each number via `f`. */ +export const sum = (index: I, f: (i: Symbolic) => Real): Real => { + const ctx = getCtx(); + const reals = ctx.func.tyF64(); + const acc = ctx.block.accum(ctx.func, ctx.func.tyRef(reals), realId(ctx, 0)); + vec(index, Null, (i) => { + const x = realId(ctx, f(i)); + const t = ctx.func.tyUnit(); + return idVal(ctx, t, ctx.block.addTo(ctx.func, acc, x)) as Null; + }); + return idVal(ctx, reals, ctx.block.resolve(ctx.func, reals, acc)) as Real; +}; + /** Return the variable ID for the abstract number or tangent `x`. */ const numId = (ctx: Context, x: Real | Tan): number => { if (typeof x === "object") return (x as any)[variable]; diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 4a60f35..b76f34e 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -35,6 +35,7 @@ import { sqrt, struct, sub, + sum, trunc, vec, vjp, @@ -233,12 +234,7 @@ describe("valid", () => { test("dot product", () => { const R3 = Vec(3, Real); - const dot = fn([R3, R3], Real, (u, v) => { - const x = mul(u[0], v[0]); - const y = mul(u[1], v[1]); - const z = mul(u[2], v[2]); - return add(add(x, y), z); - }); + const dot = fn([R3, R3], Real, (u, v) => sum(3, (i) => mul(u[i], v[i]))); const f = interp(dot); expect(f([1, 3, -5], [4, -2, -1])).toBe(3); }); @@ -280,16 +276,7 @@ describe("valid", () => { const Rn = Vec(n, Real); - const dot = fn([Rn, Rn], Real, (u, v) => { - const w = vec(n, Real, (i) => mul(u[i], v[i])); - let s = w[0]; - s = add(s, w[1]); - s = add(s, w[2]); - s = add(s, w[3]); - s = add(s, w[4]); - s = add(s, w[5]); - return s; - }); + const dot = fn([Rn, Rn], Real, (u, v) => sum(n, (i) => mul(u[i], v[i]))); const m = 5; const p = 7; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index ce0476c..9314eea 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -50,6 +50,7 @@ export { sqrt, struct, sub, + sum, trunc, vec, vjp,