Skip to content

Commit

Permalink
shallow-water: simplify reduction calls
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarna committed Nov 11, 2024
1 parent 57b40c1 commit 0952df3
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions examples/shallow_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
def transpose(a):
return np.permute_dims(a, [1, 0])

all_axes = [0, 1]
init(False)

elif backend == "numpy":
Expand All @@ -76,7 +75,6 @@ def transpose(a):
transpose = np.transpose

fini = sync = lambda x=None: None
all_axes = None
else:
raise ValueError(f'Unknown backend: "{backend}"')

Expand Down Expand Up @@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
# set bathymetry
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
# steady state potential energy
pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
pe_offset = 0.5 * g * float(np.sum(h**2.0)) / nx / ny

# compute time step
alpha = 0.5
h_max = float(np.max(h, all_axes))
h_max = float(np.max(h))
c = (g * h_max) ** 0.5
dt = alpha * dx / c
dt = t_export / int(math.ceil(t_export / dt))
Expand Down Expand Up @@ -344,22 +342,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
t = i * dt

if t >= next_t_export - 1e-8:
_elev_max = np.max(e, all_axes)
_u_max = np.max(u, all_axes)
_q_max = np.max(q, all_axes)
_total_v = np.sum(e + h, all_axes)
_elev_max = np.max(e)
_u_max = np.max(u)
_q_max = np.max(q)
_total_v = np.sum(e + h)

# potential energy
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
_total_pe = np.sum(_pe, all_axes)
_total_pe = np.sum(_pe)

# kinetic energy
u2 = u * u
v2 = v * v
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
_total_ke = np.sum(_ke, all_axes)
_total_ke = np.sum(_ke)

total_pe = float(_total_pe) * dx * dy
total_ke = float(_total_ke) * dx * dy
Expand Down Expand Up @@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
2
]
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
err_L2 = math.sqrt(float(np.sum(err2)))
info(f"L2 error: {err_L2:7.15e}")

if nx < 128 or ny < 128:
Expand Down

0 comments on commit 0952df3

Please sign in to comment.