Пример использования тензорного потока и кераса

Новый мир

Интерфейсы мозг-компьютер (BCI) - это новая технология, которая изменит определение того, что значит быть человеком: роботизированные протезы, цифровое сознание и телепатия - лишь небольшая часть того, что могут выполнять эти устройства. BCI станут частью следующей эпохи технологической революции. В этой статье показан пример того, как глубокое обучение можно применить к общедоступным данным для классификации нейронных колебаний. Нервные колебания, или мозговые волны, присутствуют в головном мозге из-за возбуждающих (ВПСП) и тормозных постсинаптических потенциалов (IPSP). Вариации мозговых волн при выполнении различных задач позволяют исследователям и ученым определять закономерности и использовать их для определенных целей. Цель этого эксперимента - использовать машинное обучение для прогнозирования использования рук: правая рука против левой.

Что такое нейронные сети?

Нейронные сети - это часть глубокого обучения, глубокое обучение - это сегмент машинного обучения, а машинное обучение - это отрасль искусственного интеллекта. Эта справочная информация может показаться сложной, но ее не так важно запоминать. Важно то, что искусственные нейронные сети были предназначены для моделирования биологических нейронных сетей в нашем мозгу. Точно так же, как у последнего есть дендриты для получения информации и аксоны для отправки информации, у первого есть узлы, которые получают информацию с одного конца и распространяют ее на другой.

Биологические нейронные сети также утяжеляют связи за счет синаптической пластичности, которая изменяет (в первую очередь) концентрацию глутаматных рецепторов на синаптической мембране. Его искусственный аналог манипулирует цифровыми весами для оптимизации для наиболее точного результата. Серия скрытых слоев между входным и выходным слоями будет составлять модель. Каждая связь между каждым узлом имеет веса и смещения, которые помогают повысить точность результатов.

Функции активации - еще одна важная часть искусственных нейронных сетей. Эти функции определяют, передавать ли вывод следующему узлу или, по крайней мере, в какой форме его следует передавать. Функции активации добавляют к сети элемент нелинейности, так что модель не является линейной функцией. Они также обеспечивают поддержание данных в определенном диапазоне. Например, функция активации softmax будет выводить определенные значения для разных значений ввода (0 или 1).

Создание нейронной сети

Создание такой модели не так сложно, как вы думаете, особенно с использованием различных библиотек машинного обучения на Python. Sklearn невероятно популярен благодаря своим встроенным функциям и моделям, но не тот, который большинство людей использовали бы для глубокого обучения. PyTorch - еще один популярный вариант, созданный Facebook. Однако я буду использовать вариант TensorFlow. TensorFlow (хотя установка на Mac M1 является абсолютной проблемой) - отличный вариант для глубокого обучения. Возможно, потому, что он был разработан Google, TensorFlow прост в изучении и использовании. Многие онлайн-учебники используют TensorFlow, и существует так много онлайн-форумов, на которых можно найти ответы на вопросы по этому поводу, что вам будет трудно не найти возможный ответ на вашу проблему.

Ниже приведен пример модели, которую я использовал, чтобы предсказать, есть ли у человека сердечное заболевание.

"""Building the model"""
model = Sequential([
    Dense(units=16, input_shape=(13, ), activation='relu'),
    Dense(units=32, activation='relu'),
    Dense(units=32, activation='relu'),
    Dense(units=2, activation='softmax')
])

model.summary()

