Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cifar10 transfer learning example #2029

Closed
wants to merge 1,518 commits into from
Closed

cifar10 transfer learning example #2029

wants to merge 1,518 commits into from

Conversation

Deanplayerljx
Copy link

No description provided.

tfx-copybara and others added 30 commits April 23, 2020 00:41
PiperOrigin-RevId: 307995072
- Uses absolute import rather than relative import
- Uses native keras model with the generic trainer.
- Uses hyphen(-) instead of underscore(_)

PiperOrigin-RevId: 308160455
PiperOrigin-RevId: 308187069
PiperOrigin-RevId: 308194360
…hon function component.

PiperOrigin-RevId: 308194919
- All pusher now always copy the model into a ModelPush artifact, if push was succeeded.
- Introduced `Versioning` semantic to be used across multiple Pushers. There are two methods in Versioning: UNIX_TIMESTAMP and MODEL_ARTIFACT_ID.
- Unified MLMD custom property:
  - `pushed` is a boolean flag whether push was successful or not. (not changed)
  - `pushed_model` points to the URI of the foreign serving system. (CAIP pusher and default pusher changed)
  - `pushed_version` stores the version value that is generated according to the Versioning semantic. It might be omitted if foreign serving system lacks the version concept (eg. BQML pusher).

Closes #1553

PiperOrigin-RevId: 308236911
PiperOrigin-RevId: 308296241
PiperOrigin-RevId: 308337128
PiperOrigin-RevId: 308688404
PiperOrigin-RevId: 308713076
Please approve this CL. It will be submitted automatically, and its GitHub pull request will be marked as merged.

Imported from GitHub PR #1667

Also included:
- Fix some tests which fails when executed externally, mostly due to
usage of testdata before CWD change, or required environment variables.
- Refreshes test dependency and remove unnecessary ones.

Copybara import of the project:

  - fff7078 Refreshes the contributing.md. by Zhitao Li <[email protected]>
  - e0b304f Merge fff7078 into dd6a3... by Zhitao <[email protected]>

COPYBARA_INTEGRATE_REVIEW=#1667 from zhitaoli:check_test fff7078
PiperOrigin-RevId: 308741011
Usage:
  $ pylint <path_to_file>
(pylintrc in working directory should be picked up by default.)

The new pylintrc is based on TF pylintrc, but added some more exceptions to accomodate existing code base. But there are 345 warnings in tfx codebase as of today.

The new github action will check all incoming PRs with pylint and pytest. These checks will run against modified / added files only.

PiperOrigin-RevId: 308741511
PiperOrigin-RevId: 308768208
PiperOrigin-RevId: 308834214
PiperOrigin-RevId: 308870509
Because the dataset is quite small, it oftentimes drops under 0.9. (I've seen 0.68 in my test.). Lowering accuracy threshold to 0.6 to make tests stable.

PiperOrigin-RevId: 308927472
PiperOrigin-RevId: 308943211
Needed because of upstream https://issues.apache.org/jira/browse/BEAM-4032, as the portability stager is now used for Dataflow jobs as well.

PiperOrigin-RevId: 308968052
PiperOrigin-RevId: 309064363
PiperOrigin-RevId: 309135057
…ders.

The executor can be used with all container launchers.

PiperOrigin-RevId: 309135987
PiperOrigin-RevId: 309241542
PiperOrigin-RevId: 309346614
`packaging` is added as a dependency of `pytest`

PiperOrigin-RevId: 309421408
PiperOrigin-RevId: 309439164
Copy link
Collaborator

@1025KB 1025KB left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use mnist for image example? we just removed cifar10 from examples,

we want to keep a reasonable amount of examples otherwise it would be hard to maintain

@davidzats-eng
Copy link
Contributor

@1025KB I understand the concern about too many example types. However, this example demonstrates how TFX can perform high-quality image classification on real-world datasets. As such, I don't think the MNSIT dataset is appropriate.

@zhitaoli
Copy link
Contributor

@davidzats-eng Can you make sure you set yourself as owner of this example once pulled in? Can we also think about a secondary endorser (who is familiar with the modeling technique or so)?

@davidzats-eng
Copy link
Contributor

@zhitaoli Ack will do.

class_name='SparseCategoricalAccuracy',
threshold=tfma.config.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.8})))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return dataset

