Skip to content

Commit

Permalink
add an example: transpose.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AllanZyne committed Oct 21, 2024
1 parent c208ce6 commit 5af4fc4
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 4 deletions.
181 changes: 181 additions & 0 deletions examples/transpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
Rambo benchmark
Examples:
# run 1000 iterations of 10 events and 100 outputs on sharpy backend
python rambo.py -nevts 10 -nout 100 -b sharpy -i 1000
# MPI parallel run
mpiexec -n 3 python rambo.py -nevts 64 -nout 64 -b sharpy -i 1000
"""

import argparse
import time as time_mod

import numpy

import sharpy

try:
import mpi4py

mpi4py.rc.finalize = False
from mpi4py import MPI

comm_rank = MPI.COMM_WORLD.Get_rank()
comm = MPI.COMM_WORLD
except ImportError:
comm_rank = 0
comm = None


def info(s):
if comm_rank == 0:
print(s)


def sp_transpose(arr):
brr = sharpy.permute_dims(arr, [1, 0])
sharpy.sync()
return brr


def np_transpose(arr):
return arr.transpose()


def initialize(np, row, col, dtype):
arr = np.arange(row * col, dtype=dtype)
return np.reshape(arr, (row, col))


def run(row, col, backend, iterations, datatype):
if backend == "sharpy":
import sharpy as np
from sharpy import fini, init, sync

transpose = sp_transpose

init(False)
elif backend == "numpy":
import numpy as np

if comm is not None:
assert (
comm.Get_size() == 1
), "Numpy backend only supports serial execution."

fini = sync = lambda x=None: None
transpose = np_transpose
else:
raise ValueError(f'Unknown backend: "{backend}"')

dtype = {
"f32": np.float32,
"f64": np.float64,
}[datatype]

info(f"Using backend: {backend}")
info(f"Number of row: {row}")
info(f"Number of column: {col}")
info(f"Datatype: {datatype}")

arr = initialize(np, row, col, dtype)
sync()

# verify
if backend == "sharpy":
brr = sp_transpose(arr)
crr = np_transpose(sharpy.to_numpy(arr))
assert numpy.allclose(sharpy.to_numpy(brr), crr)

def eval():
tic = time_mod.perf_counter()
transpose(arr)
toc = time_mod.perf_counter()
return toc - tic

# warm-up run
t_warm = eval()

# evaluate
info(f"Running {iterations} iterations")
time_list = []
for i in range(iterations):
time_list.append(eval())

# get max time over mpi ranks
if comm is not None:
t_warm = comm.allreduce(t_warm, MPI.MAX)
time_list = comm.allreduce(time_list, MPI.MAX)

t_min = numpy.min(time_list)
t_max = numpy.max(time_list)
t_med = numpy.median(time_list)
init_overhead = t_warm - t_med
if backend == "sharpy":
info(f"Estimated initialization overhead: {init_overhead:.5f} s")
info(f"Min. duration: {t_min:.5f} s")
info(f"Max. duration: {t_max:.5f} s")
info(f"Median duration: {t_med:.5f} s")

fini()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run transpose benchmark",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument(
"-r",
"--row",
type=int,
default=10000,
help="Number of row.",
)
parser.add_argument(
"-c",
"--column",
type=int,
default=10000,
help="Number of column.",
)

parser.add_argument(
"-b",
"--backend",
type=str,
default="sharpy",
choices=["sharpy", "numpy"],
help="Backend to use.",
)

parser.add_argument(
"-i",
"--iterations",
type=int,
default=10,
help="Number of iterations to run.",
)

parser.add_argument(
"-d",
"--datatype",
type=str,
default="f64",
choices=["f32", "f64"],
help="Datatype for model state variables",
)

args = parser.parse_args()
run(
args.row,
args.column,
args.backend,
args.iterations,
args.datatype,
)
15 changes: 11 additions & 4 deletions sharpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def _validate_device(device):
raise ValueError(f"Invalid device string: {device}")


def arange(start, stop=None, step=1, dtype=int64, device="", team=1):
if stop is None:
stop = start
start = 0
return ndarray(
_csp.Creator.arange(
start, stop, step, dtype, _validate_device(device), team
)
)


for func in api.api_categories["Creator"]:
FUNC = func.upper()
if func == "full":
Expand All @@ -114,10 +125,6 @@ def _validate_device(device):
exec(
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device(device), team))"
)
elif func == "arange":
exec(
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device(device), team))"
)
elif func == "linspace":
exec(
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"
Expand Down

0 comments on commit 5af4fc4

Please sign in to comment.