From 2de2bea8502d0009eba6f207d2e135d12bac422b Mon Sep 17 00:00:00 2001 From: Will Date: Sat, 28 Sep 2024 16:26:16 -0700 Subject: [PATCH 1/3] dumping sympy equations to file. Fix to work on newer versions of python --- msdsl/eqn/lds.py | 125 +++++++++++++++++++++++++++++++++++++++++++++++ msdsl/model.py | 45 +++++++++++++++-- 2 files changed, 166 insertions(+), 4 deletions(-) diff --git a/msdsl/eqn/lds.py b/msdsl/eqn/lds.py index 5dde501..0f6eb26 100644 --- a/msdsl/eqn/lds.py +++ b/msdsl/eqn/lds.py @@ -3,6 +3,11 @@ import numpy as np import scipy.linalg +import sympy as sp +from msdsl.assignment import Assignment +from msdsl.expr.expr import Array, LessThan, GreaterThan, Product, Sum, EqualTo, UIntConstant, RealConstant, Constant +from msdsl.expr.signals import AnalogSignal, DigitalSignal + class LDS: def __init__(self, A=None, B=None, C=None, D=None): # save settings @@ -55,6 +60,31 @@ def __str__(self): # return result return retval + + def convert_to_sympy(self, states, inputs, outputs): + state_strings = list(map(lambda x: str(x), states)) + inputs_strings = list(map(lambda x: str(x), inputs)) + outputs_strings = list(map(lambda x: str(x), outputs)) + # Convert state, input, and output strings to sympy symbols + states = sp.symbols(state_strings) + inputs = sp.symbols(inputs_strings) + outputs = sp.symbols(outputs_strings) + + # Convert numpy arrays to sympy matrices + A_sym = sp.Matrix(self.A) + B_sym = sp.Matrix(self.B) + C_sym = sp.Matrix(self.C) + D_sym = sp.Matrix(self.D) + + # Define state-space equations + state_eq = A_sym * sp.Matrix(states) + B_sym * sp.Matrix(inputs) + output_eq = C_sym * sp.Matrix(states) + D_sym * sp.Matrix(inputs) + + # Explicitly compute the derivatives (dot{x}) + state_ode = sp.Matrix([sp.diff(state, 't') for state in states]) - state_eq + + return state_ode, output_eq + class LdsCollection: def __init__(self): @@ -75,3 +105,98 @@ def append(self, lds: LDS): self.B = np.concatenate((self.B, B), axis=2) if self.B is not None else B self.C = np.concatenate((self.C, C), axis=2) if self.C is not None else C self.D = np.concatenate((self.D, D), axis=2) if self.D is not None else D + + + def convert_to_sympy_piecewise(self, states, inputs, outputs, sel_bits, sel_eqns): + # Convert states, inputs, and outputs to sympy symbols + state_strings = list(map(str, states)) + inputs_strings = list(map(str, inputs)) + outputs_strings = list(map(str, outputs)) + + states = sp.symbols(state_strings) + inputs = sp.symbols(inputs_strings) + outputs = sp.symbols(outputs_strings) + + # Convert sel_bits to SymPy symbols if they aren't already + sel_bits_sympy = [sp.Symbol(str(sel_bit)) if not isinstance(sel_bit, sp.Basic) else sel_bit for sel_bit in sel_bits] + + # Initialize lists for state and output equations with default expressions + state_eq_piecewise = [sp.Piecewise((0, True)) for _ in range(len(states))] + output_eq_piecewise = [sp.Piecewise((0, True)) for _ in range(len(outputs))] + + # Iterate over all possible configurations of sel_bits + + for k in range(self.A.shape[2]): # Number of scenarios + A_sym = sp.Matrix(self.A[:, :, k]) + B_sym = sp.Matrix(self.B[:, :, k]) + C_sym = sp.Matrix(self.C[:, :, k]) + D_sym = sp.Matrix(self.D[:, :, k]) + + # Define state-space equations + state_eq = A_sym * sp.Matrix(states) + B_sym * sp.Matrix(inputs) + output_eq = C_sym * sp.Matrix(states) + D_sym * sp.Matrix(inputs) + + # Compute derivatives (dot{x}) for the state equations + state_ode = sp.Matrix([sp.diff(state, 't') for state in states]) - state_eq + + # Create the condition for this scenario + condition = True + for i, sel_bit_sym in enumerate(sel_bits_sympy): + bit_value = (k >> i) & 1 + # Use logical AND to build up the condition + condition = sp.And(condition, sp.Eq(sel_bit_sym, bit_value)) + + + # Process sel_eqns to adjust the condition if needed + for sel_eqn in sel_eqns: + + + if isinstance(sel_eqn, Assignment): + signal = sp.Symbol(sel_eqn.signal.name) + expr = msdsl_ast_to_sympy(sel_eqn.expr) + + condition = condition.subs(signal, expr) + + # Assign the corresponding equations to the Piecewise objects + for i in range(len(states)): + + state_eq_piecewise[i] = sp.Piecewise((state_ode[i], condition), (state_eq_piecewise[i], True)) + + for i in range(len(outputs)): + output_eq_piecewise[i] = sp.Piecewise((output_eq[i], condition), (output_eq_piecewise[i], True)) + + return state_eq_piecewise + output_eq_piecewise + + + + +def msdsl_ast_to_sympy(ast): + """ + Convert an AST from msdsl to a sympy expression. + """ + if isinstance(ast, LessThan): + return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, GreaterThan): + return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Product): + accum = 1 # Corrected from 0 to 1 to properly accumulate products + for operand in ast.operands: + accum *= msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, Sum): + accum = 0 + for operand in ast.operands: + accum += msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, EqualTo): + return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Array): + elements = ast.operands[:-1] + address = msdsl_ast_to_sympy(ast.operands[-1]) + return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)]) + elif isinstance(ast, Constant): + return ast.value + elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal): + return sp.Symbol(str(ast)) + else: + raise Exception(f"Unsupported AST node: {type(ast)}") diff --git a/msdsl/model.py b/msdsl/model.py index b4df97e..734c13c 100644 --- a/msdsl/model.py +++ b/msdsl/model.py @@ -1,4 +1,5 @@ -from collections import OrderedDict, Iterable +from collections import OrderedDict +from collections.abc import Iterable from itertools import chain from numbers import Integral, Number from typing import List, Set, Union @@ -778,7 +779,7 @@ def get_equation_io(self, eqn_sys: EqnSys): # determine sel_bits sel_bit_names = set(signal_names(eqn_sys.get_sel_bits())) sel_bits = self.get_signals(sel_bit_names) - + # return result return inputs, states, outputs, sel_bits @@ -794,6 +795,8 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N :param extra_outputs: List of internal variables in the system of equations that should be bound to analog signals. :param clk: Name of clock signal to use (None will default to `CLK_MSDSL) :param rst: Name of the reset signal to use (None will default to `RST_MSDSL) + + Returns an LDSCollection Object """ # set defaults @@ -804,7 +807,7 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N # analyze equation to find out knowns and unknowns inputs, states, outputs, sel_bits = self.get_equation_io(eqn_sys) - + # add the extra outputs as needed for extra_output in extra_outputs: if not isinstance(extra_output, Signal): @@ -832,6 +835,8 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N # add to collection of LDS systems collection.append(lds) + + # construct address for selection if len(sel_bits) > 0: @@ -844,6 +849,10 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N states=states, outputs=outputs, sel=sel, clk=clk, rst=rst) + return collection, inputs, states, outputs, sel_bits + + + def add_discrete_time_lds(self, collection, inputs=None, states=None, outputs=None, sel=None, clk=None, rst=None): # set defaults @@ -870,6 +879,25 @@ def add_discrete_time_lds(self, collection, inputs=None, states=None, outputs=No else: self.bind_name(outputs[row].name, expr) + def get_sympy_lds(self): + """ + Must be run after add_eqn_sys. + Returns an array of piecewise LDSs in the form of sympy equations. + """ + + sympy_lds = [] + + for circuit in self.circuits: + filtered_states = list(filter(lambda x: not isinstance(x, DigitalInput), circuit.sel_bits)) + state_str = list(map(lambda x: str(x), filtered_states)) + + sel_eqns = self.get_assignments(state_str) + + circuit_lds = circuit.collection.convert_to_sympy_piecewise(circuit.states, circuit.inputs, circuit.outputs, circuit.sel_bits, sel_eqns) + sympy_lds.append(circuit_lds) + + return sympy_lds + def set_tf(self, input_: Signal, output: Signal, tf, clk=None, rst=None): """ Method to assign an output signal as a function of the input signal by applying a given transfer function. @@ -1014,8 +1042,17 @@ def compile(self, gen: CodeGenerator): # compile circuits for circuit in self.circuits: eqns = circuit.compile_to_eqn_list() - self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst) + ldscollection, inputs, states, outputs, sel_bits = self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst) + + circuit.collection = ldscollection #Assign the circuit.collection object so we may use it later. + circuit.inputs = inputs + circuit.states = states + circuit.outputs = outputs + circuit.sel_bits = sel_bits + + + # determine the I/Os and internal variables ios = [] internals = [] From 5631a730b2610ad2605e8a845fc5ead1bb0206c9 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 3 Oct 2024 12:56:08 -0700 Subject: [PATCH 2/3] fixed lds generation --- msdsl/eqn/lds.py | 39 ++++----------------------------------- msdsl/expr/expr.py | 4 ++++ msdsl/expr/extras.py | 38 ++++++++++++++++++++++++++++++++++++-- msdsl/model.py | 33 +++++++++++++++++++++++++++++++-- msdsl/util.py | 35 ++++++++++++++++++++++++++++++++++- 5 files changed, 109 insertions(+), 40 deletions(-) diff --git a/msdsl/eqn/lds.py b/msdsl/eqn/lds.py index 0f6eb26..52a174f 100644 --- a/msdsl/eqn/lds.py +++ b/msdsl/eqn/lds.py @@ -4,9 +4,8 @@ import scipy.linalg import sympy as sp +from msdsl.expr.extras import msdsl_ast_to_sympy from msdsl.assignment import Assignment -from msdsl.expr.expr import Array, LessThan, GreaterThan, Product, Sum, EqualTo, UIntConstant, RealConstant, Constant -from msdsl.expr.signals import AnalogSignal, DigitalSignal class LDS: def __init__(self, A=None, B=None, C=None, D=None): @@ -159,44 +158,14 @@ def convert_to_sympy_piecewise(self, states, inputs, outputs, sel_bits, sel_eqns # Assign the corresponding equations to the Piecewise objects for i in range(len(states)): - - state_eq_piecewise[i] = sp.Piecewise((state_ode[i], condition), (state_eq_piecewise[i], True)) + state_eq_piecewise[i] = sp.Eq( sp.Derivative(states[i],sp.Symbol('t')), sp.Piecewise((state_ode[i], condition), (state_eq_piecewise[i], True))) for i in range(len(outputs)): - output_eq_piecewise[i] = sp.Piecewise((output_eq[i], condition), (output_eq_piecewise[i], True)) + output_eq_piecewise[i] = sp.Eq(outputs[i], sp.Piecewise((output_eq[i], condition), (output_eq_piecewise[i], True))) + # For each equation we need to substitute in return state_eq_piecewise + output_eq_piecewise -def msdsl_ast_to_sympy(ast): - """ - Convert an AST from msdsl to a sympy expression. - """ - if isinstance(ast, LessThan): - return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) - elif isinstance(ast, GreaterThan): - return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) - elif isinstance(ast, Product): - accum = 1 # Corrected from 0 to 1 to properly accumulate products - for operand in ast.operands: - accum *= msdsl_ast_to_sympy(operand) - return accum - elif isinstance(ast, Sum): - accum = 0 - for operand in ast.operands: - accum += msdsl_ast_to_sympy(operand) - return accum - elif isinstance(ast, EqualTo): - return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) - elif isinstance(ast, Array): - elements = ast.operands[:-1] - address = msdsl_ast_to_sympy(ast.operands[-1]) - return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)]) - elif isinstance(ast, Constant): - return ast.value - elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal): - return sp.Symbol(str(ast)) - else: - raise Exception(f"Unsupported AST node: {type(ast)}") diff --git a/msdsl/expr/expr.py b/msdsl/expr/expr.py index cf5a708..0bd3287 100644 --- a/msdsl/expr/expr.py +++ b/msdsl/expr/expr.py @@ -6,6 +6,7 @@ from msdsl.expr.format import RealFormat, SIntFormat, UIntFormat, Format, IntFormat +import sympy as sp # constant wrapping def wrap_constant(operand): @@ -1039,6 +1040,9 @@ def mt19937(clk=None, rst=None, cke=None, seed=None): def lcg_op(clk=None, rst=None, cke=None, seed=None): return LCG(clk=clk, rst=rst, cke=cke, seed=seed) + + + # testing def main(): diff --git a/msdsl/expr/extras.py b/msdsl/expr/extras.py index 794397f..9969790 100644 --- a/msdsl/expr/extras.py +++ b/msdsl/expr/extras.py @@ -1,6 +1,9 @@ from typing import Union, List from numbers import Number, Integral -from msdsl.expr.expr import ModelExpr, concatenate, BitwiseAnd, array +from msdsl.expr.expr import ModelExpr, concatenate, BitwiseAnd, array, LessThan, GreaterThan, Product, Sum, EqualTo, Array, Constant +from msdsl.expr.signals import AnalogSignal, DigitalSignal + +import sympy as sp def all_between(x: List[ModelExpr], lo: Union[Number, ModelExpr], hi: Union[Number, ModelExpr]) -> ModelExpr: """ @@ -37,4 +40,35 @@ def if_(condition, then, else_): :param else_: Action to be executed for False case :return: Boolean """ - return array([else_, then], condition) \ No newline at end of file + return array([else_, then], condition) + +def msdsl_ast_to_sympy(ast): + """ + Convert an AST from msdsl to a sympy expression. + """ + if isinstance(ast, LessThan): + return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, GreaterThan): + return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Product): + accum = 1 # Corrected from 0 to 1 to properly accumulate products + for operand in ast.operands: + accum *= msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, Sum): + accum = 0 + for operand in ast.operands: + accum += msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, EqualTo): + return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Array): + elements = ast.operands[:-1] + address = msdsl_ast_to_sympy(ast.operands[-1]) + return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)]) + elif isinstance(ast, Constant): + return ast.value + elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal): + return sp.Symbol(str(ast)) + else: + raise Exception(f"Unsupported AST node: {type(ast)}") \ No newline at end of file diff --git a/msdsl/model.py b/msdsl/model.py index 734c13c..4f028ce 100644 --- a/msdsl/model.py +++ b/msdsl/model.py @@ -31,8 +31,12 @@ from msdsl.function import GeneralFunction, Function, PlaceholderFunction, MultiFunction from msdsl.lfsr import LFSR +from msdsl.expr.extras import msdsl_ast_to_sympy + from scipy.signal import cont2discrete +import sympy as sp + class Bus: def __init__(self, signal: Signal, n: Integral): self.signal = signal @@ -799,6 +803,8 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N Returns an LDSCollection Object """ + + # set defaults extra_outputs = extra_outputs if extra_outputs is not None else [] @@ -883,20 +889,40 @@ def get_sympy_lds(self): """ Must be run after add_eqn_sys. Returns an array of piecewise LDSs in the form of sympy equations. + Appends all assignments. """ sympy_lds = [] - + + #self.assignments + for circuit in self.circuits: filtered_states = list(filter(lambda x: not isinstance(x, DigitalInput), circuit.sel_bits)) state_str = list(map(lambda x: str(x), filtered_states)) sel_eqns = self.get_assignments(state_str) + circuit_lds = circuit.collection.convert_to_sympy_piecewise(circuit.states, circuit.inputs, circuit.outputs, circuit.sel_bits, sel_eqns) + circuit.sympy_eqs = circuit_lds sympy_lds.append(circuit_lds) + for symbol_name, assignment_obj in self.unmodified_assignments.items(): + sympy_lds.append(sp.Eq(sp.Symbol(symbol_name), msdsl_ast_to_sympy(assignment_obj.expr))) + return sympy_lds + + def write_sympy_lds_to_file(self, filename): + """ + Must be run after compile(). + Writes the sympy equations to json format. + """ + + sympy_lds = self.diffeqs + + with open(filename, 'w') as f: + for eq in sympy_lds: + f.write(str(eq) + '\n') def set_tf(self, input_: Signal, output: Signal, tf, clk=None, rst=None): """ @@ -1040,6 +1066,7 @@ def make_circuit(self, clk=None, rst=None): def compile(self, gen: CodeGenerator): # compile circuits + self.unmodified_assignments = deepcopy(self.assignments) for circuit in self.circuits: eqns = circuit.compile_to_eqn_list() ldscollection, inputs, states, outputs, sel_bits = self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst) @@ -1051,7 +1078,9 @@ def compile(self, gen: CodeGenerator): circuit.sel_bits = sel_bits - + + self.diffeqs = self.get_sympy_lds() + # determine the I/Os and internal variables ios = [] diff --git a/msdsl/util.py b/msdsl/util.py index 3cab005..9112d61 100644 --- a/msdsl/util.py +++ b/msdsl/util.py @@ -33,9 +33,42 @@ def warn(s): def list2dict(l): return {elem: k for k, elem in enumerate(l)} + +def msdsl_ast_to_sympy(ast): + """ + Convert an AST from msdsl to a sympy expression. + """ + if isinstance(ast, LessThan): + return sp.Lt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, GreaterThan): + return sp.Gt(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Product): + accum = 1 # Corrected from 0 to 1 to properly accumulate products + for operand in ast.operands: + accum *= msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, Sum): + accum = 0 + for operand in ast.operands: + accum += msdsl_ast_to_sympy(operand) + return accum + elif isinstance(ast, EqualTo): + return sp.Eq(msdsl_ast_to_sympy(ast.lhs), msdsl_ast_to_sympy(ast.rhs)) + elif isinstance(ast, Array): + elements = ast.operands[:-1] + address = msdsl_ast_to_sympy(ast.operands[-1]) + return sp.Piecewise(*[(msdsl_ast_to_sympy(elem), address if i == 1 else sp.Not(address)) for i, elem in enumerate(elements)]) + elif isinstance(ast, Constant): + return ast.value + elif isinstance(ast, AnalogSignal) or isinstance(ast, DigitalSignal): + return sp.Symbol(str(ast)) + else: + raise Exception(f"Unsupported AST node: {type(ast)}") + def main(): # list2dict tests print(list2dict(['a', 'b', 'c'])) if __name__ == '__main__': - main() \ No newline at end of file + main() + From a2fdfea727c9c76188e04bb61a8b67336a8746a5 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 3 Oct 2024 13:17:34 -0700 Subject: [PATCH 3/3] change output array format --- msdsl/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/msdsl/model.py b/msdsl/model.py index 4f028ce..8ab3068 100644 --- a/msdsl/model.py +++ b/msdsl/model.py @@ -854,7 +854,7 @@ def add_eqn_sys(self, eqns: List[ModelExpr], extra_outputs=None, clk=None, rst=N self.add_discrete_time_lds(collection=collection, inputs=inputs, states=states, outputs=outputs, sel=sel, clk=clk, rst=rst) - + return collection, inputs, states, outputs, sel_bits @@ -895,7 +895,7 @@ def get_sympy_lds(self): sympy_lds = [] #self.assignments - + for circuit in self.circuits: filtered_states = list(filter(lambda x: not isinstance(x, DigitalInput), circuit.sel_bits)) state_str = list(map(lambda x: str(x), filtered_states)) @@ -905,7 +905,7 @@ def get_sympy_lds(self): circuit_lds = circuit.collection.convert_to_sympy_piecewise(circuit.states, circuit.inputs, circuit.outputs, circuit.sel_bits, sel_eqns) circuit.sympy_eqs = circuit_lds - sympy_lds.append(circuit_lds) + sympy_lds += circuit_lds for symbol_name, assignment_obj in self.unmodified_assignments.items(): sympy_lds.append(sp.Eq(sp.Symbol(symbol_name), msdsl_ast_to_sympy(assignment_obj.expr))) @@ -1067,8 +1067,10 @@ def make_circuit(self, clk=None, rst=None): def compile(self, gen: CodeGenerator): # compile circuits self.unmodified_assignments = deepcopy(self.assignments) + for circuit in self.circuits: eqns = circuit.compile_to_eqn_list() + ldscollection, inputs, states, outputs, sel_bits = self.add_eqn_sys(eqns, circuit.extra_outputs, clk=circuit.clk, rst=circuit.rst) circuit.collection = ldscollection #Assign the circuit.collection object so we may use it later. @@ -1076,8 +1078,8 @@ def compile(self, gen: CodeGenerator): circuit.states = states circuit.outputs = outputs circuit.sel_bits = sel_bits - - + + self.diffeqs = self.get_sympy_lds()