Пример использования тензорного потока и кераса
Новый мир
Интерфейсы мозг-компьютер (BCI) - это новая технология, которая изменит определение того, что значит быть человеком: роботизированные протезы, цифровое сознание и телепатия - лишь небольшая часть того, что могут выполнять эти устройства. BCI станут частью следующей эпохи технологической революции. В этой статье показан пример того, как глубокое обучение можно применить к общедоступным данным для классификации нейронных колебаний. Нервные колебания, или мозговые волны, присутствуют в головном мозге из-за возбуждающих (ВПСП) и тормозных постсинаптических потенциалов (IPSP). Вариации мозговых волн при выполнении различных задач позволяют исследователям и ученым определять закономерности и использовать их для определенных целей. Цель этого эксперимента - использовать машинное обучение для прогнозирования использования рук: правая рука против левой.
Что такое нейронные сети?
Нейронные сети - это часть глубокого обучения, глубокое обучение - это сегмент машинного обучения, а машинное обучение - это отрасль искусственного интеллекта. Эта справочная информация может показаться сложной, но ее не так важно запоминать. Важно то, что искусственные нейронные сети были предназначены для моделирования биологических нейронных сетей в нашем мозгу. Точно так же, как у последнего есть дендриты для получения информации и аксоны для отправки информации, у первого есть узлы, которые получают информацию с одного конца и распространяют ее на другой.
Биологические нейронные сети также утяжеляют связи за счет синаптической пластичности, которая изменяет (в первую очередь) концентрацию глутаматных рецепторов на синаптической мембране. Его искусственный аналог манипулирует цифровыми весами для оптимизации для наиболее точного результата. Серия скрытых слоев между входным и выходным слоями будет составлять модель. Каждая связь между каждым узлом имеет веса и смещения, которые помогают повысить точность результатов.
Функции активации - еще одна важная часть искусственных нейронных сетей. Эти функции определяют, передавать ли вывод следующему узлу или, по крайней мере, в какой форме его следует передавать. Функции активации добавляют к сети элемент нелинейности, так что модель не является линейной функцией. Они также обеспечивают поддержание данных в определенном диапазоне. Например, функция активации softmax будет выводить определенные значения для разных значений ввода (0 или 1).
Создание нейронной сети
Создание такой модели не так сложно, как вы думаете, особенно с использованием различных библиотек машинного обучения на Python. Sklearn невероятно популярен благодаря своим встроенным функциям и моделям, но не тот, который большинство людей использовали бы для глубокого обучения. PyTorch - еще один популярный вариант, созданный Facebook. Однако я буду использовать вариант TensorFlow. TensorFlow (хотя установка на Mac M1 является абсолютной проблемой) - отличный вариант для глубокого обучения. Возможно, потому, что он был разработан Google, TensorFlow прост в изучении и использовании. Многие онлайн-учебники используют TensorFlow, и существует так много онлайн-форумов, на которых можно найти ответы на вопросы по этому поводу, что вам будет трудно не найти возможный ответ на вашу проблему.
Ниже приведен пример модели, которую я использовал, чтобы предсказать, есть ли у человека сердечное заболевание.
"""Building the model""" model = Sequential([ Dense(units=16, input_shape=(13, ), activation='relu'), Dense(units=32, activation='relu'), Dense(units=32, activation='relu'), Dense(units=2, activation='softmax') ]) model.summary() model.compile(optimizer=Adam(learning_rate=0.0001), loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x=train_features, y=train_labels, batch_size=10, epochs=50, shuffle=True, verbose=2) """Predictions""" predictions = model.predict(x=test_features, batch_size=10, verbose=0) rounded_predictions = np.argmax(predictions, axis=-1) cm = confusion_matrix(y_true=test_labels, y_pred=rounded_predictions)
Эта модель была построена с использованием одиннадцати строк кода. Чтобы создать точную модель, не потребовалось много времени. Модель имеет один входной слой с 16 единицами, два скрытых слоя с 32 единицами и один выходной слой с 2 единицами для каждого предсказанного класса.
Следует упомянуть важную особенность: функции активации: relu означает выпрямленные линейные блоки и является базовой функцией активации для любой нейронной сети. Простым примером этой функции активации является любая линейная функция от x с областью, большей или равной нулю. Диапазон зависит от смещения (точки пересечения по оси Y) функции. Эта функция хороша по многим причинам, но две из них заключаются в том, что она учитывает нелинейности и взаимодействие функций. Последнее лучше всего пояснить на примере: возраст играет большую роль для пациентов с сердечными заболеваниями. Однако вес и рост также будут влиять на результат. Для человека низкого роста идеальный вес будет отличаться от идеального веса для человека высокого роста. Если у вас избыточный вес в молодом возрасте по сравнению с пожилым, ваш прогноз будет отличаться. Таким образом, функции имеют эффект взаимодействия.
Функция активации softmax - это стандартная функция финальной активации нейронных сетей. Эта функция нормализует выходные данные по результатам, чтобы наилучшим образом выбрать прогнозируемый класс.
Ранее в коде я разделил свой набор данных на наборы для обучения и тестирования. Если вы тестируете (прогнозируете) на тех же данных, на которых тренировались, проверка точности вашей модели будет нарушена. Тестировать нужно на отдельном наборе данных. Это также поможет вам убедиться, что вы не переусердствуете с данными. (Переобучение данных относится к сценарию, в котором ваша модель настолько хорошо обучена на своих практических данных, что не может точно предсказать какие-либо другие данные.) Если результаты, которые мы получаем из нашего набора тестов, также достаточно хороши, мы можем быть уверены в том, что созданная модель.
Ниже приведены результаты обучения последних десяти эпох этой модели:
Epoch 40/50 242/242 - 0s - loss: 0.2496 - accuracy: 0.9215 Epoch 41/50 242/242 - 0s - loss: 0.2457 - accuracy: 0.9215 Epoch 42/50 242/242 - 0s - loss: 0.2509 - accuracy: 0.9215 Epoch 43/50 242/242 - 0s - loss: 0.2457 - accuracy: 0.9132 Epoch 44/50 242/242 - 0s - loss: 0.2343 - accuracy: 0.9256 Epoch 45/50 242/242 - 0s - loss: 0.2350 - accuracy: 0.9256 Epoch 46/50 242/242 - 0s - loss: 0.2352 - accuracy: 0.9298 Epoch 47/50 242/242 - 0s - loss: 0.2271 - accuracy: 0.9256 Epoch 48/50 242/242 - 0s - loss: 0.2266 - accuracy: 0.9298 Epoch 49/50 242/242 - 0s - loss: 0.2193 - accuracy: 0.9339 Epoch 50/50 242/242 - 0s - loss: 0.2203 - accuracy: 0.9339 Confusion matrix, without normalization Process finished with exit code 0
Точность модели составила 93,39% с потерей 0,2203. Когда убытки начинают увеличиваться, это свидетельствует о переобучении. Зная это, пятьдесят эпох, вероятно, было хорошим местом, чтобы остановиться.
То, что точность валидации не выше, чем точность обучения, хороший знак. (Если бы было наоборот, это обычно было бы красным флажком.) Точность обучения и тестирования модели дала положительные результаты.
Эксперимент
1. Сбор данных
Первым шагом этого эксперимента был сбор данных. У меня нет ЭЭГ-гарнитуры для сбора данных от себя, поэтому мне нужно было найти кое-что в Интернете. К счастью, PhysioNet смогла предоставить данные, использованные в исследовании, опубликованном в 2004 году. Эти данные имеют 109 субъектов с 14 запусками на каждого предмета. Я случайным образом выбрал испытуемый номер 4. Для разных задач использовались разные прогоны: в забегах 3, 4, 7, 8, 11 и 12 испытуемые визуализировали движение левой или правой рукой, а в забегах 5, 6, 9, 10, 13, и 14 испытуемых представили движение в руках или ногах.
"""Finding patient data""" # T0 corresponds to rest # T1 corresponds to left fist # T2 corresponds to right fist raw3 = mne.io.read_raw_edf('S004R03.edf', preload=True) raw4 = mne.io.read_raw_edf('S004R04.edf', preload=True) raw7 = mne.io.read_raw_edf('S004R07.edf', preload=True) raw8 = mne.io.read_raw_edf('S004R08.edf', preload=True) raw11 = mne.io.read_raw_edf('S004R11.edf', preload=True) raw12 = mne.io.read_raw_edf('S004R12.edf', preload=True)
Мне удалось скачать тематические файлы и открыть их с помощью MNE-Python.
2. Объединяйте и очищайте файлы
Эти файлы сами по себе не содержат большого количества данных ЭЭГ, поэтому я объединил их, а затем отфильтровал. Фильтрация необходима для большинства ЭЭГ, поскольку большинство файлов пациентов будут содержать большие артефакты, которые искажают анализ.
raws = mne.concatenate_raws([raw3, raw4, raw7, raw8, raw11, raw12]) # raws.plot() # strip channel names of "." characters raws.rename_channels(lambda x: x.strip('.')) # Apply band-pass filter raws.filter(7., 30., method='iir') # raws.plot()
3. Создание эпох
Мы хотим создавать эпохи для анализа с помощью ЭЭГ. Данные этих эпох станут функциями, переданными в модель.
events = events_from_annotations(raws) picks = pick_types(raws.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads') epochs = Epochs(raws, events[0], event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)
4. Создание элементов и меток.
Функции и метки сопоставимы с X и Y. Компонент Y - это просто метки T0, T1, T2, которые пришли с данными. Т0 бесполезен (пациент в состоянии покоя), а Т1, Т2 заменены на 0 и 1 соответственно. Однако X-компоненту нужно немного доработать. В настоящее время у нас есть временная информация по эпохам. Это данные, относящиеся ко времени, когда это происходит. Используя алгоритм CSP от MNE, мы можем извлекать пространственную информацию как объекты.
Y = epochs.events[:, -1] - 2 epochs_data = epochs.get_data() csp = CSP(n_components=4, log=True, reg=None) X = csp.fit_transform(epochs_data, Y)
5. Создание и подгонка модели.
Этот процесс очень похож на тот, который использовался в качестве примера ранее. Я создал четыре плотно связанных слоя, а затем пропустил обучающие элементы и метки через модель.
model = Sequential([ Dense(units=8, input_shape=(4,), activation='relu'), Dense(units=16, activation='relu'), Dense(units=16, activation='relu'), Dense(units=2, activation='softmax') ]) model.summary() train_features, test_features, train_labels, test_labels = train_test_split(X, Y, train_size=0.80, test_size=0.20) model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x=train_features, y=train_labels, validation_data=(test_features, test_labels), batch_size=10, epochs=75, shuffle=True, verbose=2)
6. Прогнозирование данных испытаний.
Этот процесс аналогичен описанному выше.
"""Predictions""" predictions = model.predict(x=test_features, batch_size=10, verbose=0) rounded_predictions = np.argmax(predictions, axis=-1) cm = confusion_matrix(y_true=test_labels, y_pred=rounded_predictions) """Confusion Matrix""" def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion Matrix', cmap=plt.cm.Blues): plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print("Confusion matrix, without normalization") thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') plt.show() cm_plot_labels = ['Left Hand', 'Right Hand'] plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title="Confusion Matrix") plt.show()
Результат выглядит следующим образом:
/Applications/anaconda3/envs/tf/bin/python /Applications/anaconda3/envs/tf/EEGBCI.py Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R03.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19679 = 0.000 ... 122.994 secs... Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R04.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19679 = 0.000 ... 122.994 secs... Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R07.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19679 = 0.000 ... 122.994 secs... Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R08.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19679 = 0.000 ... 122.994 secs... Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R11.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19679 = 0.000 ... 122.994 secs... Extracting EDF parameters from /Applications/anaconda3/envs/tf/S004R12.edf... EDF file detected Setting channel info structure... Creating raw.info structure... Reading 0 ... 19679 = 0.000 ... 122.994 secs... Filtering raw data in 6 contiguous segments Setting up band-pass filter from 7 - 30 Hz IIR filter parameters --------------------- Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter: - Filter order 16 (effective, after forward-backward) - Cutoffs at 7.00, 30.00 Hz: -6.02, -6.02 dB Used Annotations descriptions: ['T0', 'T1', 'T2'] Not setting metadata Not setting metadata 90 matching events found No baseline correction applied 0 projection items activated Loading data for 90 events and 801 original time points ... 0 bad epochs dropped Computing rank from data with rank=None Using tolerance 0.00013 (2.2e-16 eps * 64 dim * 9.3e+09 max singular value) Estimated rank (mag): 64 MAG: rank 64 computed from 64 data channels with 0 projectors Reducing data rank from 64 -> 64 Estimating covariance using EMPIRICAL Done. Computing rank from data with rank=None Using tolerance 0.00013 (2.2e-16 eps * 64 dim * 9.3e+09 max singular value) Estimated rank (mag): 64 MAG: rank 64 computed from 64 data channels with 0 projectors Reducing data rank from 64 -> 64 Estimating covariance using EMPIRICAL Done. 2021-08-19 16:44:50.608926: I tensorflow/core/platform/cpu_feature_guard.cc:145] This TensorFlow binary is optimized with Intel(R) MKL-DNN to use the following CPU instructions in performance critical operations: SSE4.1 SSE4.2 To enable them in non-MKL-DNN operations, rebuild TensorFlow with the appropriate compiler flags. 2021-08-19 16:44:50.610742: I tensorflow/core/common_runtime/process_util.cc:115] Creating new thread pool with default inter op setting: 8. Tune using inter_op_parallelism_threads for best performance. Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 8) 40 _________________________________________________________________ dense_1 (Dense) (None, 16) 144 _________________________________________________________________ dense_2 (Dense) (None, 16) 272 _________________________________________________________________ dense_3 (Dense) (None, 2) 34 ================================================================= Total params: 490 Trainable params: 490 Non-trainable params: 0 _________________________________________________________________ Train on 72 samples, validate on 18 samples Epoch 1/75 72/72 - 1s - loss: 0.6999 - accuracy: 0.4583 - val_loss: 0.7009 - val_accuracy: 0.5000 Epoch 2/75 72/72 - 0s - loss: 0.6887 - accuracy: 0.4722 - val_loss: 0.6888 - val_accuracy: 0.5556 Epoch 3/75 72/72 - 0s - loss: 0.6801 - accuracy: 0.5139 - val_loss: 0.6803 - val_accuracy: 0.5556 Epoch 4/75 72/72 - 0s - loss: 0.6742 - accuracy: 0.5694 - val_loss: 0.6692 - val_accuracy: 0.7222 Epoch 5/75 72/72 - 0s - loss: 0.6658 - accuracy: 0.6111 - val_loss: 0.6617 - val_accuracy: 0.7778 Epoch 6/75 72/72 - 0s - loss: 0.6600 - accuracy: 0.6528 - val_loss: 0.6550 - val_accuracy: 0.7778 Epoch 7/75 72/72 - 0s - loss: 0.6543 - accuracy: 0.6667 - val_loss: 0.6475 - val_accuracy: 0.8333 Epoch 8/75 72/72 - 0s - loss: 0.6475 - accuracy: 0.6806 - val_loss: 0.6397 - val_accuracy: 0.8333 Epoch 9/75 72/72 - 0s - loss: 0.6411 - accuracy: 0.7500 - val_loss: 0.6315 - val_accuracy: 0.7778 Epoch 10/75 72/72 - 0s - loss: 0.6333 - accuracy: 0.7639 - val_loss: 0.6217 - val_accuracy: 0.7778 Epoch 11/75 72/72 - 0s - loss: 0.6245 - accuracy: 0.7361 - val_loss: 0.6101 - val_accuracy: 0.7778 Epoch 12/75 72/72 - 0s - loss: 0.6141 - accuracy: 0.7778 - val_loss: 0.5976 - val_accuracy: 0.7778 Epoch 13/75 72/72 - 0s - loss: 0.6010 - accuracy: 0.8056 - val_loss: 0.5843 - val_accuracy: 0.7778 Epoch 14/75 72/72 - 0s - loss: 0.5875 - accuracy: 0.8333 - val_loss: 0.5687 - val_accuracy: 0.8333 Epoch 15/75 72/72 - 0s - loss: 0.5735 - accuracy: 0.8472 - val_loss: 0.5511 - val_accuracy: 0.8333 Epoch 16/75 72/72 - 0s - loss: 0.5563 - accuracy: 0.8611 - val_loss: 0.5322 - val_accuracy: 0.8333 Epoch 17/75 72/72 - 0s - loss: 0.5400 - accuracy: 0.8750 - val_loss: 0.5114 - val_accuracy: 0.8889 Epoch 18/75 72/72 - 0s - loss: 0.5221 - accuracy: 0.8750 - val_loss: 0.4916 - val_accuracy: 0.8889 Epoch 19/75 72/72 - 0s - loss: 0.5065 - accuracy: 0.8889 - val_loss: 0.4721 - val_accuracy: 0.8889 Epoch 20/75 72/72 - 0s - loss: 0.4900 - accuracy: 0.8750 - val_loss: 0.4512 - val_accuracy: 0.8889 Epoch 21/75 72/72 - 0s - loss: 0.4773 - accuracy: 0.8611 - val_loss: 0.4321 - val_accuracy: 0.9444 Epoch 22/75 72/72 - 0s - loss: 0.4588 - accuracy: 0.9028 - val_loss: 0.4109 - val_accuracy: 0.9444 Epoch 23/75 72/72 - 0s - loss: 0.4490 - accuracy: 0.8889 - val_loss: 0.3907 - val_accuracy: 0.9444 Epoch 24/75 72/72 - 0s - loss: 0.4278 - accuracy: 0.9028 - val_loss: 0.3750 - val_accuracy: 0.9444 Epoch 25/75 72/72 - 0s - loss: 0.4179 - accuracy: 0.8889 - val_loss: 0.3530 - val_accuracy: 0.9444 Epoch 26/75 72/72 - 0s - loss: 0.4014 - accuracy: 0.9167 - val_loss: 0.3346 - val_accuracy: 0.9444 Epoch 27/75 72/72 - 0s - loss: 0.3901 - accuracy: 0.9028 - val_loss: 0.3214 - val_accuracy: 1.0000 Epoch 28/75 72/72 - 0s - loss: 0.3799 - accuracy: 0.8889 - val_loss: 0.3024 - val_accuracy: 1.0000 Epoch 29/75 72/72 - 0s - loss: 0.3651 - accuracy: 0.9028 - val_loss: 0.2878 - val_accuracy: 0.9444 Epoch 30/75 72/72 - 0s - loss: 0.3569 - accuracy: 0.9167 - val_loss: 0.2733 - val_accuracy: 0.9444 Epoch 31/75 72/72 - 0s - loss: 0.3464 - accuracy: 0.9167 - val_loss: 0.2605 - val_accuracy: 1.0000 Epoch 32/75 72/72 - 0s - loss: 0.3412 - accuracy: 0.9028 - val_loss: 0.2508 - val_accuracy: 1.0000 Epoch 33/75 72/72 - 0s - loss: 0.3320 - accuracy: 0.9028 - val_loss: 0.2402 - val_accuracy: 0.9444 Epoch 34/75 72/72 - 0s - loss: 0.3250 - accuracy: 0.9167 - val_loss: 0.2285 - val_accuracy: 1.0000 Epoch 35/75 72/72 - 0s - loss: 0.3188 - accuracy: 0.9028 - val_loss: 0.2196 - val_accuracy: 1.0000 Epoch 36/75 72/72 - 0s - loss: 0.3119 - accuracy: 0.9167 - val_loss: 0.2093 - val_accuracy: 1.0000 Epoch 37/75 72/72 - 0s - loss: 0.3040 - accuracy: 0.9167 - val_loss: 0.2006 - val_accuracy: 1.0000 Epoch 38/75 72/72 - 0s - loss: 0.3024 - accuracy: 0.9167 - val_loss: 0.1945 - val_accuracy: 0.9444 Epoch 39/75 72/72 - 0s - loss: 0.2962 - accuracy: 0.9167 - val_loss: 0.1835 - val_accuracy: 1.0000 Epoch 40/75 72/72 - 0s - loss: 0.2908 - accuracy: 0.9167 - val_loss: 0.1767 - val_accuracy: 1.0000 Epoch 41/75 72/72 - 0s - loss: 0.2878 - accuracy: 0.9167 - val_loss: 0.1717 - val_accuracy: 1.0000 Epoch 42/75 72/72 - 0s - loss: 0.2839 - accuracy: 0.9167 - val_loss: 0.1648 - val_accuracy: 1.0000 Epoch 43/75 72/72 - 0s - loss: 0.2819 - accuracy: 0.9306 - val_loss: 0.1588 - val_accuracy: 1.0000 Epoch 44/75 72/72 - 0s - loss: 0.2795 - accuracy: 0.9167 - val_loss: 0.1600 - val_accuracy: 1.0000 Epoch 45/75 72/72 - 0s - loss: 0.2895 - accuracy: 0.8889 - val_loss: 0.1636 - val_accuracy: 0.9444 Epoch 46/75 72/72 - 0s - loss: 0.2839 - accuracy: 0.9167 - val_loss: 0.1453 - val_accuracy: 1.0000 Epoch 47/75 72/72 - 0s - loss: 0.2719 - accuracy: 0.9306 - val_loss: 0.1406 - val_accuracy: 1.0000 Epoch 48/75 72/72 - 0s - loss: 0.2679 - accuracy: 0.9306 - val_loss: 0.1375 - val_accuracy: 1.0000 Epoch 49/75 72/72 - 0s - loss: 0.2674 - accuracy: 0.9306 - val_loss: 0.1336 - val_accuracy: 1.0000 Epoch 50/75 72/72 - 0s - loss: 0.2741 - accuracy: 0.9028 - val_loss: 0.1349 - val_accuracy: 1.0000 Epoch 51/75 72/72 - 0s - loss: 0.2609 - accuracy: 0.9306 - val_loss: 0.1285 - val_accuracy: 1.0000 Epoch 52/75 72/72 - 0s - loss: 0.2667 - accuracy: 0.9028 - val_loss: 0.1261 - val_accuracy: 1.0000 Epoch 53/75 72/72 - 0s - loss: 0.2603 - accuracy: 0.9167 - val_loss: 0.1220 - val_accuracy: 1.0000 Epoch 54/75 72/72 - 0s - loss: 0.2563 - accuracy: 0.9306 - val_loss: 0.1216 - val_accuracy: 1.0000 Epoch 55/75 72/72 - 0s - loss: 0.2539 - accuracy: 0.9306 - val_loss: 0.1168 - val_accuracy: 1.0000 Epoch 56/75 72/72 - 0s - loss: 0.2544 - accuracy: 0.9306 - val_loss: 0.1143 - val_accuracy: 1.0000 Epoch 57/75 72/72 - 0s - loss: 0.2542 - accuracy: 0.9167 - val_loss: 0.1117 - val_accuracy: 1.0000 Epoch 58/75 72/72 - 0s - loss: 0.2544 - accuracy: 0.9167 - val_loss: 0.1088 - val_accuracy: 1.0000 Epoch 59/75 72/72 - 0s - loss: 0.2497 - accuracy: 0.9306 - val_loss: 0.1071 - val_accuracy: 1.0000 Epoch 60/75 72/72 - 0s - loss: 0.2452 - accuracy: 0.9306 - val_loss: 0.1085 - val_accuracy: 1.0000 Epoch 61/75 72/72 - 0s - loss: 0.2578 - accuracy: 0.8889 - val_loss: 0.1142 - val_accuracy: 1.0000 Epoch 62/75 72/72 - 0s - loss: 0.2608 - accuracy: 0.8889 - val_loss: 0.1109 - val_accuracy: 1.0000 Epoch 63/75 72/72 - 0s - loss: 0.2617 - accuracy: 0.9028 - val_loss: 0.0988 - val_accuracy: 1.0000 Epoch 64/75 72/72 - 0s - loss: 0.2458 - accuracy: 0.9306 - val_loss: 0.0993 - val_accuracy: 1.0000 Epoch 65/75 72/72 - 0s - loss: 0.2463 - accuracy: 0.9167 - val_loss: 0.0968 - val_accuracy: 1.0000 Epoch 66/75 72/72 - 0s - loss: 0.2417 - accuracy: 0.9306 - val_loss: 0.0943 - val_accuracy: 1.0000 Epoch 67/75 72/72 - 0s - loss: 0.2418 - accuracy: 0.9306 - val_loss: 0.0934 - val_accuracy: 1.0000 Epoch 68/75 72/72 - 0s - loss: 0.2410 - accuracy: 0.9306 - val_loss: 0.0922 - val_accuracy: 1.0000 Epoch 69/75 72/72 - 0s - loss: 0.2406 - accuracy: 0.9306 - val_loss: 0.0966 - val_accuracy: 1.0000 Epoch 70/75 72/72 - 0s - loss: 0.2416 - accuracy: 0.9028 - val_loss: 0.0922 - val_accuracy: 1.0000 Epoch 71/75 72/72 - 0s - loss: 0.2402 - accuracy: 0.9306 - val_loss: 0.0875 - val_accuracy: 1.0000 Epoch 72/75 72/72 - 0s - loss: 0.2380 - accuracy: 0.9306 - val_loss: 0.0862 - val_accuracy: 1.0000 Epoch 73/75 72/72 - 0s - loss: 0.2390 - accuracy: 0.9167 - val_loss: 0.0849 - val_accuracy: 1.0000 Epoch 74/75 72/72 - 0s - loss: 0.2350 - accuracy: 0.9306 - val_loss: 0.0834 - val_accuracy: 1.0000 Epoch 75/75 72/72 - 0s - loss: 0.2315 - accuracy: 0.9306 - val_loss: 0.0863 - val_accuracy: 1.0000 Confusion matrix, without normalization Process finished with exit code 0
Возможно, я установил слишком высокие эпохи, поскольку потери начинают расти, но это быстрое и легкое решение, которое приведет к аналогичным результатам. Точность обучающей выборки составила 93,06%, а точность тестовой выборки - 100%. Обычно это большой красный флаг. Огромный красный флаг. Однако вы заметите, что набор для тестирования довольно невелик. Рассмотрены не сотни дел - всего восемнадцать. Поэтому неудивительно, что точность валидации настолько высока.
Используя больший размер тестирования, можно сгенерировать матрицу неточностей, подобную приведенной ниже.
Другой вариант
Вместо загрузки всех необходимых данных на ваш компьютер, MNE предлагает альтернативу именно для этого набора данных. Созданная мной функция find_patient подключится к MNE и загрузит набор данных прямо в каталог, в котором вы запускаете свой код.
"""Importing modules""" import numpy as np import matplotlib.pyplot as plt from mne import Epochs, pick_types, find_events, events_from_annotations from mne.channels import read_layout from mne.io import concatenate_raws, read_raw_edf from mne.datasets import eegbci from mne.decoding import CSP import tensorflow as tf from tensorflow import keras from tensorflow.keras.models import Sequential, load_model from tensorflow.keras.layers import Activation, Dense from tensorflow.keras.optimizers import Adam from tensorflow.keras.metrics import categorical_crossentropy from sklearn.metrics import confusion_matrix from sklearn.model_selection import train_test_split import itertools """Initializing main variables""" # avoid classification of evoked responses by using epochs that start 1s after # cue onset. tmin, tmax = -1., 4. event_id = { 'Left Hand': 2, 'Right Hand': 3 } # Use these runs runs = [3, 4, 7, 8, 11, 12] # motor imagery: left hand vs right hand """Finding correct patient""" def find_patient(subject, runs): raw_fnames = eegbci.load_data(subject, runs) raw_files = [read_raw_edf(f, preload=True) for f in raw_fnames] raw = concatenate_raws(raw_files) # strip channel names of "." characters raw.rename_channels(lambda x: x.strip('.')) # Apply band-pass filter raw.filter(7., 30., method='iir') events = events_from_annotations(raw) picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads') # Read epochs (train will be done only between 1 and 2s) # Testing will be done with a running classifier epochs = Epochs(raw, events[0], event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True) Y = epochs.events[:, -1] - 2 epochs_data = epochs.get_data() csp = CSP(n_components=4, log=True, reg=None) X = csp.fit_transform(epochs_data, Y) data = [X, Y] return data subject4 = 4 training_patient = find_patient(subject4, runs) """Building the model""" model = Sequential([ Dense(units=8, input_shape=(4,), activation='relu'), Dense(units=16, activation='relu'), Dense(units=16, activation='relu'), Dense(units=2, activation='softmax') ]) model.summary() train_features, test_features, train_labels, test_labels = train_test_split(training_patient[0], training_patient[1], train_size=0.90, test_size=0.10) model.compile(optimizer=Adam(learning_rate=0.001), loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x=train_features, y=train_labels, validation_data=(test_features, test_labels), batch_size=10, epochs=80, shuffle=True, verbose=2) """Predictions""" # subject5 = 5 # testing_patient = find_patient(subject5, runs) predictions = model.predict(x=test_features, batch_size=10, verbose=0) rounded_predictions = np.argmax(predictions, axis=-1) cm = confusion_matrix(y_true=test_labels, y_pred=rounded_predictions) """Confusion Matrix""" def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion Matrix', cmap=plt.cm.Blues): plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print("Confusion matrix, without normalization") thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') plt.show() cm_plot_labels = ['Left Hand', 'Right Hand'] plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title="Confusion Matrix") plt.show()
Вывод
Используя онлайн-данные, собранные в опубликованном исследовании, я смог классифицировать мысли пациентов, чтобы определить, думают ли они о движении в их левой или правой руке. Успех классификации составил 93,06% на обучающей выборке и 100% на тестовой. Повторный прогон модели был выполнен с большим набором тестов для лучшего обзора модели, и точность обучения составила 95,16%, а точность тестирования - 88,89%.
«Мы можем видеть только небольшое расстояние впереди, но мы можем видеть там много того, что нужно сделать». - Алан Тьюринг
Это пример того, как машинное обучение и BCI пересекаются. Однако этот пример можно легко распространить на управление протезами с помощью нейронных сетей и нейронных колебаний. Кроме того, это пример того, как медицина и наука о данных пересекаются и полагаются друг на друга для роста; настоящие мутуалистические отношения. Любой, у кого есть доступ к компьютеру, может начать программировать на Python и тестировать онлайн-наборы данных.
BCI - это будущее медицины, и они будут продолжать расти и развиваться. Поскольку внедрение ИМК в больницах и среди пациентов будет увеличиваться, не будет времени, когда мы вспомним, что у них их не было. Кроме того, затронуты не только миллионы пациентов, которые будут спасены с помощью этой технологии; BCI могут применяться практически в любой отрасли, от потребительских товаров до военной. Это начало медицинской и технической революции.