Skip to content

Commit

Permalink
Add collimator generator.
Browse files Browse the repository at this point in the history
Signed-off-by: James Goppert <[email protected]>
  • Loading branch information
jgoppert committed Nov 11, 2024
1 parent fd967bd commit 7c86d22
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 2 deletions.
101 changes: 101 additions & 0 deletions src/generators/collimator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use crate::ast;

use tera::{Context, Tera};

static COLLIMATOR_TEMPLATE: &str = r#"
import jax.numpy as jnp
from collimator.framework import LeafSystem
{% for class in def.classes %}
class {{ class.name}} (LeafSystem):
def __init__(
self,
*args,
x0=[1.0, 0.0],
m=1.0,
g=9.81,
L=1.0,
b=0.0,
full_state_output=True,
**kwargs,
):
super().__init__(*args, **kwargs)
# Declare parameters
self.declare_dynamic_parameter("m", m)
self.declare_dynamic_parameter("g", g)
self.declare_dynamic_parameter("L", L)
self.declare_dynamic_parameter("b", b)
# Declare continuous state; it's default value indicates that its size is 2
# the attached ode callback should return the derivative of the state
self.declare_continuous_state(default_value=jnp.array(x0), ode=self.ode)
{% for comp in class.components -%}
self.{{ comp.name }} = ca.SX.sym('{{ comp.name }}');
{% endfor -%}
# Declare input port for the torque
self.declare_input_port(name="u")
if full_state_output:
# Declare output port for the full state
self.declare_continuous_state_output(name="x")
else:
def _observation_callback(time, state, *inputs, **parameters):
return state.continuous_state[0] # output only theta
self.declare_output_port(
_observation_callback,
name="y",
requires_inputs=False,
)
def ode(self, time, state, *inputs, **parameters):
# Get theta and omega from the continuous part of LeafSystem state
theta, omega = state.continuous_state
# Get parameters
m = parameters["m"]
g = parameters["g"]
L = parameters["L"]
b = parameters["b"]
# Get input
tau = inputs[0]
# Reshape to scalar if input was an array
tau = jnp.reshape(tau, ())
# Compute the time derivative of the state (ODE RHS)
dot_theta = omega
mLsq = m * L * L
dot_omega = -(g / L) * jnp.sin(theta) - b * omega / mLsq + tau / mLsq
# Return the derivative of the state
return jnp.array([dot_theta, dot_omega]
{% endfor %}
"#;

pub fn generate(def: &ast::StoredDefinition) {
let mut tera = Tera::default();
tera.add_raw_template("template", COLLIMATOR_TEMPLATE)
.unwrap();
let mut context = Context::new();
context.insert("def", def);
println!("{}", tera.render("template", &context).unwrap());
}

#[cfg(test)]
mod tests {
use super::*;
use crate::generators::parse_file;

#[test]
fn test_generate_casadi_sx() {
let def = parse_file("src/model.mo");
generate(&def);
}
}
1 change: 1 addition & 0 deletions src/generators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod casadi_mx;
pub mod casadi_sx;
pub mod collimator;
pub mod json;
pub mod sympy;

Expand Down
2 changes: 2 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ enum Generator {
Json,
CasadiMx,
CasadiSx,
Collimator,
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -49,5 +50,6 @@ fn main() {
Generator::Sympy => generators::sympy::generate(&def),
Generator::CasadiMx => generators::casadi_mx::generate(&def),
Generator::CasadiSx => generators::casadi_sx::generate(&def),
Generator::Collimator => generators::collimator::generate(&def),
}
}
2 changes: 0 additions & 2 deletions src/model.mo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@ model Integrator
equation
der(x) = 1.0;
der(y) = x;
algorithm
x := 1;
end Integrator;

0 comments on commit 7c86d22

Please sign in to comment.