tensorflow next_batch против пользовательского next_batch?

Я пытаюсь написать функцию, которая может получать пакеты данных, аналогичную next_batch в tensorflow.

next_batch можно увидеть здесь: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py

Это код, который я написал.

class Sampler:

def __init__(self, data):        
    self.x, self.y = data
    self.N, = self.y.shape
    self.start = 0
    self.shuffle = np.arange(self.N)
    np.random.shuffle(self.shuffle)
    self.x = self.x[self.shuffle]
    self.y = self.y[self.shuffle]

def sample(self, s):
    start = self.start
    end = np.minimum(start+s, self.N)
    data = (self.x[start:end], self.y[start:end])
    self.start += s   
    if self.start >= self.N - 1:
        self.start = 0
        np.random.shuffle(self.shuffle)
        self.x = self.x[self.shuffle]
        self.y = self.y[self.shuffle]
    return data

Я чувствую, что это естественный подход, но, хотя я могу получить точность 99%+ с классификацией, используя next_batch, я могу получить только около 50%, используя мою функцию «выборка».

Может ли кто-нибудь помочь мне понять, что происходит?


person girl-meets-world    schedule 27.02.2018    source источник
comment
Насколько я могу судить, ваш код делает почти то же самое, что и функция next_batch из примера mnist. Единственным отличием является то, что класс DataSet в примере сглаживает входные данные из (x,y,z,1) в (x,y*z), а затем также нормализует все данные из [0,256] в [0,1]. Ни один из них не должен влиять на точность немедленно, но в зависимости от того, как вы тренируетесь, они могут иметь эффект.   -  person cmxu    schedule 27.02.2018
comment
Большое спасибо - это решило мою проблему. Я бы отметил это как правильный ответ, но это комментарий, поэтому я думаю, что не могу этого сделать. Пожалуйста, не стесняйтесь написать это в ответ, я поставлю галочку! :) Еще раз большое спасибо.   -  person girl-meets-world    schedule 27.02.2018
comment
НП, спасибо за 5 долларов :) jkjk   -  person cmxu    schedule 27.02.2018


Ответы (1)


Прямой cp из моего комментария, но...

Насколько я могу судить, ваш код делает почти то же самое, что и функция next_batch из примера mnist. Единственным отличием является то, что класс DataSet в примере сглаживает входные данные из (x,y,z,1) в (x,y*z), а затем также нормализует все данные из [0,256] в [0,1]. Ни один из них не должен влиять на точность немедленно, но в зависимости от того, как вы тренируетесь, они могут иметь эффект.

person cmxu    schedule 27.02.2018