{ "cells": [ { "cell_type": "markdown", "id": "e9b462e4", "metadata": {}, "source": [ "# Generalization to Bayesian Softmax Regression" ] }, { "cell_type": "markdown", "id": "8510f432", "metadata": {}, "source": [ "Ref: Chap 4 of Mar18\n", "\n", "https://cfteach.github.io/brds/referencesmd.html" ] }, { "cell_type": "code", "execution_count": 2, "id": "5015e78e", "metadata": {}, "outputs": [], "source": [ "import pymc3 as pm\n", "import numpy as np\n", "import pandas as pd\n", "import theano.tensor as tt\n", "import seaborn as sns\n", "import scipy.stats as stats\n", "from scipy.special import expit as logistic\n", "import matplotlib.pyplot as plt\n", "import arviz as az\n", "import requests\n", "import io " ] }, { "cell_type": "code", "execution_count": 3, "id": "6c2e3cbf", "metadata": {}, "outputs": [], "source": [ "az.style.use('arviz-darkgrid')" ] }, { "cell_type": "code", "execution_count": 4, "id": "224771aa", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | sepal_length | \n", "sepal_width | \n", "petal_length | \n", "petal_width | \n", "species | \n", "
---|---|---|---|---|---|
0 | \n", "5.1 | \n", "3.5 | \n", "1.4 | \n", "0.2 | \n", "setosa | \n", "
1 | \n", "4.9 | \n", "3.0 | \n", "1.4 | \n", "0.2 | \n", "setosa | \n", "
2 | \n", "4.7 | \n", "3.2 | \n", "1.3 | \n", "0.2 | \n", "setosa | \n", "
3 | \n", "4.6 | \n", "3.1 | \n", "1.5 | \n", "0.2 | \n", "setosa | \n", "
4 | \n", "5.0 | \n", "3.6 | \n", "1.4 | \n", "0.2 | \n", "setosa | \n", "
\n", " (i) Correlated data has typically less power to restrict the model; correlated variables translate into wider combinations of coefficients that are able to explain the data.\n", "
\n", "\n", " (ii) One solution when dealing with highly correlated variables is to remove one (or more than one) correlated variable.\n", "
\n", "\n", " (iii) Another option is scaling all non-binary variables to have a mean of 0, and then using:\n", "
\n", "\n", " $sd$ should be chosen to weekly inform us about the expected value for the scale. The normality parameter $\\nu$ is typically chosen to be in the range (3,7). \n", " This prior is saying that in general we expect the coefficienct to be small, but we use wide tails because occasionally we will find some larger coefficients.\n", " \n", "
\n", "\n", "\n", " (i) In case of an unbalanced dataset, logistic regression can run into some trouble: the boundary cannot be determined as accurately as when the dataset is more balanced. \n", "
\n", "\n", " (ii) The decision boundary is \"shifted\" towards the less abundant class, and the uncertainty band is larger.\n", "
\n", "\n", " (iii) It is always good to have a balanced dataset. If you do have unbalanced data though, you should be careful when you interpret results: check the uncertainty of the model, and run some posterior predictive checks for consistency. Another option is to input more prior information if available and/or run an alternative model. \n", "
\n", " \n", "\n", "\n", " In order to generalize to mutliple classes, two modifications are needed: \n", "
\n", "\n", " (i) We use a softmax (see also Boltzmann distribution in physics), which is defined as:\n", "
\n", "\n", " (ii) We then replace the Bernoulli distribution with the \n", " categorical distribution.\n", " As the Bernoulli (single coin flip) is a special case of a Binomial (n coin flips), the categorical (single roll of a die) is a special case of the multinomial distribution (n rolls of a die).\n", "
\n", "\n", "\n", " | mean | \n", "sd | \n", "hdi_3% | \n", "hdi_97% | \n", "mcse_mean | \n", "mcse_sd | \n", "ess_bulk | \n", "ess_tail | \n", "r_hat | \n", "
---|---|---|---|---|---|---|---|---|---|
α[0] | \n", "-1.135 | \n", "3.746 | \n", "-7.921 | \n", "6.067 | \n", "0.048 | \n", "0.039 | \n", "6080.0 | \n", "5193.0 | \n", "1.0 | \n", "
α[1] | \n", "5.830 | \n", "3.251 | \n", "-0.181 | \n", "11.987 | \n", "0.044 | \n", "0.031 | \n", "5382.0 | \n", "5418.0 | \n", "1.0 | \n", "
α[2] | \n", "-4.846 | \n", "3.421 | \n", "-11.267 | \n", "1.436 | \n", "0.046 | \n", "0.035 | \n", "5554.0 | \n", "5354.0 | \n", "1.0 | \n", "
β[0, 0] | \n", "-2.604 | \n", "4.099 | \n", "-10.736 | \n", "4.568 | \n", "0.047 | \n", "0.041 | \n", "7711.0 | \n", "5642.0 | \n", "1.0 | \n", "
β[0, 1] | \n", "1.985 | \n", "3.267 | \n", "-4.214 | \n", "7.966 | \n", "0.044 | \n", "0.034 | \n", "5404.0 | \n", "4740.0 | \n", "1.0 | \n", "
β[0, 2] | \n", "0.640 | \n", "3.265 | \n", "-5.464 | \n", "6.770 | \n", "0.045 | \n", "0.036 | \n", "5341.0 | \n", "5130.0 | \n", "1.0 | \n", "
β[1, 0] | \n", "3.186 | \n", "3.414 | \n", "-3.514 | \n", "9.272 | \n", "0.048 | \n", "0.034 | \n", "5167.0 | \n", "5913.0 | \n", "1.0 | \n", "
β[1, 1] | \n", "-1.009 | \n", "3.037 | \n", "-7.135 | \n", "4.367 | \n", "0.045 | \n", "0.032 | \n", "4579.0 | \n", "5063.0 | \n", "1.0 | \n", "
β[1, 2] | \n", "-2.408 | \n", "3.057 | \n", "-8.328 | \n", "3.168 | \n", "0.045 | \n", "0.032 | \n", "4636.0 | \n", "5119.0 | \n", "1.0 | \n", "
β[2, 0] | \n", "-6.333 | \n", "4.256 | \n", "-14.192 | \n", "1.548 | \n", "0.047 | \n", "0.036 | \n", "8195.0 | \n", "6219.0 | \n", "1.0 | \n", "
β[2, 1] | \n", "-1.395 | \n", "3.540 | \n", "-8.308 | \n", "5.034 | \n", "0.042 | \n", "0.035 | \n", "7265.0 | \n", "6032.0 | \n", "1.0 | \n", "
β[2, 2] | \n", "7.810 | \n", "3.743 | \n", "0.967 | \n", "15.085 | \n", "0.044 | \n", "0.032 | \n", "7166.0 | \n", "5728.0 | \n", "1.0 | \n", "
β[3, 0] | \n", "-5.792 | \n", "4.358 | \n", "-14.132 | \n", "2.098 | \n", "0.051 | \n", "0.038 | \n", "7316.0 | \n", "5879.0 | \n", "1.0 | \n", "
β[3, 1] | \n", "-1.078 | \n", "3.558 | \n", "-7.501 | \n", "5.823 | \n", "0.046 | \n", "0.037 | \n", "5998.0 | \n", "5723.0 | \n", "1.0 | \n", "
β[3, 2] | \n", "6.744 | \n", "3.683 | \n", "-0.055 | \n", "13.745 | \n", "0.047 | \n", "0.034 | \n", "6223.0 | \n", "5727.0 | \n", "1.0 | \n", "
\n", " (i) 98% is the accuracy on our data; \n", " a true test to evaluate the performance of our model will be to check it on data not used to fit the model\n", "
\n", "\n", " (ii) You can check that we obtained a wide posterior. This is a result of the fact softmax normalizes probability to 1. Therefore, when we used priors on the parameters of 4 species, in reality we can \"eliminate\" one species\" from the problem, in that one of them can be calculated from the other 3 once we know their probabilities (again, they have to sum up to 1!)\n", "
\n", "\n", " (iii) Below is a suggested solution, that does fix the extra parameters to some value, e.g., zero \n", "
\n", " \n", " \n", " \n", "\n", "