Более глубокий взгляд на то, как операции потока управления работают в TensorFlow

Вступление

Tensorflow - одна из самых популярных платформ глубокого обучения, которая сыграла ключевую роль в продвижении глубокого обучения. Я использую Tensorflow более двух лет, но я видел много странного и непредсказуемого поведения при использовании потока управления.

Недавно (19 апреля 2019 г.) я посмотрел видео собственных внутренних тренингов команды TensorFlow, которое было очень полезно и прояснило, как работают операции управления потоком. Однозначно рекомендую посмотреть это видео.

В видео подробно рассматриваются функции tf.cond () и tf. While_loop. Итак, я решил написать этот пост, чтобы подробнее рассказать о том, как работает tf.cond (), и предоставить несколько примеров для иллюстрации. Надеюсь, я расскажу о tf. while_loop в следующем посте.

Примечание. В этом посте я расскажу о низкоуровневых операциях. Есть и другие операции, такие как функциональные операции, которые выходят за рамки этого сообщения в блоге.

Переключить и объединить

Две важные операции, которые используются при построении графика, - это Switch и Merge. Поэтому в этом разделе я расскажу, как они работают, и приведу несколько примеров, чтобы познакомиться с их странным поведением!

Как видите, коммутатор получает два входа: данные и предикат, и обеспечивает два выхода: данные и мертвый тензор !. Кроме того, Merge принимает два (или более двух) входных данных и предоставляет один выход, который является данными. Я собираюсь перейти к более подробной информации ниже.

Переключить

Во-первых, давайте рассмотрим переключатель. Если вы посетите веб-сайт Tensorflow, вы можете найти это определение и краткое описание работы коммутатора:

Пересылает data на выходной порт, определенный pred. Если pred истинно, ввод data перенаправляется на output_true. В противном случае данные поступают в output_false.

Как я упоминал ранее, Switch получает два входа. Один из них - это предикат, который является логическим тензором (истинным или ложным), а другой - это данные, которые должны быть переданы. Предикат определяет, должны ли данные передаваться ветвью output_true или ветвью output_false. Но одна странность здесь - это концепция мертвого тензора. Независимо от того, является ли предикат истинным или ложным; всегда есть два выхода: один из них - данные, а другой - мертвый тензор. Если pred истинно, мертвый тензор отправляется вместе с output_false (и наоборот).

Нет четкой ссылки, объясняющей, почему и как полезны мертвые тензоры, но кажется, что они полезны для распределенной обработки, и их существование является деталью реализации. например, здесь можно найти как-то убедительный ответ:

Прежде всего, мертвые тензоры - это деталь реализации конструкций потока управления TensorFlow: tf.cond() и tf.while_loop(). Эти конструкции позволяют TensorFlow определять, выполнять ли подграф на основе значения, зависящего от данных.

Давайте посмотрим на примере, чтобы прояснить ситуацию. Я черпал вдохновение из этого поста, чтобы предоставить этот пример.

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
x_0, x_1 = control_flow_ops.switch(tf.constant(1.0), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(2.0), True)
print(x_0, x_1, x_2, x_3)
with tf.Session() as sess:
    print(sess.run(x_0))    # prints 1.0
    print(sess.run(x_3))    # prints 2.0
'''
output:
Tensor("Switch:0", shape=(), dtype=float32) Tensor("Switch:1", shape=(), dtype=float32) Tensor("Switch_1:0", shape=(), dtype=float32) Tensor("Switch_1:1", shape=(), dtype=float32)
1.0
2.0
'''

Итак, давайте посмотрим, что происходит в этом примере. Я создал рисунок, который иллюстрирует происходящее.

Думаю, из рисунка видно, что происходит. например, в x_0, x_1 = control_flow_ops.switch(tf.constant(1.0), False) предикат ложный; следовательно, tf.constant(1.0) пересылается в ветвь output_false, а тензор мертвых данных - в ветвь output_true.

