В моем предыдущем посте я подробно рассказал, что такое федеративное обучение. Здесь я расскажу, как создать вашу собственную модель на основе федеративного обучения, используя фреймворк под названием Цветок.

Мы рассмотрим кросс-девайсный и асинхронный дизайн. Это очень похоже на GBoard и Siri, где локальная модель находится на граничном устройстве (в данном случае на вашем телефоне / Mac).

Давайте сразу перейдем к компонентам, необходимым для построения вашей собственной модели. Федеративная система обучения нуждается в двух частях

  1. Сервер
  2. Клиент.

Специалист по данным имеет полный контроль над сервером. На сервере размещается логика агрегации и проверяется, что все устройства имеют самые последние и обновленные параметры модели.

У клиентов (устройств) есть локальная модель, работающая на локальных данных.

В нашем случае мы будем выполнять следующие шаги.

  1. Мы построим простую модель нейронной сети на основе pytorch для чтения изображений и их классификации.
  2. Сначала мы обучим модель локальным данным в клиенте. Давайте начнем с 3 устройств, поэтому у нас есть 3 локально работающих модели на 3 отдельных устройствах.
  3. После того, как наша модель обучена и у нас есть параметры модели, мы пытаемся подключиться к серверу.
  4. Затем сервер либо принимает, либо отклоняет приглашение подключиться на основе некоторой политики. Здесь мы просто воспользуемся политикой «первым пришел - первым обслужен».
  5. Если соединение проходит, клиент отправляет параметры модели на сервер.
  6. Сервер ожидает всех 3 параметров модели, а затем агрегирует их, таким образом используя все данные во всех моделях.
  7. Это может происходить в течение любого количества эпох, сколько мы хотим обучать данные.
  8. Затем сервер отправляет клиентам обновленные весовые параметры.
  9. Теперь клиент будет использовать веса для классификации изображений.

Давайте создадим файл caller server.py и добавим следующие строки:

import flwr as fl

# Start Flower server for three rounds of federated learning
if __name__ == "__main__":
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=0.1,
        min_available_clients=3
)
fl.server.start_server("[::]:8080", config={"num_rounds": 3}, , strategy=strategy)

Это все, что нам нужно для запуска сервера, привязанного к localhost. Стратегия - это наша политика. Num_rounds указывает, что обучение будет продолжаться 3 раунда. В каждом раунде может быть свой набор клиентов в зависимости от того, какие устройства подключались первыми. Fraction_fit отбирает 10% всех доступных клиентов в каждом раунде. min_available_client - минимальное количество клиентов, которые необходимо подключить для начала обучения. Вы можете найти различные способы определения своей политики здесь.

Вы можете разместить свой server.py в AWS в EC2 или Sagemaker. Или запустите его на своей рабочей станции.

Теперь напишем нашего клиента. Вы можете найти блокнот colab здесь и код на git здесь. Flwr основан на GRPC, которого нет в бесплатной совместной версии. Вы можете создать экземпляр докера или запустить его на своей рабочей станции.

Единственная разница между ИНС, не основанной на FL, и FL заключается в подключении к серверу и получении обновленных весов. Здесь мы рассмотрим раздел записной книжки «Федеративное обучение».

class CifarClient(fl.client.NumPyClient):
def get_parameters(self):
    return [val.cpu().numpy() for _, val in   net.state_dict().items()]
def set_parameters(self, parameters):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
    self.set_parameters(parameters)
    train(net, trainloader, epochs=1)
    return self.get_parameters(), len(trainloader), {}
def evaluate(self, parameters, config):
    self.set_parameters(parameters)
    loss, accuracy = test(net, testloader)
    return float(loss), len(testloader), {"accuracy": accuracy}

Это самый важный класс, реализующий flowr. функция -

  1. get_parameters: возвращает параметры модели на сервер в виде списка ndarrays NumPy.
  2. set_parameters: устанавливает параметры модели в клиенте из списка NumPy ndarrays.
  3. fit: устанавливает параметры модели, обучает модель в клиенте и возвращает обновленные параметры модели на сервер.
  4. оценить: устанавливает параметры модели, оценивает модель на локальном наборе тестовых данных в клиенте и возвращает результат на сервер.

теперь вы можете запустить последнюю ячейку и проверить, как на точность вашей модели повлияло централизованное обучение. (Убедитесь, что у вас есть 3 копии client.py, поскольку мы упоминали min_available_clients = 3.

Использованная литература:



Https://github.com/adap/flower