Skip to content

Commit

Permalink
Add support for complex-valued data and kernels.
Browse files Browse the repository at this point in the history
TODO: I really should move some elements of the Project.toml tree into
weakdeps with extensions.
  • Loading branch information
Chris Geoga committed Oct 30, 2023
1 parent 13a6255 commit 17f21a0
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function prepare_z0_SR0(cfg, arg, data, errormodel)
n = size(data, 1)
Rinv = error_precision(errormodel, arg)
U = sparse(Vecchia.rchol(cfg, arg, issue_warning=false))
SR = Symmetric(U*U' + Rinv)
SR = Hermitian(U*U' + Rinv)
SRf = cholesky(SR, perm=n:-1:1) # for now
(SRf\(Rinv*data), SRf)
end
Expand Down
6 changes: 3 additions & 3 deletions src/nll.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function cnll_str(V::VecchiaConfig{H,D,F}, j::Int,
updatebuf!(cov_pp, pts, pts, V.kernel, params, skipltri=false)
# if the conditioning set is empty, just return the marginal nll:
if isempty(idxs)
cov_pp_f = cholesky!(Symmetric(cov_pp))
cov_pp_f = cholesky!(Hermitian(cov_pp))
return negloglik(cov_pp_f.U, mdat)
end
# otherwise, proceed and prepare conditioning points:
Expand All @@ -65,7 +65,7 @@ function cnll_str(V::VecchiaConfig{H,D,F}, j::Int,
updatebuf!(cov_cc, cpts, cpts, V.kernel, params, skipltri=true)
updatebuf!(cov_cp, cpts, pts, V.kernel, params, skipltri=false)
# Factorize the covariance matrix for the conditioning points:
cov_cc_f = cholesky!(Symmetric(cov_cc))
cov_cc_f = cholesky!(Hermitian(cov_cc))
# Before mutating the cross-covariance buffer, compute y - hat{y}, where
# hat{y} is the conditional expectation of y given the conditioning data.
ldiv!(cov_cc_f, cdat)
Expand All @@ -74,7 +74,7 @@ function cnll_str(V::VecchiaConfig{H,D,F}, j::Int,
# out any unnecessary allocations.
ldiv!(cov_cc_f.U', cov_cp)
mul!(cov_pp, adjoint(cov_cp), cov_cp, -one(T), one(T))
cov_pp_cond = cholesky!(Symmetric(cov_pp))
cov_pp_cond = cholesky!(Hermitian(cov_pp))
# compute the log-likelihood:
negloglik(cov_pp_cond.U, mdat)
end
Expand Down
14 changes: 7 additions & 7 deletions src/rcholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ end

# TODO (cg 2022/05/30 11:05): Continually look to squeeze allocations out of
# here. Maybe I can pre-allocate things for the BLAS calls, even?
function rchol_instantiate!(strbuf::RCholesky{T}, V::VecchiaConfig{H,D,F},
function rchol_instantiate!(strbuf::RCholesky, V::VecchiaConfig{H,D,F},
params::AbstractVector{T}, ::Val{Z}, tiles) where{H,D,F,T,Z}
checkthreads()
@assert !strbuf.is_instantiated[] RCHOL_INSTANTIATE_ERROR
strbuf.is_instantiated[] = true
kernel = V.kernel
cpts_sz = V.chunksize*V.blockrank
pts_sz = V.chunksize
cpts_sz = chunksize(V)*blockrank(V)
pts_sz = chunksize(V)
# allocate three buffers:
bufs = allocate_crchol_bufs(Threads.nthreads(), Val(D), Val(Z), cpts_sz, pts_sz)
# do the main loop:
Expand All @@ -49,7 +49,7 @@ function rchol_instantiate!(strbuf::RCholesky{T}, V::VecchiaConfig{H,D,F},
else
updatebuf_tiles!(cov_pp, tiles, j, j)
end
cov_pp_f = cholesky!(Symmetric(cov_pp))
cov_pp_f = cholesky!(Hermitian(cov_pp))
buf = strbuf.diagonals[j]
ldiv!(cov_pp_f.U, buf)
else
Expand Down Expand Up @@ -77,11 +77,11 @@ function rchol_instantiate!(strbuf::RCholesky{T}, V::VecchiaConfig{H,D,F},
# acknowledge that this is a little hard to read, but it really nicely cuts
# out all the unnecessary allocations. If you do a manual check, you can
# confirm that cov_pp becomes the conditional covariance of pts | cpts, etc.
cov_cc_f = cholesky!(Symmetric(cov_cc))
cov_cc_f = cholesky!(Hermitian(cov_cc))
ldiv!(cov_cc_f.U', cov_cp)
mul!(cov_pp, adjoint(cov_cp), cov_cp, -one(Z), one(Z))
ldiv!(cov_cc_f.U, cov_cp)
Djf = cholesky!(Symmetric(cov_pp))
Djf = cholesky!(Hermitian(cov_pp))
# Update the struct buffers. Note that the diagonal elements are actually
# UpperTriangular, and I am not supposed to mutate those. But we do the ugly
# hack of directly working with the data buffer backing the UpperTriangular
Expand All @@ -108,7 +108,7 @@ function rchol(V::VecchiaConfig{H,D,F}, params::AbstractVector{T};
@warn RCHOL_WARN maxlog=1
end
# allocate:
out = RCholesky_alloc(V, Val(T))
out = RCholesky_alloc(V, Val(promote_type(T,H)))
# compute the out type and the number of threads to pass in as vals:
Z = promote_type(H, T)
# create tiles if requested:
Expand Down
42 changes: 8 additions & 34 deletions src/structstypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
function (f::AutoFwdfgh{F,R})(x) where{F,R}
ForwardDiff.hessian!(f.res, f.f, x)
(DiffResults.value(f.res), DiffResults.gradient(f.res),
Symmetric(DiffResults.hessian(f.res)))
Hermitian(DiffResults.hessian(f.res)))
end

struct AutoFwdBFGS{F,R}
Expand Down Expand Up @@ -64,14 +64,14 @@ function (f::AutoFwdBFGS{F,R})(x) where{F,R}
# update the xm1:
f.xm1 .= x
# return everything:
(DiffResults.value(f.res), f.g, Symmetric(f.B))
(DiffResults.value(f.res), f.g, Hermitian(f.B))
end

# Writing a local quadratic approximation struct to avoid creating a closure.
struct LocalQuadraticApprox
fk::Float64
gk::Vector{Float64}
hk::Symmetric{Float64, Matrix{Float64}}
hk::Hermitian{Float64, Matrix{Float64}}
end
(m::LocalQuadraticApprox)(p) = m.fk + dot(m.gk, p) + dot(p, m.hk, p)/2

Expand All @@ -90,8 +90,6 @@ end
# And having them sort of provides a dangerously easy option to not check and
# make sure what those sizes really need to be.
struct VecchiaConfig{H,D,F} <: AbstractVecchiaConfig{H,D,F}
chunksize::Int64
blockrank::Int64
kernel::F
data::Vector{Matrix{H}}
pts::Vector{Vector{SVector{D, Float64}}}
Expand All @@ -100,35 +98,12 @@ end

function Base.display(V::VecchiaConfig)
println("Vecchia configuration with:")
println(" - chunksize: $(V.chunksize)")
println(" - block rank: $(V.blockrank)")
println(" - chunksize: $(chunksize(V))")
println(" - block rank: $(blockrank(V))")
println(" - data size: $(sum(x->size(x,1), V.data))")
println(" - nsamples: $(size(V.data[1], 2))")
end



# TODO (cg 2021/04/25 13:06): should these fields chunksize and blockrank be in
# here? Arguably the are redundant and encoded in the data/pts/condix values.
# And having them sort of provides a dangerously easy option to not check and
# make sure what those sizes really need to be.
struct ScalarVecchiaConfig{H,D,F} <: AbstractVecchiaConfig{H,D,F}
chunksize::Int64
blockrank::Int64
kernel::F
data::Vector{Matrix{H}}
pts::Vector{Vector{Float64}}
condix::Vector{Vector{Int64}}
end

function Base.display(V::ScalarVecchiaConfig)
println("Scalarized Vecchia configuration with:")
println("chunksize: $(V.chunksize)")
println("block rank: $(V.blockrank)")
println("data size: $(sum(x->size(x,1), V.data))")
println("nsamples: $(size(V.data[1], 2))")
end

struct CondLogLikBuf{D,T}
buf_pp::Matrix{T}
buf_cp::Matrix{T}
Expand Down Expand Up @@ -235,6 +210,7 @@ end
# with a KD-tree to choose the conditioning points.
function kdtreeconfig(data, pts, chunksize, blockrank, kfun)
(data isa Vector) && return kdtreeconfig(hcat(data), pts, chunksize, blockrank, kfun)
size(data, 1) == length(pts) || @warn "Your input data and points don't have the same length. Consider checking your code for mistakes."
# Make a KDTree of the points with a certain leaf size
tree = KDTree(pts, leafsize=chunksize)
# re-order the data accordingly:
Expand All @@ -258,9 +234,7 @@ function kdtreeconfig(data, pts, chunksize, blockrank, kfun)
# Create the conditioning meta-indices for the chunks.
condix = [cond_ixs(j,blockrank) for j in eachindex(pts_out)]
(H,D,F) = (eltype(data), length(first(pts)), typeof(kfun))
VecchiaConfig{H,D,F}(min(chunksize, length(first(pts_out))),
min(blockrank, length(pts_out)),
kfun, dat_out, pts_out, condix)
VecchiaConfig{H,D,F}(kfun, dat_out, pts_out, condix)
end

function nosortknnconfig(data, pts, blockranks, kfun)
Expand All @@ -275,7 +249,7 @@ function nosortknnconfig(data, pts, blockranks, kfun)
end
pts = [[x] for x in pts]
dat = map(x->Matrix(x'), eachrow(data))
VecchiaConfig(1, maximum(blockranks), kfun, dat, pts, condix)
VecchiaConfig(kfun, dat, pts, condix)
end

function nosortknnconfig(data, pts, blockrank::Int64, kfun)
Expand Down
16 changes: 9 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

_square(x::Real) = x*x
_square(x::Complex) = real(x*conj(x))

_mean(x) = sum(x)/length(x)

Expand All @@ -11,6 +12,9 @@ function checkthreads()::Nothing
nothing
end

blockrank(cfg::VecchiaConfig) = maximum(length, cfg.condix)
chunksize(cfg::VecchiaConfig) = maximum(length, cfg.pts)

# A hacky function to return an empty Int64[] for the first conditioning set.
@inline cond_ixs(j, r) = j == 1 ? Int64[] : collect(max(1,j-r):max(1,j-1))

Expand Down Expand Up @@ -128,7 +132,7 @@ function _nnz(vchunks, condix)
end

function generic_dense_nll(S, data)
Sf = cholesky(Symmetric(S))
Sf = cholesky(Hermitian(S))
(logdet(Sf) + sum(abs2, Sf.U'\data))/2
end

Expand Down Expand Up @@ -171,8 +175,7 @@ end
function vecchia_estimate_nugget(cfg, init, optimizer, errormodel;
optimizer_kwargs...)
nugkernel = ErrorKernel(cfg.kernel, errormodel)
nug_cfg = Vecchia.VecchiaConfig(cfg.chunksize, cfg.blockrank,
nugkernel, cfg.data, cfg.pts, cfg.condix)
nug_cfg = Vecchia.VecchiaConfig(nugkernel, cfg.data, cfg.pts, cfg.condix)
likelihood = WrappedLogLikelihood(nug_cfg)
optimizer(likelihood, init; optimizer_kwargs...)
end
Expand All @@ -189,8 +192,7 @@ function augmented_em_cfg(V::VecchiaConfig{H,D,F}, z0, presolved_saa) where{H,D,
new_data = map(chunksix) do ixj
hcat(z0[ixj,:], presolved_saa[ixj,:])
end
Vecchia.VecchiaConfig{H,D,F}(V.chunksize, V.blockrank, V.kernel,
new_data, V.pts, V.condix)
Vecchia.VecchiaConfig{H,D,F}(V.kernel, new_data, V.pts, V.condix)
end

function globalidxs(datavv)
Expand All @@ -210,8 +212,8 @@ end

function split_nll_pieces(V::VecchiaConfig{H,D,F}, ::Val{Z}, m) where{H,D,F,Z}
ndata = size(first(V.data), 2)
cpts_sz = V.chunksize*V.blockrank
pts_sz = V.chunksize
cpts_sz = chunksize(V)*blockrank(V)
pts_sz = chunksize(V)
chunks = Iterators.partition(eachindex(V.pts), cld(length(V.pts), m))
map(chunks) do chunk
local_buf = cnllbuf(Val(D), Val(Z), ndata, cpts_sz, pts_sz)
Expand Down

0 comments on commit 17f21a0

Please sign in to comment.