AT2 – Neuromodeling: Problem set #2 QUANTITATIVE MODELS OF BEHAVIOR

PROBLEM 3: The drift diffusion model of decision-making.

In a two-alternative forced choice task (2AFC-task), subjects are asked to choose between two alternative actions. We will here consider the case where a subject receives a visual motion stimulus (a set of points on a screen that are moving in different directions) and then needs to indicate whether the points were moving upwards or downwards. If such a motion stimulus is ambiguous or "noisy", the task can be quite difficult. We will assume that the motion stimulus continues until the subject has made a choice. This scenario is well described by the "drift-diffusion-model", in which the subject compares the firing rate $m_A$ of an upward-motion sensitive neuron with the firing rate $m_B$ of a downward-motion sensitive neuron and integrates the difference between the two:

$$\dot{x} = m_A − m_B + σ η(t)$$

where $η(t)$ is a noise term (Gaussian white noise with unit standard deviation) that simulates the noisiness of real neurons. If the integration variable $x$ surpasses a threshold $μ$, then the subject decides for outcome $A$; if $x$ decreases below a threshold $−μ$, then the subject decides for outcome $B$.

Any ordinary differential equation can be solved numerically using the Euler method, i.e. using the approximation:

$$x(t + ∆t) = x(t) + \dot{x} ∆t \qquad (6)$$

For stochastic differential equations, i.e. those that have a noise-term, the random part grows with the square root of the time step, a technical issue that you can ignore for now. It leads to the following discrete approximation of the drift-diffusion-model:

$$x(t+∆t) = x(t) + (m_A - m_B) ∆t + σ η(t) \sqrt{∆t} \qquad (7)$$

In [339]:
# /!\ I'm using Python 3

import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go

import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
%matplotlib inline

plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
    '<script>'
        'var waitForPlotly = setInterval( function() {'
            'if( typeof(window.Plotly) !== "undefined" ){'
                'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
                'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
                'clearInterval(waitForPlotly);'
            '}}, 250 );'
    '</script>'
))

# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
    cm = plt.get_cmap(cmap)
    colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
    rgb_colors_plotly = []
    rgb_colors_only = []
    for i, c in enumerate(colors):
        col = 'rgb{}'.format(tuple(255*c[:-1]))
        rgb_colors_only.append(col)
        rgb_colors_plotly.append([i/number_colors, col])
        rgb_colors_plotly.append([(i+1)/number_colors, col])
    return rgb_colors_only if return_rgb_only else rgb_colors_plotly

(a) Assume $m_A = 1$ and $m_B = 0.95$. Plot several runs of the drift-diffusion-model, always starting with $x(0) = 0$. Choose a stepwidth of $∆t = 0.1 \, \mathrm{ ms}$, a noise level $σ = 1/2$ and iterate the Euler method over $10000$ time steps up to time $t = 1 \, \mathrm{ s}$.

In [607]:
delta_t = 0.0001
t_max = 1

m_A, m_B = 1, 0.95

sigma = 0.5

x_0 = 0
In [440]:
def drift_diffusion(m_A=m_A, m_B=m_B, sigma=sigma, delta_t=delta_t, t_max=t_max, x_0=x_0):
    x = np.zeros(len(np.arange(0, t_max, delta_t)))
    x[0], sqrt_delta_t = x_0, np.sqrt(delta_t)
    
    term1, term2 = delta_t*(m_A-m_B), sigma*sqrt_delta_t
        
    for i in range(1, len(x)):
            x[i] = x[i-1] + term1 + term2*np.random.standard_normal()

    return np.array(x)


layout = go.Layout(
    title= '$\\text{Runs of drift-diffusion model, for }'+'σ = {}, \, m_A = {}, \, m_B = {}$'.format(sigma, m_A, m_B),
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (in seconds)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= '$\\text{Integration variable } x$',
        ticklen= 5,
        gridwidth= 2,
    ),
    showlegend= True
) 


legend_every = 1


# Plotting the evolution of the free parameter w
traces_dd = []

for i in range(10):
    traces_dd.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = drift_diffusion(),
            mode = 'lines',
            name = 'Run {}'.format(i+1),
            line = dict(
                width = 2,
                dash = 'solid'
            ),
            hoverlabel = dict(
                namelength = -1
            ),
            showlegend = (i % legend_every == 0)
        )
    )

plotly.offline.iplot(go.Figure(data=traces_dd, layout=layout))