# PROBLEM 3: The drift diffusion model of decision-making.¶

In a two-alternative forced choice task (2AFC-task), subjects are asked to choose between two alternative actions. We will here consider the case where a subject receives a visual motion stimulus (a set of points on a screen that are moving in different directions) and then needs to indicate whether the points were moving upwards or downwards. If such a motion stimulus is ambiguous or "noisy", the task can be quite difficult. We will assume that the motion stimulus continues until the subject has made a choice. This scenario is well described by the "drift-diffusion-model", in which the subject compares the firing rate $m_A$ of an upward-motion sensitive neuron with the firing rate $m_B$ of a downward-motion sensitive neuron and integrates the difference between the two:

$$\dot{x} = m_A − m_B + σ η(t)$$

where $η(t)$ is a noise term (Gaussian white noise with unit standard deviation) that simulates the noisiness of real neurons. If the integration variable $x$ surpasses a threshold $μ$, then the subject decides for outcome $A$; if $x$ decreases below a threshold $−μ$, then the subject decides for outcome $B$.

Any ordinary differential equation can be solved numerically using the Euler method, i.e. using the approximation:

$$x(t + ∆t) = x(t) + \dot{x} ∆t \qquad (6)$$

For stochastic differential equations, i.e. those that have a noise-term, the random part grows with the square root of the time step, a technical issue that you can ignore for now. It leads to the following discrete approximation of the drift-diffusion-model:

$$x(t+∆t) = x(t) + (m_A - m_B) ∆t + σ η(t) \sqrt{∆t} \qquad (7)$$

In [339]:
# /!\ I'm using Python 3

import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go

import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
%matplotlib inline

plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
'<script>'
'var waitForPlotly = setInterval( function() {'
'if( typeof(window.Plotly) !== "undefined" ){'
'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
'clearInterval(waitForPlotly);'
'}}, 250 );'
'</script>'
))

# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
cm = plt.get_cmap(cmap)
colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
rgb_colors_plotly = []
rgb_colors_only = []
for i, c in enumerate(colors):
col = 'rgb{}'.format(tuple(255*c[:-1]))
rgb_colors_only.append(col)
rgb_colors_plotly.append([i/number_colors, col])
rgb_colors_plotly.append([(i+1)/number_colors, col])
return rgb_colors_only if return_rgb_only else rgb_colors_plotly


## (a) Assume $m_A = 1$ and $m_B = 0.95$. Plot several runs of the drift-diffusion-model, always starting with $x(0) = 0$. Choose a stepwidth of $∆t = 0.1 \, \mathrm{ ms}$, a noise level $σ = 1/2$ and iterate the Euler method over $10000$ time steps up to time $t = 1 \, \mathrm{ s}$.¶

In [607]:
delta_t = 0.0001
t_max = 1

m_A, m_B = 1, 0.95

sigma = 0.5

x_0 = 0

In [440]:
def drift_diffusion(m_A=m_A, m_B=m_B, sigma=sigma, delta_t=delta_t, t_max=t_max, x_0=x_0):
x = np.zeros(len(np.arange(0, t_max, delta_t)))
x[0], sqrt_delta_t = x_0, np.sqrt(delta_t)

term1, term2 = delta_t*(m_A-m_B), sigma*sqrt_delta_t

for i in range(1, len(x)):
x[i] = x[i-1] + term1 + term2*np.random.standard_normal()

return np.array(x)

layout = go.Layout(
title= '$\\text{Runs of drift-diffusion model, for }'+'σ = {}, \, m_A = {}, \, m_B = {}$'.format(sigma, m_A, m_B),
hovermode= 'closest',
xaxis= dict(
title= 'Time (in seconds)',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
title= '$\\text{Integration variable } x$',
ticklen= 5,
gridwidth= 2,
),
showlegend= True
)

legend_every = 1

# Plotting the evolution of the free parameter w
traces_dd = []

for i in range(10):
traces_dd.append(
go.Scatter(
x = np.arange(0, t_max, delta_t),
y = drift_diffusion(),
mode = 'lines',
name = 'Run {}'.format(i+1),
line = dict(
width = 2,
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
showlegend = (i % legend_every == 0)
)
)

plotly.offline.iplot(go.Figure(data=traces_dd, layout=layout))