Next, we want to investigate a simple model of how real neurons create action potentials. In a second step, we want to build a simple model of how the vibratory stimulus from Exercise (2) can be translated into a spike train.
# /!\ I'm using Python 3
import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
from scipy import special
%matplotlib inline
plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
    '<script>'
        'var waitForPlotly = setInterval( function() {'
            'if( typeof(window.Plotly) !== "undefined" ){'
                'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
                'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
                'clearInterval(waitForPlotly);'
            '}}, 250 );'
    '</script>'
))
# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
    cm = plt.get_cmap(cmap)
    colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
    rgb_colors_plotly = []
    rgb_colors_only = []
    for i, c in enumerate(colors):
        col = 'rgb{}'.format(tuple(255*c[:-1]))
        rgb_colors_only.append(col)
        rgb_colors_plotly.append([i/number_colors, col])
        rgb_colors_plotly.append([(i+1)/number_colors, col])
    return rgb_colors_only if return_rgb_only else rgb_colors_plotly
from scipy.io import loadmat, whosmat
from numpy.random import randint
def formatted(f): 
    return format(f, '.2f').rstrip('0').rstrip('.')
$$C \frac{ {\rm d}V}{ {\rm d}t} = g_L (E_L - V(t))+I$$
where
This equation (and any other differential equation) can be solved numerically using the Euler method, i.e., using the approximation:
$$V(t+∆t) = V(t) + \frac{ {\rm d}V (t)}{ {\rm d}t} \,∆t \qquad (2)$$
Your task will be to implement this method for the above differential equation with initial condition $V (0) = E_L$. Choose a stepwidth of $∆t = 1 \text{ ms}$ and iterate the Euler method over $100$ time steps up to time $t = 100 \text{ ms}$.
$$\begin{align*} & C \frac{ ΔV}{ Δt} = g_L (E_L - V(t))+I\\ ⟺ & \frac{ ΔV}{ Δt} = \frac 1 C (g_L (E_L - V(t))+I)\\ ⟺ & V(t+Δt) - V(t) = \frac {Δt} C (g_L (E_L - V(t))+I)\\ ⟺ & V(t+Δt) = \Big(1- \frac {Δt \cdot g_L} C\Big) V(t) + \frac {Δt \, (g_L \, E_L + I)} C \end{align*}$$
C = 1.
g_L = 0.1
E_L = -70.
I = 1.
# In milliseconds
delta_t = 1.
t_max = 100.
def voltage_diff_equation(C=C, g_L=g_L, E_L=E_L, delta_t=delta_t, t_max=t_max, I=I):
    V = np.zeros(len(np.arange(0, t_max, delta_t)))
    V[0], a, b = E_L, 1-delta_t*g_L/C, delta_t*(g_L*E_L + I)/C
        
    for i in range(1, len(V)):
            V[i] = a*V[i-1] + b
    return V
layout1 = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane for an injected current }'+\
    'I = {}'.format(I)+ '\\text{ nA (with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
) 
# Plotting the evolution
trace_V = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = voltage_diff_equation(),
    mode = 'lines',
    line = dict(
        width = 2,
        dash = 'solid'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Numerical solution (Euler method)'
)
plotly.offline.iplot(go.Figure(data=[trace_V], layout=layout1))
layout2 = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane for different injected currents (with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
    showlegend= True
) 
legend_every = 4
values = list(np.arange(-10, 10.5, 0.5))
colors = colorscale_list('Blues', len(values)+3, return_rgb_only=True)
# Plotting the evolution
traces = []
for i, intensity in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = voltage_diff_equation(I=intensity),
            mode = 'lines',
            name = 'Voltage for an injected current I={} nA'.format(intensity),
            line = dict(
                width = 2,
                color = colors[i+2],
                shape = 'spline',
                dash = 'solid'
            ),
            hoverlabel = dict(
                namelength = -1
            ),
            showlegend = (i % legend_every == 0)
        )
    )