model.compile(optimizer=Adam(learning_rate=0.0001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x=train_features, y=train_labels, batch_size=10, epochs=50, shuffle=True, verbose=2)

"""Predictions"""
predictions = model.predict(x=test_features, batch_size=10, verbose=0)
rounded_predictions = np.argmax(predictions, axis=-1)
cm = confusion_matrix(y_true=test_labels, y_pred=rounded_predictions)

Эта модель была построена с использованием одиннадцати строк кода. Чтобы создать точную модель, не потребовалось много времени. Модель имеет один входной слой с 16 единицами, два скрытых слоя с 32 единицами и один выходной слой с 2 ​​единицами для каждого предсказанного класса.

Следует упомянуть важную особенность: функции активации: relu означает выпрямленные линейные блоки и является базовой функцией активации для любой нейронной сети. Простым примером этой функции активации является любая линейная функция от x с областью, большей или равной нулю. Диапазон зависит от смещения (точки пересечения по оси Y) функции. Эта функция хороша по многим причинам, но две из них заключаются в том, что она учитывает нелинейности и взаимодействие функций. Последнее лучше всего пояснить на примере: возраст играет большую роль для пациентов с сердечными заболеваниями. Однако вес и рост также будут влиять на результат. Для человека низкого роста идеальный вес будет отличаться от идеального веса для человека высокого роста. Если у вас избыточный вес в молодом возрасте по сравнению с пожилым, ваш прогноз будет отличаться. Таким образом, функции имеют эффект взаимодействия.

Функция активации softmax - это стандартная функция финальной активации нейронных сетей. Эта функция нормализует выходные данные по результатам, чтобы наилучшим образом выбрать прогнозируемый класс.

Ранее в коде я разделил свой набор данных на наборы для обучения и тестирования. Если вы тестируете (прогнозируете) на тех же данных, на которых тренировались, проверка точности вашей модели будет нарушена. Тестировать нужно на отдельном наборе данных. Это также поможет вам убедиться, что вы не переусердствуете с данными. (Переобучение данных относится к сценарию, в котором ваша модель настолько хорошо обучена на своих практических данных, что не может точно предсказать какие-либо другие данные.) Если результаты, которые мы получаем из нашего набора тестов, также достаточно хороши, мы можем быть уверены в том, что созданная модель.

Ниже приведены результаты обучения последних десяти эпох этой модели:

Epoch 40/50
242/242 - 0s - loss: 0.2496 - accuracy: 0.9215
Epoch 41/50
242/242 - 0s - loss: 0.2457 - accuracy: 0.9215
Epoch 42/50
242/242 - 0s - loss: 0.2509 - accuracy: 0.9215
Epoch 43/50
242/242 - 0s - loss: 0.2457 - accuracy: 0.9132
Epoch 44/50
242/242 - 0s - loss: 0.2343 - accuracy: 0.9256
Epoch 45/50
242/242 - 0s - loss: 0.2350 - accuracy: 0.9256
Epoch 46/50
242/242 - 0s - loss: 0.2352 - accuracy: 0.9298
Epoch 47/50
242/242 - 0s - loss: 0.2271 - accuracy: 0.9256
Epoch 48/50
242/242 - 0s - loss: 0.2266 - accuracy: 0.9298
Epoch 49/50
242/242 - 0s - loss: 0.2193 - accuracy: 0.9339
Epoch 50/50
242/242 - 0s - loss: 0.2203 - accuracy: 0.9339
Confusion matrix, without normalization
Process finished with exit code 0

Точность модели составила 93,39% с потерей 0,2203. Когда убытки начинают увеличиваться, это свидетельствует о переобучении. Зная это, пятьдесят эпох, вероятно, было хорошим местом, чтобы остановиться.

То, что точность валидации не выше, чем точность обучения, хороший знак. (Если бы было наоборот, это обычно было бы красным флажком.) Точность обучения и тестирования модели дала положительные результаты.

Эксперимент

1. Сбор данных

Первым шагом этого эксперимента был сбор данных. У меня нет ЭЭГ-гарнитуры для сбора данных от себя, поэтому мне нужно было найти кое-что в Интернете. К счастью, PhysioNet смогла предоставить данные, использованные в исследовании, опубликованном в 2004 году. Эти данные имеют 109 субъектов с 14 запусками на каждого предмета. Я случайным образом выбрал испытуемый номер 4. Для разных задач использовались разные прогоны: в забегах 3, 4, 7, 8, 11 и 12 испытуемые визуализировали движение левой или правой рукой, а в забегах 5, 6, 9, 10, 13, и 14 испытуемых представили движение в руках или ногах.

"""Finding patient data"""
# T0 corresponds to rest
# T1 corresponds to left fist
# T2 corresponds to right fist
raw3 = mne.io.read_raw_edf('S004R03.edf', preload=True)
raw4 = mne.io.read_raw_edf('S004R04.edf', preload=True)
raw7 = mne.io.read_raw_edf('S004R07.edf', preload=True)
raw8 = mne.io.read_raw_edf('S004R08.edf', preload=True)
raw11 = mne.io.read_raw_edf('S004R11.edf', preload=True)
raw12 = mne.io.read_raw_edf('S004R12.edf', preload=True)

Мне удалось скачать тематические файлы и открыть их с помощью MNE-Python.

2. Объединяйте и очищайте файлы

Эти файлы сами по себе не содержат большого количества данных ЭЭГ, поэтому я объединил их, а затем отфильтровал. Фильтрация необходима для большинства ЭЭГ, поскольку большинство файлов пациентов будут содержать большие артефакты, которые искажают анализ.

raws = mne.concatenate_raws([raw3, raw4, raw7, raw8, raw11, raw12])
# raws.plot()

# strip channel names of "." characters
raws.rename_channels(lambda x: x.strip('.'))

# Apply band-pass filter
raws.filter(7., 30., method='iir')
# raws.plot()

3. Создание эпох

Мы хотим создавать эпохи для анализа с помощью ЭЭГ. Данные этих эпох станут функциями, переданными в модель.

events = events_from_annotations(raws)

picks = pick_types(raws.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

epochs = Epochs(raws, events[0], event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)

4. Создание элементов и меток.

Функции и метки сопоставимы с X и Y. Компонент Y - это просто метки T0, T1, T2, которые пришли с данными. Т0 бесполезен (пациент в состоянии покоя), а Т1, Т2 заменены на 0 и 1 соответственно. Однако X-компоненту нужно немного доработать. В настоящее время у нас есть временная информация по эпохам. Это данные, относящиеся ко времени, когда это происходит. Используя алгоритм CSP от MNE, мы можем извлекать пространственную информацию как объекты.

Y = epochs.events[:, -1] - 2

epochs_data = epochs.get_data()
csp = CSP(n_components=4, log=True, reg=None)

X = csp.fit_transform(epochs_data, Y)

5. Создание и подгонка модели.

Этот процесс очень похож на тот, который использовался в качестве примера ранее. Я создал четыре плотно связанных слоя, а затем пропустил обучающие элементы и метки через модель.

model = Sequential([
    Dense(units=8, input_shape=(4,), activation='relu'),
    Dense(units=16, activation='relu'),
    Dense(units=16, activation='relu'),
    Dense(units=2, activation='softmax')
])

model.summary()

train_features, test_features, train_labels, test_labels = train_test_split(X, Y, train_size=0.80, test_size=0.20)

model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x=train_features, y=train_labels, validation_data=(test_features, test_labels), batch_size=10, epochs=75, shuffle=True, verbose=2)

6. Прогнозирование данных испытаний.

Этот процесс аналогичен описанному выше.

"""Predictions"""
predictions = model.predict(x=test_features, batch_size=10, verbose=0)
rounded_predictions = np.argmax(predictions, axis=-1)
cm = confusion_matrix(y_true=test_labels, y_pred=rounded_predictions)

"""Confusion Matrix"""
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion Matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print("Confusion matrix, without normalization")

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

cm_plot_labels = ['Left Hand', 'Right Hand']
plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title="Confusion Matrix")
plt.show()