def _build_keras_model() -> tf.keras.Model:
"""Creates a MobileNet model pretrained on ImageNet for classifying
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: first row of comment should be standalone summary. If more content needed, then skip a line and add in more details.


# Freeze all layers in the base model except last conv block
for layer in base_model.layers:
if '13' not in layer.name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to get this programmatically instead of hard-coding? Also some use-cases may decide to freeze more or less of the model. Can we make this easily changeable?

tfx/examples/cifar10/cifar10_utils_native_keras.py Outdated Show resolved Hide resolved
tfx/examples/cifar10/cifar10_utils_native_keras.py Outdated Show resolved Hide resolved
# We resize CIFAR10 images to match that size
image_features = tf.image.resize(image_features, [224, 224])

image_features = tf.ensure_shape(image_features, (None, 224, 224, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a check or did the computation not work without it?

Copy link
Author

@Deanplayerljx Deanplayerljx Jun 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm it's actually redundant, since resize will provide the tensor shape information. It will be necessary if we don't have the resize function, because then tfx will then complain that "all the dimensions except batch dimension should be known".

class_name='SparseCategoricalAccuracy',
threshold=tfma.config.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.8}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What value for accuracy are we getting here? Can / should this be adjusted?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was about 0.9, lower_bound of 0.8 should be fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it consistently greater than 0.8 with the small dataset? If it wasn't then we will have flaky tests so lets adjust it down. If it was, then ok to leave as-is / resolve.

lower_bound={'value': 0.8}),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10})))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems extremely tight, how about -1e-3 or so? By the way, where did this number come from? Because it seems too tight for most use-cases.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure we can do that, i copied it from the iris example i remember

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1025KB Should we change our examples? This seems to tight for most use-cases.

model = _build_keras_model()

steps_per_epoch = _TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE
epochs = int(fn_args.train_steps / steps_per_epoch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1025KB is this the right thing to do here? IIRC there were issues with multi-epoch training.


from tfx.components.trainer.executor import TrainerFnArgs

# cifar10 dataset has 50000 train records, and 10000 val records
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/val/eval?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm i think the usual terms are train/validation/test sets?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val was unclear to me. so then s/val/validation :)

input_shape=(224, 224, 3), include_top=False, weights='imagenet',
pooling='avg')

_freeze_model_by_percentage(base_model, 0.9)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the best practice is to train for a bit using the new final layer before unfreezing any part of the model to ensure that the new weights do not pollute the existing model. Can we please do that here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried that setting, i.e. train new layers for 6 epochs and unfreeze all layers for another 6 epochs. But it didn't show improvement... Should I still add it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its good to highlight best practices so yes. and maybe write a comment that this part is "optional".

tfx/examples/cifar10/cifar10_utils_native_keras.py Outdated Show resolved Hide resolved
image_features = tf.map_fn(tf.keras.applications.mobilenet.preprocess_input,
image_features, dtype=tf.float32)

outputs[transformed_name(IMAGE_KEY)] = (image_features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the parents around (image_features)?

# We resize CIFAR10 images to match that size
image_features = tf.image.resize(image_features, [224, 224])

image_features = tf.map_fn(tf.keras.applications.mobilenet.preprocess_input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we combine the two map functions into one? It might be more readable?


# The MobileNet we use was trained on ImageNet, which has image size 224 x 224.
# We resize CIFAR10 images to match that size
image_features = tf.image.resize(image_features, [224, 224])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple ways of resizing. Why do we believe this provides the best quality for our dataset?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

@Deanplayerljx Deanplayerljx Jul 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think the default option (bilinear) is the most common practice. It worked well as the model got ~90% acc on validation set.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets push to see how high we can go. Just because we get ~90% accuracy on an easy dataset doesn't mean that we can't push for more.

@@ -0,0 +1,54 @@

# CIFAR-10 Transfer Learning Example
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower_bound={'value': 0.8}),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10})))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@1025KB Should we change our examples? This seems to tight for most use-cases.


from tfx.components.trainer.executor import TrainerFnArgs

# cifar10 dataset has 50000 train records, and 10000 val records
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val was unclear to me. so then s/val/validation :)

input_shape=(224, 224, 3), include_top=False, weights='imagenet',
pooling='avg')

_freeze_model_by_percentage(base_model, 0.9)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its good to highlight best practices so yes. and maybe write a comment that this part is "optional".

serving_model_dir: Text,
metadata_path: Text,
direct_num_workers: int) -> pipeline.Pipeline:
"""Implements the cifar10 image classification pipeline using TFX."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: standardize capitalization in documentation (i.e., "CIFAR10" vs. "cifar10")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto in other files

tfx/examples/cifar10/README.md Outdated Show resolved Hide resolved
```
Finally, run the `metadata_writer.py` script to write the metadata into model
```
python ~/cifar10/meta_data_writer -model_file PATH_TO_MODEL -label_file data/labels.txt -export_directory exported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is great that you created this script. As a next step, lets see whether we can integrate it as part of the tfx pipeline. Lets leave a todo and follow-up with another pull request.

class_name='SparseCategoricalAccuracy',
threshold=tfma.config.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.8}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it consistently greater than 0.8 with the small dataset? If it wasn't then we will have flaky tests so lets adjust it down. If it was, then ok to leave as-is / resolve.

tfx/examples/cifar10/cifar10_pipeline_native_keras.py Outdated Show resolved Hide resolved
tfx/examples/cifar10/meta_data_writer.py Show resolved Hide resolved
tfx/examples/cifar10/meta_data_writer.py Show resolved Hide resolved
tfx/examples/cifar10/meta_data_writer.py Show resolved Hide resolved
tfx/examples/cifar10/meta_data_writer.py Show resolved Hide resolved
tfx/examples/cifar10/cifar10_utils_native_keras.py Outdated Show resolved Hide resolved

return dataset

def _freeze_model_by_percentage(model: tf.keras.Model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also a model.trainable parameter lets make sure that it is not overriding the settings here. Can you please verify?

Copy link
Author

@Deanplayerljx Deanplayerljx Jul 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will verify that

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.trainable is like a meta switch that turns all parameters trainable or not. if we freeze some layers in the model and then call model.trainable=True, all layers will be unfrozen; if we call model.trainable=False after freezing some of layers, all layers will be frozen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. The question is what is the default and does it get in the way.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default for model.trainable is True. As long as we don't modify it it will not get in the way.

@googlebot
Copy link

We found a Contributor License Agreement for you (the sender of this pull request), but were unable to find agreements for all the commit author(s) or Co-authors. If you authored these, maybe you used a different email address in the git commits than was used to sign the CLA (login here to double check)? If these were authored by someone else, then they will need to sign a CLA as well, and confirm that they're okay with these being contributed to Google.
In order to pass this check, please resolve this problem and then comment @googlebot I fixed it.. If the bot doesn't comment, it means it doesn't think anything has changed.

ℹ️ Googlers: Go here for more info.

@Deanplayerljx
Copy link
Author

@googlebot I fixed it.

@googlebot
Copy link

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@Deanplayerljx
Copy link
Author

It seems I messed up this PR with bunch of other people's commits. will open a new clean PR for this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.