diff --git a/fairness_indicators/example_model.py b/fairness_indicators/example_model.py index e13eab0..a4a1ab4 100644 --- a/fairness_indicators/example_model.py +++ b/fairness_indicators/example_model.py @@ -31,7 +31,7 @@ import tempfile import tensorflow.compat.v1 as tf from tensorflow.compat.v1 import estimator as tf_estimator -import tensorflow_hub as hub +from tensorflow.python.feature_column import feature_column_v2 # pylint: disable=g-deprecated-tf-checker import tensorflow_model_analysis as tfma from tensorflow_model_analysis.addons.fairness.post_export_metrics import fairness_indicators # pylint: disable=unused-import @@ -71,8 +71,9 @@ def parse_function(serialized): filenames=[train_tf_file]).map(parse_function).batch(512) return train_dataset - text_embedding_column = hub.text_embedding_column( - key=text_feature, module_spec=module_spec) + text_embedding_column = feature_column_v2.text_embedding_column( + key=text_feature, module_spec=module_spec + ) classifier = tf_estimator.DNNClassifier( hidden_units=[500, 100],