diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index f0c788e6..1ad2a866 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -30,10 +30,10 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) halfdim = first(dims) d = size(x, halfdim) n = size(y, halfdim) - scale = reshape( + scale = typeof(x)(reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) @@ -72,10 +72,10 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) n = size(x, halfdim) invN = AbstractFFTs.normalization(y, dims) twoinvN = 2 * invN - scale = reshape( + scale = typeof(x)(reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) @@ -111,10 +111,10 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) # compute scaling factors halfdim = first(dims) n = size(x, halfdim) - scale = reshape( + scale = typeof(x)(reshape( [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) + )) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ)