Skip to content

Commit

Permalink
Merge pull request #47 from cvxgrp/ms/ecos-update
Browse files Browse the repository at this point in the history
Fix parameter updates with ECOS
  • Loading branch information
maxschaller authored Apr 21, 2024
2 parents 1bf6277 + ac7320c commit aa4aa00
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 31 deletions.
2 changes: 1 addition & 1 deletion cvxpygen/cpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def generate_code(problem, code_dir='CPG_code', solver=None, solver_opts=None,
interface_class, cvxpy_interface_class = get_interface_class(solver_name)

# configuration
configuration = get_configuration(code_dir, solver, unroll, prefix)
configuration = get_configuration(code_dir, solver_name, unroll, prefix)

# cone problems check
if hasattr(param_prob, 'cone_dims'):
Expand Down
15 changes: 10 additions & 5 deletions cvxpygen/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ def __init__(self, data, p_prob, enable_settings):
self.parameter_update_structure = {
'init': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic([], extra_condition='!{prefix}ecos_workspace', functions_if_false=['AbcGh']),
function_call=f'{{prefix}}ecos_workspace = ECOS_setup({canon_constants["n"]}, {canon_constants["m"]}, {canon_constants["p"]}, {canon_constants["l"]}, {canon_constants["n_cones"]}'
function_call=f'{{prefix}}cpg_copy_all();\n'
f' {{prefix}}ecos_workspace = ECOS_setup({canon_constants["n"]}, {canon_constants["m"]}, {canon_constants["p"]}, {canon_constants["l"]}, {canon_constants["n_cones"]}'
f', {"0" if canon_constants["n_cones"] == 0 else "(int *) &{prefix}ecos_q"}, {canon_constants["e"]}'
f', {{prefix}}Canon_Params_conditioning.G->x, {{prefix}}Canon_Params_conditioning.G->p, {{prefix}}Canon_Params_conditioning.G->i'
f', {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.A->x"}'
Expand All @@ -720,16 +721,19 @@ def __init__(self, data, p_prob, enable_settings):
),
'AbcGh': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic(['A', 'b', 'G'], '||', ['c', 'h']),
function_call=f'ECOS_updateData({{prefix}}ecos_workspace, {{prefix}}Canon_Params_conditioning.G->x, {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.A->x"}'
function_call=f'{{prefix}}cpg_copy_all();\n'
f' ECOS_updateData({{prefix}}ecos_workspace, {{prefix}}Canon_Params_conditioning.G->x, {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.A->x"}'
f', {{prefix}}Canon_Params_conditioning.c, {{prefix}}Canon_Params_conditioning.h, {"0" if canon_constants["p"] == 0 else "{prefix}Canon_Params_conditioning.b"})'
),
'c': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic(['c']),
function_call=f'for (i=0; i<{canon_constants["n"]}; i++) {{{{ ecos_updateDataEntry_c({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.c[i]); }}}}'
function_call=f'{{prefix}}cpg_copy_c();\n'
f' for (i=0; i<{canon_constants["n"]}; i++) {{{{ ecos_updateDataEntry_c({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.c[i]); }}}}'
),
'h': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic(['h']),
function_call=f'for (i=0; i<{canon_constants["m"]}; i++) {{{{ ecos_updateDataEntry_h({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.h[i]); }}}}'
function_call=f'{{prefix}}cpg_copy_h();\n'
f' for (i=0; i<{canon_constants["m"]}; i++) {{{{ ecos_updateDataEntry_h({{prefix}}ecos_workspace, i, {{prefix}}Canon_Params_conditioning.h[i]); }}}}'
)
}

