При прохождении курсов ML/глубокого обучения обычно имеется доступ к предварительно загруженным наборам данных, таким как MNIST, FashionMNIST и т. д., поскольку они поставляются с такими платформами глубокого обучения, как Tensorflow. Но в реальных приложениях ваш набор данных, скорее всего, не будет доступен в таком виде. Кроме того, он может быть слишком большим, чтобы сразу поместиться в память.

TFRecords предлагает уникальное решение проблемы обучения моделей с большими наборами данных. Запись тензорного потока (tfrecord) — это формат двоичного файла, предназначенный для эффективного хранения и загрузки больших наборов данных. Они упрощают обучение модели машинного обучения и хорошо работают с различными библиотеками глубокого обучения и машинного обучения.

Как?

Здесь я покажу, как использовать tfrecords. Во-первых, вам нужно преобразовать свой набор данных в tfrecords, прочитать файл tfrecord и, наконец, как обучить модель машинного обучения с использованием tfrecords.

В этой части я бы сосредоточился исключительно на преобразовании вашего набора данных в tfrecords.

Для этого я буду использовать набор данныхRSNA Screening Mammography обнаружения рака молочной железы. Он содержит более 50 000 медицинских изображений и имеет размер более 300 ГБ, поэтому он идеально подходит для этого. Он также имеет файл train.csv, который содержит другую полезную информацию об изображениях и цели.

Первым шагом обычно является загрузка набора данных в сеанс ноутбука juypter или на локальный компьютер. Я использую блокнот Kaggle, так что это довольно просто.

Вот снимок набора данных и первых двух столбцов файла train.csv.

Основное внимание здесь должно быть уделено изображению и соответствующему ему целевому значению в качестве записи/примера в файле tfrecord.

Первый шаг — указать количество записей, которые мы хотим в каждом файле tfrecord. Обычно количество записей зависит от вашего набора данных. Я бы использовал здесь 1000 записей.

NUM_RECORDS = 1000

Затем, исходя из количества записей, мы определяем, сколько файлов tfrecord нам потребуется для хранения всего набора данных.

num_tfrecords = max(train.shape) // NUM_RECORDS

if max(train.shape) % NUM_RECORDS:
    num_tfrecords += 1  # add one record if there are any remaining samples

Затем мы определяем некоторые вспомогательные функции.

Чтобы использовать tfrecords, мы должны определить вспомогательные функции, которые помогут превратить данные в объект tf.train.Feature, который впоследствии будет использоваться в качестве функции в буфере протокола tf.train.Example. Здесь нам понадобятся разные вспомогательные функции для обработки разных типов данных в нашем наборе данных. Поскольку в наборе данных есть два типа данных, np.ndarray и int64, определены две вспомогательные функции.

Вот вспомогательные функции:

def image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_png(value).numpy()])
    )

def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

tf.io.encode_png конвертирует и сжимает тензоры в png. tf.io также содержит функцию encode_jpg для преобразования тензоров в jpg.

Здесь я показываю дополнительный шаг для обработки изображений из файлов dicom. Файлы Dicom (`.dcm`) используются для хранения медицинских изображений, а также содержат информацию о пациентах. Чтобы прочитать файлы .dcm в python, используйте библиотеку pydicom. Вот пример

import pydicom as dicom

dcm_path = '.../10006/1459541791.dcm'
ds = dicom.dcm_read(dcm_path) #read dcm file from directory
image = ds.pixel_array #get image

С помощью этой информации мы создаем функцию предварительной обработки, которая считывает файл dcm из пути и преобразует его в тензор (224, 224, 1).

def process_image(dcm_path):
    """Read Image from path and resize it to (224, 224)"""
    ds = dicom.dcmread(image_path)
    image = cv2.resize(ds.pixel_array, [224, 224]).reshape(224, 224, 1)
    return image

Далее мы создаем функцию, которая записывает каждый пример в файл tfrecord. Функция принимает предварительно обработанное изображение и целевое значение в качестве входных данных и возвращает экземпляр tf.train.Example . Каждая функция обрабатывается с помощью вспомогательных функций, определенных выше.

def create_example(image, target):
    """Write example """
    feature = {
        "image": image_feature(image),
        "target": int64_feature(target)
    }   
    return tf.train.Example(features=tf.train.Features(feature=feature))

Наконец, с помощьюtf.io.TFRecordWriter мы записываем набор данных в файлы tfrecord.

for tfrec_num in range(num_tfrecords):
    samples = train['image_id'][(tfrec_num * NUM_RECORDS) : ((tfrec_num + 1) * NUM_RECORDS)]
    tf_dir = tfrecords_dir + f'/tfrecord_{tfrec_num * NUM_RECORDS}-{(tfrec_num + 1) * NUM_RECORDS}.tfrec'
    with tf.io.TFRecordWriter(tf_dir) as writer:
        for sample in samples:
            image_path = train[train['image_id'] == sample]['image_path'].iloc[0]
            image = process_image(image_path)
            target = train[train['image_id'] == sample]['cancer'].iloc[0]
            record = create_example(image, target)
            writer.write(record.SerializeToString())

Приведенный выше код записывает набор данных в файлы tfrecord. Первые две строки получают сэмплы/данные для добавления в файл tfrecord и называют файл tfrecord_1000–2000.tfrec (для сэмплов с 1000 по 2000).

Данные записываются в файл tfrecord по одному. Таким образом, используя цикл for, мы перебираем сэмплы и с помощью tf.io.TFRecordWriter записываем каждый сэмпл (пример) в файл tfrecord. В последней строке обратите внимание на использование методов write и SerializeToString. Метод write записывает пример, а SerializeToString преобразует пример в двоичную строку.

Во второй части мы сосредоточимся на чтении данных из файла tfrecord и обучении глубокой нейронной сети с файлами tfrecord.

Вы можете прочитать это здесь: TFRecords (Часть 2): Чтение и обучение моделей с помощью Tfrecords.