Skip to content

Commit

Permalink
make ZeroShotClassification support the roberta-large-mnli (#422)
Browse files Browse the repository at this point in the history
* make ZeroShotClassification support the roberta-large-mnli

I don't know if it's correct, so a review would be appreciated

* change error message for ZeroShotClassification and Roberta models

* Revert token type id update

---------

Co-authored-by: Charles Samuels <[email protected]>
Co-authored-by: guillaume-be <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2023
1 parent 7a5c42d commit ce48aa0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ hf-tokenizers = ["tokenizers"]
features = ["doc-only"]

[dependencies]
rust_tokenizers = "8.1"
rust_tokenizers = "8.1.1"
tch = "0.13.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
Expand Down
4 changes: 2 additions & 2 deletions src/pipelines/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,13 @@ impl ZeroShotClassificationOption {
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = model_config {
if let ConfigOption::Roberta(config) = model_config {
Ok(Self::Roberta(
RobertaForSequenceClassification::new(var_store.root(), config)?,
))
} else {
Err(RustBertError::InvalidConfigurationError(
"You can only supply a BertConfig for Roberta!".to_string(),
"You can only supply a RobertaConfig for Roberta!".to_string(),
))
}
}
Expand Down

0 comments on commit ce48aa0

Please sign in to comment.