Skip to content

Commit

Permalink
Merge pull request #85 from amarquand/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
amarquand authored Jun 15, 2022
2 parents cef7055 + 1c3d53f commit 4148fbc
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 44 deletions.
18 changes: 9 additions & 9 deletions pcntoolkit/model/hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def hbr(X, y, batch_effects, batch_effects_size, configs, trace=None):
mu = pb.make_param("mu").get_samples(pb)
sigma = pb.make_param("sigma").get_samples(pb)
sigma_plus = pm.math.log(1+pm.math.exp(sigma))
y_like = pm.Normal('y',mu=mu, sigma=sigma, observed=y)
y_like = pm.Normal('y',mu=mu, sigma=sigma_plus, observed=y)

elif configs['likelihood'] in ['SHASHb','SHASHo','SHASHo2']:
"""
Expand All @@ -182,15 +182,15 @@ def hbr(X, y, batch_effects, batch_effects_size, configs, trace=None):
Comment 2
Any mapping that is applied here after sampling should also be applied in util.hbr_utils.forward in order for the functions there to properly work.
For example, the softplus applied to sigma here is also applied
For example, the softplus applied to sigma here is also applied in util.hbr_utils.forward
"""
SHASH_map = {'SHASHb':SHASHb,'SHASHo':SHASHo,'SHASHo2':SHASHo2}
mu = pb.make_param("mu").get_samples(pb)
sigma = pb.make_param("sigma").get_samples(pb)
sigma = pb.make_param("sigma", intercept_sigma_params = (1., 1.)).get_samples(pb)
sigma_plus = pm.math.log(1+pm.math.exp(sigma))
epsilon = pb.make_param("epsilon", epsilon_params=(0.,1.)).get_samples(pb)
delta = pb.make_param("delta", delta_dist='igamma',delta_params=(1.,1.)).get_samples(pb)
delta_plus = delta + 0.5
epsilon = pb.make_param("epsilon").get_samples(pb)
delta = pb.make_param("delta", intercept_delta_params=(1., 1.)).get_samples(pb)
delta_plus = pm.math.log(1+pm.math.exp(delta)) + 0.3
y_like = SHASH_map[configs['likelihood']]('y', mu=mu, sigma=sigma_plus, epsilon=epsilon, delta=delta_plus, observed = y)

