diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 6956842462..252dc3f4a5 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -107,7 +107,29 @@ function enzyme_custom_setup_args( val = lookup_value(gutils, val, B) end - activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, false) #=isforeign=# + cmode = mode + if cmode == API.DEM_ReverseModeGradient + cmode = API.DEM_ReverseModePrimal + end + uncacheable = Vector{UInt8}(undef, length(collect(LLVM.operands(op))) - 1) + activep = + if mode == API.DEM_ForwardMode || + API.EnzymeGradientUtilsGetUncacheableArgs( + gutils, + op, + uncacheable, + length(uncacheable), + ) == 1 + API.EnzymeGradientUtilsGetReturnDiffeType( + gutils, + op, + C_NULL, + C_NULL, + cmode, + ) + else + API.EnzymeGradientUtilsGetDiffeType(gutils, op, false) + end if isKWCall && arg.arg_i == 2 Ty = arg.typ