Skip to content

Commit

Permalink
Accidentally broke EM iteration. This release fixes that.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Geoga committed Mar 21, 2023
1 parent 0b6f981 commit afb90f1
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Vecchia"
uuid = "8d73829f-f4b0-474a-9580-cecc8e084068"
authors = ["Chris Geoga <[email protected]>"]
version = "0.9.7"
version = "0.9.8"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
12 changes: 6 additions & 6 deletions src/em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# Putting the struct here instead of in structstypes.jl as a violation of my own
# stye rules. It just only gets used here. Maybe there is a lesson about how to
# organize code in this choice somewhere...
struct ExpectedJointNll{C,R} <: Function
cfg::C
struct ExpectedJointNll{H,D,F,R} <: Function
cfg::VecchiaConfig{H,D,F}
errormodel::R
data_minus_z0::Matrix{Float64}
presolved_saa::Matrix{Float64}
Expand All @@ -12,18 +12,18 @@ end

# Trying to move to callable structs instead of closures so that the
# precompilation can be better...
function (E::ExpectedJointNll{C})(p) where{C}
function (E::ExpectedJointNll{H,D,F})(p::AbstractVector{T}) where{H,D,F,T}
# Like with the normal nll function, this section handles the things that
# create type instability, and then passes them to _nll so that the function
# barrier means that everything _inside_ _nll, which we want to be fast and
# multithreaded, is stable and non-allocating.
Z = promote_type(eltype(first(E.cfg.data)), eltype(p))
nthr = Threads.nthreads()
Z = promote_type(H,T)
ndata = size(E.data_minus_z0, 2)
# compute the following terms at once using the augmented data:
# - nll(V, z0)
# - (2M)^{-1} sum_j \norm[2]{U(\p)^T v_j}^2, w/ v_j the pre-solved SAA.
(logdets, qforms) = _nll(E.cfg, p, Val(nthr), Val(Z))
pieces = split_nll_pieces(E.cfg, Val(Z), Threads.nthreads())
(logdets, qforms) = _nll(pieces, p)
out = (logdets*ndata + qforms)/2
# add on the generic nll for the measurement noise and the quadratic forms
# with the error matrix that contribute to the trace term.
Expand Down
21 changes: 15 additions & 6 deletions src/nll.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,31 @@ function (vp::VecchiaLikelihoodPiece{H,D,F,T})(p) where{H,D,F,T}
out_logdet += ldj
out_qforms += qfj
end
(out_logdet*size(first(vp.cfg.data), 2) + out_qforms)/2
(out_logdet, out_qforms)
end

function nll(V::VecchiaConfig{H,D,F}, params::AbstractVector{T}) where{H,D,F,T}
checkthreads()
Z = promote_type(H,T)
ndata = size(first(V.data), 2)
pieces = split_nll_pieces(V, Val(Z), Threads.nthreads())
_nll(pieces, params)
(logdets, qforms) = _nll(pieces, params)
(logdets*ndata + qforms)/2
end

function _nll(pieces, params)
out = zeros(eltype(params), length(pieces))
function _nll(pieces::Vector{VecchiaLikelihoodPiece{H,D,F,T}},
params) where{H,D,F,T}
logdets = zeros(eltype(params), length(pieces))
qforms = zeros(eltype(params), length(pieces))
@sync for j in eachindex(pieces)
Threads.@spawn (out[j] = pieces[j](params))
Threads.@spawn begin
pj = pieces[j]
(ldj, qfj) = pj(params)
logdets[j] = ldj
qforms[j] = qfj
end
end
sum(out)
(sum(logdets), sum(qforms))
end

function cnll_str(V::VecchiaConfig{H,D,F}, j::Int,
Expand Down
12 changes: 11 additions & 1 deletion src/structstypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,22 @@ end
# thread, and then parallelize the calls to ReverseDiff.gradient!.
#
# see the method definition in ./nll.jl.
struct VecchiaLikelihoodPiece{H,D,F,T} <: Function
struct VecchiaLikelihoodPiece{H,D,F,T}
cfg::VecchiaConfig{H,D,F}
buf::CondLogLikBuf{D,T}
ixrange::UnitRange{Int64}
end

struct PieceEvaluation{H,D,F,T} <: Function
piece::VecchiaLikelihoodPiece{H,D,F,T}
end

function (c::PieceEvaluation{H,D,F,T})(p) where{H,D,F,T}
(logdets, qforms) = c.piece(p)
ndata = size(first(c.piece.cfg.data), 2)
(ndata*logdets + qforms)/2
end

struct CondRCholBuf{D,T}
buf_pp::Matrix{T}
buf_cp::Matrix{T}
Expand Down

0 comments on commit afb90f1

Please sign in to comment.