Как использовать оптимизатор в прямом проходе в PyTorch

Я хочу использовать оптимизатор в прямом проходе пользовательской функции, но он не работает. Мой код выглядит следующим образом:

class MyFct(Function):

   @staticmethod
   def forward(ctx, *args):
       input, weight, bias = args[0], args[1], args[2]

       y = torch.tensor([[0]], dtype=torch.float, requires_grad=True) #initial guess
       loss_fn = lambda y_star: (input + weight - y_star)**2

       learning_rate = 1e-4
       optimizer = torch.optim.Adam([y], lr=learning_rate)
       for t in range(5000):
           y_star = y
           print(y_star)
           loss = loss_fn(y_star)
           if t % 100 == 99:
               print(t, loss.item())
           optimizer.zero_grad()
           loss.backward()
           optimizer.step() 

       return y_star

И это мои тестовые данные:

x = torch.tensor([[2]], dtype=torch.float, requires_grad=True)
w = torch.tensor([[2]], dtype=torch.float, requires_grad=True)
y = torch.tensor([[6]], dtype=torch.float)

fct= MyFct.apply
y_hat = fct(x, w, None)

Я всегда получаю RuntimeError: элемент 0 тензоров не требует grad и не имеет grad_fn.

Кроме того, я тестировал оптимизацию вне форварда, и она работает, так что я думаю, что это что-то с контекстом? Согласно документации, «аргументы тензорного типа, отслеживающие историю (например, с require_grad = True), будут преобразованы в те, которые не отслеживают историю до вызова, и их использование будет зарегистрировано на графике», см. https://pytorch.org/docs/stable/notes/exnding.html. Это проблема? Есть ли способ обойти это?

Я новичок в PyTorch, и мне интересно, что я упускаю из виду. Любая помощь и объяснение приветствуются.


person SimonAda    schedule 27.02.2020    source источник


Ответы (1)


Думаю, я нашел здесь ответ: https://github.com/pytorch/pytorch/issues/8847, т.е. мне нужно обернуть опримизацию with torch.enable_grad():.

Однако я до сих пор не понимаю, почему необходимо преобразовывать исходные тензоры в те, которые не отслеживают историю в forward ().

person SimonAda    schedule 27.02.2020