Введение

Недавно мы выпустили интерактивное задание, которое обеспечивает вводное понимание того, как классификатор изображений можно использовать в контексте фабрики. Цель этой деятельности — произвести как можно больше «хороших» виджетов, а также свести к минимуму количество «дефектных» виджетов. Робота можно научить классифицировать любую деталь, поступающую с конвейера, а также можно попросить удалить объект, который идентифицирован как «дефектный».

Мы разработали конвейерную ленту и сцены камеры в Unity. У нас есть статья, в которой более подробно рассказывается о том, как нам нравится общаться в Unity и JS здесь.

Первоначально мы думали, что было бы неплохо вызвать конечную точку классификатора изображений на удаленном сервере. Это компенсировало бы требования к вычислениям от клиента к серверу, но не обошлось бы без подводных камней. Для наших требований обучение классификатора изображений должно происходить в режиме реального времени. Таким образом, связь по сети должна быть быстрой. Это будет плохо масштабироваться, если активность используется многими одновременными пользователями, и каждый клиент может отправлять несколько сотен запросов. Должен быть лучший способ сделать это. О том, как можно разместить модель машинного обучения на бессерверном бэкэнде, мы поговорим в другой статье.

К счастью для нас, ответом на наши проблемы стал TensorFlow.js (TF). Наиболее удобной частью было то, что мы могли использовать существующую модель, обучать ее с помощью трансферного обучения и сразу же выполнять выводы по классификатору, не выходя из браузера.

Классификация изображений с помощью трансферного обучения

TF предоставляет две обученные модели, которые мы можем использовать для создания пользовательского классификатора изображений — Knn-classifier (KNN) и MobileNet.

Модель MobileNet — это модель классификатора изображений, разработанная для работы в средах с ограниченными ресурсами, таких как браузер. Однако, если мы используем модель для прогнозирования частей виджета нашей игры, мы не получим нужных прогнозов. Вместо этого мы получим метку, основанную на очень большом обучающем наборе, на котором был обобщен MobileNet.

Трансферное обучение позволяет нам использовать существующие возможности модели MobileNet и «обучать ее новым приемам». В нашем случае мы хотели бы научить его определять разницу между бракованными и нетронутыми изделиями, спускающимися по конвейерной ленте. Для этого нам понадобится еще одна модель — KNN.

API TF MobileNet предоставляет удобный метод для извлечения признаков изображения. Мы можем использовать эти функции вместе с пользовательской меткой, такой как «дефектная передача» или «хорошая передача», и добавлять примеры в нашу модель KNN. Модель KNN берет K (по умолчанию 3 для этой модели) ближайших соседей похожих изображений. За каждое предсказание. мы возьмем выходные данные MobileNet и передадим их в модель KNN, чтобы увидеть, в какую зону попадает изображение. Зона KNN сопоставляется с идентификационной меткой и дает нам наш пользовательский прогноз «дефектная передача» или «хорошая передача».

Реализация нашего классификатора изображений

Активность классификатора изображений была создана с помощью Unity, React, Redux и Tensforflow.js. Мы использовали React для нашего интерактивного графического интерфейса, Redux для управления и хранения состояния активности и модели Tensorflow.js для создания нашего классификатора изображений. Исходники проекта можно посмотреть здесь.

В этой статье мы в основном сосредоточимся на коде Tensorflow.js, используемом для создания классификатора изображений.

Модель MobileNet загружается асинхронно через действие Redux при загрузке игры. Для загрузки модели требуется некоторое время, поэтому мы хотим убедиться, что модель доступна до начала действия.

// Redux Action
import * as MobileNetModule from '@tensorflow-models/mobilenet';
const loadMobileNet = (model) => {
  return {
    type: 'LOAD_MOBILENET',
    model,
  };
};
export const loadMobileNetAsync = () => {
  return (dispatch) => {
    MobileNetModule.load().then((m) => {
      dispatch(loadMobileNet(m));
    });
  };
}

// Redux Reducer
const featureExtractor = (state = {}, action) => {
  switch (action.type) {
    case "LOAD_MOBILENET":
      return action.model;
    default:
      return state;
  }
};
export default featureExtractor;

Классификатор KNN может быть создан синхронно и инициирован как часть состояния по умолчанию для редуктора классификатора.

import * as knn from '@tensorflow-models/knn-classifier';
const classifier = (state = knn.create(), action) => {
  switch (action.type) {
    default:
      return state;
  }
};
export default classifier;

Добавление примеров в KNN

Каждый пример исходит из двоичного файла, отправленного из Unity с использованием протокола Unity comms между Unity и React. Модель MobileNet (обозначенная как экстрактор признаков в функции) предполагает, что файл имеет тип изображения. Обрабатываем бинарник из Unity как образ.