plotly.offline.iplot(go.Figure(data=traces, layout=layout2))
$$\begin{align*} & \frac{ {\rm d}V}{ {\rm d}t} = \frac {g_L} C (E_L - V(t))+I/C\\ ⟺ & \frac{ {\rm d}V}{ {\rm d}t} + \frac {g_L} C V(t) = \frac {g_L \, E_L +I} C\\ ⟺ & \frac{ {\rm d}V}{ {\rm d}t} \, {\rm e}^{\frac {g_L \, t} C} + \frac {g_L} C V(t) \, {\rm e}^{\frac {g_L \, t} C} = \frac {g_L \, E_L +I} C \, {\rm e}^{\frac {g_L \, t} C}\\ ⟺ & \frac{ {\rm d}}{ {\rm d}t} \left[V(t) \, {\rm e}^{\frac {g_L \, t} C}\right] = \frac {g_L \, E_L +I} C \, {\rm e}^{\frac {g_L \, t} C}\\ ⟺ & V(t) \, {\rm e}^{\frac {g_L \, t} C} = \frac C {g_L} \frac {g_L \, E_L +I} C \, {\rm e}^{\frac {g_L \, t} C} + \text{const}\\ ⟺ & V(t) \, {\rm e}^{\frac {g_L \, t} C} = \left(E_L + \frac I {g_L}\right) \, {\rm e}^{\frac {g_L \, t} C} + \text{const}\\ ⟺ & V(t) \, = \left(E_L + \frac I {g_L}\right) + \text{const} × {\rm e}^{-\frac {g_L \, t} C}\\ \end{align*}$$
And with $V(0) = E_L$:
$$V(t) \, = E_L + \frac I {g_L}\left(1-{\rm e}^{-\frac {g_L \, t} C}\right)$$
layout1 = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane for an injected current }'+\
    'I = {}'.format(I)+ '\\text{ nA (with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
) 
# Plotting the evolution
trace_theoretical = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = E_L + I/g_L*(1-np.exp(- g_L/C * np.arange(0, t_max, delta_t))),
    mode = 'lines',
    line = dict(
        width = 3,
        dash = 'dash'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Exact solution'
)
plotly.offline.iplot(go.Figure(data=[trace_V, trace_theoretical], layout=layout1))
We will now equip the passive membrane with a very simple action-potential-generating mechanism. For that purpose, we will assume that every time the voltage $V$ surpasses a threshold $V_{th}$, the neuron fires an action potential (=spike), and the membrane voltage is reset to $V = E_L$.
V_th = -63
def voltage_threshold(C=C, g_L=g_L, E_L=E_L, delta_t=delta_t, t_max=t_max, I=I, V_th = V_th, return_spikes=False):
    V = np.full(len(np.arange(0, t_max, delta_t)), E_L)
    if return_spikes:
        spikes = np.zeros(len(np.arange(0, t_max, delta_t)))
    t_0 = 0
    
    for i, t in enumerate(np.arange(0, t_max, delta_t)):
        if i == 0:
            continue
        if V[i-1] >= V_th:
            t_0 = t
            if return_spikes:
                spikes[i-1] = 1
        else:
            V[i] = E_L + I/g_L*(1-np.exp(- g_L/C * (t-t_0)))
    if not return_spikes:
        return V
    else:
        return V, spikes
layout3 = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for an injected current }'+\
    'I = {}'.format(I)+ '\\text{ nA }'+'\\\\ \\text{(with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
) 
# Plotting the evolution
trace_Vth = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = voltage_threshold(),
    mode = 'lines',
    line = dict(
        width = 2,
        dash = 'solid'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Numerical solution (Euler method)'
)
plotly.offline.iplot(go.Figure(data=[trace_Vth], layout=layout3))
$$\begin{align*} & \; V(t) \, \overset{\text{eventually}}{>} V_{th}\\ ⟺ & \; \lim\limits_{t \to +∞} V(t) > V_{th}\\ ⟺ & \; E_L + \frac I {g_L} > V_{th}\\ ⟺ & \; I > g_L (V_{th} - E_L) \end{align*}$$
Here, with with $E_L ≝ -70 \text{ mV}, V_{th} = -63 \text{ mV}$:
$$V(t) \, \overset{\text{eventually}}{>} V_{th} ⟺ I > 0.7 \text{ nA}$$
layout4 = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for different injected currents }'+\
    '\\\\ \\text{(with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
    #legend=dict(x=0, y=-0.5)
) 
legend_every = 1
values = list(np.arange(0.1, .8, 0.05))
colors = colorscale_list('Blues', len(values)+2, return_rgb_only=True)
# Plotting the evolution
traces = []
for i, intensity in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = voltage_threshold(I=intensity),
            mode = 'lines',
            name = 'Voltage with I={} nA'.format(formatted(intensity)),
            line = dict(
                width = 2,
                color = colors[i+1],
                shape = 'spline',
                dash = 'solid'
            ),
            hoverlabel = dict(
                namelength = -1
            ),
            showlegend = (i % legend_every == 0)
        )
    )
