Лучшее понимание размеров PyTorch за счет визуализации процесса суммирования по трехмерному тензору

Когда я начал выполнять некоторые базовые операции с тензорами PyTorch, такие как суммирование, для одномерных тензоров это выглядело просто и довольно просто:

>> x = torch.tensor([1, 2, 3])
>> torch.sum(x)
tensor(6)

Однако, как только я начал экспериментировать с тензорами 2D и 3D и суммировать по строкам и столбцам, меня больше всего смущал второй параметрdimof torch.sum.

Начнем с того, что говорится в официальной документации:

torch.sum (input, dim, keepdim = False, dtype = None) → Tensor

Возвращает сумму каждой строки входного тензора в заданном измерении dim.

Я не совсем понимаю это объяснение. Мы можем подвести итог по столбцам, так почему же можно упомянуть, что он просто «возвращает сумму каждой строки»? Это было мое первое непонимание.

Однако, как я уже сказал, более важной проблемой было направление каждого измерения. Вот что я имею в виду. Когда мы описываем форму двумерного тензора, мы говорим, что он содержит несколько строк и несколько столбцов. Итак, для тензора 2x3 у нас есть 2 строки и 3 столбца:

>> x = torch.tensor([
     [1, 2, 3],
     [4, 5, 6]
   ])
>> x.shape
torch.Size([2, 3])

Мы указываем сначала строки (2 строки), а затем столбцы (3 столбца), верно? Это привело меня к выводу, что первое измерение (dim = 0) остается для строк, а второе (dim = 1) для столбцов. Исходя из того, что размер dim = 0 означает построчное, я ожидал, что torch.sum(x, dim=0) приведет к 1x2 тензор (1 + 2 + 3 и 4 + 5 + 6 для исхода tensor[6, 15]). Но оказалось, что я получил другое: тензор 1x3.

>> torch.sum(x, dim=0)
tensor([5, 7, 9])

Я был удивлен, увидев, что реальность оказалась противоположной тому, что я ожидал, потому что я наконец получил результат tensor[6, 15], но при передаче параметра dim = 1:

>> torch.sum(x, dim=1)
tensor([6, 15])

Так почему это так? Я нашел статью Aerin Kim 🙏, в которой разбирается та же путаница, но для матриц NumPy, где мы передаем второй параметр под названием axis. Сумма NumPy почти идентична той, что есть в PyTorch, за исключением того, что dim в PyTorch называется axis в NumPy:

numpy.sum (a, axis = None, dtype = None, out = None, keepdims = False)

Ключом к пониманию того, как работают dim в PyTorch и axis в NumPy, был этот абзац из статьи Aerin:

Способ понять «ось» числовой суммы состоит в том, что она сворачивает указанную ось. Таким образом, когда он сворачивает ось 0 (строку), он становится всего одной строкой (суммируется по столбцам).

Она очень хорошо объясняет функционирование параметра axis в numpy.sum. Однако становится сложнее, когда мы вводим третье измерение. Когда мы посмотрим на форму трехмерного тензора, мы заметим, что новое измерение добавляется в начало и занимает первую позицию (выделено полужирным шрифтом ниже), то есть третье измерение становится dim=0.

>> y = torch.tensor([
     [
       [1, 2, 3],
       [4, 5, 6]
     ],
     [
       [1, 2, 3],
       [4, 5, 6]
     ],
     [
       [1, 2, 3],
       [4, 5, 6]
     ]
   ])
>> y.shape
torch.Size([3, 2, 3])

Да, это довольно запутанно. Вот почему я думаю, что некоторые базовые визуализации процесса суммирования по различным измерениям в значительной степени способствуют лучшему пониманию.

Первое измерение (dim = 0) этого трехмерного тензора является наивысшим и содержит 3 двумерных тензора. Итак, чтобы подвести итог, мы должны сложить 3 его элемента друг над другом:

>> torch.sum(y, dim=0)
tensor([[ 3,  6,  9],
        [12, 15, 18]])

Вот как это работает:

Для второго измерения (dim = 1) мы должны свернуть строки:

>> torch.sum(y, dim=1)
tensor([[5, 7, 9],
        [5, 7, 9],
        [5, 7, 9]])

И, наконец, третье измерение обрушивается на столбцы:

>> torch.sum(y, dim=2)
tensor([[ 6, 15],
        [ 6, 15],
        [ 6, 15]])

Если вы, как и я, недавно начали изучать PyTorch или NumPy, я надеюсь, что эти базовые анимированные примеры помогут вам лучше понять, как работают измерения, не только для sum, но и для других методов, таких как хорошо.

Спасибо за прочтение!

Использованная литература:

[1] А. Ким, Интуиция оси Numpy Sum