Skip to content

Commit

Permalink
Merge pull request #231 from amarquand/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
amarquand authored Jan 8, 2025
2 parents 68018a2 + 254dfcd commit 6edae0a
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 36 deletions.
Binary file removed dist/pcntoolkit-0.30.post2-py3.12.egg
Binary file not shown.
40 changes: 24 additions & 16 deletions pcntoolkit/model/hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@
@author: augub
"""

from __future__ import print_function
from __future__ import division
from __future__ import division, print_function

from collections import OrderedDict
from functools import reduce
from itertools import product

import arviz as az
import numpy as np
import pymc as pm
import pytensor
import arviz as az
import xarray
from itertools import product
from functools import reduce
from scipy import stats
from util.utils import create_poly_basis, expand_all

from util.utils import create_poly_basis
from util.utils import expand_all
from pcntoolkit.util.utils import cartesian_product
from pcntoolkit.util.bspline import BSplineBasis
from pcntoolkit.model.SHASH import *
from pcntoolkit.util.bspline import BSplineBasis
from pcntoolkit.util.utils import cartesian_product


def create_poly_basis(X, order):
Expand Down Expand Up @@ -708,7 +707,7 @@ def Rhats(self, var_names=None, thin=1, resolution=100):
testvars = az.extract(idata, group='posterior',
var_names=var_names, combined=False)
testvar_names = [var for var in list(
testvars.data_vars.keys()) if not '_samples' in var]
testvars.data_vars.keys()) if '_samples' not in var]
rhat_dict = {}
for var_name in testvar_names:
var = np.stack(testvars[var_name].to_numpy())[:, ::thin]
Expand Down Expand Up @@ -795,12 +794,21 @@ def get_new_dim_size(tup):
dims = dims + pb.batch_effect_dim_names
if self.name.startswith("slope") or self.name.startswith("offset_slope"):
dims = dims + ["basis_functions"]
self.dist = from_posterior(
param=self.name,
samples=samples.to_numpy(),
shape=new_shape,
distribution=dist,
dims=dims,
if dims == []:
self.dist = from_posterior(
param=self.name,
samples=samples.to_numpy(),
shape=new_shape,
distribution=dist,
freedom=pb.configs["freedom"],
)
else:
self.dist = from_posterior(
param=self.name,
samples=samples.to_numpy(),
shape=new_shape,
distribution=dist,
dims=dims,
freedom=pb.configs["freedom"],
)

Expand Down
2 changes: 1 addition & 1 deletion pcntoolkit/normative.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ def transfer(covfile, respfile, testcov=None, testresp=None, maskfile=None,
if testcov is not None:
yhat, s2 = nm.predict_on_new_sites(Xte, batch_effects_test)
if testresp is not None:
Z[:, i] = nm.get_mcmc_zscores(Xte, Yte[:, i:i+1], **kwargs)
Z[:, i] = nm.get_mcmc_zscores(Xte, Yte[:, i:i+1], tsbefile=tsbefile, **kwargs)

# We basically use normative.predict script here.
if alg == 'blr':
Expand Down
12 changes: 5 additions & 7 deletions pcntoolkit/normative_model/norm_hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,6 @@ def predict(self, Xs, X=None, Y=None, **kwargs):
pred=pred_type,
**kwargs,
)
# else:
# raise ValueError(
# "This is a transferred model. Please use predict_on_new_sites function."
# )

return yhat.squeeze(), s2.squeeze()

Expand All @@ -339,6 +335,8 @@ def transfer(self, X, y, batch_effects):
:return: The instance of the NormHBR object.
"""
self.hbr.transfer(X, y, batch_effects)
self.batch_effects_maps = [{v: i for i, v in enumerate(np.unique(batch_effects[:, j]))}
for j in range(batch_effects.shape[1])]
self.configs["transferred"] = True
return self

Expand Down Expand Up @@ -452,7 +450,7 @@ def tune(
]

X_dummy, batch_effects_dummy, Y_dummy = self.hbr.generate(
X_dummy, batch_effects_dummy, samples
X_dummy, batch_effects_dummy, samples, self.batch_effects_maps
)

if informative_prior:
Expand Down Expand Up @@ -490,7 +488,7 @@ def merge(
X_dummy_ranges)

X_dummy1, batch_effects_dummy1, Y_dummy1 = self.hbr.generate(
X_dummy1, batch_effects_dummy1, samples
X_dummy1, batch_effects_dummy1, samples, self.batch_effects_maps
)
X_dummy2, batch_effects_dummy2, Y_dummy2 = nm.hbr.generate(
X_dummy2, batch_effects_dummy2, samples
Expand All @@ -512,7 +510,7 @@ def merge(

def generate(self, X, batch_effects, samples=10):
X, batch_effects, generated_samples = self.hbr.generate(
X, batch_effects, samples
X, batch_effects, samples, self.batch_effects_maps
)
return X, batch_effects, generated_samples

Expand Down
4 changes: 2 additions & 2 deletions pcntoolkit/normative_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,8 +822,8 @@ def collect_nm(processing_dir,
if meta_data['outscaler'] in ['standardize', 'minmax',
'robminmax']:
Y_scalers.append(meta_data['scaler_resp'])
meta_data['mean_resp'] = np.squeeze(np.column_stack(mY))
meta_data['std_resp'] = np.squeeze(np.column_stack(sY))
meta_data['mean_resp'] = [np.squeeze(np.concatenate(mY))]
meta_data['std_resp'] = [np.squeeze(np.concatenate(sY))]
meta_data['scaler_cov'] = X_scalers
meta_data['scaler_resp'] = Y_scalers

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pcntoolkit"
version = "0.32.0"
version = "0.33.0"
description = "Predictive Clinical Neuroscience Toolkit"
authors = ["Andre Marquand"]
license = "GNU GPLv3"
Expand Down
39 changes: 30 additions & 9 deletions tests/test_HBR.ipynb

Large diffs are not rendered by default.

0 comments on commit 6edae0a

Please sign in to comment.