三木社区

 找回密码
 立即注册
搜索
热搜: 活动 交友 discuz
查看: 441|回复: 0
打印 上一主题 下一主题

读取文件

[复制链接]

1562

主题

1564

帖子

4904

积分

博士

Rank: 8Rank: 8

积分
4904
跳转到指定楼层
楼主
发表于 2017-9-18 08:01:06 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
TensorFlow读取数据共有三种方法:
  • Feeding:当TensorFlow运行每步计算的时候,从Python获取数据。在Graph的设计阶段,用placeholder占住Graph的位置,完成Graph的表达;当Graph传给Session后,在运算时再把需要的数据从Python传过来。
  • Preloaded data:数据直接预加载到TensorFlow的Graph中,再把Graph传入Session运行。只适用于小数据。
  • Reading from file:在Graph中定义好文件读取的运算节点,把Graph传入Session运行时,执行读取文件的运算,这样可以避免在Python和TensorFlow C++执行环境之间反复传递数据。

本文讲解Reading from file的代码。
其他关于TensorFlow的学习笔记,请点击入门教程
实现
  1. #!/usr/bin/env python
  2. # -*- coding=utf-8 -*-
  3. # @author: 陈水平
  4. # @date: 2017-02-19
  5. # @description: modified program to illustrate reading from file based on TF offitial tutorial
  6. # @ref: https://www.tensorflow.org/programmers_guide/reading_data

  7. def read_my_file_format(filename_queue):
  8.   """从文件名队列读取一行数据
  9.   
  10.   输入:
  11.   -----
  12.   filename_queue:文件名队列,举个例子,可以使用`tf.train.string_input_producer(["file0.csv", "file1.csv"])`方法创建一个包含两个CSV文件的队列
  13.   
  14.   输出:
  15.   -----
  16.   一个样本:`[features, label]`
  17.   """
  18.   reader = tf.SomeReader()  # 创建Reader
  19.   key, record_string = reader.read(filename_queue)  # 读取一行记录
  20.   example, label = tf.some_decoder(record_string)  # 解析该行记录
  21.   processed_example = some_processing(example)  # 对特征进行预处理
  22.   return processed_example, label

  23. def input_pipeline(filenames, batch_size, num_epochs=None):
  24.   """ 从一组文件中读取一个批次数据
  25.   
  26.   输入:
  27.   -----
  28.   filenames:文件名列表,如`["file0.csv", "file1.csv"]`
  29.   batch_size:每次读取的样本数
  30.   num_epochs:每个文件的读取次数
  31.   
  32.   输出:
  33.   -----
  34.   一批样本,`[[example1, label1], [example2, label2], ...]`
  35.   """
  36.   filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)  # 创建文件名队列
  37.   example, label = read_my_file_format(filename_queue)  # 读取一个样本
  38.   # 将样本放进样本队列,每次输出一个批次样本
  39.   #   - min_after_dequeue:定义输出样本后的队列最小样本数,越大随机性越强,但start up时间和内存占用越多
  40.   #   - capacity:队列大小,必须比min_after_dequeue大
  41.   min_after_dequeue = 10000
  42.   capacity = min_after_dqueue + 3 * batch_size
  43.   example_batch, label_batch = tf.train.shuffle_batch(
  44.     [example, label], batch_size=batch_size, capacity=capacity,
  45.     min_after_dequeue=min_after_dequeue)
  46.   return example_batch, label_batch
  47.   
  48. def main(_):
  49.   x, y = input_pipeline(['file0.csv', 'file1.csv'], 1000, 5)
  50.   train_op = some_func(x, y)
  51.   init_op = tf.global_variables_initializer()
  52.   local_init_op = tf.local_variables_initializer()  # local variables like epoch_num, batch_size
  53.   sess = tf.Session()
  54.   
  55.   sess.run(init_op)
  56.   sess.run(local_init_op)
  57.   
  58.   # `QueueRunner`用于创建一系列线程,反复地执行`enqueue` op
  59.   # `Coordinator`用于让这些线程一起结束
  60.   # 典型应用场景:
  61.   #   - 多线程准备样本数据,执行enqueue将样本放进一个队列
  62.   #   - 一个训练线程从队列执行dequeu获取一批样本,执行training op
  63.   # `tf.train`的许多函数会在graph中添加`QueueRunner`对象,如`tf.train.string_input_producer`
  64.   # 在执行training op之前,需要保证Queue里有数据,因此需要先执行`start_queue_runners`
  65.   coord = tf.train.Coordinator()
  66.   threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  67.   
  68.   try:
  69.     while not coord.should_stop():
  70.       sess.run(train_op)
  71.   except tf.errors.OutOfRangeError:
  72.     print 'Done training -- epoch limit reached'
  73.   finally:
  74.     coord.request_stop()
  75.   
  76.   # Wait for threads to finish  
  77.   coord.join(threads)
  78.   sess.close()
  79.   
  80. if __name__ == '__main__':
  81.   tf.app.run()
复制代码


回复

使用道具 举报

Archiver|手机版|小黑屋|三木电子社区 ( 辽ICP备11000133号-4 )

辽公网安备 21021702000620号

GMT+8, 2025-5-9 23:57 , Processed in 0.030290 second(s), 22 queries .

Powered by Discuz! X3.3

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表