Skip to content

Commit

Permalink
Add functions to parser (#57)
Browse files Browse the repository at this point in the history
* Add functions

* clippy

* added expr tests in parser

* Update error

* added identity to parser tests

* testing existance of std functions

* fmt

---------

Co-authored-by: Aleksander Tudruj <[email protected]>
  • Loading branch information
mgr0dzicki and tudny authored Apr 30, 2023
1 parent 0975c2d commit e10ee08
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 35 deletions.
70 changes: 68 additions & 2 deletions src/environment.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::btree_map::IterMut;
use std::collections::BTreeMap;

use anyhow::bail;
use anyhow::{bail, Context};

use crate::locale::Locale;
use crate::traits::{GuiDisplayable, LaTeXable};
Expand All @@ -23,6 +23,10 @@ impl Identifier {
}
}

fn new_unsafe(id: String) -> Self {
Self { id }
}

pub fn result() -> Self {
Self {
id: Self::RESULT.to_string(),
Expand Down Expand Up @@ -64,6 +68,20 @@ impl<T: MatrixNumber> Type<T> {
pub fn from_matrix_result(opt: anyhow::Result<Matrix<T>>) -> anyhow::Result<Self> {
Ok(Self::Matrix(opt?))
}

pub fn into_scalar(self) -> anyhow::Result<T> {
match self {
Type::Scalar(s) => Ok(s),
Type::Matrix(_) => bail!("Expected scalar, got matrix."),
}
}

pub fn into_matrix(self) -> anyhow::Result<Matrix<T>> {
match self {
Type::Matrix(m) => Ok(m),
Type::Scalar(_) => bail!("Expected matrix, got scalar."),
}
}
}

impl<T: MatrixNumber> ToString for Type<T> {
Expand Down Expand Up @@ -106,25 +124,58 @@ impl<T: MatrixNumber> LaTeXable for Type<T> {
}
}

pub type Callable<T> = dyn Fn(Type<T>) -> anyhow::Result<Type<T>>;

fn builtin_functions<T: MatrixNumber>() -> BTreeMap<Identifier, Box<Callable<T>>> {
BTreeMap::from([
(
Identifier::new_unsafe("transpose".to_string()),
Box::new(|t: Type<T>| Ok(Type::Matrix(t.into_matrix()?.transpose())))
as Box<Callable<T>>,
),
(
Identifier::new_unsafe("identity".to_string()),
Box::new(|t: Type<T>| {
Ok(Type::Matrix(Matrix::identity(
t.into_scalar()?
.to_usize()
.context("Invalid identity argument")?,
)))
}) as Box<Callable<T>>,
),
(
Identifier::new_unsafe("inverse".to_string()),
Box::new(|t: Type<T>| Ok(Type::Matrix(t.into_matrix()?.inverse()?.result)))
as Box<Callable<T>>,
),
])
}

pub struct Environment<T: MatrixNumber> {
env: BTreeMap<Identifier, Type<T>>,
fun: BTreeMap<Identifier, Box<Callable<T>>>,
}

impl<T: MatrixNumber> Environment<T> {
pub fn new() -> Self {
Self {
env: BTreeMap::new(),
fun: builtin_functions(),
}
}

pub fn insert(&mut self, id: Identifier, value: Type<T>) {
self.env.insert(id, value);
}

pub fn get(&self, id: &Identifier) -> Option<&Type<T>> {
pub fn get_value(&self, id: &Identifier) -> Option<&Type<T>> {
self.env.get(id)
}

pub fn get_function(&self, id: &Identifier) -> Option<&Callable<T>> {
self.fun.get(id).map(|f| f.as_ref())
}

pub fn iter_mut(&mut self) -> IterMut<'_, Identifier, Type<T>> {
self.env.iter_mut()
}
Expand Down Expand Up @@ -155,4 +206,19 @@ mod tests {
assert!(matches!(Identifier::new("32".to_string()), Err(_)));
assert!(matches!(Identifier::new("".to_string()), Err(_)));
}

#[test]
fn test_env_contains_std_fun() {
let env = Environment::<i64>::new();

assert!(env
.get_function(&Identifier::new_unsafe("transpose".to_string()))
.is_some());
assert!(env
.get_function(&Identifier::new_unsafe("identity".to_string()))
.is_some());
assert!(env
.get_function(&Identifier::new_unsafe("inverse".to_string()))
.is_some());
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl<K: MatrixNumber> eframe::App for MatrixApp<K> {
let mut windows_result = None;
for (id, window) in self.state.windows.iter_mut() {
if window.is_open {
let element = self.state.env.get(id).unwrap();
let element = self.state.env.get_value(id).unwrap();
let local_result = display_env_element_window(
ctx,
(id, element),
Expand Down
185 changes: 153 additions & 32 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ impl<'a> Tokenizer<'a> {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Clone, PartialEq, Eq)]
enum WorkingToken<T: MatrixNumber> {
Type(Type<T>),
Function(Identifier),
UnaryOp(char),
BinaryOp(char),
LeftBracket,
Expand All @@ -87,6 +88,7 @@ impl<T: MatrixNumber> Display for WorkingToken<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
WorkingToken::Type(_) => write!(f, "value token"),
WorkingToken::Function(_) => write!(f, "function token"),
WorkingToken::UnaryOp(op) => write!(f, "unary operator \"{op}\""),
WorkingToken::BinaryOp(op) => write!(f, "binary operator \"{op}\""),
WorkingToken::LeftBracket => write!(f, "( bracket"),
Expand Down Expand Up @@ -114,23 +116,35 @@ fn binary_op<T: MatrixNumber>(left: Type<T>, right: Type<T>, op: char) -> anyhow
(Type::Scalar(l), Type::Matrix(r)) => Type::from_matrix_result(r.checked_mul_scl(&l)),
},
'/' => match (left, right) {
(Type::Scalar(l), Type::Scalar(r)) => if !r.is_zero() {
Type::from_scalar_option(l.checked_div(&r))
} else {
bail!("Division by zero!")
},
(Type::Matrix(_), Type::Matrix(_)) => bail!("WTF dividing by matrix? You should use the `inv` function (not implemented yet, wait for it...)"),
(Type::Matrix(_), Type::Scalar(_)) => bail!("Diving matrix by scalar is not supported yet..."),
(Type::Scalar(_), Type::Matrix(_)) => bail!("Diving scalar by matrix does not make sense!"),
(Type::Scalar(l), Type::Scalar(r)) => {
if !r.is_zero() {
Type::from_scalar_option(l.checked_div(&r))
} else {
bail!("Division by zero!")
}
}
(Type::Matrix(_), Type::Matrix(_)) => {
bail!("WTF dividing by matrix? You should use the `inverse` function instead!")
}
(Type::Matrix(_), Type::Scalar(_)) => {
bail!("Diving matrix by scalar is not supported yet...")
}
(Type::Scalar(_), Type::Matrix(_)) => {
bail!("Diving scalar by matrix does not make sense!")
}
},
'^' => if let Type::Scalar(exp) = right {
let exp = exp.to_usize().context("Exponent should be a nonnegative integer.")?;
match left {
Type::Scalar(base) => Type::from_scalar_option(checked_pow(base, exp)),
Type::Matrix(base) => Type::from_matrix_result(base.checked_pow(exp)),
'^' => {
if let Type::Scalar(exp) = right {
let exp = exp
.to_usize()
.context("Exponent should be a nonnegative integer.")?;
match left {
Type::Scalar(base) => Type::from_scalar_option(checked_pow(base, exp)),
Type::Matrix(base) => Type::from_matrix_result(base.checked_pow(exp)),
}
} else {
bail!("Exponent cannot be a matrix!");
}
} else {
bail!("Exponent cannot be a matrix!");
}
_ => unimplemented!(),
}
Expand All @@ -155,7 +169,7 @@ fn unary_op<T: MatrixNumber>(arg: Type<T>, op: char) -> anyhow::Result<Type<T>>
<unary_op> ::= "+" | "-"
<binary_op> ::= "+" | "-" | "*" | "/"
<expr> ::= <integer> | <identifier> | <expr> <binary_op> <expr>
| "(" <expr> ")" | <unary_op> <expr>
| "(" <expr> ")" | <unary_op> <expr> | <identifier> "(" <expr> ")"
*/
pub fn parse_expression<T: MatrixNumber>(
raw: &str,
Expand Down Expand Up @@ -185,6 +199,7 @@ pub fn parse_expression<T: MatrixNumber>(
None | Some(WorkingToken::LeftBracket)
| Some(WorkingToken::BinaryOp(_))
| Some(WorkingToken::UnaryOp(_))
| Some(WorkingToken::Function(_))
),
Token::Operator(_) => matches!(
previous,
Expand Down Expand Up @@ -221,15 +236,18 @@ pub fn parse_expression<T: MatrixNumber>(
outputs.back()
}
Token::Identifier(id) => {
outputs.push_back(WorkingToken::Type(
env.get(id)
.context(format!(
"Undefined identifier! Object \"{}\" is unknown.",
id.to_string()
))?
.clone(),
));
outputs.back()
if let Some(value) = env.get_value(id) {
outputs.push_back(WorkingToken::Type(value.clone()));
outputs.back()
} else if env.get_function(id).is_some() {
operators.push_front(WorkingToken::Function(id.clone()));
operators.front()
} else {
bail!(
"Undefined identifier! Object \"{}\" is unknown.",
id.to_string()
)
}
}
Token::LeftBracket => {
operators.push_front(WorkingToken::LeftBracket);
Expand All @@ -248,10 +266,11 @@ pub fn parse_expression<T: MatrixNumber>(
bail!("Mismatched brackets!");
}
if let Some(op) = operators.pop_front() {
if matches!(op, WorkingToken::UnaryOp(_)) {
outputs.push_back(op);
} else {
operators.push_front(op);
match op {
WorkingToken::UnaryOp(_) | WorkingToken::Function(_) => {
outputs.push_back(op)
}
_ => operators.push_front(op),
}
}
Some(&WorkingToken::RightBracket)
Expand Down Expand Up @@ -312,6 +331,10 @@ pub fn parse_expression<T: MatrixNumber>(
let arg = val_stack.pop_front().context("Invalid expression!")?;
val_stack.push_front(unary_op(arg, op)?);
}
WorkingToken::Function(id) => {
let arg = val_stack.pop_front().context("Invalid expression!")?;
val_stack.push_front(env.get_function(&id).unwrap()(arg)?);
}
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -546,7 +569,8 @@ mod tests {
}

assert_eq!(
*env.get(&Identifier::new("b".to_string()).unwrap()).unwrap(),
*env.get_value(&Identifier::new("b".to_string()).unwrap())
.unwrap(),
Type::<i64>::Scalar(89)
);
}
Expand All @@ -561,8 +585,105 @@ mod tests {
exec("a = $ ^ $");

assert_eq!(
*env.get(&Identifier::new("a".to_string()).unwrap()).unwrap(),
*env.get_value(&Identifier::new("a".to_string()).unwrap())
.unwrap(),
Type::<i64>::Scalar(256)
);
}

#[test]
fn test_expression_functions() {
let mut env = Environment::new();

let a = im![1, 2, 3; 4, 5, 6];
let at = im![1, 4; 2, 5; 3, 6];
let b = im![1, 2; 3, 4];

env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
env.insert(
Identifier::new("B".to_string()).unwrap(),
Type::Matrix(b.clone()),
);

assert_eq!(
parse_expression("transpose(A)", &env).unwrap(),
Type::Matrix(at)
);
assert_eq!(
parse_expression("identity(4)", &env).unwrap(),
Type::Matrix(Matrix::identity(4))
);
assert_eq!(
parse_expression("inverse(B)", &env).unwrap(),
Type::Matrix(b.inverse().unwrap().result)
);
}

#[test]
fn test_nested_functions() {
let mut env = Environment::new();

let a = im![1, 2, 3; 4, 5, 6];
let att = im![1, 2, 3; 4, 5, 6];

env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));

assert_eq!(
parse_expression("transpose(transpose(A))", &env).unwrap(),
Type::Matrix(att)
)
}

#[test]
fn test_expr_with_function() {
let mut env = Environment::new();

let a = im![1, 2, 3; 4, 5, 6];
let b = im![1, 2; 3, 4];

env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
env.insert(
Identifier::new("B".to_string()).unwrap(),
Type::Matrix(b.clone()),
);

assert_eq!(
parse_expression("transpose(A) * B", &env).unwrap(),
Type::Matrix(im![13, 18; 17, 24; 21, 30])
);
}

#[test]
fn test_expr_in_function() {
let mut env = Environment::new();

let a = im![1, 2, 3; 4, 5, 6];
let i = Matrix::identity(2);
let at = im![1, 4; 2, 5; 3, 6];

env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));
env.insert(Identifier::new("I".to_string()).unwrap(), Type::Matrix(i));

assert_eq!(
parse_expression("transpose(I * A)", &env).unwrap(),
Type::Matrix(at)
);
}

#[test]
fn test_complex_nested_function_with_expr() {
let mut env = Environment::new();

let a = im![1, 2, 3; 4, 5, 6];

env.insert(Identifier::new("A".to_string()).unwrap(), Type::Matrix(a));

assert_eq!(
parse_expression(
"transpose(transpose(identity(2137 - 2135 + 1 - 1 + (42 - 420) * 0) * A) + transpose(identity(2) * A))",
&env
).unwrap(),
Type::Matrix(im![2, 4, 6; 8, 10, 12])
);
}
}

0 comments on commit e10ee08

Please sign in to comment.