Reproducibility is critical to any scientific endeavour, and machine learning is no exception. Releasing code that generates results from papers is an important step in addressing this, but difficulties arise in random aspect of neural network training including data shuffling, augmentation and network initialization, making exact replication of results difficult. Two common approaches for handling these difficulties are:
- repeating experiments multiple times and reporting statistics; and
- managing the random state.
This post looks at the second point, particularly as it applies to training tensorflow’s keras.Models. We’ll be focusing on two properties of our programs:
- Determinism: our programs should produce exactly the same outputs for the same inputs.
- Pre-emptibility: our programs should be able to be interrupted and restarted without affecting the results.
Note that just because our programs are deterministic doesn’t mean there aren’t sources of pseudo-randomness - just that those sources need to be configurable such that they can perform every time. This is commonly done by setting a program’s random seed, but as we’ll see that’s not necessarily the end of the story - particularly if we want our programs to be pre-emptible. For more information about setting random seeds and random number generation in tensorflow check out tensorflow’s random numbers guide.
This is part 1 of a 2-part series looking at deterministic, pre-emptible tensorflow. Part 2 takes a deep dive into data augmentation.
Modules, Checkpoints and BackupAndRestore
Before we get into the specifics of training deterministic pre-emptible models, it’s important that we understand the mechanism by which we’ll be saving and restoring our training state. We’ll be using 2 key classes provided in tensorflow:
- tf.Module: base class for objects that track dependencies, where dependencies are defined as savable objects assigned as attributes. Most public
tf.keras
classes includingModel
andLayer
subclass this. - tf.train.Checkpoint: for saving and restoring
Module
s, including any tracked dependencies.
The following example shows simple usage.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import tensorflow as tf
class Foo(tf.Module):
def __init__(self):
self.bar = tf.Variable(0, dtype=tf.int64)
foo = Foo()
foo.bar.assign(2)
chkpt = tf.train.Checkpoint(root=foo)
chkpt.write("/tmp/foo")
del foo, chkpt
fresh_foo = Foo()
fresh_chkpt = tf.train.Checkpoint(root=fresh_foo)
assert fresh_foo.bar.numpy() == 0
fresh_chkpt.read("/tmp/foo")
assert fresh_foo.bar.numpy() == 2
print("Completed successfully")
During training, our training state will be made up of:
- the Model, including any trainable weights, the random state of any stochastic operations, and the Optimizer if provided in compile;
- the training data Iterator, including random state associated with any data augmentation or shuffle operations, or buffers for operations like prefetch; and
- any other stateful callbacks, like ReduceLROnPlateau or EarlyStopping.
For convenience, we’ll be using BackupAndRestore callback to manage when training state is saved or restored via a Checkpoint
. Unfortunately, BackupAndRestore
only stores the state of the Model
- not the other aspects of our training state listed above. A simple work-around is to include the other elements in our Model
’s training state. Since Model
s are Module
s, they automatically track dependencies assigned as attributes.
1
2
3
model._train_iter = train_iter
model_callbacks = callbacks
# now `model`'s state will include that state of `train_iter` and `callback`
Random Seeds and Weight Initialization
Probably the largest source of non-determinism - and the simplest to fix - is weight initialization. We can make this deterministic by calling tf.random.set_seed.
1
tf.random.set_seed(seed)
Note this will only affect operations created after this call that use the global seed, including tf.keras.layers weight initialization and Dataset.shuffle. It will not affect the global generator state. If our program uses the global generator, we should also set it’s state.
1
tf.random.get_global_generator().reset_from_seed(seed)
Alternatively, we can replace the global generator with a fresh one.
1
tf.random.set_global_generator(tf.random.Generator.from_seed(seed))
The former will operations created by methods on the global generator, while the second will not (though may cause garbage collection related breakages if there are no refences to the old global generator). The result should be equivalent if called before any other get_global_generator
calls.
Data Pipelining
Most training data pipelines will have up to 3 sources of randomness:
- random operations involved in data augmentations like possible image rotations and/or flips;
- race conditions associated with parallel map functions for data loading and augmentation; and
- dataset shuffling.
While researching this post I realized incorporating the first two in a deterministic and pre-emptible manner with tensorflow is distinctly non-trivial. To keep things simple I refactored that section into an entirely separate article.
In this post, we’ll use a relatively straight-forward pipeline without augmentation that uses a shuffle
and uniform map
function using the tf.data.Dataset interface. While Dataset
s can be saved in checkpoints, they won’t contain all the state we want - in our instance, the shuffle buffer. Instead, we want to work with the Iterator of an infinite dataset.
1
2
3
4
5
6
7
8
9
10
11
12
13
dataset: tf.data.Dataset = load_dataset(split='train')
examples_per_epoch = len(dataset)
dataset = dataset.repeat()
dataset = dataset.shuffle(shuffle_buffer, seed=0)
dataset = dataset.batch(batch_size)
dataset = dataset.map(
map_func,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
deterministic=True, # not strictly necessary - this will be the default behaviour
)
steps_per_epoch = examples_per_epoch // batch_size # ignore final fractional batch
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
train_iter = iter(dataset) # this can be saved in checkpoints to track buffer state
Note that we apply repeat
before shuffle
. This has the following consequences:
- we have one persistent shuffle buffer, meaning we won’t need to refill it from scratch each epoch;
- examples will bleed from one epoch to the next - i.e. every epoch will have slightly different examples; and
- our dataset has infinite length.
Floating Point Determinism
There was a time when GPU operations were mostly non-deterministic due to race conditions in floating point operations. This is still the default case for many operations, but most can now be made deterministic by setting the TF_DETERMINISTIC_OPS
environment variable.
1
export TF_DETERMINISTIC_OPS=1
Alternatively, we can set it inside python.
1
2
import os
os.environ["TF_DETERMINISTIC_OPS"] = "1"
See nvidia’s determinism repository for full details. Note there is not universal coverage for deterministic implementations - exceptions include tf.gather
gradients, tf.math.segment_*
operations and sparse-dense matrix multiplications (though the repository discusses ongoing work to resolve these).
Operations with Random State
Some operations like Dropout are intended to be stochastic. Unfortunately, despite the official guide for random number generation discouraging their use, most implementations in tensorflow use base tf.random
operations, rather than tf.random.stateless_*
variants or Generator methods. Hopefully this will change in subsequent releases, but for the moment we can re-implement those necessary for our networks. A simple dropout implementation is given below.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# We register it so we don't serialize then deserialize a `tf.keras.layers.Dropout`
@tf.keras.utils.register_keras_serializable("PreEmptible")
class Dropout(tf.keras.layers.Layer):
def __init__(self, rate: float, seed: Optional[int] = None, **kwargs):
super().__init__(**kwargs)
self._rate = rate
self._seed = seed
self._rng = None
def get_config(self):
config = super().get_config()
config.update(dict(rate=self.rate, seed=self.seed))
return config
@property
def rate(self) -> float:
return self._rate
@property
def seed(self) -> Optional[int]:
return self._seed
def build(self, input_shape):
if self.built:
return
assert self._rng is None
if self.seed is None:
self._rng = tf.random.get_global_generator().split(1)[0]
else:
self._rng = tf.random.Generator.from_seed(self.seed)
super().build(input_shape)
def _apply_training(self, inputs):
mask = self._rng.uniform(tf.shape(inputs)) > self.rate
return tf.where(mask, inputs / (1 - self.rate), tf.zeros_like(inputs))
@tf.function
def call( # pylint: disable=arguments-differ
self, inputs, training: Optional[bool] = None
):
assert self._rng is not None
if training is None:
training = tf.keras.backend.learning_phase()
if training:
return self._apply_training(inputs)
return tf.identity(inputs)
Because both Layer
s and Generator
s are Module
s, by assigning our Generator
as an attribute of our Layer
, the Generator
state will be saved anywhere our Layer
is, including when it’s part of a Model
.
Callbacks
Callbacks provide a flexible interface for users to inject custom behaviour into a training loop (e.g. Model.fit
). Most implementations in tf.keras.callbacks
are responsible for logging, saving, or otherwise providing user feedback on the training process. However, a couple directly influence the training process and maintain their own state based on performance across multiple epochs: ReduceLROnPlateau and EarlyStopping.
There are three things we need to do to ensure their state is included with our model’s state:
- implement wrappers that extend
Module
; - change stateful attributes to non-trainableVariables rather than python primitives; and
- assign them as attributes to the
Model
.
In the following example, we demonstrate how we can wrap ReduceLROnPlateau
such that it can be used in a pre-emptible training process.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def variable_property(name: str, dtype: tf.DType, doc: Optional[str] = None, **kwargs):
"""
Get a property that wraps `tf.Variable` assignment.
Useful for augmenting a base class to save values in `tf.Variable`s rather than
as attributes.
"""
attr_name = f"_variable_{name}"
def getx(self):
return tf.keras.backend.get_value(getattr(self, attr_name))
@tf.Module.with_name_scope
def setx(self, value):
variable = getattr(self, attr_name, None)
if variable is None:
variable = tf.Variable(value, dtype=dtype, name=name, **kwargs)
setattr(self, attr_name, variable)
else:
variable.assign(value)
def delx(self):
delattr(self, attr_name)
return property(getx, setx, delx, doc)
class CallbackModule(tf.Module):
def set_model(self, model: tf.keras.Model):
# pylint: disable=protected-access
old_model = getattr(self, "model", None)
if old_model is not None:
del old_model._callbacks[self.name]
if not hasattr(model, "_callbacks"):
model._callbacks = {}
callbacks = model._callbacks
# pylint: enable=protected-access
assert self.name not in callbacks
callbacks[self.name] = self
self.model = model # pylint: disable=attribute-defined-outside-init
class ReduceLROnPlateau(base.ReduceLROnPlateau, CallbackModule):
def __init__(self, **kwargs):
CallbackModule.__init__(self, name=None)
self._supports_tf_logs = True
self._config = dict(kwargs)
base.ReduceLROnPlateau.__init__(self, **self._config)
best = variable_property("best", tf.float32, trainable=False)
wait = variable_property("wait", tf.int64, trainable=False)
cooldown_counter = variable_property("cooldown_counter", tf.int64, trainable=False)
Custom Fit
While the intent of this post was to create an implementation that used everyone’s favourite Model.fit, I wasn’t able to find a way. I suspect the issue is related to how keras iterates over the data, but there also seems to be some issues with the optimizer as well (deterministic results are acheivable with Model.fit
if using SGD
optimizer and without a shuffle
transform in the data pipeline). Having said that, writing a custom fit
implementation with the same interface isn’t too onerous.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def as_infinite_iterator(
dataset: tf.data.Dataset, steps_per_epoch: Optional[int] = None
) -> Tuple[tf.data.Iterator, int]:
"""
Get an iterator for an infinite dataset and steps_per_epoch.
Args:
dataset: possibly infinite dataset.
steps_per_epoch: number of steps per epoch if `dataset` has infinite
cardinality, otherwise `None` or `dataset`'s cardinality.
Returns:
iterator: tf.data.Iterator of possibly repeated `dataset`.
steps_per_epoch: number of elements in iterator considered one epoch.
Raises:
ValueError is dataset has finite cardinality inconsistent with steps_per_epoch.
"""
cardinality = tf.keras.backend.get_value(dataset.cardinality())
if steps_per_epoch is None:
steps_per_epoch = cardinality
if cardinality == tf.data.INFINITE_CARDINALITY:
raise ValueError(
"steps_per_epoch must be provided if dataset has infinite "
"cardinality"
)
dataset = dataset.repeat()
elif cardinality != tf.data.INFINITE_CARDINALITY:
assert cardinality == steps_per_epoch
dataset = dataset.repeat()
return iter(dataset), steps_per_epoch
def fit_custom(
model: tf.keras.Model,
x: tf.data.Dataset,
epochs: int = 1,
initial_epoch: int = 0,
validation_freq: int = 1,
validation_data: Optional[tf.data.Dataset] = None,
steps_per_epoch: Optional[int] = None,
validation_steps: Optional[int] = None,
callbacks: Iterable[tf.keras.callbacks.Callback]=(),
verbose: bool=True,
) -> tf.keras.callbacks.History:
"""Custom fit implementation. See `tf.keras.Model.fit` for more info."""
train_func = model.make_train_function()
train_iter, steps_per_epoch = as_infinite_iterator(x, steps_per_epoch)
model._train_iter = train_iter
cb = tf.keras.callbacks.CallbackList(
callbacks=callbacks, add_history=True, add_progbar=verbose, model=model
)
cb.set_params(dict(epochs=epochs, verbose=int(verbose), steps=steps_per_epoch))
cb.on_train_begin()
initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch)
model.stop_training = False
for epoch in range(initial_epoch, epochs):
model.reset_metrics()
cb.on_epoch_begin(epoch)
logs = None
for step in range(steps_per_epoch):
cb.on_train_batch_begin(step)
logs = train_func(train_iter)
cb.on_train_batch_end(step, logs)
if model.stop_training:
break
assert logs is not None
epoch_logs = logs
if validation_data is not None and model._should_eval(epoch, validation_freq):
logs = model.evaluate(
validation_data,
steps=validation_steps,
callbacks=callbacks,
return_dict=True,
)
epoch_logs.update({"val_" + name: val for name, val in logs.items()})
cb.on_epoch_end(epoch, epoch_logs)
training_logs = epoch_logs
if model.stop_training:
break
cb.on_train_end(logs=training_logs)
del model._train_iter
return model.history
Complete Example
The above functionality is all implemented in my kblocks repository. You can see a complete example (including data augmentation) here.
Conclusion
As we’ve seen, there’s a lot more to deterministic and pre-emptible training than just setting the random seed and adding a BackupAndRestore.
Think I’m wrong? Found a typo or bug? Have a question or just think something needs further explanation? Let me know in the comments.