export const addExample = (binary, label) => {
  const { featureExtractor, classifier } = store.getState();
  preProcessImageWithCallback(binary, (img) => {
    const features = featureExtractor.infer(img, "conv_preds");
    classifier.addExample(features, label);
  });
};

Здесь важно отметить, что мы делаем это асинхронно и вызываем функцию preProcessImageWithCallback только после срабатывания события загрузки изображения. Функция обратного вызова использует модель MobileNet (featureExtractor) для создания набора извлеченных функций для модели KNN (классификатор), чтобы добавить пример.

const preProcessImageWithCallback = async (binary, callback) => {
  let img = new Image(800, 600);
  img.src = `data:image/png;base64, ${binary}`;
  img.onload = () => {
    callback(img);
  };
};

Прогнозирование изображений

Как только KNN будет построена с достаточным набором примеров, прогнозы станут удивительно точными. Функция предсказания возвращает обещание, которое разрешается, как только метод classifier.predict возвращает результат. Если метка результата не соответствует классам меток KNN, мы отклоняем обещание прогнозирования и перестраиваем модель.

Мы также отправляем класс defaultLabel в тех случаях, когда метод прогнозирования вызывается до того, как у KNN появятся примеры.

export const predict = (binary, uid) => {
  return new Promise((resolve, reject) => {
    const { featureExtractor, classifier, labelClasses } = store.getState();
preProcessImageWithCallback(binary, (img) => {
      const features = featureExtractor.infer(img, "conv_preds");
      classifier
        .predictClass(features)
        .then((result) => {
          const labelClass = getLabelClassByName(
            labelClasses.list,
            result.label
          );
          const confidences = result.confidences;
          if (labelClass) {
            resolve({
              ...labelClass,
              uid,
              confidences,
            });
          } else {
            reject("rebuilding model");
          }
        })
        .catch((_) => {
          resolve({
            ...defaultLabelClass,
            uid,
            rgb: {
              a: 0,
              r: 0,
              g: 0,
              b: 0,
            },
          });
        });
    });
  });
};

Повторная загрузка и удаление этикеток

Наш интерфейс активности позволяет учащемуся удалять отдельные метки. К сожалению, такое поведение напрямую не поддерживалось в KNN API. Чтобы поддерживать эту функцию, нам пришлось удалить все примеры, а затем перезагружать примеры из состояния Redux для каждого удаления.

Очистить весь набор меток класса довольно просто.

export const deleteExamples = (label) => {
  const { classifier } = store.getState();
  try {
    classifier.clearClass(label);
  } catch (e) {}
};

Перезагрузка примеров включала в себя взятие всех меток классов KNN и добавление примеров обратно в KNN. Метод classDatasetMatrices KNN возвращает только имена меток и не предоставляет полный объект для каждой метки. Мы сохраняем полный объект в редьюсере labelClasses и перебираем каждый объект, очищаем все примеры, а затем добавляем каждый пример для каждого изображения. Это немного неэффективно с вычислительной точки зрения, и его можно улучшить, добавив прямую поддержку удаления отдельных меток непосредственно в объект KNN. К счастью, задачу можно выполнить с небольшим количеством примеров и не возникает проблем, даже если функция перезагрузки примеров равна O(n²) .

export const reloadAllExamples = () => {
  const { classifier, labelClasses } = store.getState();
  const labels = Object.keys(classifier.classDatasetMatrices);
labelClasses.list.forEach((label) => {
    const labelObject = labels[label.id - 1];
    classifier.clearClass(labelObject);
    label.images.forEach((img) => addExample(img, label.name));
  });
};

Вывод

Надеемся, что эта статья поможет демистифицировать использование классификатора изображений в приложении Javascript. В обзоре нам удалось создать классификатор изображений, сначала извлекая функции нашего изображения с использованием модели MobileNet, вводя эти функции в модель KNN с меткой, а затем вызывая функцию прогнозирования KNN, используя функции MobileNet последующих изображений для прогноз. Нам удалось настроить модель в режиме реального времени без особых накладных расходов. Мы рекомендуем модели TensorFlow.js для любого разработчика JS, заинтересованного в удобном использовании машинного обучения в своих приложениях без зависимости от удаленного сервера. Обязательно ознакомьтесь с другими моделями TF.js для вашего следующего проекта.

Этот материал основан на работе, поддержанной Национальным научным фондом в рамках гранта номер 1937063.

Гитхаб-репозиторий

Онлайн демо

Быстрое прототипирование с Unity и React

Развертывание модели машинного обучения на бессерверной серверной части с помощью SAM CLI