Skip to content

Commit

Permalink
Nontrivial refactoring of nll to better compartmentalize the
Browse files Browse the repository at this point in the history
operations that are done on a single thread and the logic of splitting
those up to parallelize. Ultimately hoping that this re-org will make it
possible to easily use `ReverseDiff.gradient!` on a compiled tape in
parallel. Working on it...
  • Loading branch information
Chris Geoga committed Mar 20, 2023
1 parent c51816c commit 0b6f981
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Vecchia"
uuid = "8d73829f-f4b0-474a-9580-cecc8e084068"
authors = ["Chris Geoga <[email protected]>"]
version = "0.9.6"
version = "0.9.7"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
47 changes: 21 additions & 26 deletions src/nll.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,32 @@ function negloglik(U::UpperTriangular, y_mut_allowed::AbstractMatrix{T}) where{T
(2*logdet(U), sum(_square, y_mut_allowed))
end

# see ./structstypes.jl for a def of the struct fields. But since this method is
# the core logic for the nll function, I think it should live here.
function (vp::VecchiaLikelihoodPiece{H,D,F,T})(p) where{H,D,F,T}
out_logdet = zero(eltype(p))
out_qforms = zero(eltype(p))
for j in vp.ixrange
(ldj, qfj) = cnll_str(vp.cfg, j, vp.buf, p)
out_logdet += ldj
out_qforms += qfj
end
(out_logdet*size(first(vp.cfg.data), 2) + out_qforms)/2
end

function nll(V::VecchiaConfig{H,D,F}, params::AbstractVector{T}) where{H,D,F,T}
checkthreads()
Z = promote_type(H,T)
ndata = size(V.data[1], 2)
cpts_sz = V.chunksize*V.blockrank
pts_sz = V.chunksize
nthr = Threads.nthreads()
bufs = allocate_cnll_bufs(nthr, Val(D), Val(Z), ndata, cpts_sz, pts_sz)
(logdets, qforms) = _nll(V, params, bufs)
(logdets*ndata + qforms)/2
Z = promote_type(H,T)
pieces = split_nll_pieces(V, Val(Z), Threads.nthreads())
_nll(pieces, params)
end

function _nll(V::VecchiaConfig{H,D,F}, params::AbstractVector{T},
bufs::Vector{CondLogLikBuf{D,Z}})::Tuple{Z,Z} where{H,D,F,T,Z}
kernel = V.kernel
out_logdet = zeros(Z, length(bufs))
out_qforms = zeros(Z, length(bufs))
# Note that I'm not just using Threads.@threads for [...] and then getting
# buffers with bufs[Threads.threadid()], because this has the potential for
# some soundness issues. Further reading:
# https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042
m = cld(length(V.condix), Threads.nthreads())
@sync for (i, chunk) in enumerate(Iterators.partition(eachindex(V.condix), m))
tbuf = bufs[i]
Threads.@spawn for j in chunk
(ldj, qfj) = cnll_str(V, j, tbuf, params)
out_logdet[i] += ldj
out_qforms[i] += qfj
end
function _nll(pieces, params)
out = zeros(eltype(params), length(pieces))
@sync for j in eachindex(pieces)
Threads.@spawn (out[j] = pieces[j](params))
end
sum(out_logdet), sum(out_qforms)
sum(out)
end

function cnll_str(V::VecchiaConfig{H,D,F}, j::Int,
Expand Down
16 changes: 16 additions & 0 deletions src/structstypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ end
(m::LocalQuadraticApprox)(p) = m.fk + dot(m.gk, p) + dot(p, m.hk, p)/2



# TODO (cg 2021/04/25 13:06): should these fields chunksize and blockrank be in
# here? Arguably the are redundant and encoded in the data/pts/condix values.
# And having them sort of provides a dangerously easy option to not check and
Expand All @@ -97,6 +98,8 @@ function Base.display(V::VecchiaConfig)
println(" - nsamples: $(size(V.data[1], 2))")
end



# TODO (cg 2021/04/25 13:06): should these fields chunksize and blockrank be in
# here? Arguably the are redundant and encoded in the data/pts/condix values.
# And having them sort of provides a dangerously easy option to not check and
Expand Down Expand Up @@ -137,6 +140,19 @@ function cnllbuf(::Val{D}, ::Val{Z}, ndata, cpts_sz, pts_sz) where{D,Z}
CondLogLikBuf{D,Z}(buf_pp, buf_cp, buf_cc, buf_cdat, buf_mdat, buf_cpts)
end

# A piece of a Vecchia approximation with a single-argument method. Split up
# like this because using ReverseDiff.gradient on the thread-parallel nll
# doesn't work, and so breaking it into pieces like this means that I can more
# easily compile tapes for the chunks that would each be evaluated on a single
# thread, and then parallelize the calls to ReverseDiff.gradient!.
#
# see the method definition in ./nll.jl.
struct VecchiaLikelihoodPiece{H,D,F,T} <: Function
cfg::VecchiaConfig{H,D,F}
buf::CondLogLikBuf{D,T}
ixrange::UnitRange{Int64}
end

struct CondRCholBuf{D,T}
buf_pp::Matrix{T}
buf_cp::Matrix{T}
Expand Down
13 changes: 12 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,21 @@ function globalidxs(datavv)
end

function allocate_cnll_bufs(N, ::Val{D}, ::Val{Z},
ndata, cpts_sz, pts_sz) where{D,Z}
ndata, cpts_sz, pts_sz) where{D,Z}
[cnllbuf(Val(D), Val(Z), ndata, cpts_sz, pts_sz) for _ in 1:N]
end

function split_nll_pieces(V::VecchiaConfig{H,D,F}, ::Val{Z}, m) where{H,D,F,Z}
ndata = size(first(V.data), 2)
cpts_sz = V.chunksize*V.blockrank
pts_sz = V.chunksize
chunks = Iterators.partition(eachindex(V.pts), cld(length(V.pts), m))
map(chunks) do chunk
local_buf = cnllbuf(Val(D), Val(Z), ndata, cpts_sz, pts_sz)
VecchiaLikelihoodPiece(V, local_buf, chunk)
end
end

@generated function allocate_crchol_bufs(::Val{N}, ::Val{D}, ::Val{Z},
cpts_sz, pts_sz) where{N,D,Z}
quote
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

using Test, LinearAlgebra, StaticArrays, StableRNGs, Vecchia, SparseArrays

BLAS.set_num_threads(1)

# TODO (cg 2022/12/23 15:18):
# 1) Any EM tests
# 2) Any sqp/tr tests
Expand Down

0 comments on commit 0b6f981

Please sign in to comment.