return model
Expand Down Expand Up @@ -426,7 +426,7 @@ def __init__(self, name, dist, params, pb, shape=(1,)) -> None:
def make_dist(self, dist, params, pb):
"""This creates a pymc3 distribution. If there is a trace, the distribution is fitted to the trace. If there isn't a trace, the prior is parameterized by the values in (params)"""
with pb.model as m:
if pb.trace is not None:
if (pb.trace is not None) and (not self.has_random_effect):
int_dist = from_posterior(param=self.name,
samples=pb.trace[self.name],
distribution=dist,
Expand Down Expand Up @@ -606,9 +606,9 @@ def get_samples(self, pb:ParamBuilder):
with pb.model:
samples = theano.tensor.zeros([pb.n_ys, *self.dim])
for be, idx in pb.be_idx_tups:
dot = theano.tensor.dot(pb.X[idx,:], self.slope_parameterization.dist[be])
dot = theano.tensor.dot(pb.X[idx,:], self.slope_parameterization.dist[be]).T
intercept = self.intercept_parameterization.dist[be]
samples = theano.tensor.set_subtensor(samples[idx,0],dot+intercept)
samples = theano.tensor.set_subtensor(samples[idx,:],dot+intercept)
return samples


Expand Down
4 changes: 3 additions & 1 deletion pcntoolkit/normative.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,9 @@ def transfer(covfile, respfile, testcov=None, testresp=None, maskfile=None,
if tsbefile is not None:
batch_effects_test = fileio.load(tsbefile)
else:
batch_effects_test = np.zeros([Xte.shape[0],2])
batch_effects_test = np.zeros([Xte.shape[0],2])
else:
ts_sample_num = 0

Yhat = np.zeros([ts_sample_num, feature_num])
S2 = np.zeros([ts_sample_num, feature_num])
Expand Down
3 changes: 1 addition & 2 deletions pcntoolkit/normative_model/norm_hbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ def predict(self, Xs, X=None, Y=None, **kwargs):
return yhat.squeeze(), s2.squeeze()

def estimate_on_new_sites(self, X, y, batch_effects):

self.hbr.adapt(X, y, batch_effects)
self.hbr.estimate_on_new_site(X, y, batch_effects)
self.configs['transferred'] = True
return self

Expand Down
61 changes: 29 additions & 32 deletions pcntoolkit/util/hbr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,32 @@ def z_score(Y, mu, sigma, epsilon=None, delta=None, likelihood = "Normal"):
return Z


def get_MCMC_quantiles(min_x, max_x, z_scores, model, be):
def get_MCMC_quantiles(synthetic_X, z_scores, model, be):
"""Get an MCMC estimate of the quantiles"""
"""This does not use the get_single_quantiles function, for memory efficiency"""
resolution = 200
synthetic_X = np.linspace(min_x, max_x, resolution)[:,None]
resolution = synthetic_X.shape[0]
synthetic_X_transformed = bspline_transform(synthetic_X, model.hbr.bsp)
be = np.reshape(np.array(be),(1,-1))
synthetic_Z = np.repeat(be, resolution, axis = 0)
synthetic_Z = np.repeat(be, resolution, axis = 0)[:,None]
z_scores = np.reshape(np.array(z_scores),(1,-1))
zs = np.repeat(z_scores, resolution, axis=0).T
zs = np.repeat(z_scores, resolution, axis=0)
def f(sample):
ps = forward(np.squeeze(synthetic_X_transformed),synthetic_Z, model,sample)
ps = forward(synthetic_X_transformed,synthetic_Z, model,sample)
q = quantile(zs, ps['mu'], ps['sigma'],ps.get('epsilon',None),ps.get('delta',None), likelihood = model.configs['likelihood'])
return q
out = MCMC_estimate(f, model.hbr.trace)
return synthetic_X, out


def get_single_quantiles(min_x, max_x, z_scores, model, be, sample):
def get_single_quantiles(synthetic_X, z_scores, model, be, sample):
"""Get the quantiles within a given range of covariates, given a model"""
resolution = 200
synthetic_X = np.linspace(min_x, max_x, resolution)[:,None]
resolution = synthetic_X.shape[0]
synthetic_X_transformed = bspline_transform(synthetic_X, model.hbr.bsp)
be = np.reshape(np.array(be),(1,-1))
synthetic_Z = np.repeat(be, resolution, axis = 0)
z_scores = np.reshape(np.array(z_scores),(1,-1))
zs = np.repeat(z_scores, resolution, axis=0).T
ps = forward(np.squeeze(synthetic_X_transformed),synthetic_Z, model,sample)
zs = np.repeat(z_scores, resolution, axis=0)
ps = forward(synthetic_X_transformed,synthetic_Z, model,sample)
q = quantile(zs, ps['mu'], ps['sigma'],ps.get('epsilon',None),ps.get('delta',None), likelihood = model.configs['likelihood'])
return q

Expand All @@ -115,7 +113,7 @@ def quantile(zs, mu, sigma, epsilon=None, delta=None, likelihood = "Normal"):

def single_parameter_forward(X, Z, model, sample, p_name):
"""Get a likelihood paramameter given covariates, batch-effects and model parameters"""
outs = np.zeros(X.shape[0])
outs = np.zeros(X.shape[0])[:,None]
all_bes = np.unique(Z,axis=0)
for be in all_bes:
bet = tuple(be)
Expand All @@ -129,21 +127,21 @@ def single_parameter_forward(X, Z, model, sample, p_name):
intercept_be = sample[f"intercept_{p_name}"][bet]
else:
intercept_be = sample[f"intercept_{p_name}"]

outs[idx] = X[idx]@slope_be + intercept_be

out = (X[np.squeeze(idx),:]@slope_be)[:,None] + intercept_be
outs[np.squeeze(idx),:] = out
else:
if model.configs[f'random_{p_name}']:
outs[idx] = sample[p_name][bet]
outs[np.squeeze(idx),:] = sample[p_name][bet]
else:
outs[idx] = sample[p_name]
outs[np.squeeze(idx),:] = sample[p_name]
return outs


def forward(X, Z, model, sample):
"""Get all likelihood paramameters given covariates and batch-effects and model parameters"""
# TODO think if this is the correct spot for this
mapfuncs={'sigma': lambda x: np.log(1+np.exp(x)), 'delta':lambda x :x+0.5}

mapfuncs={'sigma': lambda x: np.log(1+np.exp(x)), 'delta':lambda x :np.log(1+np.exp(x)) + 0.3}
likelihood = model.configs['likelihood']
if likelihood == 'Normal':
parameter_list = ['mu','sigma']
Expand All @@ -164,27 +162,26 @@ def forward(X, Z, model, sample):
return output_dict


def Rhats(model, thin = 1, resolution = 100):
def Rhats(model, thin = 1, resolution = 100, varnames = None):
"""Get Rhat as function of sampling iteration"""
if varnames == None:
varnames = trace.varnames
trace = model.hbr.trace
n_chains = len(trace.chains)
chain_length = trace.get_values(trace.varnames[0],chains=trace.chains[0], thin=thin).shape[0]
chain_length = trace.get_values(varnames[0],chains=trace.chains[0], thin=thin).shape[0]
interval_skip=chain_length//resolution

rhat_dict = {}
for varname in trace.varnames:

for varname in varnames:
testvar = np.stack(trace.get_values(varname,combine=False))
vardim = testvar.reshape((testvar.shape[0], testvar.shape[1], -1)).shape[2]
rhats_var = np.zeros((n_chains,resolution, vardim))

for i in range(n_chains):
var = np.stack(trace.get_values(varname,combine=False))
var = var.reshape((var.shape[0], var.shape[1], -1))
rhats_var_chain = np.zeros((resolution, var.shape[2]))
for v in range(var.shape[2]):
for j in range(resolution):
rhats_var_chain[j,v] = pm.rhat(var[:,:j*interval_skip,v])
rhats_var[i] = rhats_var_chain
rhats_var = np.zeros((resolution, vardim))

var = np.stack(trace.get_values(varname,combine=False))
var = var.reshape((var.shape[0], var.shape[1], -1))
for v in range(var.shape[2]):
for j in range(resolution):
rhats_var[j,v] = pm.rhat(var[:,:j*interval_skip,v])
rhat_dict[varname] = rhats_var
return rhat_dict

Expand Down

0 comments on commit 4148fbc

Please sign in to comment.