TensorFlow: Reading images in queue without shuffling -


i have training set of 614 images have been shuffled. want read images in order in batches of 5. because labels arranged in same order, shuffling of images when being read batch result in incorrect labelling.

these functions read , add images batch:

# add files queue batch: def add_to_batch(image):      print('adding batch')     image_batch = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)      # add summary     tf.image_summary('images',image_batch,max_images=30)      return image_batch  # read files in queue , process: def get_batch():      # create filename queue of images read     filenames = [('/media/jessica/jessica/tensorflow/streetview/training/original/train_%d.png' % i) in range(1,614)]     filename_queue =   tf.train.string_input_producer(filenames,shuffle=false,capacity=614)     reader = tf.wholefilereader()     key, value = reader.read(filename_queue)      # read , process image     # image 500 x 275:     my_image = tf.image.decode_png(value)     my_image_float = tf.cast(my_image,tf.float32)     my_image_float = tf.reshape(my_image_float,[275,500,4])      return add_to_batch(my_image_float) 

this function perform prediction:

def inference(x):      < perform convolution, pooling etc.>      return y_conv 

this function calculate loss , perform optimisation:

def train_step(y_label,y_conv):      """ calculate loss """     # cross-entropy     loss = -tf.reduce_sum(y_label*tf.log(y_conv + 1e-9))      # add summary     tf.scalar_summary('loss',loss)      """ optimisation """     opt = tf.train.adamoptimizer().minimize(loss)      return loss 

this main function:

def main ():      # training     images = get_batch()     y_conv = inference(images)     loss = train_step(y_label,y_conv)      # write , merge summaries     writer = tf.train.summarywriter('/media/jessica/jessica/tensorflow/streetview/summarylogs/log_5', graph_def=sess.graph_def)     merged = tf.merge_all_summaries()      """ run session """     sess.run(tf.initialize_all_variables())     tf.train.start_queue_runners(sess=sess)      print "running..."     step in range(5):          # y_1 = <get correct labels here>          # train         loss_value = sess.run(train_step,feed_dict={y_label:y_1})         print "step %d, loss %g"%(step,loss_value)          # save summary         summary_str = sess.run(merged,feed_dict={y_label:y_1})         writer.add_summary(summary_str,step)      print('finished')  if __name__ == '__main__':   main() 

when check image_summary images not seem in sequence. or rather, happening is:

images 1-5: discarded, images 6-10: read, images 11-15: discarded, images 16-20: read etc.

so looks getting batches twice, throwing away first 1 , using second one? have tried few remedies nothing seems work. feel understanding fundamentally wrong calling images = get_batch() , sess.run().

your batch operation fifoqueue, every time use it's output, advances state.

your first session.run call uses images 1-5 in computation of train_step, second session.run asks computation of image_summary pulls images 5-6 , uses them in visualization.

if want visualize things without affecting state of input, helps cache queue values in variables , define summaries variables inputs rather depending on live queue.

(image_batch_live,) = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)  image_batch = tf.variable(   tf.zeros((batch_size, image_size, image_size, color_channels)),   trainable=false,   name="input_values_cached")  advance_batch = tf.assign(image_batch, image_batch_live) 

so image_batch static value can use both computing loss , visualization. between steps call sess.run(advance_batch) advance queue.

minor wrinkle approach -- default saver save image_batch variable checkpoint. if ever change batch-size, checkpoint restore fail dimension mismatch. work-around need specify list of variables restore manually, , run initializers rest.


Comments

Popular posts from this blog

Django REST Framework perform_create: You cannot call `.save()` after accessing `serializer.data` -

Why does Go error when trying to marshal this JSON? -