Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parameter updates with ECOS #47

Merged
merged 7 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cvxpygen/cpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def generate_code(problem, code_dir='CPG_code', solver=None, solver_opts=None,
interface_class, cvxpy_interface_class = get_interface_class(solver_name)

# configuration
configuration = get_configuration(code_dir, solver, unroll, prefix)
configuration = get_configuration(code_dir, solver_name, unroll, prefix)

# cone problems check
if hasattr(param_prob, 'cone_dims'):
Expand Down
15 changes: 10 additions & 5 deletions cvxpygen/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ def __init__(self, data, p_prob, enable_settings):
self.parameter_update_structure = {
'init': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic([], extra_condition='!{prefix}ecos_workspace', functions_if_false=['AbcGh']),
function_call=f'{{prefix}}ecos_workspace = ECOS_setup({canon_constants["n"]}, {canon_constants["m"]}, {canon_constants["p"]}, {canon_constants["l"]}, {canon_constants["n_cones"]}'
function_call=f'{{prefix}}cpg_copy_all();\n'
f' {{prefix}}ecos_workspace = ECOS_setup({canon_constants["n"]}, {canon_constants["m"]}, {canon_constants["p"]}, {canon_constants["l"]}, {canon_constants["n_cones"]}'
f', {"0" if canon_constants["n_cones"] == 0 else "(int *) &{prefix}ecos_q"}, {canon_constants["e"]}'
f', {{prefix}}Canon_Params_conditioning.G->x, {{prefix}}Canon_Params_conditioning.G->p, {{prefix}}Canon_Params_conditioning.G->i'
f', {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.A->x"}'
Expand All @@ -720,16 +721,19 @@ def __init__(self, data, p_prob, enable_settings):
),
'AbcGh': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic(['A', 'b', 'G'], '||', ['c', 'h']),
function_call=f'ECOS_updateData({{prefix}}ecos_workspace, {{prefix}}Canon_Params_conditioning.G->x, {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.A->x"}'
function_call=f'{{prefix}}cpg_copy_all();\n'
f' ECOS_updateData({{prefix}}ecos_workspace, {{prefix}}Canon_Params_conditioning.G->x, {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.A->x"}'
f', {{prefix}}Canon_Params_conditioning.c, {{prefix}}Canon_Params_conditioning.h, {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.b"})'
),
'c': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic(['c']),
function_call=f'for (i=0; i<{canon_constants["n"]}; i++) {{{{ ecos_updateDataEntry_c({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.c[i]); }}}}'
function_call=f'{{prefix}}cpg_copy_c();\n'
f' for (i=0; i<{canon_constants["n"]}; i++) {{{{ ecos_updateDataEntry_c({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.c[i]); }}}}'
),
'h': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic(['h']),
function_call=f'for (i=0; i<{canon_constants["m"]}; i++) {{{{ ecos_updateDataEntry_h({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.h[i]); }}}}'
function_call=f'{{prefix}}cpg_copy_h();\n'
f' for (i=0; i<{canon_constants["m"]}; i++) {{{{ ecos_updateDataEntry_h({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.h[i]); }}}}'
)
}

