Skip to content

Commit

Permalink
Implement recursive_map as basis for make_zero
Browse files Browse the repository at this point in the history
...as well as recursive_add, recursive_accumulate!, and accumulate_into!
  • Loading branch information
danielwe committed Jan 9, 2025
1 parent 2309abd commit 16d5b65
Show file tree
Hide file tree
Showing 13 changed files with 2,029 additions and 1,566 deletions.
49 changes: 5 additions & 44 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,11 @@ end
end
end

@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
return Base.zero(prev)::FT
end

@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive}
if haskey(seen, prev)
return seen[prev]
end
new = Base.zero(prev)::FT
seen[prev] = new
return new
end

@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT, seen
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(T))
return nothing
end
@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
Enzyme.EnzymeCore.make_zero!(prev, nothing)
return nothing
# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct,
# but in case their dedicated `zero` and `fill!` methods are more efficient than
# `make_zero(!)`s recursion, we opt into treating them as leaves.
@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S,T}}) where {S,T}
return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T)
end

end
122 changes: 111 additions & 11 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,28 +506,128 @@ function autodiff_thunk end
function autodiff_deferred_thunk end

"""
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T
Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
make_zero(
prev::T, ::Val{copy_if_inactive}=Val(false), ::Val{runtime_inactive}=Val(false)
)::T
make_zero(
::Type{T},
seen::IdDict,
prev::T,
::Val{copy_if_inactive}=Val(false),
::Val{runtime_inactive}=Val(false),
)::T
Recursively make a copy of the value `prev::T` in which all differentiable values are
zeroed. The argument `copy_if_inactive` specifies what to do if the type `T` or any
of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s
instance (the default) or make a copy.
The argument `runtime_inactive` specifies whether each constituent type is checked for being
guaranteed inactive at runtime for every call to `make_zero`, or if this can be checked once
at compile-time and reused across multiple calls to `make_zero` and related functions (the
default). Runtime checks are necessary to pick up recently added methods to
`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually
not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have
previously been passed to `make_zero` or related functions.
Extending this method for custom types is rarely needed. If you implement a new type, such
as a GPU array type, for which `make_zero` should directly invoke `zero` rather than
iterate/broadcast when the eltype is scalar, it is sufficient to implement `Base.zero` and
make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is not appropriate,
extend [`EnzymeCore.isvectortype`](@ref) directly instead.)
"""
function make_zero end

"""
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive}=Val(false))::Nothing
Recursively set a variable's differentiable values to zero. Only applicable for types `T`
that are mutable or hold all differentiable values in mutable storage (e.g.,
`Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over
parts of `val` that are guaranteed to be inactive.
Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
The argument `runtime_inactive` specifies whether each constituent type is checked for being
guaranteed inactive at runtime for every call to `make_zero!`, or if this can be checked once
at compile-time and reused across multiple calls to `make_zero!` and related functions (the
default). Runtime checks are necessary to pick up recently added methods to
`EnzymeRules.inactive_type`, but may incur a significant performance penalty and is usually
not needed unless `EnzymeRules.inactive_type` is extended interactively for types that have
previously been passed to `make_zero!` or related functions.
Extending this method for custom types is rarely needed. If you implement a new mutable
type, such as a GPU array type, for which `make_zero!` should directly invoke
`fill!(x, false)` rather than iterate/broadcast when the eltype is scalar, it is sufficient
to implement `Base.zero`, `Base.fill!`, and make sure your type subtypes `DenseArray`. (If
subtyping `DenseArray` is not appropriate, extend [`EnzymeCore.isvectortype`](@ref) directly
instead.)
"""
function make_zero! end

