Skip to content

Commit

Permalink
Merge branch 'master' into flairNLPgh-3488/save-column-corpus-to-files
Browse files Browse the repository at this point in the history
  • Loading branch information
chelseagzr authored Sep 12, 2024
2 parents 757f0ca + c674212 commit f4466f7
Show file tree
Hide file tree
Showing 19 changed files with 162 additions and 336 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
uses: actions/cache@v3
with:
path: ./cache
key: cache-v1.1
key: cache-v1.2
- name: Run tests
run: |
python -c 'import flair'
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ resources/taggers/
regression_train/
/doc_build/

scripts/
scripts/
7 changes: 6 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ sphinx
importlib-metadata
sphinx-multiversion
pydata-sphinx-theme<0.14
sphinx_design
sphinx_design

# previous dependencies that are required to build docs for later versions too.
semver
gensim
bpemb
22 changes: 21 additions & 1 deletion flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from flair.datasets.base import find_train_dev_test_files
from flair.file_utils import cached_path, unpack_file
from flair.tokenization import Tokenizer

log = logging.getLogger("flair")

Expand All @@ -41,6 +42,7 @@ def __init__(
label_column_name: str = "label",
metadata_column_name: str = "metadata",
label_type: str = "ner",
use_tokenizer: Union[bool, Tokenizer] = True,
**corpusargs,
) -> None:
"""Instantiates a MuliFileJsonlCorpus as, e.g., created with doccanos JSONL export.
Expand All @@ -52,9 +54,12 @@ def __init__(
:param train_files: the name of the train files
:param test_files: the name of the test files
:param dev_files: the name of the dev files, if empty, dev data is sampled from train
:param encoding: file encoding (default "utf-8")
:param text_column_name: Name of the text column inside the jsonl files.
:param label_column_name: Name of the label column inside the jsonl files.
:param metadata_column_name: Name of the metadata column inside the jsonl files.
:param label_type: he type of label to predict (default "ner")
:param use_tokenizer: Specify a custom tokenizer to split the text into tokens.
:raises RuntimeError: If no paths are given
"""
Expand All @@ -68,6 +73,7 @@ def __init__(
metadata_column_name=metadata_column_name,
label_type=label_type,
encoding=encoding,
use_tokenizer=use_tokenizer,
)
for train_file in train_files
]
Expand All @@ -86,6 +92,8 @@ def __init__(
label_column_name=label_column_name,
metadata_column_name=metadata_column_name,
label_type=label_type,
encoding=encoding,
use_tokenizer=use_tokenizer,
)
for test_file in test_files
]
Expand All @@ -104,6 +112,8 @@ def __init__(
label_column_name=label_column_name,
metadata_column_name=metadata_column_name,
label_type=label_type,
encoding=encoding,
use_tokenizer=use_tokenizer,
)
for dev_file in dev_files
]
Expand All @@ -128,6 +138,7 @@ def __init__(
label_type: str = "ner",
autofind_splits: bool = True,
name: Optional[str] = None,
use_tokenizer: Union[bool, Tokenizer] = True,
**corpusargs,
) -> None:
"""Instantiates a JsonlCorpus with one file per Dataset (train, dev, and test).
Expand All @@ -136,11 +147,14 @@ def __init__(
:param train_file: the name of the train file
:param test_file: the name of the test file
:param dev_file: the name of the dev file, if None, dev data is sampled from train
:param encoding: file encoding (default "utf-8")
:param text_column_name: Name of the text column inside the JSONL file.
:param label_column_name: Name of the label column inside the JSONL file.
:param metadata_column_name: Name of the metadata column inside the JSONL file.
:param label_type: The type of label to predict (default "ner")
:param autofind_splits: Whether train, test and dev file should be determined automatically
:param name: name of the Corpus see flair.data.Corpus
:param use_tokenizer: Specify a custom tokenizer to split the text into tokens.
"""
# find train, dev and test files if not specified
dev_file, test_file, train_file = find_train_dev_test_files(
Expand All @@ -156,6 +170,7 @@ def __init__(
label_type=label_type,
name=name if data_folder is None else str(data_folder),
encoding=encoding,
use_tokenizer=use_tokenizer,
**corpusargs,
)

Expand All @@ -169,6 +184,7 @@ def __init__(
label_column_name: str = "label",
metadata_column_name: str = "metadata",
label_type: str = "ner",
use_tokenizer: Union[bool, Tokenizer] = True,
) -> None:
"""Instantiates a JsonlDataset and converts all annotated char spans to token tags using the IOB scheme.
Expand All @@ -184,9 +200,12 @@ def __init__(
Args:
path_to_jsonl_file: File to read
encoding: file encoding (default "utf-8")
text_column_name: Name of the text column
label_column_name: Name of the label column
metadata_column_name: Name of the metadata column
label_type: The type of label to predict (default "ner")
use_tokenizer: Specify a custom tokenizer to split the text into tokens.
"""
path_to_json_file = Path(path_to_jsonl_file)

Expand All @@ -203,7 +222,7 @@ def __init__(
raw_text = current_line[text_column_name]
current_labels = current_line[label_column_name]
current_metadatas = current_line.get(self.metadata_column_name, [])
current_sentence = Sentence(raw_text)
current_sentence = Sentence(raw_text, use_tokenizer=use_tokenizer)

self._add_labels_to_sentence(raw_text, current_sentence, current_labels)
self._add_metadatas_to_sentence(current_sentence, current_metadatas)
Expand Down Expand Up @@ -310,6 +329,7 @@ def __init__(
dev_files: the name of the dev files, if empty, dev data is sampled from train
column_delimiter: default is to split on any separatator, but you can overwrite for instance with "\t" to split only on tabs
comment_symbol: if set, lines that begin with this symbol are treated as comments
encoding: file encoding (default "utf-8")
document_separator_token: If provided, sentences that function as document boundaries are so marked
skip_first_line: set to True if your dataset has a header line
in_memory: If set to True, the dataset is kept in memory as Sentence objects, otherwise does disk reads
Expand Down
49 changes: 33 additions & 16 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
import transformers
from semver import Version
from packaging.version import Version
from torch.jit import ScriptModule
from transformers import (
CONFIG_MAPPING,
Expand Down Expand Up @@ -65,7 +65,7 @@ def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tens

@torch.jit.script_if_tracing
def truncate_hidden_states(hidden_states: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
return hidden_states[:, :, : input_ids.size()[1]]
return hidden_states[:, :, : input_ids.size(1)]


@torch.jit.script_if_tracing
Expand Down Expand Up @@ -95,14 +95,12 @@ def combine_strided_tensors(
if selected_sentences.size(0) > 1:
start_part = selected_sentences[0, : half_stride + 1]
mid_part = selected_sentences[:, half_stride + 1 : max_length - 1 - half_stride]
mid_part = torch.reshape(mid_part, (mid_part.shape[0] * mid_part.shape[1],) + mid_part.shape[2:])
end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :]
mid_part = torch.reshape(mid_part, (mid_part.size(0) * mid_part.size(1),) + mid_part.size()[2:])
end_part = selected_sentences[selected_sentences.size(0) - 1, max_length - half_stride - 1 :]
sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0)
sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat(
(start_part, mid_part, end_part), dim=0
)
sentence_hidden_states[sentence_id, : sentence_hidden_state.size(0)] = sentence_hidden_state
else:
sentence_hidden_states[sentence_id, : selected_sentences.shape[1]] = selected_sentences[0, :]
sentence_hidden_states[sentence_id, : selected_sentences.size(1)] = selected_sentences[0, :]

return sentence_hidden_states

Expand Down Expand Up @@ -171,11 +169,30 @@ def fill_mean_token_embeddings(
word_ids: torch.Tensor,
token_lengths: torch.Tensor,
):
for i in torch.arange(all_token_embeddings.shape[0]):
for _id in torch.arange(token_lengths[i]): # type: ignore[call-overload]
all_token_embeddings[i, _id, :] = torch.nan_to_num(
sentence_hidden_states[i][word_ids[i] == _id].mean(dim=0)
)
batch_size, max_tokens, embedding_dim = all_token_embeddings.shape
mask = word_ids >= 0

# sum embeddings for each token
all_token_embeddings.scatter_add_(
1,
word_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, embedding_dim),
sentence_hidden_states * mask.unsqueeze(-1).float(),
)

# calculate the mean of subtokens
subtoken_counts = torch.zeros_like(all_token_embeddings[:, :, 0])
subtoken_counts.scatter_add_(1, word_ids.clamp(min=0), mask.float())
all_token_embeddings = torch.where(
subtoken_counts.unsqueeze(-1) > 0,
all_token_embeddings / subtoken_counts.unsqueeze(-1),
torch.zeros_like(all_token_embeddings),
)

# Create a mask for valid tokens based on token_lengths
token_mask = torch.arange(max_tokens, device=token_lengths.device)[None, :] < token_lengths[:, None]
all_token_embeddings = all_token_embeddings * token_mask.unsqueeze(-1)
all_token_embeddings = torch.nan_to_num(all_token_embeddings)

return all_token_embeddings


Expand Down Expand Up @@ -1056,7 +1073,7 @@ def __init__(
model, add_prefix_space=True, **transformers_tokenizer_kwargs, **kwargs
)
try:
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model, apply_ocr=False)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model, apply_ocr=False, **kwargs)
except OSError:
self.feature_extractor = None
else:
Expand Down Expand Up @@ -1222,7 +1239,7 @@ def embedding_length(self) -> int:
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if transformers.__version__ >= Version(4, 31, 0):
if Version(transformers.__version__) >= Version("4.31.0"):
assert isinstance(state_dict, dict)
state_dict.pop(f"{prefix}model.embeddings.position_ids", None)
super()._load_from_state_dict(
Expand Down Expand Up @@ -1307,7 +1324,7 @@ def __setstate__(self, state):
self.__dict__[key] = embedding.__dict__[key]

if model_state_dict:
if transformers.__version__ >= Version(4, 31, 0):
if Version(transformers.__version__) >= Version("4.31.0"):
model_state_dict.pop("embeddings.position_ids", None)
self.model.load_state_dict(model_state_dict)

Expand Down
2 changes: 1 addition & 1 deletion flair/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_from_cache(url: str, cache_dir: Path) -> Path:
return cache_path

# make HEAD request to check ETag
response = requests.head(url, headers={"User-Agent": "Flair"}, allow_redirects=True)
response = requests.head(url, headers={"User-Agent": "Flair"}, allow_redirects=True, proxies=url_proxies)
if response.status_code != 200:
raise OSError(f"HEAD request failed for url {url} with status code {response.status_code}.")

Expand Down
3 changes: 2 additions & 1 deletion flair/models/entity_mention_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,12 +1056,13 @@ def evaluate(
embedding_storage_mode: str = "none",
mini_batch_size: int = 32,
main_evaluation_metric: Tuple[str, str] = ("accuracy", "f1-score"),
exclude_labels: List[str] = [],
exclude_labels: Optional[List[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
k: int = 1,
**kwargs,
) -> Result:
exclude_labels = exclude_labels if exclude_labels is not None else []
if gold_label_dictionary is not None:
raise NotImplementedError("evaluating an EntityMentionLinker with a gold_label_dictionary is not supported")

Expand Down
26 changes: 16 additions & 10 deletions flair/models/pairwise_regression_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
Expand All @@ -12,7 +11,7 @@
from flair.data import Corpus, Dictionary, Sentence, TextPair, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.nn.model import ReduceTransformerVocabMixin
from flair.training_utils import MetricRegression, Result, store_embeddings
from flair.training_utils import EmbeddingStorageMode, MetricRegression, Result, store_embeddings


class TextPairRegressor(flair.nn.Model[TextPair], ReduceTransformerVocabMixin):
Expand Down Expand Up @@ -91,7 +90,7 @@ def label_type(self):

def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
) -> typing.Iterable[List[str]]:
) -> Iterable[List[str]]:
for sentence_pair in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence_pair.first]
yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)]
Expand Down Expand Up @@ -204,10 +203,16 @@ def _get_state_dict(self):
return model_state

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
# add DefaultClassifier arguments
def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs):
"""Initializes a TextPairRegressor model from a state dictionary (exported by _get_state_dict).
Requires keys 'state_dict', 'document_embeddings', and 'label_type' in the state dictionary.
"""
if "document_embeddings" in state:
state["embeddings"] = state.pop("document_embeddings") # need to rename this parameter
# add Model arguments
for arg in [
"document_embeddings",
"embeddings",
"label_type",
"embed_separately",
"dropout",
Expand Down Expand Up @@ -276,14 +281,15 @@ def evaluate(
data_points: Union[List[TextPair], Dataset],
gold_label_type: str,
out_path: Union[str, Path, None] = None,
embedding_storage_mode: str = "none",
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
exclude_labels: List[str] = [],
main_evaluation_metric: Tuple[str, str] = ("correlation", "pearson"),
exclude_labels: Optional[List[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
**kwargs,
) -> Result:
exclude_labels = exclude_labels if exclude_labels is not None else []
# read Dataset into data loader, if list of sentences passed, make Dataset first
if not isinstance(data_points, Dataset):
data_points = FlairDatapointDataset(data_points)
Expand Down
3 changes: 2 additions & 1 deletion flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings import DocumentEmbeddings, TransformerDocumentEmbeddings
from flair.tokenization import SpaceTokenizer
from flair.training_utils import EmbeddingStorageMode

logger: logging.Logger = logging.getLogger("flair")

Expand Down Expand Up @@ -602,7 +603,7 @@ def predict(
verbose: bool = False,
label_name: Optional[str] = None,
return_loss: bool = False,
embedding_storage_mode: str = "none",
embedding_storage_mode: EmbeddingStorageMode = "none",
) -> Optional[Tuple[torch.Tensor, int]]:
"""Predicts the class labels for the given sentence(s).
Expand Down
9 changes: 5 additions & 4 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings.base import load_embeddings
from flair.nn.model import ReduceTransformerVocabMixin
from flair.training_utils import MetricRegression, Result, store_embeddings
from flair.training_utils import EmbeddingStorageMode, MetricRegression, Result, store_embeddings

log = logging.getLogger("flair")

Expand Down Expand Up @@ -78,7 +78,7 @@ def predict(
mini_batch_size: int = 32,
verbose: bool = False,
label_name: Optional[str] = None,
embedding_storage_mode="none",
embedding_storage_mode: EmbeddingStorageMode = "none",
) -> List[Sentence]:
if label_name is None:
label_name = self.label_name if self.label_name is not None else "label"
Expand Down Expand Up @@ -135,14 +135,15 @@ def evaluate(
data_points: Union[List[Sentence], Dataset],
gold_label_type: str,
out_path: Optional[Union[str, Path]] = None,
embedding_storage_mode: str = "none",
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
exclude_labels: List[str] = [],
exclude_labels: Optional[List[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
**kwargs,
) -> Result:
exclude_labels = exclude_labels if exclude_labels is not None else []
# read Dataset into data loader, if list of sentences passed, make Dataset first
if not isinstance(data_points, Dataset):
data_points = FlairDatapointDataset(data_points)
Expand Down
Loading

0 comments on commit f4466f7

Please sign in to comment.