Skip to content

Commit

Permalink
shallow water: add partial gpu support, warm up jit cache
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarna committed Nov 7, 2024
1 parent 0c674f4 commit 74cae69
Showing 1 changed file with 98 additions and 72 deletions.
170 changes: 98 additions & 72 deletions examples/shallow_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,28 +111,32 @@ def run(n, backend, datatype, benchmark_mode):
t_end = 1.0

# coordinate arrays
sync()
x_t_2d = fromfunction(
lambda i, j: xmin + i * dx + dx / 2,
(nx, ny),
dtype=dtype,
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype, device=""
)
y_t_2d = fromfunction(
lambda i, j: ymin + j * dy + dy / 2,
(nx, ny),
dtype=dtype,
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype, device=""
)
x_u_2d = fromfunction(
lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype, device=""
)
x_u_2d = fromfunction(lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype)
y_u_2d = fromfunction(
lambda i, j: ymin + j * dy + dy / 2,
(nx + 1, ny),
dtype=dtype,
device="",
)
x_v_2d = fromfunction(
lambda i, j: xmin + i * dx + dx / 2,
(nx, ny + 1),
dtype=dtype,
device="",
)
y_v_2d = fromfunction(lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype)
y_v_2d = fromfunction(
lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype, device=""
)
sync()

T_shape = (nx, ny)
U_shape = (nx + 1, ny)
Expand All @@ -157,7 +161,7 @@ def run(n, backend, datatype, benchmark_mode):
q = create_full(F_shape, 0.0, dtype)

# bathymetry
h = create_full(T_shape, 0.0, dtype)
h = create_full(T_shape, 1.0, dtype) # HACK init with 1

hu = create_full(U_shape, 0.0, dtype)
hv = create_full(V_shape, 0.0, dtype)
Expand Down Expand Up @@ -205,22 +209,16 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
bath = 1.0
return bath * create_full(T_shape, 1.0, dtype)

# inital elevation
u0, v0, e0 = exact_solution(
0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
)
e[:, :] = e0
u[:, :] = u0
v[:, :] = v0

# set bathymetry
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
# h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
# steady state potential energy
pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
# pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
pe_offset = 0.5 * g * float(1.0) / nx / ny

# compute time step
alpha = 0.5
h_max = float(np.max(h, all_axes))
# h_max = float(np.max(h, all_axes))
h_max = float(1.0)
c = (g * h_max) ** 0.5
dt = alpha * dx / c
dt = t_export / int(math.ceil(t_export / dt))
Expand Down Expand Up @@ -329,6 +327,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt)
e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt)

# warm jit cache
step(u, v, e, u1, v1, e1, u2, v2, e2)
sync()

# initial solution
u0, v0, e0 = exact_solution(
0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
)
e[:, :] = e0.to_device(device)
u[:, :] = u0.to_device(device)
v[:, :] = v0.to_device(device)

t = 0
i_export = 0
next_t_export = 0
Expand All @@ -341,30 +351,41 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
t = i * dt

if t >= next_t_export - 1e-8:
_elev_max = np.max(e, all_axes)
_u_max = np.max(u, all_axes)
_q_max = np.max(q, all_axes)
_total_v = np.sum(e + h, all_axes)

# potential energy
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
_total_pe = np.sum(_pe, all_axes)

# kinetic energy
u2 = u * u
v2 = v * v
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
_total_ke = np.sum(_ke, all_axes)

total_pe = float(_total_pe) * dx * dy
total_ke = float(_total_ke) * dx * dy
total_e = total_ke + total_pe
elev_max = float(_elev_max)
u_max = float(_u_max)
q_max = float(_q_max)
total_v = float(_total_v) * dx * dy
if device:
# FIXME gpu.memcpy to host requires identity layout
# FIXME reduction on gpu
elev_max = 0
u_max = 0
q_max = 0
diff_e = 0
diff_v = 0
total_pe = 0
total_ke = 0
else:
_elev_max = np.max(e, all_axes)
_u_max = np.max(u, all_axes)
_q_max = np.max(q, all_axes)
_total_v = np.sum(e + h, all_axes)

# potential energy
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
_total_pe = np.sum(_pe, all_axes)

# kinetic energy
u2 = u * u
v2 = v * v
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
_total_ke = np.sum(_ke, all_axes)

total_pe = float(_total_pe) * dx * dy
total_ke = float(_total_ke) * dx * dy
total_e = total_ke + total_pe
elev_max = float(_elev_max)
u_max = float(_u_max)
q_max = float(_q_max)
total_v = float(_total_v) * dx * dy

if i_export == 0:
initial_v = total_v
Expand Down Expand Up @@ -399,35 +420,40 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
duration = time_mod.perf_counter() - tic
info(f"Duration: {duration:.2f} s")

e_exact = exact_solution(t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d)[
2
]
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
info(f"L2 error: {err_L2:7.15e}")

if nx < 128 or ny < 128:
info("Skipping correctness test due to small problem size.")
elif not benchmark_mode:
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
assert (
diff_e < tolerance_ene
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
if nx == 128 and ny == 128:
if datatype == "f32":
assert numpy.allclose(
err_L2, 4.3127859e-05, rtol=1e-5
), "L2 error does not match"
else:
assert numpy.allclose(
err_L2, 4.315799035627906e-05
), "L2 error does not match"
else:
tolerance_l2 = 1e-4
if device:
# FIXME gpu.memcpy to host requires identity layout
# FIXME reduction on gpu
pass
else:
e_exact = exact_solution(
t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
)[2]
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
info(f"L2 error: {err_L2:7.15e}")

if nx < 128 or ny < 128:
info("Skipping correctness test due to small problem size.")
elif not benchmark_mode:
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
assert (
err_L2 < tolerance_l2
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
info("SUCCESS")
diff_e < tolerance_ene
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
if nx == 128 and ny == 128:
if datatype == "f32":
assert numpy.allclose(
err_L2, 4.3127859e-05, rtol=1e-5
), "L2 error does not match"
else:
assert numpy.allclose(
err_L2, 4.315799035627906e-05
), "L2 error does not match"
else:
tolerance_l2 = 1e-4
assert (
err_L2 < tolerance_l2
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
info("SUCCESS")

fini()

Expand Down

0 comments on commit 74cae69

Please sign in to comment.