Problem Set 5: Efficient balanced networks

In [1]:
# /!\ I'm using Python 3

import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go

import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
from scipy import special

%matplotlib inline

plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
    '<script>'
        'var waitForPlotly = setInterval( function() {'
            'if( typeof(window.Plotly) !== "undefined" ){'
                'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
                'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
                'clearInterval(waitForPlotly);'
            '}}, 250 );'
    '</script>'
))

# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
    cm = plt.get_cmap(cmap)
    colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
    rgb_colors_plotly = []
    rgb_colors_only = []
    for i, c in enumerate(colors):
        col = 'rgb{}'.format(tuple(255*c[:-1]))
        rgb_colors_only.append(col)
        rgb_colors_plotly.append([i/number_colors, col])
        rgb_colors_plotly.append([(i+1)/number_colors, col])
    return rgb_colors_only if return_rgb_only else rgb_colors_plotly

from scipy.io import loadmat, whosmat
from numpy.random import randint

def formatted(f): 
    return format(f, '.2f').rstrip('0').rstrip('.')

1. Implementation of the integrate-and-fire (IF) neuron

$$\dot{V} = -λ V + F c(t)$$


$$ \begin{align*} & \frac{ ΔV}{ Δt} = -λ V(t) + Fc(t)\\ ⟺ & V(t+Δt) - V(t) = (-λ V(t) + Fc(t)) Δt\\ ⟺ & V(t+Δt) = (1-λ Δt) V(t) + Fc(t)Δt\\ \end{align*} $$

In [28]:
lam = 10.
F = 1.
V_th = .5
V_0 = 0.

delta_t = .01
t_max = 10.


def c(delta_t=delta_t, t_max=t_max, amplitude=lam, noise_magnitude=10, sigma=1., seed=1):
    np.random.seed(seed)
    
    t = np.arange(0, t_max, delta_t)
    c_noisy = np.sin(t) + noise_magnitude*np.random.random(len(t))
    
    gaussian = 1/(sigma * np.sqrt(2 * np.pi)) * np.exp(-(t-t_max/2)**2 /(2 * sigma**2))

    c_smooth = np.convolve(c_noisy, gaussian, mode='same')
    return (c_smooth-np.mean(c_smooth))/10

plotly.offline.iplot(go.Figure(data=[go.Scatter(
    x=np.arange(0, t_max, delta_t),
    y = c(),
    mode = 'lines'
)]))
In [29]:
def voltage_diff_equation(lam=lam, F=F, V_0=0., V_th=V_th, delta_t=delta_t, t_max=t_max, c=c(),\
                          return_spikes=False):
    
    n = len(np.arange(0, t_max, delta_t))
    V = np.full(n, V_0)
    
    if return_spikes:
        spikes = np.zeros(n)

    a, b = 1-lam*delta_t, F*delta_t
    
    for i in range(n):
        if i == 0:
            continue
        if V[i-1] >= V_th:
            if return_spikes:
                spikes[i-1] = 1
        else:
            V[i] = a*V[i-1] + b*c[i-1]
            

    if not return_spikes:
        return V
    else:
        return V, spikes


layout1 = go.Layout(
    title= '$\\text{Integrate-and-Fire neuron encoding a time-varying 1D input }'+\
    ' c(t) '+ '\\text{ (with }'+\
    '\; λ = {}, '.format(lam)+\
    '\; F = {}, '.format(F)+\
    '\; V_{th}, '+' = {}'.format(V_th)+ ' \\text{ ) }$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Voltage (mV)',
        ticklen= 5,
        gridwidth= 2,
    ),
) 



# Plotting the evolution

trace_V = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = voltage_diff_equation(),
    mode = 'lines',
    line = dict(
        width = 2,
        dash = 'solid'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Numerical solution (Euler method)'
)

plotly.offline.iplot(go.Figure(data=[trace_V], layout=layout1))
In [38]:
def x(lam=lam, x_0=0., delta_t=delta_t, t_max=t_max, c=c(), rescale=True):
    n = len(np.arange(0, t_max, delta_t))
    x = np.full(n, x_0)
    
    a = 1-lam*delta_t
    
    for i in range(1, n):
            x[i] = a*x[i-1] + delta_t*c[i-1]
    
    return x if not rescale else x/x.std()

plotly.offline.iplot(go.Figure(data=[go.Scatter(
    x=np.arange(0, t_max, delta_t),
    y = x(rescale=False),
    mode = 'lines'
)]))
In [46]:
V, spikes = voltage_diff_equation(return_spikes=True)

r = np.convolve(spikes, np.exp(-lam*np.arange(0, t_max, delta_t)))

plotly.offline.iplot(go.Figure(data=[go.Scatter(
    x = np.arange(0, t_max, delta_t),
    y = r,
    mode = 'lines'
)]))
In [48]:
layout2 = go.Layout(
    title= '$\\text{Polt of the the signal } x \\text{, the estimate } \hat{x} \\text{ and the spike train }'+\
    '\\text{ (with }'+\
    '\; λ = {}, '.format(lam)+\
    '\; F = {}, '.format(F)+\
    '\; V_{th}, '+' = {}'.format(V_th)+ ' \\text{ ) }$',
    hovermode= 'closest',
    xaxis= dict(
        title= 'Time (ms)',
        ticklen= 5,
        zeroline= False,
        gridwidth= 2,
    ),
    yaxis=dict(
        title= 'Value',
        ticklen= 5,
        gridwidth= 2,
    ),
) 



# Plotting

trace_x = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = x(),
    mode = 'lines',
    line = dict(
        width = 2,
        dash = 'solid'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Signal'
)


trace_x_hat = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = (V_th/F)*r,
    mode = 'lines',
    line = dict(
        width = 2,
        dash = 'solid'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Estimate'
)


trace_sp = go.Scatter(
    x = np.arange(0, t_max, delta_t), 
    y = spikes,
    mode = 'lines',
    line = dict(
        width = 2,
        dash = 'solid'
    ),
    hoverlabel = dict(
        namelength = -1
    ),
    name='Spike train'
)

plotly.offline.iplot(go.Figure(data=[trace_x, trace_x_hat, trace_sp], layout=layout2))