Результат выглядит следующим образом:

/Applications/anaconda3/envs/tf/bin/python /Applications/anaconda3/envs/tf/EEGBCI.py
Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R03.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R07.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R08.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R11.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R12.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Filtering raw data in 6 contiguous segments
Setting up band-pass filter from 7 - 30 Hz
IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 7.00, 30.00 Hz: -6.02, -6.02 dB
Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
Not setting metadata
90 matching events found
No baseline correction applied
0 projection items activated
Loading data for 90 events and 801 original time points ...
0 bad epochs dropped
Computing rank from data with rank=None
    Using tolerance 0.00013 (2.2e-16 eps * 64 dim * 9.3e+09  max singular value)
    Estimated rank (mag): 64
    MAG: rank 64 computed from 64 data channels with 0 projectors
Reducing data rank from 64 -> 64
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 0.00013 (2.2e-16 eps * 64 dim * 9.3e+09  max singular value)
    Estimated rank (mag): 64
    MAG: rank 64 computed from 64 data channels with 0 projectors
Reducing data rank from 64 -> 64
Estimating covariance using EMPIRICAL
Done.
2021-08-19 16:44:50.608926: I tensorflow/core/platform/cpu_feature_guard.cc:145] This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following CPU instructions in performance critical operations:  SSE4.1 SSE4.2
To enable them in non-MKL-DNN operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-19 16:44:50.610742: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 8. Tune using inter_op_parallelism_threads for best performance.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 8)                 40        
_________________________________________________________________
dense_1 (Dense)              (None, 16)                144       
_________________________________________________________________
dense_2 (Dense)              (None, 16)                272       
_________________________________________________________________
dense_3 (Dense)              (None, 2)                 34        
=================================================================
Total params: 490
Trainable params: 490
Non-trainable params: 0
_________________________________________________________________
Train on 72 samples, validate on 18 samples
Epoch 1/75
72/72 - 1s - loss: 0.6999 - accuracy: 0.4583 - val_loss: 0.7009 - val_accuracy: 0.5000
Epoch 2/75
72/72 - 0s - loss: 0.6887 - accuracy: 0.4722 - val_loss: 0.6888 - val_accuracy: 0.5556
Epoch 3/75
72/72 - 0s - loss: 0.6801 - accuracy: 0.5139 - val_loss: 0.6803 - val_accuracy: 0.5556
Epoch 4/75
72/72 - 0s - loss: 0.6742 - accuracy: 0.5694 - val_loss: 0.6692 - val_accuracy: 0.7222
Epoch 5/75
72/72 - 0s - loss: 0.6658 - accuracy: 0.6111 - val_loss: 0.6617 - val_accuracy: 0.7778
Epoch 6/75
72/72 - 0s - loss: 0.6600 - accuracy: 0.6528 - val_loss: 0.6550 - val_accuracy: 0.7778
Epoch 7/75
72/72 - 0s - loss: 0.6543 - accuracy: 0.6667 - val_loss: 0.6475 - val_accuracy: 0.8333
Epoch 8/75
72/72 - 0s - loss: 0.6475 - accuracy: 0.6806 - val_loss: 0.6397 - val_accuracy: 0.8333
Epoch 9/75
72/72 - 0s - loss: 0.6411 - accuracy: 0.7500 - val_loss: 0.6315 - val_accuracy: 0.7778
Epoch 10/75
72/72 - 0s - loss: 0.6333 - accuracy: 0.7639 - val_loss: 0.6217 - val_accuracy: 0.7778
Epoch 11/75
72/72 - 0s - loss: 0.6245 - accuracy: 0.7361 - val_loss: 0.6101 - val_accuracy: 0.7778
Epoch 12/75
72/72 - 0s - loss: 0.6141 - accuracy: 0.7778 - val_loss: 0.5976 - val_accuracy: 0.7778
Epoch 13/75
72/72 - 0s - loss: 0.6010 - accuracy: 0.8056 - val_loss: 0.5843 - val_accuracy: 0.7778
Epoch 14/75
72/72 - 0s - loss: 0.5875 - accuracy: 0.8333 - val_loss: 0.5687 - val_accuracy: 0.8333
Epoch 15/75
72/72 - 0s - loss: 0.5735 - accuracy: 0.8472 - val_loss: 0.5511 - val_accuracy: 0.8333
Epoch 16/75
72/72 - 0s - loss: 0.5563 - accuracy: 0.8611 - val_loss: 0.5322 - val_accuracy: 0.8333
Epoch 17/75
72/72 - 0s - loss: 0.5400 - accuracy: 0.8750 - val_loss: 0.5114 - val_accuracy: 0.8889
Epoch 18/75
72/72 - 0s - loss: 0.5221 - accuracy: 0.8750 - val_loss: 0.4916 - val_accuracy: 0.8889
Epoch 19/75
72/72 - 0s - loss: 0.5065 - accuracy: 0.8889 - val_loss: 0.4721 - val_accuracy: 0.8889
Epoch 20/75
72/72 - 0s - loss: 0.4900 - accuracy: 0.8750 - val_loss: 0.4512 - val_accuracy: 0.8889
Epoch 21/75
72/72 - 0s - loss: 0.4773 - accuracy: 0.8611 - val_loss: 0.4321 - val_accuracy: 0.9444
Epoch 22/75
72/72 - 0s - loss: 0.4588 - accuracy: 0.9028 - val_loss: 0.4109 - val_accuracy: 0.9444
Epoch 23/75
72/72 - 0s - loss: 0.4490 - accuracy: 0.8889 - val_loss: 0.3907 - val_accuracy: 0.9444
Epoch 24/75
72/72 - 0s - loss: 0.4278 - accuracy: 0.9028 - val_loss: 0.3750 - val_accuracy: 0.9444
Epoch 25/75
72/72 - 0s - loss: 0.4179 - accuracy: 0.8889 - val_loss: 0.3530 - val_accuracy: 0.9444
Epoch 26/75
72/72 - 0s - loss: 0.4014 - accuracy: 0.9167 - val_loss: 0.3346 - val_accuracy: 0.9444
Epoch 27/75
72/72 - 0s - loss: 0.3901 - accuracy: 0.9028 - val_loss: 0.3214 - val_accuracy: 1.0000
Epoch 28/75
72/72 - 0s - loss: 0.3799 - accuracy: 0.8889 - val_loss: 0.3024 - val_accuracy: 1.0000
Epoch 29/75
72/72 - 0s - loss: 0.3651 - accuracy: 0.9028 - val_loss: 0.2878 - val_accuracy: 0.9444
Epoch 30/75
72/72 - 0s - loss: 0.3569 - accuracy: 0.9167 - val_loss: 0.2733 - val_accuracy: 0.9444
Epoch 31/75
72/72 - 0s - loss: 0.3464 - accuracy: 0.9167 - val_loss: 0.2605 - val_accuracy: 1.0000
Epoch 32/75
72/72 - 0s - loss: 0.3412 - accuracy: 0.9028 - val_loss: 0.2508 - val_accuracy: 1.0000
Epoch 33/75
72/72 - 0s - loss: 0.3320 - accuracy: 0.9028 - val_loss: 0.2402 - val_accuracy: 0.9444
Epoch 34/75
72/72 - 0s - loss: 0.3250 - accuracy: 0.9167 - val_loss: 0.2285 - val_accuracy: 1.0000
Epoch 35/75
72/72 - 0s - loss: 0.3188 - accuracy: 0.9028 - val_loss: 0.2196 - val_accuracy: 1.0000
Epoch 36/75
72/72 - 0s - loss: 0.3119 - accuracy: 0.9167 - val_loss: 0.2093 - val_accuracy: 1.0000
Epoch 37/75
72/72 - 0s - loss: 0.3040 - accuracy: 0.9167 - val_loss: 0.2006 - val_accuracy: 1.0000
Epoch 38/75
72/72 - 0s - loss: 0.3024 - accuracy: 0.9167 - val_loss: 0.1945 - val_accuracy: 0.9444
Epoch 39/75
72/72 - 0s - loss: 0.2962 - accuracy: 0.9167 - val_loss: 0.1835 - val_accuracy: 1.0000
Epoch 40/75
72/72 - 0s - loss: 0.2908 - accuracy: 0.9167 - val_loss: 0.1767 - val_accuracy: 1.0000
Epoch 41/75
72/72 - 0s - loss: 0.2878 - accuracy: 0.9167 - val_loss: 0.1717 - val_accuracy: 1.0000
Epoch 42/75
72/72 - 0s - loss: 0.2839 - accuracy: 0.9167 - val_loss: 0.1648 - val_accuracy: 1.0000
Epoch 43/75
72/72 - 0s - loss: 0.2819 - accuracy: 0.9306 - val_loss: 0.1588 - val_accuracy: 1.0000
Epoch 44/75
72/72 - 0s - loss: 0.2795 - accuracy: 0.9167 - val_loss: 0.1600 - val_accuracy: 1.0000
Epoch 45/75
72/72 - 0s - loss: 0.2895 - accuracy: 0.8889 - val_loss: 0.1636 - val_accuracy: 0.9444
Epoch 46/75
72/72 - 0s - loss: 0.2839 - accuracy: 0.9167 - val_loss: 0.1453 - val_accuracy: 1.0000
Epoch 47/75
72/72 - 0s - loss: 0.2719 - accuracy: 0.9306 - val_loss: 0.1406 - val_accuracy: 1.0000
Epoch 48/75
72/72 - 0s - loss: 0.2679 - accuracy: 0.9306 - val_loss: 0.1375 - val_accuracy: 1.0000
Epoch 49/75
72/72 - 0s - loss: 0.2674 - accuracy: 0.9306 - val_loss: 0.1336 - val_accuracy: 1.0000
Epoch 50/75
72/72 - 0s - loss: 0.2741 - accuracy: 0.9028 - val_loss: 0.1349 - val_accuracy: 1.0000
Epoch 51/75
72/72 - 0s - loss: 0.2609 - accuracy: 0.9306 - val_loss: 0.1285 - val_accuracy: 1.0000
Epoch 52/75
72/72 - 0s - loss: 0.2667 - accuracy: 0.9028 - val_loss: 0.1261 - val_accuracy: 1.0000
Epoch 53/75
72/72 - 0s - loss: 0.2603 - accuracy: 0.9167 - val_loss: 0.1220 - val_accuracy: 1.0000
Epoch 54/75
72/72 - 0s - loss: 0.2563 - accuracy: 0.9306 - val_loss: 0.1216 - val_accuracy: 1.0000
Epoch 55/75
72/72 - 0s - loss: 0.2539 - accuracy: 0.9306 - val_loss: 0.1168 - val_accuracy: 1.0000
Epoch 56/75
72/72 - 0s - loss: 0.2544 - accuracy: 0.9306 - val_loss: 0.1143 - val_accuracy: 1.0000
Epoch 57/75
72/72 - 0s - loss: 0.2542 - accuracy: 0.9167 - val_loss: 0.1117 - val_accuracy: 1.0000
Epoch 58/75
72/72 - 0s - loss: 0.2544 - accuracy: 0.9167 - val_loss: 0.1088 - val_accuracy: 1.0000
Epoch 59/75
72/72 - 0s - loss: 0.2497 - accuracy: 0.9306 - val_loss: 0.1071 - val_accuracy: 1.0000
Epoch 60/75
72/72 - 0s - loss: 0.2452 - accuracy: 0.9306 - val_loss: 0.1085 - val_accuracy: 1.0000
Epoch 61/75
72/72 - 0s - loss: 0.2578 - accuracy: 0.8889 - val_loss: 0.1142 - val_accuracy: 1.0000
Epoch 62/75
72/72 - 0s - loss: 0.2608 - accuracy: 0.8889 - val_loss: 0.1109 - val_accuracy: 1.0000
Epoch 63/75
72/72 - 0s - loss: 0.2617 - accuracy: 0.9028 - val_loss: 0.0988 - val_accuracy: 1.0000
Epoch 64/75
72/72 - 0s - loss: 0.2458 - accuracy: 0.9306 - val_loss: 0.0993 - val_accuracy: 1.0000
Epoch 65/75
72/72 - 0s - loss: 0.2463 - accuracy: 0.9167 - val_loss: 0.0968 - val_accuracy: 1.0000
Epoch 66/75
72/72 - 0s - loss: 0.2417 - accuracy: 0.9306 - val_loss: 0.0943 - val_accuracy: 1.0000
Epoch 67/75
72/72 - 0s - loss: 0.2418 - accuracy: 0.9306 - val_loss: 0.0934 - val_accuracy: 1.0000
Epoch 68/75
72/72 - 0s - loss: 0.2410 - accuracy: 0.9306 - val_loss: 0.0922 - val_accuracy: 1.0000
Epoch 69/75
72/72 - 0s - loss: 0.2406 - accuracy: 0.9306 - val_loss: 0.0966 - val_accuracy: 1.0000
Epoch 70/75
72/72 - 0s - loss: 0.2416 - accuracy: 0.9028 - val_loss: 0.0922 - val_accuracy: 1.0000
Epoch 71/75
72/72 - 0s - loss: 0.2402 - accuracy: 0.9306 - val_loss: 0.0875 - val_accuracy: 1.0000
Epoch 72/75
72/72 - 0s - loss: 0.2380 - accuracy: 0.9306 - val_loss: 0.0862 - val_accuracy: 1.0000
Epoch 73/75
72/72 - 0s - loss: 0.2390 - accuracy: 0.9167 - val_loss: 0.0849 - val_accuracy: 1.0000
Epoch 74/75
72/72 - 0s - loss: 0.2350 - accuracy: 0.9306 - val_loss: 0.0834 - val_accuracy: 1.0000
Epoch 75/75
72/72 - 0s - loss: 0.2315 - accuracy: 0.9306 - val_loss: 0.0863 - val_accuracy: 1.0000
Confusion matrix, without normalization
Process finished with exit code 0

