Skip to content

Commit

Permalink
wave equation: add gpu reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarna committed Dec 13, 2024
1 parent ad5c58d commit c195a74
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions examples/wave_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,15 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
t = i * dt

if t >= next_t_export - 1e-8:
if device:
# FIXME gpu.memcpy to host requires identity layout
# FIXME reduction on gpu
# e_host = e.to_device()
# u_host = u.to_device()
# h_host = h.to_device()
# _elev_max = np.max(e_host, all_axes)
# _u_max = np.max(u_host, all_axes)
# _total_v = np.sum(e_host + h, all_axes)
_elev_max = 0
_u_max = 0
_total_v = 0
else:
_elev_max = np.max(e, all_axes)
_u_max = np.max(u, all_axes)
_total_v = np.sum(e + h, all_axes)
sync()
H_tmp = e + h
sync()
_elev_max = np.max(e, all_axes).to_device()
# NOTE max(u) segfaults, shape (n+1, n) too large for tiling
_u_max = np.max(u[1:, :], all_axes).to_device()
_total_v = np.sum(H_tmp, all_axes).to_device()
# NOTE this segfaults
# _total_v = np.sum(e + h, all_axes).to_device() # segfaults

elev_max = float(_elev_max)
u_max = float(_u_max)
Expand Down Expand Up @@ -294,16 +287,11 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
duration = time_mod.perf_counter() - tic
info(f"Duration: {duration:.2f} s")

if device:
# FIXME gpu.memcpy to host requires identity layout
# FIXME reduction on gpu
# err2_host = err2.to_device()
# err_L2 = math.sqrt(float(np.sum(err2_host, all_axes)))
err_L2 = 0
else:
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly).to_device(device)
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
err2sum = np.sum(err2, all_axes).to_device()
sync()
err_L2 = math.sqrt(float(err2sum))
info(f"L2 error: {err_L2:7.5e}")

if nx == 128 and ny == 128 and not benchmark_mode and not device:
Expand Down

0 comments on commit c195a74

Please sign in to comment.