-
Notifications
You must be signed in to change notification settings - Fork 716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cifar10 transfer learning example #2029
Conversation
PiperOrigin-RevId: 307995072
- Uses absolute import rather than relative import - Uses native keras model with the generic trainer. - Uses hyphen(-) instead of underscore(_) PiperOrigin-RevId: 308160455
PiperOrigin-RevId: 308169394
PiperOrigin-RevId: 308187069
PiperOrigin-RevId: 308194360
…hon function component. PiperOrigin-RevId: 308194919
- All pusher now always copy the model into a ModelPush artifact, if push was succeeded. - Introduced `Versioning` semantic to be used across multiple Pushers. There are two methods in Versioning: UNIX_TIMESTAMP and MODEL_ARTIFACT_ID. - Unified MLMD custom property: - `pushed` is a boolean flag whether push was successful or not. (not changed) - `pushed_model` points to the URI of the foreign serving system. (CAIP pusher and default pusher changed) - `pushed_version` stores the version value that is generated according to the Versioning semantic. It might be omitted if foreign serving system lacks the version concept (eg. BQML pusher). Closes #1553 PiperOrigin-RevId: 308236911
PiperOrigin-RevId: 308296241
PiperOrigin-RevId: 308337128
…imental`. PiperOrigin-RevId: 308662021
PiperOrigin-RevId: 308688404
Please approve this CL. It will be submitted automatically, and its GitHub pull request will be marked as merged. Imported from GitHub PR #1667 Also included: - Fix some tests which fails when executed externally, mostly due to usage of testdata before CWD change, or required environment variables. - Refreshes test dependency and remove unnecessary ones. Copybara import of the project: - fff7078 Refreshes the contributing.md. by Zhitao Li <[email protected]> - e0b304f Merge fff7078 into dd6a3... by Zhitao <[email protected]> COPYBARA_INTEGRATE_REVIEW=#1667 from zhitaoli:check_test fff7078 PiperOrigin-RevId: 308741011
Usage: $ pylint <path_to_file> (pylintrc in working directory should be picked up by default.) The new pylintrc is based on TF pylintrc, but added some more exceptions to accomodate existing code base. But there are 345 warnings in tfx codebase as of today. The new github action will check all incoming PRs with pylint and pytest. These checks will run against modified / added files only. PiperOrigin-RevId: 308741511
PiperOrigin-RevId: 308834214
PiperOrigin-RevId: 308922417
PiperOrigin-RevId: 308925109
Because the dataset is quite small, it oftentimes drops under 0.9. (I've seen 0.68 in my test.). Lowering accuracy threshold to 0.6 to make tests stable. PiperOrigin-RevId: 308927472
Needed because of upstream https://issues.apache.org/jira/browse/BEAM-4032, as the portability stager is now used for Dataflow jobs as well. PiperOrigin-RevId: 308968052
PiperOrigin-RevId: 309064363
PiperOrigin-RevId: 309103386
…ders. The executor can be used with all container launchers. PiperOrigin-RevId: 309135987
PiperOrigin-RevId: 309346614
`packaging` is added as a dependency of `pytest` PiperOrigin-RevId: 309421408
PiperOrigin-RevId: 309439164
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use mnist for image example? we just removed cifar10 from examples,
we want to keep a reasonable amount of examples otherwise it would be hard to maintain
@1025KB I understand the concern about too many example types. However, this example demonstrates how TFX can perform high-quality image classification on real-world datasets. As such, I don't think the MNSIT dataset is appropriate. |
@davidzats-eng Can you make sure you set yourself as owner of this example once pulled in? Can we also think about a secondary endorser (who is familiar with the modeling technique or so)? |
@zhitaoli Ack will do. |
class_name='SparseCategoricalAccuracy', | ||
threshold=tfma.config.MetricThreshold( | ||
value_threshold=tfma.GenericValueThreshold( | ||
lower_bound={'value': 0.8}))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets also add in a change threshold for comletenes: https://github.com/tensorflow/model-analysis/blob/d18828330cd1efc47d35c8458350979a8d62fd15/tensorflow_model_analysis/proto/config.proto#L194
return dataset | ||
|
||
def _build_keras_model() -> tf.keras.Model: | ||
"""Creates a MobileNet model pretrained on ImageNet for classifying |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: first row of comment should be standalone summary. If more content needed, then skip a line and add in more details.
|
||
# Freeze all layers in the base model except last conv block | ||
for layer in base_model.layers: | ||
if '13' not in layer.name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to get this programmatically instead of hard-coding? Also some use-cases may decide to freeze more or less of the model. Can we make this easily changeable?
# We resize CIFAR10 images to match that size | ||
image_features = tf.image.resize(image_features, [224, 224]) | ||
|
||
image_features = tf.ensure_shape(image_features, (None, 224, 224, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just a check or did the computation not work without it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm it's actually redundant, since resize will provide the tensor shape information. It will be necessary if we don't have the resize function, because then tfx will then complain that "all the dimensions except batch dimension should be known".
class_name='SparseCategoricalAccuracy', | ||
threshold=tfma.config.MetricThreshold( | ||
value_threshold=tfma.GenericValueThreshold( | ||
lower_bound={'value': 0.8}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What value for accuracy are we getting here? Can / should this be adjusted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was about 0.9, lower_bound of 0.8 should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was it consistently greater than 0.8 with the small dataset? If it wasn't then we will have flaky tests so lets adjust it down. If it was, then ok to leave as-is / resolve.
lower_bound={'value': 0.8}), | ||
change_threshold=tfma.GenericChangeThreshold( | ||
direction=tfma.MetricDirection.HIGHER_IS_BETTER, | ||
absolute={'value': -1e-10}))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems extremely tight, how about -1e-3 or so? By the way, where did this number come from? Because it seems too tight for most use-cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure we can do that, i copied it from the iris example i remember
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@1025KB Should we change our examples? This seems to tight for most use-cases.
model = _build_keras_model() | ||
|
||
steps_per_epoch = _TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE | ||
epochs = int(fn_args.train_steps / steps_per_epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@1025KB is this the right thing to do here? IIRC there were issues with multi-epoch training.
|
||
from tfx.components.trainer.executor import TrainerFnArgs | ||
|
||
# cifar10 dataset has 50000 train records, and 10000 val records |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/val/eval?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm i think the usual terms are train/validation/test sets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val was unclear to me. so then s/val/validation :)
input_shape=(224, 224, 3), include_top=False, weights='imagenet', | ||
pooling='avg') | ||
|
||
_freeze_model_by_percentage(base_model, 0.9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that the best practice is to train for a bit using the new final layer before unfreezing any part of the model to ensure that the new weights do not pollute the existing model. Can we please do that here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried that setting, i.e. train new layers for 6 epochs and unfreeze all layers for another 6 epochs. But it didn't show improvement... Should I still add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its good to highlight best practices so yes. and maybe write a comment that this part is "optional".
image_features = tf.map_fn(tf.keras.applications.mobilenet.preprocess_input, | ||
image_features, dtype=tf.float32) | ||
|
||
outputs[transformed_name(IMAGE_KEY)] = (image_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the parents around (image_features)?
# We resize CIFAR10 images to match that size | ||
image_features = tf.image.resize(image_features, [224, 224]) | ||
|
||
image_features = tf.map_fn(tf.keras.applications.mobilenet.preprocess_input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we combine the two map functions into one? It might be more readable?
|
||
# The MobileNet we use was trained on ImageNet, which has image size 224 x 224. | ||
# We resize CIFAR10 images to match that size | ||
image_features = tf.image.resize(image_features, [224, 224]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are multiple ways of resizing. Why do we believe this provides the best quality for our dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://www.tensorflow.org/api_docs/python/tf/image/ResizeMethod for options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I think the default option (bilinear) is the most common practice. It worked well as the model got ~90% acc on validation set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets push to see how high we can go. Just because we get ~90% accuracy on an easy dataset doesn't mean that we can't push for more.
tfx/examples/cifar10/README.md
Outdated
@@ -0,0 +1,54 @@ | |||
|
|||
# CIFAR-10 Transfer Learning Example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets also add a unit test like so:
https://github.com/tensorflow/tfx/blob/master/tfx/examples/mnist/mnist_pipeline_native_keras_e2e_test.py
lower_bound={'value': 0.8}), | ||
change_threshold=tfma.GenericChangeThreshold( | ||
direction=tfma.MetricDirection.HIGHER_IS_BETTER, | ||
absolute={'value': -1e-10}))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@1025KB Should we change our examples? This seems to tight for most use-cases.
|
||
from tfx.components.trainer.executor import TrainerFnArgs | ||
|
||
# cifar10 dataset has 50000 train records, and 10000 val records |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val was unclear to me. so then s/val/validation :)
input_shape=(224, 224, 3), include_top=False, weights='imagenet', | ||
pooling='avg') | ||
|
||
_freeze_model_by_percentage(base_model, 0.9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its good to highlight best practices so yes. and maybe write a comment that this part is "optional".
serving_model_dir: Text, | ||
metadata_path: Text, | ||
direct_num_workers: int) -> pipeline.Pipeline: | ||
"""Implements the cifar10 image classification pipeline using TFX.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: standardize capitalization in documentation (i.e., "CIFAR10" vs. "cifar10")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto in other files
``` | ||
Finally, run the `metadata_writer.py` script to write the metadata into model | ||
``` | ||
python ~/cifar10/meta_data_writer -model_file PATH_TO_MODEL -label_file data/labels.txt -export_directory exported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is great that you created this script. As a next step, lets see whether we can integrate it as part of the tfx pipeline. Lets leave a todo and follow-up with another pull request.
class_name='SparseCategoricalAccuracy', | ||
threshold=tfma.config.MetricThreshold( | ||
value_threshold=tfma.GenericValueThreshold( | ||
lower_bound={'value': 0.8}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was it consistently greater than 0.8 with the small dataset? If it wasn't then we will have flaky tests so lets adjust it down. If it was, then ok to leave as-is / resolve.
|
||
return dataset | ||
|
||
def _freeze_model_by_percentage(model: tf.keras.Model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also a model.trainable parameter lets make sure that it is not overriding the settings here. Can you please verify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will verify that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.trainable is like a meta switch that turns all parameters trainable or not. if we freeze some layers in the model and then call model.trainable=True, all layers will be unfrozen; if we call model.trainable=False after freezing some of layers, all layers will be frozen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. The question is what is the default and does it get in the way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default for model.trainable is True. As long as we don't modify it it will not get in the way.
We found a Contributor License Agreement for you (the sender of this pull request), but were unable to find agreements for all the commit author(s) or Co-authors. If you authored these, maybe you used a different email address in the git commits than was used to sign the CLA (login here to double check)? If these were authored by someone else, then they will need to sign a CLA as well, and confirm that they're okay with these being contributed to Google. ℹ️ Googlers: Go here for more info. |
@googlebot I fixed it. |
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
It seems I messed up this PR with bunch of other people's commits. will open a new clean PR for this |
No description provided.