Введение
Распознавание рукописных цифр является одной из основных задач в области машинного обучения, и у него есть несколько реальных приложений, таких как чтение почтовых номеров, цифр банковских чеков и приложений для форм. В этой статье мы исследуем, как можно использовать алгоритм K-ближайших соседей (KNN) для решения этой задачи с использованием набора данных MNIST. KNN — это простой, но мощный метод непараметрической классификации, который работает, находя предопределенное количество обучающих выборок, ближайших по расстоянию до новой точки, и прогнозируя метку на их основе.
Мотивация
Выбор KNN для этой задачи обусловлен его простотой и эффективностью. Будучи алгоритмом ленивого обучения, KNN не делает предположений о характере базовых данных, что делает его полезным в ситуациях, когда граница решения очень нерегулярна. Кроме того, его способность быстро адаптироваться к изменениям делает KNN отличным выбором для распознавания рукописных цифр, где вариации являются скорее нормой, чем исключением.
В этой статье мы будем обучать KNN для MNIST, полную тетрадь вы можете найти здесь. Теперь давайте подробно обсудим, что мы делаем в этом блокноте.
Загрузка набора данных
Набор данных MNIST («Модифицированный национальный институт стандартов и технологий») представляет собой большую базу данных рукописных цифр, которая широко используется для обучения и тестирования в области машинного обучения. Он содержит 60 000 обучающих изображений и 10 000 тестовых изображений, каждое из которых представляет собой изображение в градациях серого 28x28, связанное с меткой от 0 до 9.
from sklearn.datasets import load_digits # Load data digits = load_digits() X = digits.images y = digits.target
Выполнение исследовательского анализа данных (EDA)
Прежде всего, давайте отобразим наш набор данных.
# Plot the first few images fig, axes = plt.subplots(2, 5, figsize=(10, 5)) axes = axes.ravel() for i in np.arange(0, 10): axes[i].imshow(X[i], cmap='gray') axes[i].set_title("Digit: %s" % y[i]) axes[i].axis('off') plt.subplots_adjust(wspace=0.5)
Следующее построение распределения классов в наборе данных MNIST является важным шагом в исследовательском анализе данных (EDA). Цель этого анализа — выяснить, сбалансирован ли набор данных, то есть каждый класс (в данном случае каждая цифра от 0 до 9) имеет примерно одинаковое количество экземпляров. Несбалансированный набор данных может привести к предвзятости в модели машинного обучения, в результате чего она будет хорошо работать с перепредставленными классами и плохо работать с недопредставленными. Визуализируя распределение классов, мы можем убедиться, что каждый класс адекватно представлен в нашем наборе данных. Это определяет нашу стратегию обучения модели машинного обучения и позволяет нам при необходимости принимать корректирующие меры, такие как увеличение данных, методы повторной выборки или корректировка весов классов.
# Plot the distribution of classes plt.figure(figsize=(10, 5)) plt.bar(unique, counts) plt.xticks(unique) plt.xlabel("Digit") plt.ylabel("Frequency") plt.title("Distribution of digits in MNIST") plt.show()
Теперь давайте визуализируем среднее изображение для каждой цифры (от 0 до 9) в наборе данных MNIST. Это делается для того, чтобы глубже понять общие закономерности, лежащие в основе каждого класса цифр. Усредняя все изображения, принадлежащие к определенному классу, мы получаем репрезентативное изображение, которое инкапсулирует общие черты этой цифры, как она появляется в нашем наборе данных. Эти «усредненные изображения» дают нам общий контур или «прототипную» цифру, с которой может столкнуться наша модель. Наблюдение за этими изображениями может пролить свет на уникальные характеристики каждой цифры и дать нам визуальное представление о функциях, которые модель должна изучить, чтобы эффективно различать разные классы.
# Compute average images for each digit avg_images = np.array([np.mean(X[y == i], axis=0) for i in range(10)]) # Plot average images fig, axes = plt.subplots(2, 5, figsize=(10, 5)) axes = axes.ravel() for i in np.arange(0, 10): axes[i].imshow(avg_images[i], cmap='gray') axes[i].set_title("Digit: %s" % i) axes[i].axis('off') plt.subplots_adjust(wspace=0.5)
Затем мы выполняем этот анализ распределения интенсивности пикселей, чтобы лучше понять распределение значений пикселей в нашем наборе данных MNIST. Изменяя данные 3D-изображения в 2D-массив, а затем сглаживая их до 1D-массива, мы можем вместе исследовать все значения пикселей. Затем на гистограмме отображается частота этих значений пикселей в диапазоне от 0 (черный) до 255 (белый). Это дает представление об общих характеристиках наших изображений. Например, если большинство значений пикселей ближе к нижнему пределу (около 0), это будет означать, что наши изображения, как правило, темные, с несколькими отчетливыми белыми областями, представляющими цифры. Это понимание может быть полезно для этапов предварительной обработки, таких как нормализация или стандартизация, а также может помочь в уточнении нашей модели машинного обучения, предоставляя представление о данных, из которых она будет учиться.
# Reshape X to 2D X_2D = X.reshape(X.shape[0], -1) # Histogram of pixel intensities plt.figure(figsize=(10, 5)) plt.hist(X_2D.ravel(), bins=30, color='gray', alpha=0.7) plt.title("Distribution of pixel intensities") plt.xlabel("Pixel intensity") plt.ylabel("Frequency") plt.show()
Далее у нас есть визуализация t-SNE. t-SNE (t-Distributed Stochastic Neighbor Embedding) — это алгоритм машинного обучения, особенно подходящий для визуализации многомерных наборов данных, таких как данные MNIST, которые состоят из изображений размером 8x8 пикселей (64-мерный набор данных). Когда становится непрактичным визуализировать такие многомерные данные, методы уменьшения размерности, такие как t-SNE, помогают представить эти данные в 2D или 3D, сохраняя важные структуры в разных масштабах. Представленная здесь визуализация t-SNE переводит эти многомерные цифровые изображения в 2D, сохраняя при этом относительные расстояния между различными точками данных. Каждая точка на графике соответствует цифре MNIST, а цвет соответствует фактической метке цифры. Кластеры похожих точек данных одного и того же цвета указывают на то, что цифры одного и того же типа имеют схожие паттерны интенсивности пикселей. Таким образом, мы используем t-SNE, чтобы лучше понять и получить представление о наших многомерных данных, визуально идентифицируя кластеры или группы похожих экземпляров и, возможно, выбросы.
Подготовка набора данных для KNN
Подготовка набора данных для модели KNN в первую очередь включает нормализацию значений оттенков серого и изменение формы массивов данных. Нормализация, которая масштабирует все значения оттенков серого в диапазоне от 0 до 1, помогает модели обучаться быстрее и снижает вероятность застревания в локальных оптимумах. Изменение формы массивов данных с трехмерного на двухмерное упрощает ввод для модели KNN, которая ожидает двумерный массив.
# Load data digits = load_digits() X = digits.images.reshape((len(digits.images), -1)) y = digits.target
Разделение данных
Первым важным шагом является разделение нашего набора данных MNIST на обучающий набор и набор для тестирования. Это делается для оценки производительности нашей модели на невидимых данных после обучения. Для этой цели используется функция train_test_split
из модуля model_selection
sklearn. Параметр test_size
установлен на 0,2, что означает, что 20% данных зарезервировано для тестирования. Параметр stratify=y
гарантирует, что распределение меток остается одинаковым как в обучающем, так и в тестовом наборах. Параметр random_state
используется для воспроизведения одних и тех же обучающих и тестовых наборов в нескольких прогонах.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
Определение гиперпараметров
Производительность KNN сильно зависит от выбора его гиперпараметров. Мы рассматриваем три основных гиперпараметра:
n_neighbors
: количество рассматриваемых соседей. Пробуем значения от 3 до 11.weights
: Весовая функция, используемая при прогнозировании. Мы рассматриваем два распространенных варианта: «равномерный» (все точки в каждом районе имеют одинаковый вес) и «расстояние» (точки взвешиваются в зависимости от расстояния, обратного их расстоянию, поэтому более близкие соседи имеют большее влияние).p
: Параметр мощности для метрики Минковского. Когда p=1, это эквивалентно использованию manhattan_distance, а euclidean_distance используется для p=2.
parameters = { 'n_neighbors': [3, 5, 7, 9, 11], 'weights': ['uniform', 'distance'], 'p': [1, 2] }
Инициализация модели и поиск по сетке
Мы начинаем с инициализации модели KNN с помощью KNeighborsClassifier()
. Затем мы выполняем поиск по сетке, чтобы найти оптимальное сочетание гиперпараметров для нашей модели KNN. Поиск по сетке — это процесс, в котором вы определяете набор возможных значений для различных гиперпараметров и пробуете все возможные комбинации. В этом случае гиперпараметры и их возможные значения определяются в словаре parameters
.
Функция GridSearchCV
из модуля model_selection
sklearn используется для выполнения поиска по сетке. Эта функция обучает модель KNN для каждой комбинации гиперпараметров и использует перекрестную проверку для оценки производительности каждой модели. Параметр cv
указывает количество сгибов для перекрестной проверки, а параметр scoring
указывает метрику, используемую для оценки.
knn = KNeighborsClassifier() grid_search = GridSearchCV(knn, parameters, cv=5, scoring='accuracy') grid_search.fit(X_train, y_train)
Оценка результатов поиска по сетке
После завершения поиска по сетке мы можем вывести наилучшую комбинацию гиперпараметров и соответствующую оценку точности. Это делается с помощью атрибутов best_params_
и best_score_
объекта grid_search
.
print(grid_search.best_params_) print(grid_search.best_score_)
Обучение модели оптимальными гиперпараметрами
Наконец, мы обучаем новую модель KNN, используя лучшие гиперпараметры, найденные с помощью поиска по сетке. Затем модель подгоняется с использованием обучающих данных.
knn_best = KNeighborsClassifier(**grid_search.best_params_) knn_best.fit(X_train, y_train)
Этот процесс позволяет нам найти наиболее подходящий набор гиперпараметров для нашей модели KNN, позволяя ей предсказывать рукописные цифры с высокой степенью точности.
Тестирование и оценка модели
После обучения нашей модели KNN с лучшим набором гиперпараметров следующим шагом будет оценка ее производительности на тестовом наборе. Это делается с помощью метода predict
модели knn_best
для прогнозирования меток для тестовых данных. Чтобы более подробно понять производительность модели, мы создаем отчет о классификации, используя функцию classification_report
sklearn. Отчет о классификации предоставляет ключевые показатели, такие как точность, полнота и оценка F1 для каждого класса, а также общую точность модели.
y_pred = knn_best.predict(X_test) conf_mat = confusion_matrix(y_test, y_pred) sns.heatmap(conf_mat, annot=True, fmt='d') plt.show()
Согласно отчету о классификации, модель K ближайших соседей (KNN) демонстрирует превосходную производительность в наборе данных MNIST. Точность модели составляет приблизительно 99 %, что указывает на то, что модель предсказывает правильный класс для 99 % тестовых изображений. Значения точности, отзыва и F1-оценки неизменно высоки для разных цифр, демонстрируя, что производительность модели не смещена в сторону определенных классов. Тем не менее, есть небольшое снижение производительности для цифры «8» с точки зрения точности и для цифры «9» с точки зрения отзыва, что позволяет предположить незначительные ошибки классификации в этих категориях. В целом, несмотря на эти незначительные расхождения, модель KNN оказалась очень эффективной при распознавании рукописных цифр в наборе данных MNIST.
Визуализация предсказания тестового изображения
В дополнение к общим показателям производительности часто бывает полезно посмотреть на прогнозы модели для отдельных экземпляров. Здесь мы выбираем изображение из тестового набора, используем модель для предсказания его метки, а затем визуализируем изображение вместе с истинными и предсказанными метками.
# Send a testing image to the model, get prediction, and visualize it test_image_index = 0 test_image = X_test[test_image_index] test_image_label = y_test[test_image_index] test_image_pred = knn_best.predict(test_image.reshape(1, -1)) plt.imshow(test_image.reshape(8, 8), cmap='gray') plt.title(f"True label: {test_image_label}, Predicted label: {test_image_pred[0]}") plt.show()
Заключение
В этой статье мы рассмотрели, как можно эффективно применять алгоритм K-ближайших соседей (KNN) для распознавания рукописных цифр с использованием набора данных MNIST. Процесс включал загрузку и изучение набора данных, подготовку его для модели KNN и настройку гиперпараметров модели с помощью поиска по сетке. Затем мы обучили модель с использованием оптимальных гиперпараметров и оценили ее производительность с помощью отчета о классификации, матрицы путаницы и визуализации отдельных прогнозов. Модель KNN, несмотря на свою простоту, смогла достичь высокой степени точности в этой задаче, продемонстрировав свою полезность для задач классификации на основе изображений. Эти результаты усиливают возможности методов машинного обучения в автоматизации и точном выполнении сложных задач, таких как распознавание цифр, в различных реальных приложениях.
Это все люди. Да пребудет с тобой сила.