Возможно, я установил слишком высокие эпохи, поскольку потери начинают расти, но это быстрое и легкое решение, которое приведет к аналогичным результатам. Точность обучающей выборки составила 93,06%, а точность тестовой выборки - 100%. Обычно это большой красный флаг. Огромный красный флаг. Однако вы заметите, что набор для тестирования довольно невелик. Рассмотрены не сотни дел - всего восемнадцать. Поэтому неудивительно, что точность валидации настолько высока.

Используя больший размер тестирования, можно сгенерировать матрицу неточностей, подобную приведенной ниже.

Другой вариант

Вместо загрузки всех необходимых данных на ваш компьютер, MNE предлагает альтернативу именно для этого набора данных. Созданная мной функция find_patient подключится к MNE и загрузит набор данных прямо в каталог, в котором вы запускаете свой код.

"""Importing modules"""
import numpy as np
import matplotlib.pyplot as plt

from mne import Epochs, pick_types, find_events, events_from_annotations
from mne.channels import read_layout
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci
from mne.decoding import CSP

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Activation, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import itertools

"""Initializing main variables"""
# avoid classification of evoked responses by using epochs that start 1s after
# cue onset.
tmin, tmax = -1., 4.
event_id = {
    'Left Hand': 2,
    'Right Hand': 3
}
# Use these runs
runs = [3, 4, 7, 8, 11, 12]  # motor imagery: left hand vs right hand

