InvalidArgumentError: Входные данные для изменения формы — это тензор с 0 значениями, но запрошенная форма имеет 54912

Очень начинающий вопрос, надеюсь, все в порядке

Я пытаюсь обучить эту модель из GitHub с помощью Набор данных MAPS, и я сделал новые .tfrecords с этим кодом для набора поездов. Он основан на коде здесь, но Я изменил некоторые вещи, чтобы освободить место для другого входа (еще один MIDI-файл, который я просто называю «tempo MIDI»).

def create_train_set(tempopath, train_list, outdir, min_length, max_length):
  # train_list = list of wav paths selected for  

  train_file_pairs = []

  # find matching midi files

  for wav_path in train_list:
    midi_file = ''
    tempo_midi_file = ''

    if os.path.isfile(wav_path + '.mid'):
      midi_file = wav_path + '.mid'
    if os.path.isfile(wav_path + '.midi'):
      midi_file = wav_path + '.midi'

    if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.mid'):
      tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.mid'
    if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.midi'):
      tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.midi'

    wav_file = wav_path + '.wav'   
    train_file_pairs.append((wav_file, midi_file, tempo_midi_file))

  train_output_name = os.path.join(outdir, 'train.tfrecord')

  with tf.python_io.TFRecordWriter(train_output_name) as writer:
    for idx, pair in enumerate(train_file_pairs):
      print('{} of {}: {}'.format(idx, len(train_file_pairs), pair[0]))
      # load the wav data
      wav_data = tf.gfile.Open(pair[0], 'rb').read()
      # load the midi data and convert to a notesequence
      ns = midi_io.midi_file_to_note_sequence(pair[1])
      tempo = midi_io.midi_file_to_note_sequence(pair[2])
      # aldu = audio_label_data_utils.py
      for example in aldu.process_record(          
          wav_data, ns, tempo, pair[0], min_length, max_length,
          sample_rate):       
        writer.write(example.SerializeToString())

с tf.Example следующим образом:

  example = tf.train.Example(
      features=tf.train.Features(
          feature={
              'id':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[example_id.encode('utf-8')])),
              'sequence':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[ns.SerializeToString()])),
              'audio':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(value=[wav_data])),
              'tempo':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[velocity_range.SerializeToString()])),                        
              'velocity_range':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[velocity_range.SerializeToString()])),          
          })) 

Однако, когда я пытаюсь обучить модель, я получаю это сообщение об ошибке (я пометил скрипты py строкой печати, чтобы знать, куда все идет):

Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py
Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py

E0611 07:56:55.419340  8436 error_handling.py:70] Error recorded from training_loop: Input to reshape is a tensor with 0 values, but the requested shape has 54912
         [[{{node Reshape_8}}]]
         [[IteratorGetNext]]
I0611 07:56:55.420338  8436 error_handling.py:96] training_loop marked as finished
W0611 07:56:55.421335  8436 error_handling.py:130] Reraising captured error
Traceback (most recent call last):
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
    return fn(*args)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
         [[{{node Reshape_8}}]]
         [[IteratorGetNext]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "onsets_frames_transcription_train.py", line 128, in <module>
    console_entry_point()
  File "onsets_frames_transcription_train.py", line 124, in console_entry_point
    tf.app.run(main)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\platform\app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 300, in run
    _run_main(main, args)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "onsets_frames_transcription_train.py", line 120, in main
    additional_trial_info=additional_trial_info)
  File "onsets_frames_transcription_train.py", line 95, in run
    num_steps=FLAGS.num_steps)
  File "C:\Users\User\magenta\magenta\models\onsets_frames_transcription\train_util.py", line 134, in train
    estimator.train(input_fn=transcription_data, max_steps=num_steps)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2876, in train
    rendezvous.raise_errors()
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\error_handling.py", line 131, in raise_errors
    six.reraise(typ, value, traceback)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1192, in _train_model_default
    saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1484, in _train_with_estimator_spec
    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 754, in run
    run_metadata=run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1252, in run
    run_metadata=run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1353, in run
    raise six.reraise(*original_exc_info)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1338, in run
    return self._sess.run(*args, **kwargs)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1411, in run
    run_metadata=run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1169, in run
    return self._sess.run(*args, **kwargs)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
    run_metadata_ptr)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
    run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
         [[{{node Reshape_8}}]]
         [[IteratorGetNext]]

Из этого я понял, что проблема заключается в wav_to_num_frames, но это единственный код для этого.

def wav_to_num_frames(wav_audio, frames_per_second):
  """Transforms a wav-encoded audio string into number of frames."""
  print("Running wav_to_num_frames from data")
  w = wave.open(six.BytesIO(wav_audio))
  return np.int32(w.getnframes() / w.getframerate() * frames_per_second)

У меня не было этой проблемы, когда я пытался обучить модель с помощью tfrecords, созданных с помощью исходного кода, поэтому я не знаю, что не так.


person intrusiveVoyager    schedule 11.06.2020    source источник


Ответы (1)


Оказывается, проблема заключалась не в самих созданных .tfrecords, а в размере тензоров, которые я назначил для вновь добавленных данных. На этот вопрос нет конкретного ответа, поскольку он очень специфичен для этой ситуации.

person intrusiveVoyager    schedule 13.06.2020