Tflearn Ошибка формы ввода

Ошибка

«Невозможно передать значение формы (128, 1) для Tensor 'TargetsData/Y:0', которое имеет форму '(?,)'».

Код

У меня 4 класса и словарный запас состоит из 17355 слов.

tf.reset_default_graph()
net = tflearn.input_data(shape=(None,trainX.shape[1]),name='input')
net = tflearn.fully_connected(net, 200, activation='ReLU')
net = tflearn.fully_connected(net, 25, activation='ReLU')
net = tflearn.fully_connected(net, 4, activation='softmax')
net = tflearn.regression(net, optimizer='sgd', 
                         learning_rate=0.1, 
                         to_one_hot = True,n_classes =4,
                         loss='categorical_crossentropy')

model = tflearn.DNN(net)
model.fit(trainX, trainY, validation_set=0.1, show_metric=True, batch_size=128, n_epoch=100)

trainX.shape = 12384,17355, trainY.shape = 12384,1, testX.shape = 1376,17355, testY.shape = 1376,1


person Amazon_Warrior    schedule 26.08.2017    source источник


Ответы (1)


Что вызывает эту ошибку?

Ошибка "Невозможно передать значение формы... для тензора 'TargetsData/Y:0', форма которого..." в основном вызвана тем, что форма trainY отличается от формы заполнителя оценщика ( слой регрессии.

Почему?

В вашем случае основная проблема заключается в том, что форма trainY (?, 1), которая является двумерным тензором, но форма заполнителя (?,), которая является одномерным тензором. Итак, мы получаем эту ошибку.

Как это решить?

Преобразуйте trainY в одномерный тензор. Поскольку вы установили to_one_hot = True в слое регрессии, поэтому форма-заполнитель представляет собой одномерный тензор, который содержит индексы класса. Для получения подробной информации вы можете проверить исходный код: tflearn/tflearn/layers /estimator.py о регрессии:

   with tf.name_scope(pscope):
        p_shape = [None] if to_one_hot else input_shape
        placeholder = tf.placeholder(shape=p_shape, dtype=dtype, name="Y") 

Итак, нам нужно изменить форму поезда Y с (12384,1) на (12384,) перед подачей в модель.

person Yangguang    schedule 28.08.2017