广东湘恒智能科技有限公司
主营产品: 西门子PLC代理商,plc变频器,伺服电机,人机界面,触摸屏,线缆,DP接头
SIEMENS浙江省衢州市 西门子代理商——西门子华东一级总代理


对MNIST数据集中的测试数据进行预测,测试模型准确率。

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet5_forward
import mnist_lenet5_backward
import numpy as np

TEST_INTERVAL_SECS = 5
#+++++++++++++++++++++++++++++++修改读入的大小
BATCH_SIZE = 500#0 #batch
STEPS = 2

def test(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32,[
BATCH_SIZE,#mnist.test.num_examples,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS])
y_ = tf.placeholder(tf.float32, [None, mnist_lenet5_forward.OUTPUT_NODE])
y = mnist_lenet5_forward.forward(x,False,None)

ema = tf.train.ExponentialMovingAverage(mnist_lenet5_backward.MOVING_AVERAGE_DECAY)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)

#判断预测值和实际值是否相同
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
#求平均得到准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_lenet5_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)

#根据读入的模型名字切分出该模型是属于迭代了多少次保存的
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

for i in range(STEPS):
#读取一个batch的数据
xs, ys = mnist.test.next_batch(BATCH_SIZE)
reshaped_x = np.reshape(xs,(
BATCH_SIZE,#mnist.test.num_examples,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.IMAGE_SIZE,
mnist_lenet5_forward.NUM_CHANNELS))
#计算出测试集上准确率
accuracy_score = sess.run(accuracy, feed_dict={x:reshaped_x,y_:ys})
print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))
else:
print('No checkpoint file found')
return
#每隔5秒寻找一次是否有最新的模型
time.sleep(TEST_INTERVAL_SECS)

def main():
mnist = input_data.read_data_sets("./data/", one_hot=True)
test(mnist)

if __name__ == '__main__':
main()


展开全文
相关产品
拨打电话 微信咨询 发送询价