-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: James Goppert <[email protected]>
- Loading branch information
Showing
4 changed files
with
104 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
use crate::ast; | ||
|
||
use tera::{Context, Tera}; | ||
|
||
static COLLIMATOR_TEMPLATE: &str = r#" | ||
import jax.numpy as jnp | ||
from collimator.framework import LeafSystem | ||
{% for class in def.classes %} | ||
class {{ class.name}} (LeafSystem): | ||
def __init__( | ||
self, | ||
*args, | ||
x0=[1.0, 0.0], | ||
m=1.0, | ||
g=9.81, | ||
L=1.0, | ||
b=0.0, | ||
full_state_output=True, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
# Declare parameters | ||
self.declare_dynamic_parameter("m", m) | ||
self.declare_dynamic_parameter("g", g) | ||
self.declare_dynamic_parameter("L", L) | ||
self.declare_dynamic_parameter("b", b) | ||
# Declare continuous state; it's default value indicates that its size is 2 | ||
# the attached ode callback should return the derivative of the state | ||
self.declare_continuous_state(default_value=jnp.array(x0), ode=self.ode) | ||
{% for comp in class.components -%} | ||
self.{{ comp.name }} = ca.SX.sym('{{ comp.name }}'); | ||
{% endfor -%} | ||
# Declare input port for the torque | ||
self.declare_input_port(name="u") | ||
if full_state_output: | ||
# Declare output port for the full state | ||
self.declare_continuous_state_output(name="x") | ||
else: | ||
def _observation_callback(time, state, *inputs, **parameters): | ||
return state.continuous_state[0] # output only theta | ||
self.declare_output_port( | ||
_observation_callback, | ||
name="y", | ||
requires_inputs=False, | ||
) | ||
def ode(self, time, state, *inputs, **parameters): | ||
# Get theta and omega from the continuous part of LeafSystem state | ||
theta, omega = state.continuous_state | ||
# Get parameters | ||
m = parameters["m"] | ||
g = parameters["g"] | ||
L = parameters["L"] | ||
b = parameters["b"] | ||
# Get input | ||
tau = inputs[0] | ||
# Reshape to scalar if input was an array | ||
tau = jnp.reshape(tau, ()) | ||
# Compute the time derivative of the state (ODE RHS) | ||
dot_theta = omega | ||
mLsq = m * L * L | ||
dot_omega = -(g / L) * jnp.sin(theta) - b * omega / mLsq + tau / mLsq | ||
# Return the derivative of the state | ||
return jnp.array([dot_theta, dot_omega] | ||
{% endfor %} | ||
"#; | ||
|
||
pub fn generate(def: &ast::StoredDefinition) { | ||
let mut tera = Tera::default(); | ||
tera.add_raw_template("template", COLLIMATOR_TEMPLATE) | ||
.unwrap(); | ||
let mut context = Context::new(); | ||
context.insert("def", def); | ||
println!("{}", tera.render("template", &context).unwrap()); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::generators::parse_file; | ||
|
||
#[test] | ||
fn test_generate_casadi_sx() { | ||
let def = parse_file("src/model.mo"); | ||
generate(&def); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
pub mod casadi_mx; | ||
pub mod casadi_sx; | ||
pub mod collimator; | ||
pub mod json; | ||
pub mod sympy; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,4 @@ model Integrator | |
equation | ||
der(x) = 1.0; | ||
der(y) = x; | ||
algorithm | ||
x := 1; | ||
end Integrator; |