三木社区

 找回密码
 立即注册
搜索
热搜: 活动 交友 discuz
查看: 476|回复: 0
打印 上一主题 下一主题

基于MNIST数据的循环神经网络RNN

[复制链接]

1562

主题

1564

帖子

4904

积分

博士

Rank: 8Rank: 8

积分
4904
跳转到指定楼层
楼主
发表于 2017-9-18 07:51:47 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
本文输入数据是MNIST,全称是Modified National Institute of Standards and Technology,是一组由这个机构搜集的手写数字扫描文件和每个文件对应标签的数据集,经过一定的修改使其适合机器学习算法读取。这个数据集可以从牛的不行的Yann LeCun教授的网站获取。
本系列的其他文章已经根据TensorFlow的官方教程基于MNIST数据集采用了softmax regression和CNN进行建模。为了完整性,本文对MNIST数据应用RNN模型求解,具体使用的RNN为LSTM。
关于RNN/LSTM的理论知识,可以参考这篇文章
  1. # coding: utf-8
  2. # @author: 陈水平
  3. # @date:2017-02-14
  4. #

  5. # In[1]:

  6. import tensorflow as tf
  7. import numpy as np


  8. # In[2]:

  9. sess = tf.InteractiveSession()


  10. # In[3]:

  11. from tensorflow.examples.tutorials.mnist import input_data
  12. mnist = input_data.read_data_sets('mnist/', one_hot=True)


  13. # In[4]:

  14. learning_rate = 0.001
  15. batch_size = 128

  16. n_input = 28
  17. n_steps = 28
  18. n_hidden = 128
  19. n_classes = 10

  20. x = tf.placeholder(tf.float32, [None, n_steps, n_input])
  21. y = tf.placeholder(tf.float32, [None, n_classes])


  22. # In[5]:

  23. def RNN(x, weight, biases):
  24.     # x shape: (batch_size, n_steps, n_input)
  25.     # desired shape: list of n_steps with element shape (batch_size, n_input)
  26.     x = tf.transpose(x, [1, 0, 2])
  27.     x = tf.reshape(x, [-1, n_input])
  28.     x = tf.split(0, n_steps, x)
  29.     outputs = list()
  30.     lstm = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
  31.     state = (tf.zeros([n_steps, n_hidden]),)*2
  32.     sess.run(state)
  33.     with tf.variable_scope("myrnn2") as scope:
  34.         for i in range(n_steps-1):
  35.             if i > 0:
  36.                 scope.reuse_variables()
  37.             output, state = lstm(x[i], state)
  38.             outputs.append(output)
  39.     final = tf.matmul(outputs[-1], weight) + biases
  40.     return final


  41. # In[6]:

  42. def RNN(x, n_steps, n_input, n_hidden, n_classes):
  43.     # Parameters:
  44.     # Input gate: input, previous output, and bias
  45.     ix = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))
  46.     im = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))
  47.     ib = tf.Variable(tf.zeros([1, n_hidden]))
  48.     # Forget gate: input, previous output, and bias
  49.     fx = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))
  50.     fm = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))
  51.     fb = tf.Variable(tf.zeros([1, n_hidden]))
  52.     # Memory cell: input, state, and bias
  53.     cx = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))
  54.     cm = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))
  55.     cb = tf.Variable(tf.zeros([1, n_hidden]))
  56.     # Output gate: input, previous output, and bias
  57.     ox = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))
  58.     om = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))
  59.     ob = tf.Variable(tf.zeros([1, n_hidden]))
  60.     # Classifier weights and biases
  61.     w = tf.Variable(tf.truncated_normal([n_hidden, n_classes]))
  62.     b = tf.Variable(tf.zeros([n_classes]))

  63.     # Definition of the cell computation
  64.     def lstm_cell(i, o, state):
  65.         input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)
  66.         forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)
  67.         update = tf.tanh(tf.matmul(i, cx) + tf.matmul(o, cm) + cb)
  68.         state = forget_gate * state + input_gate * update
  69.         output_gate = tf.sigmoid(tf.matmul(i, ox) +  tf.matmul(o, om) + ob)
  70.         return output_gate * tf.tanh(state), state
  71.    
  72.     # Unrolled LSTM loop
  73.     outputs = list()
  74.     state = tf.Variable(tf.zeros([batch_size, n_hidden]))
  75.     output = tf.Variable(tf.zeros([batch_size, n_hidden]))
  76.    
  77.     # x shape: (batch_size, n_steps, n_input)
  78.     # desired shape: list of n_steps with element shape (batch_size, n_input)
  79.     x = tf.transpose(x, [1, 0, 2])
  80.     x = tf.reshape(x, [-1, n_input])
  81.     x = tf.split(0, n_steps, x)
  82.     for i in x:
  83.         output, state = lstm_cell(i, output, state)
  84.         outputs.append(output)
  85.     logits =tf.matmul(outputs[-1], w) + b
  86.     return logits


  87. # In[7]:

  88. pred = RNN(x, n_steps, n_input, n_hidden, n_classes)

  89. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
  90. optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

  91. correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
  92. accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

  93. # Initializing the variables
  94. init = tf.global_variables_initializer()


  95. # In[8]:

  96. # Launch the graph
  97. sess.run(init)
  98. for step in range(20000):
  99.     batch_x, batch_y = mnist.train.next_batch(batch_size)
  100.     batch_x = batch_x.reshape((batch_size, n_steps, n_input))
  101.     sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})

  102.     if step % 50 == 0:
  103.         acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
  104.         loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
  105.         print "Iter " + str(step) + ", Minibatch Loss= " +               "{:.6f}".format(loss) + ", Training Accuracy= " +               "{:.5f}".format(acc)
  106. print "Optimization Finished!"


  107. # In[9]:

  108. # Calculate accuracy for 128 mnist test images
  109. test_len = batch_size
  110. test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
  111. test_label = mnist.test.labels[:test_len]
  112. print "Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label})