plotly.offline.iplot(go.Figure(data=traces, layout=layout4))
layout4 = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for different injected currents }'+\
    '\\\\ \\text{(with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
    legend=dict(x=-.1, y=-0.2)
) 
legend_every = 1
values = list(np.arange(0.8, 1.7, 0.5))
colors = colorscale_list('Blues', len(values)+2, return_rgb_only=True)
# Plotting the evolution
traces = []
for i, intensity in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = voltage_threshold(I=intensity),
            mode = 'lines',
            name = 'Voltage with I={} nA'.format(intensity),
            line = dict(
                width = 2,
                color = colors[i+1],
                shape = 'spline',
                dash = 'solid'
            ),
            hoverlabel = dict(
                namelength = -1
            ),
            showlegend = (i % legend_every == 0)
        )
    )
plotly.offline.iplot(go.Figure(data=traces, layout=layout4))
layout4 = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for different injected currents }'+\
    '\\\\ \\text{(with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
    legend=dict(x=-.1, y=-0.2)
) 
legend_every = 1
values = list(np.arange(1.3, 2.2, 0.5))
colors = colorscale_list('Blues', len(values)+2, return_rgb_only=True)
# Plotting the evolution
traces = []
for i, intensity in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = voltage_threshold(I=intensity),
            mode = 'lines',
            name = 'Voltage with I={} nA'.format(intensity),
            line = dict(
                width = 2,
                color = colors[i+1],
                shape = 'spline',
                dash = 'solid'
            ),
            hoverlabel = dict(
                namelength = -1
            ),
            showlegend = (i % legend_every == 0)
        )
    )
plotly.offline.iplot(go.Figure(data=traces, layout=layout4))
# Rastergram
values = list(np.arange(0.8, 10, 0.5))
colors = colorscale_list('Blues', len(values)+3, return_rgb_only=True)
traces = []
backgrounds = []
for i, intensity in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = intensity*voltage_threshold(I=intensity, return_spikes=True)[1],
            mode = 'markers',
            marker = dict(
                symbol = 'square',
                color = colors[i+2]
            ),
            name = 'I={} nA'.format(intensity),
            showlegend = False
        )
    )
    
    if i%2 == 1:
        backgrounds.append(
            dict(
                fillcolor='rgb(230, 230, 230)',
                line=dict(
                    width=0
                ),
                opacity=0.45,
                type='rect',
                x0=0,
                x1=t_max,
                y0=intensity-0.2,
                y1=intensity+0.2,
                layer='below'
            ))
    
    
layout5 = go.Layout(
    title= '$\\text{Rastergram of spike trains with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for different injected currents }'+\
    '\\\\ \\text{(with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Injected current causing the spike train (nA)',
        range=[values[0]-0.5, values[-1]+0.5],
        tickvals=values
    ),
    legend=dict(x=-.1, y=-0.2),
    shapes=backgrounds
) 
plotly.offline.iplot(go.Figure(data=traces, layout=layout5))
print([sum(voltage_threshold(I=intensity, return_spikes=True)[1]) for intensity in values])
I_init, I_fin = 0, 10
def number_spikes(I):
    return np.sum(voltage_threshold(I=I, return_spikes=True)[1])
number_spikes = np.vectorize(number_spikes)
xs = np.linspace(I_init, I_fin, 1000)
data = [
    go.Scatter(
        x=xs,
        y=number_spikes(xs)
    )
]
layout_tc = go.Layout(
    title='Tuning curve',
    xaxis=dict(
        title='Input current (nA)'
    ),
    yaxis=dict(
        title='Number of spikes within 100 ms',
        ticklen= 5,
        gridwidth= 2
    )
)
plotly.offline.iplot(go.Figure(data=data, layout=layout_tc))
V_th = -63
rp = 5 # refractory period = 5*delta_t
def voltage_threshold_rp(rp=rp, C=C, g_L=g_L, E_L=E_L, delta_t=delta_t, t_max=t_max, I=I, V_th = V_th, return_spikes=False):
    V = np.full(len(np.arange(0, t_max, delta_t)), E_L)
    if return_spikes:
        spikes = np.zeros(len(np.arange(0, t_max, delta_t)))
    t_0 = 0
    
    refractory_period = 0
    
    for i, t in enumerate(np.arange(0, t_max, delta_t)):
        if not refractory_period:
            if i == 0:
                continue
            if V[i-1] >= V_th:
                if rp > 0:
                    refractory_period = 1
                else:
                    t_0 = t
                if return_spikes:
                    spikes[i-1] = 1
            else:
                V[i] = E_L + I/g_L*(1-np.exp(- g_L/C * (t-t_0)))
        else:
            if refractory_period <= rp:
                refractory_period += 1
            else:
                refractory_period = 0
                t_0 = t
                V[i] = E_L + I/g_L*(1-np.exp(- g_L/C * (t-t_0)))
    if not return_spikes:
        return V
    else:
        return V, spikes
layout_rp = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for an injected current }'+\
    'I = {}'.format(I)+ '\\text{ nA }'+'\\\\ \\text{(with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
) 
# Plotting the evolution
trace_rp = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = voltage_threshold_rp(),
    mode = 'lines',
    line = dict(
        width = 2,
        dash = 'solid'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Numerical solution (Euler method)'
)
plotly.offline.iplot(go.Figure(data=[trace_rp], layout=layout_rp))
I_init, I_fin = 0, 10
def number_spikes_rp(I, rp=rp):
    return np.sum(voltage_threshold_rp(I=I, rp=rp, return_spikes=True)[1])
number_spikes_rp = np.vectorize(number_spikes_rp)
layout_tc_rp = go.Layout(
    title= '$\\text{Tuning curves of a neuron with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for different refractory periods }'+\
    '\\\\ \\text{(with }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Input current (nA)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Number of spikes within 100 ms',
        ticklen= 5,
        gridwidth= 2,
    )
) 
legend_every = 2
values = list(np.arange(0, 11, 1))
colors = colorscale_list('Greens', len(values)+3, return_rgb_only=True)
# Plotting the evolution
traces = []
for i, refp in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.linspace(I_init, I_fin, 1000), 
            y = number_spikes_rp(np.linspace(I_init, I_fin, 1000), rp=refp),
            mode = 'lines',
            name = 'Refractory period of {} ms'.format(refp),
            line = dict(
                width = 2,
                color = colors[i+2],
                shape = 'spline',
                dash = 'solid'
            ),
            hoverlabel = dict(
                namelength = -1
            ),
            showlegend = (i % legend_every == 0),
            fill='tozeroy'
        )
    )
