In the next exercise, we want to do some simple analysis of real spike trains. First, download the data-file provided on the website. The data file contains the recordings of a single neuron in the primary somatosensory cortex of a monkey that was experiencing a vibratory stimulus on the fingertip. There are three variables of relevance in the file: f1
, spt
, and t
. The variable f1
is a vector that contains the different vibration frequencies that the monkey experienced.
To load a mat
file in python we need to import the loadmat function from the scipy.io package as follows: from scipy.io import loadmat
then load the file with cell = loadmat('simadata.mat')
.
In matlab it's simpler, since we only have to load the file directly using load
.
The variable spt
contains the spike trains recorded. Note that this variable is a cell array — to retrieve the spike trains recorded for the $i$-th stimulus, you need to type s=cell['spt'][0,i]
(Python) or s=spt{i}
(MATLAB). Afterwards, s
is a matrix of $0$s and $1$s as above. The variable t
contains the respective time points.
# /!\ I'm using Python 3
import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
%matplotlib inline
plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
'<script>'
'var waitForPlotly = setInterval( function() {'
'if( typeof(window.Plotly) !== "undefined" ){'
'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
'clearInterval(waitForPlotly);'
'}}, 250 );'
'</script>'
))
# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
cm = plt.get_cmap(cmap)
colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
rgb_colors_plotly = []
rgb_colors_only = []
for i, c in enumerate(colors):
col = 'rgb{}'.format(tuple(255*c[:-1]))
rgb_colors_only.append(col)
rgb_colors_plotly.append([i/number_colors, col])
rgb_colors_plotly.append([(i+1)/number_colors, col])
return rgb_colors_only if return_rgb_only else rgb_colors_plotly
from scipy.io import loadmat, whosmat
from numpy.random import randint
def truncate(f):
return float(format(f, '.2f').rstrip('0').rstrip('.'))
truncate = np.vectorize(truncate)
f1=8.4
Hz) into the same graph.¶cell = loadmat('simdata.mat')
s = cell['spt']
cell['t']
# Rastergram for f1=8.4 Hz
traces_spikes = []
for i, spike in enumerate(s[np.where(cell['f1'] == 8.4)][0]):
traces_spikes.append(
go.Scatter(
x = cell['t'][0],
y = (i+1)*spike,
mode = 'markers',
name = 'spike #{}'.format(i+1),
marker = dict(
symbol = 'square'
),
showlegend = False
)
)
layout = go.Layout(
title='$\\text{Rastergram of spike trains for } f_1 = 8.4 \\text{ Hz}$',
xaxis=dict(
title='Time (ms)'
),
yaxis=dict(
range=[0.5, i+2],
title='Trial',
ticklen= 5,
gridwidth= 2,
tick0 = 1,
)
)
plotly.offline.iplot(go.Figure(data=traces_spikes, layout=layout))
# Rastergram for all the values of f1
traces_spikes = []
backgrounds = []
colors = colorscale_list('tab10', len(cell['f1'][0])+3, return_rgb_only=True)
cum_sum = 0
for j, fr in enumerate(cell['f1'][0]):
for i, spike in enumerate(s[0, j]):
traces_spikes.append(
go.Scatter(
x = cell['t'][0],
y = (cum_sum+i+1)*spike,
mode = 'markers',
name = 'f1 = {} Hz'.format(fr),
marker = dict(
symbol = 'square'
),
line = dict(
color = colors[j],
),
showlegend = i==0,
)
)
if j%2 == 1:
backgrounds.append(
dict(
fillcolor='#ccc',
line=dict(
width=0
),
opacity=0.5,
type='rect',
x0=0,
x1=1001,
y0=cum_sum,
y1=cum_sum+len(s[0, j]),
layer='below'
))
cum_sum += len(s[0, j])
layout = go.Layout(
title='Rastergram of spike trains for different frequencies',
xaxis=dict(
range=[0,1001],
title='Time (ms)'
),
yaxis=dict(
range=[1, cum_sum+1],
title='Trial',
ticklen= 5,
gridwidth= 2,
),
shapes=backgrounds
)
plotly.offline.iplot(go.Figure(data=traces_spikes, layout=layout))
Advanced: Additionally, add the information of the standard error of the mean (SEM) as errorbars in this plot. Remember that the standard error of the mean is defined as $SEM = σ/ \sqrt{N}$, where $N$ is the number of samples and $σ$ is the standard deviation of the variable you average.)
t_init, t_fin = 200, 700
spike_counts = [np.sum(sj[:,np.where(cell['t'][0] == t_init)[0][0]:np.where(cell['t'][0] == t_fin)[0][0]], axis=1)\
for sj in s[0]]
means = list(map(np.mean, spike_counts))
stds = list(map(np.std, spike_counts))
avg_firing_rate = [mean/((t_fin-t_init)/1000) for mean in means]
SEM = [std/np.sqrt(len(spike_counts[i])) for i, std in enumerate(stds)]
print(avg_firing_rate)
data = [
go.Scatter(
x=cell['f1'][0],
y=avg_firing_rate,
error_y=dict(
type='data',
array=SEM,
visible=True
)
)
]
layout = go.Layout(
title='Tuning curve of average firing rates with standard error of the mean (SEM) errorbars',
xaxis=dict(
title='Vibration Frequency (Hz)'
),
yaxis=dict(
title='Average firing rate (number of spikes/sec)',
ticklen= 5,
gridwidth= 2,
)
)
data2 = [
go.Scatter(
x=cell['f1'][0],
y=means,
error_y=dict(
type='data',
array=stds,
visible=True
)
)
]
layout2 = go.Layout(
title='Average spike counts with standard deviation errorbars',
xaxis=dict(
title='Vibration Frequency (Hz)'
),
yaxis=dict(
title='Average spike count within the simulation period',
ticklen= 5,
gridwidth= 2,
)
)
plotly.offline.iplot(go.Figure(data=data2, layout=layout2))
plotly.offline.iplot(go.Figure(data=data, layout=layout))
cell['f1'][0]
slope, intercept, r_value, _, std_err = stats.linregress(cell['f1'][0], means)
Markdown('''
## *Linear regression*: Average spike counts
- **Equation**: y = {} x + {}
- **Correlation coefficient**: {}
- **Standard error of the estimate**: {}
'''.format(truncate(slope), truncate(intercept), r_value, truncate(std_err)))
cell['f1'][0]
slope, intercept, r_value, _, std_err = stats.linregress(cell['f1'][0], means)
Markdown('''
## *Linear regression*: Average firing rates
- **Equation**: y = {} x + {}
- **Correlation coefficient**: {}
- **Standard error of the estimate**: {}
'''.format(truncate(slope), truncate(intercept), r_value, truncate(std_err)))