Skip to content

Commit

Permalink
Merge branch 'master' into flairNLPgh-3488/save-column-corpus-to-files
Browse files Browse the repository at this point in the history
  • Loading branch information
chelseagzr authored Jul 23, 2024
2 parents 2665170 + 832f56e commit 205e46d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 51 deletions.
7 changes: 0 additions & 7 deletions flair/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,6 @@ def hf_download(model_name: str) -> str:
)
except HTTPError:
# output information
logger.error("-" * 80)
logger.error(
f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!"
)
logger.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.")
logger.error(" -> Alternatively, point to a model file on your local drive.")
logger.error("-" * 80)
Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid
raise

Expand Down
52 changes: 10 additions & 42 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,6 @@ def _fetch_model(model_name) -> str:
"chunk": "flair/chunk-english",
"chunk-fast": "flair/chunk-english-fast",
# Language-specific NER models
"ar-ner": "megantosh/flair-arabic-multi-ner",
"ar-pos": "megantosh/flair-arabic-dialects-codeswitch-egy-lev",
"da-ner": "flair/ner-danish",
"de-ner": "flair/ner-german",
"de-ler": "flair/ner-german-legal",
Expand All @@ -691,37 +689,13 @@ def _fetch_model(model_name) -> str:
}

hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
hunflair_paper_path = hu_path + "/hunflair_smallish_models"
hunflair_main_path = hu_path + "/hunflair_allcorpus_models"

hu_model_map = {
# English NER models
"ner": "/".join([hu_path, "ner", "en-ner-conll03-v0.4.pt"]),
"ner-pooled": "/".join([hu_path, "ner-pooled", "en-ner-conll03-pooled-v0.5.pt"]),
"ner-fast": "/".join([hu_path, "ner-fast", "en-ner-fast-conll03-v0.4.pt"]),
"ner-ontonotes": "/".join([hu_path, "ner-ontonotes", "en-ner-ontonotes-v0.4.pt"]),
"ner-ontonotes-fast": "/".join([hu_path, "ner-ontonotes-fast", "en-ner-ontonotes-fast-v0.4.pt"]),
# Multilingual NER models
"ner-multi": "/".join([hu_path, "multi-ner", "quadner-large.pt"]),
"multi-ner": "/".join([hu_path, "multi-ner", "quadner-large.pt"]),
"ner-multi-fast": "/".join([hu_path, "multi-ner-fast", "ner-multi-fast.pt"]),
# English POS models
"upos": "/".join([hu_path, "upos", "en-pos-ontonotes-v0.4.pt"]),
"upos-fast": "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"]),
"pos": "/".join([hu_path, "pos", "en-pos-ontonotes-v0.5.pt"]),
"pos-fast": "/".join([hu_path, "pos-fast", "en-pos-ontonotes-fast-v0.5.pt"]),
# Multilingual POS models
"pos-multi": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]),
"multi-pos": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]),
"pos-multi-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]),
"multi-pos-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]),
# English SRL models
"frame": "/".join([hu_path, "frame", "en-frame-ontonotes-v0.4.pt"]),
"frame-fast": "/".join([hu_path, "frame-fast", "en-frame-ontonotes-fast-v0.4.pt"]),
"frame-large": "/".join([hu_path, "frame-large", "frame-large.pt"]),
# English chunking models
"chunk": "/".join([hu_path, "chunk", "en-chunk-conll2000-v0.4.pt"]),
"chunk-fast": "/".join([hu_path, "chunk-fast", "en-chunk-conll2000-fast-v0.4.pt"]),
# Danish models
"da-pos": "/".join([hu_path, "da-pos", "da-pos-v0.1.pt"]),
"da-ner": "/".join([hu_path, "NER-danish", "da-ner-v0.1.pt"]),
Expand All @@ -730,13 +704,14 @@ def _fetch_model(model_name) -> str:
"de-pos-tweets": "/".join([hu_path, "de-pos-tweets", "de-pos-twitter-v0.1.pt"]),
"de-ner": "/".join([hu_path, "de-ner", "de-ner-conll03-v0.4.pt"]),
"de-ner-germeval": "/".join([hu_path, "de-ner-germeval", "de-ner-germeval-0.4.1.pt"]),
"de-ler": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]),
"de-ner-legal": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]),
# Arabic models
"ar-ner": "/".join([hu_path, "arabic", "ar-ner.pt"]),
"ar-pos": "/".join([hu_path, "arabic", "ar-pos.pt"]),
# French models
"fr-ner": "/".join([hu_path, "fr-ner", "fr-ner-wikiner-0.4.pt"]),
# Dutch models
"nl-ner": "/".join([hu_path, "nl-ner", "nl-ner-bert-conll02-v0.8.pt"]),
"nl-ner-rnn": "/".join([hu_path, "nl-ner-rnn", "nl-ner-conll02-v0.5.pt"]),
"nl-ner-rnn": "/".join([hu_path, "nl-ner-rnn", "nl-ner-conll02-v0.14.0.pt"]),
# Malayalam models
"ml-pos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-xpos-model.pt",
"ml-upos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-upos-model.pt",
Expand All @@ -748,20 +723,13 @@ def _fetch_model(model_name) -> str:
"pucpr-flair-clinical-pos-tagging-best-model.pt",
]
),
# Keyphase models
"keyphrase": "/".join([hu_path, "keyphrase", "keyphrase-en-scibert.pt"]),
"negation-speculation": "/".join([hu_path, "negation-speculation", "negation-speculation-model.pt"]),
"negation-speculation": "/".join([hu_path, "negation-speculation-v14", "negation-speculation-v0.14.0.pt"]),
# Biomedical models
"hunflair-paper-cellline": "/".join([hunflair_paper_path, "cellline", "hunflair-celline-v1.0.pt"]),
"hunflair-paper-chemical": "/".join([hunflair_paper_path, "chemical", "hunflair-chemical-v1.0.pt"]),
"hunflair-paper-disease": "/".join([hunflair_paper_path, "disease", "hunflair-disease-v1.0.pt"]),
"hunflair-paper-gene": "/".join([hunflair_paper_path, "gene", "hunflair-gene-v1.0.pt"]),
"hunflair-paper-species": "/".join([hunflair_paper_path, "species", "hunflair-species-v1.0.pt"]),
"hunflair-cellline": "/".join([hunflair_main_path, "cellline", "hunflair-celline-v1.0.pt"]),
"hunflair-chemical": "/".join([hunflair_main_path, "huner-chemical", "hunflair-chemical-full-v1.0.pt"]),
"hunflair-disease": "/".join([hunflair_main_path, "huner-disease", "hunflair-disease-full-v1.0.pt"]),
"hunflair-gene": "/".join([hunflair_main_path, "huner-gene", "hunflair-gene-full-v1.0.pt"]),
"hunflair-species": "/".join([hunflair_main_path, "huner-species", "hunflair-species-full-v1.1.pt"]),
"hunflair-cellline": "/".join([hunflair_main_path, "huner-cellline", "hunflair-cellline.pt"]),
"hunflair-chemical": "/".join([hunflair_main_path, "huner-chemical", "hunflair-chemical.pt"]),
"hunflair-disease": "/".join([hunflair_main_path, "huner-disease", "hunflair-disease.pt"]),
"hunflair-gene": "/".join([hunflair_main_path, "huner-gene", "hunflair-gene.pt"]),
"hunflair-species": "/".join([hunflair_main_path, "huner-species", "hunflair-species.pt"]),
}