plotly.offline.iplot(go.Figure(data=traces, layout=layout_tc_rp))
$$C \frac{ {\rm d}V}{ {\rm d}t} = g_L (E_L - V(t))+ I + ση(t)$$
where
sigma = 0.5
rp = 5
def voltage_white_noise(sigma=sigma, C=C, g_L=g_L, E_L=E_L, delta_t=delta_t, t_max=t_max, I=I,
                        V_th = V_th, rp=rp, return_spikes=False):
    
    V = np.full(len(np.arange(0, t_max, delta_t)), E_L)
    a, b, c = 1-delta_t*g_L/C, delta_t*(g_L*E_L + I)/C, sigma*np.sqrt(delta_t)
    
    if return_spikes:
        spikes = np.zeros(len(np.arange(0, t_max, delta_t)))
    refractory_period = 0
    
    for i in range(1, len(V)):
        if not refractory_period:
            if i == 0:
                continue
            if V[i-1] >= V_th:
                if rp > 0:
                    refractory_period = 1
                if return_spikes:
                    spikes[i-1] = 1
            else:
                V[i] = a*V[i-1] + b + c*np.random.standard_normal()
        else:
            if refractory_period <= rp:
                refractory_period += 1
            else:
                refractory_period = 0
                V[i] = a*V[i-1] + b + c*np.random.standard_normal()
    if not return_spikes:
        return V
    else:
        return V, spikes
layout_sig = go.Layout(
    title= '$\\text{Voltage across a neuron\'s membrane with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for noise magnitudes } σ'+\
    '\\\\ \\text{(with a refractory period of }'+\
    '\;  {}'.format(rp)+ '\\text{ ms, }'+\
    '\; I = {}'.format(I)+ '\\text{ nA, }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    )
) 
legend_every = 1
delta_values = 0.5
values = list(np.arange(0.1, 2.1, delta_values))
colors = colorscale_list('rainbow', len(values), return_rgb_only=True)
# Plotting the evolution
traces = []
for i, sig in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = voltage_white_noise(sigma=sig),
            mode = 'lines',
            name = 'Voltage for σ={} nA'.format(sig),
            line = dict(
                width = 2,
                color = colors[i],
                dash = 'solid'
            ),
            hoverlabel = dict(
                namelength = -1
            ),
            showlegend = (i % legend_every == 0)
        )
    )
plotly.offline.iplot(go.Figure(data=traces, layout=layout_sig))
# Rastergram
delta_values = 0.5
values = list(np.arange(0.2, 10, delta_values))
colors = colorscale_list('Reds', len(values)+3, return_rgb_only=True)
traces = []
backgrounds = []
for i, sig in enumerate(values):
    traces.append(
        go.Scatter(
            x = np.arange(0, t_max, delta_t), 
            y = sig*voltage_white_noise(sigma=sig, return_spikes=True)[1],
            mode = 'markers',
            marker = dict(
                symbol = 'square',
                color = colors[i+2]
            ),
            name = 'σ={} nA'.format(sig),
            showlegend = False
        )
    )
    
    if i%2 == 1:
        backgrounds.append(
            dict(
                fillcolor='rgb(230, 230, 230)',
                line=dict(
                    width=0
                ),
                opacity=0.45,
                type='rect',
                x0=0,
                x1=t_max,
                y0=sig-0.2,
                y1=sig+0.2,
                layer='below'
            ))
    
    
layout5 = go.Layout(
    title= '$\\text{Rastergram of spike trains with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+ '\\text{ mV for different noise magnitudes } σ'+\
    '\\\\ \\text{(with a refractory period of }'+\
    ' {}'.format(rp)+ '\\text{ ms, }'+\
    '\; I = {}'.format(I)+ '\\text{ nA, }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'White noise magnitude (nA)',
        range=[values[0]-0.1, values[-1]+0.5],
        tickvals=values
    ),
    legend=dict(x=-.1, y=-0.2),
    shapes=backgrounds
) 
plotly.offline.iplot(go.Figure(data=traces, layout=layout5))
t_max2 = 1000
sigma2 = 1
rp = 2
def voltage_varying_current(sigma=sigma2, C=C, g_L=g_L, E_L=E_L, delta_t=delta_t, t_max=t_max2, I=lambda t: I,
                        V_th = V_th, rp=rp, return_spikes=False):
    
    V = np.full(len(np.arange(0, t_max, delta_t)), E_L)
    a, b, c = 1-delta_t*g_L/C, delta_t*(g_L*E_L)/C, sigma*np.sqrt(delta_t)
    
    if return_spikes:
        spikes = np.zeros(len(np.arange(0, t_max, delta_t)))
    refractory_period = 0
    
    for i, t in enumerate(np.arange(0, t_max, delta_t)):
        if not refractory_period:
            if i == 0:
                continue
            if V[i-1] >= V_th:
                if rp > 0:
                    refractory_period = 1
                if return_spikes:
                    spikes[i-1] = 1
            else:
                V[i] = a*V[i-1] + b + delta_t*I(t)/C + c*np.random.standard_normal()
        else:
            if refractory_period <= rp:
                refractory_period += 1
            else:
                refractory_period = 0
                V[i] = a*V[i-1] + b + delta_t*I(t)/C + c*np.random.standard_normal()
    if not return_spikes:
        return V
    else:
        return V, spikes
def Intensity1(f1):
    std = 150
    offset = 450
    return (lambda t: 2*np.cos(f1*(t-offset)/200)*np.exp(-(t-offset)**2/(2*std**2)))
def Intensity2(f1):
    t_low, t_high = 200, 700
    return (lambda t: 3*np.cos(f1*(t-t_low)/200) if t_low <= t <= t_high else 0)
t_init, t_fin = 0, 1000
xs = np.linspace(t_init, t_fin, 1000)
data1 = [
    go.Scatter(
        x=xs,
        y=[Intensity1(8.4)(x) for x in xs]
    )
]
data2 = [
    go.Scatter(
        x=xs,
        y=[Intensity2(8.4)(x) for x in xs]
    )
]
layout_I = go.Layout(
    title='$\\text{Input current for } f_1 ≝ 8.4 \\text{ Hz}$',
    xaxis=dict(
        title='Time (ms)'
    ),
    yaxis=dict(
        title='Input current (nA)',
        ticklen= 5,
        gridwidth= 2
    )
)
plotly.offline.iplot(go.Figure(data=data1, layout=layout_I))
plotly.offline.iplot(go.Figure(data=data2, layout=layout_I))
# Rastergram
sigma1 = 1.
values = np.array([8.4, 12., 15.7, 19.6, 23.6, 25.9, 27.7, 35.])
colors = colorscale_list('tab10', len(values)+3, return_rgb_only=True)
traces = []
backgrounds = []
bundle_size = 10
for j, fr in enumerate(values):
    for i in range(bundle_size):
        traces.append(
            go.Scatter(
                x = np.arange(0, t_max2, delta_t),
                y = (j*bundle_size+i+1)*voltage_varying_current(I=Intensity1(fr),\
                                                      sigma=sigma1, return_spikes=True)[1],
                mode = 'markers',
                name = 'f1 = {} Hz'.format(fr),
                marker = dict(
                    symbol = 'square'
                ),
                line = dict(
                    color = colors[j],
                ),
                showlegend = i==0,
            )
        )
    if j%2 == 1:
        backgrounds.append(
            dict(
                fillcolor='#ccc',
                line=dict(
                    width=0
                ),
                opacity=0.5,
                type='rect',
                x0=0,
                x1=t_max2,
                y0=j*bundle_size,
                y1=(j+1)*bundle_size,
                layer='below'
            ))
    
layout_f = go.Layout(
    title= '$\\text{Rastergram of spike trains with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+\
    '\\text{ mV for time varying input currents } I(t)'+\
    '\\\\ \\text{(with a refractory period of }'+\
    '\; {}'.format(rp)+ '\\text{ ms, }'+\
    '\; σ = {}'.format(sigma1)+ '\\text{ nA, }'+\
    '\; I = {}'.format(I)+ '\\text{ nA, }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= '$\\text{Stimulation frequency } f_1 \\text{(Hz)}$',
        range=[1, (j+1)*bundle_size]
    ),
    shapes=backgrounds
) 
plotly.offline.iplot(go.Figure(data=traces, layout=layout_f))
# Rastergram
sigma2 = 1.5
values = np.array([8.4, 12., 15.7, 19.6, 23.6, 25.9, 27.7, 35.])
colors = colorscale_list('tab10', len(values)+3, return_rgb_only=True)
traces = []
backgrounds = []
bundle_size = 10
for j, fr in enumerate(values):
    for i in range(bundle_size):
        traces.append(
            go.Scatter(
                x = np.arange(0, t_max2, delta_t),
                y = (j*bundle_size+i+1)*voltage_varying_current(I=Intensity2(fr),\
                                                      sigma=sigma2, return_spikes=True)[1],
                mode = 'markers',
                name = 'f1 = {} Hz'.format(fr),
                marker = dict(
                    symbol = 'square'
                ),
                line = dict(
                    color = colors[j],
                ),
                showlegend = i==0,
            )
        )
    if j%2 == 1:
        backgrounds.append(
            dict(
                fillcolor='#ccc',
                line=dict(
                    width=0
                ),
                opacity=0.5,
                type='rect',
                x0=0,
                x1=t_max2,
                y0=j*bundle_size,
                y1=(j+1)*bundle_size,
                layer='below'
            ))
    
layout_f = go.Layout(
    title= '$\\text{Rastergram of spike trains with spiking threshold } V_{th}'+\
    ' = {}'.format(V_th)+\
    '\\text{ mV for time varying input currents } I(t)'+\
    '\\\\ \\text{(with a refractory period of }'+\
    '\; {}'.format(rp)+ '\\text{ ms, }'+\
    '\; σ = {}'.format(sigma2)+ '\\text{ nA, }'+\
    '\; I = {}'.format(I)+ '\\text{ nA, }'+\
    '\; C = {}'.format(C)+ '\\text{ nF, }'+\
    '\; g_L = {}'.format(g_L)+ '\\text{ μS, }'+\
    '\; E_L = {}'.format(E_L)+ ' \\text{ mV)}$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= '$\\text{Stimulation frequency } f_1 \\text{(Hz)}$',
        range=[1, (j+1)*bundle_size]
    ),
    shapes=backgrounds
) 
plotly.offline.iplot(go.Figure(data=traces, layout=layout_f))
t_init, t_fin = 200, 700
s = []
for fr in values:
    s.append(np.vstack([voltage_varying_current(I=Intensity2(fr), sigma=sigma2, return_spikes=True)[1]\
              for _ in range(bundle_size)]))
spike_counts = [np.sum(sj[:,t_init:t_fin], axis=1)\
               for sj in s]
means = list(map(np.mean, spike_counts))
stds = list(map(np.std, spike_counts))
avg_firing_rate = [mean/((t_fin-t_init)/1000) for mean in means]
SEM = [std/np.sqrt(len(spike_counts[i])) for i, std in enumerate(stds)]
print(spike_counts)
data = [
    go.Scatter(
        x=values,
        y=avg_firing_rate,
        error_y=dict(
            type='data',
            array=SEM,
            visible=True
        )
    )
]
layout = go.Layout(
    title='Tuning curve of average firing rates with standard error of the mean (SEM) errorbars',
    xaxis=dict(
        title='Vibration Frequency (Hz)'
    ),
    yaxis=dict(
        title='Average firing rate (number of spikes/sec)',
        ticklen= 5,
        gridwidth= 2,
    )
)
data2 = [
    go.Scatter(
        x=values,
        y=means,
        error_y=dict(
            type='data',
            array=stds,
            visible=True
        )
    )
]
layout2 = go.Layout(
    title='Average spike counts with standard deviation errorbars',
    xaxis=dict(
        title='Vibration Frequency (Hz)'
    ),
    yaxis=dict(
        title='Average spike count within the simulation period',
        ticklen= 5,
        gridwidth= 2,
    )
)
plotly.offline.iplot(go.Figure(data=data2, layout=layout2))
plotly.offline.iplot(go.Figure(data=data, layout=layout))
s[0].shape