diff --git a/notebooks/pcntk_colab_env.ipynb b/notebooks/pcntk_colab_env.ipynb new file mode 100644 index 00000000..1f804dc0 --- /dev/null +++ b/notebooks/pcntk_colab_env.ipynb @@ -0,0 +1,231 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyNfYdKn7+C4d4WSym/CFRMQ", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n", + "!pip install nutpie" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vIbnkHN9ydb3", + "outputId": "0d4ade30-dab7-4d39-f1be-f8123f30184f" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n", + " Using cached https://github.com/amarquand/PCNtoolkit/archive/dev.zip\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: bspline<0.2.0,>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (0.1.1)\n", + "Requirement already satisfied: nibabel<6.0.0,>=5.3.1 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (5.3.2)\n", + "Requirement already satisfied: pymc<6.0.0,>=5.17.0 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (5.17.0)\n", + "Requirement already satisfied: scikit-learn<2.0.0,>=1.5.2 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.5.2)\n", + "Requirement already satisfied: scipy<2.0,>=1.12 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.13.1)\n", + "Requirement already satisfied: seaborn<0.14.0,>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (0.13.2)\n", + "Requirement already satisfied: six<2.0.0,>=1.16.0 in /usr/local/lib/python3.10/dist-packages (from pcntoolkit==0.30.2) (1.16.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from bspline<0.2.0,>=0.1.1->pcntoolkit==0.30.2) (1.26.4)\n", + "Requirement already satisfied: importlib-resources>=5.12 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (6.4.5)\n", + "Requirement already satisfied: packaging>=20 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (24.1)\n", + "Requirement already satisfied: typing-extensions>=4.6 in /usr/local/lib/python3.10/dist-packages (from nibabel<6.0.0,>=5.3.1->pcntoolkit==0.30.2) (4.12.2)\n", + "Requirement already satisfied: arviz>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.20.0)\n", + "Requirement already satisfied: cachetools>=4.2.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (5.5.0)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.1.0)\n", + "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.2.2)\n", + "Requirement already satisfied: pytensor<2.26,>=2.25.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.25.5)\n", + "Requirement already satisfied: rich>=13.7.1 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (13.9.4)\n", + "Requirement already satisfied: threadpoolctl<4.0.0,>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.5.0)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn<2.0.0,>=1.5.2->pcntoolkit==0.30.2) (1.4.2)\n", + "Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /usr/local/lib/python3.10/dist-packages (from seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (3.8.0)\n", + "Requirement already satisfied: setuptools>=60.0.0 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (75.1.0)\n", + "Requirement already satisfied: xarray>=2022.6.0 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.10.0)\n", + "Requirement already satisfied: h5netcdf>=1.0.2 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.4.0)\n", + "Requirement already satisfied: xarray-einstats>=0.3 in /usr/local/lib/python3.10/dist-packages (from arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.8.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (1.3.0)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (4.54.1)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (1.4.7)\n", + "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (10.4.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (3.2.0)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib!=3.6.1,>=3.4->seaborn<0.14.0,>=0.13.2->pcntoolkit==0.30.2) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2024.2)\n", + "Requirement already satisfied: filelock>=3.15 in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.16.1)\n", + "Requirement already satisfied: etuples in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.3.9)\n", + "Requirement already satisfied: logical-unification in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.4.6)\n", + "Requirement already satisfied: miniKanren in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.0.3)\n", + "Requirement already satisfied: cons in /usr/local/lib/python3.10/dist-packages (from pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.4.6)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (2.18.0)\n", + "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (from h5netcdf>=1.0.2->arviz>=0.13.0->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (3.12.1)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=13.7.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.1.2)\n", + "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (0.12.1)\n", + "Requirement already satisfied: multipledispatch in /usr/local/lib/python3.10/dist-packages (from logical-unification->pytensor<2.26,>=2.25.1->pymc<6.0.0,>=5.17.0->pcntoolkit==0.30.2) (1.0.0)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# %%\n", + "from warnings import filterwarnings\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "\n", + "from pcntoolkit.normative_model.norm_utils import norm_init\n", + "from pcntoolkit.util.utils import simulate_data\n", + "\n", + "filterwarnings(\"ignore\")\n", + "\n", + "\n", + "########################### Experiment Settings ###############################\n", + "\n", + "\n", + "random_state = 40\n", + "working_dir = \"temp\" # Specify a working directory to save data and results.\n", + "os.makedirs(working_dir, exist_ok=True)\n", + "simulation_method = \"linear\"\n", + "n_features = 1 # The number of input features of X\n", + "n_grps = 3 # Number of batches in data\n", + "n_samples = 500 # Number of samples in each group (use a list for different\n", + "# sample numbers across different batches)\n", + "\n", + "model_type = \"bspline\" # modelto try 'linear, ''polynomial', 'bspline'\n", + "\n", + "\n", + "############################## Data Simulation ################################\n", + "\n", + "\n", + "X_train, Y_train, grp_id_train, X_test, Y_test, grp_id_test, coef = simulate_data(\n", + " simulation_method,\n", + " n_samples,\n", + " n_features,\n", + " n_grps,\n", + " working_dir=working_dir,\n", + " plot=True,\n", + " noise=\"heteroscedastic_nongaussian\",\n", + " random_state=random_state,\n", + ")\n", + "\n", + "################################# Fittig and Predicting ###############################\n", + "\n", + "nm = norm_init(\n", + " X_train,\n", + " Y_train,\n", + " alg=\"hbr\",\n", + " model_type=model_type,\n", + " likelihood=\"SHASHb\",\n", + " linear_sigma=\"True\",\n", + " random_slope_mu=\"False\",\n", + " linear_epsilon=\"False\",\n", + " linear_delta=\"False\",\n", + " nuts_sampler=\"nutpie\",\n", + ")\n", + "\n", + "nm.estimate(X_train, Y_train, trbefile=os.path.join(working_dir, \"trbefile.pkl\"))\n", + "yhat, ys2 = nm.predict(X_test, tsbefile=os.path.join(working_dir, \"tsbefile.pkl\"))\n", + "\n", + "\n", + "################################# Plotting Quantiles ###############################\n", + "for i in range(n_features):\n", + " sorted_idx = X_test[:, i].argsort(axis=0).squeeze()\n", + " temp_X = X_test[sorted_idx, i]\n", + " temp_Y = Y_test[sorted_idx,]\n", + " temp_be = grp_id_test[sorted_idx, :].squeeze()\n", + " temp_yhat = yhat[sorted_idx,]\n", + " temp_s2 = ys2[sorted_idx,]\n", + "\n", + " plt.figure()\n", + " for j in range(n_grps):\n", + " scat1 = plt.scatter(\n", + " temp_X[temp_be == j,], temp_Y[temp_be == j,], label=\"Group\" + str(j)\n", + " )\n", + " # Showing the quantiles\n", + " resolution = 200\n", + " synth_X = np.linspace(np.min(X_train), np.max(X_train), resolution)\n", + " q = nm.get_mcmc_quantiles(synth_X, batch_effects=j * np.ones(resolution))\n", + " col = scat1.get_facecolors()[0]\n", + " plt.plot(synth_X, q.T, linewidth=1, color=col, zorder=0)\n", + "\n", + " plt.title(\"Model %s, Feature %d\" % (model_type, i))\n", + " plt.legend()\n", + " plt.show(block=False)\n", + " plt.savefig(working_dir + \"quantiles_\" + model_type + \"_feature_\" + str(i) + \".png\")\n", + "\n", + " for j in range(n_grps):\n", + " plt.figure()\n", + " plt.scatter(temp_X[temp_be == j,], temp_Y[temp_be == j,])\n", + " plt.plot(temp_X[temp_be == j,], temp_yhat[temp_be == j,], color=\"red\")\n", + " plt.fill_between(\n", + " temp_X[temp_be == j,].squeeze(),\n", + " (temp_yhat[temp_be == j,] - 2 * np.sqrt(temp_s2[temp_be == j,])).squeeze(),\n", + " (temp_yhat[temp_be == j,] + 2 * np.sqrt(temp_s2[temp_be == j,])).squeeze(),\n", + " color=\"red\",\n", + " alpha=0.2,\n", + " )\n", + " plt.title(\"Model %s, Group %d, Feature %d\" % (model_type, j, i))\n", + " plt.show(block=False)\n", + " plt.savefig(\n", + " working_dir\n", + " + \"pred_\"\n", + " + model_type\n", + " + \"_group_\"\n", + " + str(j)\n", + " + \"_feature_\"\n", + " + str(i)\n", + " + \".png\"\n", + " )" + ], + "metadata": { + "id": "RT0EbS7yzCNh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "3gneUhT80BZZ" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file