Skip to content

Commit

Permalink
Improve casadi sx generator.
Browse files Browse the repository at this point in the history
Signed-off-by: James Goppert <[email protected]>
  • Loading branch information
jgoppert committed Nov 19, 2024
1 parent 7c86d22 commit 3c6eeff
Show file tree
Hide file tree
Showing 16 changed files with 267 additions and 51 deletions.
57 changes: 57 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Cargo test",
"cargo": {
"args": [
"test",
"--no-run",
]
},
"args": []
},
{
"type": "lldb",
"request": "launch",
"name": "Debug executable 'rumoca'",
"cargo": {
"args": [
"build",
"--bin=rumoca",
"--package=rumoca"
],
"filter": {
"name": "rumoca",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'rumoca'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=rumoca",
"--package=rumoca"
],
"filter": {
"name": "rumoca",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}
43 changes: 43 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ license = "Apache-2.0"

[build-dependencies]
lalrpop = "0.22.0"
rerun_except = "1.0.0"

[dependencies]
clap = { version = "4.5.20", features = ["derive"] }
codespan-reporting = "0.11.1"
lalrpop-util = "0.22.0"
logos = "0.14.2"
serde = { version = "1.0.214", features = ["derive", "serde_derive"] }
serde_json = "1.0.132"
tera = "1.20.0"
unindent = "0.2.3"
3 changes: 3 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use rerun_except::rerun_except;

fn main() {
lalrpop::process_root().unwrap();
rerun_except(&["src/generators/templates/*.tera"]).unwrap();
}
10 changes: 10 additions & 0 deletions integrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

import sympy

class Integrator:

def __init__(self):
self.x = sympy.symbols('x');
self.y = sympy.symbols('y');


4 changes: 4 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ pub struct ComponentReference {

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Equation {
Der {
comp: ComponentReference,
rhs: Box<Expression>,
},
Simple {
lhs: Box<Expression>,
rhs: Box<Expression>,
Expand Down
2 changes: 1 addition & 1 deletion src/generators/casadi_mx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mod tests {

#[test]
fn test_generate_casadi_mx() {
let def = parse_file("src/model.mo");
let def = parse_file("src/model.mo").expect("failed to parse");
generate(&def);
}
}
22 changes: 6 additions & 16 deletions src/generators/casadi_sx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,15 @@ use crate::ast;

use tera::{Context, Tera};

static CASADI_SX_TEMPLATE: &str = r#"
import casadi as ca
{% for class in def.classes %}
class {{ class.name }}:
def __init__(self):
{% for comp in class.components -%}
self.{{ comp.name }} = ca.SX.sym('{{ comp.name }}');
{% endfor -%}
{% endfor %}
"#;

pub fn generate(def: &ast::StoredDefinition) {
let template = std::fs::read_to_string("src/generators/templates/casadi_sx.tera").
expect("failed to read template");
let mut tera = Tera::default();
tera.add_raw_template("template", CASADI_SX_TEMPLATE)
.unwrap();
tera.add_raw_template("casadi_sx", &template).expect("failed to add template");

let mut context = Context::new();
context.insert("def", def);
println!("{}", tera.render("template", &context).unwrap());
println!("{}", tera.render("casadi_sx", &context).unwrap());
}

#[cfg(test)]
Expand All @@ -30,7 +20,7 @@ mod tests {

#[test]
fn test_generate_casadi_sx() {
let def = parse_file("src/model.mo");
let def = parse_file("src/model.mo").expect("failed to parse");
generate(&def);
}
}
27 changes: 7 additions & 20 deletions src/generators/collimator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@ class {{ class.name}} (LeafSystem):
g=9.81,
L=1.0,
b=0.0,
full_state_output=True,
**kwargs,
):
super().__init__(*args, **kwargs)
# Declare parameters
self.declare_dynamic_parameter("m", m)
self.declare_dynamic_parameter("g", g)
self.declare_dynamic_parameter("L", L)
self.declare_dynamic_parameter("b", b)
# self.declare_dynamic_parameter("m", m)
# self.declare_dynamic_parameter("g", g)
# self.declare_dynamic_parameter("L", L)
# self.declare_dynamic_parameter("b", b)
# Declare continuous state; it's default value indicates that its size is 2
# the attached ode callback should return the derivative of the state
self.declare_continuous_state(default_value=jnp.array(x0), ode=self.ode)
{% for comp in class.components -%}
Expand All @@ -38,19 +35,8 @@ class {{ class.name}} (LeafSystem):
# Declare input port for the torque
self.declare_input_port(name="u")
if full_state_output:
# Declare output port for the full state
self.declare_continuous_state_output(name="x")
else:
self.declare_continuous_state_output(name="x")
def _observation_callback(time, state, *inputs, **parameters):
return state.continuous_state[0] # output only theta
self.declare_output_port(
_observation_callback,
name="y",
requires_inputs=False,
)
def ode(self, time, state, *inputs, **parameters):
# Get theta and omega from the continuous part of LeafSystem state
Expand All @@ -77,6 +63,7 @@ class {{ class.name}} (LeafSystem):
return jnp.array([dot_theta, dot_omega]
{% endfor %}
"#;

pub fn generate(def: &ast::StoredDefinition) {
Expand All @@ -95,7 +82,7 @@ mod tests {

#[test]
fn test_generate_casadi_sx() {
let def = parse_file("src/model.mo");
let def = parse_file("src/model.mo").expect("failed to parse");
generate(&def);
}
}
2 changes: 1 addition & 1 deletion src/generators/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod tests {

#[test]
fn test_generate_json() {
let def: ast::StoredDefinition = parse_file("src/model.mo");
let def: ast::StoredDefinition = parse_file("src/model.mo").expect("failed to parse");
generate(&def);
}
}
62 changes: 57 additions & 5 deletions src/generators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,66 @@ pub mod sympy;
use crate::ast;
use crate::lexer;
use crate::parser;
use crate::tokens::{LexicalError, Token};

use codespan_reporting::files::SimpleFiles;

use lexer::Lexer;
use parser::StoredDefinitionParser;
use lalrpop_util::ParseError;

pub fn parse_file(filename: &str) -> ast::StoredDefinition {
let source_code = std::fs::read_to_string(filename).unwrap();
let lexer = Lexer::new(&source_code);
pub fn parse_file(filename: &str) -> Result<ast::StoredDefinition, ParseError<usize, Token, LexicalError>> {
let mut files = SimpleFiles::new();
let file_id = files.add(
filename,
std::fs::read_to_string(filename).unwrap(),
);
let file = files.get(file_id).expect("failed to get file");
let lexer = Lexer::new(file.source());
let parser = StoredDefinitionParser::new();
let def = parser.parse(lexer).expect("failed to parse");
return def;
parser.parse(lexer)
// if def.is_err() {
// // let type_id = def.as_ref().unwrap().type_id();
// // match type_id {
// // UnrecognizedToken => println!("unrecognized token!"),
// // _ => println!("type unhandled {:?}", type_id.)
// // }
// let err = def.as_ref().expect_err("error");

// let writer = StandardStream::stderr(ColorChoice::Always);
// let config = codespan_reporting::term::Config::default();

// match err {
// ParseError::InvalidToken { location } => {
// println!("invalid token loc:{}", location)
// },
// // ParseError::UnrecognizedEof { location , expected } => {
// // println!("unrecognized EOF {}, expected:", location);
// // // for tok in expected {
// // // println!("expected: {}", tok)
// // // }
// // },
// ParseError::UnrecognizedToken { token, expected } => {
// // for tok in expected {
// // println!("{}", tok)
// // }
// let diagonistic = Diagnostic::error()
// .with_message("failed to parse")
// .with_code("E001")
// .with_labels(vec![
// Label::primary(file_id, (token.0)
// Label::secondary(file_id, (0)..(token.2+100)),
// ])
// .with_notes(vec![expected[0].clone(), unindent(
// "
// expected type \"=\"
// "
// )]);
// codespan_reporting::term::emit(&mut writer.lock(), &config, &files, &diagonistic).expect("fail");
// }
// _ => { println!("unhandled") }
// }

// }
//return def.expect("failed to parse");
}
2 changes: 1 addition & 1 deletion src/generators/sympy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mod tests {

#[test]
fn test_generate_sympy() {
let def: ast::StoredDefinition = parse_file("src/model.mo");
let def: ast::StoredDefinition = parse_file("src/model.mo").expect("failed to parse");
generate(&def);
}
}
Loading

0 comments on commit 3c6eeff

Please sign in to comment.