Skip to content

Commit

Permalink
Merge pull request #225 from amarquand/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
amarquand authored Dec 18, 2024
2 parents aacd7bf + 0d8c8e0 commit 17e71d0
Show file tree
Hide file tree
Showing 19 changed files with 859 additions and 2,699 deletions.
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,13 @@ tests/cli_test/*

docs/autoapi/*
docs/_build/*
docs
docs
tests/cli_test_parallel_kfold/temp/be_te_fcon1000.pkl
tests/cli_test_parallel_kfold/temp/be_tr_fcon1000.pkl
tests/cli_test_parallel_kfold/temp/fcon1000
tests/cli_test_parallel_kfold/temp/X_te_fcon1000.pkl
tests/cli_test_parallel_kfold/temp/X_tr_fcon1000.pkl
tests/cli_test_parallel_kfold/temp/X_var_te_fcon1000.pkl
tests/cli_test_parallel_kfold/temp/X_var_tr_fcon1000.pkl
tests/cli_test_parallel_kfold/temp/Y_te_fcon1000.pkl
tests/cli_test_parallel_kfold/temp/Y_tr_fcon1000.pkl
247 changes: 124 additions & 123 deletions pcntoolkit/model/hbr.py

Large diffs are not rendered by default.

249 changes: 169 additions & 80 deletions pcntoolkit/normative.py

Large diffs are not rendered by default.

150 changes: 79 additions & 71 deletions pcntoolkit/normative_model/norm_hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,21 @@
@author: augub
"""

from __future__ import print_function
from __future__ import division
from sys import exit

from itertools import product
from __future__ import division, print_function

import os
import warnings
import sys
from sys import exit

import xarray
import arviz as az
import numpy as np
import xarray
from scipy import special as spp
from ast import literal_eval as make_tuple

try:
from pcntoolkit.dataio import fileio
from pcntoolkit.normative_model.norm_base import NormBase
from pcntoolkit.model.hbr import HBR
from pcntoolkit.normative_model.norm_base import NormBase
except ImportError:
pass

Expand Down Expand Up @@ -111,8 +106,9 @@ class NormHBR(NormBase):
but higher values like 0.9 or 0.95 often work better for problematic posteriors.
:param order: String that defines the order of bspline or polynomial model.
The defauls is '3'.
:param nknots: String that defines the numbers of knots for the bspline model.
The defauls is '5'. Higher values increase the model complexity with negative
:param nknots: String that defines the numbers of interior knots for the bspline model.
The defauls is '3'. Two knots will be added to this number for boundries. So final
number of knots will be nknots+2. Higher values increase the model complexity with negative
effect on the spped of estimations.
:param nn_hidden_layers_num: String the specifies the number of hidden layers
in neural network model. It can be either '1' or '2'. The default is set to '2'.
Expand Down Expand Up @@ -158,7 +154,7 @@ def __init__(self, **kwargs):

if self.configs["type"] == "bspline":
self.configs["order"] = int(kwargs.get("order", "3"))
self.configs["nknots"] = int(kwargs.get("nknots", "5"))
self.configs["nknots"] = int(kwargs.get("nknots", "3"))
elif self.configs["type"] == "polynomial":
self.configs["order"] = int(kwargs.get("order", "3"))
elif self.configs["type"] == "nn":
Expand Down Expand Up @@ -286,7 +282,7 @@ def estimate(self, X, y, **kwargs):
self.hbr.estimate(X, y, batch_effects_train)

return self

def predict(self, Xs, X=None, Y=None, **kwargs):
"""
Predict the target values for the given test data.
Expand All @@ -313,22 +309,22 @@ def predict(self, Xs, X=None, Y=None, **kwargs):

pred_type = self.configs["pred_type"]

if self.configs["transferred"] == False:
yhat, s2 = self.hbr.predict(
X=Xs,
batch_effects=batch_effects_test,
batch_effects_maps=self.batch_effects_maps,
pred=pred_type,
**kwargs,
)
else:
raise ValueError(
"This is a transferred model. Please use predict_on_new_sites function."
)
# if self.configs["transferred"] == False:
yhat, s2 = self.hbr.predict(
X=Xs,
batch_effects=batch_effects_test,
batch_effects_maps=self.batch_effects_maps,
pred=pred_type,
**kwargs,
)
# else:
# raise ValueError(
# "This is a transferred model. Please use predict_on_new_sites function."
# )

return yhat.squeeze(), s2.squeeze()

def estimate_on_new_sites(self, X, y, batch_effects):
def transfer(self, X, y, batch_effects):
"""
Samples from the posterior of the Hierarchical Bayesian Regression model.
Expand All @@ -342,7 +338,7 @@ def estimate_on_new_sites(self, X, y, batch_effects):
- 'trbefile': File containing the batch effects for the training data. Optional.
:return: The instance of the NormHBR object.
"""
self.hbr.estimate_on_new_site(X, y, batch_effects)
self.hbr.transfer(X, y, batch_effects)
self.configs["transferred"] = True
return self

Expand All @@ -357,9 +353,16 @@ def predict_on_new_sites(self, X, batch_effects):
:param batch_effects: Batch effects for the new sites.
:return: A tuple containing the predicted target values and the marginal variances for the test data on the new sites.
"""
yhat, s2 = self.hbr.predict_on_new_site(X, batch_effects)

yhat, s2 = self.hbr.predict(
X,
batch_effects=batch_effects,
batch_effects_maps=self.batch_effects_maps
)

return yhat, s2


def extend(
self,
X,
Expand All @@ -368,7 +371,7 @@ def extend(
X_dummy_ranges=[[0.1, 0.9, 0.01]],
merge_batch_dim=0,
samples=10,
informative_prior=False,
informative_prior=False
):
"""
Extend the Hierarchical Bayesian Regression model using data sampled from the posterior predictive distribution.
Expand All @@ -386,11 +389,11 @@ def extend(
:param informative_prior: Whether to use the adapt method for estimation. Default is False.
:return: The instance of the NormHBR object.
"""
X_dummy, batch_effects_dummy = self.hbr.create_dummy_inputs(
X_dummy_ranges)


X_dummy, batch_effects_dummy = self.hbr.create_dummy_inputs(X)
X_dummy, batch_effects_dummy, Y_dummy = self.hbr.generate(
X_dummy, batch_effects_dummy, samples
X_dummy, batch_effects_dummy, samples, batch_effects_maps=self.batch_effects_maps
)

batch_effects[:, merge_batch_dim] = (
Expand All @@ -399,18 +402,20 @@ def extend(
+ 1
)

X = np.concatenate((X_dummy, X))
y = np.concatenate((Y_dummy, y))
batch_effects = np.concatenate((batch_effects_dummy, batch_effects))

self.batch_effects_maps = [ {v: i for i, v in enumerate(np.unique(batch_effects[:, j]))}
for j in range(batch_effects.shape[1])
]

if informative_prior:
self.hbr.adapt(
np.concatenate((X_dummy, X)),
np.concatenate((Y_dummy, y)),
np.concatenate((batch_effects_dummy, batch_effects)),
)
#raise NotImplementedError("The extension with informaitve prior is not implemented yet.")
self.hbr.transfer(X, y, batch_effects)
else:
self.hbr.estimate(
np.concatenate((X_dummy, X)),
np.concatenate((Y_dummy, y)),
np.concatenate((batch_effects_dummy, batch_effects)),
)

self.hbr.estimate(X, y, batch_effects)

return self

Expand Down Expand Up @@ -522,44 +527,43 @@ def get_mcmc_quantiles(self, X, batch_effects=None, z_scores=None):
"""
# Set batch effects to zero if none are provided
if batch_effects is None:
batch_effects = batch_effects_test = np.zeros([X.shape[0], 1])
batch_effects = np.zeros([X.shape[0], 1])

# Set the z_scores for which the quantiles are computed
if z_scores is None:
z_scores = np.arange(-3, 4)
likelihood = self.configs["likelihood"]
elif len(z_scores.shape) == 2:
if not z_scores.shape[0] == X.shape[0]:
raise ValueError("The number of columns in z_scores must match the number of columns in X")
z_scores = z_scores.T

# Determine the variables to predict
if self.configs["likelihood"] == "Normal":
var_names = ["mu_samples", "sigma_samples", "sigma_plus_samples"]
elif self.configs["likelihood"].startswith("SHASH"):
var_names = [
"mu_samples",
"sigma_samples",
"sigma_plus_samples",
"epsilon_samples",
"delta_samples",
"delta_plus_samples",
]
else:
exit("Unknown likelihood: " + self.configs["likelihood"])
match self.configs["likelihood"]:
case "Normal":
var_names = ["mu_samples", "sigma_samples", "sigma_plus_samples"]
case "SHASHo" | "SHASHo2" | "SHASHb":
var_names = [
"mu_samples",
"sigma_samples",
"sigma_plus_samples",
"epsilon_samples",
"delta_samples",
"delta_plus_samples",
]
case _:
exit("Unknown likelihood: " + self.configs["likelihood"])

# Delete the posterior predictive if it already exists
if "posterior_predictive" in self.hbr.idata.groups():
del self.hbr.idata.posterior_predictive

if self.configs["transferred"] == True:
self.predict_on_new_sites(X=X, batch_effects=batch_effects)
# var_names = ["y_like"]
else:
self.hbr.predict(
# Do a forward to get the posterior predictive in the idata
X=X,
batch_effects=batch_effects,
batch_effects_maps=self.batch_effects_maps,
pred="single",
var_names=var_names + ["y_like"],
)
self.hbr.predict(
X=X,
batch_effects=batch_effects,
batch_effects_maps=self.batch_effects_maps,
pred="single",
var_names=var_names + ["y_like"],
)

# Extract the relevant samples from the idata
post_pred = az.extract(
Expand All @@ -580,8 +584,12 @@ def get_mcmc_quantiles(self, X, batch_effects=None, z_scores=None):
(z_scores.shape[0], len_synth_data, n_mcmc_samples))

# Compute the quantile iteratively for each z-score

for i, j in enumerate(z_scores):
zs = np.full((len_synth_data, n_mcmc_samples), j, dtype=float)
if len(z_scores.shape) == 1:
zs = np.full((len_synth_data, n_mcmc_samples), j, dtype=float)
else:
zs = np.repeat(j[:,None], n_mcmc_samples, axis=1)
quantiles[i] = xarray.apply_ufunc(
quantile,
*array_of_vars,
Expand Down
7 changes: 5 additions & 2 deletions pcntoolkit/normative_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def execute_nm(processing_dir,
cv_folds = kwargs.get('cv_folds', None)
testcovfile_path = kwargs.get('testcovfile_path', None)
testrespfile_path = kwargs.get('testrespfile_path', None)
outputsuffix = kwargs.get('outputsuffix', '_estimate')
outputsuffix = kwargs.get('outputsuffix', 'estimate')
outputsuffix = "_" + outputsuffix.replace("_", "")
cluster_spec = kwargs.pop('cluster_spec', 'torque')
log_path = kwargs.get('log_path', None)
binary = kwargs.pop('binary', False)
Expand Down Expand Up @@ -473,7 +474,7 @@ def collect_nm(processing_dir,
collect=False,
binary=False,
batch_size=None,
outputsuffix='_estimate'):
outputsuffix='estimate'):
'''Function to checks and collects all batches.
Basic usage::
Expand All @@ -493,6 +494,8 @@ def collect_nm(processing_dir,
written by (primarily) T Wolfers, (adapted) SM Kia, (adapted) S Rutherford.
'''

outputsuffix = "_" + outputsuffix.replace("_", "")

if binary:
file_extentions = '.pkl'
else:
Expand Down
Loading

0 comments on commit 17e71d0

Please sign in to comment.