TFLearn
TF.Learn
tf的高级api:TF.Learn
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('tmp/data/')
X_train = mnist.train.images
print(X_train.shape)
X_test = mnist.test.images
y_train = mnist.train.labels.astype("int")
y_test = mnist.test.labels.astype("int")
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(X_train)
dnn_clf = tf.contrib.learn.DNNClassifier(hidden_units=[300, 100],n_classes=10,feature_columns=feature_columns)
# dnn_clf = tf.contrib.learn.SKCompat(dnn_clf)
dnn_clf.fit(X_train, y_train, batch_size=50, steps=40000)
from sklearn.metrics import accuracy_score
y_pred = list(dnn_clf.predict(X_test))
"""y_pre"""
accuracy_score(y_test,y_pred)
dnn_clf.evaluate(X_test, y_test)
tf的低级api
底下两个链接非常清晰地说明了softmax和cross entropy的关系。先输出softmax,然后计算cross entropy的损失。
softmax


交叉熵
交叉熵计算公式

使用TF提供的函数采用全连接,Using dense()
instead of neuron_layer()
dense()
instead of neuron_layer()
Note: the book uses tensorflow.contrib.layers.fully_connected()
rather than tf.layers.dense()
(which did not exist when this chapter was written). It is now preferable to use tf.layers.dense()
, because anything in the contrib module may change or be deleted without notice. The dense()
function is almost identical to the fully_connected()
function, except for a few minor differences:
several parameters are renamed:
scope
becomesname
,activation_fn
becomesactivation
(and similarly the_fn
suffix is removed from other parameters such asnormalizer_fn
),weights_initializer
becomeskernel_initializer
, etc.the default
activation
is nowNone
rather thantf.nn.relu
.
# Common imports
import tensorflow as tf
import numpy as np
import os
# to make this notebook's output stable across runs
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
n_inputs = 28*28 # MNIST
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10
reset_graph()
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")
with tf.name_scope("dnn"):
hidden1 = tf.layers.dense(X, n_hidden1, name="hidden1",
activation=tf.nn.relu)
hidden2 = tf.layers.dense(hidden1, n_hidden2, name="hidden2",
activation=tf.nn.relu)
logits = tf.layers.dense(hidden2, n_outputs, name="outputs")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
loss = tf.reduce_mean(xentropy, name="loss")
learning_rate = 0.01
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
init = tf.global_variables_initializer()
saver = tf.train.Saver()
n_epochs = 20
batch_size = 50
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples // batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
save_path = saver.save(sess, "./my_model_final.ckpt")
不使用提供的函数,自行连接全连接的代码
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('tmp/data/')
"""先声明输入输出还有隐层个数"""
n_inputs = 28 * 28 # MNIST
n_hidden1 = 300
n_hidden2 = 100
n_outputs = 10
"""X的shape是(None,n_inputs),None是不知道的每次训练实例个数"""
"""y是1D的tensor,但是我们不知道训练实例是多少,所以shape=(None)"""
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")
"""X作为输入层,而且一个train batch里的实例同时被DNN处理"""
"""隐层间的不一样是他们所连接的神经元还有隐层各自包含的神经元数目"""
"""X是输入,n_neurons是神经元个数,name是该层的名字,还有用在该层的激活函数"""
def neuron_layer(X, n_neurons, name, activation=None):
with tf.name_scope(name):
n_inputs = int(X.get_shape()[1])
stddev = 2 / np.sqrt(n_inputs)
"""用截断正太分布初始化保证了不会出现过大的W值从而避免了训练时GD缓慢"""
init = tf.truncated_normal((n_inputs, n_neurons), stddev=stddev)
"""初始化该层的权重W"""
W = tf.Variable(init,name="weights")
b = tf.Variable(tf.zeros([n_neurons]), name="biases")
z = tf.matmul(X,W) + b
if activation=="relu":
return tf.nn.relu(z)
else:
return z
"""创建一个Dnn"""
with tf.name_scope('dnn'):
hidden1 = neuron_layer(X, n_hidden1, "hidden1", activation="relu")
hidden2 = neuron_layer(hidden1, n_hidden2, "hidden2", activation="relu")
logits = neuron_layer(hidden2, n_outputs,"outputs")
"""平均交叉熵作为损失函数的计算值"""
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)
loss = tf.reduce_mean(xentropy, name="loss")
learning_rate = 0.01
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
saver = tf.train.Saver()
n_epochs = 40
batch_size = 50
with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(mnist.train.num_examples// batch_size):
X_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(training_op, feed_dict={X:X_batch,y:y_batch})
acc_train = accuracy.eval(feed_dict={X:X_batch, y:y_batch})
acc_val = accuracy.eval(feed_dict={X:mnist.validation.images,
y:mnist.validation.labels})
print(epoch, " Train accuracy:", acc_train," Val accuracy:", acc_val)
save_path = saver.save(sess,"./my_model_final.ckpt")
"""调出模型进行测试"""
with tf.Session() as sess:
saver.restore(sess, "./my_model_final.ckpt") # or better, use save_path
X_new_scaled = mnist.test.images[:20]
Z = logits.eval(feed_dict={X: X_new_scaled})
y_pred = np.argmax(Z, axis=1)
print("Predicted classes:", y_pred)
print("Actual classes: ", mnist.test.labels[:20])
Reference
Last updated
Was this helpful?