From 2d0215a6a54993861e02f8e641bd4b605025634d Mon Sep 17 00:00:00 2001 From: Mohamed Rachidi Date: Tue, 12 Nov 2024 13:33:11 +0400 Subject: [PATCH] fix degree search, module can now find the 8 rounds distinguisher of Aradi --- .../cipher_modules/division_trail_search.py | 78 +++++++++---------- 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/claasp/cipher_modules/division_trail_search.py b/claasp/cipher_modules/division_trail_search.py index 827d341d..88dc5aa8 100644 --- a/claasp/cipher_modules/division_trail_search.py +++ b/claasp/cipher_modules/division_trail_search.py @@ -36,6 +36,7 @@ class MilpDivisionTrailModel(): This module can only be used if the user possesses a Gurobi license. """ + def __init__(self, cipher): self._cipher = cipher self._variables = None @@ -44,13 +45,20 @@ def __init__(self, cipher): self._used_variables = [] self._variables_as_list = [] self._unused_variables = [] + self._used_predecessors_sorted = None + self._output_id = None + self._output_bit_index_previous_comp = None + self._block_needed = None + self._input_id_link_needed = None def get_all_variables_as_list(self): for component_id in list(self._variables.keys())[:-1]: for bit_position in self._variables[component_id].keys(): for value in self._variables[component_id][bit_position].keys(): if value != "current": - self._variables_as_list.append(self._variables[component_id][bit_position][value].VarName) + varname = self._variables[component_id][bit_position][value].VarName + if varname not in self._variables_as_list: # rot and intermediate has the same name than original + self._variables_as_list.append(varname) def get_unused_variables(self): self.get_all_variables_as_list() @@ -81,9 +89,7 @@ def build_gurobi_model(self): model = Model("basic_model", env=env) # model = Model() model.Params.LogToConsole = 0 - model.Params.Threads = 16 # best found experimentaly on ascon_sbox_2rounds - model.setParam("PoolSolutions", 1234) # 200000000 - model.setParam(GRB.Param.PoolSearchMode, 2) + # model.Params.Threads = 16 self._model = model def get_anfs_from_sbox(self, component): @@ -217,10 +223,13 @@ def add_sbox_constraints(self, component): def add_xor_constraints(self, component): output_vars = self.get_output_vars(component) + # print(output_vars) input_vars_concat = [] constant_flag = [] for index, input_name in enumerate(component.input_id_links): + # print(input_name) + # print(self._variables[input_name]) for pos in component.input_bit_positions[index]: current = self._variables[input_name][pos]["current"] if input_name[:8] == "constant": @@ -372,10 +381,11 @@ def add_constraints(self, predecessors, input_id_link_needed, block_needed): self.create_gurobi_vars_from_all_components(predecessors, input_id_link_needed, block_needed) used_predecessors_sorted = self.order_predecessors(list(self._occurences.keys())) + self._used_predecessors_sorted = used_predecessors_sorted for component_id in used_predecessors_sorted: if component_id not in self._cipher.inputs: component = self._cipher.get_component_from_id(component_id) - print(f"---------> {component.id}") + # print(f"---------> {component.id}") if component.type == "sbox": self.add_sbox_constraints(component) elif component.type in ["cipher_output", "constant", "intermediate_output"]: @@ -462,8 +472,8 @@ def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_nee occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed) all_vars = {} used_predecessors_sorted = self.order_predecessors(list(occurences.keys())) - print("used_predecessors_sorted") - print(used_predecessors_sorted) + # print("used_predecessors_sorted") + # print(used_predecessors_sorted) for component_id in used_predecessors_sorted: all_vars[component_id] = {} # We need the inputs vars to be the first ones defined by gurobi in order to find their values with X.values method. @@ -571,20 +581,15 @@ def get_output_bit_index_previous_component(self, output_bit_index_ciphertext, c block_needed = comp.input_bit_positions[index] input_id_link_needed = chosen_cipher_output output_bit_index_previous_comp = output_bit_index_ciphertext - print(output_id) - print(block_needed) - print(input_id_link_needed) - print(output_bit_index_previous_comp) return output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot else: output_id = self.get_cipher_output_component_id() - # output_id = "xor_1_69" component = self._cipher.get_component_from_id(output_id) pivot = 0 output_bit_index_previous_comp = output_bit_index_ciphertext for index, block in enumerate(component.input_bit_positions): if pivot <= output_bit_index_ciphertext < pivot + len(block): - output_bit_index_previous_comp = output_bit_index_ciphertext - pivot + output_bit_index_previous_comp = block[output_bit_index_ciphertext - pivot] block_needed = block input_id_link_needed = component.input_id_links[index] break @@ -608,6 +613,14 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex start = time.time() output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot = self.get_output_bit_index_previous_component( output_bit_index_ciphertext, chosen_cipher_output) + # print(output_id) + # print(block_needed) + # print(input_id_link_needed) + # print(output_bit_index_previous_comp) + self._output_id = output_id + self._output_bit_index_previous_comp = output_bit_index_previous_comp + self._block_needed = block_needed + self._input_id_link_needed = input_id_link_needed G = create_networkx_graph_from_input_ids(self._cipher) predecessors = list(_get_predecessors_subgraph(G, [input_id_link_needed])) @@ -615,10 +628,6 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex if input_id in predecessors: predecessors.remove(input_id) - # print("input_id_link_needed") - # print(input_id_link_needed) - # print("predecessors") - # print(predecessors) self.add_constraints(predecessors, input_id_link_needed, block_needed) var_from_block_needed = [] @@ -654,7 +663,7 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex self.set_unused_variables_to_zero() self._model.update() - self._model.write("division_trail_model.lp") + # self._model.write("division_trail_model.lp") end = time.time() building_time = end - start print(f"########## building_time : {building_time}") @@ -695,6 +704,8 @@ def get_solutions(self): else: if index < len(list(self._occurences[self._cipher.inputs[0]].keys())): tmp += self._cipher.inputs[0][0] + str(first_input_bit_positions[index]) + if 1 not in values[:max_input_bit_pos]: + tmp += str(1) if tmp in monomials: monomials.remove(tmp) else: @@ -716,14 +727,8 @@ def optimize_model(self): def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, chosen_cipher_output=None): self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output) - - # # Specific to Aradi analysis: - # for i in range(96): - # v = self._model.getVarByName(f"plaintext[{i}]") - # self._model.addConstr(v == 0) - # self._model.update() - # self._model.write("division_trail_model.lp") - # ######################## + self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large + self._model.setParam(GRB.Param.PoolSearchMode, 2) self.optimize_model() self.get_solutions() @@ -753,9 +758,9 @@ def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_out self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output) self._model.setParam(GRB.Param.PoolSearchMode, 1) self._model.setParam('Presolve', 2) - self._model.setParam('MIPFocus', 3) - # self._model.setParam('Cuts', 2) - self._model.setParam('NodefileStart', 2.0) + self._model.setParam("MIPFocus", 2) + self._model.setParam("MIPGap", 0) # when set to 0, best solution = optimal solution + self._model.setParam('Cuts', 2) index_plaintext = self._cipher.inputs.index("plaintext") plaintext_bit_size = self._cipher.inputs_bit_size[index_plaintext] @@ -765,19 +770,10 @@ def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_out p.append(self._model.getVarByName(f"plaintext[{i}]")) self._model.setObjective(sum(p[i] for i in range(nb_plaintext_bits_used)), GRB.MAXIMIZE) - ## Specific to Aradi analysis: - # for i in range(128): - # v = self._model.getVarByName(f"plaintext[{i}]") - # if 0 <= i < 128: # free vars - # self._model.addConstr(v >= 0) - # else: - # self._model.addConstr(v == 0) - # self._model.update() - # self._model.write("division_trail_model.lp") - ####################### - + self._model.update() + self._model.write("division_trail_model.lp") self.optimize_model() - # get degree + degree = self._model.getObjective().getValue() return degree