Следует отметить одну важную вещь: я выполнил x_0 и x_3 в tf.Session (), которые содержат данные (тензор). Если я попытаюсь запустить и выполнить мертвый тензор, я столкнусь с ошибкой. Всякий раз, когда вы пытаетесь выполнить и получить мертвый тензор в Session.run (), это приведет к ошибке. например, следующий код вызывает известную и часто возникающую ошибку:

with tf.Session() as sess:
    print(sess.run(x_1))
'''
output:
InvalidArgumentError: Retval[0] does not have value
'''

Теперь, думаю, для Switch этого достаточно. давайте посмотрим, как работает Merge.

Объединить

Слияние - еще один оператор, который необходим для построения графа tf.cond ().

Слияние может принимать более одного входа, но только один из них должен содержать данные, а другие должны быть мертвыми тензорами. В противном случае мы столкнемся с каким-то случайным и непредсказуемым поведением. Давайте посмотрим, как работает Merge в последнем примере:

with tf.Session() as sess:
    print(sess.run(control_flow_ops.merge([x_0, x_1])))       
    print(sess.run(control_flow_ops.merge([x_1, x_0])))       
    print(sess.run(control_flow_ops.merge([x_2, x_3])))   
    print(sess.run(control_flow_ops.merge([x_3, x_2])))     
    print(sess.run(control_flow_ops.merge([x_0, x_1, x_2])))
'''
output:
Merge(output=1.0, value_index=0)
Merge(output=1.0, value_index=1)
Merge(output=2.0, value_index=1)
Merge(output=2.0, value_index=0)
Merge(output=1.0, value_index=0)
Merge(output=2.0, value_index=2)
'''

Он ведет себя полностью в соответствии с нашими ожиданиями. Но все становится немного неожиданным и странным, когда вы загружаете два тензора, у которых есть данные, в Merge.

with tf.Session() as sess:
    print(sess.run(control_flow_ops.merge([x_1, x_0, x_3]))) 
    print(sess.run(control_flow_ops.merge([x_0, x_3])))
    print(sess.run(control_flow_ops.merge([x_3, x_0])))
'''
output:
Merge(output=1.0, value_index=1)
Merge(output=1.0, value_index=0)
Merge(output=2.0, value_index=0)
'''

Иногда он возвращает значение x_0, а иногда значение x_3. Так что будьте осторожны с таким поведением.

Примечание: мертвые тензоры распространяются по вычислительному графу, пока не достигнут операций слияния.

tf.cond ()

Теперь, я думаю, мы хорошо понимаем, как работают Switch и Merge. Пришло время погрузиться в tf.cond (). Я рассматриваю простой случай, когда входными аргументами являются pred, true_fn и false_fn.

tf.cond(pred, true_fn, false_fn)

Я собираюсь рассмотреть простой пример, чтобы представить эту концепцию. учтите следующее условие:

tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

Я построил вычислительный граф для этого простого примера, и вы можете найти его на рисунке 4.

Прежде всего следует упомянуть, что для каждого входа есть переключатель. Под вводом я подразумеваю аргументы функций true и false в tf.cond (). В этом примере есть три входа (x, y и z), и, как следствие, есть три переключателя на вычислительном графе.

Для true_fn выходы переключателя испускаются из истинной ветви. Для false_fn выходы переключателя испускаются из ветви false. В зависимости от результата условия (независимо от того, меньше ли x, чем y или нет), предикат может быть истинным или ложным, и одна из ветвей (левая или правая) будет выполнена. Важно отметить, что операции tf.add() и tf.square() идут после переключателей. В результате в этом примере будет выполнен только один из них, а другой останется нетронутым.

Кроме того, я считаю, что эта картина немного неправильная. Я думаю, что мертвые тензоры распространяются через операцию сложения или возведения в квадрат, пока не встретятся с операциями слияния. Операции слияния удаляют мертвые тензоры и обеспечивают только один вывод.

Надеюсь, вы узнали кое-что о tf.cond (), и вам стало удобнее работать с этим API. Я собираюсь закончить этот пост, приведя один противоречивый пример и объясняя, как то, что мы уже узнали, может помочь нам понять внутреннюю работу. На веб-сайте TensorFlow вы можете найти следующее утверждение:

