Skip to content

Keras

This is a note about developing Keras.

Tracing

Static values in tf_function

All statics will remain constant in tf_function. Any example? Tensor.shape are statics. The TF operations are dynamic. Usually looks like tf.some_operation(). For example, tf.shape() is dynamic.

Testing

The testing infra is a good practice that can be applied to other projects as well.

from keras.testing_infra import test_combinations
from keras.testing_infra import test_utils

# test_combinations.TestCase is a class in Keras for testing.
# It extends the tf.test.TestCase and absl.testing.parameterized.TestCase
class SomeTest(test_combinations.TestCase):
  @test_combinations.run_with_all_model_types
  def test_model_instrumentation(self):
    layers = [
        layers_module.Dense(10, dtype=np.float64),
        layers_module.Dense(10, dtype=np.float64)
    ]
    model = test_utils.get_model_from_layers(layers, input_shape=(1,))
# The following doesn't control which TF version to use (TF1 or TF2).
# They only decide how to run the tests based on the current enabled TF version.
# The TF version is specified by the command for running the tests.
# Keras style for run all modes. (Recommended)
# Run once in TF 1 mode.
# For TF 2 modes, it run twice with `test_utils.should_run_eagerly` equal to `True` and `False`.
@test_combinations.run_all_keras_modes
def some_test(self):
    ...
    model.compile(..., run_eagerly=test_utils.should_run_eagerly)
# Use args to skip TF1 or TF2 modes.
@test_combinations.run_all_keras_modes(always_skip_v1=True)
@test_combinations.run_all_keras_modes(always_skip_eager=True)

# Run twice in TF1 (graph & eager).
# Run once in TF2 (eager).
# The graph mode is the TF1 legacy mode, which doesn't exist in TF2.
# No args needed for the test method.
@test_combinations.generate(test_combinations.combine(mode=['graph', 'eager']))
def some_test(self):
    pass

# Only run once in TF2 with eager mode.
@test_utils.run_v2_only
@test_combinations.run_all_keras_modes(always_skip_v1=True)
@parameterized.parameters("h5", "tf")
def test_keras_saving_functional(self, save_format):
    ...
    model.save(path, save_format=save_format)
# Use tuple if you have multiple args.
@parameterized.parameters(
    ("h5", 0),
    ("tf", 1))
def test_keras_saving_functional(self, save_format, arg2):
    ...
    model.save(path, save_format=save_format)
# Giving names to the tests with different params.
# parameters() would use the param value in the tests names.
@parameterized.named_parameters(("name1", "h5"), ("name2", "tf"))
def test_keras_saving_functional(self, save_format):
    ...
    model.save(path, save_format=save_format)
# Exhaust all combinations of the values.
@test_combinations.generate(test_combinations.combine(
      ragged_query=[True, False],
      ragged_value=[True, False],
      ragged_key=[True, False]))
def test_ragged_tensor(self, ragged_query, ragged_value, ragged_key):
    ...
# If you want to save the model.
def some_test(self):
    ...
    path_string = self.get_temp_dir()
    model.save(os.path.join(path_string, 'my_model'))