Source code for span.tdt.spikedataframe

#!/usr/bin/env python

# spikedataframe.py ---

# Copyright (C) 2012 Copyright (C) 2012 Phillip Cloud <cpcloud@gmail.com>

# Author: Phillip Cloud <cpcloud@gmail.com>

# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


"""
Example
-------
>>> import span
>>> tank = span.TdtTank('basename/of/some/tank/file')
>>> sp = tank.spik
>>> assert isinstance(sp, span.SpikeDataFrame)
"""
import abc
import functools
import numbers
import types
import warnings

import numpy as np
import pandas as pd
from pandas import Series, DataFrame, DatetimeIndex
from span.utils import samples_per_ms, clear_refrac, LOCAL_TZ
from span.xcorr import xcorr as _xcorr
import six


class SpikeDataFrameBase(DataFrame):
    __metaclass__ = abc.ABCMeta

    def __init__(self, *args, **kwargs):
        super(SpikeDataFrameBase, self).__init__(*args, **kwargs)

    @abc.abstractproperty
    def nchannels(self):
        pass

    @abc.abstractproperty
    def nsamples(self):
        pass

    @abc.abstractproperty
    def fs(self):
        pass

    @property
    def period(self):
        """Return the period in *nanoseconds*"""
        return 1.0 / self.fs * 1e9

    @abc.abstractmethod
    def threshold(self, *args, **kwargs):
        pass

    @abc.abstractmethod
    def clear_refrac(self, *args, **kwargs):
        pass


