-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyNfYdKn7+C4d4WSym/CFRMQ", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/amarquand/PCNtoolkit/blob/master/notebooks/pcntk_colab_env.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!pip install https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n", | ||
"!pip install nutpie" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "vIbnkHN9ydb3", | ||
"outputId": "0d4ade30-dab7-4d39-f1be-f8123f30184f" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Collecting https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n", | ||
" Using cached https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n", | ||
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", | ||
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", | ||
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", | ||
"Requirement already satisfied: bspline<0.2.0,>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (0.1.1)\n", | ||
"Requirement already satisfied: nibabel<6.0.0,>=5.3.1 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (5.3.2)\n", | ||
"Requirement already satisfied: pymc<6.0.0,>=5.17.0 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (5.17.0)\n", | ||
"Requirement already satisfied: scikit-learn<2.0.0,>=1.5.2 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.5.2)\n", | ||
"Requirement already satisfied: scipy<2.0,>=1.12 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.13.1)\n", | ||
"Requirement already satisfied: seaborn<0.14.0,>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (0.13.2)\n", | ||
"Requirement already satisfied: six<2.0.0,>=1.16.0 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.16.0)\n", | ||
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from bspline<0.2.0,>=0.1.1->pcntoolkit==0.30.2) (1.26.4)\n", | ||
"Requirement already satisfied: importlib-resources>=5.12 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (6.4.5)\n", | ||
"Requirement already satisfied: packaging>=20 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (24.1)\n", | ||
"Requirement already satisfied: typing-extensions>=4.6 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (4.12.2)\n", | ||
"Requirement already satisfied: arviz>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.20.0)\n", | ||
"Requirement already satisfied: cachetools>=4.2.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (5.5.0)\n", | ||
"Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.1.0)\n", | ||
"Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.2.2)\n", | ||
"Requirement already satisfied: pytensor<2.26,>=2.25.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.25.5)\n", | ||
"Requirement already satisfied: rich>=13.7.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (13.9.4)\n", | ||
"Requirement already satisfied: threadpoolctl<4.0.0,>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.5.0)\n", | ||
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<2.0.0,>=1.5.2->pcntoolkit==0.30.2) (1.4.2)\n", | ||
"Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /usr/local/lib/python3.10/dist-packages (from seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (3.8.0)\n", | ||
"Requirement already satisfied: setuptools>=60.0.0 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (75.1.0)\n", | ||
"Requirement already satisfied: xarray>=2022.6.0 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.10.0)\n", | ||
"Requirement already satisfied: h5netcdf>=1.0.2 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.4.0)\n", | ||
"Requirement already satisfied: xarray-einstats>=0.3 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.8.0)\n", | ||
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (1.3.0)\n", | ||
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (0.12.1)\n", | ||
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (4.54.1)\n", | ||
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (1.4.7)\n", | ||
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (10.4.0)\n", | ||
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (3.2.0)\n", | ||
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (2.8.2)\n", | ||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.2)\n", | ||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.2)\n", | ||
"Requirement already satisfied: filelock>=3.15 in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.16.1)\n", | ||
"Requirement already satisfied: etuples in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.3.9)\n", | ||
"Requirement already satisfied: logical-unification in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.4.6)\n", | ||
"Requirement already satisfied: miniKanren in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.0.3)\n", | ||
"Requirement already satisfied: cons in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.4.6)\n", | ||
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.0.0)\n", | ||
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.18.0)\n", | ||
"Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from h5netcdf>=1.0.2->arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.12.1)\n", | ||
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.1.2)\n", | ||
"Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.12.1)\n", | ||
"Requirement already satisfied: multipledispatch in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.0.0)\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# %%\n", | ||
"from warnings import filterwarnings\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"import os\n", | ||
"\n", | ||
"from pcntoolkit.normative_model.norm_utils import norm_init\n", | ||
"from pcntoolkit.util.utils import simulate_data\n", | ||
"\n", | ||
"filterwarnings(\"ignore\")\n", | ||
"\n", | ||
"\n", | ||
"########################### Experiment Settings ###############################\n", | ||
"\n", | ||
"\n", | ||
"random_state = 40\n", | ||
"working_dir = \"temp\" # Specify a working directory to save data and results.\n", | ||
"os.makedirs(working_dir, exist_ok=True)\n", | ||
"simulation_method = \"linear\"\n", | ||
"n_features = 1 # The number of input features of X\n", | ||
"n_grps = 3 # Number of batches in data\n", | ||
"n_samples = 500 # Number of samples in each group (use a list for different\n", | ||
"# sample numbers across different batches)\n", | ||
"\n", | ||
"model_type = \"bspline\" # modelto try 'linear, ''polynomial', 'bspline'\n", | ||
"\n", | ||
"\n", | ||
"############################## Data Simulation ################################\n", | ||
"\n", | ||
"\n", | ||
"X_train, Y_train, grp_id_train, X_test, Y_test, grp_id_test, coef = simulate_data(\n", | ||
" simulation_method,\n", | ||
" n_samples,\n", | ||
" n_features,\n", | ||
" n_grps,\n", | ||
" working_dir=working_dir,\n", | ||
" plot=True,\n", | ||
" noise=\"heteroscedastic_nongaussian\",\n", | ||
" random_state=random_state,\n", | ||
")\n", | ||
"\n", | ||
"################################# Fittig and Predicting ###############################\n", | ||
"\n", | ||
"nm = norm_init(\n", | ||
" X_train,\n", | ||
" Y_train,\n", | ||
" alg=\"hbr\",\n", | ||
" model_type=model_type,\n", | ||
" likelihood=\"SHASHb\",\n", | ||
" linear_sigma=\"True\",\n", | ||
" random_slope_mu=\"False\",\n", | ||
" linear_epsilon=\"False\",\n", | ||
" linear_delta=\"False\",\n", | ||
" nuts_sampler=\"nutpie\",\n", | ||
")\n", | ||
"\n", | ||
"nm.estimate(X_train, Y_train, trbefile=os.path.join(working_dir, \"trbefile.pkl\"))\n", | ||
"yhat, ys2 = nm.predict(X_test, tsbefile=os.path.join(working_dir, \"tsbefile.pkl\"))\n", | ||
"\n", | ||
"\n", | ||
"################################# Plotting Quantiles ###############################\n", | ||
"for i in range(n_features):\n", | ||
" sorted_idx = X_test[:, i].argsort(axis=0).squeeze()\n", | ||
" temp_X = X_test[sorted_idx, i]\n", | ||
" temp_Y = Y_test[sorted_idx,]\n", | ||
" temp_be = grp_id_test[sorted_idx, :].squeeze()\n", | ||
" temp_yhat = yhat[sorted_idx,]\n", | ||
" temp_s2 = ys2[sorted_idx,]\n", | ||
"\n", | ||
" plt.figure()\n", | ||
" for j in range(n_grps):\n", | ||
" scat1 = plt.scatter(\n", | ||
" temp_X[temp_be == j,], temp_Y[temp_be == j,], label=\"Group\" + str(j)\n", | ||
" )\n", | ||
" # Showing the quantiles\n", | ||
" resolution = 200\n", | ||
" synth_X = np.linspace(np.min(X_train), np.max(X_train), resolution)\n", | ||
" q = nm.get_mcmc_quantiles(synth_X, batch_effects=j * np.ones(resolution))\n", | ||
" col = scat1.get_facecolors()[0]\n", | ||
" plt.plot(synth_X, q.T, linewidth=1, color=col, zorder=0)\n", | ||
"\n", | ||
" plt.title(\"Model %s, Feature %d\" % (model_type, i))\n", | ||
" plt.legend()\n", | ||
" plt.show(block=False)\n", | ||
" plt.savefig(working_dir + \"quantiles_\" + model_type + \"_feature_\" + str(i) + \".png\")\n", | ||
"\n", | ||
" for j in range(n_grps):\n", | ||
" plt.figure()\n", | ||
" plt.scatter(temp_X[temp_be == j,], temp_Y[temp_be == j,])\n", | ||
" plt.plot(temp_X[temp_be == j,], temp_yhat[temp_be == j,], color=\"red\")\n", | ||
" plt.fill_between(\n", | ||
" temp_X[temp_be == j,].squeeze(),\n", | ||
" (temp_yhat[temp_be == j,] - 2 * np.sqrt(temp_s2[temp_be == j,])).squeeze(),\n", | ||
" (temp_yhat[temp_be == j,] + 2 * np.sqrt(temp_s2[temp_be == j,])).squeeze(),\n", | ||
" color=\"red\",\n", | ||
" alpha=0.2,\n", | ||
" )\n", | ||
" plt.title(\"Model %s, Group %d, Feature %d\" % (model_type, j, i))\n", | ||
" plt.show(block=False)\n", | ||
" plt.savefig(\n", | ||
" working_dir\n", | ||
" + \"pred_\"\n", | ||
" + model_type\n", | ||
" + \"_group_\"\n", | ||
" + str(j)\n", | ||
" + \"_feature_\"\n", | ||
" + str(i)\n", | ||
" + \".png\"\n", | ||
" )" | ||
], | ||
"metadata": { | ||
"id": "RT0EbS7yzCNh" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [], | ||
"metadata": { | ||
"id": "3gneUhT80BZZ" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |