Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom RetinaNet save and load, added test case #1928

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

gianlucasama
Copy link
Contributor

Fixes #1885 lack of serialization and errors when saving and loading.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, but not sure this tests would actually run in its current state.

@@ -244,6 +244,49 @@ def test_saved_model(self, save_format, filename):
restored_output = restored_model(input_batch)
self.assertAllClose(model_output, restored_output)

@pytest.mark.large
def test_custom_saved_model(self, save_format, filename):
class CustomPredictionHead(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make sense simple convolutions or something with real variables?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, how does the forward pass work if these classes are ill-defined?

Copy link
Contributor Author

@gianlucasama gianlucasama Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the test and it worked, because the custom classes are inheriting from the classes that are used as default in the RetinaNet class, they just have different names. If you think we should also test for custom architecture, e.g. creating a totally new and independent class, I'll do it, but I don't think it would change anything.

Copy link
Contributor Author

@gianlucasama gianlucasama Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also test for saving/loading with custom Label Encoder, Label Decoder, Anchor Generator, Backbone, and so on.

restored_model = keras.models.load_model(
save_path,
{
"RetinaNet": keras_cv.models.RetinaNet,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to provide this as a custom object? If so, why isn't it needed in all tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This, my friend, I still don't understand, but from what I remember that's the only way to make the whole thing work using the different formats of saving/loading. I'll look into it and let you know why I did it that way.

Copy link
Contributor Author

@gianlucasama gianlucasama Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is with the "h5" saving/loading format. Without the "RetinaNet": keras_cv.models.RetinaNet "h5" loading doesn't work, it gives you:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/gianlucasama/.vscode/extensions/ms-python.python-2023.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/gianlucasama/Github/keras-cv/keras_cv/models/object_detection/retinanet/retinanet_test.py", line 374, in <module>
    t.test_custom_saved_model("h5", "model")
  File "/home/gianlucasama/Github/keras-cv/keras_cv/models/object_detection/retinanet/retinanet_test.py", line 274, in test_custom_saved_model
    restored_model = keras.models.load_model(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/saving_api.py", line 212, in load_model
    return legacy_sm_saving_lib.load_model(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 61, in error_handler
    return fn(*args, **kwargs)
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/save.py", line 245, in load_model
    return hdf5_format.load_model_from_hdf5(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/hdf5_format.py", line 192, in load_model_from_hdf5
    model = model_config_lib.model_from_config(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/model_config.py", line 55, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File ["/home/gianlucasama/.local/lib/python3.10/site-packages/keras/layers/serialization.py",] line 265, in deserialize
    return legacy_serialization.deserialize_keras_object(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/serialization.py", line 486, in deserialize_keras_object
    (cls, cls_config) = class_and_config_for_serialized_keras_object(
  File "/home/gianlucasama/.local/lib/python3.10/site-packages/keras/saving/legacy/serialization.py", line 368, in class_and_config_for_serialized_keras_object
    raise ValueError(
ValueError: Unknown layer: 'RetinaNet'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

By the way, I noticed that in RetinaNetTest.test_saved_model, line 228, we are not testing the "h5" saving/loading format and that's probably why we are not adding the "RetinaNet": keras_cv.models.RetinaNet there.

@jbischof
Copy link
Contributor

/gcbrun

@gianlucasama
Copy link
Contributor Author

You were right, the test would not run, because I forgot to commit the parametrization needed to run the test.

@ianstenbit
Copy link
Contributor

Thanks for the PR!

Could you please merge the keras-team/master branch into your dev branch? There are some changes that were made on master that are required to fix our GCB tests. Thank you! 👍

@gianlucasama gianlucasama force-pushed the master branch 2 times, most recently from 107b0a9 to 46e27ac Compare July 31, 2023 11:59
@gianlucasama gianlucasama reopened this Jul 31, 2023
@gianlucasama
Copy link
Contributor Author

Sorry I messed up merging branches, it should be all fine now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants