Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TPU and evaluation saving for GLUE finetuning #234

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

chenmoneygithub
Copy link
Contributor

No description provided.

@google-cla
Copy link

google-cla bot commented Jun 25, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@chenmoneygithub chenmoneygithub changed the title Add TP Add TPU and evaluation saving for GLUE finetuning Jun 25, 2022
@@ -161,12 +179,14 @@ def __init__(
kernel_initializer=initializer,
name="logits",
)
self._drop_out = tf.keras.layers.Dropout(dropout)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_dropout_layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

model = keras.models.load_model(FLAGS.saved_model_input, compile=False)
# model = keras.models.load_model(FLAGS.saved_model_input, compile=False)
model = keras.models.load_model(
"gs://chenmoney-testing-east/" + FLAGS.saved_model_input,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs an update!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@@ -270,6 +306,58 @@ def preprocess_data(inputs, labels):
f"The best hyperparameters found are:\nLearning Rate: {best_hp['lr']}"
)

if FLAGS.saved_evaluations_output:
filenames = {
"cola": "CoLA.tsv",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these the exact filename needed for GLUE submission?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, they require the exact name to match

@@ -53,6 +53,12 @@
"The directory to save the finetuned model.",
)

flags.DEFINE_string(
"saved_evaluations_output",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe something a little more directly named here? tsv_prediction_output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

examples/bert/bert_finetune_glue.py Show resolved Hide resolved
with tf.io.gfile.GFile(filename, "w") as f:
# Write the required headline for GLUE.
f.write("index\tprediction\n")
for i in range(test_ds.cardinality()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we should be able to do this with a more readable tf.data loop here. for idx, x in enumerate(dataset): works directly, let's avoid calling iter and range and next ourselves

)

@tf.function
def eval_step(iterator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just the actual keras.evaluate here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the predicted label to write to the file. evaluate only returns metrics?

@@ -161,12 +173,14 @@ def __init__(
kernel_initializer=initializer,
name="logits",
)
self._drop_out_layer = tf.keras.layers.Dropout(dropout)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop_out -> dropout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -178,6 +192,7 @@ def __init__(self, model_config):

def build(self, hp):
model = keras.models.load_model(FLAGS.saved_model_input, compile=False)
model = model.bert_model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will work right? this seems to be implying we save the whole model in bert_train.py, which we do not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


labelname = labelnames.get(FLAGS.task_name)
with tf.io.gfile.GFile(filename, "w") as f:
# Write the required headline for GLUE.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use the builtin csv with a different delimiter to write tsvs
https://docs.python.org/3/library/csv.html
https://stackoverflow.com/q/29895602

We should prefer that over doing this writing outself

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@mattdangerw mattdangerw self-assigned this Mar 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants