Я использую tf2.0
и хочу использовать tf.keras
и tf.data.dataset
для обучения моей сети. Однако я изо всех сил пытаюсь использовать tf.keras.fit
с tf.data.dataset
с несколькими выходами и с функцией настраиваемых потерь вместе.
Моя версия тензорного потока - tf2.0
, а вот пример кода, который я пробовал, но потерпел неудачу.
import tensorflow as tf
import numpy as np
# define model
inputs = tf.keras.Input((512,512,3), name='model_input')
x = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same',kernel_initializer=tf.random_normal_initializer(stddev=0.01), name='conv1')(inputs)
x = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same',kernel_initializer=tf.random_normal_initializer(stddev=0.01), name='conv2')(x)
output1 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same',kernel_initializer=tf.random_normal_initializer(stddev=0.01), name='output1')(x)
output2 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same',kernel_initializer=tf.random_normal_initializer(stddev=0.01), name='output2')(x)
model = tf.keras.Model(inputs, [output1, output2])
# define dataset
def parse_func(single_data): # just for example case
input = single_data
output1 = single_data
output2 = single_data
weight1 = output1
weight2 = output2
return input, output1, output2, weight1, weight2
def tf_parse_func(single_data):
return tf.py_function(parse_func, [single_data], [tf.float32, tf.flaot32, tf.float32, tf.flaot32, tf.float32])
data = np.random.rand(10, 512, 512, 3)
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.map(tf_parse_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(2, drop_remainder=True)
# def loss func
def loss_fn1(label, pred):
return tf.reduce_mean(tf.keras.losses.MSE(label, pred))
def loss_fn2(label, pred):
return tf.nn.l2_loss(label-pred)
# start training
model.compile(loss={'output1':loss_fn1, 'output2':loss_fn2},
loss_weights={'output1':1, 'output2':2},
optimizer=tf.keras.optimizers.Adam())
model.fit(dataset, epochs=5)
На самом деле я хочу пройти
loss_weights={'output1': 1, 'output2': 2}
нравится
loss_weights={'output1': weight1, 'output2': weight2}
,
но я не знал, как это сделать. Лучше передать weight1/weight2
в качестве параметра функции потерь, но я не знал, как это сделать. Я хочу loss_fn1
использовать output1, weight1
из dataset
и loss_fn2
использовать output2, weight2
.
Когда я запускаю приведенный выше код, я получаю такую ошибку:
2019-10-22 20:47:40.551618: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 62914560 exceeds 10% of system memory.
1/Unknown - 0s 28ms/stepTraceback (most recent call last):
File "tools/keras_train_test.py", line 65, in <module>
model.fit(dataset, epochs=5)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py", line 324, in fit
total_epochs=epochs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py", line 123, in run_one_epoch
batch_outs = execution_function(iterator)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 86, in execution_function
distributed_function(input_fn))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 457, in __call__
result = self._call(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 503, in _call
self._initialize(args, kwds, add_initializers_to=initializer_map)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 408, in _initialize
*args, **kwds))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1848, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2150, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 2041, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py", line 358, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 66, in distributed_function
model, input_iterator, mode)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 112, in _prepare_feed_values
inputs, targets, sample_weights = _get_input_from_iterator(inputs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 145, in _get_input_from_iterator
x, y, sample_weights = next_element
ValueError: too many values to unpack (expected 3)
Я пробовал много подходов и других способов, которые я нашел, но не могу заставить это работать. Так может ли кто-нибудь мне помочь? Большое спасибо!