Skip to content

Commit

Permalink
Modify templates
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu committed Jan 26, 2024
1 parent 48b72a4 commit a6717ab
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def print_variable(self, variable: ASTVariable) -> str:
return "((post_neuron_t*)(__target))->get_" + _name + "(_tr_t)"

if variable.get_name() == PredefinedVariables.E_CONSTANT:
return "numerics::e"
return "M_E"

symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
#include <cmath>
#include <iostream>
#include "{{ neuronName }}.h"
{%- if not uses_numeric_solver %}
#include "{{ neuronName }}_kernel.h"
#include "rk5.h"
{%- endif %}
#include "spike_buffer.h"

{%- import 'directives/SetScalParamAndVar.jinja2' as set_scal_param_var with context %}

{%- if uses_analytic_solver %}
{%- if not uses_numeric_solver %}
using namespace {{ neuronName }}_ns;

__global__ void {{ neuronName }}_Calibrate(int n_node, float *param_arr,
Expand Down Expand Up @@ -74,13 +78,14 @@ __global__ void {{ neuronName }}_Update(int n_node, int i_node_0, float *var_arr
{%- endif %}
}
}
{%- endif %}


{{ neuronName }}::~{{ neuronName }}()
{
FreeVarArr();
FreeParamArr();
}
{%- endif %}

{%- if uses_numeric_solver %}
namespace {{neuronName}}_ns
Expand All @@ -97,8 +102,10 @@ void NodeInit(int n_var, int n_param, double x, float *y, float *param,

// Internal variables
{%- for variable_symbol in neuron.get_internal_symbols() %}
{%- if variable_symbol.get_symbol_name() != "__h" %}
{%- set variable = utils.get_internal_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
{%- include "directives/MemberInitialization.jinja2" %}
{%- endif %}
{%- endfor %}

// State variables
Expand All @@ -112,12 +119,14 @@ __device__
void NodeCalibrate(int n_var, int n_param, double x, float *y,
float *param, {{neuronName}}_rk5 data_struct)
{
refractory_step = 0;
// refractory_step = 0;
{%- filter indent(4,True) %}
{%- for internals_block in neuron.get_internals_blocks() %}
{%- for decl in internals_block.get_declarations() %}
{%- for variable in decl.get_variables() %}
{%- if variable.get_name() != "h" %}
{%- include "directives/MemberInitialization.jinja2" %}
{%- endif %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
Expand All @@ -140,7 +149,6 @@ void NodeCalibrate(int n_var, int n_param, double x, float *y,
{
{{neuronName}}_ns::NodeCalibrate(n_var, n_param, x, y, param, data_struct);
}
using namespace aeif_cond_alpha_ns;
{%- endif %}

int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
Expand All @@ -165,7 +173,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
scal_param_name_ = {{ neuronName }}_scal_param_name;

{%- if uses_numeric_solver %}
group_param_name_ = aeif_cond_alpha_group_param_name;
group_param_name_ = {{neuronName}}_group_param_name;
rk5_data_struct_.i_node_0_ = i_node_0_;

SetGroupParam("h_min_rel", 1.0e-3);
Expand Down Expand Up @@ -219,18 +227,6 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
return 0;
}

int {{ neuronName }}::Update(long long it, double t1)
{
{%- if uses_analytic_solver %}
{{ neuronName }}_Update<<<(n_node_+1023)/1024, 1024>>>
(n_node_, i_node_0_, var_arr_, param_arr_, n_var_, n_param_);
// gpuErrchk( cudaDeviceSynchronize() );
{%- else %}
rk5_.Update<N_SCAL_VAR, N_SCAL_PARAM>(t1, h_min_, rk5_data_struct_);
{%- endif %}
return 0;
}

int {{ neuronName }}::Free()
{
FreeVarArr();
Expand All @@ -241,13 +237,25 @@ int {{ neuronName }}::Free()

int {{ neuronName }}::Calibrate(double time_min, float time_resolution)
{
{%- if uses_analytic_solver %}
{{ neuronName }}_Calibrate<<<(n_node_+1023)/1024, 1024>>>
(n_node_, param_arr_, n_param_, time_resolution);
{%- else %}
{%- if uses_numeric_solver %}
h_min_ = h_min_rel_* time_resolution;
h_ = h0_rel_* time_resolution;
rk5_.Calibrate(time_min, h_, rk5_data_struct_);
{%- else %}
{{ neuronName }}_Calibrate<<<(n_node_+1023)/1024, 1024>>>
(n_node_, param_arr_, n_param_, time_resolution);
{%- endif %}
return 0;
}

int {{ neuronName }}::Update(long long it, double t1)
{
{%- if uses_numeric_solver %}
rk5_.Update<N_SCAL_VAR, N_SCAL_PARAM>(t1, h_min_, rk5_data_struct_);
{%- else %}
{{ neuronName }}_Update<<<(n_node_+1023)/1024, 1024>>>
(n_node_, i_node_0_, var_arr_, param_arr_, n_var_, n_param_);
// gpuErrchk( cudaDeviceSynchronize() );
{%- endif %}
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ const std::string {{ neuronName }}_group_param_name[N_GROUP_PARAM] = {
#}

{%- if uses_numeric_solver %}
{%- for variable_symbol in utils.adjusted_state_symbols() %}
{%- for variable_symbol in utils.adjusted_state_symbols(neuron) %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
#define {{ printer_no_origin.print(variable) }} y[i_{{ printer_no_origin.print(variable) }}]
{%- endfor %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@



#ifndef {{ neuronName.upper() }}_KERNEL)_H
#define {{ neuronName.upper() }}_KERNEL)_H
#ifndef {{ neuronName.upper() }}_KERNEL_H
#define {{ neuronName.upper() }}_KERNEL_H

#include <string>
#include <cmath>
Expand Down
12 changes: 12 additions & 0 deletions tests/nest_gpu_tests/nest_gpu_code_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ def test_nest_gpu_code_generator(self):
generate_nest_gpu_target(input_path, target_path,
logging_level=logging_level,
suffix=suffix)

def test_nest_gpu_code_generator_numeric(self):
input_path = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(
os.pardir, os.pardir, "models", "neurons", "aeif_cond_alpha.nestml"))))
target_path = "target_gpu_numeric"
logging_level = "INFO"
suffix = "_nestml"
codegen_opts = {"solver": "numeric"}
generate_nest_gpu_target(input_path, target_path,
logging_level=logging_level,
suffix=suffix,
codegen_opts=codegen_opts)

0 comments on commit a6717ab

Please sign in to comment.