Gaussian Processes
Gaussian Processes#
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import pymc as pm
Covariance Functions and Kernel#
def exp_quad_kernel(x,knots,l=1):
return np.array([np.exp(-(x-k)**2/(2*l**2)) for k in knots])
data = np.array([-1,0,1,2])
#data = np.array([-1,0])
cov = exp_quad_kernel(data,data,1)
[[1. 0.60653066 0.13533528 0.011109 ]
[0.60653066 1. 0.60653066 0.13533528]
[0.13533528 0.60653066 1. 0.60653066]
[0.011109 0.13533528 0.60653066 1. ]]
_, ax=plt.subplots(1,2,figsize=(12,5))
ax[0].plot(data, np.zeros_like(data),'ko')
for idx,i in enumerate(data):
ax[0].text(i,0+0.005, idx)
im = ax[1].imshow(cov)
for i in range(len(cov)):
for j in range(len(cov)):
ax[1].text(i,j, round(cov[i,j],2), color=colors[int(im.norm(cov[i,j])>0.5)], ha = 'center', va='center', fontdict={'size':16})
test_points = np.linspace(0,10,200)
fig, ax = plt.subplots(2,2,figsize=(12,6),sharex=True,sharey=True,constrained_layout=True)
ax = np.ravel(ax)
for idx, l in enumerate((0.2,1,2,10)):
cov = exp_quad_kernel(test_points, test_points, l)
ax[idx].plot(test_points, stats.multivariate_normal.rvs(cov=cov,size=3).T)
fig.text(0.51,-0.03,'x', fontsize=16)
fig.text(-0.03,0.5,'f(x)', fontsize=16, rotation=90)
Gaussian Processes: Implementation#
x = np.random.uniform(0,10,size=15)
scale = 0.50
y = np.random.normal(scale*x*np.sin(x), 0.1)
true_x = np.linspace(0,10,200)
true_y = scale*true_x*np.sin(true_x)
plt.plot(true_x,true_y, 'k--')
X = x[:,None]
with pm.Model() as model_reg:
#hyperprior for lengthscale kernel parameter
l = pm.Gamma('l',2,0.5)
#instantiate a covariance function
cov =,ls=l)
#mean =
#instantiate a GP prior
gp =,
eps = pm.HalfNormal('eps',25)
y_pred = gp.marginal_likelihood('y_pred',X=X, y=y, noise=eps)
trace_reg = pm.sample(2000, return_inferencedata=True, target_accept=0.95)
import arviz as az
array([[<Axes: title={'center': 'l'}>, <Axes: title={'center': 'l'}>],
[<Axes: title={'center': 'eps'}>, <Axes: title={'center': 'eps'}>]],
X_new = np.linspace(np.floor(x.min()), np.ceil(x.max()), 100)[:,None]
with model_reg:
#del marginal_gp_model.named_vars['f_pred']
f_pred = gp.conditional('f_pred', X_new)
with model_reg:
pred_samples = pm.sample_posterior_predictive(trace_reg, var_names=['f_pred'])
<xarray.Dataset> Dimensions: (chain: 2, draw: 2000, f_pred_dim_2: 100) Coordinates: * chain (chain) int64 0 1 * draw (draw) int64 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999 * f_pred_dim_2 (f_pred_dim_2) int64 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99 Data variables: f_pred (chain, draw, f_pred_dim_2) float64 0.0928 0.04929 ... -2.191 Attributes: created_at: 2024-03-26T01:56:59.862834 arviz_version: 0.15.1 inference_library: pymc inference_library_version: 5.10.4
<xarray.Dataset> Dimensions: (y_pred_dim_0: 15) Coordinates: * y_pred_dim_0 (y_pred_dim_0) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 Data variables: y_pred (y_pred_dim_0) float64 -1.12 -0.4835 2.889 ... 0.9431 0.8937 Attributes: created_at: 2024-03-26T01:56:59.864786 arviz_version: 0.15.1 inference_library: pymc inference_library_version: 5.10.4
(100, 2000)
(100, 1)
_, ax = plt.subplots(figsize=(12,5))
ax.plot(X_new, pred_samples.posterior_predictive['f_pred'].mean(axis=0).T, 'C1-', alpha=0.3)
ax.plot(X, y, 'ko')
_, ax = plt.subplots(figsize=(12,5)), pred_samples.posterior_predictive['f_pred'].mean(axis=0), X_new, palette='viridis',plot_samples=False)
_, ax = plt.subplots(figsize=(12,5))
point = {'l':trace_reg.posterior['l'].mean(), 'eps':trace_reg.posterior['eps'].mean()}
mu, var = gp.predict(X_new, point=point, diag=True, model = model_reg)
sd = var**0.5
ax.fill_between(X_new.flatten(),mu-sd,mu+sd, color="C1",alpha=0.3)
ax.fill_between(X_new.flatten(),mu-2*sd,mu+2*sd, color="C1",alpha=0.3)
Real-world example: Spawning Salmon#
The plot below shows the relationship between the number of spawning salmon in a particular stream and the number of fry that are recruited into the population in the spring.
Biological knowledge suggests this relationship is not linear and we would like to model it.
import requests
import pandas as pd
import io
target_url = ''
download = requests.get(target_url).content
salmon_data = pd.read_table(io.StringIO(download.decode('utf-8')), sep='\s+', index_col=0)
salmon_data.plot.scatter(x='spawners', y='recruits', s=50);
#we have prior knowledge about fish population growth, and we can include a linear mean function as a prior
with pm.Model() as salmon_model:
# Lengthscale
ρ = pm.HalfCauchy('ρ', 1)
η = pm.HalfCauchy('η', 1)
M =
K = (η**2) *, ρ)
with salmon_model:
σ = pm.HalfCauchy('σ', 1)
recruit_gp =, cov_func=K)
recruit_gp.marginal_likelihood('recruits', X=salmon_data.spawners.values.reshape(-1,1),
y=salmon_data.recruits.values, noise=σ)
with salmon_model:
salmon_trace = pm.sample(1000, return_inferencedata=True)
az.plot_trace(salmon_trace, var_names=['ρ','η','σ']) #'η'
X_pred = np.linspace(0, 500, 100).reshape(-1, 1)
with salmon_model:
salmon_pred3 = recruit_gp.conditional('salmon_pred3', X_pred)
with salmon_model:
salmon_samples = pm.sample_posterior_predictive(salmon_trace, var_names=['salmon_pred3'])
_, ax = plt.subplots(figsize=(12,5))
ax.plot(X_pred, salmon_samples.posterior_predictive['salmon_pred3'].mean(axis=0).T, 'C1-', alpha=0.3)
ax.plot(salmon_data['spawners'].values,salmon_data['recruits'].values, 'ko')
What happens if the population gets very large, e.g., at 600 or 800 spawners?
X_pred = np.linspace(0, 800, 100).reshape(-1, 1)
with salmon_model:
salmon_pred4 = recruit_gp.conditional('salmon_pred4', X_pred)
with salmon_model:
salmon_samples = pm.sample_posterior_predictive(salmon_trace, var_names=['salmon_pred4'])
_, ax = plt.subplots(figsize=(12,5))
ax.plot(X_pred, salmon_samples.posterior_predictive['salmon_pred4'].mean(axis=0).T, 'C1-', alpha=0.3)
ax.plot(salmon_data['spawners'].values,salmon_data['recruits'].values, 'ko')
_, ax = plt.subplots(figsize=(12,5))
mu = salmon_samples.posterior_predictive['salmon_pred4'].mean(axis=0).T.mean(axis=1)
sd = salmon_samples.posterior_predictive['salmon_pred4'].std(axis=0).T.mean(axis=1)
ax.fill_between(X_pred.flatten(),mu-sd,mu+sd, color="C1",alpha=0.3)
ax.fill_between(X_pred.flatten(),mu-3*sd,mu+3*sd, color="C1",alpha=0.3)
ax.plot(salmon_data['spawners'].values,salmon_data['recruits'].values, 'ko')
