Skip to content

Commit

Permalink
fix degree search, module can now find the 8 rounds distinguisher of …
Browse files Browse the repository at this point in the history
…Aradi
  • Loading branch information
SiMohamedRachidi committed Nov 12, 2024
1 parent f82fba9 commit 2d0215a
Showing 1 changed file with 37 additions and 41 deletions.
78 changes: 37 additions & 41 deletions claasp/cipher_modules/division_trail_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class MilpDivisionTrailModel():
This module can only be used if the user possesses a Gurobi license.
"""

def __init__(self, cipher):
self._cipher = cipher
self._variables = None
Expand All @@ -44,13 +45,20 @@ def __init__(self, cipher):
self._used_variables = []
self._variables_as_list = []
self._unused_variables = []
self._used_predecessors_sorted = None
self._output_id = None
self._output_bit_index_previous_comp = None
self._block_needed = None
self._input_id_link_needed = None

def get_all_variables_as_list(self):
for component_id in list(self._variables.keys())[:-1]:
for bit_position in self._variables[component_id].keys():
for value in self._variables[component_id][bit_position].keys():
if value != "current":
self._variables_as_list.append(self._variables[component_id][bit_position][value].VarName)
varname = self._variables[component_id][bit_position][value].VarName
if varname not in self._variables_as_list: # rot and intermediate has the same name than original
self._variables_as_list.append(varname)

def get_unused_variables(self):
self.get_all_variables_as_list()
Expand Down Expand Up @@ -81,9 +89,7 @@ def build_gurobi_model(self):
model = Model("basic_model", env=env)
# model = Model()
model.Params.LogToConsole = 0
model.Params.Threads = 16 # best found experimentaly on ascon_sbox_2rounds
model.setParam("PoolSolutions", 1234) # 200000000
model.setParam(GRB.Param.PoolSearchMode, 2)
# model.Params.Threads = 16
self._model = model

def get_anfs_from_sbox(self, component):
Expand Down Expand Up @@ -217,10 +223,13 @@ def add_sbox_constraints(self, component):

def add_xor_constraints(self, component):
output_vars = self.get_output_vars(component)
# print(output_vars)

input_vars_concat = []
constant_flag = []
for index, input_name in enumerate(component.input_id_links):
# print(input_name)
# print(self._variables[input_name])
for pos in component.input_bit_positions[index]:
current = self._variables[input_name][pos]["current"]
if input_name[:8] == "constant":
Expand Down Expand Up @@ -372,10 +381,11 @@ def add_constraints(self, predecessors, input_id_link_needed, block_needed):
self.create_gurobi_vars_from_all_components(predecessors, input_id_link_needed, block_needed)

used_predecessors_sorted = self.order_predecessors(list(self._occurences.keys()))
self._used_predecessors_sorted = used_predecessors_sorted
for component_id in used_predecessors_sorted:
if component_id not in self._cipher.inputs:
component = self._cipher.get_component_from_id(component_id)
print(f"---------> {component.id}")
# print(f"---------> {component.id}")
if component.type == "sbox":
self.add_sbox_constraints(component)
elif component.type in ["cipher_output", "constant", "intermediate_output"]:
Expand Down Expand Up @@ -462,8 +472,8 @@ def create_gurobi_vars_from_all_components(self, predecessors, input_id_link_nee
occurences = self.get_where_component_is_used(predecessors, input_id_link_needed, block_needed)
all_vars = {}
used_predecessors_sorted = self.order_predecessors(list(occurences.keys()))
print("used_predecessors_sorted")
print(used_predecessors_sorted)
# print("used_predecessors_sorted")
# print(used_predecessors_sorted)
for component_id in used_predecessors_sorted:
all_vars[component_id] = {}
# We need the inputs vars to be the first ones defined by gurobi in order to find their values with X.values method.
Expand Down Expand Up @@ -571,20 +581,15 @@ def get_output_bit_index_previous_component(self, output_bit_index_ciphertext, c
block_needed = comp.input_bit_positions[index]
input_id_link_needed = chosen_cipher_output
output_bit_index_previous_comp = output_bit_index_ciphertext
print(output_id)
print(block_needed)
print(input_id_link_needed)
print(output_bit_index_previous_comp)
return output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot
else:
output_id = self.get_cipher_output_component_id()
# output_id = "xor_1_69"
component = self._cipher.get_component_from_id(output_id)
pivot = 0
output_bit_index_previous_comp = output_bit_index_ciphertext
for index, block in enumerate(component.input_bit_positions):
if pivot <= output_bit_index_ciphertext < pivot + len(block):
output_bit_index_previous_comp = output_bit_index_ciphertext - pivot
output_bit_index_previous_comp = block[output_bit_index_ciphertext - pivot]
block_needed = block
input_id_link_needed = component.input_id_links[index]
break
Expand All @@ -608,17 +613,21 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex
start = time.time()
output_id, output_bit_index_previous_comp, block_needed, input_id_link_needed, pivot = self.get_output_bit_index_previous_component(
output_bit_index_ciphertext, chosen_cipher_output)
# print(output_id)
# print(block_needed)
# print(input_id_link_needed)
# print(output_bit_index_previous_comp)
self._output_id = output_id
self._output_bit_index_previous_comp = output_bit_index_previous_comp
self._block_needed = block_needed
self._input_id_link_needed = input_id_link_needed

G = create_networkx_graph_from_input_ids(self._cipher)
predecessors = list(_get_predecessors_subgraph(G, [input_id_link_needed]))
for input_id in self._cipher.inputs + ['']:
if input_id in predecessors:
predecessors.remove(input_id)

# print("input_id_link_needed")
# print(input_id_link_needed)
# print("predecessors")
# print(predecessors)
self.add_constraints(predecessors, input_id_link_needed, block_needed)

var_from_block_needed = []
Expand Down Expand Up @@ -654,7 +663,7 @@ def build_generic_model_for_specific_output_bit(self, output_bit_index_ciphertex

self.set_unused_variables_to_zero()
self._model.update()
self._model.write("division_trail_model.lp")
# self._model.write("division_trail_model.lp")
end = time.time()
building_time = end - start
print(f"########## building_time : {building_time}")
Expand Down Expand Up @@ -695,6 +704,8 @@ def get_solutions(self):
else:
if index < len(list(self._occurences[self._cipher.inputs[0]].keys())):
tmp += self._cipher.inputs[0][0] + str(first_input_bit_positions[index])
if 1 not in values[:max_input_bit_pos]:
tmp += str(1)
if tmp in monomials:
monomials.remove(tmp)
else:
Expand All @@ -716,14 +727,8 @@ def optimize_model(self):

def find_anf_of_specific_output_bit(self, output_bit_index, fixed_degree=None, chosen_cipher_output=None):
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)

# # Specific to Aradi analysis:
# for i in range(96):
# v = self._model.getVarByName(f"plaintext[{i}]")
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
# ########################
self._model.setParam("PoolSolutions", 200000000) # 200000000 to be large
self._model.setParam(GRB.Param.PoolSearchMode, 2)

self.optimize_model()
self.get_solutions()
Expand Down Expand Up @@ -753,9 +758,9 @@ def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_out
self.build_generic_model_for_specific_output_bit(output_bit_index, fixed_degree, chosen_cipher_output)
self._model.setParam(GRB.Param.PoolSearchMode, 1)
self._model.setParam('Presolve', 2)
self._model.setParam('MIPFocus', 3)
# self._model.setParam('Cuts', 2)
self._model.setParam('NodefileStart', 2.0)
self._model.setParam("MIPFocus", 2)
self._model.setParam("MIPGap", 0) # when set to 0, best solution = optimal solution
self._model.setParam('Cuts', 2)

index_plaintext = self._cipher.inputs.index("plaintext")
plaintext_bit_size = self._cipher.inputs_bit_size[index_plaintext]
Expand All @@ -765,19 +770,10 @@ def find_degree_of_specific_output_bit(self, output_bit_index, chosen_cipher_out
p.append(self._model.getVarByName(f"plaintext[{i}]"))
self._model.setObjective(sum(p[i] for i in range(nb_plaintext_bits_used)), GRB.MAXIMIZE)

## Specific to Aradi analysis:
# for i in range(128):
# v = self._model.getVarByName(f"plaintext[{i}]")
# if 0 <= i < 128: # free vars
# self._model.addConstr(v >= 0)
# else:
# self._model.addConstr(v == 0)
# self._model.update()
# self._model.write("division_trail_model.lp")
#######################

self._model.update()
self._model.write("division_trail_model.lp")
self.optimize_model()
# get degree

degree = self._model.getObjective().getValue()
return degree

Expand Down

0 comments on commit 2d0215a

Please sign in to comment.