"""Finding correct patient"""
def find_patient(subject, runs):
    raw_fnames = eegbci.load_data(subject, runs)
    raw_files = [read_raw_edf(f, preload=True) for f in raw_fnames]
    raw = concatenate_raws(raw_files)

    # strip channel names of "." characters
    raw.rename_channels(lambda x: x.strip('.'))

    # Apply band-pass filter
    raw.filter(7., 30., method='iir')

    events = events_from_annotations(raw)

    picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')

    # Read epochs (train will be done only between 1 and 2s)
    # Testing will be done with a running classifier
    epochs = Epochs(raw, events[0], event_id, tmin, tmax, proj=True, picks=picks,
                    baseline=None, preload=True)

    Y = epochs.events[:, -1] - 2

    epochs_data = epochs.get_data()
    csp = CSP(n_components=4, log=True, reg=None)

    X = csp.fit_transform(epochs_data, Y)

    data = [X, Y]

    return data

subject4 = 4
training_patient = find_patient(subject4, runs)

"""Building the model"""
model = Sequential([
    Dense(units=8, input_shape=(4,), activation='relu'),
    Dense(units=16, activation='relu'),
    Dense(units=16, activation='relu'),
    Dense(units=2, activation='softmax')
])

model.summary()

train_features, test_features, train_labels, test_labels = train_test_split(training_patient[0], training_patient[1], train_size=0.90, test_size=0.10)

