Skip to content

Commit

Permalink
Fix compilation errros
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu committed Jan 29, 2024
1 parent a6717ab commit 5cd6533
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
23 changes: 23 additions & 0 deletions pynestml/codegeneration/nest_gpu_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def replace_text_between_tags(filepath, replace_str, begin_tag="// <<BEGIN_NESTM
file_str = file_str[:start_pos] + replace_str + file_str[end_pos:]
with open(filepath, "w") as f:
f.write(file_str)
f.close()


class NESTGPUCodeGenerator(NESTCodeGenerator):
Expand Down Expand Up @@ -133,6 +134,9 @@ def generate_module_code(self, neurons: Sequence[ASTNeuron], synapses: Sequence[
self.add_model_name_to_neuron_header(neuron)
self.add_model_to_neuron_class(neuron)
self.add_files_to_makefile(neuron)
if neuron.get_name() in self.numeric_solver.keys() \
and self.numeric_solver[neuron.get_name()] is not None:
self.add_model_header_to_rk5_interface(neuron)

def copy_models_from_target_path(self, neuron: ASTNeuron):
"""Copies all the files related to the neuron model to the NEST GPU src directory"""
Expand Down Expand Up @@ -190,3 +194,22 @@ def add_files_to_makefile(self, neuron: ASTNeuron):
replace_text_between_tags(cmakelists_path, code_block,
begin_tag="# <<BEGIN_NESTML_GENERATED>>",
end_tag="# <<END_NESTML_GENERATED>>")

def add_model_header_to_rk5_interface(self, neuron: ASTNeuron):
"""
Modifies the rk5_interface.h header file to add the model rk5 header file. This is only for
neuron models with a numeric solver.
"""
rk5_interface_path = str(os.path.join(self.nest_gpu_path, "src", "rk5_interface.h"))
shutil.copy(rk5_interface_path, rk5_interface_path + ".bak")

code_block = f"#include \"{neuron.get_name()}_rk5.h\""

replace_text_between_tags(rk5_interface_path, code_block)

# with open(rk5_interface_path, "a+") as f:
# lines = f.readlines()
# lines.append(f"#include \"{neuron.get_name()}_rk5.h\"")
# f.writelines(lines)

# f.close()
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ __global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
{%- for internals_block in neuron.get_internals_blocks() %}
{%- for decl in internals_block.get_declarations() %}
{%- for variable in decl.get_variables() %}
{%- if variable.get_name() != "__h" %}
{%- include "directives/MemberInitialization.jinja2" %}
{%- endif %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
Expand Down Expand Up @@ -124,7 +126,7 @@ void NodeCalibrate(int n_var, int n_param, double x, float *y,
{%- for internals_block in neuron.get_internals_blocks() %}
{%- for decl in internals_block.get_declarations() %}
{%- for variable in decl.get_variables() %}
{%- if variable.get_name() != "h" %}
{%- if variable.get_name() != "__h" %}
{%- include "directives/MemberInitialization.jinja2" %}
{%- endif %}
{%- endfor %}
Expand All @@ -151,6 +153,8 @@ void NodeCalibrate(int n_var, int n_param, double x, float *y,
}
{%- endif %}

using namespace {{ neuronName }}_ns;

int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
int i_group, unsigned long long *seed)
{
Expand All @@ -166,8 +170,10 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
group_param_ = new float[N_GROUP_PARAM];
{%- endif %}

{%- if not uses_numeric_solver %}
AllocParamArr();
AllocVarArr();
{%- endif %}

scal_var_name_ = {{ neuronName }}_scal_var_name;
scal_param_name_ = {{ neuronName }}_scal_param_name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ __device__
void Derivatives(double x, float *y, float *dydx, float *param,
{{ neuronName }}_rk5 data_struct)
{
float I_syn_tot = 0.0;
{# float I_syn_tot = 0.0;
I_syn_tot += I_syn_ex - I_syn_in;
float V = ( refractory_step > 0 ) ? V_reset : MIN(V_m, V_peak);
Expand All @@ -52,7 +52,7 @@ __device__
// Adaptation current w.
dwdt = (a*(V - E_L) - w) / tau_w;
dI_syn_exdt = -I_syn_ex / tau_syn_ex;
dI_syn_indt = -I_syn_in / tau_syn_in;
dI_syn_indt = -I_syn_in / tau_syn_in; #}
}

template<int NVAR, int NPARAM> //, class DataStruct>
Expand Down

0 comments on commit 5cd6533

Please sign in to comment.