ПРЕДУПРЕЖДЕНИЕ: любые тензоры или операции, созданные вне true_fn и false_fn, будут выполняться независимо от того, какая ветвь выбрана во время выполнения. Хотя такое поведение согласуется с моделью потока данных TensorFlow, оно часто удивляет пользователей, ожидавших более ленивой семантики.

Итак, я собираюсь привести пример, чтобы прояснить, о чем говорится в этом предупреждении. Я привожу два примера: в первом все операции определены в true_fn и false_fn, а во втором примере некоторые операции определены вне этих функций. Я собираюсь построить и визуализировать вычислительный граф, чтобы проиллюстрировать, почему происходит такое поведение.

Пример 1:

import tensorflow as tf
x = tf.constant(3.0)
y = tf.constant(2.0)
def true_fn():
    z = tf.multiply(x, y)
    print_output = tf.Print(z, [z], "The value I want to print!!")
    return tf.add(x, print_output)
def false_fn():
    return tf.square(y)
result = tf.cond(x < y, true_fn, false_fn)
with tf.Session() as sess:
    print(sess.run(result))
## output: 4.0
'''
if you keep everything the same and just changing x to x = tf.constant(1.0), the predicate becomes true and the output will be as the following:
3.0
The value I want to print!![2]
'''

Здесь важно сосредоточить внимание на том, что все тензоры и операции были созданы внутри функций. Итак, имеется три входных аргумента и, следовательно, на графике существует три переключателя. Построить вычислительный граф для этого случая будет несложно.

если предикат становится истинным (x будет меньше y), будет выполняться true_fn (левая ветвь), а правая не будет выполняться и останется нетронутой (и наоборот).

Примечание: я использовал функцию tf.Print (), чтобы что-то напечатать на вычислительном графике и получить доступ к значению тензора на графике. Использование tf.Print () немного сложно, и я не собираюсь здесь объяснять, как это работает. Об этой функции есть отличный пост в блоге здесь.

Примечание. Когда предикат ложен (x ›y), выполняется false_fn (правая ветвь), и в результате tf.Print () получает только мертвые тензоры и ничего не печатает.

Пример 2:

Пример 1 был немного скучным, и результат полностью оправдал наши ожидания. В этом примере все становится еще интереснее.

x = tf.constant(3.0)
y = tf.constant(2.0)
z = tf.multiply(x, y)
print_output = tf.Print(z, [z], "The value I want to print!!")
def true_fn():
    return tf.add(x, print_output)
def false_fn():
    return tf.square(y)
result = tf.cond(x < y, true_fn, false_fn)
with tf.Session() as sess:
    print(sess.run(result))
'''
output:
4.0
The value I want to print!![6]
'''

В этом примере предикат false (x ›y), и мы ожидаем, что false_fn executes и true_fn останутся нетронутыми. Однако мы видим, что вывод содержит «Значение, которое я хочу напечатать !! [6]», которое принадлежит true_fn. На первый взгляд, такое поведение может показаться немного странным, но оно полностью соответствует тому, что мы видели и понимали до сих пор. Некоторые из тензоров (z и print_output) определены вне функции, и в результате они будут помещены перед переключателем в вычислительном графе. давайте нарисуем график, чтобы прояснить этот момент:

Вы можете видеть на рисунке 6, что операции умножения и печати находятся за пределами (перед) переключателями. Таким образом, независимо от того, является ли предикат истинным или ложным, эти две операции будут выполнены в обоих случаях.

Итак, понимая переключение и слияние и понимая, как работает tf.cond (), надеюсь, вы увидите, что это поведение согласуется с моделью потока данных TensorFlow, и в этом нет ничего плохого.

Я собираюсь закончить этот пост здесь. Спасибо, что дочитали пост до конца. Пожалуйста, дайте мне знать, если я допустил ошибку или что-то не так. Надеюсь, я расскажу о tf. while_loop () в следующем посте.

использованная литература