Expand Down Expand Up @@ -929,7 +933,8 @@ def __init__(self, data, p_prob, enable_settings):
'init': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic([], extra_condition=extra_condition, functions_if_false=[]),
function_call= \
f'clarabel_CscMatrix_init(&{{prefix}}P, {canon_constants["n"]}, {canon_constants["n"]}, {P_p}, {P_i}, {P_x});\n'
f'{{prefix}}cpg_copy_all();\n'
f' clarabel_CscMatrix_init(&{{prefix}}P, {canon_constants["n"]}, {canon_constants["n"]}, {P_p}, {P_i}, {P_x});\n'
f' clarabel_CscMatrix_init(&{{prefix}}A, {canon_constants["m"]}, {canon_constants["n"]}, {{prefix}}Canon_Params_conditioning.A->p, {{prefix}}Canon_Params_conditioning.A->i, {{prefix}}Canon_Params_conditioning.A->x);\n' \
f' {{prefix}}settings = clarabel_DefaultSettings_default()'
)
Expand Down
61 changes: 39 additions & 22 deletions cvxpygen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def write_workspace_prot(f, configuration, variable_info, dual_variable_info, pa
f.write('typedef struct {\n')
for p_id in parameter_canon.p.keys():
if p_id.isupper():
f.write(f' cpg_csc *{(p_id+";").ljust(8)} // Canonical parameter {p_id}\n')
f.write(f' cpg_csc *{(p_id+";").ljust(8)} // Canonical parameter {p_id}\n')
else:
if p_id == 'd':
s = ''
Expand All @@ -644,7 +644,7 @@ def write_workspace_prot(f, configuration, variable_info, dual_variable_info, pa
f.write('// Flags indicating outdated canonical parameters\n')
f.write('typedef struct {\n')
for p_id in parameter_canon.p.keys():
f.write(f' int {(p_id + ";").ljust(8)} // Bool, if canonical parameter {p_id} outdated\n')
f.write(f' int {(p_id + ";").ljust(10)} // Bool, if canonical parameter {p_id} outdated\n')
f.write('} Canon_Outdated_t;\n\n')

f.write('// Primal solution\n')
Expand Down Expand Up @@ -672,17 +672,17 @@ def write_workspace_prot(f, configuration, variable_info, dual_variable_info, pa
f.write('typedef struct {\n')
f.write(' cpg_float obj_val; // Objective function value\n')
f.write(' cpg_int iter; // Number of iterations\n')
f.write(f' {"cpg_int " if solver_interface.status_is_int else "char *"}status; // Solver status\n')
f.write(f' {"cpg_int status; " if solver_interface.status_is_int else "char *status; "}// Solver status\n')
f.write(' cpg_float pri_res; // Primal residual\n')
f.write(' cpg_float dua_res; // Dual residual\n')
f.write('} CPG_Info_t;\n\n')

f.write('// Solution and solver information\n')
f.write('typedef struct {\n')
f.write(' CPG_Prim_t *prim; // Primal solution\n')
f.write(' CPG_Prim_t *prim; // Primal solution\n')
if len(dual_variable_info.name_to_init) > 0:
f.write(' CPG_Dual_t *dual; // Dual solution\n')
f.write(' CPG_Info_t *info; // Solver info\n')
f.write(' CPG_Dual_t *dual; // Dual solution\n')
f.write(' CPG_Info_t *info; // Solver info\n')
f.write('} CPG_Result_t;\n\n')

f.write('// Solver settings\n')
Expand Down Expand Up @@ -845,18 +845,12 @@ def write_solve_def(f, configuration, variable_info, dual_variable_info, paramet
f.write(f' {configuration.prefix}CPG_Info.pri_res = {result_prefix}{solver_interface.ws_ptrs.primal_residual};\n')
f.write(f' {configuration.prefix}CPG_Info.dua_res = {result_prefix}{solver_interface.ws_ptrs.dual_residual};\n')
f.write('}\n\n')

f.write('// Solve via canonicalization, canonical solve, retrieval\n')
f.write(f'void {configuration.prefix}cpg_solve(){{\n')
f.write(' // Canonicalize if necessary\n')

for p_id, changes in parameter_canon.p_id_to_changes.items():
if changes:
f.write(f' if ({configuration.prefix}Canon_Outdated.{p_id}) {{\n')
f.write(f' {configuration.prefix}cpg_canonicalize_{p_id}();\n')
f.write(' }\n')
if solver_interface.inmemory_preconditioning:
size = parameter_canon.p_id_to_size[p_id]

if solver_interface.inmemory_preconditioning:
f.write('// Copy canonical parameters for preconditioning\n')
for p_id, size in parameter_canon.p_id_to_size.items():
if p_id != 'd':
f.write(f'void {configuration.prefix}cpg_copy_{p_id}(){{\n')
if size == 1:
f.write(f' {configuration.prefix}Canon_Params_conditioning.{p_id} = {configuration.prefix}Canon_Params.{p_id};\n')
elif size > 1:
Expand All @@ -866,6 +860,22 @@ def write_solve_def(f, configuration, variable_info, dual_variable_info, paramet
else:
f.write(f' {configuration.prefix}Canon_Params_conditioning.{p_id}[i] = {configuration.prefix}Canon_Params.{p_id}[i];\n')
f.write(' }\n')
f.write('}\n\n')
f.write(f'void {configuration.prefix}cpg_copy_all(){{\n')
for p_id in parameter_canon.p.keys():
if p_id != 'd':
f.write(f' {configuration.prefix}cpg_copy_{p_id}();\n')
f.write('}\n\n')

f.write('// Solve via canonicalization, canonical solve, retrieval\n')
f.write(f'void {configuration.prefix}cpg_solve(){{\n')
f.write(' // Canonicalize if necessary\n')

for p_id, changes in parameter_canon.p_id_to_changes.items():
if changes:
f.write(f' if ({configuration.prefix}Canon_Outdated.{p_id}) {{\n')
f.write(f' {configuration.prefix}cpg_canonicalize_{p_id}();\n')
f.write(' }\n')

pus = solver_interface.parameter_update_structure
write_update_structure(f, configuration, parameter_canon, pus, *analyze_pus(pus, parameter_canon.p_id_to_changes))
Expand Down Expand Up @@ -939,6 +949,13 @@ def write_solve_prot(f, configuration, variable_info, dual_variable_info, parame
f.write('\n// Retrieve solver information\n')
f.write(f'extern void {configuration.prefix}cpg_retrieve_info();\n')

if solver_interface.inmemory_preconditioning:
f.write('\n// Copy canonical parameters for preconditioning\n')
for p_id in parameter_canon.p_id_to_size.keys():
if p_id != 'd':
f.write(f'extern void {configuration.prefix}cpg_copy_{p_id}();\n')
f.write(f'extern void {configuration.prefix}cpg_copy_all();\n')

f.write('\n// Solve via canonicalization, canonical solve, retrieval\n')
f.write(f'extern void {configuration.prefix}cpg_solve();\n')

Expand Down Expand Up @@ -1414,7 +1431,7 @@ def replace_html_data(text, configuration, variable_info, dual_variable_info, pa
CPGINFOTYPEDEF += 'typedef struct {\n'
CPGINFOTYPEDEF += ' cpg_float obj_val; // Objective function value\n'
CPGINFOTYPEDEF += ' cpg_int iter; // Number of iterations\n'
CPGINFOTYPEDEF += (f' {"cpg_int " if solver_interface.status_is_int else "char *"}status; // Solver status\n')
CPGINFOTYPEDEF += (f' {"cpg_int status; " if solver_interface.status_is_int else "char *status; "}// Solver status\n')
CPGINFOTYPEDEF += ' cpg_float pri_res; // Primal residual\n'
CPGINFOTYPEDEF += ' cpg_float dua_res; // Dual residual\n'
CPGINFOTYPEDEF += '} CPG_Info_t;\n'
Expand All @@ -1423,10 +1440,10 @@ def replace_html_data(text, configuration, variable_info, dual_variable_info, pa
# type definition of CPG_Result_t
CPGRESULTTYPEDEF = '\n// Struct type with user-defined objective value and solution as fields\n'
CPGRESULTTYPEDEF += 'typedef struct {\n'
CPGRESULTTYPEDEF += ' CPG_Prim_t *prim; // Primal solution\n'
CPGRESULTTYPEDEF += ' CPG_Prim_t *prim; // Primal solution\n'
if len(dual_variable_info.name_to_init) > 0:
CPGRESULTTYPEDEF += ' CPG_Dual_t *dual; // Dual solution\n'
CPGRESULTTYPEDEF += ' CPG_Info_t *info; // Solver information\n'
CPGRESULTTYPEDEF += ' CPG_Dual_t *dual; // Dual solution\n'
CPGRESULTTYPEDEF += ' CPG_Info_t *info; // Solver information\n'
CPGRESULTTYPEDEF += '} CPG_Result_t;\n'
text = text.replace('$CPGRESULTTYPEDEF', CPGRESULTTYPEDEF)

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

MAJOR = 0
MINOR = 3
MICRO = 2
MICRO = 3
VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO)


Expand Down Expand Up @@ -43,8 +43,8 @@ def readme():
'pybind11 >= 2.8',
'osqp >= 0.6.2, < 1.0.0',
'clarabel >= 0.6.0',
'scipy >= 1.1.0',
'numpy >= 1.15',
'scipy >= 1.1.0, <1.12.0',
'numpy >= 1.15, <1.28.0',
],
extras_require={
'dev': [
Expand Down
Loading