"""
make_zero(prev::T)
isvectortype(::Type{T})::Bool
Helper function to recursively make zero.
"""
@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive}
make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive))
Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
and [`make_zero!`](@ref) recurse through an object.
By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or
`T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`.
A new vector type, such as a GPU array type, should normally subtype `DenseArray` and
inherit `isvectortype` that way. However if this is not appropariate, `isvectortype` may be
extended directly as follows:
```julia
@inline function EnzymeCore.isvectortype(::Type{T}) where {T<:NewArray}
U = eltype(T)
return isbitstype(U) && EnzymeCore.isscalartype(U)
end
```
Such a type should implement `Base.zero` and, if mutable, `Base.fill!`.
Extending `isvectortype` is mostly relevant for the lowest-level of abstraction of memory at
which vector space operations like addition and scalar multiplication are supported, the
prototypical case being `Array`. Regular Julia structs with vector space-like semantics
should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act
directly on their backing arrays, just like how Enzyme treats them when differentiating. For
example, structured matrix wrappers and sparse array types that are backed by `Array` should
not extend `isvectortype`.
See also [`isscalartype`](@ref).
"""
function isvectortype end

"""
isscalartype(::Type{T})::Bool
Trait defining a subset of [`isvectortype`](@ref) types that should not be considered
composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero
values of the type in-place. For example, `BigFloat` is a mutable type but does not support
in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensures that
`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]
By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete
`T <: AbstractFloat`.
A hypothetical new real number type with Enzyme support should usually subtype
`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,
the function can be extended as follows:
```julia
@inline EnzymeCore.isscalartype(::Type{NewReal}) = true
@inline EnzymeCore.isscalartype(::Type{Complex{NewReal}}) = true
```
In either case, the type should implement `Base.zero`.
See also [`isvectortype`](@ref).
[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is
mentioned here only to demonstrate that it would be inappropriate to use traits like
`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing,
showing the need for a dedicated `isscalartype` trait.
"""
function isscalartype end

function tape_type end

Expand Down
8 changes: 2 additions & 6 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,8 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
# compute the correct complex derivative in reverse mode by propagating the conjugate return values
# then subtracting twice the imaginary component to get the correct result

for (k, v) in seen
Compiler.recursive_accumulate(k, v, refn_seed)
end
for (k, v) in seen2
Compiler.recursive_accumulate(k, v, imfn_seed)
end
Compiler.accumulate_seen!(refn_seed, seen)
Compiler.accumulate_seen!(imfn_seed, seen2)

fused = fuse_complex_results(results, args...)

Expand Down
5 changes: 5 additions & 0 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,11 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T}
rt = Enzyme.Compiler.active_reg_inner(T, (), world)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

"""
Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode)
Expand Down
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ const JuliaGlobalNameMap = Dict{String,Any}(
include("absint.jl")
include("llvm/transforms.jl")
include("llvm/passes.jl")
include("typeutils/make_zero.jl")
include("typeutils/recursive_maps.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
Expand Down
50 changes: 1 addition & 49 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,47 +253,6 @@ function EnzymeRules.augmented_primal(
return EnzymeRules.AugmentedReturn(primal, shadow, shadow)
end


@inline function accumulate_into(
into::RT,
seen::IdDict,
from::RT,
)::Tuple{RT,RT} where {RT<:Array}
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
seen[into] = (into, from)
for i in eachindex(from)
tup = accumulate_into(into[i], seen, from[i])
@inbounds into[i] = tup[1]
@inbounds from[i] = tup[2]
end
end
return seen[into]
end

@inline function accumulate_into(
into::RT,
seen::IdDict,
from::RT,
)::Tuple{RT,RT} where {RT<:AbstractFloat}
if !haskey(seen, into)
seen[into] = (into + from, RT(0))
end
return seen[into]
end

@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT}
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
throw(AssertionError("Unknown type to accumulate into: $RT"))
end
return seen[into]
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{typeof(Base.deepcopy)},
Expand All @@ -302,15 +261,8 @@ function EnzymeRules.reverse(
x::Annotation{Ty},
) where {RT,Ty}
if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
accumulate_into(x.dval, IdDict(), shadow)
else
for i = 1:EnzymeRules.width(config)
accumulate_into(x.dval[i], IdDict(), shadow[i])
end
end
Compiler.accumulate_into!(x.dval, shadow)
end

return (nothing,)
end

Expand Down
Loading

0 comments on commit 16d5b65

Please sign in to comment.