Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tsplit #2018

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Tsplit #2018

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.13.12"
[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Expand Down
2 changes: 2 additions & 0 deletions src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ function absint(arg::LLVM.Value, partial::Bool = false)
return (false, nothing)
end
ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce)))
@show ptr
if ptr == C_NULL
# bt = GPUCompiler.backtrace(arg)
# btstr = sprint() do io
Expand All @@ -155,6 +156,7 @@ function absint(arg::LLVM.Value, partial::Bool = false)
return (false, nothing)
end
typ = Base.unsafe_pointer_to_objref(ptr)
@show typ
return (true, typ)
end
return (false, nothing)
Expand Down
24 changes: 22 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1615,8 +1615,8 @@ function julia_error(
legal2, obj = absint(cur)

# Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple
if legal2 &&
active_reg_inner(TT, (), world) == ActiveState &&
if legal2
if active_reg_inner(TT, (), world) == ActiveState &&
isa(cur, LLVM.ConstantExpr) &&
cur == data2
if width == 1
Expand All @@ -1634,6 +1634,14 @@ function julia_error(
end
return shadowres
end
end

@static if VERSION < v"1.11-"
else
if obj isa Memory && obj === typeof(obj).instance
return make_batched(ncur, prevbb)
end
end
end

badval = if legal2
Expand Down Expand Up @@ -6157,6 +6165,8 @@ function GPUCompiler.codegen(
expectLen += 1
end

@show mi, RT, LLVM.name(f)

# Unsupported calling conv
# also wouldn't have any type info for this [would for earlier args though]
if mi.specTypes.parameters[end] === Vararg{Any}
Expand Down Expand Up @@ -6322,6 +6332,14 @@ function GPUCompiler.codegen(

func = mi.specTypes.parameters[1]

@static if VERSION < v"1.11-"
else
if func == typeof(Core.memoryref)
attributes = function_attributes(llvmfn)
push!(attributes, EnumAttribute("alwaysinline", 0))
end
end

meth = mi.def
name = meth.name
jlmod = meth.module
Expand Down Expand Up @@ -7092,6 +7110,8 @@ function GPUCompiler.codegen(
end
end

mark_type_split!(mod)

if params.run_enzyme
# Generate the adjoint
memcpy_alloca_to_loadstore(mod)
Expand Down
5 changes: 5 additions & 0 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ end

function is_alwaysinline_func(@nospecialize(TT))
isa(TT, DataType) || return false
@static if VERSION ≥ v"1.11-"
if TT.parameters[1] == typeof(Core.memoryref)
return true
end
end
return false
end

Expand Down
62 changes: 62 additions & 0 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,72 @@ function unwrap_ptr_casts(val::LLVM.Value)
end
end

function mark_type_split!(mod::LLVM.Module)
for f in LLVM.functions(mod), bb in blocks(f)
inst = LLVM.terminator(bb)
@show string(inst)
if isa(inst, LLVM.SwitchInst) || (isa(inst, LLVM.BrInst) && LLVM.isconditional(inst))
cond = if isa(inst, LLVM.BrInst)
LLVM.condition(inst)
else
operands(inst)[1]
end
x = cond
while true
if isa(x, LLVM.BitCastInst) || isa(x, LLVM.AddrSpaceCastInst) || isa(x, LLVM.PtrToIntInst) || isa(x, LLVM.IntToPtrInst)
x = operands(x)[1]
continue
end
if isa(x, LLVM.AddInst) || isa(x, LLVM.SubInst)
if isa(operands(x)[2], LLVM.ConstantInt)
x = operands(x)[1]
continue
end
if isa(operands(x)[1], LLVM.ConstantInt)
x = operands(x)[2]
continue
end
end
if isa(x, LLVM.CallInst)
cv = LLVM.called_value(x)
if cv isa LLVM.Function
nm = LLVM.name(cv)
if nm == "julia.pointer_from_objref"
x = operands(x)[1]
continue
end
intr = LLVM.API.LLVMGetIntrinsicID(cv)
@show cv, intr, LLVM.Intrinsic("llvm.fshl").id, (intr == LLVM.Intrinsic("llvm.fshl").id), isa(operands(x)[2], LLVM.ConstantInt)
if intr == LLVM.Intrinsic("llvm.fshl").id
x = operands(x)[1]
@show "post fshl", x
continue
end
end
end
break
end
@show string(cond), string(x), string(inst)
if isa(x, LLVM.CallInst)
cv = LLVM.called_value(x)
if cv isa LLVM.Function
nm = LLVM.name(cv)
if nm == "julia.typeof"
metadata(inst)["enzyme_notypeprop"] = MDNode(LLVM.Metadata[])
@show "added", string(inst)
end
end
end
flush(stdout)
end
end
end

function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns)
calls = []
isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0
mod = LLVM.parent(f)
println(string(f))
for bb in blocks(f), inst in collect(instructions(bb))
if isa(inst, LLVM.CallInst)
push!(calls, inst)
Expand Down
18 changes: 18 additions & 0 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,23 @@ end

Enzyme.autodiff(Forward, byrefs, BatchDuplicated([1.0], ([1.0], [1.0])), BatchDuplicated([1.0], ([1.0], [1.0]) ) )
end

function myunique0()
return Vector{Float64}(undef, 0)
end
@static if VERSION < v"1.11-"
@testset "Forward mode array construct" begin
autodiff(Forward, myunique0, Duplicated)
end
else
function myunique()
m = Memory{Float64}.instance
return Core.memoryref(m)
end
@testset "Forward mode array construct" begin
autodiff(Forward, myunique, Duplicated)
autodiff(Forward, myunique0, Duplicated)
end
end

include("usermixed.jl")