In a two-alternative forced choice task (2AFC-task), subjects are asked to choose between two alternative actions. We will here consider the case where a subject receives a visual motion stimulus (a set of points on a screen that are moving in different directions) and then needs to indicate whether the points were moving upwards or downwards. If such a motion stimulus is ambiguous or "noisy", the task can be quite difficult. We will assume that the motion stimulus continues until the subject has made a choice. This scenario is well described by the "drift-diffusion-model", in which the subject compares the firing rate $m_A$ of an upward-motion sensitive neuron with the firing rate $m_B$ of a downward-motion sensitive neuron and integrates the difference between the two:
$$\dot{x} = m_A − m_B + σ η(t)$$
where $η(t)$ is a noise term (Gaussian white noise with unit standard deviation) that simulates the noisiness of real neurons. If the integration variable $x$ surpasses a threshold $μ$, then the subject decides for outcome $A$; if $x$ decreases below a threshold $−μ$, then the subject decides for outcome $B$.
Any ordinary differential equation can be solved numerically using the Euler method, i.e. using the approximation:
$$x(t + ∆t) = x(t) + \dot{x} ∆t \qquad (6)$$
For stochastic differential equations, i.e. those that have a noise-term, the random part grows with the square root of the time step, a technical issue that you can ignore for now. It leads to the following discrete approximation of the drift-diffusion-model:
$$x(t+∆t) = x(t) + (m_A - m_B) ∆t + σ η(t) \sqrt{∆t} \qquad (7)$$
# /!\ I'm using Python 3
import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
%matplotlib inline
plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
'<script>'
'var waitForPlotly = setInterval( function() {'
'if( typeof(window.Plotly) !== "undefined" ){'
'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
'clearInterval(waitForPlotly);'
'}}, 250 );'
'</script>'
))
# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
cm = plt.get_cmap(cmap)
colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
rgb_colors_plotly = []
rgb_colors_only = []
for i, c in enumerate(colors):
col = 'rgb{}'.format(tuple(255*c[:-1]))
rgb_colors_only.append(col)
rgb_colors_plotly.append([i/number_colors, col])
rgb_colors_plotly.append([(i+1)/number_colors, col])
return rgb_colors_only if return_rgb_only else rgb_colors_plotly
delta_t = 0.0001
t_max = 1
m_A, m_B = 1, 0.95
sigma = 0.5
x_0 = 0
def drift_diffusion(m_A=m_A, m_B=m_B, sigma=sigma, delta_t=delta_t, t_max=t_max, x_0=x_0):
x = np.zeros(len(np.arange(0, t_max, delta_t)))
x[0], sqrt_delta_t = x_0, np.sqrt(delta_t)
term1, term2 = delta_t*(m_A-m_B), sigma*sqrt_delta_t
for i in range(1, len(x)):
x[i] = x[i-1] + term1 + term2*np.random.standard_normal()
return np.array(x)
layout = go.Layout(
title= '$\\text{Runs of drift-diffusion model, for }'+'σ = {}, \, m_A = {}, \, m_B = {}$'.format(sigma, m_A, m_B),
hovermode= 'closest',
xaxis= dict(
title= 'Time (in seconds)',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
title= '$\\text{Integration variable } x$',
ticklen= 5,
gridwidth= 2,
),
showlegend= True
)
legend_every = 1
# Plotting the evolution of the free parameter w
traces_dd = []
for i in range(10):
traces_dd.append(
go.Scatter(
x = np.arange(0, t_max, delta_t),
y = drift_diffusion(),
mode = 'lines',
name = 'Run {}'.format(i+1),
line = dict(
width = 2,
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
plotly.offline.iplot(go.Figure(data=traces_dd, layout=layout))
def reaction_time_slow(m_A=m_A, m_B=m_B, sigma=sigma, delta_t=delta_t, x_0=x_0, mu=0.8):
term1, term2 = delta_t*(m_A-m_B), sigma*np.sqrt(delta_t)
x = x_0
i = 0
while -mu <= x <= mu:
i += 1
x += term1 + term2*np.random.standard_normal()
# Return reaction time in milliseconds, and the outcome (1 for A, 0 for B)
return 0.1+delta_t*i, 1 if x > mu else 0
def reaction_times(m_A=m_A, m_B=m_B, m_E=None, sigma=sigma, delta_t=delta_t, x_0=x_0, mu=0.8,\
number_trials=100000, n_max=150000):
if not m_E:
m_E = m_A-m_B
term1, term2 = delta_t*m_E, sigma*np.sqrt(delta_t)
outcome_A, outcome_B = np.zeros(n_max), np.zeros(n_max)
for trial in range(number_trials):
eta = np.random.standard_normal()
sol_A = np.roots([term1, term2*eta, x_0-mu])**2
sol_B = np.roots([term1, term2*eta, x_0+mu])**2
for n in [int(np.ceil(n)) for n in sol_A[np.logical_and(sol_A.imag == 0, sol_A.real >= 0)]]:
if n < n_max:
outcome_A[n] += 1
for n in [int(np.ceil(n)) for n in sol_B[np.logical_and(sol_B.imag == 0, sol_B.real >= 0)]]:
if n < n_max:
outcome_B[n] += 1
return outcome_A, outcome_B
def reaction_times2(m_A=m_A, m_B=m_B, m_E=None, sigma=sigma, delta_t=delta_t, x_0=x_0,\
number_trials=1000, n_max=150000, mu=0.8):
# Alternative algorithm to the one above, but reaction_times is better
if not m_E:
m_E = m_A-m_B
coeffs = np.vstack((np.full(number_trials, delta_t*m_E),\
sigma*np.sqrt(delta_t)*np.random.standard_normal(number_trials))).T
time_steps = np.vstack((np.arange(n_max), np.sqrt(np.arange(n_max))))
trial_at_fixed_time = coeffs.dot(time_steps)+x_0
cumul_A, cumul_B = zip(*[(len(trial_at_fixed_time[:, i][trial_at_fixed_time[:, i] > mu]),\
len(trial_at_fixed_time[:, i][trial_at_fixed_time[:, i] < -mu]))\
for i in range(n_max)])
# Return, from each time frame, how many trials
# among number_trials have had outcome A (i.e. 1)
# and how many have had outcome B
return np.hstack((np.zeros(1), np.diff(cumul_A))), np.hstack((np.zeros(1), np.diff(cumul_B)))
number_trials = 1000
RT, outcome = np.zeros(number_trials), np.zeros(number_trials)
for i in range(number_trials):
RT[i], outcome[i] = reaction_time_slow()
trace1 = go.Histogram(
x=RT[outcome == 1],
opacity=0.75,
name='A'
)
trace2 = go.Histogram(
x=RT[outcome == 0],
opacity=0.75,
name='B'
)
layout_bar = go.Layout(
barmode='overlay',
title='Reaction time for outcome A and B',
xaxis=dict(
title='Reaction time (in seconds)'
),
yaxis=dict(
title='Number of trials with this reaction time'
)
)
fig = go.Figure(data=[trace1, trace2], layout=layout_bar)
plotly.offline.iplot(fig)
number_trials = 10000
RT, outcome = np.zeros(number_trials), np.zeros(number_trials)
for i in range(number_trials):
RT[i], outcome[i] = reaction_time_slow()
trace1 = go.Histogram(
x=RT[outcome == 1],
opacity=0.75,
name='A'
)
trace2 = go.Histogram(
x=RT[outcome == 0],
opacity=0.75,
name='B'
)
layout_bar = go.Layout(
barmode='overlay',
title='Reaction time for outcome A and B',
xaxis=dict(
title='Reaction time (in seconds)'
),
yaxis=dict(
title='Number of trials with this reaction time'
)
)
fig = go.Figure(data=[trace1, trace2], layout=layout_bar)
plotly.offline.iplot(fig)
n_max = 150000
t_max = 0.1 + n_max*delta_t
delta_bin = 0.20
print(int(t_max/delta_bin))
A, B = reaction_times(number_trials=100000)
A = np.hstack((np.zeros(int(0.1/delta_t)), A))
B = np.hstack((np.zeros(int(0.1/delta_t)), B))
A = np.array(list(map(np.sum, np.array_split(A, int(t_max/delta_bin))))).flatten()
B = np.array(list(map(np.sum, np.array_split(B, int(t_max/delta_bin))))).flatten()
trace_A = go.Bar(
x = [i*delta_bin for i in range(int(t_max/delta_bin))],
y = A,
opacity=0.75,
name = 'A'
)
trace_B = go.Bar(
x = [i*delta_bin for i in range(int(t_max/delta_bin))],
y = B,
opacity=0.75,
name = 'B'
)
layout_bar = go.Layout(
barmode='overlay',
title='Reaction time for outcome A and B',
xaxis=dict(
title='Reaction time (in seconds)'
),
yaxis=dict(
title='Number of trials with this reaction time'
)
)
plotly.offline.iplot(
go.Figure(
data=[trace_A, trace_B],
layout=layout_bar)
)
$$p_A = \frac{1}{1+ \exp(-β(m_A - m_B))} \qquad (8)$$
where $β = 2μ/σ^2$.
number_trials = 100
mu = 5
sigma = 0.5
beta = 2*mu/(sigma**2)
def sigmoid(x, beta=beta):
return 1/(1+np.exp(-beta * x))
sigmoid = np.vectorize(sigmoid)
def prob_outcome_A(m_E, mu=mu, sigma=sigma):
A, B = reaction_times(m_E=m_E, mu=mu, sigma=sigma, number_trials=1500)
a = np.sum(A)
return a/(a+np.sum(B))
prob_outcome_A = np.vectorize(prob_outcome_A)
layout_final = go.Layout(
title= '$\\text{Probability of outcome A with respect to } m_E$',
hovermode= 'closest',
xaxis= dict(
title= '$m_E = m_A - m_B$',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
title= 'Probability',
ticklen= 5,
gridwidth= 2,
),
showlegend= True
)
sigma = 0.5
display(Markdown("### For different values of $μ$, and $σ$ = {}".format(sigma)))
legend_every = 1
values = [2, 3, 5]
colors1 = colorscale_list('Greens', len(values)+3, return_rgb_only=True)
colors2 = colorscale_list('Reds', len(values)+3, return_rgb_only=True)
# Plotting the evolution
traces_mu = []
for i, mu in enumerate(values):
traces_mu.append(
go.Scatter(
x = np.linspace(-0.2, 0.2, 40),
y = prob_outcome_A(np.linspace(-0.2, 0.2, 40), mu=mu, sigma=sigma),
mode = 'lines+markers',
name = '$\qquad \quad \\text{Proba. of A with } μ ='+ '{}$'.format(mu),
line = dict(
width = 3,
color = colors1[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
traces_mu.append(
go.Scatter(
x = np.linspace(-0.2, 0.2, 40),
y = sigmoid(np.linspace(-0.2, 0.2, 40), beta=2*mu/(sigma**2)),
mode = 'lines',
name = '$\qquad \quad p_A \\text{ with } μ ='\
+'{}$'.format(mu),
line = dict(
width = 2,
color = colors2[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
plotly.offline.iplot(go.Figure(data=traces_mu, layout=layout_final))
mu = 3
display(Markdown("### For different values of $σ$, and $μ$ = {}".format(mu)))
legend_every = 1
values = [0.3, 0.7, 1.1]
colors1 = colorscale_list('Purples', len(values)+3, return_rgb_only=True)
colors2 = colorscale_list('Oranges', len(values)+3, return_rgb_only=True)
# Plotting the evolution
traces_sigma = []
for i, sigma in enumerate(values):
traces_sigma.append(
go.Scatter(
x = np.linspace(-0.2, 0.2, 40),
y = prob_outcome_A(np.linspace(-0.2, 0.2, 40), mu=mu, sigma=sigma),
mode = 'lines+markers',
name = '$\qquad \quad \\text{Proba. of A with } σ ='+ '{}$'.format(sigma),
line = dict(
width = 3,
color = colors1[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
traces_sigma.append(
go.Scatter(
x = np.linspace(-0.2, 0.2, 40),
y = sigmoid(np.linspace(-0.2, 0.2, 40), beta=2*mu/(sigma**2)),
mode = 'lines',
name = '$\qquad \quad p_A \\text{ with } σ ='\
+'{}$'.format(sigma),
line = dict(
width = 2,
color = colors2[i+2],
shape = 'spline',
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)
plotly.offline.iplot(go.Figure(data=traces_sigma, layout=layout_final))