[docs]class SpikeDataFrame(SpikeDataFrameBase): """Class encapsulting a Pandas DataFrame with extensions for analyzing spike train data. See the pandas DataFrame documentation for constructor details. """ def __init__(self, *args, **kwargs): self.super.__init__(*args, **kwargs) self.isclean = False @property def nchannels(self): return self.shape[1] @property def nsamples(self): return self.shape[0] @property def fs(self): return 1e9 / self.index.freq.n
[docs] def threshold(self, threshes): """Threshold spikes. Parameters ---------- threshes : array_like Raises ------ AssertionError * If `threshes` is not a scalar or a vector of length equal to the number of channels. Returns ------- threshed : array_like """ if np.isscalar(threshes): threshes = np.repeat(threshes, self.nchannels) if threshes.size != self.nchannels: raise AssertionError('number of threshold values must be 1 ' '(same for all channels) or {0}, different ' 'threshold for each ' 'channel'.format(self.nchannels)) cmpf = self.lt if np.all(threshes < 0) else self.gt thr = threshes.item() if threshes.size == 1 else threshes threshes = Series(thr, index=self.columns) f = functools.partial(cmpf, axis=1) return f(threshes)
@property def _constructor(self): return self.__class__
[docs] def clear_refrac(self, ms=2, inplace=False): """Remove spikes from the refractory period of all channels. Parameters ---------- threshed : array_like Array of ones and zeros. ms : real, optional, default 2 The length of the refractory period in milliseconds. Raises ------ TypeError * If `ms` is not an instance of ``numbers.Real``. ValueError * If `ms` is less than 0. Returns ------- r : SpikeDataFrame The thresholded and refractory-period-cleared array of booleans indicating the sample point at which a spike was above threshold. Notes ----- This method DOES NOT modify the object inplace by default. """ if not isinstance(ms, numbers.Real): raise TypeError('ms must be a real number') if ms < 0: raise ValueError('refractory period must be a nonnegative real ' 'number') if not ms: if not inplace: return self return ms_fs = samples_per_ms(self.fs, ms) df = self.copy() if inplace else self values = df.values clear_refrac(values, ms_fs) if not inplace: return df
[docs] def prune_spikes(self, remove_null=True): """Reduce a cleared spike array to the minimum necessary to bin and compute correlations. Parameters ---------- remove_null : bool, optional, default True Returns ------- b : DataFrame """ res = {} _remove_null = lambda x: x if remove_null: _remove_null = lambda x: x & x.notnull() res = [v[_remove_null(v)] for _, v in self.iteritems()] reduc = pd.concat(res, axis=1) df = self._constructor(reduc.values, reduc.index, self.columns) df.sort_index(axis=0, inplace=True) df.fillna(0, inplace=True) b = df.astype(bool) b.fillna(0, inplace=True) return b
def bin(self, bin_size, how='sum', *args, **kwargs): return self.resample(bin_size, how=how, *args, **kwargs) @classmethod
[docs] def xcorr(cls, binned, maxlags=None, detrend=None, scale_type=None, sortlevel='shank i', nan_auto=False): """Compute the cross correlation of binned data. Parameters ---------- binned : array_like Data of which to compute the cross-correlation. maxlags : int, optional Maximum number of lags to return from the cross correlation. Defaults to None and computes the full cross correlation. detrend : callable or None, optional Callable used to detrend. Defaults to ``None`` scale_type : str, optional Method of scaling. Defaults to ``None``. sortlevel : str, optional How to sort the index of the returned cross-correlation. Defaults to "shank i" so the the xcorrs are ordered by their physical ordering. nan_auto : bool, optional If ``True`` then the autocorrelation values will be ``NaN``. Defaults to ``False``. Raises ------ AssertionError * If detrend is not a callable object * If scale_type is not a string or is not None ValueError * If sortlevel is not ``None`` and is not a string or a number in the list of level names or level indices. Returns ------- xc : DataFrame The cross correlation of all the columns of the data, indexed by lags and columned by channel pair. See Also -------- span.xcorr.xcorr General cross correlation function. SpikeDataFrame.clear_refrac Clear the refractory period of a channel or array of channels. """ assert callable(detrend) or detrend is None, ('detrend must be a ' 'callable class or ' 'function or None') assert isinstance(scale_type, six.string_types + (types.NoneType,)), \ 'scale_type must be a string or None' xc = _xcorr(binned, maxlags=maxlags, detrend=detrend, scale_type=scale_type) if nan_auto: # HACK for channel names xc0 = xc.ix[0] names = xc0.index.names chi_ind = names.index('channel i') chj_ind = names.index('channel j') selector = lambda x: x[chi_ind] == x[chj_ind] xc.ix[0, xc0.select(selector).index] = np.nan with warnings.catch_warnings(): warnings.simplefilter('ignore', FutureWarning) xc.sortlevel(level=sortlevel, axis=1, inplace=True) return xc
[docs] def interval_jitter(self, window=100, unit='ms'): """Basic jitter samples by some window in units of `unit`. Parameters ---------- window : int, optional The size of the jitter window. unit : str, optional The time units of the jitter window. Returns ------- df : SpikeDataFrame """ new_index = self._interval_jitter_reindex(window, unit) df = self._constructor(self.values, new_index, self.columns) df.sort_index(inplace=True) return df
def jitter_channel(self, orig_index, orig_indices, index_where, channel, window, unit='ms'): new_index = self._interval_jitter_reindex(index_where, window, unit) orig_index.values[orig_indices] = new_index.values s = Series(channel.values, index=orig_index, name=channel.name) return s.sort_index() def _interval_jitter_reindex(self, index, window, unit): index = index.values # datetime units dt = index.dtype # start of the window-length window beg = np.floor(index.astype(int, copy=False) / window) start = (window * beg).astype(dt, copy=False) # timedelta unit td_unit = 'timedelta64[%s]' % unit # shift from beginning of jitter window by U * window rt = np.random.rand(index.size) * window rand_time = rt.astype(td_unit, copy=False) shifted = start + rand_time return DatetimeIndex(shifted, tz=LOCAL_TZ) ## reimplement methods that pandas dataframe doesn't correctly construct # after calling @property def super(self): return super(SpikeDataFrame, self) def _call_super_method(self, method_name, *args, **kwargs): method = getattr(self.super, method_name) return self._constructor(method(*args, **kwargs)) def dot(self, *args, **kwargs): return self._call_super_method('dot', *args, **kwargs) def sort_index(self, *args, **kwargs): return self._call_super_method('sort_index', *args, **kwargs)
spike_xcorr = SpikeDataFrame.xcorr

Project Versions

This Page