model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x=train_features, y=train_labels, validation_data=(test_features, test_labels), batch_size=10, epochs=80, shuffle=True, verbose=2)

"""Predictions"""
# subject5 = 5
# testing_patient = find_patient(subject5, runs)

predictions = model.predict(x=test_features, batch_size=10, verbose=0)
rounded_predictions = np.argmax(predictions, axis=-1)
cm = confusion_matrix(y_true=test_labels, y_pred=rounded_predictions)

"""Confusion Matrix"""
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion Matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print("Confusion matrix, without normalization")

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

cm_plot_labels = ['Left Hand', 'Right Hand']
plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title="Confusion Matrix")
plt.show()

Вывод

Используя онлайн-данные, собранные в опубликованном исследовании, я смог классифицировать мысли пациентов, чтобы определить, думают ли они о движении в их левой или правой руке. Успех классификации составил 93,06% на обучающей выборке и 100% на тестовой. Повторный прогон модели был выполнен с большим набором тестов для лучшего обзора модели, и точность обучения составила 95,16%, а точность тестирования - 88,89%.

«Мы можем видеть только небольшое расстояние впереди, но мы можем видеть там много того, что нужно сделать». - Алан Тьюринг

Это пример того, как машинное обучение и BCI пересекаются. Однако этот пример можно легко распространить на управление протезами с помощью нейронных сетей и нейронных колебаний. Кроме того, это пример того, как медицина и наука о данных пересекаются и полагаются друг на друга для роста; настоящие мутуалистические отношения. Любой, у кого есть доступ к компьютеру, может начать программировать на Python и тестировать онлайн-наборы данных.

BCI - это будущее медицины, и они будут продолжать расти и развиваться. Поскольку внедрение ИМК в больницах и среди пациентов будет увеличиваться, не будет времени, когда мы вспомним, что у них их не было. Кроме того, затронуты не только миллионы пациентов, которые будут спасены с помощью этой технологии; BCI могут применяться практически в любой отрасли, от потребительских товаров до военной. Это начало медицинской и технической революции.