复制代码
输出如下:
  1. Iter 0, Minibatch Loss= 2.540429, Training Accuracy= 0.07812
  2. Iter 50, Minibatch Loss= 2.423611, Training Accuracy= 0.06250
  3. Iter 100, Minibatch Loss= 2.318830, Training Accuracy= 0.13281
  4. Iter 150, Minibatch Loss= 2.276640, Training Accuracy= 0.13281
  5. Iter 200, Minibatch Loss= 2.276727, Training Accuracy= 0.12500
  6. Iter 250, Minibatch Loss= 2.267064, Training Accuracy= 0.16406
  7. Iter 300, Minibatch Loss= 2.234139, Training Accuracy= 0.19531
  8. Iter 350, Minibatch Loss= 2.295060, Training Accuracy= 0.12500
  9. Iter 400, Minibatch Loss= 2.261856, Training Accuracy= 0.16406
  10. Iter 450, Minibatch Loss= 2.220284, Training Accuracy= 0.17969
  11. Iter 500, Minibatch Loss= 2.276015, Training Accuracy= 0.13281
  12. Iter 550, Minibatch Loss= 2.220499, Training Accuracy= 0.14062
  13. Iter 600, Minibatch Loss= 2.219574, Training Accuracy= 0.11719
  14. Iter 650, Minibatch Loss= 2.189177, Training Accuracy= 0.25781
  15. Iter 700, Minibatch Loss= 2.195167, Training Accuracy= 0.19531
  16. Iter 750, Minibatch Loss= 2.226459, Training Accuracy= 0.18750
  17. Iter 800, Minibatch Loss= 2.148620, Training Accuracy= 0.23438
  18. Iter 850, Minibatch Loss= 2.122925, Training Accuracy= 0.21875
  19. Iter 900, Minibatch Loss= 2.065122, Training Accuracy= 0.24219
  20. ...
  21. Iter 19350, Minibatch Loss= 0.001304, Training Accuracy= 1.00000
  22. Iter 19400, Minibatch Loss= 0.000144, Training Accuracy= 1.00000
  23. Iter 19450, Minibatch Loss= 0.000907, Training Accuracy= 1.00000
  24. Iter 19500, Minibatch Loss= 0.002555, Training Accuracy= 1.00000
  25. Iter 19550, Minibatch Loss= 0.002018, Training Accuracy= 1.00000
  26. Iter 19600, Minibatch Loss= 0.000853, Training Accuracy= 1.00000
  27. Iter 19650, Minibatch Loss= 0.001035, Training Accuracy= 1.00000
  28. Iter 19700, Minibatch Loss= 0.007034, Training Accuracy= 0.99219
  29. Iter 19750, Minibatch Loss= 0.000608, Training Accuracy= 1.00000
  30. Iter 19800, Minibatch Loss= 0.002913, Training Accuracy= 1.00000
  31. Iter 19850, Minibatch Loss= 0.003484, Training Accuracy= 1.00000
  32. Iter 19900, Minibatch Loss= 0.005693, Training Accuracy= 1.00000
  33. Iter 19950, Minibatch Loss= 0.001904, Training Accuracy= 1.00000
  34. Optimization Finished!

  35. Testing Accuracy: 0.992188
复制代码


回复

使用道具 举报

Archiver|手机版|小黑屋|三木电子社区 ( 辽ICP备11000133号-4 )

辽公网安备 21021702000620号

GMT+8, 2025-5-9 23:54 , Processed in 0.049058 second(s), 22 queries .

Powered by Discuz! X3.3

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表