Skip to content

Commit

Permalink
Refactor function.
Browse files Browse the repository at this point in the history
  • Loading branch information
isazi committed Aug 22, 2024
1 parent 24857cc commit 3d32c47
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions kernel_tuner/utils/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,24 +481,38 @@ def wrap_data(code: str, langs: Code, data: dict, preprocessor: list = None, use
for name in data.keys():
if "*" in data[name][0]:
size = parse_size(data[name][1], preprocessor=preprocessor, dimensions=user_dimensions)
temp = None
if is_openacc(langs.directive):
if is_cxx(langs.language):
intro += create_data_directive_openacc_cxx(name, size)
outro += exit_data_directive_openacc_cxx(name, size)
elif is_fortran(langs.language):
intro += create_data_directive_openacc_fortran(name, size)
outro += exit_data_directive_openacc_fortran(name, size)
temp = wrap_data_openacc(name, size)
elif is_openmp(langs.directive):
if is_cxx(langs.language):
intro += create_data_directive_openmp_cxx(name, size)
outro += exit_data_directive_openmp_cxx(name, size)
elif is_fortran(langs.language):
intro += create_data_directive_openmp_fortran(name, size)
outro += exit_data_directive_openmp_fortran(name, size)

temp = wrap_data_openmp(name, size)
intro += temp[0]
outro += temp[1]
return "\n".join([intro, code, outro])


def wrap_data_openacc(name: str, size: int) -> Tuple[str, str]:
"""Create language specific data directives"""
if is_cxx(langs.language):
intro = create_data_directive_openacc_cxx(name, size)
outro = exit_data_directive_openacc_cxx(name, size)
elif is_fortran(langs.language):
intro = create_data_directive_openacc_fortran(name, size)
outro = exit_data_directive_openacc_fortran(name, size)
return intro, outro


def wrap_data_openmp(name: str, size: int) -> Tuple[str, str]:
"""Create language specific data directives"""
if is_cxx(langs.language):
intro += create_data_directive_openmp_cxx(name, size)
outro += exit_data_directive_openmp_cxx(name, size)
elif is_fortran(langs.language):
intro += create_data_directive_openmp_fortran(name, size)
outro += exit_data_directive_openmp_fortran(name, size)
return intro, outro


def extract_directive_code(code: str, langs: Code, kernel_name: str = None) -> dict:
"""Extract explicitly marked directive sections from code"""
if is_cxx(langs.language):
Expand Down

0 comments on commit 3d32c47

Please sign in to comment.