Introduction To Tensorflow Estimator

In this post I am going to introduce tf.estimator library. So first of all, what is this library trying to do? When writing tensorflow code, there is a lot of repeated operations that we need to do:

  • read the data in batches
  • process the data, e.g. convert images to floats
  • run a for loop and take a few gradient descent steps
  • save model weights to disk
  • output metrics to tensorboard

The keras library makes this quite a bit easier, but there are times when you might need to use plain old tensorflow (it gets quite hacky to implement some multiple output models and GANs in keras). The estimator library basically does all this boiler plate without placing restrictions on what you can do with your model.

Unfortunately the docs are a little bit lacking and most examples use the tensorflow records, which you do not have to use! This post will give you some boiler plate code where you can expand on (hopefully).

Below is a summary plot of how tensorflow estimators work. You need to define two functions and a dictionary or parameters.

Let’s go step by step. First, the params. This is where all the configuration parameters of your model and input live. This might look something like this:

params = dict(
    #Input parameters
    #Model parameters
    shape=[32, 32, 1],

Next we need to load the input. Here there are two convenient functions we can use. If the data can be held in memory, then we can create a dataset using Alternatively, we can create a dataset from a generatro using Using this interface, we can specify the batch size and epochs simply by calling dataset.batch() and dataset.repeat().

Below I will show the generator interface because I tend to find myself using it more for real problems. For the similar reason, I will also pass features as a dictionary, which allows one to combine tensors of different shapes and sizes.

import numpy as np
import tensorflow as tf

beta = np.random.normal(size=(params['num_features'], 1)) * 2.0

def input_fn(params, mode):
    # Generate fake data
    X = np.random.normal(size=(10000, params['num_features']))
    y = X @ beta + np.random.normal(size=(10000, 1))
    # We can use below instead but I find
    # in practice the generator interface is more useful.
    def data_it():
        for i in range(10000):
            yield {'X': X[i, :]}, y[i]
    data_types = ({'X': tf.float32}, tf.float32)
    data_shapes = ({
        'X': (params['num_features'], )
    }, (1, ))
    dataset =
    # If we are training then shuffle and repeat the dataset
    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.shuffle(2500) # Shuffle needs a buffer size
        dataset = dataset.repeat(params['epochs'])
    dataset = dataset.batch(params['batch_size'])
    return dataset

Obligatory mention of which allows you to transform the data after you have a dataset.

Anyway, moving on, next thing we need is a model_fn. This part of the code takes the input data, the mode and params then returns a bunch of stuff. What you return really is up to you. The return has to be a tf.estimator.Estimator. The docs are rather scant as of writing this post. The only argument necessary is the mode. You can specify the loss which will get logged with tensorboard, a dictionary of predictions if you are in predict mode, train_ops which will be executed during the training loop and eval_metric_ops which will be executed during the evaluation stage.

from tensorflow.keras.layers import Dense, Dropout

def model_fn(features, labels, mode, params):
    output = features['X']
    for width in params['shape']:
        output = Dense(width, activation='relu')(output)
        if params['dropout'] > 0 and mode == tf.estimator.ModeKeys.TRAIN:
            output = Dropout(params['dropout'])(output)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                        predictions={ # We can put whatever we want here
                            'values': output,
    loss = tf.reduce_mean((output - labels) ** 2)
    if mode == tf.estimator.ModeKeys.EVAL:
        metrics = dict(
            mae=tf.metrics.mean_absolute_error(labels, output)
        return tf.estimator.EstimatorSpec(
    optimiser = tf.train.AdamOptimizer(params['learning_rate'])
    train_op = optimiser.minimize(
    return tf.estimator.EstimatorSpec(

Depending on the mode we produce a different EstimatorSpec. Bare in mind that the model_fn is only called once to construct the graph! If you need any operations you need to execute, you need to pass these to train_op and eval_metric_ops which will be executed during training and evaluation respectively.

Now we can bring it all together!

run_config = tf.estimator.RunConfig(
    save_summary_steps=200, # Save tensorboard output every 200 steps
    save_checkpoints_steps=3000, # Save the model weights every 3000 steps
    keep_checkpoint_max=10, # Keep the last 10 saved weights

estimator = tf.estimator.Estimator(model_fn,

train_spec = tf.estimator.TrainSpec(input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn, throttle_secs=2)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using config: {'_model_dir': 'my_model_dir/', '_tf_random_seed': None, '_save_summary_steps': 200, '_save_checkpoints_steps': 3000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
, '_keep_checkpoint_max': 10, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': < object at 0x7f677d23af28>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 3000 or save_checkpoints_secs None.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/ Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/ calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/ where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into my_model_dir/model.ckpt.
INFO:tensorflow:loss = 10.592407, step = 0
INFO:tensorflow:global_step/sec: 169.577
INFO:tensorflow:loss = 6.0479317, step = 100 (0.590 sec)
INFO:tensorflow:global_step/sec: 224.544
INFO:tensorflow:loss = 4.1276617, step = 200 (0.447 sec)
INFO:tensorflow:global_step/sec: 229.509
INFO:tensorflow:loss = 7.515444, step = 300 (0.434 sec)
INFO:tensorflow:global_step/sec: 218.804
INFO:tensorflow:loss = 5.87257, step = 400 (0.459 sec)
INFO:tensorflow:global_step/sec: 227.757
INFO:tensorflow:loss = 15.559208, step = 500 (0.437 sec)
INFO:tensorflow:global_step/sec: 217.511
INFO:tensorflow:loss = 8.777319, step = 600 (0.460 sec)
INFO:tensorflow:global_step/sec: 228.692
INFO:tensorflow:loss = 9.553991, step = 700 (0.437 sec)
INFO:tensorflow:global_step/sec: 213.618
INFO:tensorflow:loss = 9.763461, step = 800 (0.471 sec)
INFO:tensorflow:global_step/sec: 231.125
INFO:tensorflow:loss = 12.304417, step = 900 (0.429 sec)
INFO:tensorflow:global_step/sec: 228.433
INFO:tensorflow:loss = 4.2876644, step = 1000 (0.441 sec)
INFO:tensorflow:global_step/sec: 223.483
INFO:tensorflow:loss = 10.412739, step = 1100 (0.446 sec)
INFO:tensorflow:global_step/sec: 233.307
INFO:tensorflow:loss = 7.404874, step = 1200 (0.427 sec)
INFO:tensorflow:global_step/sec: 214.08
INFO:tensorflow:loss = 8.00625, step = 1300 (0.467 sec)
INFO:tensorflow:global_step/sec: 225.272
INFO:tensorflow:loss = 6.8261833, step = 1400 (0.445 sec)
INFO:tensorflow:global_step/sec: 248.704
INFO:tensorflow:loss = 5.6619678, step = 1500 (0.401 sec)
INFO:tensorflow:Saving checkpoints for 1563 into my_model_dir/model.ckpt.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-11-19T18:41:40Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from my_model_dir/model.ckpt-1563
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2019-11-19-18:41:41
INFO:tensorflow:Saving dict for global step 1563: global_step = 1563, loss = 7.899214, mae = 1.9335026
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1563: my_model_dir/model.ckpt-1563
INFO:tensorflow:Loss for final step: 9.399655.

({'global_step': 1563, 'loss': 7.899214, 'mae': 1.9335026}, [])

You can find lots more options on the RunConfig documentation.

In the model directory, estimator has saved the weights and logged the tensorboard output. This means that if we point the estimator to the same directory, it will automatically load all the weights. Good stuff!

!ls ./my_model_dir/

We can get predictions by calling the .predict method on the estimator. This returns the predictions as a generator.

est_preds = estimator.predict(input_fn)
predictions = np.stack([x['values'] for x in est_preds], axis=0)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from my_model_dir/model.ckpt-1563
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

(10000, 1)

I hope that his gives enough overview to read the docs for the various options that you can use.

comments powered by Disqus