"""
Trains a CNN on fake and real galaxy images using TensorFlow.
Adapted from https://github.com/ageron/handson-ml
"""
import numpy as np
import os
from datetime import datetime
from glob import glob
import tensorflow as tf
[docs]def get_indir(nersc=False):
'''Returns path to dr5_testtrain directory'''
if nersc:
return os.path.join('/global/cscratch1/sd/kaylanb','obiwan_out')
else:
return os.path.join(os.environ['HOME'],'Downloads')
[docs]def get_outdir(nersc=False,knl=False):
'''Where to write ckpt and log files'''
if (nersc) & (knl):
return os.path.join('/global/cscratch1/sd/kaylanb','obiwan_out','cnn_knl')
elif (nersc) & (not knl):
return os.path.join('/global/cscratch1/sd/kaylanb','obiwan_out','cnn')
else:
return os.path.join(os.environ['HOME'],'Downloads','cnn')
def get_xtrain_fns(brick,indir):
search= os.path.join(indir,'dr5_testtrain','testtrain',
brick[:3],brick,'xtrain_*.npy')
xtrain_fns= glob(search)
if len(xtrain_fns) == 0:
raise IOError('No training data found matching: %s' % search)
return xtrain_fns
#def BatchGen(X,y,brick,batch_size=32):
def BatchGen(brick,indir,batch_size=32):
fns= get_xtrain_fns(brick,indir)
for fn in fns:
print('Loading %s' % fn)
X= np.load(fns[0])
y= np.load(fns[0].replace('xtrain_','ytrain_'))
N= X.shape[0]
ind= np.array_split(np.arange(N),N // batch_size + 1)
for i in ind:
yield X[i,...],y[i].astype(np.int32) #.reshape(-1,1).astype(np.int32)
def get_bricks(fn='cnn_bricks.txt'):
fn= os.path.join(os.path.dirname(__file__),
'../../../etc',fn)
if not os.path.exists(fn):
raise IOError('Need to create brick list: %s' % fn)
bricks= np.loadtxt(fn,dtype=str)
if len(bricks.shape) == 0:
# single line
with open(fn,'r') as f:
bricks= np.array([f.read().strip()])
return bricks
def get_checkpoint(epoch,brick,outdir):
return os.path.join(outdir,'ckpts',
'epoch_%s_brick_%s.ckpt' % (epoch,brick))
[docs]def bookmark_fn(outdir):
"""Single line text file storing the epoch,brick,batch number of last ckpt"""
return os.path.join(outdir,'ckpts',
'last_epoch_brick_batch.txt')
def get_bookmark(outdir):
with open(bookmark_fn(outdir),'r') as f:
epoch,brick,ith_batch= f.read().strip().split(' ')
return epoch,brick,ith_batch
def get_logdir(outdir):
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
logdir= os.path.join(outdir,'logs')
return os.path.join(logdir,"{}/run-{}/".format(logdir, now))
height,width,channels = (64,64,6)
conv_kwargs= dict(strides=1,
padding='SAME',
activation=tf.nn.relu)
pool_kwargs= dict(ksize= [1,2,2,1],
strides=[1,2,2,1],
padding='VALID')
with tf.name_scope("inputs"):
X = tf.placeholder(tf.float32, shape=[None,height,width,channels], name="X")
y = tf.placeholder(tf.int32, shape=[None], name="y")
# 64x64
with tf.name_scope("layer1"):
conv1 = tf.layers.conv2d(X, filters=3*channels, kernel_size=7,
**conv_kwargs)
pool1 = tf.nn.avg_pool(conv1, **pool_kwargs)
# 32x32
with tf.name_scope("layer2"):
conv2 = tf.layers.conv2d(pool1, filters=6*channels, kernel_size=7,
**conv_kwargs)
pool2 = tf.nn.avg_pool(conv2, **pool_kwargs)
# 16x16
with tf.name_scope("layer3"):
conv3 = tf.layers.conv2d(pool2, filters=9*channels, kernel_size=7,
**conv_kwargs)
pool3 = tf.nn.avg_pool(conv3, **pool_kwargs)
# next is fc
pool3_flat = tf.reshape(pool3,
shape=[-1, pool3.shape[1] * pool3.shape[2] * pool3.shape[3]])
with tf.name_scope("fc"):
fc = tf.layers.dense(pool3_flat, 64, activation=tf.nn.relu, name="fc")
with tf.name_scope("output"):
logits = tf.layers.dense(fc, 2, name="output") # 2 classes
Y_proba = tf.nn.softmax(logits, name="Y_proba")
with tf.name_scope("train"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss)
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
with tf.name_scope("init_and_save"):
init = tf.global_variables_initializer()
saver = tf.train.Saver()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
loss_summary= tf.summary.scalar('loss', loss)
accur_summary = tf.summary.scalar('accuracy', accuracy)
if __name__ == '__main__':
from argparse import ArgumentParser
parser= ArgumentParser()
parser.add_argument('--outdir', type=str, default=None, help='optional output directory for the checkpoint and log files')
args= parser.parse_args()
knl=False
config=None
if 'isKNL' in os.environ.keys():
# Set in slurm_job_knl.sh
knl=True
config= tf.ConfigProto()
config.intra_op_parallelism_threads=os.environ['OMP_NUM_THREADS']
assert(os.environ['OMP_NUM_THREADS'] == 68)
config.inter_op_parallelism_threads=1
nersc=False
if 'CSCRATCH' in os.environ.keys():
nersc=True
indir= get_indir(nersc=nersc)
outdir= get_outdir(nersc=nersc,knl=knl)
if not args.outdir is None:
outdir= args.outdir
# Train
n_epochs = 4
batch_size = 16
bricks= get_bricks()
file_writer = tf.summary.FileWriter(get_logdir(outdir),
tf.get_default_graph())
first_epoch,first_brick,first_batch= '0',bricks[0],'0'
fn= get_checkpoint(first_epoch,first_brick,outdir)+'.meta'
if os.path.exists(fn):
last_epoch,last_brick,last_batch= get_bookmark(outdir)
ckpt_fn= get_checkpoint(last_epoch,last_brick,outdir)
else:
last_epoch,last_brick,last_batch= first_epoch,first_brick,first_batch
ckpt_fn= None
last_ibrick= np.where(bricks == last_brick)[0][0] #+ 1 creates bug where last break skips all epochs
#bricks= ['1211p060']
with tf.Session(config=config) as sess:
if ckpt_fn is None:
sess.run(init)
print('Starting from scratch')
else:
saver.restore(sess, ckpt_fn)
print('Restored ckpt %s' % ckpt_fn)
batch_index= int(last_batch)
for epoch in range(int(last_epoch),n_epochs+1):
for ibrick,brick in enumerate(bricks):
# Don't repeat bricks when restart from ckpt
if ibrick < last_ibrick:
print('skipping: epoch,ibrick,last_ibrick', epoch,ibrick,last_ibrick)
continue
data_gen= BatchGen(brick,indir,batch_size)
for X_,y_ in data_gen:
sess.run(training_op, feed_dict={X: X_, y: y_})
batch_index+=1
if batch_index % 2 == 0:
step = batch_index
file_writer.add_summary(loss_summary.eval(feed_dict={X: X_, y: y_}),
step)
file_writer.add_summary(accur_summary.eval(feed_dict={X: X_, y: y_}),
step)
acc_train = accuracy.eval(feed_dict={X: X_, y: y_})
print(epoch, "Train accuracy:", acc_train)
# Save progress
fn= get_checkpoint(epoch,brick,outdir)
save_path = saver.save(sess, fn)
print('Wrote ckpt %s' % fn)
with open(bookmark_fn(outdir),'w') as f:
f.write('%d %s %d' % (epoch,brick,batch_index))
print('Updated %s' % bookmark_fn(outdir))
# Reset last_ibrick so use all bricks in next epoch
last_ibrick= 0