diff --git a/ext/cuda/data_layouts.jl b/ext/cuda/data_layouts.jl index 35c438ade1..7eec9dfe21 100644 --- a/ext/cuda/data_layouts.jl +++ b/ext/cuda/data_layouts.jl @@ -1,11 +1,17 @@ +import ClimaCore.DataLayouts: AbstractData +import ClimaCore.DataLayouts: FusedMultiBroadcast import ClimaCore.DataLayouts: IJKFVH, IJFH, VIJFH, VIFH, IFH, IJF, IF, VF, DataF import ClimaCore.DataLayouts: IJFHStyle, VIJFHStyle, VFStyle, DataFStyle import ClimaCore.DataLayouts: promote_parent_array_type import ClimaCore.DataLayouts: parent_array_type +import ClimaCore.DataLayouts: device_from_array_type, isascalar +import ClimaCore.DataLayouts: fused_copyto! import Adapt import CUDA +device_from_array_type(::Type{<:CUDA.CuArray}) = ClimaComms.CUDADevice() + parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} = CUDA.CuArray{T, N, B} where {N} @@ -180,3 +186,83 @@ function Base.fill!(dest::DataF{S, A}, val) where {S, A <: CUDA.CuArray} ) return dest end + +Base.@propagate_inbounds function rcopyto_at!( + pair::Pair{<:AbstractData, <:Any}, + I, + v, +) + dest, bc = pair.first, pair.second + if v <= size(dest, 4) + bcI = isascalar(bc) ? bc[] : bc[I] + dest[I] = bcI + end + return nothing +end +Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, v) + rcopyto_at!(first(pairs), I, v) + rcopyto_at!(Base.tail(pairs), I, v) +end +Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, v) = + rcopyto_at!(first(pairs), I, v) +@inline rcopyto_at!(pairs::Tuple{}, I, v) = nothing + +function knl_fused_copyto!(fmbc::FusedMultiBroadcast) + + @inbounds begin + i = CUDA.threadIdx().x + j = CUDA.threadIdx().y + + h = CUDA.blockIdx().x + v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z + (; pairs) = fmbc + I = CartesianIndex((i, j, 1, v, h)) + rcopyto_at!(pairs, I, v) + end + return nothing +end + +function fused_copyto!( + fmbc::FusedMultiBroadcast, + dest1::VIJFH{S, Nij}, + ::ClimaComms.CUDADevice, +) where {S, Nij} + _, _, _, Nv, Nh = size(dest1) + if Nv > 0 && Nh > 0 + Nv_per_block = min(Nv, fld(256, Nij * Nij)) + Nv_blocks = cld(Nv, Nv_per_block) + args = (fmbc,) + auto_launch!( + knl_fused_copyto!, + args, + dest1; + threads_s = (Nij, Nij, Nv_per_block), + blocks_s = (Nh, Nv_blocks), + ) + end + return nothing +end + +adapt_f(to, f::F) where {F} = Adapt.adapt(to, f) +adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...) + +function Adapt.adapt_structure( + to::CUDA.KernelAdaptor, + fmbc::FusedMultiBroadcast, +) + FusedMultiBroadcast( + map(fmbc.pairs) do pair + dest = pair.first + bc = pair.second + Pair( + Adapt.adapt(to, dest), + Base.Broadcast.Broadcasted( + bc.style, + adapt_f(to, bc.f), + Adapt.adapt(to, bc.args), + Adapt.adapt(to, bc.axes), + ), + ) + end, + ) +end diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index ca72c59cbe..3a7a5ef2f5 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -17,6 +17,7 @@ module DataLayouts import Base: Base, @propagate_inbounds import StaticArrays: SOneTo, MArray, SArray import ClimaComms +import MultiBroadcastFusion as MBF import Adapt import ..slab, ..slab_args, ..column, ..column_args, ..level @@ -1451,4 +1452,11 @@ Adapt.adapt_structure(to, data::VF{S}) where {S} = Adapt.adapt_structure(to, data::DataF{S}) where {S} = DataF{S}(Adapt.adapt(to, parent(data))) +# TODO: Should the DataLayout be device-aware? So that we can +# determine if we're multi-threaded or not? +# This is only currently used in FusedMultiBroadcast kernels +device_from_array_type(::Type{<:AbstractArray}) = ClimaComms.CPUSingleThreaded() +ClimaComms.device(data::AbstractData) = + device_from_array_type(typeof(parent(data))) + end # module diff --git a/src/DataLayouts/broadcast.jl b/src/DataLayouts/broadcast.jl index c77b5ad0b8..b2d0ed91e2 100644 --- a/src/DataLayouts/broadcast.jl +++ b/src/DataLayouts/broadcast.jl @@ -1,3 +1,11 @@ +import MultiBroadcastFusion as MBF +import MultiBroadcastFusion: fused_direct + +# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`: +# via https://github.com/CliMA/MultiBroadcastFusion.jl +MBF.@make_type FusedMultiBroadcast +MBF.@make_fused fused_direct FusedMultiBroadcast fused_direct + # Broadcasting of AbstractData objects # https://docs.julialang.org/en/v1/manual/interfaces/#Broadcast-Styles @@ -587,3 +595,96 @@ function Base.copyto!( ) where {S, Nij, A} return _serial_copyto!(dest, bc) end + +# ============= FusedMultiBroadcast + +isascalar( + bc::Base.Broadcast.Broadcasted{Style}, +) where { + Style <: + Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}}, +} = true +isascalar(bc) = false + + +# Fused multi-broadcast entry point for DataLayouts +function Base.copyto!( + fmbc::FusedMultiBroadcast{T}, +) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}} + dest1 = first(fmbc.pairs).first + # check_fused_broadcast_axes(fmbc) # we should already have checked the axes + fused_copyto!(fmbc, dest1, ClimaComms.device(dest1)) +end + +function fused_copyto!( + fmbc::FusedMultiBroadcast, + dest1::VIJFH{S1, Nij}, + ::ClimaComms.AbstractCPUDevice, +) where {S1, Nij} + _, _, _, Nv, Nh = size(dest1) + for (dest, bc) in fmbc.pairs + # Base.copyto!(dest, bc) # we can just fall back like this + @inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij, v in 1:Nv + I = CartesianIndex(i, j, 1, v, h) + bcI = isascalar(bc) ? bc[] : bc[I] + dest[I] = convert(eltype(dest), bcI) + end + end + return nothing +end + +function fused_copyto!( + fmbc::FusedMultiBroadcast, + dest1::VIFH{S, Ni, A}, + ::ClimaComms.AbstractCPUDevice, +) where {S, Ni, A} + # copy contiguous columns + _, _, _, Nv, Nh = size(dest1) + for (dest, bc) in fmbc.pairs + @inbounds for h in 1:Nh, i in 1:Ni, v in 1:Nv + I = CartesianIndex(i, 1, 1, v, h) + bcI = isascalar(bc) ? bc[] : bc[I] + dest[I] = convert(eltype(dest), bcI) + end + end + return nothing +end + +function fused_copyto!( + fmbc::FusedMultiBroadcast, + dest1::VF{S1, A}, + ::ClimaComms.AbstractCPUDevice, +) where {S1, A} + _, _, _, Nv, _ = size(dest1) + for (dest, bc) in fmbc.pairs + @inbounds for v in 1:Nv + I = CartesianIndex(1, 1, 1, v, 1) + dest[I] = convert(eltype(dest), bc[I]) + end + end + return nothing +end + +# we've already diagonalized dest, so we only need to make +# sure that all the broadcast axes are compatible. +# Logic here is similar to Base.Broadcast.instantiate +@inline function _check_fused_broadcast_axes(bc1, bc2) + axes = Base.Broadcast.combine_axes(bc1.args..., bc2.args...) + if !(axes isa Nothing) + Base.Broadcast.check_broadcast_axes(axes, bc1.args...) + Base.Broadcast.check_broadcast_axes(axes, bc2.args...) + end +end + +@inline check_fused_broadcast_axes(fmbc::FusedMultiBroadcast) = + check_fused_broadcast_axes( + map(x -> x.second, fmbc.pairs), + first(fmbc.pairs).second, + ) +@inline check_fused_broadcast_axes(bcs::Tuple{<:Any}, bc1) = + _check_fused_broadcast_axes(first(bcs), bc1) +@inline check_fused_broadcast_axes(bcs::Tuple{}, bc1) = nothing +@inline function check_fused_broadcast_axes(bcs::Tuple, bc1) + _check_fused_broadcast_axes(first(bcs), bc1) + check_fused_broadcast_axes(Base.tail(bcs), bc1) +end diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index 13b6fbf730..992b1f5aa0 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -3,7 +3,14 @@ module Fields import ClimaComms import MultiBroadcastFusion as MBF import ..slab, ..slab_args, ..column, ..column_args, ..level -import ..DataLayouts: DataLayouts, AbstractData, DataStyle +import ..DataLayouts: + DataLayouts, + AbstractData, + DataStyle, + FusedMultiBroadcast, + @fused_direct, + isascalar, + check_fused_broadcast_axes import ..Domains import ..Topologies import ..Quadratures diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index 2922ba6588..c5ef39dfa2 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -150,6 +150,45 @@ end return dest end +# Fused multi-broadcast entry point for Fields +function Base.copyto!( + fmbc::FusedMultiBroadcast{T}, +) where {N, T <: NTuple{N, Pair{<:Field, <:Any}}} + fmb_data = FusedMultiBroadcast( + map(fmbc.pairs) do pair + bc = Base.Broadcast.instantiate(todata(pair.second)) + bc′ = if isascalar(bc) + Base.Broadcast.instantiate( + Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()), + ) + else + bc + end + Pair(field_values(pair.first), bc′) + end, + ) + check_mismatched_spaces(fmbc) + check_fused_broadcast_axes(fmbc) + Base.copyto!(fmb_data) # forward to DataLayouts +end + +@inline check_mismatched_spaces(fmbc::FusedMultiBroadcast) = + check_mismatched_spaces( + map(x -> axes(x.first), fmbc.pairs), + axes(first(fmbc.pairs).first), + ) +@inline check_mismatched_spaces(axs::Tuple{<:Any}, ax1) = + _check_mismatched_spaces(first(axs), ax1) +@inline check_mismatched_spaces(axs::Tuple{}, ax1) = nothing +@inline function check_mismatched_spaces(axs::Tuple, ax1) + _check_mismatched_spaces(first(axs), ax1) + check_mismatched_spaces(Base.tail(axs), ax1) +end + +_check_mismatched_spaces(::T, ::T) where {T <: AbstractSpace} = nothing +_check_mismatched_spaces(space1, space2) = + error("FusedMultiBroadcast spaces are not the same.") + @noinline function error_mismatched_spaces(space1::Type, space2::Type) error("Broacasted spaces are not the same.") end diff --git a/test/Fields/field.jl b/test/Fields/field.jl index a49faf578d..b91b583c6d 100644 --- a/test/Fields/field.jl +++ b/test/Fields/field.jl @@ -1,4 +1,5 @@ #= +julia --check-bounds=yes --project=test julia --project=test using Revise; include(joinpath("test", "Fields", "field.jl")) =# @@ -915,3 +916,7 @@ end end nothing end + +include("field_multi_broadcast_fusion.jl") + +nothing diff --git a/test/Fields/field_multi_broadcast_fusion.jl b/test/Fields/field_multi_broadcast_fusion.jl new file mode 100644 index 0000000000..e642046166 --- /dev/null +++ b/test/Fields/field_multi_broadcast_fusion.jl @@ -0,0 +1,310 @@ +#= +julia --check-bounds=yes --project=test +julia --project=test +using Revise; include(joinpath("test", "Fields", "field_multi_broadcast_fusion.jl")) +=# +using Test +using JET +using BenchmarkTools + +using ClimaComms +using OrderedCollections +using StaticArrays, IntervalSets +import ClimaCore +import ClimaCore.Utilities: PlusHalf +import ClimaCore.DataLayouts: IJFH +import ClimaCore.DataLayouts +import ClimaCore: + Fields, + slab, + Domains, + Topologies, + Meshes, + Operators, + Spaces, + Geometry, + Quadratures + +import ClimaCore.Fields: @fused_direct +using LinearAlgebra: norm +using Statistics: mean +using ForwardDiff +using CUDA +using CUDA: @allowscalar + +util_file = + joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl") +if !(@isdefined(TU)) + include(util_file) + import .TestUtilities as TU +end + +function CenterExtrudedFiniteDifferenceSpaceLineHSpace( + ::Type{FT}; + zelem = 10, + context = ClimaComms.SingletonCommsContext(), + helem = 4, + Nq = 4, +) where {FT} + radius = FT(128) + zlim = (0, 1) + domain = Domains.IntervalDomain( + Geometry.XPoint(zero(FT)), + Geometry.XPoint(FT(1)); + periodic = true, + ) + hmesh = Meshes.IntervalMesh(domain; nelems = helem) + + vertdomain = Domains.IntervalDomain( + Geometry.ZPoint{FT}(zlim[1]), + Geometry.ZPoint{FT}(zlim[2]); + boundary_names = (:bottom, :top), + ) + vertmesh = Meshes.IntervalMesh(vertdomain, nelems = zelem) + vtopology = Topologies.IntervalTopology(context, vertmesh) + vspace = Spaces.CenterFiniteDifferenceSpace(vtopology) + + quad = Quadratures.GLL{Nq}() + htopology = Topologies.IntervalTopology(context, hmesh) + hspace = Spaces.SpectralElementSpace1D(htopology, quad) + return Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace) +end + +function benchmark_kernel!(f!, X, Y) + println("\n--------------------------- $(nameof(typeof(f!))) ") + trial = benchmark_kernel!(f!, X, Y, ClimaComms.device(X.x1)) + show(stdout, MIME("text/plain"), trial) +end +benchmark_kernel!(f!, X, Y, ::ClimaComms.CUDADevice) = + CUDA.@sync BenchmarkTools.@benchmark $f!($X, $Y); +benchmark_kernel!(f!, X, Y, ::ClimaComms.AbstractCPUDevice) = + BenchmarkTools.@benchmark $f!($X, $Y); + +function show_diff(A, B) + for pn in propertynames(A) + Ai = getproperty(A, pn) + Bi = getproperty(B, pn) + println("==================== Comparing $pn") + @show Ai + @show Bi + @show abs.(Ai .- Bi) + end +end + +function compare(A, B) + pass = true + for pn in propertynames(A) + pass = + pass && + all(parent(getproperty(A, pn)) .== parent(getproperty(B, pn))) + end + pass || show_diff(A, B) + return pass +end +function test_kernel!(; fused!, unfused!, X, Y) + for pn in propertynames(X) + rand_field!(getproperty(X, pn)) + end + for pn in propertynames(Y) + rand_field!(getproperty(Y, pn)) + end + X_fused = similar(X) + X_fused .= X + X_unfused = similar(X) + X_unfused .= X + Y_fused = similar(Y) + Y_fused .= Y + Y_unfused = similar(Y) + Y_unfused .= Y + unfused!(X_unfused, Y_unfused) + fused!(X_fused, Y_fused) + @testset "Test correctness of $(nameof(typeof(fused!)))" begin + @test compare(X_fused, X_unfused) + @test compare(Y_fused, Y_unfused) + end +end + +function fused!(X, Y) + (; x1, x2, x3) = X + (; y1, y2, y3) = Y + @fused_direct begin + @. y1 = x1 + x2 + x3 + @. y2 = x1 + x2 + x3 + end + return nothing +end +function unfused!(X, Y) + (; x1, x2, x3) = X + (; y1, y2, y3) = Y + @. y1 = x1 + x2 + x3 + @. y2 = x1 + x2 + x3 + return nothing +end +function fused_bycolumn!(X, Y) + (; x1, x2, x3) = X + (; y1, y2, y3) = Y + Fields.bycolumn(axes(x1)) do colidx + @fused_direct begin + @. y1[colidx] = x1[colidx] + x2[colidx] + x3[colidx] + @. y2[colidx] = x1[colidx] + x2[colidx] + x3[colidx] + end + end + return nothing +end +function unfused_bycolumn!(X, Y) + (; x1, x2, x3) = X + (; y1, y2, y3) = Y + Fields.bycolumn(axes(x1)) do colidx + @. y1[colidx] = x1[colidx] + x2[colidx] + x3[colidx] + @. y2[colidx] = x1[colidx] + x2[colidx] + x3[colidx] + end + return nothing +end + +function rand_field(FT, space) + f = Fields.Field(FT, space) + rand_field!(f) +end + +function rand_field!(f) + parent(f) .= map(x -> rand(), parent(f)) + return f +end + +@testset "FusedMultiBroadcast - restrict to only similar fields" begin + FT = Float64 + dev = ClimaComms.device() + cspace = TU.CenterExtrudedFiniteDifferenceSpace( + FT; + zelem = 3, + helem = 4, + context = ClimaComms.context(dev), + ) + fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace) + x = rand_field(FT, cspace) + y = rand_field(FT, fspace) + # Cannot fuse center and face-spaced broadcasting + @test_throws ErrorException begin + @fused_direct begin + @. x += 1 + @. y += 1 + end + end + nothing +end + +struct SomeData{FT} + a::FT + b::FT + c::FT +end +@testset "FusedMultiBroadcast - restrict to only similar broadcast types" begin + FT = Float64 + dev = ClimaComms.device() + cspace = TU.CenterExtrudedFiniteDifferenceSpace( + FT; + zelem = 3, + helem = 4, + context = ClimaComms.context(dev), + ) + fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace) + x = rand_field(FT, cspace) + sd = Fields.Field(SomeData{FT}, cspace) + x2 = rand_field(FT, cspace) + y = rand_field(FT, fspace) + # Error when the axes of the RHS are incompatible + @test_throws DimensionMismatch begin + @fused_direct begin + @. x += 1 + @. x += y + end + end + @test_throws DimensionMismatch begin + @fused_direct begin + @. x += y + @. x += y + end + end + # Different but compatible broadcasts + @fused_direct begin + @. x += 1 + @. x += x2 + end + # Different fields but same spaces + @fused_direct begin + @. x += 1 + @. sd = SomeData{FT}(1, 2, 3) + end + @fused_direct begin + @. x += 1 + @. sd.b = 3 + end + nothing +end + +@testset "FusedMultiBroadcast VIJFH and VF" begin + FT = Float64 + space = TU.CenterExtrudedFiniteDifferenceSpace( + FT; + zelem = 3, + helem = 4, + context = ClimaComms.context(), + ) + X = Fields.FieldVector( + x1 = rand_field(FT, space), + x2 = rand_field(FT, space), + x3 = rand_field(FT, space), + ) + Y = Fields.FieldVector( + y1 = rand_field(FT, space), + y2 = rand_field(FT, space), + y3 = rand_field(FT, space), + ) + test_kernel!(; fused!, unfused!, X, Y) + test_kernel!(; fused! = fused_bycolumn!, unfused! = unfused_bycolumn!, X, Y) + + benchmark_kernel!(unfused!, X, Y) + benchmark_kernel!(fused!, X, Y) + + benchmark_kernel!(unfused_bycolumn!, X, Y) + benchmark_kernel!(fused_bycolumn!, X, Y) + nothing +end + +@testset "FusedMultiBroadcast VIFH" begin + FT = Float64 + device = ClimaComms.device() + # Add GPU test when https://github.com/CliMA/ClimaCore.jl/issues/1383 is fixed + if device isa ClimaComms.CPUSingleThreaded + space = CenterExtrudedFiniteDifferenceSpaceLineHSpace( + FT; + zelem = 3, + helem = 4, + context = ClimaComms.context(device), + ) + X = Fields.FieldVector( + x1 = rand_field(FT, space), + x2 = rand_field(FT, space), + x3 = rand_field(FT, space), + ) + Y = Fields.FieldVector( + y1 = rand_field(FT, space), + y2 = rand_field(FT, space), + y3 = rand_field(FT, space), + ) + test_kernel!(; fused!, unfused!, X, Y) + test_kernel!(; + fused! = fused_bycolumn!, + unfused! = unfused_bycolumn!, + X, + Y, + ) + + benchmark_kernel!(unfused!, X, Y) + benchmark_kernel!(fused!, X, Y) + + benchmark_kernel!(unfused_bycolumn!, X, Y) + benchmark_kernel!(fused_bycolumn!, X, Y) + nothing + end +end