AT2 – Neuromodeling: Problem set #3 SPIKE TRAINS

PROBLEM 2: Analysis of spike trains.

In the next exercise, we want to do some simple analysis of real spike trains. First, download the data-file provided on the website. The data file contains the recordings of a single neuron in the primary somatosensory cortex of a monkey that was experiencing a vibratory stimulus on the fingertip. There are three variables of relevance in the file: f1, spt, and t. The variable f1 is a vector that contains the different vibration frequencies that the monkey experienced.

To load a mat file in python we need to import the loadmat function from the scipy.io package as follows: from scipy.io import loadmat then load the file with cell = loadmat('simadata.mat'). In matlab it's simpler, since we only have to load the file directly using load. The variable spt contains the spike trains recorded. Note that this variable is a cell array — to retrieve the spike trains recorded for the $i$-th stimulus, you need to type s=cell['spt'][0,i] (Python) or s=spt{i} (MATLAB). Afterwards, s is a matrix of $0$s and $1$s as above. The variable t contains the respective time points.

In [20]:
# /!\ I'm using Python 3

import plotly
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go

import matplotlib.pyplot as plt
import scipy.io as sio
from scipy import stats
%matplotlib inline

plotly.offline.init_notebook_mode(connected=True)
from IPython.core.display import display, HTML, Markdown
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
    '<script>'
        'var waitForPlotly = setInterval( function() {'
            'if( typeof(window.Plotly) !== "undefined" ){'
                'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
                'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
                'clearInterval(waitForPlotly);'
            '}}, 250 );'
    '</script>'
))

# Colorscales
def colorscale_list(cmap, number_colors, return_rgb_only=False):
    cm = plt.get_cmap(cmap)
    colors = [np.array(cm(i/number_colors)) for i in range(1, number_colors+1)]
    rgb_colors_plotly = []
    rgb_colors_only = []
    for i, c in enumerate(colors):
        col = 'rgb{}'.format(tuple(255*c[:-1]))
        rgb_colors_only.append(col)
        rgb_colors_plotly.append([i/number_colors, col])
        rgb_colors_plotly.append([(i+1)/number_colors, col])
    return rgb_colors_only if return_rgb_only else rgb_colors_plotly

from scipy.io import loadmat, whosmat
from numpy.random import randint

def truncate(f):
    return float(format(f, '.2f').rstrip('0').rstrip('.'))
truncate = np.vectorize(truncate)

(a) Plot all the spike trains for the first stimulus (f1=8.4 Hz) into the same graph.

In [2]:
cell = loadmat('simdata.mat')
s = cell['spt']
cell['t']
Out[2]:
array([[   0,    5,   10,   15,   20,   25,   30,   35,   40,   45,   50,
          55,   60,   65,   70,   75,   80,   85,   90,   95,  100,  105,
         110,  115,  120,  125,  130,  135,  140,  145,  150,  155,  160,
         165,  170,  175,  180,  185,  190,  195,  200,  205,  210,  215,
         220,  225,  230,  235,  240,  245,  250,  255,  260,  265,  270,
         275,  280,  285,  290,  295,  300,  305,  310,  315,  320,  325,
         330,  335,  340,  345,  350,  355,  360,  365,  370,  375,  380,
         385,  390,  395,  400,  405,  410,  415,  420,  425,  430,  435,
         440,  445,  450,  455,  460,  465,  470,  475,  480,  485,  490,
         495,  500,  505,  510,  515,  520,  525,  530,  535,  540,  545,
         550,  555,  560,  565,  570,  575,  580,  585,  590,  595,  600,
         605,  610,  615,  620,  625,  630,  635,  640,  645,  650,  655,
         660,  665,  670,  675,  680,  685,  690,  695,  700,  705,  710,
         715,  720,  725,  730,  735,  740,  745,  750,  755,  760,  765,
         770,  775,  780,  785,  790,  795,  800,  805,  810,  815,  820,
         825,  830,  835,  840,  845,  850,  855,  860,  865,  870,  875,
         880,  885,  890,  895,  900,  905,  910,  915,  920,  925,  930,
         935,  940,  945,  950,  955,  960,  965,  970,  975,  980,  985,
         990,  995, 1000]], dtype=uint16)
In [3]:
# Rastergram for f1=8.4 Hz
traces_spikes = []

for i, spike in enumerate(s[np.where(cell['f1'] == 8.4)][0]):
    traces_spikes.append(
        go.Scatter(
            x = cell['t'][0],
            y = (i+1)*spike,
            mode = 'markers',
            name = 'spike #{}'.format(i+1),
            marker = dict(
                symbol = 'square'
            ),
            showlegend = False
        )
    )

layout = go.Layout(
    title='$\\text{Rastergram of spike trains for } f_1 = 8.4 \\text{ Hz}$',
    xaxis=dict(
        title='Time (ms)'
    ),
    yaxis=dict(
        range=[0.5, i+2],
        title='Trial',
        ticklen= 5,
        gridwidth= 2,
        tick0 = 1,
        
    )
)

plotly.offline.iplot(go.Figure(data=traces_spikes, layout=layout))

