Source code for sssm.wrap_sssm.detection.detection

"""
Sleep Semantic Segmentation: fast and robust detection of 'Spindle',
'Background', 'Arousal','K-complex', 'Slow wave', 'Vertex Sharp', 'Sawtooth' from one channel sleep EEG recordings.
- Author: Xiaoyu Bao
- GitHub:
- License:
"""
import logging
import numpy as np
import pandas as pd
from scipy import signal
from mne.filter import filter_data

from scipy.interpolate import interp1d
from scipy.fftpack import next_fast_len

from matplotlib.colors import Normalize
from matplotlib.cm import get_cmap
import matplotlib.pyplot as plt
import ipywidgets as ipy
from wrap_sssm.utils.io import set_log_level
from sssm_core.model import Model as ssm
from yasa.numba import _detrend, _rms
from yasa.spectral import stft_power
from yasa.others import (
    _zerocrossings,
)

logger = logging.getLogger("sssm")

__all__ = [
    "sleep_event_detect",
    "SleepEventDetect",
]

[docs] def sleep_event_detect( data, sf=None, wave_name= ['Spindle', 'Background', 'Arousal','K-complex', 'Slow wave', 'Vertex Sharp', 'Sawtooth'], device='cuda', model_name='model.pt', model_path=None, step = 50, event_threshold = { 'Spindle': 0.95, 'Background': 0.9, 'Arousal': 0.9, 'K-complex': 0.6, 'Slow wave': 0.6, 'Vertex Sharp': 0.6, 'Sawtooth': 0.6}, overall_threshold=0.5, verbose=False, ): """Detects sleep events in the provided data. Args: data (array-like): Input data for sleep event detection. sf (float, optional): Sampling frequency. Defaults to None. wave_name (list, optional): List of wave names. Defaults to ['Spindle', 'Background', 'Arousal', 'K-complex', 'Slow wave', 'Vertex Sharp', 'Sawtooth']. device (str, optional): Device to run the model on. Defaults to 'cuda'. model_name (str, optional): Name of the model file. Defaults to 'model.pt'. model_path (str, optional): Path to the model file. Defaults to None. step (int, optional): Step size for detection. Defaults to 50. event_threshold (dict, optional): Thresholds for each event. Defaults to {'Spindle': 0.95, 'Background': 0.9, 'Arousal': 0.9, 'K-complex': 0.6, 'Slow wave': 0.6, 'Vertex Sharp': 0.6, 'Sawtooth': 0.6}. overall_threshold (float, optional): Overall detection threshold. Defaults to 0.5. verbose (bool, optional): Verbose output. Defaults to False. Returns: results: The detected sleep events. """ set_log_level(verbose) thresh_standard = {'Spindle': 0.95,'Background': 0.9,'Arousal': 0.9,'K-complex': 0.6, 'Slow wave': 0.6,'Vertex Sharp': 0.6,'Sawtooth': 0.6} for i_thresh in thresh_standard.keys(): if i_thresh not in event_threshold.keys(): event_threshold[i_thresh] = thresh_standard[i_thresh] model = ssm.SSM(device=device, model_name=model_name, model_path=model_path) return SleepEventDetect(model, wave_name, data, sf, event_threshold,step,overall_threshold)
[docs] class SleepEventDetect(): def __init__(self, model, wave_name, data, sf, thresh,step,overall_threshold): """Initializes the SleepEventDetect class. Args: model: The SSM class used for detection. wave_name (list): List of wave names. data (ndarray): Input data for detection. sf (float): Sampling frequency. thresh (dict): Thresholds for each event. step (int): Step size for detection. overall_threshold (float): Overall detection threshold. """
[docs] self._model = model
[docs] self._wave_name = wave_name
[docs] self._data = data
[docs] self._sf = sf
[docs] self._event_threshold = thresh
[docs] self._step = step
[docs] ret = self._model.predict(data.astype(np.float16), step=step)
[docs] self.event_df = self._model.to_pandas(overall_threshold=overall_threshold, event_threshold=self._event_threshold)
[docs] self._times = np.arange( self._data .shape[-1]) / self._sf
[docs] def summary(self,event=None): """Provides a summary of detected events. Args: event (str, optional): The specific event name to summarize. Defaults to None. Returns: dict or pd.DataFrame: If event is None, returns a dictionary with event details for all wave_names. If event is provided, returns the details for the specified event. Raises: ValueError: If the specified event is not in the wave names. """ filtered_df = self.calculate_feature() if event is None: for i_event in filtered_df.keys(): print(i_event) print(filtered_df[i_event]) return filtered_df elif event in self._wave_name: return filtered_df[event] else: raise ValueError(f'sleep event {event} is not corrected')
[docs] def _get_event_df(self, wave_names): """Filters the event dataframe by wave names. Args: wave_names (list): List of wave names to filter by. Returns: DataFrame: Filtered event dataframe for wave_names. """ filtered_df = self.event_df[self.event_df['label'].isin(wave_names)] return filtered_df
[docs] def plot_average(self, event_type = None, figsize=(6, 4.5),**kwargs): """Plots the average waveform of specified event types. Args: event_type (list, optional): List of event types to plot. Defaults to None. figsize (tuple, optional): Figure size for the subplot. Defaults to (6, 4.5). Returns: list: Axes of the average sleep event plot. """ if event_type is None: event_type = self._wave_name ###图位置参数的设置与波的特征有关(这部分后面再处理) if set(event_type).issubset(set( self._wave_name)): event_data = [] for i_event in event_type: i_event_df = self.event_df[self.event_df['label'] == i_event] segments = [] for index, row in i_event_df.iterrows(): center = (row['Start']+row['End'])/2 segment = self._data[:,int(center-1.5*self._sf): int(center+1.5*self._sf)] segments.append(segment) segment_ave = np.mean(np.array(segments),axis=0) event_data.append(segment_ave) figsize = (len(event_type) * figsize[0], figsize[1]) if len(event_type) == 1: fig, axs = plt.subplots(1, 1, figsize=figsize) axs = [axs] else: fig, axs = plt.subplots(1, len(event_type), figsize=figsize) for index, i_event in enumerate(event_type): if isinstance(event_data[index], float) and np.isnan(event_data[index]): axs[index].plot() elif isinstance(event_data[index], np.ndarray): times = np.arange(event_data[index].shape[-1]) / self._sf data = np.squeeze(event_data[index]) axs[index].plot(times, data, **kwargs) axs[index].set_title(f"Average {i_event}") axs[index].set_xlabel("Time (sec)") axs[index].set_ylabel("Amplitude (uV)") else: raise ValueError(f'event_type {event_type} is not corrected') return axs
[docs] def _get_mask(self): """Generates a mask for every event type. Returns: dict: Mask dictionary for every event type. """ mask_dict = {} for event in self._wave_name: i_event_df = self._get_event_df([event]) mask_dict[event] =[] for index, row in i_event_df.iterrows(): mask_dict[event].append(range(row['Start'], row['End'])) return mask_dict
[docs] def _plot_events(self, ax, event_type, cmap, norm, xrng=None): """Plots the detected events on the given axis. Args: ax (Axes): Matplotlib axis to plot on. event_type (list): List of event types to plot. cmap: Colormap to use for plotting. norm: Normalization for colormap. xrng (range, optional): X-axis range for plotting. Defaults to None. """ mask = self._get_mask() data = np.squeeze(self._data) for index, i_event in enumerate(event_type): for i_list in mask[i_event]: if xrng is not None: ax.plot(self._times[i_list][xrng], data[i_list][xrng], color=cmap(norm(index)), label=f'{i_event}') else: ax.plot(self._times[i_list], data[i_list], color=cmap(norm(index)), label=f'{i_event}')
[docs] def plot_detection(self, event_type=None, figsize=(12, 4), cmap='Spectral'): """Plots the detection of specified event types. Args: event_type (list, optional): List of event types to plot. Defaults to None. figsize (tuple, optional): Figure size for the plot. Defaults to (12, 4). cmap (str, optional): Colormap for plotting. Defaults to 'Spectral'. Returns: interactive: Interactive plot object. """ if event_type is None: event_type = self._wave_name if not set(event_type).issubset(set(self._wave_name)): raise ValueError(f'event_type {event_type} is not correct') cmap = get_cmap(cmap) win_size = 10 n_epochs = int((self._data.shape[-1] / self._sf) / win_size) data = np.squeeze(self._data) norm = Normalize(vmin=0, vmax=len(event_type) - 1) fig, ax = plt.subplots(figsize=figsize) initial_line, = ax.plot(self._times, data, "k", lw=1, label='Original Data') mask = self._get_mask() handles = [] labels = [] for index, i_event in enumerate(event_type): dummy_line, = ax.plot([], [], color=cmap(norm(index)), label=f'{i_event}') # 创建一个虚拟的line handles.append(dummy_line) labels.append(f'{i_event}') for i_list in mask[i_event]: ax.plot(self._times[i_list], data[i_list], color=cmap(norm(index))) ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(event_type)) plt.xlabel("Time (seconds)") plt.ylabel("Amplitude (uV)") fig.canvas.header_visible = False fig.tight_layout() layout = ipy.Layout(width="50%", justify_content="center", align_items="center") sl_ep = ipy.IntSlider( min=0, max=n_epochs, step=1, value=0, layout=layout, description="Epoch:", ) sl_amp = ipy.IntSlider( min=25, max=500, step=25, value=150, layout=layout, orientation="horizontal", description="Amplitude:", ) dd_win = ipy.Dropdown( options=[1, 5, 10, 30, 60], value=win_size, description="Window size:", ) def update(epoch, amplitude, win_size): """Update plot.""" n_epochs = int((self._data.shape[-1] / self._sf) / win_size) sl_ep.max = n_epochs xlim = [epoch * win_size, (epoch + 1) * win_size] xrng = np.arange(xlim[0] * self._sf, min(xlim[1] * self._sf, self._data.shape[-1]), dtype=int) try: initial_line.set_data(self._times[xrng], data[xrng]) except IndexError: pass for idx, i_event in enumerate(event_type): for i_list in mask[i_event]: try: ax.plot(self._times[i_list][xrng], data[i_list][xrng], color=cmap(norm(idx))) except IndexError: pass ax.set_xlim(xlim) ax.set_ylim([-amplitude, amplitude]) ax.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=len(event_type)) interact_obj = ipy.interact(update, epoch=sl_ep, amplitude=sl_amp, win_size=dd_win) return interact_obj
[docs] def calculate_feature(self, event_type = None,**kwargs): """Calculates features for the specified event types. Args: event_type (list, optional): List of event types to calculate features for. Defaults to None. Returns: dict: Calculated features for each event type. """ feature_dict ={} if event_type is None: event_type = self._wave_name for i_event_type in event_type: if i_event_type == "Spindle": feature_dict[i_event_type] = self.calculate_feature_spindle(**kwargs) elif i_event_type == "Slow wave": feature_dict[i_event_type] = self.calculate_feature_slow_wave(**kwargs) else: feature_dict[i_event_type] = self.calculate_feature_other(i_event_type,**kwargs) return feature_dict
[docs] def calculate_feature_spindle(self,**kwargs): """Calculate features of sleep spindles from EEG data. Args: freq_sp (tuple): Frequency range for spindle detection (default: (12, 15)). freq_broad (tuple): Broad frequency range for filtering (default: (1, 30)). Returns: pandas.DataFrame: DataFrame containing spindle features including Peak, Duration, Amplitude, RMS, AbsPower, RelPower, Frequency, Oscillations, and Symmetry. Notes: This function is based on the original implementation provided by https://github.com/raphaelvallat/yasa """ freq_sp = kwargs.get('freq_sp', (12, 15)) freq_broad = kwargs.get('freq_broad', (1, 30)) data = self._data n_samples = data.shape[-1] data_df = self._get_event_df(['Spindle']) new_columns = ["Peak", "Duration", "Amplitude", "RMS", "AbsPower", "RelPower", "Frequency", "Oscillations", "Symmetry"] for col in new_columns: data_df[col] = np.nan nfast = next_fast_len(n_samples) data_broad = filter_data(data, self._sf, freq_broad[0], freq_broad[1], method="fir", verbose=0) f, t, Sxx = stft_power(data_broad[0, :], self._sf, window=2, step=0.2, band=freq_broad, interp=False, norm=True) idx_sigma = np.logical_and(f >= freq_sp[0], f <= freq_sp[1]) rel_pow = Sxx[idx_sigma].sum(0) distance = 60 * self._sf / 1000 func = interp1d(t, rel_pow, kind="cubic", bounds_error=False, fill_value=0) t = np.arange(n_samples) / self._sf rel_pow = func(t) data_sigma = filter_data( data, self._sf, freq_sp[0], freq_sp[1], l_trans_bandwidth=1.5, h_trans_bandwidth=1.5, method="fir", verbose=0, ) analytic = signal.hilbert(data_sigma, N=nfast)[:, :n_samples] inst_phase = np.angle(analytic) inst_pow = np.square(np.abs(analytic)) inst_freq = self._sf / (2 * np.pi) * np.diff(inst_phase, axis=-1) for index, row in self.event_df.iterrows(): start_value = row['Start'] end_value = row['End'] segment_broad = np.squeeze(data_broad[:, int(start_value):int(end_value)]) # segment = data[:, int(start_value):int(end_value)] sp_x = np.arange(segment_broad.shape[-1], dtype=np.float64) sp_det = _detrend(sp_x, segment_broad) sp_amp= np.ptp(sp_det) # Peak-to-peak amplitude sp_rms= _rms(sp_det) # Root mean square sp_rel = np.median(rel_pow[start_value:end_value]) # Median relative power # Hilbert-based instantaneous properties sp_inst_freq = inst_freq[0, start_value:end_value] sp_inst_pow = inst_pow[0, start_value:end_value] sp_abs = np.median(np.log10(sp_inst_pow[sp_inst_pow > 0])) sp_freq = np.median(sp_inst_freq[sp_inst_freq > 0]) # Number of oscillations peaks, peaks_params = signal.find_peaks( sp_det, distance=distance, prominence=(None, None) ) sp_osc = len(peaks) pk = peaks[peaks_params["prominences"].argmax()] sp_pro = start_value/self._sf + pk / self._sf sp_sym = pk / sp_det.size data_df.loc[index, "Peak"] = sp_pro data_df.loc[index, "Amplitude"] = sp_amp data_df.loc[index, "RMS"] = sp_rms data_df.loc[index, "RMS"] = sp_rms data_df.loc[index, "AbsPower"] = sp_abs data_df.loc[index, "RelPower"] = sp_rel data_df.loc[index, "Frequency"] = sp_freq data_df.loc[index, "Oscillations"] = sp_osc data_df.loc[index, "Symmetry"] = sp_sym, return data_df
[docs] def calculate_feature_slow_wave(self,**kwargs): """Calculate features of slow waves from EEG data. Args: freq_sw (tuple): Frequency range for slow wave detection (default: (0.3, 1.5)). amp_neg (tuple): Amplitude range for negative peaks (default: (40, 200)). amp_pos (tuple): Amplitude range for positive peaks (default: (10, 150)). Returns: pandas.DataFrame: DataFrame containing slow wave features including NegPeak, MidCrossing, PosPeak, ValNegPeak, ValPosPeak, PTP, Slope, and Frequency. Notes: This function is based on the original implementation provided by https://github.com/raphaelvallat/yasa """ freq_sw = kwargs.get('freq_sw', (0.3, 1.5)) amp_neg = kwargs.get('amp_neg', (40, 200)) amp_pos = kwargs.get('amp_pos', (10, 150)) #sw data = self._data times = np.arange(data.size) / self._sf data_filt = filter_data( data, self._sf, freq_sw[0], freq_sw[1], method="fir", verbose=0, l_trans_bandwidth=0.2, h_trans_bandwidth=0.2, ) Slow_wave_data_df = self._get_event_df( ['Slow wave']) new_columns = ["NegPeak", "MidCrossing", "PosPeak", "ValNegPeak", "ValPosPeak", "PTP", "Slope", "Frequency"] data_df = pd.DataFrame(columns=new_columns, dtype=object) for index, idx_mask in enumerate(self._get_mask()['Slow wave']): idx_neg_peaks, _ = signal.find_peaks(-1 * data_filt[0, idx_mask],height=amp_neg)# idx_pos_peaks, _ = signal.find_peaks(data_filt[0, idx_mask],height=amp_pos)# idx_neg_peaks = np.intersect1d(idx_neg_peaks, idx_mask, assume_unique=True) idx_pos_peaks = np.intersect1d(idx_pos_peaks, idx_mask, assume_unique=True) # If no peaks are detected, return None if len(idx_neg_peaks) == 0 or len(idx_pos_peaks) == 0: logger.warning("no neg_peaks or pos_peaks") continue # Make sure that the last detected peak is a positive one if idx_pos_peaks[-1] < idx_neg_peaks[-1]: # If not, append a fake positive peak one sample after the last neg idx_pos_peaks = np.append(idx_pos_peaks, idx_neg_peaks[-1] + 1) # For each negative peak, we find the closest following positive peak pk_sorted = np.searchsorted(idx_pos_peaks, idx_neg_peaks) closest_pos_peaks = idx_pos_peaks[pk_sorted] - idx_neg_peaks closest_pos_peaks = closest_pos_peaks[np.nonzero(closest_pos_peaks)] idx_pos_peaks = idx_neg_peaks + closest_pos_peaks sw_ptp = np.abs(data_filt[0, idx_neg_peaks]) + data_filt[0, idx_pos_peaks] zero_crossings = _zerocrossings(data_filt[0, :]) # Make sure that there is a zero-crossing after the last detected peak if zero_crossings[-1] < max(idx_pos_peaks[-1], idx_neg_peaks[-1]): # If not, append the index of the last peak zero_crossings = np.append(zero_crossings, max(idx_pos_peaks[-1], idx_neg_peaks[-1])) # Find distance to previous and following zc neg_sorted = np.searchsorted(zero_crossings, idx_neg_peaks) previous_neg_zc = zero_crossings[neg_sorted - 1] - idx_neg_peaks following_neg_zc = zero_crossings[neg_sorted] - idx_neg_peaks # Distance between the positive peaks and the previous and # following zero-crossings pos_sorted = np.searchsorted(zero_crossings, idx_pos_peaks) previous_pos_zc = zero_crossings[pos_sorted - 1] - idx_pos_peaks following_pos_zc = zero_crossings[pos_sorted] - idx_pos_peaks # Duration of the negative and positive phases, in seconds neg_phase_dur = (np.abs(previous_neg_zc) + following_neg_zc) / self._sf pos_phase_dur = (np.abs(previous_pos_zc) + following_pos_zc) / self._sf # We now compute a set of metrics sw_start = times[idx_neg_peaks + previous_neg_zc] sw_end = times[idx_pos_peaks + following_pos_zc] # This should be the same as `sw_dur = pos_phase_dur + neg_phase_dur` # We round to avoid floating point errr (e.g. 1.9000000002) sw_dur = (sw_end - sw_start).round(4) sw_dur_both_phase = (pos_phase_dur + neg_phase_dur).round(4) sw_midcrossing = times[idx_neg_peaks + following_neg_zc] sw_idx_neg = times[idx_neg_peaks] # Location of negative peak sw_idx_pos = times[idx_pos_peaks] # Location of positive peak # Slope between peak trough and midcrossing sw_slope = sw_ptp / (sw_midcrossing - sw_idx_neg) data_df.loc[index, "NegPeak"] = sw_idx_neg.tolist() data_df.loc[index, "MidCrossing"] = sw_midcrossing.tolist() data_df.loc[index, "PosPeak"] = sw_idx_pos.tolist() data_df.loc[index, "ValNegPeak"] = data_filt[0,idx_neg_peaks].tolist() data_df.loc[index, "ValPosPeak"] = data_filt[0,idx_pos_peaks].tolist() data_df.loc[index, "PTP"] = sw_ptp.tolist() data_df.loc[index, "Slope"] = sw_slope.tolist() data_df.loc[index, "Frequency"] = 1 / sw_dur combined_df = pd.concat([Slow_wave_data_df, data_df], ignore_index=True) return combined_df
[docs] def calculate_feature_other(self,i_event_type): """Retrieve features of other specified event types from EEG data. Args: i_event_type (str): The event type for which features are to be retrieved. Returns: pandas.DataFrame: DataFrame containing features of the specified event type. """ data_df = self._get_event_df([i_event_type]) return data_df
if __name__ == '__main__': import mne # load data raw = mne.io.read_raw_edf('./SC4001E0-PSG.edf', preload=True) raw.filter(0.1, 40) data = raw.get_data(['EEG Fpz-Cz'], units="uV") print(data.shape) sf = 100 # test sleep_event_detect sp = sleep_event_detect(data[:, :50000], sf) # test calculate_feature sp.calculate_feature() # test summary sp.summary() print(sp.summary()) # test plot_average figure = sp.plot_average() plt.show() sp.plot_detection()