Skip to content

Commit

Permalink
Created using Colab
Browse files Browse the repository at this point in the history
  • Loading branch information
AuguB committed Nov 12, 2024
1 parent 131bea0 commit b82c3b3
Showing 1 changed file with 231 additions and 0 deletions.
231 changes: 231 additions & 0 deletions notebooks/pcntk_colab_env.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNfYdKn7+C4d4WSym/CFRMQ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/amarquand/PCNtoolkit/blob/master/notebooks/pcntk_colab_env.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n",
"!pip install nutpie"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vIbnkHN9ydb3",
"outputId": "0d4ade30-dab7-4d39-f1be-f8123f30184f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n",
" Using cached https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: bspline<0.2.0,>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (0.1.1)\n",
"Requirement already satisfied: nibabel<6.0.0,>=5.3.1 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (5.3.2)\n",
"Requirement already satisfied: pymc<6.0.0,>=5.17.0 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (5.17.0)\n",
"Requirement already satisfied: scikit-learn<2.0.0,>=1.5.2 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.5.2)\n",
"Requirement already satisfied: scipy<2.0,>=1.12 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.13.1)\n",
"Requirement already satisfied: seaborn<0.14.0,>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (0.13.2)\n",
"Requirement already satisfied: six<2.0.0,>=1.16.0 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.16.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from bspline<0.2.0,>=0.1.1->pcntoolkit==0.30.2) (1.26.4)\n",
"Requirement already satisfied: importlib-resources>=5.12 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (6.4.5)\n",
"Requirement already satisfied: packaging>=20 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (24.1)\n",
"Requirement already satisfied: typing-extensions>=4.6 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (4.12.2)\n",
"Requirement already satisfied: arviz>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.20.0)\n",
"Requirement already satisfied: cachetools>=4.2.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (5.5.0)\n",
"Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.1.0)\n",
"Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.2.2)\n",
"Requirement already satisfied: pytensor<2.26,>=2.25.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.25.5)\n",
"Requirement already satisfied: rich>=13.7.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (13.9.4)\n",
"Requirement already satisfied: threadpoolctl<4.0.0,>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.5.0)\n",
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<2.0.0,>=1.5.2->pcntoolkit==0.30.2) (1.4.2)\n",
"Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /usr/local/lib/python3.10/dist-packages (from seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (3.8.0)\n",
"Requirement already satisfied: setuptools>=60.0.0 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (75.1.0)\n",
"Requirement already satisfied: xarray>=2022.6.0 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.10.0)\n",
"Requirement already satisfied: h5netcdf>=1.0.2 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.4.0)\n",
"Requirement already satisfied: xarray-einstats>=0.3 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.8.0)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (1.3.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (4.54.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (1.4.7)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (10.4.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (3.2.0)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.2)\n",
"Requirement already satisfied: filelock>=3.15 in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.16.1)\n",
"Requirement already satisfied: etuples in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.3.9)\n",
"Requirement already satisfied: logical-unification in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.4.6)\n",
"Requirement already satisfied: miniKanren in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.0.3)\n",
"Requirement already satisfied: cons in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.4.6)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.0.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.18.0)\n",
"Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from h5netcdf>=1.0.2->arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.12.1)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.1.2)\n",
"Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.12.1)\n",
"Requirement already satisfied: multipledispatch in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.0.0)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# %%\n",
"from warnings import filterwarnings\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"\n",
"from pcntoolkit.normative_model.norm_utils import norm_init\n",
"from pcntoolkit.util.utils import simulate_data\n",
"\n",
"filterwarnings(\"ignore\")\n",
"\n",
"\n",
"########################### Experiment Settings ###############################\n",
"\n",
"\n",
"random_state = 40\n",
"working_dir = \"temp\" # Specify a working directory to save data and results.\n",
"os.makedirs(working_dir, exist_ok=True)\n",
"simulation_method = \"linear\"\n",
"n_features = 1 # The number of input features of X\n",
"n_grps = 3 # Number of batches in data\n",
"n_samples = 500 # Number of samples in each group (use a list for different\n",
"# sample numbers across different batches)\n",
"\n",
"model_type = \"bspline\" # modelto try 'linear, ''polynomial', 'bspline'\n",
"\n",
"\n",
"############################## Data Simulation ################################\n",
"\n",
"\n",
"X_train, Y_train, grp_id_train, X_test, Y_test, grp_id_test, coef = simulate_data(\n",
" simulation_method,\n",
" n_samples,\n",
" n_features,\n",
" n_grps,\n",
" working_dir=working_dir,\n",
" plot=True,\n",
" noise=\"heteroscedastic_nongaussian\",\n",
" random_state=random_state,\n",
")\n",
"\n",
"################################# Fittig and Predicting ###############################\n",
"\n",
"nm = norm_init(\n",
" X_train,\n",
" Y_train,\n",
" alg=\"hbr\",\n",
" model_type=model_type,\n",
" likelihood=\"SHASHb\",\n",
" linear_sigma=\"True\",\n",
" random_slope_mu=\"False\",\n",
" linear_epsilon=\"False\",\n",
" linear_delta=\"False\",\n",
" nuts_sampler=\"nutpie\",\n",
")\n",
"\n",
"nm.estimate(X_train, Y_train, trbefile=os.path.join(working_dir, \"trbefile.pkl\"))\n",
"yhat, ys2 = nm.predict(X_test, tsbefile=os.path.join(working_dir, \"tsbefile.pkl\"))\n",
"\n",
"\n",
"################################# Plotting Quantiles ###############################\n",
"for i in range(n_features):\n",
" sorted_idx = X_test[:, i].argsort(axis=0).squeeze()\n",
" temp_X = X_test[sorted_idx, i]\n",
" temp_Y = Y_test[sorted_idx,]\n",
" temp_be = grp_id_test[sorted_idx, :].squeeze()\n",
" temp_yhat = yhat[sorted_idx,]\n",
" temp_s2 = ys2[sorted_idx,]\n",
"\n",
" plt.figure()\n",
" for j in range(n_grps):\n",
" scat1 = plt.scatter(\n",
" temp_X[temp_be == j,], temp_Y[temp_be == j,], label=\"Group\" + str(j)\n",
" )\n",
" # Showing the quantiles\n",
" resolution = 200\n",
" synth_X = np.linspace(np.min(X_train), np.max(X_train), resolution)\n",
" q = nm.get_mcmc_quantiles(synth_X, batch_effects=j * np.ones(resolution))\n",
" col = scat1.get_facecolors()[0]\n",
" plt.plot(synth_X, q.T, linewidth=1, color=col, zorder=0)\n",
"\n",
" plt.title(\"Model %s, Feature %d\" % (model_type, i))\n",
" plt.legend()\n",
" plt.show(block=False)\n",
" plt.savefig(working_dir + \"quantiles_\" + model_type + \"_feature_\" + str(i) + \".png\")\n",
"\n",
" for j in range(n_grps):\n",
" plt.figure()\n",
" plt.scatter(temp_X[temp_be == j,], temp_Y[temp_be == j,])\n",
" plt.plot(temp_X[temp_be == j,], temp_yhat[temp_be == j,], color=\"red\")\n",
" plt.fill_between(\n",
" temp_X[temp_be == j,].squeeze(),\n",
" (temp_yhat[temp_be == j,] - 2 * np.sqrt(temp_s2[temp_be == j,])).squeeze(),\n",
" (temp_yhat[temp_be == j,] + 2 * np.sqrt(temp_s2[temp_be == j,])).squeeze(),\n",
" color=\"red\",\n",
" alpha=0.2,\n",
" )\n",
" plt.title(\"Model %s, Group %d, Feature %d\" % (model_type, j, i))\n",
" plt.show(block=False)\n",
" plt.savefig(\n",
" working_dir\n",
" + \"pred_\"\n",
" + model_type\n",
" + \"_group_\"\n",
" + str(j)\n",
" + \"_feature_\"\n",
" + str(i)\n",
" + \".png\"\n",
" )"
],
"metadata": {
"id": "RT0EbS7yzCNh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "3gneUhT80BZZ"
},
"execution_count": null,
"outputs": []
}
]
}

0 comments on commit b82c3b3

Please sign in to comment.