(b) Plot all the spike trains into the same graph. (Advanced: Use alternating white and grey colors in the background to indicate the different stimuli.)

In [4]:
# Rastergram for all the values of f1
traces_spikes = []
backgrounds = []

colors = colorscale_list('tab10', len(cell['f1'][0])+3, return_rgb_only=True)

cum_sum = 0
for j, fr in enumerate(cell['f1'][0]):
    for i, spike in enumerate(s[0, j]):
        traces_spikes.append(
            go.Scatter(
                x = cell['t'][0],
                y = (cum_sum+i+1)*spike,
                mode = 'markers',
                name = 'f1 = {} Hz'.format(fr),
                marker = dict(
                    symbol = 'square'
                ),
                line = dict(
                    color = colors[j],
                ),
                showlegend = i==0,
            )
        )
    if j%2 == 1:
        backgrounds.append(
            dict(
                fillcolor='#ccc',
                line=dict(
                    width=0
                ),
                opacity=0.5,
                type='rect',
                x0=0,
                x1=1001,
                y0=cum_sum,
                y1=cum_sum+len(s[0, j]),
                layer='below'
            ))
    cum_sum += len(s[0, j])
In [13]:
layout = go.Layout(
    title='Rastergram of spike trains for different frequencies',
    xaxis=dict(
        range=[0,1001],
        title='Time (ms)'
    ),
    yaxis=dict(
        range=[1, cum_sum+1],
        title='Trial',
        ticklen= 5,
        gridwidth= 2,
    ),
    shapes=backgrounds
)

plotly.offline.iplot(go.Figure(data=traces_spikes, layout=layout))

(c) Count the number of spikes in each trial that fall within the stimulation period ($200 ⋯ 700$ msec). For each stimulus, compute the average spike count and the standard deviation of spike counts. Plot the tuning curve of the neuron, i.e. its average firing rate (=spike count / sec) against the stimulus frequency.

Advanced: Additionally, add the information of the standard error of the mean (SEM) as errorbars in this plot. Remember that the standard error of the mean is defined as $SEM = σ/ \sqrt{N}$, where $N$ is the number of samples and $σ$ is the standard deviation of the variable you average.)

In [15]:
t_init, t_fin = 200, 700

spike_counts = [np.sum(sj[:,np.where(cell['t'][0] == t_init)[0][0]:np.where(cell['t'][0] == t_fin)[0][0]], axis=1)\
               for sj in s[0]]

means = list(map(np.mean, spike_counts))
stds = list(map(np.std, spike_counts))

avg_firing_rate = [mean/((t_fin-t_init)/1000) for mean in means]
SEM = [std/np.sqrt(len(spike_counts[i])) for i, std in enumerate(stds)]

print(avg_firing_rate)

data = [
    go.Scatter(
        x=cell['f1'][0],
        y=avg_firing_rate,
        error_y=dict(
            type='data',
            array=SEM,
            visible=True
        )
    )
]


layout = go.Layout(
    title='Tuning curve of average firing rates with standard error of the mean (SEM) errorbars',
    xaxis=dict(
        title='Vibration Frequency (Hz)'
    ),
    yaxis=dict(
        title='Average firing rate (number of spikes/sec)',
        ticklen= 5,
        gridwidth= 2,
    )
)



data2 = [
    go.Scatter(
        x=cell['f1'][0],
        y=means,
        error_y=dict(
            type='data',
            array=stds,
            visible=True
        )
    )
]

layout2 = go.Layout(
    title='Average spike counts with standard deviation errorbars',
    xaxis=dict(
        title='Vibration Frequency (Hz)'
    ),
    yaxis=dict(
        title='Average spike count within the simulation period',
        ticklen= 5,
        gridwidth= 2,
    )
)

plotly.offline.iplot(go.Figure(data=data2, layout=layout2))

plotly.offline.iplot(go.Figure(data=data, layout=layout))
[33.0, 38.399999999999999, 47.200000000000003, 59.799999999999997, 71.200000000000003, 79.0, 83.599999999999994, 104.59999999999999]
In [28]:
cell['f1'][0]

slope, intercept, r_value, _, std_err = stats.linregress(cell['f1'][0], means)

Markdown('''
## *Linear regression*: Average spike counts

- **Equation**: y = {} x + {}
- **Correlation coefficient**: {}
- **Standard error of the estimate**: {}
'''.format(truncate(slope), truncate(intercept), r_value, truncate(std_err)))
Out[28]:

Linear regression: Average spike counts

  • Equation: y = 1.39 x + 3.08
  • Correlation coefficient: 0.9972302875778214
  • Standard error of the estimate: 0.04
In [27]:
cell['f1'][0]

slope, intercept, r_value, _, std_err = stats.linregress(cell['f1'][0], means)

Markdown('''
## *Linear regression*: Average firing rates

- **Equation**: y = {} x + {}
- **Correlation coefficient**: {}
- **Standard error of the estimate**: {}
'''.format(truncate(slope), truncate(intercept), r_value, truncate(std_err)))
Out[27]:

Linear regression: Average firing rates

  • Equation: y = 1.39 x + 3.08
  • Correlation coefficient: 0.9972302875778214
  • Standard error of the estimate: 0.04