cache_dir = Path("models")
Expand Down
12 changes: 11 additions & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,17 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model":
continue

# if the model cannot be fetched, load as a file
state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path))
try:
state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path))
except Exception:
log.error("-" * 80)
log.error(
f"ERROR: The key '{model_path}' was neither found on the ModelHub nor is this a valid path to a file on your system!"
)
log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.")
log.error(" -> Alternatively, point to a model file on your local drive.")
log.error("-" * 80)
raise ValueError(f"Could not find any model with name '{model_path}'")

# try to get model class from state
cls_name = state.pop("__cls__", None)
Expand Down
2 changes: 1 addition & 1 deletion flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def train_custom(
if inspect.isclass(sampler):
sampler = sampler()
# set dataset to sample from
sampler.set_dataset(train_data) # type: ignore[union-attr]
sampler.set_dataset(train_data)
shuffle = False

# this field stores the names of all dynamic embeddings in the model (determined after first forward pass)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def test_write_to_and_load_from_directory(tasks_base_path):
assert loaded_corpus.train[0].to_tagged_string() == corpus.train[0].to_tagged_string()


@pytest.mark.integration()
def test_hipe_2022_corpus(tasks_base_path):
# This test covers the complete HIPE 2022 dataset.
# https://github.com/hipe-eval/HIPE-2022-data
Expand Down Expand Up @@ -699,6 +700,7 @@ def test_hipe_2022(dataset_version="v2.1", add_document_separator=True):
test_hipe_2022(dataset_version="v2.1", add_document_separator=False)


@pytest.mark.integration()
def test_icdar_europeana_corpus(tasks_base_path):
# This test covers the complete ICDAR Europeana corpus:
# https://github.com/stefan-it/historic-domain-adaptation-icdar
Expand All @@ -716,6 +718,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str):
check_number_sentences(len(corpus.test), gold_stats[language]["test"], "test")


@pytest.mark.integration()
def test_masakhane_corpus(tasks_base_path):
# This test covers the complete MasakhaNER dataset, including support for v1 and v2.
supported_versions = ["v1", "v2"]
Expand Down Expand Up @@ -799,6 +802,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag
check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version)


@pytest.mark.integration()
def test_nermud_corpus(tasks_base_path):
# This test covers the NERMuD dataset. Official stats can be found here:
# https://github.com/dhfbk/KIND/tree/main/evalita-2023
Expand Down Expand Up @@ -826,6 +830,7 @@ def test_german_ler_corpus(tasks_base_path):
assert len(corpus.test) == 6673, "Mismatch in number of sentences for test split"


@pytest.mark.integration()
def test_masakha_pos_corpus(tasks_base_path):
# This test covers the complete MasakhaPOS dataset.
supported_versions = ["v1"]
Expand Down Expand Up @@ -894,6 +899,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag
check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version)


@pytest.mark.integration()
def test_german_mobie(tasks_base_path):
corpus = flair.datasets.NER_GERMAN_MOBIE()

Expand Down Expand Up @@ -978,6 +984,7 @@ def test_jsonl_corpus_loads_metadata(tasks_base_path):
assert dataset.sentences[2].get_metadata("from") == 125


@pytest.mark.integration()
def test_ontonotes_download():
from urllib.parse import urlparse

Expand Down

0 comments on commit 205e46d

Please sign in to comment.