Several classical conditioning experiments can be qualitatively reproduced by the Rescorla-Wagner model (see Chapter 9.1-9.2 of the Dayan & Abbott). To explain these experimental observations, we assume that the objective of an animal is to predict the presence of certain events such as a food reward. Let us denote the presence or absence of a such a reward (also called unconditioned stimulus or UCS) by $r = 1$ or $r = 0$, respectively. Other events such as stimuli (also called conditioned stimuli or CS) may or may not predict the occurence of this reward, and we denote their presence by $u = 1$ and their absence by $u = 0$. The organism's task is then to predict if a reward is present, depending on whether the stimulus was present. We denote the animal's prediction by $v$ and write
$$v = wu \qquad \text{(1)}$$
where $w$ is a free parameter that the animal needs to learn.
After every trial of a conditioning experiment, this parameter is learned or updated using the Rescorla-Wagner learning rule:
$$w → w + εδu \qquad \text{(2)}$$
where
# /!\ I'm using Python 3
import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
%matplotlib inline
plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
'<script>'
'var waitForPlotly = setInterval( function() {'
'if( typeof(window.Plotly) !== "undefined" ){'
'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
'clearInterval(waitForPlotly);'
'}}, 250 );'
'</script>'
))
# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
cm = plt.get_cmap(cmap)
colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
rgb_colors_plotly = []
rgb_colors_only = []
for i, c in enumerate(colors):
col = 'rgb{}'.format(tuple(255*c[:-1]))
rgb_colors_only.append(col)
rgb_colors_plotly.append([i/number_colors, col])
rgb_colors_plotly.append([(i+1)/number_colors, col])
return rgb_colors_only if return_rgb_only else rgb_colors_plotly
nb_trials = 50
index_change = 25
u = np.ones(nb_trials)
r = np.ones(nb_trials)
r[index_change:] = 0
layout= go.Layout(
title= 'Trials of a conditioning experiment',
hovermode= 'closest',
xaxis= dict(
title= 'Trial number...',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
title= 'Value',
ticklen= 5,
gridwidth= 2,
),
showlegend= True
)
trace_stimulus = go.Scatter(
x = list(range(1, nb_trials+1)),
y = u,
mode = 'markers',
name = 'Stimulus u'
)
trace_reward = go.Scatter(
x = list(range(1, nb_trials+1)),
y = r,
mode = 'lines',
name = 'Reward r',
line = dict(
dash = 'dash',
width = 3
)
)
plotly.offline.iplot(
go.Figure(
data=[trace_stimulus, trace_reward],
layout=layout)
)
epsilon = 0.1
def gradient_descent(epsilon=epsilon):
w = np.zeros(nb_trials)
for i in range(nb_trials):
w[i] = w[i-1]+epsilon*(r[i]-w[i-1]*u[i])*u[i]
return np.array(w)
trace_w = go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent(),
mode = 'lines+markers',
name = 'Prediction v',
line = dict(
width = 3
)
)
plotly.offline.iplot(
go.Figure(
data=[trace_stimulus, trace_reward, trace_w],
layout=layout)
)
display(Markdown("### For different values of `epsilon` $∈ [0.1, 1]$"))
number_lines = 10
legend_every = 1
value_min, value_max = .1, 1.
delta_val = (value_max-value_min)/(number_lines-1)
colors = colorscale_list('Greens', number_lines+3, return_rgb_only=True)
# Plotting the evolution of the free parameter w
traces_eps = []
for i, epsilon in enumerate([round(value_min + i*delta_val,2) for i in range(number_lines)]):
traces_eps.append(
go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent(epsilon=epsilon)*u,
mode = 'lines',
name = 'Prediction with epsilon = {}'.format(epsilon),
line = dict(
width = 1,
color = colors[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
plotly.offline.iplot(go.Figure(data=[trace_stimulus, trace_reward]+traces_eps, layout=layout))
display(Markdown("### For different values of `epsilon` $∈ [1, 2[$"))
number_lines = 10
legend_every = 1
value_min, value_max = 1, 1.9
delta_val = (value_max-value_min)/(number_lines-1)
colors = colorscale_list('Greens', number_lines+3, return_rgb_only=True)
# Plotting the evolution of the free parameter w
traces_eps = []
for i, epsilon in enumerate([round(value_min + i*delta_val,2) for i in range(number_lines)]):
traces_eps.append(
go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent(epsilon=epsilon)*u,
mode = 'lines',
name = 'Prediction with epsilon = {}'.format(epsilon),
line = dict(
width = 1,
color = colors[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
plotly.offline.iplot(go.Figure(data=[trace_stimulus, trace_reward]+traces_eps, layout=layout))
display(Markdown("### For different values of `epsilon` $≥ 2$"))
number_lines = 5
legend_every = 1
value_min, value_max = 2, 2.05
delta_val = (value_max-value_min)/(number_lines-1)
colors = colorscale_list('Greens', number_lines+3, return_rgb_only=True)
# Plotting the evolution of the free parameter w
traces_eps = []
for i, epsilon in enumerate([round(value_min + i*delta_val,2) for i in range(number_lines)]):
traces_eps.append(
go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent(epsilon=epsilon)*u,
mode = 'lines',
name = 'Prediction with epsilon = {}'.format(epsilon),
line = dict(
width = 1,
color = colors[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
layout2 = go.Layout(
title= 'Trials of a conditioning experiment',
hovermode= 'closest',
xaxis= dict(
range=[10, 50],
title= 'Trial number...',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
range=[-10, 10],
title= 'Value',
ticklen= 5,
gridwidth= 2,
),
showlegend= True
)
plotly.offline.iplot(go.Figure(data=[trace_stimulus, trace_reward]+traces_eps, layout=layout2))
proba = 0.4
np.random.seed(1)
r = (np.random.rand(nb_trials)<proba).astype(int)
trace_reward = go.Scatter(
x = list(range(1, nb_trials+1)),
y = r,
mode = 'lines',
name = 'Random reward r with proba p={}'.format(proba),
line = dict(
dash = 'dash'
)
)
plotly.offline.iplot(
go.Figure(
data=[trace_stimulus, trace_reward],
layout=layout)
)
trace_w = go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent()*u,
mode = 'lines+markers',
name = 'Prediction with epsilon=0.1',
line = dict(
width = 3
)
)
plotly.offline.iplot(
go.Figure(
data=[trace_stimulus, trace_reward, trace_w],
layout=layout)
)
display(Markdown("### For different values of `epsilon` $∈ [0.1, 1]$"))
number_lines = 10
legend_every = 1
value_min, value_max = .1, 1
delta_val = (value_max-value_min)/(number_lines-1)
colors = colorscale_list('Greens', number_lines+3, return_rgb_only=True)
# Plotting the evolution of the free parameter w
traces_eps = []
for i, epsilon in enumerate([round(value_min + i*delta_val, 2) for i in range(number_lines)]):
traces_eps.append(
go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent(epsilon=epsilon)*u,
mode = 'lines',
name = 'Prediction with epsilon = {}'.format(epsilon),
line = dict(
width = 1,
color = colors[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
plotly.offline.iplot(go.Figure(data=[trace_stimulus, trace_reward]+traces_eps, layout=layout))
display(Markdown("### For different values of `epsilon` $∈ [1, 2[$"))
number_lines = 10
legend_every = 1
value_min, value_max = 1, 1.9
delta_val = (value_max-value_min)/(number_lines-1)
colors = colorscale_list('Greens', number_lines+3, return_rgb_only=True)
# Plotting the evolution of the free parameter w
traces_eps = []
for i, epsilon in enumerate([round(value_min + i*delta_val, 2) for i in range(number_lines)]):
traces_eps.append(
go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent(epsilon=epsilon)*u,
mode = 'lines',
name = 'Prediction with epsilon = {}'.format(epsilon),
line = dict(
width = 1,
color = colors[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
plotly.offline.iplot(go.Figure(data=[trace_stimulus, trace_reward]+traces_eps, layout=layout))
display(Markdown("### For different values of `epsilon` $≥2$"))
number_lines = 5
legend_every = 1
value_min, value_max = 2, 2.05
delta_val = (value_max-value_min)/(number_lines-1)
colors = colorscale_list('Greens', number_lines+3, return_rgb_only=True)
# Plotting the evolution of the free parameter w
traces_eps = []
for i, epsilon in enumerate([round(value_min + i*delta_val, 2) for i in range(number_lines)]):
traces_eps.append(
go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent(epsilon=epsilon)*u,
mode = 'lines',
name = 'Prediction with epsilon = {}'.format(epsilon),
line = dict(
width = 1,
color = colors[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
layout3 = go.Layout(
title= 'Trials of a conditioning experiment',
hovermode= 'closest',
xaxis= dict(
range=[0, 50],
title= 'Trial number...',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
range=[-10, 10],
title= 'Value',
ticklen= 5,
gridwidth= 2,
),
showlegend= True
)
plotly.offline.iplot(go.Figure(data=[trace_stimulus, trace_reward]+traces_eps, layout=layout3))
u = np.ones(nb_trials)
r = np.ones(nb_trials)
trace_reward = go.Scatter(
x = list(range(1, nb_trials+1)),
y = r,
mode = 'lines',
name = 'Reward r',
line = dict(
dash = 'dash',
width = 3
)
)
trace_w_1D = go.Scatter(
x = list(range(1, nb_trials+1)),
y = gradient_descent()*u,
mode = 'lines',
name = 'Prediction v based on one (the first) stimulus only',
line = dict(
width = 1
)
)
plotly.offline.iplot(
go.Figure(
data=[trace_stimulus, trace_reward, trace_w_1D],
layout=layout)
)
def gradient_descent_2D(epsilons=(0.1, 0.1)):
w = np.zeros((nb_trials, 2))
eps = np.array(epsilons)
for i in range(nb_trials):
w[i] = w[i-1]+(r[i]-w[i-1].dot(u[i]))*eps*u[i]
return np.array(w)
nb_trials = 50
index_change = 25
u = np.ones((nb_trials, 2))
r = np.ones(nb_trials)
u[:index_change, -1] = 0
trace_stimulus_2D = []
trace_stimulus_2D.append(go.Scatter(
x = list(range(1, nb_trials+1)),
y = u[:,0],
mode = 'markers',
name = 'Stimulus 1',
marker = dict(
size = 9
)
))
trace_stimulus_2D.append(go.Scatter(
x = list(range(1, nb_trials+1)),
y = u[:,1],
mode = 'markers',
name = 'Stimulus 2',
marker = dict(
symbol = 'diamond',
color = 'yellow'
),
))
trace_reward = go.Scatter(
x = list(range(1, nb_trials+1)),
y = r,
mode = 'lines',
name = 'Reward r',
line = dict(
dash = 'dash',
width = 4,
color = 'orange'
)
)
trace_w = go.Scatter(
x = list(range(1, nb_trials+1)),
y = np.array([w.dot(ui) for w, ui in zip(gradient_descent_2D(), u)]),
mode = 'lines',
name = 'Prediction v for the two stimuli',
line = dict(
width = 3,
dash = 'dash',
color = 'green'
)
)
plotly.offline.iplot(
go.Figure(
data=trace_stimulus_2D+[trace_reward, trace_w, trace_w_1D],
layout=layout)
)
gradient_descent_2D()
u = np.ones((nb_trials, 2))
trace_stimulus_2D = []
trace_stimulus_2D.append(go.Scatter(
x = list(range(1, nb_trials+1)),
y = u[:,0],
mode = 'markers',
name = 'Stimulus 1',
marker = dict(
size = 9
)
))
trace_stimulus_2D.append(go.Scatter(
x = list(range(1, nb_trials+1)),
y = u[:,1],
mode = 'markers',
name = 'Stimulus 2',
marker = dict(
symbol = 'diamond',
color = 'yellow'
),
))
trace_reward = go.Scatter(
x = list(range(1, nb_trials+1)),
y = r,
mode = 'lines',
name = 'Reward r',
line = dict(
dash = 'dash',
width = 4,
color = 'orange'
)
)
trace_w = go.Scatter(
x = list(range(1, nb_trials+1)),
y = np.array([w.dot(ui) for w, ui in zip(gradient_descent_2D(epsilons=(0.2, 0.1)), u)]),
mode = 'lines',
name = 'Prediction v for the two stimuli',
line = dict(
width = 2,
dash = 'dash',
color = 'green'
)
)
plotly.offline.iplot(
go.Figure(
data=trace_stimulus_2D+[trace_reward, trace_w, trace_w_1D],
layout=layout)
)
gradient_descent_2D(epsilons=(0.2, 0.1))