Example
对照着Tensorflow概述(Tensorflow.md)中的整体流程,进行代码解读。
快速测试tensorflow的函数
import tensorflow as tf;
A = [[0.8,0.6,0.3], [0.1,0.6,0.4]]
B = [1, 1]
out = tf.nn.in_top_k(A, B, 1)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(out)
train
人脸识别train.py
import argparse
import os
import time
import tensorflow as tf
import inception_resnet_v1
from utils import inputs, get_files_name
def run_training(image_path, batch_size, epoch, model_path, log_dir, start_lr, wd, kp):
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Create a session for running operations in the Graph.
sess = tf.Session()
# Input images and labels.
images, age_labels, gender_labels, _ = inputs(path=get_files_name(image_path), batch_size=batch_size,
num_epochs=epoch)
# load network
# face_resnet = face_resnet_v2_generator(101, 'channels_first')
train_mode = tf.placeholder(tf.bool)
age_logits, gender_logits, _ = inception_resnet_v1.inference(images, keep_probability=kp,
phase_train=train_mode, weight_decay=wd)
# Build a Graph that computes predictions from the inference model.
# logits = face_resnet(images, train_mode)
# if you want to transfer weight from another model,please uncomment below codes
# sess = restore_from_source(sess,'./models')
# if you want to transfer weight from another model,please uncomment above codes
# Add to the Graph the loss calculation.
age_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=age_labels, logits=age_logits)
age_cross_entropy_mean = tf.reduce_mean(age_cross_entropy)
gender_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=gender_labels,
logits=gender_logits)
gender_cross_entropy_mean = tf.reduce_mean(gender_cross_entropy)
# l2 regularization
total_loss = tf.add_n(
[gender_cross_entropy_mean, age_cross_entropy_mean] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
age_ = tf.cast(tf.constant([i for i in range(0, 101)]), tf.float32)
age = tf.reduce_sum(tf.multiply(tf.nn.softmax(age_logits), age_), axis=1)
abs_loss = tf.losses.absolute_difference(age_labels, age)
gender_acc = tf.reduce_mean(tf.cast(tf.nn.in_top_k(gender_logits, gender_labels, 1), tf.float32))
tf.summary.scalar("age_cross_entropy", age_cross_entropy_mean)
tf.summary.scalar("gender_cross_entropy", gender_cross_entropy_mean)
tf.summary.scalar("total loss", total_loss)
tf.summary.scalar("train_abs_age_error", abs_loss)
tf.summary.scalar("gender_accuracy", gender_acc)
# Add to the Graph operations that train the model.
global_step = tf.Variable(0, name="global_step", trainable=False)
lr = tf.train.exponential_decay(start_lr, global_step=global_step, decay_steps=3000, decay_rate=0.9,
staircase=True)
optimizer = tf.train.AdamOptimizer(lr)
tf.summary.scalar("lr", lr)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # update batch normalization layer
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(total_loss, global_step)
# if you want to transfer weight from another model,please comment below codes
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
# if you want to transfer weight from another model, please comment above codes
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(log_dir, sess.graph)
# if you want to transfer weight from another model,please uncomment below codes
# sess, new_saver = save_to_target(sess,target_path='./models/new/',max_to_keep=100)
# if you want to transfer weight from another model, please uncomment above codes
# if you want to transfer weight from another model,please comment below codes
new_saver = tf.train.Saver(max_to_keep=100)
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt and ckpt.model_checkpoint_path:
new_saver.restore(sess, ckpt.model_checkpoint_path)
print("restore and continue training!")
else:
pass
# if you want to transfer weight from another model, please comment above codes
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
step = sess.run(global_step)
start_time = time.time()
while not coord.should_stop():
# start_time = time.time()
# Run one step of the model. The return values are
# the activations from the `train_op` (which is
# discarded) and the `loss` op. To inspect the values
# of your ops or variables, you may include them in
# the list passed to sess.run() and the value tensors
# will be returned in the tuple from the call.
_, summary = sess.run([train_op, merged], {train_mode: True})
train_writer.add_summary(summary, step)
# duration = time.time() - start_time
# # Print an overview fairly often.
if step % 100 == 0:
duration = time.time() - start_time
print('%.3f sec' % duration)
start_time = time.time()
if step % 1000 == 0:
save_path = new_saver.save(sess, os.path.join(model_path, "model.ckpt"), global_step=global_step)
print("Model saved in file: %s" % save_path)
step = sess.run(global_step)
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (epoch, step))
finally:
# When done, ask the threads to stop.
save_path = new_saver.save(sess, os.path.join(model_path, "model.ckpt"), global_step=global_step)
print("Model saved in file: %s" % save_path)
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
sess.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--learning_rate", "--lr", type=float, default=1e-3, help="Init learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Set 0 to disable weight decay")
parser.add_argument("--model_path", type=str, default="./models", help="Path to save models")
parser.add_argument("--log_path", type=str, default="./train_log", help="Path to save logs")
parser.add_argument("--epoch", type=int, default=6, help="Epoch")
parser.add_argument("--images", type=str, default="./data/train", help="Path of tfrecords")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
parser.add_argument("--keep_prob", type=float, default=0.8, help="Used by dropout")
parser.add_argument("--cuda", default=False, action="store_true",
help="Set this flag will use cuda when testing.")
args = parser.parse_args()
if not args.cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
run_training(image_path=args.images, batch_size=args.batch_size, epoch=args.epoch, model_path=args.model_path,
log_dir=args.log_path, start_lr=args.learning_rate, wd=args.weight_decay, kp=args.keep_prob)
Last updated
Was this helpful?