Expand Down Expand Up @@ -929,7 +933,8 @@ def __init__(self, data, p_prob, enable_settings):
'init': ParameterUpdateLogic(
update_pending_logic=UpdatePendingLogic([], extra_condition=extra_condition, functions_if_false=[]),
function_call= \
f'clarabel_CscMatrix_init(&{{prefix}}P, {canon_constants["n"]}, {canon_constants["n"]}, {P_p}, {P_i}, {P_x});\n'
f'{{prefix}}cpg_copy_all();\n'
f' clarabel_CscMatrix_init(&{{prefix}}P, {canon_constants["n"]}, {canon_constants["n"]}, {P_p}, {P_i}, {P_x});\n'
f' clarabel_CscMatrix_init(&{{prefix}}A, {canon_constants["m"]}, {canon_constants["n"]}, {{prefix}}Canon_Params_conditioning.A->p, {{prefix}}Canon_Params_conditioning.A->i, {{prefix}}Canon_Params_conditioning.A->x);\n' \
f' {{prefix}}settings = clarabel_DefaultSettings_default()'
)
Expand Down
61 changes: 39 additions & 22 deletions cvxpygen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def write_workspace_prot(f, configuration, variable_info, dual_variable_info, pa
f.write('typedef struct {\n')
for p_id in parameter_canon.p.keys():
if p_id.isupper():
f.write(f' cpg_csc *{(p_id+";").ljust(8)} // Canonical parameter {p_id}\n')
f.write(f' cpg_csc *{(p_id+";").ljust(8)} // Canonical parameter {p_id}\n')
else:
if p_id == 'd':
s = ''
Expand All @@ -644,7 +644,7 @@ def write_workspace_prot(f, configuration, variable_info, dual_variable_info, pa
f.write('// Flags indicating outdated canonical parameters\n')
f.write('typedef struct {\n')
for p_id in parameter_canon.p.keys():
f.write(f' int {(p_id + ";").ljust(8)} // Bool, if canonical parameter {p_id} outdated\n')
f.write(f' int {(p_id + ";").ljust(10)} // Bool, if canonical parameter {p_id} outdated\n')
f.write('} Canon_Outdated_t;\n\n')

f.write('// Primal solution\n')
Expand Down Expand Up @@ -672,17 +672,17 @@ def write_workspace_prot(f, configuration, variable_info, dual_variable_info, pa
f.write('typedef struct {\n')
f.write(' cpg_float obj_val; // Objective function value\n')
f.write(' cpg_int iter; // Number of iterations\n')
f.write(f' {"cpg_int " if solver_interface.status_is_int else "char *"}status; // Solver status\n')
f.write(f' {"cpg_int status; " if solver_interface.status_is_int else "char *status; "}// Solver status\n')
f.write(' cpg_float pri_res; // Primal residual\n')
f.write(' cpg_float dua_res; // Dual residual\n')
f.write('} CPG_Info_t;\n\n')

f.write('// Solution and solver information\n')
f.write('typedef struct {\n')
f.write(' CPG_Prim_t *prim; // Primal solution\n')
f.write(' CPG_Prim_t *prim; // Primal solution\n')
if len(dual_variable_info.name_to_init) > 0:
f.write(' CPG_Dual_t *dual; // Dual solution\n')
f.write(' CPG_Info_t *info; // Solver info\n')
f.write(' CPG_Dual_t *dual; // Dual solution\n')
f.write(' CPG_Info_t *info; // Solver info\n')
f.write('} CPG_Result_t;\n\n')

f.write('// Solver settings\n')
Expand Down Expand Up @@ -845,18 +845,12 @@ def write_solve_def(f, configuration, variable_info, dual_variable_info, paramet
f.write(f' {configuration.prefix}CPG_Info.pri_res = {result_prefix}{solver_interface.ws_ptrs.primal_residual};\n')
f.write(f' {configuration.prefix}CPG_Info.dua_res = {result_prefix}{solver_interface.ws_ptrs.dual_residual};\n')
f.write('}\n\n')

f.write('// Solve via canonicalization, canonical solve, retrieval\n')
f.write(f'void {configuration.prefix}cpg_solve(){{\n')
f.write(' // Canonicalize if necessary\n')

for p_id, changes in parameter_canon.p_id_to_changes.items():
if changes:
f.write(f' if ({configuration.prefix}Canon_Outdated.{p_id}) {{\n')
f.write(f' {configuration.prefix}cpg_canonicalize_{p_id}();\n')
f.write(' }\n')
if solver_interface.inmemory_preconditioning:
size = parameter_canon.p_id_to_size[p_id]

if solver_interface.inmemory_preconditioning:
f.write('// Copy canonical parameters for preconditioning\n')
for p_id, size in parameter_canon.p_id_to_size.items():
if p_id != 'd':
f.write(f'void {configuration.prefix}cpg_copy_{p_id}(){{\n')
if size == 1:
f.write(f' {configuration.prefix}Canon_Params_conditioning.{p_id} = {configuration.prefix}Canon_Params.{p_id};\n')
elif size > 1:
Expand All @@ -866,6 +860,22 @@ def write_solve_def(f, configuration, variable_info, dual_variable_info, paramet
else:
f.write(f' {configuration.prefix}Canon_Params_conditioning.{p_id}[i] = {configuration.prefix}Canon_Params.{p_id}[i];\n')
f.write(' }\n')
f.write('}\n\n')
f.write(f'void {configuration.prefix}cpg_copy_all(){{\n')
for p_id in parameter_canon.p.keys():
if p_id != 'd':
f.write(f' {configuration.prefix}cpg_copy_{p_id}();\n')
f.write('}\n\n')

f.write('// Solve via canonicalization, canonical solve, retrieval\n')
f.write(f'void {configuration.prefix}cpg_solve(){{\n')
f.write(' // Canonicalize if necessary\n')

for p_id, changes in parameter_canon.p_id_to_changes.items():
if changes:
f.write(f' if ({configuration.prefix}Canon_Outdated.{p_id}) {{\n')
f.write(f' {configuration.prefix}cpg_canonicalize_{p_id}();\n')
f.write(' }\n')

pus = solver_interface.parameter_update_structure
write_update_structure(f, configuration, parameter_canon, pus, *analyze_pus(pus, parameter_canon.p_id_to_changes))
Expand Down Expand Up @@ -939,6 +949,13 @@ def write_solve_prot(f, configuration, variable_info, dual_variable_info, parame
f.write('\n// Retrieve solver information\n')
f.write(f'extern void {configuration.prefix}cpg_retrieve_info();\n')

if solver_interface.inmemory_preconditioning:
f.write('\n// Copy canonical parameters for preconditioning\n')
for p_id in parameter_canon.p_id_to_size.keys():
if p_id != 'd':
f.write(f'extern void {configuration.prefix}cpg_copy_{p_id}();\n')
f.write(f'extern void {configuration.prefix}cpg_copy_all();\n')

f.write('\n// Solve via canonicalization, canonical solve, retrieval\n')
f.write(f'extern void {configuration.prefix}cpg_solve();\n')

Expand Down Expand Up @@ -1414,7 +1431,7 @@ def replace_html_data(text, configuration, variable_info, dual_variable_info, pa
CPGINFOTYPEDEF += 'typedef struct {\n'
CPGINFOTYPEDEF += ' cpg_float obj_val; // Objective function value\n'
CPGINFOTYPEDEF += ' cpg_int iter; // Number of iterations\n'
CPGINFOTYPEDEF += (f' {"cpg_int " if solver_interface.status_is_int else "char *"}status; // Solver status\n')
CPGINFOTYPEDEF += (f' {"cpg_int status; " if solver_interface.status_is_int else "char *status; "}// Solver status\n')
CPGINFOTYPEDEF += ' cpg_float pri_res; // Primal residual\n'
CPGINFOTYPEDEF += ' cpg_float dua_res; // Dual residual\n'
CPGINFOTYPEDEF += '} CPG_Info_t;\n'
Expand All @@ -1423,10 +1440,10 @@ def replace_html_data(text, configuration, variable_info, dual_variable_info, pa
# type definition of CPG_Result_t
CPGRESULTTYPEDEF = '\n// Struct type with user-defined objective value and solution as fields\n'
CPGRESULTTYPEDEF += 'typedef struct {\n'
CPGRESULTTYPEDEF += ' CPG_Prim_t *prim; // Primal solution\n'
CPGRESULTTYPEDEF += ' CPG_Prim_t *prim; // Primal solution\n'
if len(dual_variable_info.name_to_init) > 0:
CPGRESULTTYPEDEF += ' CPG_Dual_t *dual; // Dual solution\n'
CPGRESULTTYPEDEF += ' CPG_Info_t *info; // Solver information\n'
CPGRESULTTYPEDEF += ' CPG_Dual_t *dual; // Dual solution\n'
CPGRESULTTYPEDEF += ' CPG_Info_t *info; // Solver information\n'
CPGRESULTTYPEDEF += '} CPG_Result_t;\n'
text = text.replace('$CPGRESULTTYPEDEF', CPGRESULTTYPEDEF)

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

MAJOR = 0
MINOR = 3
MICRO = 2
MICRO = 3
VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO)


Expand Down Expand Up @@ -43,8 +43,8 @@ def readme():
'pybind11 >= 2.8',
'osqp >= 0.6.2, < 1.0.0',
'clarabel >= 0.6.0',
'scipy >= 1.1.0',
'numpy >= 1.15',
'scipy >= 1.1.0, <1.12.0',
'numpy >= 1.15, <1.28.0',
],
extras_require={
'dev': [
Expand Down

0 comments on commit aa4aa00

Please sign in to comment.