SSSM core example#

[ ]:
pip install sssm --upgrade
[53]:
import numpy as np
from sssm import ssm
import mne
import matplotlib.pyplot as plt

load data#

[37]:
raw = mne.io.read_raw_cnt(r"C:\Users\bkxcy\workspace\eeg_data_process\data\health\sleep\baoxiaoyu20211201s.cnt")
raw.load_data()

pre-process#

[37]:
raw = raw.filter(0.7,45)
data = (raw.resample(100).get_data(units="uV")).astype(np.float16)
Reading 0 ... 5604509  =      0.000 ... 22418.036 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.7 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.70
- Lower transition bandwidth: 0.70 Hz (-6 dB cutoff frequency: 0.35 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1179 samples (4.716 s)

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  36 out of  36 | elapsed:    4.7s finished

predict plot#

[42]:
model = ssm.SSM()
ret = model.predict(data[17:18],step=300)
model.plot_predictions()
../_images/example_sssm_example_8_0.png

input 1 epoch(30s) data#

[138]:
# model = ssm.SSM(model_name='ckp_last.pt')
model = ssm.SSM()
epoch_id = 200
input_data = data[17:18,epoch_id*3000:(epoch_id+1)*3000]
plt.figure(figsize=(50,4))
plt.plot(input_data.T)
plt.xlim([0,3000])
plt.show()
ret = model.predict(input_data,step=1)
model.plot_predictions()
../_images/example_sssm_example_10_0.png
../_images/example_sssm_example_10_1.png
[135]:
model = ssm.SSM(model_name='ckp_last.pt')
# model = ssm.SSM()
epoch_id = 200
input_data = data[17:18,epoch_id*3000:(epoch_id+1)*3000]
plt.figure(figsize=(50,4))
plt.plot(input_data.T)
plt.xlim([0,3000])
plt.show()
ret = model.predict(input_data,step=1)
model.plot_predictions()
../_images/example_sssm_example_11_0.png
../_images/example_sssm_example_11_1.png

input full night data#

[139]:
model = ssm.SSM(model_name='ckp_last.pt')
ret = model.predict(data[17:18],step=300)
model.plot_predictions()
../_images/example_sssm_example_13_0.png

get results#

[140]:
df = model.to_pandas(event_threshold={
                'Spindle': 0.8,
                'Background': 0.8,
                'Arousal': 0.8,
                'K-complex': 0.8,
                'Slow wave': 0.8,
                'Vertex Sharp': 0.8,
                'Sawtooth': 0.8})
df
[140]:
Start End Duration label predict_proba epoch_id
1 600 900 300 Arousal 0.996692 0
2 900 1800 900 Background 0.982475 0
4 2100 3300 1200 Background 0.956050 0
6 3600 4500 900 Arousal 0.982866 0
7 4500 4800 300 Background 0.995384 0
... ... ... ... ... ... ...
3296 2234100 2234400 300 Arousal 0.998787 0
3297 2234400 2235900 1500 Background 0.955171 0
3298 2235900 2240700 4800 Arousal 0.974922 0
3299 2240700 2241300 600 Background 0.932963 0
3300 2241300 2241804 504 Arousal 0.951562 0

2266 rows × 6 columns

[141]:
spindles_df = df.loc[df.label=='Spindle'].reset_index(drop=True)
kc_df = df.loc[df.label=='K-complex'].reset_index(drop=True)
spindles_df = spindles_df.sort_values('predict_proba',ascending=False).reset_index(drop=True)
kc_df = kc_df.sort_values('predict_proba',ascending=False).reset_index(drop=True)
[142]:
spindles_df
[142]:
Start End Duration label predict_proba epoch_id
0 1168800 1169100 300 Spindle 0.999875 0
1 1933200 1933500 300 Spindle 0.999850 0
2 1992000 1992300 300 Spindle 0.999827 0
3 1007400 1007700 300 Spindle 0.999812 0
4 1062600 1062900 300 Spindle 0.999789 0
... ... ... ... ... ... ...
248 1856400 1856700 300 Spindle 0.802168 0
249 1246200 1246500 300 Spindle 0.802033 0
250 1549200 1549500 300 Spindle 0.801724 0
251 860700 861000 300 Spindle 0.801277 0
252 1446000 1446600 600 Spindle 0.801162 0

253 rows × 6 columns

[143]:
kc_df
[143]:
Start End Duration label predict_proba epoch_id
0 891600 891900 300 K-complex 0.941446 0
1 695700 696000 300 K-complex 0.936204 0
2 1420200 1420500 300 K-complex 0.933121 0
3 957900 958200 300 K-complex 0.923686 0
4 1834500 1834800 300 K-complex 0.915197 0
5 797400 797700 300 K-complex 0.905449 0
6 814200 814500 300 K-complex 0.887404 0
7 1919400 1919700 300 K-complex 0.886306 0
8 735900 736200 300 K-complex 0.885708 0
9 1841400 1841700 300 K-complex 0.882073 0
10 665400 665700 300 K-complex 0.880695 0
11 719400 719700 300 K-complex 0.880499 0
12 1851300 1851600 300 K-complex 0.878774 0
13 1845000 1845300 300 K-complex 0.877886 0
14 684600 684900 300 K-complex 0.873030 0
15 1908300 1908600 300 K-complex 0.871325 0
16 963900 964200 300 K-complex 0.871283 0
17 1458900 1459200 300 K-complex 0.871187 0
18 644400 644700 300 K-complex 0.865622 0
19 786300 786600 300 K-complex 0.856225 0
20 1802100 1802400 300 K-complex 0.854385 0
21 874800 875100 300 K-complex 0.853656 0
22 735300 735600 300 K-complex 0.843922 0
23 783900 784200 300 K-complex 0.842278 0
24 1133700 1134000 300 K-complex 0.836957 0
25 1793100 1793400 300 K-complex 0.836822 0
26 671700 672000 300 K-complex 0.829499 0
27 669000 669300 300 K-complex 0.826817 0
28 1126800 1127100 300 K-complex 0.825165 0
29 842100 842400 300 K-complex 0.824742 0
30 687600 687900 300 K-complex 0.822328 0
31 1243200 1243500 300 K-complex 0.819598 0
32 1305300 1305600 300 K-complex 0.818472 0
33 816000 816300 300 K-complex 0.804663 0
34 633000 633300 300 K-complex 0.802887 0

get predicted data#

[144]:
_id = 0
plt.plot(data[17,spindles_df.loc[_id].Start:spindles_df.loc[_id].End])
[144]:
[<matplotlib.lines.Line2D at 0x18ce10fbb80>]
../_images/example_sssm_example_20_1.png
[145]:
_id = 0
plt.plot(data[17,kc_df.loc[_id].Start:kc_df.loc[_id].End])
[145]:
[<matplotlib.lines.Line2D at 0x18ce1462760>]
../_images/example_sssm_example_21_1.png

get feature#

[146]:
feature = model.feature.cpu().numpy()

feature.shape # n_epoch, n_sample, m,n
[146]:
(1, 7472, 128, 15)
[148]:
feature = feature[0].reshape(feature.shape[1],-1)
feature.shape# n_sample, n_feature
[148]:
(7472, 1920)
[149]:
# from sklearn.decomposition import PCA
from sklearn.manifold import TSNE as PCA
import pandas as pd
import seaborn as sns
pca = PCA(n_components=2, random_state=0)
[150]:
transformed_feature = pca.fit_transform(X=feature)
C:\Users\bkxcy\anaconda3\envs\torch_cuda\lib\site-packages\sklearn\manifold\_t_sne.py:795: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.
  warnings.warn(
C:\Users\bkxcy\anaconda3\envs\torch_cuda\lib\site-packages\sklearn\manifold\_t_sne.py:805: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.
  warnings.warn(
[151]:
x_component, y_component = transformed_feature[:, 0], transformed_feature[:, 1]
feature_df = pd.DataFrame({
    'x':transformed_feature[:, 0],
    'y':transformed_feature[:, 1],
    'Event Type':[['SS','KC','SW','SAW','VSW','BG','MA'][i] for i in ret],
})
[152]:
# Plot our dimensionality-reduced (via PCA) dataset.
plt.figure(figsize=(10, 10), dpi=100)
# plt.scatter(x=x_component, y=y_component, c=test['labels'], cmap='tab20', s=5, alpha=8/10)
sns.scatterplot(data=feature_df,x='x',y='y',hue='Event Type',palette='bright',s=10,alpha=0.8)
plt.title('Sleep evtets classe map')
plt.legend(markerscale=1)
# plt.legend(['Spindle','K-complex','Slow-wave','Sawtooth','Vertex Sharp','Background','Arousal'])
# plt.legend()
plt.show()
../_images/example_sssm_example_28_0.png