{ "cells": [ { "cell_type": "markdown", "id": "e9b462e4", "metadata": {}, "source": [ "# Generalization to Bayesian Softmax Regression" ] }, { "cell_type": "markdown", "id": "8510f432", "metadata": {}, "source": [ "Ref: Chap 4 of Mar18\n", "\n", "https://cfteach.github.io/brds/referencesmd.html" ] }, { "cell_type": "code", "execution_count": 2, "id": "5015e78e", "metadata": {}, "outputs": [], "source": [ "import pymc3 as pm\n", "import numpy as np\n", "import pandas as pd\n", "import theano.tensor as tt\n", "import seaborn as sns\n", "import scipy.stats as stats\n", "from scipy.special import expit as logistic\n", "import matplotlib.pyplot as plt\n", "import arviz as az\n", "import requests\n", "import io " ] }, { "cell_type": "code", "execution_count": 3, "id": "6c2e3cbf", "metadata": {}, "outputs": [], "source": [ "az.style.use('arviz-darkgrid')" ] }, { "cell_type": "code", "execution_count": 4, "id": "224771aa", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal_lengthsepal_widthpetal_lengthpetal_widthspecies
05.13.51.40.2setosa
14.93.01.40.2setosa
24.73.21.30.2setosa
34.63.11.50.2setosa
45.03.61.40.2setosa
\n", "
" ], "text/plain": [ " sepal_length sepal_width petal_length petal_width species\n", "0 5.1 3.5 1.4 0.2 setosa\n", "1 4.9 3.0 1.4 0.2 setosa\n", "2 4.7 3.2 1.3 0.2 setosa\n", "3 4.6 3.1 1.5 0.2 setosa\n", "4 5.0 3.6 1.4 0.2 setosa" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_url = 'https://raw.githubusercontent.com/cfteach/brds/main/datasets/iris.csv' \n", "\n", "download = requests.get(target_url).content\n", "iris = pd.read_csv(io.StringIO(download.decode('utf-8')))\n", "\n", "iris.head()" ] }, { "cell_type": "markdown", "id": "4f145d8b", "metadata": {}, "source": [ "## Recipe 1: Dealing with correlated data" ] }, { "cell_type": "code", "execution_count": 5, "id": "a3e9f488", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/r2/_2532dgx683084s9v9ss0cfc0000gq/T/ipykernel_21894/3442237557.py:1: FutureWarning: The default value of numeric_only in DataFrame.corr is deprecated. In a future version, it will default to False. Select only valid columns or specify the value of numeric_only to silence this warning.\n", " corr = iris[iris['species'] != 'virginica'].corr()\n" ] }, { "data": { "text/plain": [ "[Text(0, 0.5, 'sepal_length'),\n", " Text(0, 1.5, 'sepal_width'),\n", " Text(0, 2.5, 'petal_length'),\n", " Text(0, 3.5, 'petal_width')]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbcAAAEoCAYAAADbp799AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAA1EklEQVR4nO3dd5xU9dXH8c/uKkivCigiiHKwRbBroqDEaEwzamJiL7HGBGKMxhI1scWOibHEWBOe51HTLInRqIANRIOIiBxDL9IWkA7L7s7zx+/ucnfZNlvmzs5836/XvJi5bc69rnPu73fPvb+CVCqFiIhILilMOgAREZHmpuQmIiI5R8lNRERyjpKbiIjkHCU3ERHJOUpuIiKSc7ZLOgCpatWqVbo3Q0RaXLdu3Qoau275kkH1/k4V9v600dtvDkpuIiKSli2p0nqXaZuBOOqi5CYiImkpJ/s7mJTcREQkLVtSZfUu0y4DcdRFyU1ERNKilpuIiOScMiU3ERHJNVtS5UmHUC8lNxERSUv2pzYlNxERSZO6JUVEJOdsyf7cpuQmIiLpKSPRh480iJKbiIikpVwtNxERyTUlreCZ+0puIiKSlvKUuiVFRCTH6JqbiIjknC2poqRDqJeSm4iIpEUtt2ZiZv2BOcDB7v5+cy2bCWZ2DnC/u3dMOhYRkeZQlsr+gpLsj7AVMbO5ZnZF0nGIiLSkLRTV+0paq2i5iYhI9mgNLbcGJTczOwq4A9gXKAMcOM/dp5nZEcBtwMHAKuB54Cp3XxOtOw6YAWwGzoo2+YdomfJomTOAkcBgYCMwHhjl7ouaYR8xs72BO4Gjou2/BvzE3ZdE858AegL/Bq4E2gN/B37o7huiZToADwInAeuB0cAXgWJ3Pyfaz92AO83sTgB3r+yYNrMRwH3AAGAS4fjNaY79ExHJpGxomdWn3vRrZtsBzwFvAfsDhxJ+2MvMbD/gFUJC25/wwz8EeKzaZk6Pvutw4CLgQmBUbH4b4IZoG18nJJr/bdQebRt/H+ANYBpwCPBloCPwnJnF9/9IQvL+MnAq8G1Cwq1wNzAsmn5MFOuRsfknAQuBXwF9oleFtsDVwHmEY9AVeKg59k9EJNPKUoX1vpLWkJZbZ8KP8QvuPiuaNgPAzJ4Cnnb3uysWNrNLgA/MbCd3XxZNXgz82N1TwAwzGwRcDtwD4O7xZDg72sYnZtbX3Rc2fvcAuAT40N2visV4FrASOIjQigJYA1zs7mXRdz8LjABuM7OOhMR0lrv/O9rG+YRkRrQPK82sDFhb0SKM2Y7QCvRo3buAx8ysIDomIiKtRnkrKNeoN7lFP9pPAC+b2WuELr0/u/t84EBgDzM7NbZKRVfcQKAiuU2s9iM+AbjJzDq7+xozO4DQchsCdI9tox+xBNJIBwJHmdm6GuYNZGtymx4ltgqfEVqpFcttH1sWd19vZtMaGMPmisQW23YboBshyYqItBoluXKfm7ufa2ajgeOBbwK3mNmJhK7GPwD31rBag66XRdeyXgZeBc4kJMSewJuEBNBUhcA/gJqqGJfG3m+pNi9F81WTltawbZpx+yIiGVOeBd2O9WlwtaS7fwh8CNxuZi8BZwOTgX3cfWY9qx9arQvuMOCzqNV2ICGZXVNRYGFmJ6W7I3WYDHwXmOfu1RNYQ80iJL+DgdkAZtaecI1uVmy5EmgFV1pFRJqgrBWcl9eb3MxsAKEI5HlCa2x34AuEysHngYlm9hDwMLCWUPH4DXe/KLaZnYHRZvYAsB/wM+DmaN58QiXlZWb2O2Av4Kam71ql3wEXAE+b2e3A8mgfvgv81N3X1rcBd19nZo8REnsx4RridYSWV7y7dS5wpJn9idAVWdyM+yEikhVaw+O3GpJ+NwCDgGeBT4EngTHA7e4+lVBe359Qvv8h4baApdW2MYbQonkXeAR4lKgr092XE1qBJwLTCdfeLm/8LlXl7p8RSvbLgX8BHxMS3ubo1VBXELpKnwfGAlOB94FNsWWuB3YltOaWNzV2EZFs1BqqJQtSqZYt1ovu/5rm7pe16BdlmJm1BeYBd8arRZtq1apVqp4UkRbXrVu3Rj8g8g+fHlnv79QPBr2Z6AMo9YSSBjKzoYQu00lAJ+Cq6N+nk4xLRCTTSlLZnzqyP8Jqout7Z9Qy+0/ufnELfv3lgBGqH6cARzXDfXgiIq2KBisF3H14M2/yeuCuWuataebvquTuHxBu+hYRyWs5US2ZbaKnniyrd0EREWkRraFastUlNxERSVZO3cQtIiICGolbRERy0Jby7E8d2R+hiIhklXK13EREJNdkwxNI6qPkJiIiaVG1pIiI5BzdxC0iIjmnKS03M7uUMDJMH8KD7Ee5+5t1LH8acCXhAf5rCGN/XuHuS+r6nuzvOBURkaxSniqo91UTMzsVuA+4FRgKvAO8ZGb9aln+i8AfCaPR7EMYPWZvwkgzdVLLTURE0tKEm7gvB55w90eizz8ys+OBS4Cra1j+cGChu98bfZ5jZr8FflvfFym5iYhIWrY0IrmZWRvgQLZ9NvArwBG1rPY2cKuZfQN4EegBfA/4Z33fp25JERFJS3mqsN5XDXoSBq2uPpj1UqB3TSu4+wRCMhsDlBAGgS4gDHBdJyU3ERFJSzkF9b6ag5ntTeiCvInQ6juekAgfrm9ddUuKiEhatpQ3qlqyGCgDelWb3guorfLxamCSu98ZfZ5qZuuBN83smrrG01RyyzLf7XFh0iFkrWdW/D7pEESExt3n5u4lZvYf4Fjg2disY4G/1LJae0JCjKv4XGfPo5KbiIikpQndjvcAfzSzSYRikYuBnYGHAMzsKQB3Pyta/gXgETO7BHiZcG/caGCyu8+v64uU3EREJC2ljeuWxN2fNrMewHWERDUNOMHd50WL9Ku2/BNm1gm4DLgbWA28DlxV33cpuYmISFqa8vgtd38AeKCWecNrmNag+9qqU3ITEZG0aMgbERHJOaXl2X8XmZKbiIikRaMCiIhIzlFyExGRnFOqkbhFRCTXqOUmIiI5RwUlIiKSc1JquYmISK7RfW4iIpJzytQtKSIiuUYFJSIiknN0zU1ERHJOWXn2J7fs7zitg5n1N7OUmR3UAtt+wsxerGeZF83siXqWOcfM1jVrcCIiCSqnoN5X0tRyq91ISO+/kJnNBe5397taJCIRkSygbslWzN1XJx1DSzrouCFcOvpcCosKeenR13j69r9XmX/yT77OV88fQVlpGauXr+Gu8x9g2fxiAG7957XsddieTHtrBr/45q8TiF5EktQauiWblNzM7CjgDmBfoAxw4Dx3n2ZmRwC3AQcDq4DngavcfU207jhgBrAZqBhS/A/RMuXRMmcQWlCDgY3AeGCUuy9qRKwTgb+7+6+jz38CTgf6uPsSM2sfxTnC3d+Kuht7uvvXo+XbEwbYOwVYD9xXbfvjgN2AO83sTgB3L4jNHxGtMwCYFB2nOenuR3MoLCzkR/efz1VfuYnihSu5f9JtTHj+feZ/srBymZkfzOGHB1/F5o0lfP3ir3DB7Wdyy/fvBeDZu56jbfu2fO3CY5MIX0QS1hpabo2+5mZm2wHPAW8B+wOHAqOBMjPbD3iFkND2B04ChgCPVdvM6VEMhwMXARcCo2Lz2wA3RNv4OtAT+N9GhjwOGB77PAwojk07AiglJJ6a3AUcC5wMjACGAkfF5p8ELAR+RRg+vU9sXlvgauA8wr52BR5q3G40nR2yB5/NXMKSOcso3VLKuKff5ohvVb1s+eG4j9m8sQSATyZ+yo59u1fO++D1aWxYuzGjMYtI9kilCup9Ja0pLbfOhB/pF9x9VjRtBoCZPQU87e53VyxsZpcAH5jZTu6+LJq8GPixu6eAGWY2CLgcuAfA3ePJcHa0jU/MrK+7LyQ944DLoqTcH+gC/AY4Gvg/QpKb4O4l1Vc0s47A+YTW1svRtHMJyYwo1pVmVgasdfcl1TaxHfBDd/do3buAx8ysINr3jOq5S3eWL1xR+bl44UoGH7pnrct/9fwRTPrXB5kITURagZzulox+zJ8AXjaz14DXgD+7+3zgQGAPMzs1tkrF0RgIVCS3idV+3CcAN5lZZ3dfY2YHEFpuQ4DusW30I5ZYGugtQgvqYGCf6POrwMPR/OHAv2pZdyChFTmhYoK7rzOzjxr43ZsrElvks2h73YCVDdxGIkacfiSDDtydnw6/IelQRCRLZEPLrD5NuhXA3c8ldEe+AXwTcDM7LtruHwhJqeK1P7AnMKUh2zazDsDLwAbgTEJSOj6a3aYRsa4D/kNoqQ0HxgITgX5mtke0/XHpbreBSqt9rkjoidyKUbxoJTv27VH5uWff7hQvWrHNckNH7Mdp15zE9d+6nS0l1XdBRPJVrndLAuDuHwIfAreb2UvA2cBkYB93n1nP6odW65o7DPgsarUdSLjGdk1F4YWZndTEcMcRkttg4D5332Rm7wLXUvf1tlnAlii+2VEsHQiFNLNiy5UARU2MscX5ezPZZc8+9O6/E8WLVjL81C9y2+lV6mMYOKQ/ox66kGu+egufL1+TUKQiko1y+vFbZjaAUATyPLAI2B34AvBgNG2imT1E6PZbS0go33D3i2Kb2RkYbWYPAPsBPwNujubNJ1RSXmZmvwP2Am5qbLyRccBPCa3BybFp1wLja7reBpVdkI8SEvhyQrfi9WybyOYCR0aVmJvdvbiJ8baI8rJy7v/Ro9z2r2spLCrk5cfHMm/6Qs7+5al8+v4sJrzwPhfecSbtOu7AL575KQDL5hdz/Ym3A3DP+F+x6+BdaNdxB/5n/kPc84MHef+VD5PcJRHJpIxXCqSvKS23DcAg4FlCC2spMAa43d23RLcJ3Ewo3y8itHj+Vm0bY6J57xIO16PAvQDuvtzMzgZuBX4ITCUUm9R2Xawh3or+fdPdy6L34wjX9cbVs+4VQIdoHzYAv40+x11PSOazCNf3svb0ZtJLHzDppapFIk/e8HTl+6u+Uvt5xOXDrm+xuEQk+5W3goKSglQqmRQc3Rc2zd0vSySALHVs4XdawTlRMp5Z8fukQxDJGd26dWt0hhr49C31/k7NOvXaRDOgnlAiIiLpyeVrbtkmur53Ri2z/+TuF2cyHhGRXJUqTzqC+iWW3Nx9eDNv8nrCU0RqonI/EZFmkg2l/vXJmZZb9NSTZfUuKCIiTdMKKgNyJrmJiEhmpFpBtaSSm4iIpEnJTUREco26JUVEJOeoW1JERHJNQs/+SIuSm4iIpEfJTUREck2BuiVFRCTnqOUmIiI5R08oERGRnNOEZ0ua2aWEsTv7AB8Do9z9zTqWbwNcB5xJGAN0KXCXu/+mru8pbHyIIiKSl1INeNXAzE4F7iOM0zkUeAd4ycz61fFt/wccD1wIGPAdwviedVLLTURE0tKEgpLLgSfc/ZHo84/M7HjgEuDq6gub2VeAEcBAdy+OJs9tyBcpuYmISHoaUVASdS8eyLajt7wCHFHLaicC7wGXm9lZwEbgJeAad19X1/cpuYmISCb0BIoI18zilgJfrmWd3YEvAZuBk4GuwG8J195OqevLlNyyzL6Ti5IOIWsdMOaxpEPIWpNPPy/pECSPZPA+t0JCO/E0d18NYGaXAS+bWS93r54oq6woIiLScI0rKCkGyoBe1ab3ApbU8k2LgUUViS3ySfRvXUUoSm4iIpKmRiQ3dy8B/gMcW23WsYSqyZq8DexsZh1j0wZF/86rK0R1S4qISFoKGn+f2z3AH81sEiFxXUy4fvYQgJk9BeDuZ0XL/w/wC+BxM7uRcM3tPuDP7r6sri9Sy01ERNLTyPvc3P1pYBThpuwphGKRE9y9ohXWj1h3Y1QR+WWgC6Fq8hlgPFDvRWa13EREJC0FTXi2pLs/ADxQy7zhNUxz4Cvpfo+Sm4iIpEejAoiISK5pSsstU5TcREQkPUpuIiKSa5pQLZkxSm4iIpIetdxERCTX6JqbiIjkHiU3ERHJNWq5iYhI7lFBiYiI5Bq13EREJPcouYmISK5pDfe5tYpRAcwsZWZ1DinemGVbmpkNj+LpmXQsIiLNppGjAmRSxlpuZjYcGAvs6O7FmfreTDGzccA0d78s6VgaYvmUNUx/YiGp8hS7HtODgSf2rjJ/Y3EJH/5uHqUbykiVp7DTdmanoV0oWVvK5HvmsHrWBvoO784+5+2a0B5kxlH9+nPDkUdTWFDA09On8dDkSdss87U9BjHykCNIpVJ8smI5o175ZwKRimSOrrlJVkqVp/j4sQUccu0e7NBje96+2tnpoC506tuucpmZf11Cn8O7sttXdmTtwo28/+vZ7HR/Fwq3L2DQqX1Yu2AT6xZsTHAvWl5hQQG/GjaCM5/7M0vWreW5757Oq3NmMnPVyspl+nfpyiUHHsopf/lf1mzeTI927erYokiOaAXdkg1OblHLZAawGagYJfUPwFXuXm5mbYCbgNOB7sDHwHXu/rKZ9Se02gCWmxnAk+5+jpkdD1wL7EtozL4HjHL3T5q4bxVx7wLcDRwXTXon2v5/o/k3AqcANwO3ADsBrwE/qGhhmtl2wJ3AOdE2ngB2APZy9+Fm9gQwDBhmZj+MlhkQC2N/M7sV2A+YDlzo7pObY/8a4/OZG2jfqy3te7UFoM8R3Vj63uoqyQ2gdGP4Cy7dUE7bbtsDsN0ORXQf3JENSzZnNugE7N+rN/NWf86CNasBeOG/zrG778HM/2xtvX1vny/wx4+msGZzOB4rNuZ2wheB1tFyS/ea2+nROocDFwEXEkZVBXic8AN/GiFRPQm8YGb7AwuAk6Pl9gH6ACOjzx2A0cAhwHBgdbRem3R3pjoza09Iqpui2A4HFgOvRvMq9AdOBb5NGBRvKCHRVbiCkNh+ABxGOAanxeaPBCYQjkGf6LUgNv824OfAAcAKYIyZJTYg0qaVJezQY+vhbdejDZtXbamyzJ7f6cOiN1fy+iXTeO/Xs9jn3L6ZDjNxvTt0ZPHatZWfl6xbS+8OHassM6BrNwZ07cazJ3+Pv57yfY7q1z/DUYokIAevuS0GfuzuKWCGmQ0CLjez54DvA/3dfX607P1m9mXgIne/1Mwq+nKWxa+5uftf4l9gZucCawjJ7q30d6mK7wEFwLlRzJjZRcAy4OuEIcshHIdz3H11tMzvgXNj2xkJ3F4Rq5mNAo6P7cNqMysBNrj7kti+VLz9hbuPjab9KtqvXYCFTdy/FvPZ26voO6w7u3+jF6s+Xc+H98/jyLsGU1CY/YMUZlJRYQH9u3Tl+397ht4dOvL0Sd/j+P99krUlud+ylfzVGqol001uEyuSRGQCoSvyS4QkMj32gw7QFni9rg2a2cBoG4cCOxJaRYVAvzRjq8mBhO7BtdXiag8MjH2eV5HYIp8Ruicxsy5Ab6CyL8rdU2Y2CWhoNcXUatsm2n4iyW2H7m3YtKKk8vPGFSWV3Y4VFo5dwcFXh0PUbVAHyraUU7K2lLZdqi6Xy5asX0efTp0qP/fu2Ikl69dVXWbdOqYsXUxpeTkL165hzucrGdC1K1OXLc10uCKZkwUts/o0560AKeBgYEjstRdwXj3rvUhIahcREtxQoBRocrckYf+mVItpCDAIeDi2XNU+ubAvzXls4tuv+LNI7DaMLgPbs37JZjYs20x5aTmL31lFr4O6VFmmXc/tWTEtdMmtW7iJ8i3ltOmcX/VHU5cuoX+XrvTt1JntCwv5xp7Gq3NmVVnmldkzOWyXcI7TbYd2DOjanflrVte0OZGcUdCAV9LS/bU61MwKYq23wwgtkQmE/eld0f1Wg4qmQlHFBDPrAQwGLo112x3QiLhqM5nQXVrs7p83ZgNRl+MSQuJ+PYqxIPq8JLZoCbF9y2aFRQXsc15fJt06C8pT9B3eg067tuPTZxbTZff29DqoC4PP3IVpDy9gzj+WQUEBX7hkNwoKwp/s2Ms+pnRDGeWlKZa+t5qDrx24TTFKLihLpbjhjdd56lsnU1hQyLPTp/HflSv4ySFH8NGypbw6dxZvzJ/Lkf1245XTzqEsVc5t74zn802bkg5dpEXlYrfkzsBoM3uAUPn3M+Bmd//UzMYAT5jZTwlJpTuhQGS2u/8VmEdotXzNzF4ANgKrgGLgAjNbQLgOdSeh5dYcxhCKQZ4zs+uB+YSuxG8BD1VUTDbAfcCVZvYpodrxIkLRyOLYMnOBQ6LK0HXASrLYTkO7sNPQqq21Qd/tU/m+U992HH7ToBrXPfr+fVo0tmwybt4cxs2bU2XavZPeqfL5lrfGcwvjMxmWSLJysFtyDKF18i7wCPAocG8071xCteAdhFsGXgSOIiQ13H0RcAOhCnEpcL+7lxOqFL8ATAN+B/yCcLtBk7n7hiiG2cCzUVxPAt0IibWh7gL+SNi/idG0vxGqMOPLlBCS33Ka55qhiEj2aQXVkgWpVMOiaG1P4GhpZvYB8Ja7/6g5t/uTKd/Lgj+L7PT3tw5OOoSsNfn0+i5ti1TVrVu3Rl8aG3LZvfX+Tk25/yeJXnrLrwqBRjKz3Qg3gY8HtgcuILQ2L0gyLhGRJLSGm7hbVXIzs2uAa2qZ/aa7f7WFvrqc8FSWOwldudOBr7r7+y30fSIiWSunCkrcfXgLxtFQD7H1xuvqWuy5R+6+gHAvn4iIqOXWvNx9JVlehSgikvOU3EREJNfkVLekiIgIQEEDq+yTpOQmIiLpyf7cpuQmIiLpUbekiIjkHN3nJiIiuUfJTUREco26JUVEJOeoW1JERHKPbgUQEZFc05RuSTO7lDAWaB/gY2CUu7/ZgPW+BIwDZrj7vvUtn+54biIikucKyut/1cTMTiUM/nwrMBR4B3jJzOoc/9LMugFPAa81NEYlNxERSU/jByu9HHjC3R9x90+i8TAXA5fU842PEgaantDQEJXcREQkLQXlqXpf1ZlZG+BA4JVqs14Bjqjtu6JuzF7AzenEqOQmIiJpKUjV/6pBT6AIWFpt+lKgd00rmNl+wA3AGe5elk6MKijJMp9vaZd0CFmraIPOxWrTZfOhSYeQtVa3fTfpEHJOJu5zM7O2wNPAFe4+J931ldxERCQ9jbsVoBgoI3QxxvUCltSwfB9gL+BxM3s8mlYIFJhZKXCCu1fv4qykU2EREUlLY7ol3b0E+A9wbLVZxxKqJqtbBOwHDIm9HgJmRu9rWqeSWm4iIpKWJnRL3gP80cwmAW8DFwM7E5IWZvYUgLuf5e5bgGnxlc1sGbDZ3atMr4mSm4iIpKeGasiGcPenzawHcB2h23EaoXtxXrRInfe7pUPJTURE0tOEp2+5+wPAA7XMG17PujcCNzbke5TcREQkLTXdx5ZtlNxERCQtGhVARERyj5KbiIjkmoKy7M9uSm4iIpKWAo3nJiIiOSf7c5uSm4iIpEfVkiIiknvULSkiIrkmE6MCNJWSm4iIpEfdkiIikmtaQ7Vkqx7yxsxSZnZKC2z3RjOr86nTZna/mY2rZ5nhUYw9mzVAEZEklaXqfyUs8eSWpQngLmBYOiuY2Tgzu7+F4hERyRoFqVS9r6SpW7IG7r4OWJd0HC1p1YermPPHOVAOOw3fib7f7Ftl/pw/zWH19NUAlJeUs2XNFg79/aEAzP2/uayasgqAXU/clZ6HZdN5SdMdNWA3rhsxnKKCQp6ZOo2H332vyvyT9t2bnw8/kiVrw5/Inz74kGemTuOwfn255uit50QDe3Rn5PP/5NWZszIaf0t681249bdQXg6nfA0uOL3q/EVL4LrbYeXn0KUz3HEt9N4pzLvzQRg/EVLlcMRBcM2PoaAg47sgzSELkld9mpzcoq65GcBm4Kxo8h+Aq9y93MzaADcBpwPdgY+B69z9ZTPrD4yN1lluZgBPuvs5ZnY8cC2wL+GWwfeAUe7+SSNi/D/gc3e/OPp8c7Ttw919YjRtAXC1u//JzG4ETnH3faN5RcDtwPnRJp8EimLbf4LQ0htmZj+MJg+IhbC/md1KGFV2OnChu09Odz+aS6o8xewnZ7PPz/ehTfc2TL1+Kt0P7E77XdpXLjPgjK3hL35lMevnrgdg5QcrWT93PUNuGUL5lnKm3TKNrl/oynbtc+M8qbCggBu/fAxnP/NXlqxdy1/POo3XZs5i5oqVVZb7x4xP+eWrY6tMmzh/Id98cgwAXXZoy2sXnMdbc+eRK8rK4KbR8Ojd0GtH+O5FcPQXYY/+W5e58wH41nFw4vEwcTLc83u44zr4YFp4PfdYWO70y+C9KXDI0AR2RJouC7od69Nc3ZKnR9s6HLgIuBAYFc17nPDDfxohUT0JvGBm+wMLgJOj5fYhDF43MvrcARgNHAIMB1ZH67VpRHzjom1UGA4UV0wzsz2AvtFyNfkpcEG0b4cTElv8nHUkMIGwr32i14LY/NuAnwMHACuAMWaW2DnrulnraNerHTvstAOF2xXS87CerPzPylqXL55QTM/DQ+ts46KNdLbOFBQVULRDER36deDzqZ9nKPKWt3+f3sz7/HMWrF7NlvJy/vGJ8+U9Bqa9neNtEOPnzGFTaWkLRJmMqZ9Av11g152hzfZwwjHw+ltVl5k5Dw49ILw/dCi8/vbWeZtLYEsplGyB0jLo0S1zsUvzyqduycXAj909Bcwws0HA5Wb2HPB9oL+7z4+Wvd/Mvgxc5O6XmlnFr+oydy+u2KC7/yX+BWZ2LrCGkOyq/S9Vr3HAg2bWh5AkDwauB44Bfk1IcrPcfWEt648C7nD3Z6JYRgLHxWJdbWYlwAZ3XxKLueLtL9x9bDTtV1H8uwC1fV+L2rxqM226bz1HaNO9Detm1dwLu6l4E5uWbaLLPl0A6LBbBxb8dQE7n7Az5SXlrJ6+mna7tMtI3JnQq2NHFq9dW/l5ydp17L9z722WO27QnhzcdxfmrvqcW14fx+K1VY/f1wcP4rH3E2uct4hlxVu7GCG03qZW60cZPBD+/QacdQr8+01Yv6GAVatTDN03JLujTgo9Wqd/Gwb2z2j40pyyIHnVp7mS28QosVWYQOiK/BJQAEyP/dADtAVer2uDZjYw2sahwI6ElmEhjRiG3N1nmNkSQhJbDswCngZ+YWbbR9PH1RJHF0JLbEJse+Vm9i6wawNDmBp7/1n0704klNzSUTyhmB6H9KCgMDQ0u+7XlXWz1/HRLz9i+87b02nPTpXz8sXrM2fz4idOSVkZ39t/P+444TjOfHrrudiOHTpgO/bkzTm50yXZUFdeGrou//4SHLQ/9NoxRVEhzFsIs+bB2GfDcuf/FN7/MCwjrVB59t/FnYkLJSlCS2lLtekb61nvRcKP/0XAIqCUcL2qMd2SAOOBo4FlwFh3n2tmxVFsw4CrG7ndhojve8VJQGKVqm27taVkZUnl55KVJbTpVvNhXTFxBQPOHlBlWt9v9aXvt0IByqe/+5R2vXOn5bZ03Tr6dOpU+bl3p44srdYq+3zTpsr3z0ydxlXDj6wy/4TBg3jlv7MobQU/AOnYqScsWbb189Ll0Kvntsv89ubwfv0GeOUN6NwJnn0R9t8bOkSXdY88FKZ8rOTWarWCP+3m+oE9tNo1pMMILZQJhJZbb3efWe21KFq24lc2XqDRAxgM3Orur0ZFJJ1oWjIeR0huw9naShtHuJZW6/U2d19N6HY9LBZfAaF7NK4kvg/ZrOPuHdm4ZCOblm2ivLSc4onFdD+g+zbLbfhsA6XrS+m059Yf+1R5ii1rQ65eP3896xesp+t+XTMVeoubungJu3XrRt8undm+sJCv7WW8NnN2lWV27NCh8v2IPXZnVrVik2/sZbz4yYyMxJtJ+w0OLbCFi8N1s3++HgpK4lZ9vvWk/pExcNJXw/s+veC9D6G0NFx3e/9DGLhbRsOXZpRP19x2Bkab2QOEisCfATe7+6dmNgZ4wsx+CkwmVEwOB2a7+1+BeYTWzNfM7AVCi24VoeDjgqiKcRfgTkLrrbHGAQ8Cu1E1uT1C3dfbAO4DrjazT4GPgEsJXZWLY8vMBQ6JKkDXAbVXaCSsoKiA3c/enel3TCdVnqLXsF6079ue+X+eT8cBHel+YEh0xROK6XlYTwpi9dqp0hTTbgr3txe1K2LQJYMoKMqdbsmyVIpfvvo6j3/nJIoKCnj2o4/574oVjPzS4UxbspTXZs7m7AOHMGKPgZSWl7N60yau/OfLlevv0rkzvTt14t35Wd/jnLbttoPrRsEPrggJ7KQTYM8B8JtHYd/BcMwXYdKUUCFZUBBaZdePCuseNwzenQzfOjfM+9Ih2yZGaUXKsr/p1lzJbQyh1fIuIVE9CtwbzTuXUHZ/B6GFtBKYRHQLgLsvMrMbgFsItxA8Fd0KcCrwG2AaMJNQsVilyCQdsetuK9x9eTR5HOEYjKtn9buB3lF8AH+M9nmv2DJ3ESpBpwPtqHorQNbpNqQb3YZULVfrd0rVy5n9Tt728mZhm0KG3pHb9dvjZ89l/Ownqky7763KS67c9cbb3PXG29Rk0Zo1fOnBR1oyvEQNOyy84n58/tb3xw0Pr+qKiuCXV7RkZJJRWdAyq09BqolBRve5TXP3y5olojx37nvnZv9fTULeHPuFpEPIWp+e9WDSIWSt1W3fTTqErNStW7dGd7l8dc8r6/2deum/dyTapZMbd96KiEjm5FG3ZKLM7Brgmlpmv+nuX81kPCIiOS2VB8nN3Yc3QxxN9RDwTC3z6rvlQERE0tEKrrnlRMvN3VeSxdWJIiI5Rd2SIiKSc9RyExGRnFNWlnQE9VJyExGR9KjlJiIiOUfJTUREck1K3ZIiIpJzytVyExGRXKNuSRERyTnqlhQRkVyTagUD8Sq5iYhIetQtKSIiOUfdkiIikmtSqpYUEZGckw9D3oiISH5pDTdxF6RawYVBERGRdBQmHYCIiEhzU3ITEZGco+QmIiI5R8lNRERyjpKbiIjkHN0KIDUys65UO/lx95XJRJNddGxEsp+Sm1Qys92Ah4DhQJvYrAIgBRQlEFZW0LGpn5ntDOzEtol/cjIRZQ8dm8xTcpO4x4GuwPnAZ4QfbQl0bGphZkOBPwGDCck+Lq8Tv45NcpTcJO4Q4DB3n5Z0IFlIx6Z2vwcWABegxF+djk1ClNwkbg7QNukgspSOTe32Boa6+6dJB5KFdGwSompJiRsJ3GZmeyQdSBbSsandR0DvpIPIUjo2CdGzJfOcma2lalfJDoTrAJuB0viy7t45g6ElTsemdmbWPfZxCHArcB3hx3xLfNl8qyTVsckO6paUy5IOIIvp2NSumKqJvwB4pYZp+Vg0oWOTBZTc8py7P5l0DNlKx6ZORycdQBbTsckC6paUSmZWBvRx92XVpvcAlrl73p5l6tjUzsz6AQvcPVVtegGwq7vPTyay5OnYJEcFJRJX/T6cCm2BkkwGkoV0bGo3B9ixhundo3n5TMcmIeqWFMzs8uhtCrjYzNbFZhcBRwIzMh5YFtCxaZCK60fVdQQ2ZTiWbKNjkxAlNwH4UfRvAfADID6GfAkwF7g4wzFlCx2bWpjZb6K3KcJtEhtis4sIN75PyXRc2UDHJnlKboK7DwAws7HASe6+KuGQsoaOTZ32i/4tAPaiavdsCTAZuCvTQWUJHZuEqaBERJrEzB4HRrr7mqRjyTY6NslRcpNKZvZYLbNShOsDM4Gn3f2zzEWVnDqOxzbc/byWjEVE0qNuSYnbkVAgUQ5UPCB4X0LXyn+Ak4BfmdmR7j4lkQgzq3qV21GEY/NR9HlfQsXxG5kMKtuY2eu1zIqfFD2ZL8O7RF3YDWo1uPsxLRxO3lJyk7i3gXXA+e6+AcDM2gOPAB8CJwBPAXcDI5IKMlPc/RsV783samAjcK67r4+mdQAeZWuyy1czgNOAJcCkaNrBhGcq/p1wwnSpmR3v7q8lEmFmxUeOKAJOJxybd6NphwB9CEPhSAtRt6RUMrPFwDHu/km16XsDr7l7n2h8qlfdvUciQSYkOjYj3H16ten7EI5N3j4c18zuAQrdfVS16XcDKXe/wszuAw5x98OTiDEpZnYvIcGNjN/IbWajgQJ3H5lUbLlON3FLXEfCGWV1vaN5AGvIzxZ/R2DnGqb3AdpnOJZsczbwuxqmPwycG71/hDD8S745C7i/+hNKgAeAMxOIJ2/k44+U1O5vwKNmdiXwXjTtYOAO4K/R50OAfByb6i/A42b2M2BiNO0w4Ha2Hpt8VQDsA/y32vS92fpklxLC9cp8U0C4LaD6/zP71bCsNCMlN4m7GLiHcC2g4m+jFHgMuCL6/AlhVOF8cwnhWuMTwPbRtFLCNbcralknXzxJOCnak6onRVcRjhfAMKpei8oXjwF/iI5N/KToSuDxxKLKA7rmJtuICiUGRh9nVRRQiI5NTcysCPgZ8GO2Dsy5BLgPuMvdy6IHCJe7+8KEwkyEmRUSTn5GsrXLfzHh2Nzt7mW1rStNo+QmIs3GzDoD6KblbenYZJaSm1Qysx0IZ5gjgJ2oVnDk7l9IIq6kmNnzwBnuviZ6Xyt3/2aGwhKRBtA1N4l7APg28CzwDg28ETWHrWDrMViRZCDZzMy6A7dQ+0lR5yTiSoqZTQWGufsqM/uIOv4/yrcTxkxScpO4E4HvuPurSQeSDdz93JreyzYeBYYCvwc+QydFfwE2x97n+/FIhLolpZKZLSTcqOxJx5JtzOwIYJK7lyYdS7YxszXAse7+br0Li2SIbuKWuDuAy82stlGn89nrwCoze8XMrjGzI8xMPR/BMsJj26QaMzvNzGp6MIK0MLXcpJKZvUB4DuBqYDqwJT4/n4smzKwd8EXC/VrDCfdxbQEmAGPd/bbkokuWmZ0KfBc4292V5GLMbD6wCzALGFfxypeRNZKk5CaVorGnaqXrTluZ2UDgWuAMoMjdixIOKTFR0UR/wjMU57HtSVFeF02Y2R6EE6Jh0asi2Y1194sSDC2nKbmJNICZ7UT4gTo6+rcf4Qn44whn4uOTii1pZnZDXfPd/ZeZiiWbRTe7H0J4wk/enxS1NCU32YaZHUR4CseL7r4+eirH5nwupjCzcmA54WHArwLvuvvmuteSfGdmh7D1pOiLQDEwnq0nRfMSCy7H6YK4VDKzXsBzhLPLFLAnMJvwvMlNhBu889X/EAYrHQkcAIw1s3HA5Bqe+J53ogcAfJ1wUvSwu38edd2ucveVyUaXqImEk6K7gIvcfX7C8eQNVUtK3L3AUqAHsCE2/VngK4lElCXc/Qx370dIbH8DhhBGA1hpZs8lGVvSomtKM4CHCDdzd49mXUKowM1ntxJGBLgJ+KeZ/dbMTjazvBoPMQlKbhI3ArjW3VdVmz6LcI1JYA7h6fbTAQc6AMcnGlHyRgOvAL0Io5VXeJ7QHZe33P06dz8S6EZo9X8e/bvIzD5MMrZcp25JiWtHGHeruh0J3ZJ5KxrjbjjwJaAt8B/CtZO7gbeSiywrHAEcFj39Pz59PjUP8JqPOgM9CY8n6w20iT5LC1Fyk7g3gHOAa6LPqajC6yrgtaSCyhLfJhQB3Ae8paFutrF9DdP6Ee6ZzFtm9iCh/N8IXf4VJ0Tj9CSglqXkJnFXAuPN7GBC6+RuwgjLXQiVXnnL3Q9vyHJm9gBwvbsXt3BI2eQV4HLg/OhzKhre5ZfAPxKLKjt0JZwQKZllmG4FkCrMrDehEOBAwjXZycDv3H1xooG1EtFzFoe4++ykY8kUM9sZGBt93B34ANiD0FI5yt2XJxVba2Fm/wB+oP/Pmo9ablKFuy8B6rwpV+qUd8/ldPfPzGwI8H1CNWkhYYSAMe6+sa51pdJRhGve0kyU3PKcmR3Q0GXdfXJLxiKtV5TEHoteIolTcpP3CTds19fiSBGeHSiCmZ3U0GXd/a8tGYtITZTcZEDSAUir9OcGLqeTIkmEkluea8yz7fK0IlBi3F0PgJCspj9QaYwzCDelyrb+BKxJOohsZGb/0MCdkilquUlj5EVFYGOKbdz9kpaLqNVTRWDtbgXy+QHTzU7JTaR2KraRtDWm2CafR3JvKUpuIrVTsY00hoptsoCSm0gtNJCkNIaKbbKDkptIGqJHTfUjPNW9kru/kUxEIlITJTdpjLyrCIySWsVo3BXX4eIPZlX3ktTIzLYjjG5f00nRU4kElQeU3PKcKgIbbDRQBuwNvEcYoLQX8CvgJ8mF1arkXUWgmQ0GXiBcvy0g/A1tB2wBNgNKbi1EyU1UEdgww4CvufsMM0sBy939bTPbDNwE/DvZ8DJLFYENNpowsO0QYEn0bxfgQeC6pILKB0puoorAhmkHVDyRZSVhROVPgenAF5IKKkGqCGyYg4Fh7r7ezMqB7dx9cjSy+2/Jz7+djFByy3OqCGywGcBgYC4wBbjYzBYAPwQWJRdWMlQR2GAFwIbo/XJgF8CBhYQx76SFKLnJNlQRWKP7gN7R+18B/yKMX7YZODupoCTrTQP2B2YDk4CrzKwMuACYmWRguU7JTSqpIrB27j4m9n6ymfUntOTm6wHSqgiswy1Ah+j9dcA/CKOWFwOnJhVUPihIpVL1LyV5wcyeAXoQutq2qQh097wqmqiNmXUEcPd1SceSDeqrCHR3PWQ7xsy6A6vcXT++LUj95hI3DLjK3WcQWmzLo0q3qwgVgXnNzEaZ2XxgNbDazBaY2U/MLC8eJF2H0YSKwC6E60t7AQcRrk2enFhUWcDMHjOzTvFp7r4SaG9mGrW8BSm5SVxNFYGQvxWBlczsDuBG4GHg2Oj1EHA9cHtykWWFg4Gb3X09UFkRCFwJ3J1oZMk7m5pHQmgHnJXhWPKKrrlJnCoCa/cD4AfuHi+Bf93MnJDwrkwmrKygisBqoq7HgujVzcxKY7OLgK8BS5OILV8ouUmcKgLrNrWWafneA6KKwG0VE7r2U4Sej+pSwA0ZjSjPKLlJJVUE1ukpQgt2ZLXplwB/zHw4WUUVgds6mtBqe51w3TH+2LESYJ67f5ZEYPlC1ZJSI1UEVmVmDwKnAYuBidHkQ4GdgTFAZbeTu/844wFmGVUEBma2G+HkMK+PQxKU3KQKMxsFXE64bgLwGXAPMDqf/wc1s7ENXDTl7se0aDBZJqr6G+nua6tN7wD81t3PSyay7GBm+wEXAQOB89x9sZmdSGi9fZBocDlM3ZJSKaoIvBC4E5gQTT6cUBHYhzwumnD3o5OOIYudDfwcWFttekVFYN4mNzP7CvA88BJwDFsrJwcC5wAnJhJYHlBykzhVBNbDzHoSfpimuPvmpONJkioCG+Qm4HJ3f8DM4sl/HPDTZELKD0puUp0qAmsQ3Yj7GKE4IAXsCcw2s4eAJe5+Y4LhJUUVgfXbF/hnDdNXAt0zHEteyesfLNlGRUVgdaoIDDdq7wwcAGyMTX8R+HYiESXvaGAEoeV2CqHbreL1JaCfu9+SXHhZYSVbr1/HHUC4D1BaiFpuEtcWOM3MjqOGikAz+03FgnlYEfhN4NvuPiUarLTCJ8DuCcWUKHcfD2BmA1BFYG3+B7jTzL5LaMluZ2bDgLuAxxONLMcpuUncYGBy9H636N8l0Wuv2HL5+CPWDVhRw/ROhAcF5y13n2dm+5mZKgK3dR3wBDCP0MKdTugxG0O4P1BaiJKbVFJFYJ3eI7TeRkefKxL8RcA7SQSULVQRWDt33wKcbma/IHTVpoAJ7p6vT27JGCU32YYqAmt0DfCyme1D+P/m8uj9ocCRiUaWPFUE1qGme0fNLO/vHW1pKiiRSmbWycyeBZYRWiO7RNMfMrMbk4wtae7+DuGevzbALEIhxSLgsOgJ+PlMFYG10GgSyVFykzhVBNbCzPYGtrj72e6+LzCKcA3l62aWtyOUR1QRWLuKe0dvcffXo9cthIdKn59wbDlNyU3ivgmMcvcpVC0ayduKwJjHgKEAZrYr8DdCq+SHwM0JxpUNKioC+7JtReBTiUaWHXTvaAJ0cCVOFYG1i1eSngJMcvcTgDMJwwLls+uAOYSKwI6EisCxwFuoIlD3jiZEBSUSp4rA2hURhiqBcL2t4hrTLKBXIhFlCVUE1kn3jiZEyU3iVBFYu2nAJWb2IiG5XR1N34XwGKq8porAWune0YQouUkld3/HzA4HfsbWisD/ECoCP0o0uORdBfwduAJ4MnY8vkkYfTpvaTSJ2une0eRoPDepFFUElrm7R5+/Qhiy5GPgDnfP6+tuUVVkZ3dfFZvWH9jg7ssSCyxhZrYSuLDaaBKY2SnAw+7eI5nIJJ+p5SZxjxGut3msInA84YJ4Z7Z2xeWlKLmvqjZtbjLRZB1VBEpW0R+exKkiUBpDFYGSddRykzhVBEpjqCJQso6Sm8SpIlAaQxWBknWU3CROFYGSNlUESjZStaRUoYpAEckFSm4iIpJzVC0pIiI5R8lNRERyjpKbiIjkHCU3ERHJOUpuIiKSc/4fpMENJjHpVDkAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "corr = iris[iris['species'] != 'virginica'].corr() \n", "mask = np.tri(*corr.shape).T \n", "g = sns.heatmap(corr.abs(), mask=mask, annot=True, cmap='viridis')\n", "g.set_xticklabels(g.get_yticklabels(), rotation = 90, fontsize = 14)\n", "g.set_yticklabels(g.get_yticklabels(), rotation = 0, fontsize = 14)\n", "\n" ] }, { "cell_type": "markdown", "id": "a6144913", "metadata": {}, "source": [ "
\n", "   Notes
\n", "

\n", " (i) Correlated data has typically less power to restrict the model; correlated variables translate into wider combinations of coefficients that are able to explain the data.\n", "

\n", "

\n", " (ii) One solution when dealing with highly correlated variables is to remove one (or more than one) correlated variable.\n", "

\n", "

\n", " (iii) Another option is scaling all non-binary variables to have a mean of 0, and then using:\n", "

\n", "
\n", " $\\beta \\sim StudentT(0,\\nu,sd)$\n", "
\n", "

\n", " $sd$ should be chosen to weekly inform us about the expected value for the scale. The normality parameter $\\nu$ is typically chosen to be in the range (3,7). \n", " This prior is saying that in general we expect the coefficienct to be small, but we use wide tails because occasionally we will find some larger coefficients.\n", " \n", "

\n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "91fa75de", "metadata": {}, "source": [ "$\\beta \\sim StudentT(0,\\nu,sd)$\n" ] }, { "cell_type": "markdown", "id": "3907e55c", "metadata": {}, "source": [ "## Recipe 2: Dealing with unbalanced classes" ] }, { "cell_type": "code", "execution_count": 6, "id": "f65d95c8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n" ] } ], "source": [ "df = iris.query(\"species == ('setosa', 'versicolor')\") \n", "df = df[45:] # let's select two unbalanced classes\n", "y_3 = pd.Categorical(df['species']).codes \n", "x_n = ['sepal_length', 'sepal_width'] \n", "x_3 = df[x_n].values\n", "\n", "print(y_3) #this is why is unbalanced" ] }, { "cell_type": "markdown", "id": "c25dc3ba", "metadata": {}, "source": [ "Doing the usual thing: build the logistic regression..." ] }, { "cell_type": "code", "execution_count": 23, "id": "987cf733", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n", " return wrapped_(*args_, **kwargs_)\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [β, α]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [8000/8000 00:04<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds.\n" ] } ], "source": [ "with pm.Model() as model_3: \n", " α = pm.Normal('α', mu=0, sd=10) \n", " β = pm.Normal('β', mu=0, sd=2, shape=len(x_n)) \n", " \n", " μ = α + pm.math.dot(x_3, β) \n", " θ = 1 / (1 + pm.math.exp(-μ)) \n", " bd = pm.Deterministic('bd', -α/β[1] - β[0]/β[1] * x_3[:,0]) \n", " \n", " yl = pm.Bernoulli('yl', p=θ, observed=y_3) \n", " \n", " trace_3 = pm.sample(1000, target_accept=0.95)" ] }, { "cell_type": "code", "execution_count": 24, "id": "976f39e8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/arviz/plots/hdiplot.py:157: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions\n", " hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)\n" ] }, { "data": { "text/plain": [ "Text(0, 0.5, 'sepal_width')" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "idx = np.argsort(x_3[:,0]) \n", "bd = trace_3['bd'].mean(0)[idx] \n", "\n", "plt.scatter(x_3[:,0], x_3[:,1], c= [f'C{x}' for x in y_3]) \n", "plt.plot(x_3[:,0][idx], bd, color='k')\n", "\n", "az.plot_hdi(x_3[:,0], trace_3['bd'], color='k')\n", "\n", "plt.xlabel(x_n[0]) \n", "plt.ylabel(x_n[1])" ] }, { "cell_type": "markdown", "id": "a4bfe7a3", "metadata": {}, "source": [ "
\n", "   Notes
\n", "

\n", " (i) In case of an unbalanced dataset, logistic regression can run into some trouble: the boundary cannot be determined as accurately as when the dataset is more balanced. \n", "

\n", "

\n", " (ii) The decision boundary is \"shifted\" towards the less abundant class, and the uncertainty band is larger.\n", "

\n", "

\n", " (iii) It is always good to have a balanced dataset. If you do have unbalanced data though, you should be careful when you interpret results: check the uncertainty of the model, and run some posterior predictive checks for consistency. Another option is to input more prior information if available and/or run an alternative model. \n", "

\n", " \n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "d981ce29", "metadata": {}, "source": [ "## Generalization to multiple classes: Softmax Regression" ] }, { "cell_type": "markdown", "id": "aa61bd29", "metadata": {}, "source": [ "
\n", "   Notes
\n", "

\n", " In order to generalize to mutliple classes, two modifications are needed: \n", "

\n", "

\n", " (i) We use a softmax (see also Boltzmann distribution in physics), which is defined as:\n", "

\n", "

\n", "
\n", " $softmax_{i}(\\mu)= \\frac{exp(\\mu_{i})}{\\sum_{k}exp(\\mu_{k})}$ \n", "
\n", "

\n", " (ii) We then replace the Bernoulli distribution with the \n", " categorical distribution.\n", " As the Bernoulli (single coin flip) is a special case of a Binomial (n coin flips), the categorical (single roll of a die) is a special case of the multinomial distribution (n rolls of a die).\n", "

\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": 80, "id": "250279e5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(150, 4)\n", "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2]\n", "Index(['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], dtype='object')\n", " sepal_length sepal_width petal_length petal_width species\n", "0 5.1 3.5 1.4 0.2 setosa\n", "1 4.9 3.0 1.4 0.2 setosa\n", "2 4.7 3.2 1.3 0.2 setosa\n", "3 4.6 3.1 1.5 0.2 setosa\n", "4 5.0 3.6 1.4 0.2 setosa\n", ".. ... ... ... ... ...\n", "145 6.7 3.0 5.2 2.3 virginica\n", "146 6.3 2.5 5.0 1.9 virginica\n", "147 6.5 3.0 5.2 2.0 virginica\n", "148 6.2 3.4 5.4 2.3 virginica\n", "149 5.9 3.0 5.1 1.8 virginica\n", "\n", "[150 rows x 5 columns]\n" ] } ], "source": [ "iris = sns.load_dataset('iris')\n", "y_s = pd.Categorical(iris['species']).codes\n", "x_n = iris.columns[:-1]\n", "x_s = iris[x_n].values\n", "\n", "x_s = (x_s - x_s.mean(axis=0)) / x_s.std(axis=0)\n", "\n", "print(np.shape(x_s))\n", "\n", "print(y_s)\n", "print(x_n)\n", "print(iris)" ] }, { "cell_type": "code", "execution_count": 62, "id": "8d9ed80a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n", " return wrapped_(*args_, **kwargs_)\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [β, α]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [12000/12000 00:19<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 33 seconds.\n" ] } ], "source": [ "with pm.Model() as model_s:\n", " α = pm.Normal('α', mu=0, sd=5, shape=3)\n", " β = pm.Normal('β', mu=0, sd=5, shape=(4,3))\n", " μ = pm.Deterministic('μ', α + pm.math.dot(x_s, β))\n", " θ = tt.nnet.softmax(μ)\n", " yl = pm.Categorical('yl', p=θ, observed=y_s)\n", " trace_s = pm.sample(2000, target_accept=.95)\n", " idata_s = az.from_pymc3(trace_s)" ] }, { "cell_type": "code", "execution_count": 63, "id": "1c7e1c4d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy is: 0.980\n" ] } ], "source": [ "data_pred = trace_s['μ'].mean(axis=0)\n", "\n", "y_pred = [np.exp(point)/np.sum(np.exp(point), axis=0)\n", " for point in data_pred]\n", "\n", "res_t = np.sum(y_s == np.argmax(y_pred, axis=1)) / len(y_s)\n", "print(\"accuracy is: {:1.3f}\".format(res_t))\n" ] }, { "cell_type": "code", "execution_count": 64, "id": "bd3265cf", "metadata": {}, "outputs": [], "source": [ "from scipy.special import softmax " ] }, { "cell_type": "code", "execution_count": 65, "id": "32eac468", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy is: 0.980\n" ] } ], "source": [ "y_pred2 = softmax(data_pred, axis=1)\n", "res_t2 = np.sum(y_s == np.argmax(y_pred2, axis=1)) / len(y_s)\n", "print(\"accuracy is: {:1.3f}\".format(res_t2))" ] }, { "cell_type": "code", "execution_count": 78, "id": "57a64211", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
α[0]-1.1353.746-7.9216.0670.0480.0396080.05193.01.0
α[1]5.8303.251-0.18111.9870.0440.0315382.05418.01.0
α[2]-4.8463.421-11.2671.4360.0460.0355554.05354.01.0
β[0, 0]-2.6044.099-10.7364.5680.0470.0417711.05642.01.0
β[0, 1]1.9853.267-4.2147.9660.0440.0345404.04740.01.0
β[0, 2]0.6403.265-5.4646.7700.0450.0365341.05130.01.0
β[1, 0]3.1863.414-3.5149.2720.0480.0345167.05913.01.0
β[1, 1]-1.0093.037-7.1354.3670.0450.0324579.05063.01.0
β[1, 2]-2.4083.057-8.3283.1680.0450.0324636.05119.01.0
β[2, 0]-6.3334.256-14.1921.5480.0470.0368195.06219.01.0
β[2, 1]-1.3953.540-8.3085.0340.0420.0357265.06032.01.0
β[2, 2]7.8103.7430.96715.0850.0440.0327166.05728.01.0
β[3, 0]-5.7924.358-14.1322.0980.0510.0387316.05879.01.0
β[3, 1]-1.0783.558-7.5015.8230.0460.0375998.05723.01.0
β[3, 2]6.7443.683-0.05513.7450.0470.0346223.05727.01.0
\n", "
" ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", "α[0] -1.135 3.746 -7.921 6.067 0.048 0.039 6080.0 \n", "α[1] 5.830 3.251 -0.181 11.987 0.044 0.031 5382.0 \n", "α[2] -4.846 3.421 -11.267 1.436 0.046 0.035 5554.0 \n", "β[0, 0] -2.604 4.099 -10.736 4.568 0.047 0.041 7711.0 \n", "β[0, 1] 1.985 3.267 -4.214 7.966 0.044 0.034 5404.0 \n", "β[0, 2] 0.640 3.265 -5.464 6.770 0.045 0.036 5341.0 \n", "β[1, 0] 3.186 3.414 -3.514 9.272 0.048 0.034 5167.0 \n", "β[1, 1] -1.009 3.037 -7.135 4.367 0.045 0.032 4579.0 \n", "β[1, 2] -2.408 3.057 -8.328 3.168 0.045 0.032 4636.0 \n", "β[2, 0] -6.333 4.256 -14.192 1.548 0.047 0.036 8195.0 \n", "β[2, 1] -1.395 3.540 -8.308 5.034 0.042 0.035 7265.0 \n", "β[2, 2] 7.810 3.743 0.967 15.085 0.044 0.032 7166.0 \n", "β[3, 0] -5.792 4.358 -14.132 2.098 0.051 0.038 7316.0 \n", "β[3, 1] -1.078 3.558 -7.501 5.823 0.046 0.037 5998.0 \n", "β[3, 2] 6.744 3.683 -0.055 13.745 0.047 0.034 6223.0 \n", "\n", " ess_tail r_hat \n", "α[0] 5193.0 1.0 \n", "α[1] 5418.0 1.0 \n", "α[2] 5354.0 1.0 \n", "β[0, 0] 5642.0 1.0 \n", "β[0, 1] 4740.0 1.0 \n", "β[0, 2] 5130.0 1.0 \n", "β[1, 0] 5913.0 1.0 \n", "β[1, 1] 5063.0 1.0 \n", "β[1, 2] 5119.0 1.0 \n", "β[2, 0] 6219.0 1.0 \n", "β[2, 1] 6032.0 1.0 \n", "β[2, 2] 5728.0 1.0 \n", "β[3, 0] 5879.0 1.0 \n", "β[3, 1] 5723.0 1.0 \n", "β[3, 2] 5727.0 1.0 " ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "az.summary(idata_s).head(15)" ] }, { "cell_type": "markdown", "id": "38c2e3d4", "metadata": {}, "source": [ "
\n", "   Notes
\n", "

\n", " (i) 98% is the accuracy on our data; \n", " a true test to evaluate the performance of our model will be to check it on data not used to fit the model\n", "

\n", "

\n", " (ii) You can check that we obtained a wide posterior. This is a result of the fact softmax normalizes probability to 1. Therefore, when we used priors on the parameters of 4 species, in reality we can \"eliminate\" one species\" from the problem, in that one of them can be calculated from the other 3 once we know their probabilities (again, they have to sum up to 1!)\n", "

\n", "

\n", " (iii) Below is a suggested solution, that does fix the extra parameters to some value, e.g., zero \n", "

\n", " \n", " \n", " \n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 89, "id": "3573ce58", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/deprecat/classic.py:215: FutureWarning: In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n", " return wrapped_(*args_, **kwargs_)\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [β, α]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [12000/12000 00:10<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 24 seconds.\n" ] } ], "source": [ "with pm.Model() as model_sf:\n", " α = pm.Normal('α', mu=0, sd=2, shape=2)\n", " β = pm.Normal('β', mu=0, sd=2, shape=(4,2))\n", " α_f = tt.concatenate([[0] ,α])\n", " β_f = tt.concatenate([np.zeros((4,1)) , β], axis=1)\n", " μ = pm.Deterministic('μ', α_f + pm.math.dot(x_s, β_f))\n", " θ = tt.nnet.softmax(μ)\n", " yl = pm.Categorical('yl', p=θ, observed=y_s)\n", " trace_sf = pm.sample(2000, target_accept=.92)\n", " idata_sf = az.from_pymc3(trace_sf)" ] }, { "cell_type": "code", "execution_count": 90, "id": "a7df5514", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy is: 0.973\n" ] } ], "source": [ "data_pred_sf = trace_sf['μ'].mean(axis=0)\n", "\n", "y_pred_sf = softmax(data_pred_sf, axis=1)\n", "res_sf = np.sum(y_s == np.argmax(y_pred_sf, axis=1)) / len(y_s)\n", "print(\"accuracy is: {:1.3f}\".format(res_sf))\n" ] }, { "cell_type": "code", "execution_count": 95, "id": "fdaf3b4c", "metadata": {}, "outputs": [], "source": [ "#az.summary(idata_sf) --- it will complain as one value of mu is 0 by construction " ] }, { "cell_type": "code", "execution_count": null, "id": "adf11cde", "metadata": {}, "outputs": [], "source": [ "cmpd_df = az.compare({'model_s':idata_s, 'model_sf': idata_sf}, method='BB-pseudo-BMA', ic='waic')" ] }, { "cell_type": "markdown", "id": "9852437a", "metadata": {}, "source": [ "## Final remarks: Robust Logistic Regression (extra, for the curious...)" ] }, { "cell_type": "markdown", "id": "2a5eca01", "metadata": {}, "source": [ "Let's take the dataset for the species setosa and versicolor only. \n", "Let's complicate the problem by assuming the presence of unusual seros and/or ones in our dataset." ] }, { "cell_type": "code", "execution_count": 97, "id": "46f49670", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, 'sepal_length')" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "iris = sns.load_dataset(\"iris\") \n", "df = iris.query(\"species == ('setosa', 'versicolor')\") \n", "y_0 = pd.Categorical(df['species']).codes \n", "x_n = 'sepal_length' \n", "x_0 = df[x_n].values \n", "y_0 = np.concatenate((y_0, np.ones(6, dtype=int))) \n", "x_0 = np.concatenate((x_0, [4.2, 4.5, 4.0, 4.3, 4.2, 4.4])) \n", "x_c = x_0 - x_0.mean() \n", "plt.plot(x_c, y_0, 'o', color='k');\n", "plt.xlabel(x_n)" ] }, { "cell_type": "markdown", "id": "9e439e03", "metadata": {}, "source": [ "We have some versicolor (category 1) with some unusually short sepal_length... \n", "\n", "We can fix this with a **mixture model**. We say that the outpu variable comes with $\\pi$ probability of random guessing (0.5 chance for category 1 to be indeed 1), and with 1-$\\pi$ probability from a logistic regression model:\n", "\n", "$p= \\pi \\ 0.5 + (1-\\pi) \\ logistic(\\alpha+X\\beta)$\n", "\n", "Notice that when $\\pi=1$, we get $p=0.5$ (random guess), whereas when $\\pi=0$ we get the logistic regression. \n", "\n", "This model can be implemented with a slight modification of what we saw in mod2_part2. \n", "\n", "N.B. $\\pi$ is a new variable in our model" ] }, { "cell_type": "code", "execution_count": 101, "id": "daec8bbe", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [π, β, α]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [16000/16000 00:02<00:00 Sampling 4 chains, 1 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "/Users/cfanelli/Desktop/teaching/BRDS/jupynb_env_new/lib/python3.9/site-packages/scipy/stats/_continuous_distns.py:624: RuntimeWarning: overflow encountered in _beta_ppf\n", " return _boost._beta_ppf(q, a, b)\n", "Sampling 4 chains for 2_000 tune and 2_000 draw iterations (8_000 + 8_000 draws total) took 8 seconds.\n", "There was 1 divergence after tuning. Increase `target_accept` or reparameterize.\n" ] } ], "source": [ "with pm.Model() as model_rlg:\n", " α = pm.Normal('α', mu=0, sd=10)\n", " β = pm.Normal('β', mu=0, sd=10)\n", " \n", " μ = α + x_c * β \n", " θ = pm.Deterministic('θ', pm.math.sigmoid(μ))\n", " bd = pm.Deterministic('bd', -α/β)\n", " \n", " π = pm.Beta('π', 1., 1.) \n", " p = π * 0.5 + (1 - π) * θ \n", " \n", " yl = pm.Bernoulli('yl', p=p, observed=y_0)\n", "\n", " trace_rlg = pm.sample(2000, target_accept=0.95, tune = 2000, return_inferencedata=True)" ] }, { "cell_type": "code", "execution_count": 102, "id": "7b37f4b3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ],\n", " [Text(-2.0, 0, '3.4'),\n", " Text(-1.5, 0, '3.9'),\n", " Text(-1.0, 0, '4.4'),\n", " Text(-0.5, 0, '4.9'),\n", " Text(0.0, 0, '5.4'),\n", " Text(0.5, 0, '5.9'),\n", " Text(1.0, 0, '6.4'),\n", " Text(1.5, 0, '6.9'),\n", " Text(2.0, 0, '7.4')])" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "theta = trace_rlg.posterior['θ'].mean(axis=0).mean(axis=0)\n", "idx = np.argsort(x_c)\n", "\n", "np.random.seed(123)\n", "\n", "\n", "plt.plot(x_c[idx], theta[idx], color='C2', lw=3)\n", "\n", "plt.vlines(trace_rlg.posterior['bd'].mean(), 0, 1, color='k')\n", "\n", "bd_hpd = az.hdi(trace_rlg.posterior['bd'])\n", "\n", "\n", "plt.fill_betweenx([0, 1], bd_hpd.bd[0].values, bd_hpd.bd[1].values, color='k', alpha=0.5)\n", "\n", "\n", "plt.scatter(x_c, np.random.normal(y_0, 0.02),\n", " marker='.', color=[f'C{x}' for x in y_0])\n", "\n", "\n", "az.plot_hdi(x_c, trace_rlg.posterior['θ'], color='C2') #green band \n", "\n", "\n", "plt.xlabel(x_n)\n", "plt.ylabel('θ', rotation=0)\n", "# use original scale for xticks\n", "locs, _ = plt.xticks()\n", "plt.xticks(locs, np.round(locs + x_0.mean(), 1))" ] }, { "cell_type": "code", "execution_count": null, "id": "d28aaf11", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "jupynb_env_new", "language": "python", "name": "jupynb_env_new" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }