# /!\ I'm using Python 3
import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
from scipy import special
%matplotlib inline
plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
'<script>'
'var waitForPlotly = setInterval( function() {'
'if( typeof(window.Plotly) !== "undefined" ){'
'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
'clearInterval(waitForPlotly);'
'}}, 250 );'
'</script>'
))
# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
cm = plt.get_cmap(cmap)
colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
rgb_colors_plotly = []
rgb_colors_only = []
for i, c in enumerate(colors):
col = 'rgb{}'.format(tuple(255*c[:-1]))
rgb_colors_only.append(col)
rgb_colors_plotly.append([i/number_colors, col])
rgb_colors_plotly.append([(i+1)/number_colors, col])
return rgb_colors_only if return_rgb_only else rgb_colors_plotly
from scipy.io import loadmat, whosmat
from numpy.random import randint
def formatted(f):
return format(f, '.2f').rstrip('0').rstrip('.')
$$\dot{V} = -λ V + F c(t)$$
$$ \begin{align*} & \frac{ ΔV}{ Δt} = -λ V(t) + Fc(t)\\ ⟺ & V(t+Δt) - V(t) = (-λ V(t) + Fc(t)) Δt\\ ⟺ & V(t+Δt) = (1-λ Δt) V(t) + Fc(t)Δt\\ \end{align*} $$
lam = 10.
F = 1.
V_th = .5
V_0 = 0.
delta_t = .01
t_max = 10.
def c(delta_t=delta_t, t_max=t_max, amplitude=lam, noise_magnitude=10, sigma=1., seed=1):
np.random.seed(seed)
t = np.arange(0, t_max, delta_t)
c_noisy = np.sin(t) + noise_magnitude*np.random.random(len(t))
gaussian = 1/(sigma * np.sqrt(2 * np.pi)) * np.exp(-(t-t_max/2)**2 /(2 * sigma**2))
c_smooth = np.convolve(c_noisy, gaussian, mode='same')
return (c_smooth-np.mean(c_smooth))/10
plotly.offline.iplot(go.Figure(data=[go.Scatter(
x=np.arange(0, t_max, delta_t),
y = c(),
mode = 'lines'
)]))
def voltage_diff_equation(lam=lam, F=F, V_0=0., V_th=V_th, delta_t=delta_t, t_max=t_max, c=c(),\
return_spikes=False):
n = len(np.arange(0, t_max, delta_t))
V = np.full(n, V_0)
if return_spikes:
spikes = np.zeros(n)
a, b = 1-lam*delta_t, F*delta_t
for i in range(n):
if i == 0:
continue
if V[i-1] >= V_th:
if return_spikes:
spikes[i-1] = 1
else:
V[i] = a*V[i-1] + b*c[i-1]
if not return_spikes:
return V
else:
return V, spikes
layout1 = go.Layout(
title= '$\\text{Integrate-and-Fire neuron encoding a time-varying 1D input }'+\
' c(t) '+ '\\text{ (with }'+\
'\; λ = {}, '.format(lam)+\
'\; F = {}, '.format(F)+\
'\; V_{th}, '+' = {}'.format(V_th)+ ' \\text{ ) }$',
hovermode= 'closest',
xaxis= dict(
title= 'Time (ms)',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
title= 'Voltage (mV)',
ticklen= 5,
gridwidth= 2,
),
)
# Plotting the evolution
trace_V = go.Scatter(
x = np.arange(0, t_max, delta_t),
y = voltage_diff_equation(),
mode = 'lines',
line = dict(
width = 2,
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
name='Numerical solution (Euler method)'
)
plotly.offline.iplot(go.Figure(data=[trace_V], layout=layout1))
def x(lam=lam, x_0=0., delta_t=delta_t, t_max=t_max, c=c(), rescale=True):
n = len(np.arange(0, t_max, delta_t))
x = np.full(n, x_0)
a = 1-lam*delta_t
for i in range(1, n):
x[i] = a*x[i-1] + delta_t*c[i-1]
return x if not rescale else x/x.std()
plotly.offline.iplot(go.Figure(data=[go.Scatter(
x=np.arange(0, t_max, delta_t),
y = x(rescale=False),
mode = 'lines'
)]))
V, spikes = voltage_diff_equation(return_spikes=True)
r = np.convolve(spikes, np.exp(-lam*np.arange(0, t_max, delta_t)))
plotly.offline.iplot(go.Figure(data=[go.Scatter(
x = np.arange(0, t_max, delta_t),
y = r,
mode = 'lines'
)]))
layout2 = go.Layout(
title= '$\\text{Polt of the the signal } x \\text{, the estimate } \hat{x} \\text{ and the spike train }'+\
'\\text{ (with }'+\
'\; λ = {}, '.format(lam)+\
'\; F = {}, '.format(F)+\
'\; V_{th}, '+' = {}'.format(V_th)+ ' \\text{ ) }$',
hovermode= 'closest',
xaxis= dict(
title= 'Time (ms)',
ticklen= 5,
zeroline= False,
gridwidth= 2,
),
yaxis=dict(
title= 'Value',
ticklen= 5,
gridwidth= 2,
),
)
# Plotting
trace_x = go.Scatter(
x = np.arange(0, t_max, delta_t),
y = x(),
mode = 'lines',
line = dict(
width = 2,
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
name='Signal'
)
trace_x_hat = go.Scatter(
x = np.arange(0, t_max, delta_t),
y = (V_th/F)*r,
mode = 'lines',
line = dict(
width = 2,
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
name='Estimate'
)
trace_sp = go.Scatter(
x = np.arange(0, t_max, delta_t),
y = spikes,
mode = 'lines',
line = dict(
width = 2,
dash = 'solid'
),
hoverlabel = dict(
namelength = -1
),
name='Spike train'
)
plotly.offline.iplot(go.Figure(data=[trace_x, trace_x_hat, trace_sp], layout=layout2))