diff --git a/.gitignore b/.gitignore
index 3c7ce29ca..6ff5d3657 100644
--- a/.gitignore
+++ b/.gitignore
@@ -277,7 +277,7 @@ crashlytics.properties
crashlytics-build.properties
fabric.properties
-# Editor-based Rest Client
+# Editor-based Rest HanLPClient
.idea/httpRequests
# Android studio 3.1+ serialized cache file
@@ -285,4 +285,9 @@ fabric.properties
.idea
*.iml
data
-.vscode/settings.json
+.vscode
+*.pkl
+*.pdf
+_static/
+_build/
+_templates/
\ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 000000000..d979760ad
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,21 @@
+language: python
+cache: pip
+python:
+ - '3.6'
+install:
+ - pip install .
+deploy:
+ provider: pypi
+ username: __token__
+ password:
+ secure: KU0S/z54UMdS3rJT0fNndVnvhKB48YBzpwBZQZAOUJafFyqw1Nm366cpn9OdyWPQ54LolQNEKyQZc7xDpV89j1ukKQ1aGgZ5rXD8zrAqivcWEzEWzRpO8uPGbGT0TSJDfd3zX8vHO5UznmW2nNuRJfHFkEmB/27TlZAs2ph/SrEGvuBQOFgQZMShzFWGRKL+kEXX946qlw1EdLe2XvpK7jkWQpG9c8S5mNhbqBMAofVAXyNoHqX3FrPdEvN9MY9iRx3FxusHBqHeRLwrPHK2aQLVUE5D0WE1NzKwNZ4UxbY4PfiESYDueqGR8O/awpuLwg+6itk6FbtExAIAZyDLvGS4o88AGks6VJlJKwdT0LZ6cR1+WOGXyewSjHiJmjdBnFCtvyjn/O6sDEIDmku4FINuNIcmXy2bYwns9D3lNzb2EYpSTu5A9Q4EAAWZ4t0DsWBSRJmuauv6VNTHOENPRXR3fA9honp6GWiEh+4b/yfIaT9p0VnkR7D3KoN27eNmouU4s68hAfnFVPnB/OWU/DNoWs2PbLo4ztficmGOcOyDbS4BjrLjxuyU3aAHYIeXAff6A3I/a1tz+QknYCOJz/ZnQ3e4FC+2lm/cCGzPTfi+IVQ7QJryAY8hbblDX48PHCzVLa0PPer+v2NZVrnfddMZoLd1ox65hM2gHuy6NkQ=
+ on:
+ python: 3.6
+ branch: master
+env:
+ global:
+ - HANLP_VERBOSE=0
+script:
+ - python -m unittest discover ./tests
+notifications:
+ email: false
diff --git a/README.md b/README.md
index 1e9e19409..5576466bb 100644
--- a/README.md
+++ b/README.md
@@ -1,325 +1,110 @@
# HanLP: Han Language Processing
-[中文](https://github.com/hankcs/HanLP/tree/doc-zh) | [1.x](https://github.com/hankcs/HanLP/tree/1.x) | [forum](https://bbs.hankcs.com/) | [docker](https://github.com/WalterInSH/hanlp-jupyter-docker)
+[中文](https://github.com/hankcs/HanLP/tree/doc-zh) | [docs](https://hanlp.hankcs.com/docs/) | [1.x](https://github.com/hankcs/HanLP/tree/1.x) | [forum](https://bbs.hankcs.com/) | [docker](https://github.com/WalterInSH/hanlp-jupyter-docker)
-The multilingual NLP library for researchers and companies, built on TensorFlow 2.0, for advancing state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be efficient, user friendly and extendable. It comes with pretrained models for various human languages including English, Chinese and many others. Currently, HanLP 2.0 is in alpha stage with more killer features on the roadmap. Discussions are welcomed on our [forum](https://bbs.hankcs.com/), while bug reports and feature requests are reserved for GitHub issues. For Java users, please checkout the [1.x](https://github.com/hankcs/HanLP/tree/1.x) branch.
+The multilingual NLP library for researchers and companies, built on PyTorch and TensorFlow 2.x, for advancing state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be efficient, user friendly and extendable. It comes with pretrained models for 104 human languages including English, Chinese and many others.
- ## Installation
+Thanks to open-access corpora like Universal Dependencies and OntoNotes, HanLP 2.1 now offers 10 joint tasks on 104 languages: tokenization, lemmatization, part-of-speech tagging, token feature extraction, dependency parsing, constituency parsing, semantic role labeling, semantic dependency parsing, abstract meaning representation (AMR) parsing.
-```bash
-pip install hanlp
-```
-
-HanLP requires Python 3.6 or later. GPU/TPU is suggested but not mandatory.
-
-## Quick Start
-
-### Tokenization
-
-For an end user, the basic workflow starts with loading some pretrained models from disk or Internet. Each model has an identifier, which could be one path on your computer or an URL to any public servers. To tokenize Chinese, let's load a tokenizer called `CTB6_CONVSEG` with 2 lines of code.
-
-```python
->>> import hanlp
->>> tokenizer = hanlp.load('CTB6_CONVSEG')
-```
-
-HanLP will automatically resolve the identifier `CTB6_CONVSEG` to an [URL](https://file.hankcs.com/hanlp/cws/ctb6-convseg-cws_20191230_184525.zip), then download it and unzip it. Due to the huge network traffic, it could fail temporally then you need to retry or manually download and unzip it to the path shown in your terminal .
-
-Once the model is loaded, you can then tokenize one sentence through calling the tokenizer as a function:
-
-```python
->>> tokenizer('商品和服务')
-['商品', '和', '服务']
-```
-
-If you're processing English, a rule based function should be good enough.
-
-```python
->>> tokenizer = hanlp.utils.rules.tokenize_english
->>> tokenizer("Don't go gentle into that good night.")
-['Do', "n't", 'go', 'gentle', 'into', 'that', 'good', 'night', '.']
-```
-
-#### Going Further
-
-However, you can predict much faster. In the era of deep learning, batched computation usually gives a linear scale-up factor of `batch_size`. So, you can predict multiple sentences at once, at the cost of GPU memory.
-
-```python
->>> tokenizer(['萨哈夫说,伊拉克将同联合国销毁伊拉克大规模杀伤性武器特别委员会继续保持合作。',
- '上海华安工业(集团)公司董事长谭旭光和秘书张晚霞来到美国纽约现代艺术博物馆参观。',
- 'HanLP支援臺灣正體、香港繁體,具有新詞辨識能力的中文斷詞系統'])
-[['萨哈夫', '说', ',', '伊拉克', '将', '同', '联合国', '销毁', '伊拉克', '大', '规模', '杀伤性', '武器', '特别', '委员会', '继续', '保持', '合作', '。'],
- ['上海', '华安', '工业', '(', '集团', ')', '公司', '董事长', '谭旭光', '和', '秘书', '张晚霞', '来到', '美国', '纽约', '现代', '艺术', '博物馆', '参观', '。'],
- ['HanLP', '支援', '臺灣', '正體', '、', '香港', '繁體', ',', '具有', '新詞', '辨識', '能力', '的', '中文', '斷詞', '系統']]
-```
-
-That's it! You're now ready to employ the latest DL models from HanLP in your research and work. Here are some tips if you want to go further.
-
-- Print `hanlp.pretrained.ALL` to list all the pretrained models available in HanLP.
-
-- Use `hanlp.pretrained.*` to browse pretrained models by categories of NLP tasks. You can use the variables to identify them too.
-
- ```python
- >>> hanlp.pretrained.cws.CTB6_CONVSEG
- 'https://file.hankcs.com/hanlp/cws/ctb6-convseg-cws_20191230_184525.zip'
- ```
-
-### Part-of-Speech Tagging
-
-Taggers take lists of tokens as input, then outputs one tag for each token.
-
-```python
->>> tagger = hanlp.load(hanlp.pretrained.pos.PTB_POS_RNN_FASTTEXT_EN)
->>> tagger([['I', 'banked', '2', 'dollars', 'in', 'a', 'bank', '.'],
- ['Is', 'this', 'the', 'future', 'of', 'chamber', 'music', '?']])
-[['PRP', 'VBD', 'CD', 'NNS', 'IN', 'DT', 'NN', '.'],
- ['VBZ', 'DT', 'DT', 'NN', 'IN', 'NN', 'NN', '.']]
-```
-
-The language solely depends on which model you load.
-
-```python
->>> tagger = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ALBERT_BASE)
->>> tagger(['我', '的', '希望', '是', '希望', '和平'])
-['PN', 'DEG', 'NN', 'VC', 'VV', 'NN']
-```
-
-Did you notice the different pos tags for the same word `希望` ("hope")? The first one means "my dream" as a noun while the later means "want" as a verb. This tagger uses fasttext[^fasttext] as its embedding layer, which is free from OOV.
-
-### Named Entity Recognition
-
-The NER component requires tokenized tokens as input, then outputs the entities along with their types and spans.
-
-```python
->>> recognizer = hanlp.load(hanlp.pretrained.ner.CONLL03_NER_BERT_BASE_UNCASED_EN)
->>> recognizer(["President", "Obama", "is", "speaking", "at", "the", "White", "House"])
-[('Obama', 'PER', 1, 2), ('White House', 'LOC', 6, 8)]
-```
+For end users, HanLP offers light-weighted RESTful APIs and native Python APIs.
-Recognizers take lists of tokens as input, so don't forget to wrap your sentence with `list`. For the outputs, each tuple stands for `(entity, type, begin, end)`.
+## RESTful APIs
-```python
->>> recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
->>> recognizer([list('上海华安工业(集团)公司董事长谭旭光和秘书张晚霞来到美国纽约现代艺术博物馆参观。'),
- list('萨哈夫说,伊拉克将同联合国销毁伊拉克大规模杀伤性武器特别委员会继续保持合作。')])
-[[('上海华安工业(集团)公司', 'NT', 0, 12), ('谭旭光', 'NR', 15, 18), ('张晚霞', 'NR', 21, 24), ('美国', 'NS', 26, 28), ('纽约现代艺术博物馆', 'NS', 28, 37)],
- [('萨哈夫', 'NR', 0, 3), ('伊拉克', 'NS', 5, 8), ('联合国销毁伊拉克大规模杀伤性武器特别委员会', 'NT', 10, 31)]]
-```
+Tiny packages in several KBs for agile development and mobile applications. An auth key is required and [a free one can be applied here](https://bbs.hankcs.com/t/apply-for-free-hanlp-restful-apis/3178) under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.
-This `MSRA_NER_BERT_BASE_ZH` is the state-of-the-art NER model based on BERT[^bert]. You can read its evaluation log through:
+ ### Python
```bash
-$ cat ~/.hanlp/ner/ner_bert_base_msra_20200104_185735/test.log
-20-01-04 18:55:02 INFO Evaluation results for test.tsv - loss: 1.4949 - f1: 0.9522 - speed: 113.37 sample/sec
-processed 177342 tokens with 5268 phrases; found: 5316 phrases; correct: 5039.
-accuracy: 99.37%; precision: 94.79%; recall: 95.65%; FB1: 95.22
- NR: precision: 96.39%; recall: 97.83%; FB1: 97.10 1357
- NS: precision: 96.70%; recall: 95.79%; FB1: 96.24 2610
- NT: precision: 89.47%; recall: 93.13%; FB1: 91.27 1349
+pip install hanlp_restful
```
-### Syntactic Dependency Parsing
-
-Parsing lies in the core of NLP. Without parsing, one cannot claim to be a NLP researcher or engineer. But using HanLP, it takes no more than two lines of code.
+Create a client with our API endpoint and your auth.
```python
->>> syntactic_parser = hanlp.load(hanlp.pretrained.dep.PTB_BIAFFINE_DEP_EN)
->>> print(syntactic_parser([('Is', 'VBZ'), ('this', 'DT'), ('the', 'DT'), ('future', 'NN'), ('of', 'IN'), ('chamber', 'NN'), ('music', 'NN'), ('?', '.')]))
-1 Is _ VBZ _ _ 4 cop _ _
-2 this _ DT _ _ 4 nsubj _ _
-3 the _ DT _ _ 4 det _ _
-4 future _ NN _ _ 0 root _ _
-5 of _ IN _ _ 4 prep _ _
-6 chamber _ NN _ _ 7 nn _ _
-7 music _ NN _ _ 5 pobj _ _
-8 ? _ . _ _ 4 punct _ _
+from hanlp_restful import HanLPClient
+HanLP = HanLPClient('https://hanlp.hankcs.com/api', auth='your_auth', language='mul')
```
-Parsers take both tokens and part-of-speech tags as input. The output is a tree in CoNLL-X format[^conllx], which can be manipulated through the `CoNLLSentence` class. Similar codes for Chinese:
+### Java
-```python
->>> syntactic_parser = hanlp.load(hanlp.pretrained.dep.CTB7_BIAFFINE_DEP_ZH)
->>> print(syntactic_parser([('蜡烛', 'NN'), ('两', 'CD'), ('头', 'NN'), ('烧', 'VV')]))
-1 蜡烛 _ NN _ _ 4 nsubj _ _
-2 两 _ CD _ _ 3 nummod _ _
-3 头 _ NN _ _ 4 dep _ _
-4 烧 _ VV _ _ 0 root _ _
-```
-
-### Semantic Dependency Parsing
+Insert the following dependency into your `pom.xml`.
-A graph is a generalized tree, which conveys more information about the semantic relations between tokens.
-
-```python
->>> semantic_parser = hanlp.load(hanlp.pretrained.sdp.SEMEVAL15_PAS_BIAFFINE_EN)
->>> print(semantic_parser([('Is', 'VBZ'), ('this', 'DT'), ('the', 'DT'), ('future', 'NN'), ('of', 'IN'), ('chamber', 'NN'), ('music', 'NN'), ('?', '.')]))
-1 Is _ VBZ _ _ 0 ROOT _ _
-2 this _ DT _ _ 1 verb_ARG1 _ _
-3 the _ DT _ _ 0 ROOT _ _
-4 future _ NN _ _ 1 verb_ARG2 _ _
-4 future _ NN _ _ 3 det_ARG1 _ _
-4 future _ NN _ _ 5 prep_ARG1 _ _
-5 of _ IN _ _ 0 ROOT _ _
-6 chamber _ NN _ _ 0 ROOT _ _
-7 music _ NN _ _ 5 prep_ARG2 _ _
-7 music _ NN _ _ 6 noun_ARG1 _ _
-8 ? _ . _ _ 0 ROOT _ _
+```xml
+
+ com.hankcs.hanlp.restful
+ hanlp-restful
+ 0.0.2
+
```
-HanLP implements the biaffine[^biaffine] model which delivers the SOTA performance.
+Create a client with our API endpoint and your auth.
-```python
->>> semantic_parser = hanlp.load(hanlp.pretrained.sdp.SEMEVAL16_NEWS_BIAFFINE_ZH)
->>> print(semantic_parser([('蜡烛', 'NN'), ('两', 'CD'), ('头', 'NN'), ('烧', 'VV')]))
-1 蜡烛 _ NN _ _ 3 Poss _ _
-1 蜡烛 _ NN _ _ 4 Pat _ _
-2 两 _ CD _ _ 3 Quan _ _
-3 头 _ NN _ _ 4 Loc _ _
-4 烧 _ VV _ _ 0 Root _ _
+```java
+HanLPClient HanLP = new HanLPClient("https://hanlp.hankcs.com/api", "your_auth", "mul");
```
-The output is a `CoNLLSentence` too. However, it's not a tree but a graph in which one node can have multiple heads, e.g. `蜡烛` has two heads (ID 3 and 4).
-
-### Pipelines
+### Quick Start
-Since parsers require part-of-speech tagging and tokenization, while taggers expects tokenization to be done beforehand, wouldn't it be nice if we have a pipeline to connect the inputs and outputs, like a computation graph?
+No matter which language you uses, the same interface can be used to parse a document.
```python
-pipeline = hanlp.pipeline() \
- .append(hanlp.utils.rules.split_sentence, output_key='sentences') \
- .append(tokenizer, output_key='tokens') \
- .append(tagger, output_key='part_of_speech_tags') \
- .append(syntactic_parser, input_key=('tokens', 'part_of_speech_tags'), output_key='syntactic_dependencies') \
- .append(semantic_parser, input_key=('tokens', 'part_of_speech_tags'), output_key='semantic_dependencies')
+HanLP.parse("In 2021, HanLPv2.1 delivers state-of-the-art multilingual NLP techniques to production environment. 2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。2021年 HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。")
```
-Notice that the first pipe is an old-school Python function `split_sentence`, which splits the input text into a list of sentences. Then the later DL components can utilize the batch processing seamlessly. This results in a pipeline with one input (text) pipe, multiple flow pipes and one output (parsed document). You can print out the pipeline to check its structure.
+## Native APIs
-```python
->>> pipeline
-[None->LambdaComponent->sentences, sentences->NgramConvTokenizer->tokens, tokens->RNNPartOfSpeechTagger->part_of_speech_tags, ('tokens', 'part_of_speech_tags')->BiaffineDependencyParser->syntactic_dependencies, ('tokens', 'part_of_speech_tags')->BiaffineSemanticDependencyParser->semantic_dependencies]
+```bash
+pip install hanlp
```
-This time, let's feed in a whole document `text`, which might be the scenario in your daily work.
-
-```python
->>> print(pipeline(text))
-{
- "sentences": [
- "Jobs and Wozniak co-founded Apple in 1976 to sell Wozniak's Apple I personal computer.",
- "Together the duo gained fame and wealth a year later with the Apple II."
- ],
- "tokens": [
- ["Jobs", "and", "Wozniak", "co-founded", "Apple", "in", "1976", "to", "sell", "Wozniak", "'s", "", "Apple", "I", "personal", "computer", "."],
- ["Together", "the", "duo", "gained", "fame", "and", "wealth", "a", "year", "later", "with", "the", "Apple", "II", "."]
- ],
- "part_of_speech_tags": [
- ["NNS", "CC", "NNP", "VBD", "NNP", "IN", "CD", "TO", "VB", "NNP", "POS", "``", "NNP", "PRP", "JJ", "NN", "."],
- ["IN", "DT", "NN", "VBD", "NN", "CC", "NN", "DT", "NN", "RB", "IN", "DT", "NNP", "NNP", "."]
- ],
- "syntactic_dependencies": [
- [[4, "nsubj"], [1, "cc"], [1, "conj"], [0, "root"], [4, "dobj"], [4, "prep"], [6, "pobj"], [9, "aux"], [4, "xcomp"], [16, "poss"], [10, "possessive"], [16, "punct"], [16, "nn"], [16, "nn"], [16, "amod"], [9, "dobj"], [4, "punct"]],
- [[4, "advmod"], [3, "det"], [4, "nsubj"], [0, "root"], [4, "dobj"], [5, "cc"], [5, "conj"], [9, "det"], [10, "npadvmod"], [4, "advmod"], [4, "prep"], [14, "det"], [14, "nn"], [11, "pobj"], [4, "punct"]]
- ],
- "semantic_dependencies": [
- [[[2], ["coord_ARG1"]], [[4, 9], ["verb_ARG1", "verb_ARG1"]], [[2], ["coord_ARG2"]], [[6, 8], ["prep_ARG1", "comp_MOD"]], [[4], ["verb_ARG2"]], [[0], ["ROOT"]], [[6], ["prep_ARG2"]], [[0], ["ROOT"]], [[8], ["comp_ARG1"]], [[11], ["poss_ARG2"]], [[0], ["ROOT"]], [[0], ["ROOT"]], [[0], ["ROOT"]], [[0], ["ROOT"]], [[0], ["ROOT"]], [[9, 11, 12, 14, 15], ["verb_ARG3", "poss_ARG1", "punct_ARG1", "noun_ARG1", "adj_ARG1"]], [[0], ["ROOT"]]],
- [[[0], ["ROOT"]], [[0], ["ROOT"]], [[1, 2, 4], ["adj_ARG1", "det_ARG1", "verb_ARG1"]], [[1, 10], ["adj_ARG1", "adj_ARG1"]], [[6], ["coord_ARG1"]], [[4], ["verb_ARG2"]], [[6], ["coord_ARG2"]], [[0], ["ROOT"]], [[8], ["det_ARG1"]], [[9], ["noun_ARG1"]], [[0], ["ROOT"]], [[0], ["ROOT"]], [[0], ["ROOT"]], [[11, 12, 13], ["prep_ARG2", "det_ARG1", "noun_ARG1"]], [[0], ["ROOT"]]]
- ]
-}
-```
+HanLP requires Python 3.6 or later. GPU/TPU is suggested but not mandatory.
-The output for Chinese looks similar to the English one.
+### Quick Start
```python
->>> print(pipeline(text))
-{
- "sentences": [
- "HanLP是一系列模型与算法组成的自然语言处理工具包,目标是普及自然语言处理在生产环境中的应用。",
- "HanLP具备功能完善、性能高效、架构清晰、语料时新、可自定义的特点。",
- "内部算法经过工业界和学术界考验,配套书籍《自然语言处理入门》已经出版。"
- ],
- "tokens": [
- ["HanLP", "是", "一", "系列", "模型", "与", "算法", "组成", "的", "自然", "语言", "处理", "工具包", ",", "目标", "是", "普及", "自然", "语言", "处理", "在", "生产", "环境", "中", "的", "应用", "。"],
- ["HanLP", "具备", "功能", "完善", "、", "性能", "高效", "、", "架构", "清晰", "、", "语料", "时", "新", "、", "可", "自", "定义", "的", "特点", "。"],
- ["内部", "算法", "经过", "工业界", "和", "学术界", "考验", ",", "配套", "书籍", "《", "自然", "语言", "处理", "入门", "》", "已经", "出版", "。"]
- ],
- "part_of_speech_tags": [
- ["NR", "VC", "CD", "M", "NN", "CC", "NN", "VV", "DEC", "NN", "NN", "VV", "NN", "PU", "NN", "VC", "VV", "NN", "NN", "VV", "P", "NN", "NN", "LC", "DEG", "NN", "PU"],
- ["NR", "VV", "NN", "VA", "PU", "NN", "VA", "PU", "NN", "VA", "PU", "NN", "LC", "VA", "PU", "VV", "P", "VV", "DEC", "NN", "PU"],
- ["NN", "NN", "P", "NN", "CC", "NN", "NN", "PU", "VV", "NN", "PU", "NN", "NN", "NN", "NN", "PU", "AD", "VV", "PU"]
- ],
- "syntactic_dependencies": [
- [[2, "top"], [0, "root"], [4, "nummod"], [11, "clf"], [7, "conj"], [7, "cc"], [8, "nsubj"], [11, "rcmod"], [8, "cpm"], [11, "nn"], [12, "nsubj"], [2, "ccomp"], [12, "dobj"], [2, "punct"], [16, "top"], [2, "conj"], [16, "ccomp"], [19, "nn"], [20, "nsubj"], [17, "conj"], [26, "assmod"], [23, "nn"], [24, "lobj"], [21, "plmod"], [21, "assm"], [20, "dobj"], [2, "punct"]],
- [[2, "nsubj"], [0, "root"], [4, "nsubj"], [20, "rcmod"], [4, "punct"], [7, "nsubj"], [4, "conj"], [4, "punct"], [10, "nsubj"], [4, "conj"], [4, "punct"], [13, "lobj"], [14, "loc"], [4, "conj"], [4, "punct"], [18, "mmod"], [18, "advmod"], [4, "conj"], [4, "cpm"], [2, "dobj"], [2, "punct"]],
- [[2, "nn"], [18, "nsubj"], [18, "prep"], [6, "conj"], [6, "cc"], [7, "nn"], [3, "pobj"], [18, "punct"], [10, "rcmod"], [15, "nn"], [15, "punct"], [15, "nn"], [15, "nn"], [15, "nn"], [18, "nsubj"], [15, "punct"], [18, "advmod"], [0, "root"], [18, "punct"]]
- ],
- "semantic_dependencies": [
- [[[2], ["Exp"]], [[0], ["Aft"]], [[4], ["Quan"]], [[0], ["Aft"]], [[8], ["Poss"]], [[7], ["mConj"]], [[8], ["Datv"]], [[11], ["rProd"]], [[8], ["mAux"]], [[11], ["Desc"]], [[12], ["Datv"]], [[2], ["dClas"]], [[2, 12], ["Clas", "Cont"]], [[2, 12], ["mPunc", "mPunc"]], [[16], ["Exp"]], [[17], ["mMod"]], [[2], ["eSucc"]], [[19], ["Desc"]], [[20], ["Pat"]], [[26], ["rProd"]], [[23], ["mPrep"]], [[23], ["Desc"]], [[20], ["Loc"]], [[23], ["mRang"]], [[0], ["Aft"]], [[16], ["Clas"]], [[16], ["mPunc"]]],
- [[[2], ["Poss"]], [[0], ["Aft"]], [[4], ["Exp"]], [[0], ["Aft"]], [[4], ["mPunc"]], [[0], ["Aft"]], [[4], ["eCoo"]], [[4, 7], ["mPunc", "mPunc"]], [[0], ["Aft"]], [[0], ["Aft"]], [[7, 10], ["mPunc", "mPunc"]], [[0], ["Aft"]], [[12], ["mTime"]], [[0], ["Aft"]], [[14], ["mPunc"]], [[0], ["Aft"]], [[0], ["Aft"]], [[20], ["Desc"]], [[18], ["mAux"]], [[0], ["Aft"]], [[0], ["Aft"]]],
- [[[2], ["Desc"]], [[7, 9, 18], ["Exp", "Agt", "Exp"]], [[4], ["mPrep"]], [[0], ["Aft"]], [[6], ["mPrep"]], [[7], ["Datv"]], [[0], ["Aft"]], [[7], ["mPunc"]], [[7], ["eCoo"]], [[0], ["Aft"]], [[0], ["Aft"]], [[13], ["Desc"]], [[0], ["Aft"]], [[0], ["Aft"]], [[0], ["Aft"]], [[0], ["Aft"]], [[18], ["mTime"]], [[0], ["Aft"]], [[18], ["mPunc"]]]
- ]
-}
+import hanlp
+HanLP = hanlp.load(hanlp.pretrained.mtl.CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH)
+HanLP(['In 2021, HanLPv2.1 delivers state-of-the-art multilingual NLP techniques to production environment.',
+ '2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。',
+ '2021年 HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。'])
```
-The output is a json `dict`, which most people are familiar with.
-
-- Feel free to add more pre/post-processing to the pipeline, including cleaning, custom dictionary etc.
-- Use `pipeline.save('zh.json')` to save your pipeline and deploy it to your production server.
+In particular, the Python `HanLPClient` can also be used as a callable function following the same semantics. See [docs](https://hanlp.hankcs.com/docs/) for more details.
## Train Your Own Models
-To write DL models is not hard, the real hard thing is to write a model able to reproduce the score in papers. The snippet below shows how to train a 97% F1 cws model on MSR corpus.
+To write DL models is not hard, the real hard thing is to write a model able to reproduce the scores in papers. The snippet below shows how to surpass the state-of-the-art tokenizer in 9 minutes.
```python
-tokenizer = NgramConvTokenizer()
-save_dir = 'data/model/cws/convseg-msr-nocrf-noembed'
-tokenizer.fit(SIGHAN2005_MSR_TRAIN,
- SIGHAN2005_MSR_VALID,
- save_dir,
- word_embed={'class_name': 'HanLP>Word2VecEmbedding',
- 'config': {
- 'trainable': True,
- 'filepath': CONVSEG_W2V_NEWS_TENSITE_CHAR,
- 'expand_vocab': False,
- 'lowercase': False,
- }},
- optimizer=tf.keras.optimizers.Adam(learning_rate=0.001,
- epsilon=1e-8, clipnorm=5),
- epochs=100,
- window_size=0,
- metrics='f1',
- weight_norm=True)
-tokenizer.evaluate(SIGHAN2005_MSR_TEST, save_dir=save_dir)
-```
-
-The training and evaluation logs are as follows.
-
-```
-Train for 783 steps, validate for 87 steps
-Epoch 1/100
-783/783 [==============================] - 177s 226ms/step - loss: 15.6354 - f1: 0.8506 - val_loss: 9.9109 - val_f1: 0.9081
-Epoch 2/100
-236/783 [========>.....................] - ETA: 1:41 - loss: 9.0359 - f1: 0.9126
-...
-19-12-28 20:55:59 INFO Trained 100 epochs in 3 h 55 m 42 s, each epoch takes 2 m 21 s
-19-12-28 20:56:06 INFO Evaluation results for msr_test_gold.utf8 - loss: 3.6579 - f1: 0.9715 - speed: 1173.80 sample/sec
+tokenizer = TransformerTaggingTokenizer()
+save_dir = 'data/model/cws/sighan2005_pku_bert_base_96.61'
+tokenizer.fit(
+ SIGHAN2005_PKU_TRAIN_ALL,
+ SIGHAN2005_PKU_TEST, # Conventionally, no devset is used. See Tian et al. (2020).
+ save_dir,
+ 'bert-base-chinese',
+ max_seq_len=300,
+ char_level=True,
+ hard_constraint=True,
+ sampler_builder=SortingSamplerBuilder(batch_size=32),
+ epochs=3,
+ adam_epsilon=1e-6,
+ warmup_steps=0.1,
+ weight_decay=0.01,
+ word_dropout=0.1,
+ seed=1609422632,
+)
+tokenizer.evaluate(SIGHAN2005_PKU_TEST, save_dir)
```
-Similarly, you can train a sentiment classifier to classify the comments of hotels.
+The result is guaranteed to be `96.66` as the random feed is fixed. Different from some overclaining papers and projects, HanLP promises every digit in our scores are reproducible. Any issues on reproducibility will be treated and solved as a top-priority fatal bug.
-```python
-save_dir = 'data/model/classification/chnsenticorp_bert_base'
-classifier = TransformerClassifier(TransformerTextTransform(y_column=0))
-classifier.fit(CHNSENTICORP_ERNIE_TRAIN, CHNSENTICORP_ERNIE_VALID, save_dir,
- transformer='chinese_L-12_H-768_A-12')
-classifier.load(save_dir)
-print(classifier('前台客房服务态度非常好!早餐很丰富,房价很干净。再接再厉!'))
-classifier.evaluate(CHNSENTICORP_ERNIE_TEST, save_dir=save_dir)
-```
+## Performance
-Due to the size of models, and the fact that corpora are domain specific, HanLP has limited plan to distribute pretrained text classification models.
+
lang corpora model tok pos ner dep con srl sdp lem fea amr fine coarse ctb pku 863 ud pku msra ontonotes SemEval16 DM PAS PSD mul UD2.7 OntoNotes5 small 98.30 - - - - 91.72 - - 74.86 74.66 74.29 65.73 - 88.52 92.56 83.84 84.65 81.13 - base 99.59 - - - - 95.95 - - 80.31 85.84 80.22 74.61 - 93.23 95.16 86.57 92.91 90.30 - zh open small 97.25 - 96.66 - - - - - 95.00 84.57 87.62 73.40 84.57 - - - - - - base 97.50 - 97.07 - - - - - 96.04 87.11 89.84 77.78 87.11 - - - - - - close small 96.70 95.93 96.87 97.56 95.05 - 96.22 95.74 76.79 84.44 88.13 75.81 74.28 - - - - - - base 97.52 96.44 96.99 97.59 95.29 - 96.48 95.72 77.77 85.29 88.57 76.52 73.76 - - - - - -
-For more training scripts, please refer to [`tests/train`](https://github.com/hankcs/HanLP/tree/master/tests/train). We are also working hard to release more examples in [`tests/demo`](https://github.com/hankcs/HanLP/tree/master/tests/demo). Serving, documentations and more pretrained models are on the way too.
+- Multilingual models are temporary ones which will be replaced in one week.
+- AMR models will be released once our paper gets accepted.
## Citing
@@ -336,15 +121,15 @@ If you use HanLP in your research, please cite this repository.
## License
-HanLP is licensed under **Apache License 2.0**. You can use HanLP in your commercial products for free. We would appreciate it if you add a link to HanLP on your website.
+### Codes
-## References
+HanLP is licensed under **Apache License 2.0**. You can use HanLP in your commercial products for free. We would appreciate it if you add a link to HanLP on your website.
-[^fasttext]: A. Joulin, E. Grave, P. Bojanowski, and T. Mikolov, “Bag of Tricks for Efficient Text Classification,” vol. cs.CL. 07-Jul-2016.
+### Models
-[^bert]: J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding,” arXiv.org, vol. cs.CL. 10-Oct-2018.bert
+Unless specified, all models in HanLP are licensed under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) .
-[^biaffine]: T. Dozat and C. D. Manning, “Deep Biaffine Attention for Neural Dependency Parsing.,” ICLR, 2017.
+## References
-[^conllx]: Buchholz, S., & Marsi, E. (2006, June). CoNLL-X shared task on multilingual dependency parsing. In *Proceedings of the tenth conference on computational natural language learning* (pp. 149-164). Association for Computational Linguistics.
+https://hanlp.hankcs.com/docs/references.html
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 000000000..d4bb2cbb9
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/annotations/con/ctb.md b/docs/annotations/con/ctb.md
new file mode 100644
index 000000000..94af82667
--- /dev/null
+++ b/docs/annotations/con/ctb.md
@@ -0,0 +1,54 @@
+
+
+# Chinese Tree Bank
+
+See also [The Bracketing Guidelines for the Penn Chinese Treebank (3.0)](https://repository.upenn.edu/cgi/viewcontent.cgi?article=1040&context=ircs_reports).
+
+| Tag | Definition | 定义 | 例子 |
+|------|----------------------------------------------|----------------------------------------------------|-------------------|
+| ADJP | adjective phrase | 形容词短语,以形容词为中心词 | 不完全、大型 |
+| ADVP | adverbial phrase headed by AD (adverb) | 副词短语,以副词为中心词 | 非常、很 |
+| CLP | classifier phrase | 由量词构成的短语 | 系列、大批 |
+| CP | clause headed by C (complementizer) | 从句,通过带补语(如“的”、“吗”等) | 张三喜欢李四吗? |
+| DNP | phrase formed by ‘‘XP + DEG’’ | 结构为XP + DEG(的)的短语,其中XP可以是ADJP、DP、QP、PP等等,用于修饰名词短语。 | 大型的、前几年的、五年的、在上海的 |
+| DP | determiner phrase | 限定词短语,通常由限定词和数量词构成 | 这三个、任何 |
+| DVP | phrase formed by ‘‘XP + DEV’’ | 结构为XP+地的短评,用于修饰动词短语VP | 心情失落地、大批地 |
+| FRAG | fragment | 片段 | (完) |
+| INTJ | interjection | 插话,感叹语 | 哈哈、切 |
+| IP | simple clause headed by I (INFL) | 简单子句或句子,通常不带补语(如“的”、“吗”等) | 张三喜欢李四。 |
+| LCP | phrase formed by ‘‘XP + LC’’ | 用于表本地点+方位词(LC)的短语 | 生活中、田野上 |
+| LST | list marker | 列表短语,包括标点符号 | 一. |
+| MSP | some particles | 其他小品词 | 所、而、来、去 |
+| NN | common noun | 名词 | HanLP、技术 |
+| NP | noun phrase | 名词短语,中心词通常为名词 | 美好生活、经济水平 |
+| PP | preposition phrase | 介词短语,中心词通常为介词 | 在北京、据报道 |
+| PRN | parenthetical | 插入语 | ,(张三说), |
+| QP | quantifier phrase | 量词短语 | 三个、五百辆 |
+| ROOT | root node | 根节点 | 根节点 |
+| UCP | unidentical coordination phrase | 不对称的并列短语,指并列词两侧的短语类型不致 | (养老、医疗)保险 |
+| VCD | coordinated verb compound | 复合动词 | 出版发行 |
+| VCP | verb compounds formed by VV + VC | VV + VC形式的动词短语 | 看作是 |
+| VNV | verb compounds formed by A-not-A or A-one-A | V不V形式的动词短语 | 能不能、信不信 |
+| VP | verb phrase | 动词短语,中心词通常为动词 | 完成任务、努力工作 |
+| VPT | potential form V-de-R or V-bu-R | V不R、V得R形式的动词短语 | 打不赢、打得过 |
+| VRD | verb resultative compound | 动补结构短语 | 研制成功、降下来 |
+| VSB | verb compounds formed by a modifier + a head | 修饰语+中心词构成的动词短语 | 拿来支付、仰头望去 |
\ No newline at end of file
diff --git a/docs/annotations/con/index.md b/docs/annotations/con/index.md
new file mode 100644
index 000000000..156aa9d05
--- /dev/null
+++ b/docs/annotations/con/index.md
@@ -0,0 +1,7 @@
+# Constituency Parsing
+
+```{toctree}
+ctb
+ptb
+```
+
diff --git a/docs/annotations/con/ptb.md b/docs/annotations/con/ptb.md
new file mode 100644
index 000000000..70addaac5
--- /dev/null
+++ b/docs/annotations/con/ptb.md
@@ -0,0 +1,53 @@
+
+
+# Penn Treebank
+
+| Tag | Description |
+|--------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| ADJP | Adjective Phrase. |
+| ADVP | Adverb Phrase. |
+| CONJP | Conjunction Phrase. |
+| FRAG | Fragment. |
+| INTJ | Interjection. Corresponds approximately to the part-of-speech tag UH. |
+| LST | List marker. Includes surrounding punctuation. |
+| NAC | Not a Constituent; used to show the scope of certain prenominal modifiers within an NP. |
+| NP | Noun Phrase. |
+| NX | - Used within certain complex NPs to mark the head of the NP. Corresponds very roughly to N-bar level but used quite differently. |
+| PP | Prepositional Phrase. |
+| PRN | Parenthetical |
+| PRT | Particle. Category for words that should be tagged RP. |
+| QP | Quantifier Phrase (i.e. complex measure/amount phrase); used within NP. |
+| ROOT | No description |
+| RRC | Reduced Relative Clause. |
+| S | conjunction or a wh-word and that does not exhibit subject-verb inversion. |
+| SBAR | Clause introduced by a (possibly empty) subordinating conjunction. |
+| SBARQ | - Direct question introduced by a wh-word or a wh-phrase. Indirect questions and relative clauses should be bracketed as SBAR, not SBARQ. |
+| SINV | - Inverted declarative sentence, i.e. one in which the subject follows the tensed verb or modal. |
+| SQ | Inverted yes/no question, or main clause of a wh-question, following the wh-phrase in SBARQ. |
+| UCP | Unlike Coordinated Phrase. |
+| VP | Vereb Phrase. |
+| WHADJP | Wh-adjective Phrase. Adjectival phrase containing a wh-adverb, as in how hot. |
+| WHADVP | - Wh-adverb Phrase. Introduces a clause with an NP gap. May be null (containing the 0 complementizer) or lexical, containing a wh-adverb such as how or why. |
+| WHNP | - Wh-noun Phrase. Introduces a clause with an NP gap. May be null (containing the 0 complementizer) or lexical, containing some wh-word, e.g. who, which book, whose daughter, none of which, or how many leopards. |
+| WHPP | - Wh-prepositional Phrase. Prepositional phrase containing a wh-noun phrase (such as of which or by whose authority) that either introduces a PP gap or is contained by a WHNP. |
+| X | - Unknown, uncertain, or unbracketable. X is often used for bracketing typos and in bracketing the…the-constructions. |
+
diff --git a/docs/annotations/dep/index.md b/docs/annotations/dep/index.md
new file mode 100644
index 000000000..f6619e33a
--- /dev/null
+++ b/docs/annotations/dep/index.md
@@ -0,0 +1,7 @@
+# Dependency Parsing
+
+```{toctree}
+sd
+ud
+```
+
diff --git a/docs/annotations/dep/sd.md b/docs/annotations/dep/sd.md
new file mode 100644
index 000000000..687d72416
--- /dev/null
+++ b/docs/annotations/dep/sd.md
@@ -0,0 +1,155 @@
+
+
+# Stanford Dependencies
+
+See also [Stanford typed dependencies manual](https://nlp.stanford.edu/software/dependencies_manual.pdf).
+
+## English
+
+| Tag | Description |
+|------------|-----------------------------------|
+| abbrev | abbreviation modifier |
+| acomp | adjectival complement |
+| advcl | adverbial clause modifier |
+| advmod | adverbial modifier |
+| agent | agent |
+| amod | adjectival modifier |
+| appos | appositional modifier |
+| arg | argument |
+| attr | attributive |
+| aux | auxiliary |
+| auxpass | passive auxiliary |
+| cc | coordination |
+| ccomp | clausal complement |
+| comp | complement |
+| complm | complementizer |
+| conj | conjunct |
+| cop | copula |
+| csubj | clausal subject |
+| csubjpass | clausal passive subject |
+| dep | dependent |
+| det | determiner |
+| discourse | discourse element |
+| dobj | direct object |
+| expl | expletive |
+| goeswith | goes with |
+| iobj | indirect object |
+| mark | marker |
+| mod | modifier |
+| mwe | multi-word expression |
+| neg | negation modifier |
+| nn | noun compound modifier |
+| npadvmod | noun phrase as adverbial modifier |
+| nsubj | nominal subject |
+| nsubjpass | passive nominal subject |
+| num | numeric modifier |
+| number | element of compound number |
+| obj | object |
+| parataxis | parataxis |
+| pcomp | prepositional complement |
+| pobj | object of a preposition |
+| poss | possession modifier |
+| possessive | possessive modifier |
+| preconj | preconjunct |
+| pred | predicate |
+| predet | predeterminer |
+| prep | prepositional modifier |
+| prepc | prepositional clausal modifier |
+| prt | phrasal verb particle |
+| punct | punctuation |
+| purpcl | purpose clause modifier |
+| quantmod | quantifier phrase modifier |
+| rcmod | relative clause modifier |
+| ref | referent |
+| rel | relative |
+| root | root |
+| sdep | semantic dependent |
+| subj | subject |
+| tmod | temporal modifier |
+| vmod | verb modifier |
+| xcomp | open clausal complement |
+| xsubj | controlling subject |
+
+## Chinese
+
+| Tag | Description |
+|------------|-----------------------------------|
+| abbrev | abbreviation modifier |
+| acomp | adjectival complement |
+| advcl | adverbial clause modifier |
+| advmod | adverbial modifier |
+| agent | agent |
+| amod | adjectival modifier |
+| appos | appositional modifier |
+| arg | argument |
+| attr | attributive |
+| aux | auxiliary |
+| auxpass | passive auxiliary |
+| cc | coordination |
+| ccomp | clausal complement |
+| comp | complement |
+| complm | complementizer |
+| conj | conjunct |
+| cop | copula |
+| csubj | clausal subject |
+| csubjpass | clausal passive subject |
+| dep | dependent |
+| det | determiner |
+| discourse | discourse element |
+| dobj | direct object |
+| expl | expletive |
+| goeswith | goes with |
+| iobj | indirect object |
+| mark | marker |
+| mod | modifier |
+| mwe | multi-word expression |
+| neg | negation modifier |
+| nn | noun compound modifier |
+| npadvmod | noun phrase as adverbial modifier |
+| nsubj | nominal subject |
+| nsubjpass | passive nominal subject |
+| num | numeric modifier |
+| number | element of compound number |
+| obj | object |
+| parataxis | parataxis |
+| pcomp | prepositional complement |
+| pobj | object of a preposition |
+| poss | possession modifier |
+| possessive | possessive modifier |
+| preconj | preconjunct |
+| pred | predicate |
+| predet | predeterminer |
+| prep | prepositional modifier |
+| prepc | prepositional clausal modifier |
+| prt | phrasal verb particle |
+| punct | punctuation |
+| purpcl | purpose clause modifier |
+| quantmod | quantifier phrase modifier |
+| rcmod | relative clause modifier |
+| ref | referent |
+| rel | relative |
+| root | root |
+| sdep | semantic dependent |
+| subj | subject |
+| tmod | temporal modifier |
+| vmod | verb modifier |
+| xcomp | open clausal complement |
\ No newline at end of file
diff --git a/docs/annotations/dep/ud.md b/docs/annotations/dep/ud.md
new file mode 100644
index 000000000..0971c9ccf
--- /dev/null
+++ b/docs/annotations/dep/ud.md
@@ -0,0 +1,67 @@
+
+
+# Universal Dependencies
+
+See also [Universal Dependencies](https://universaldependencies.org/docs/u/dep/index.html).
+
+| Tag | Description |
+|------------|----------------------------------------------|
+| acl | clausal modifier of noun (adjectival clause) |
+| advcl | adverbial clause modifier |
+| advmod | adverbial modifier |
+| amod | adjectival modifier |
+| appos | appositional modifier |
+| aux | auxiliary |
+| auxpass | passive auxiliary |
+| case | case marking |
+| cc | coordinating conjunction |
+| ccomp | clausal complement |
+| compound | compound |
+| conj | conjunct |
+| cop | copula |
+| csubj | clausal subject |
+| csubjpass | clausal passive subject |
+| dep | unspecified dependency |
+| det | determiner |
+| discourse | discourse element |
+| dislocated | dislocated elements |
+| dobj | direct object |
+| expl | expletive |
+| foreign | foreign words |
+| goeswith | goes with |
+| iobj | indirect object |
+| list | list |
+| mark | marker |
+| mwe | multi-word expression |
+| name | name |
+| neg | negation modifier |
+| nmod | nominal modifier |
+| nsubj | nominal subject |
+| nsubjpass | passive nominal subject |
+| nummod | numeric modifier |
+| parataxis | parataxis |
+| punct | punctuation |
+| remnant | remnant in ellipsis |
+| reparandum | overridden disfluency |
+| root | root |
+| vocative | vocative |
+| xcomp | open clausal complement |
\ No newline at end of file
diff --git a/docs/annotations/index.md b/docs/annotations/index.md
new file mode 100644
index 000000000..55d1c55cf
--- /dev/null
+++ b/docs/annotations/index.md
@@ -0,0 +1,12 @@
+# Annotations
+
+
+```{toctree}
+pos/index
+ner/index
+dep/index
+sdp/index
+srl/index
+con/index
+```
+
diff --git a/docs/annotations/ner/index.md b/docs/annotations/ner/index.md
new file mode 100644
index 000000000..d81f89dba
--- /dev/null
+++ b/docs/annotations/ner/index.md
@@ -0,0 +1,9 @@
+# Named Entity Recognition
+
+
+```{toctree}
+pku
+msra
+ontonotes
+```
+
diff --git a/docs/annotations/ner/msra.md b/docs/annotations/ner/msra.md
new file mode 100644
index 000000000..ecee90b12
--- /dev/null
+++ b/docs/annotations/ner/msra.md
@@ -0,0 +1,56 @@
+
+
+# msra
+
+| Category | Subcategory | Tag-set of Format-1 | Tag-set of Format-2 |
+|----------|----------------|---------------------|---------------------|
+| NAMEX | Person | P | PERSON |
+| | Location | L | LOCATION |
+| | Organization | 〇 | ORGANIZATION |
+| TIMEX | Date | dat | DATE |
+| | Duration | dur | DURATION |
+| | Time | tim | TIME |
+| NUMEX | Percent | per | PERCENT |
+| | Money | mon | MONEY |
+| | Frequency | fre | FREQUENCY |
+| | Integer | int | INTEGER |
+| | Fraction | fra | FRACTION |
+| | Decimal | dec | DECIMAL |
+| | Ordinal | ord | ORDINAL |
+| | Rate | rat | RATE |
+| MEASUREX | Age | age | AGE |
+| | Weight | wei | WEIGHT |
+| | Length | len | LENGTH |
+| | Temperature | tem | TEMPERATURE |
+| | Angle | ang | ANGLE |
+| | Area | are | AREA |
+| | Capacity | cap | CAPACITY |
+| | Speed | spe | SPEED |
+| | Acceleration | acc | ACCELERATION |
+| | Other measures | mea | MEASURE |
+| ADDREX | Email | ema | EMAIL |
+| | Phone | pho | PHONE |
+| | Fax | fax | FAX |
+| | Telex | tel | TELEX |
+| | WWW | WWW | WWW |
+| | Postalcode | pos | POSTALCODE |
+
diff --git a/docs/annotations/ner/ontonotes.md b/docs/annotations/ner/ontonotes.md
new file mode 100644
index 000000000..d2e0e5d59
--- /dev/null
+++ b/docs/annotations/ner/ontonotes.md
@@ -0,0 +1,43 @@
+
+
+# ontonotes
+
+| TAG | Description |
+|--------------|------------------------------------------------------|
+| PERSON | People, including fictional |
+| NORP | Nationalities or religious or political groups |
+| FACILITY | Buildings, airports, highways, bridges, etc. |
+| ORGANIZATION | Companies, agencies, institutions, etc. |
+| GPE | Countries, cities, states |
+| LOCATION | Non-GPE locations, mountain ranges, bodies of water |
+| PRODUCT | Vehicles, weapons, foods, etc. (Not services) |
+| EVENT | Named hurricanes, battles, wars, sports events, etc. |
+| WORK OF ART | Titles of books, songs, etc. |
+| LAW | Named documents made into laws |
+| DATE | Absolute or relative dates or periods |
+| TIME | Times smaller than a day |
+| PERCENT | Percentage |
+| MONEY | Monetary values, including unit |
+| QUANTITY | Measurements, as of weight or distance |
+| ORDINAL | “first”, “second” |
+| CARDINAL | Numerals that do not fall under another type |
+
diff --git a/docs/annotations/ner/pku.md b/docs/annotations/ner/pku.md
new file mode 100644
index 000000000..c096a2876
--- /dev/null
+++ b/docs/annotations/ner/pku.md
@@ -0,0 +1,28 @@
+
+
+# pku
+
+| 序号 | 词性 | 名称 | 帮助记忆的诠释 | 例子及注解 |
+| ---- | ---- | -------- | ------------------------------------------------------ | ------------------------------------------------------------ |
+| 1 | nr | 人名 | 名词代码n和“人(ren)”的声母并在一起。 | 1. 汉族人及与汉族起名方式相同的非汉族人的姓和名单独切分,并分别标注为nr。张/nr 仁伟/nr, 欧阳/nr 修/nr, 阮/nr 志雄/nr, 朴/nr 贞爱/nr汉族人除有单姓和复姓外,还有双姓,即有的女子出嫁后,在原来的姓上加上丈夫的姓。如:陈方安生。这种情况切分、标注为:陈/nr 方/nr 安生/nr;唐姜氏,切分、标注为:唐/nr 姜氏/nr。2. 姓名后的职务、职称或称呼要分开。江/nr 主席/n, 小平/nr 同志/n, 江/nr 总书记/n,张/nr 教授/n, 王/nr 部长/n, 陈/nr 老总/n, 李/nr 大娘/n, 刘/nr 阿姨/n, 龙/nr 姑姑/n3. 对人的简称、尊称等若为两个字,则合为一个切分单位,并标以nr。老张/nr, 大李/nr, 小郝/nr, 郭老/nr, 陈总/nr4. 明显带排行的亲属称谓要切分开,分不清楚的则不切开。三/m 哥/n, 大婶/n, 大/a 女儿/n, 大哥/n, 小弟/n, 老爸/n5. 一些著名作者的或不易区分姓和名的笔名通常作为一个切分单位。鲁迅/nr, 茅盾/nr, 巴金/nr, 三毛/nr, 琼瑶/nr, 白桦/nr6. 外国人或少数民族的译名(包括日本人的姓名)不予切分,标注为nr。克林顿/nr, 叶利钦/nr, 才旦卓玛/nr, 小林多喜二/nr, 北研二/nr,华盛顿/nr, 爱因斯坦/nr有些西方人的姓名中有小圆点,也不分开。卡尔·马克思/nr |
+| 2 | ns | 地名 | 名词代码n和处所词代码s并在一起。 | 安徽/ns,深圳/ns,杭州/ns,拉萨/ns,哈尔滨/ns, 呼和浩特/ns, 乌鲁木齐/ns,长江/ns,黄海/ns,太平洋/ns, 泰山/ns, 华山/ns,亚洲/ns, 海南岛/ns,太湖/ns,白洋淀/ns, 俄罗斯/ns,哈萨克斯坦/ns,彼得堡/ns, 伏尔加格勒/ns 1. 国名不论长短,作为一个切分单位。中国/ns, 中华人民共和国/ns, 日本国/ns, 美利坚合众国/ns, 美国/ns2. 地名后有“省”、“市”、“县”、“区”、“乡”、“镇”、“村”、“旗”、“州”、“都”、“府”、“道”等单字的行政区划名称时,不切分开,作为一个切分单位。四川省/ns, 天津市/ns,景德镇/ns沙市市/ns, 牡丹江市/ns,正定县/ns,海淀区/ns, 通州区/ns,东升乡/ns, 双桥镇/ns 南化村/ns,华盛顿州/ns,俄亥俄州/ns,东京都/ns, 大阪府/ns,北海道/ns, 长野县/ns,开封府/ns,宣城县/ns3. 地名后的行政区划有两个以上的汉字,则将地名同行政区划名称切开,不过要将地名同行政区划名称用方括号括起来,并标以短语NS。[芜湖/ns 专区/n] NS,[宣城/ns 地区/n]ns,[内蒙古/ns 自治区/n]NS,[深圳/ns 特区/n]NS, [厦门/ns 经济/n 特区/n]NS, [香港/ns 特别/a 行政区/n]NS,[香港/ns 特区/n]NS, [华盛顿/ns 特区/n]NS,4. 地名后有表示地形地貌的一个字的普通名词,如“江、河、山、洋、海、岛、峰、湖”等,不予切分。鸭绿江/ns,亚马逊河/ns, 喜马拉雅山/ns, 珠穆朗玛峰/ns,地中海/ns,大西洋/ns,洞庭湖/ns, 塞普路斯岛/ns 5. 地名后接的表示地形地貌的普通名词若有两个以上汉字,则应切开。然后将地名同该普通名词标成短语NS。[台湾/ns 海峡/n]NS,[华北/ns 平原/n]NS,[帕米尔/ns 高原/n]NS, [南沙/ns 群岛/n]NS,[京东/ns 大/a 峡谷/n]NS [横断/b 山脉/n]NS6.地名后有表示自然区划的一个字的普通名词,如“ 街,路,道,巷,里,町,庄,村,弄,堡”等,不予切分。 中关村/ns,长安街/ns,学院路/ns, 景德镇/ns, 吴家堡/ns, 庞各庄/ns, 三元里/ns,彼得堡/ns, 北菜市巷/ns, 7.地名后接的表示自然区划的普通名词若有两个以上汉字,则应切开。然后将地名同自然区划名词标成短语NS。[米市/ns 大街/n]NS, [蒋家/nz 胡同/n]NS , [陶然亭/ns 公园/n]NS , 8. 大小地名相连时的标注方式为:北京市/ns 海淀区/ns 海淀镇/ns [南/f 大街/n]NS [蒋家/nz 胡同/n]NS 24/m 号/q , |
+| 3 | nt | 机构团体 | “团”的声母为t,名词代码n和t并在一起。 | (参见2。短语标记说明--NT)联合国/nt,中共中央/nt,国务院/nt, 北京大学/nt1.大多数团体、机构、组织的专有名称一般是短语型的,较长,且含有地名或人名等专名,再组合,标注为短语NT。[中国/ns 计算机/n 学会/n]NT, [香港/ns 钟表业/n 总会/n]NT, [烟台/ns 大学/n]NT, [香港/ns 理工大学/n]NT, [华东/ns 理工大学/n]NT,[合肥/ns 师范/n 学院/n]NT, [北京/ns 图书馆/n]NT, [富士通/nz 株式会社/n]NT, [香山/ns 植物园/n]NT, [安娜/nz 美容院/n]NT,[上海/ns 手表/n 厂/n]NT, [永和/nz 烧饼铺/n]NT,[北京/ns 国安/nz 队/n]NT,2. 对于在国际或中国范围内的知名的唯一的团体、机构、组织的名称即使前面没有专名,也标为nt或NT。联合国/nt,国务院/nt,外交部/nt, 财政部/nt,教育部/nt, 国防部/nt,[世界/n 贸易/n 组织/n]NT, [国家/n 教育/vn 委员会/n]NT,[信息/n 产业/n 部/n]NT,[全国/n 信息/n 技术/n 标准化/vn 委员会/n]NT,[全国/n 总/b 工会/n]NT,[全国/n 人民/n 代表/n 大会/n]NT,美国的“国务院”,其他国家的“外交部、财政部、教育部”,必须在其所属国的国名之后出现时,才联合标注为NT。[美国/ns 国务院/n]NT,[法国/ns 外交部/n]NT,[美/j 国会/n]NT,日本有些政府机构名称很特别,无论是否出现在“日本”国名之后都标为nt。[日本/ns 外务省/nt]NT,[日/j 通产省/nt]NT通产省/nt 3. 前后相连有上下位关系的团体机构组织名称的处理方式如下:[联合国/nt 教科文/j 组织/n]NT, [中国/ns 银行/n 北京/ns 分行/n]NT,[河北省/ns 正定县/ns 西平乐乡/ns 南化村/ns 党支部/n]NT, 当下位名称含有专名(如“北京/ns 分行/n”、“南化村/ns 党支部/n”、“昌平/ns 分校/n”)时,也可脱离前面的上位名称单独标注为NT。[中国/ns 银行/n]NT [北京/ns 分行/n]NT,北京大学/nt [昌平/ns 分校/n]NT,4. 团体、机构、组织名称中用圆括号加注简称时:[宝山/ns 钢铁/n (/w 宝钢/j )/w 总/b 公司/n]NT,[宝山/ns 钢铁/n 总/b 公司/n]NT,(/w 宝钢/j )/w |
\ No newline at end of file
diff --git a/docs/annotations/pos/863.md b/docs/annotations/pos/863.md
new file mode 100644
index 000000000..8c500b924
--- /dev/null
+++ b/docs/annotations/pos/863.md
@@ -0,0 +1,55 @@
+
+
+# 863
+
+| 词性 | 名称 | 说明 | 例子 |
+| ---- | -------------- | ------------------------------------------------ | ------------------------------------------------------------ |
+| ng | 普通名词 | 普通名词(ng),表示事物的名称 | 人 马 书 教师 飞机 电冰箱 阿姨 桌子 木头道德 理论 历史 思想 文化 因素 作风 哲学 |
+| nt | 时间名词 | 时间名词(nt),包括一般所说的时量词 | 年 月 日 分 秒现在 过去 昨天 去年 将来 宋朝 星期一 |
+| nd | 方位名词 | 方位名词(nd),表示位置的相对方向 | 上 下 左 右 前 后 里 外 中 东 西 南 北前边 左面 里头 中间 外部 |
+| nl | 处所名词 | 处所名词(nl),表示处所 | 空中 高处 隔壁 门口 附近 边疆 一旁 野外 |
+| nh | 人名 | 人名(nh),表示人的名称的专有名词 | 华罗庚 阿凡提 诸葛亮 司马相如 松赞干布 卡尔·马克思 |
+| ns | 地名 | 地名(ns),表示地理区域名称的专有名词 | 亚洲 大西洋 地中海 阿尔卑斯山 加拿大中国 北京 浙江 景德镇 呼和浩特 中关村 |
+| nn | 族名 | 族名(nn),表示民族或部落名称的专有名词 | 回族 藏族 壮族 蒙古族 维吾尔族 哈萨克族 |
+| ni | 机构名 | 机构名(ni),表示团体、组织、机构名称的专有名词 | 联合国 教育部 北京大学 中国科学院 |
+| nz | 其他专有名词 | 其他专有名词(nz) | 五粮液 宫爆鸡丁 桑塔纳 |
+| vt | 及物动词 | 及物动词(vt),能够带宾语 | 吃 打 擦 洗 喂 借 送 买 捧 提 填喜欢 告诉 接受 羡慕 考虑 调查 同意 发动 |
+| vi | 不及物动词 | 不及物动词(vi),不能够带宾语 | 病 休息 咳嗽 瘫痪 游泳 睡觉 |
+| vl | 联系动词 | 联系动词(vl),表示关系的判断 | 是 |
+| vu | 能愿动词 | 能愿动词(vu),表示可能、意愿 | 能够 能 应该 可以 可能 情愿 愿意 要 |
+| vd | 趋向动词 | 趋向动词(vd),表示趋向 | (走)上 (趴)下 (进)来 (回)去(跑)上来 (掉)下去 (提)起来 (扔)过去 |
+| aq | 性质形容词 | 性质形容词(aq),表示性质 | 好 高 美 大 勇敢 危险 漂亮 干净 伟大 |
+| as | 状态形容词 | 状态形容词(as),表示状态 | 雪白 黢黑 通红 冰凉 绿油油 亮堂堂 白花花 冷冰冰 |
+| in | 名词性习用语 | 名词性习用语(in) | 海市蜃楼 井底之蛙 蛛丝马迹 |
+| iv | 动词性习用语 | 动词性习用语(iv) | 跑龙套 打官腔 吃老本 与时俱进 励精图治 |
+| ia | 形容词性习用语 | 形容词性习用语(ia) | 丰富多彩 艰苦朴素 光明正大 |
+| ic | 连词性习用语 | 连词性习用语(ic) | 总而言之 由此可见 综上所述 |
+| jn | 名词性缩略语 | 名词性缩略语(jn) | 人大 五四 奥运 |
+| jv | 动词性缩略语 | 动词性缩略语(jv) | 调研 离退休 |
+| ja | 形容词性缩略语 | 形容词性缩略语(ja) | 短平快 高精尖 |
+| gn | 名词性语素字 | 名词性语素字(gn) | 民 农 材 |
+| gv | 动词性语素字 | 动词性语素字(gv) | 抒 究 涤 |
+| ga | 形容词性语素字 | 形容词性语素字(ga) | 殊 遥 伟 |
+| wp | 标点符号 | 标点符号(wp),如: | , 。 、 ; ? ! : “” …… |
+| ws | 非汉字字符串 | 非汉字字符串(ws),如: | office windows |
+| wu | 其他未知的符号 | 其他未知的符号(wu) | |
+
diff --git a/docs/annotations/pos/ctb.md b/docs/annotations/pos/ctb.md
new file mode 100644
index 000000000..7c2940b1c
--- /dev/null
+++ b/docs/annotations/pos/ctb.md
@@ -0,0 +1,62 @@
+
+
+# ctb
+
+[The Part-Of-Speech Tagging Guidelines for the Penn Chinese Treebank (3.0)](https://repository.upenn.edu/cgi/viewcontent.cgi?article=1039&context=ircs_reports).
+
+| Tag | Description | Chinese | Chinese Description | Examples |
+|-----|-------------------------------------------------------|---------|---------------------------------------------------------|-------------------------|
+| AD | adverb | 副词 | 副词 | 仍然、很、大大、约 |
+| AS | aspect marker | 动态助词 | 助词 | 了、着、过 |
+| BA | XX in ba-construction | 把字句 | 当“把”、“将”出现在结构“NP0 + BA + NP1+VP”时的词性 | 把、将 |
+| CC | coordinating conjunction | 并列连接词 | 并列连词 | 与、和、或者、还是 |
+| CD | cardinal number | 概数词 | 数词或表达数量的词 | 一百、好些、若干 |
+| CS | subordinating conjunction | 从属连词 | 从属连词 | 如果、那么、就 |
+| DEC | XX in a relative clause | 补语成分“的” | 当“的”或“之”作补语标记或名词化标记时的词性,其结构为:S/VP DEC {NP},如,喜欢旅游的大学生 | 的、之 |
+| DEG | associative XX | 属格“的” | 当“的”或“之”作所有格时的词性,其结构为:NP/PP/JJ/DT DEG {NP}, 如,他的车、经济的发展 | 的、之 |
+| DER | XX in V-de const. and V-de-R | 表结果的“得” | 当“得”出现在结构“V-得-R”时的词性,如,他跑得很快 | 得 |
+| DEV | XX bevore VP | 表方式的“地” | 当“地”出现在结构“X-地-VP”时的词性,如,高兴地说 | 地 |
+| DT | determiner | 限定词 | 代冠词,通常用来修饰名词 | 这、那、该、每、各 |
+| ETC | for words XX, XX XX | 表示省略 | “等”、“等等”的词性 | 绝 咕咕 寸、寸寸 |
+| FW | foreign words | 外来语 | 外来词 | 卡拉、A型 |
+| IJ | interjection | 句首感叹词 | 感叹词,通常出现在句子首部 | 啊 |
+| JJ | other noun-modifier | 其他名词修饰语 | 形容词 | 共同、新 |
+| LB | X in long bei-const | 长句式表被动 | 当“被”、“叫”、“给”出现在结构“NP0 + LB + NP1+ VP”结构时 的词性,如,他被我训了一顿 | 被、叫、给 |
+| LC | localizer | 方位词 | 方位词 | 前、旁、到、在内 |
+| M | measure word | 量词 | 量词 | 个、群、公里 |
+| MSP | other particle | 其他小品词 | 其他虚词,包括“所”、“以”、“来”和“而”等出现在VP前的词 | 所、以、来、而 |
+| NN | common noun | 其他名词 | 除专有名词和时间名词外的所有名词 | 桌子、生活、经济 |
+| NOI | noise that characters are written in the wrong order | 噪声 | 当“把”、“将”出现在结构“NP0 + BA + NP1+VP”时的词性 | 事/NOI 类/NOI 各/NOI 故/NOI |
+| NR | proper noun | 专有名词 | 专有名词,通常表示地名、人名、机构名等 | 北京、乔丹、微软 |
+| NT | temporal noun | 时间名词 | 表本时间概念的名词 | 一月、汉朝、当今 |
+| OD | ordinal number | 序数词 | 序列词 | 第一百 |
+| ON | onomatopoeia | 象声词 | 象声词 | 哗哗、呼、咯吱 |
+| P | preposition excl. XX and XX | 介词 | 介词 | 从、对、根据 |
+| PN | pronoun | 代词 | 代词,通常用来指代名词 | 我、这些、其、自己 |
+| PU | punctuation | 标点符号 | 标点符号 | ?、。、; |
+| SB | XX in short bei-const | 短句式表被动 | 当“被”、“给”出现在NP0 +SB+ VP结果时的词性,如,他被训了 一顿 | 被、叫 |
+| SP | sentence final particle | 句末助词 | 经常出现在句尾的词 | 吧、呢、啊、啊 |
+| URL | web address | 网址 | 网址 | www.hankcs.com |
+| VA | predicative adjective | 表语形容词 | 可以接在“很”后面的形容词谓语 | 雪白、厉害 |
+| VC | XX | 系动词 | 系词,表示“是”或“非”概念的动词 | 是、为、非 |
+| VE | XX as the main verb | 动词有无 | 表示“有”或“无”概念的动词 | 有、没有、无 |
+| VV | other verb | 其他动词 | 其他普通动词,包括情态词、控制动词、动作动词、心理动词等等 | 可能、要、走、喜欢 |
diff --git a/docs/annotations/pos/index.md b/docs/annotations/pos/index.md
new file mode 100644
index 000000000..03b9f07c0
--- /dev/null
+++ b/docs/annotations/pos/index.md
@@ -0,0 +1,10 @@
+# Part-of-Speech Tagging
+
+
+```{toctree}
+ctb
+pku
+863
+ud
+```
+
diff --git a/docs/annotations/pos/pku.md b/docs/annotations/pos/pku.md
new file mode 100644
index 000000000..6127b50c6
--- /dev/null
+++ b/docs/annotations/pos/pku.md
@@ -0,0 +1,68 @@
+
+
+# pku
+
+| 序号 | 词性 | 名称 | 帮助记忆的诠释 | 例子及注解 |
+| ---- | ---- | -------- | ------------------------------------------------------ | ------------------------------------------------------------ |
+| 1 | Ag | 形语素 | 形容词性语素。形容词代码为a,语素代码g前面置以A。 | 绿色/n 似/d 锦/Ag , |
+| 2 | a | 形容词 | 取英语形容词adjective的第1个字母 | [重要/a 步伐/n]NP ,美丽/a ,看似/v 抽象/a , |
+| 3 | ad | 副形词 | 直接作状语的形容词。形容词代码a和副词代码d并在一起。 | [积极/ad 谋求/v]V-ZZ ,幻象/n 易/ad 逝/Vg , |
+| 4 | an | 名形词 | 具有名词功能的形容词。形容词代码a和名词代码n并在一起。 | [外交/n 和/c 安全/an]NP-BL , |
+| 5 | Bg | 区别语素 | 区别词性语素。区别词代码为b,语素代码g前面置以B。 | 赤/Ag 橙/Bg 黄/a 绿/a 青/a 蓝/a 紫/a , |
+| 6 | b | 区别词 | 取汉字“别”的声母。 | 女/b 司机/n, 金/b 手镯/n, 慢性/b 胃炎/n, 古/b 钱币/n, 副/b 主任/n, 总/b 公司/n 单音节区别词和单音节名词或名语素组合,作为一个词,并标以名词词性n。 |
+| 7 | c | 连词 | 取英语连词conjunction的第1个字母。 | 合作/vn 与/c 伙伴/n |
+| 8 | Dg | 副语素 | 副词性语素。副词代码为d,语素代码g前面置以D。 | 了解/v 甚/Dg 深/a ,煞/Dg 是/v 喜人/a , |
+| 9 | d | 副词 | 取adverb的第2个字母,因其第1个字母已用于形容词。 | 进一步/d 发展/v , |
+| 10 | e | 叹词 | 取英语叹词exclamation的第1个字母。 | 啊/e ,/w 那/r 金灿灿/z 的/u 麦穗/n , |
+| 11 | f | 方位词 | 取汉字“方”。 | 军人/n 的/u 眼睛/n 里/f 不/d 是/v 没有/v 风景/n , |
+| 12 | h | 前接成分 | 取英语head的第1个字母。 | 许多/m 非/h 主角/n 人物/n ,办事处/n 的/u “/w 准/h 政府/n ”/w 功能/n 不断/d 加强/v , |
+| 13 | i | 成语 | 取英语成语idiom的第1个字母。 | 一言一行/i ,义无反顾/i , |
+| 14 | j | 简称略语 | 取汉字“简”的声母。 | [德/j 外长/n]NP ,文教/j , |
+| 15 | k | 后接成分 | 后接成分。 | 少年儿童/l 朋友/n 们/k ,身体/n 健康/a 者/k , |
+| 16 | l | 习用语 | 习用语尚未成为成语,有点“临时性”,取“临”的声母。 | 少年儿童/l 朋友/n 们/k ,落到实处/l , |
+| 17 | Mg | 数语素 | 数词性语素。数词代码为m,语素代码g前面置以M。 | 甲/Mg 减下/v 的/u 人/n 让/v 乙/Mg 背上/v ,凡/d “/w 寅/Mg 年/n ”/w 中/f 出生/v 的/u 人/n 生肖/n 都/d 属/v 虎/n , |
+| 18 | m | 数词 | 取英语numeral的第3个字母,n,u已有他用。 | 1.数量词组应切分为数词和量词。 三/m 个/q, 10/m 公斤/q, 一/m 盒/q 点心/n ,但少数数量词已是词典的登录单位,则不再切分。 一个/m , 一些/m ,2. 基数、序数、小数、分数、百分数一律不予切分,为一个切分单位,标注为 m 。一百二十三/m,20万/m, 123.54/m, 一个/m, 第一/m, 第三十五/m, 20%/m, 三分之二/m, 千分之三十/m, 几十/m 人/n, 十几万/m 元/q, 第一百零一/m 个/q ,3. 约数,前加副词、形容词或后加“来、多、左右”等助数词的应予分开。约/d 一百/m 多/m 万/m,仅/d 一百/m 个/q, 四十/m 来/m 个/q,二十/m 余/m 只/q, 十几/m 个/q,三十/m 左右/m ,两个数词相连的及“成百”、“上千”等则不予切分。五六/m 年/q, 七八/m 天/q,十七八/m 岁/q, 成百/m 学生/n,上千/m 人/n, 4.表序关系的“数+名”结构,应予切分。二/m 连/n , 三/m 部/n , |
+| 19 | Ng | 名语素 | 名词性语素。名词代码为n,语素代码g前面置以N。 | 出/v 过/u 两/m 天/q 差/Ng, 理/v 了/u 一/m 次/q 发/Ng, |
+| 20 | n | 名词 | 取英语名词noun的第1个字母。 | (参见 动词--v)岗位/n , 城市/n , 机会/n ,她/r 是/v 责任/n 编辑/n , |
+| 21 | nr | 人名 | 名词代码n和“人(ren)”的声母并在一起。 | 1. 汉族人及与汉族起名方式相同的非汉族人的姓和名单独切分,并分别标注为nr。张/nr 仁伟/nr, 欧阳/nr 修/nr, 阮/nr 志雄/nr, 朴/nr 贞爱/nr汉族人除有单姓和复姓外,还有双姓,即有的女子出嫁后,在原来的姓上加上丈夫的姓。如:陈方安生。这种情况切分、标注为:陈/nr 方/nr 安生/nr;唐姜氏,切分、标注为:唐/nr 姜氏/nr。2. 姓名后的职务、职称或称呼要分开。江/nr 主席/n, 小平/nr 同志/n, 江/nr 总书记/n,张/nr 教授/n, 王/nr 部长/n, 陈/nr 老总/n, 李/nr 大娘/n, 刘/nr 阿姨/n, 龙/nr 姑姑/n3. 对人的简称、尊称等若为两个字,则合为一个切分单位,并标以nr。老张/nr, 大李/nr, 小郝/nr, 郭老/nr, 陈总/nr4. 明显带排行的亲属称谓要切分开,分不清楚的则不切开。三/m 哥/n, 大婶/n, 大/a 女儿/n, 大哥/n, 小弟/n, 老爸/n5. 一些著名作者的或不易区分姓和名的笔名通常作为一个切分单位。鲁迅/nr, 茅盾/nr, 巴金/nr, 三毛/nr, 琼瑶/nr, 白桦/nr6. 外国人或少数民族的译名(包括日本人的姓名)不予切分,标注为nr。克林顿/nr, 叶利钦/nr, 才旦卓玛/nr, 小林多喜二/nr, 北研二/nr,华盛顿/nr, 爱因斯坦/nr有些西方人的姓名中有小圆点,也不分开。卡尔·马克思/nr |
+| 22 | ns | 地名 | 名词代码n和处所词代码s并在一起。 | 安徽/ns,深圳/ns,杭州/ns,拉萨/ns,哈尔滨/ns, 呼和浩特/ns, 乌鲁木齐/ns,长江/ns,黄海/ns,太平洋/ns, 泰山/ns, 华山/ns,亚洲/ns, 海南岛/ns,太湖/ns,白洋淀/ns, 俄罗斯/ns,哈萨克斯坦/ns,彼得堡/ns, 伏尔加格勒/ns 1. 国名不论长短,作为一个切分单位。中国/ns, 中华人民共和国/ns, 日本国/ns, 美利坚合众国/ns, 美国/ns2. 地名后有“省”、“市”、“县”、“区”、“乡”、“镇”、“村”、“旗”、“州”、“都”、“府”、“道”等单字的行政区划名称时,不切分开,作为一个切分单位。四川省/ns, 天津市/ns,景德镇/ns沙市市/ns, 牡丹江市/ns,正定县/ns,海淀区/ns, 通州区/ns,东升乡/ns, 双桥镇/ns 南化村/ns,华盛顿州/ns,俄亥俄州/ns,东京都/ns, 大阪府/ns,北海道/ns, 长野县/ns,开封府/ns,宣城县/ns3. 地名后的行政区划有两个以上的汉字,则将地名同行政区划名称切开,不过要将地名同行政区划名称用方括号括起来,并标以短语NS。[芜湖/ns 专区/n] NS,[宣城/ns 地区/n]ns,[内蒙古/ns 自治区/n]NS,[深圳/ns 特区/n]NS, [厦门/ns 经济/n 特区/n]NS, [香港/ns 特别/a 行政区/n]NS,[香港/ns 特区/n]NS, [华盛顿/ns 特区/n]NS,4. 地名后有表示地形地貌的一个字的普通名词,如“江、河、山、洋、海、岛、峰、湖”等,不予切分。鸭绿江/ns,亚马逊河/ns, 喜马拉雅山/ns, 珠穆朗玛峰/ns,地中海/ns,大西洋/ns,洞庭湖/ns, 塞普路斯岛/ns 5. 地名后接的表示地形地貌的普通名词若有两个以上汉字,则应切开。然后将地名同该普通名词标成短语NS。[台湾/ns 海峡/n]NS,[华北/ns 平原/n]NS,[帕米尔/ns 高原/n]NS, [南沙/ns 群岛/n]NS,[京东/ns 大/a 峡谷/n]NS [横断/b 山脉/n]NS6.地名后有表示自然区划的一个字的普通名词,如“ 街,路,道,巷,里,町,庄,村,弄,堡”等,不予切分。 中关村/ns,长安街/ns,学院路/ns, 景德镇/ns, 吴家堡/ns, 庞各庄/ns, 三元里/ns,彼得堡/ns, 北菜市巷/ns, 7.地名后接的表示自然区划的普通名词若有两个以上汉字,则应切开。然后将地名同自然区划名词标成短语NS。[米市/ns 大街/n]NS, [蒋家/nz 胡同/n]NS , [陶然亭/ns 公园/n]NS , 8. 大小地名相连时的标注方式为:北京市/ns 海淀区/ns 海淀镇/ns [南/f 大街/n]NS [蒋家/nz 胡同/n]NS 24/m 号/q , |
+| 23 | nt | 机构团体 | “团”的声母为t,名词代码n和t并在一起。 | (参见2。短语标记说明--NT)联合国/nt,中共中央/nt,国务院/nt, 北京大学/nt1.大多数团体、机构、组织的专有名称一般是短语型的,较长,且含有地名或人名等专名,再组合,标注为短语NT。[中国/ns 计算机/n 学会/n]NT, [香港/ns 钟表业/n 总会/n]NT, [烟台/ns 大学/n]NT, [香港/ns 理工大学/n]NT, [华东/ns 理工大学/n]NT,[合肥/ns 师范/n 学院/n]NT, [北京/ns 图书馆/n]NT, [富士通/nz 株式会社/n]NT, [香山/ns 植物园/n]NT, [安娜/nz 美容院/n]NT,[上海/ns 手表/n 厂/n]NT, [永和/nz 烧饼铺/n]NT,[北京/ns 国安/nz 队/n]NT,2. 对于在国际或中国范围内的知名的唯一的团体、机构、组织的名称即使前面没有专名,也标为nt或NT。联合国/nt,国务院/nt,外交部/nt, 财政部/nt,教育部/nt, 国防部/nt,[世界/n 贸易/n 组织/n]NT, [国家/n 教育/vn 委员会/n]NT,[信息/n 产业/n 部/n]NT,[全国/n 信息/n 技术/n 标准化/vn 委员会/n]NT,[全国/n 总/b 工会/n]NT,[全国/n 人民/n 代表/n 大会/n]NT,美国的“国务院”,其他国家的“外交部、财政部、教育部”,必须在其所属国的国名之后出现时,才联合标注为NT。[美国/ns 国务院/n]NT,[法国/ns 外交部/n]NT,[美/j 国会/n]NT,日本有些政府机构名称很特别,无论是否出现在“日本”国名之后都标为nt。[日本/ns 外务省/nt]NT,[日/j 通产省/nt]NT通产省/nt 3. 前后相连有上下位关系的团体机构组织名称的处理方式如下:[联合国/nt 教科文/j 组织/n]NT, [中国/ns 银行/n 北京/ns 分行/n]NT,[河北省/ns 正定县/ns 西平乐乡/ns 南化村/ns 党支部/n]NT, 当下位名称含有专名(如“北京/ns 分行/n”、“南化村/ns 党支部/n”、“昌平/ns 分校/n”)时,也可脱离前面的上位名称单独标注为NT。[中国/ns 银行/n]NT [北京/ns 分行/n]NT,北京大学/nt [昌平/ns 分校/n]NT,4. 团体、机构、组织名称中用圆括号加注简称时:[宝山/ns 钢铁/n (/w 宝钢/j )/w 总/b 公司/n]NT,[宝山/ns 钢铁/n 总/b 公司/n]NT,(/w 宝钢/j )/w |
+| 24 | nx | 外文字符 | 外文字符。 | A/nx 公司/n ,B/nx 先生/n ,X/nx 君/Ng ,24/m K/nx 镀金/n ,C/nx 是/v 光速/n ,Windows98/nx ,PentiumIV/nx ,I LOVE THIS GAME/nx , |
+| 25 | nz | 其他专名 | “专”的声母的第1个字母为z,名词代码n和z并在一起。 | (参见2。短语标记说明--NZ)除人名、国名、地名、团体、机构、组织以外的其他专有名词都标以nz。满族/nz,俄罗斯族/nz,汉语/nz,罗马利亚语/nz, 捷克语/nz,中文/nz, 英文/nz, 满人/nz, 哈萨克人/nz, 诺贝尔奖/nz, 茅盾奖/nz, 1.包含专有名称(或简称)的交通线,标以nz;短语型的,标为NZ。津浦路/nz, 石太线/nz, [京/j 九/j 铁路/n]NZ, [京/j 津/j 高速/b 公路/n]NZ, 2. 历史上重要事件、运动等专有名称一般是短语型的,按短语型专有名称处理,标以NZ。[卢沟桥/ns 事件/n]NZ, [西安/ns 事变/n]NZ,[五四/t 运动/n]NZ, [明治/nz 维新/n]NZ,[甲午/t 战争/n]NZ,3.专有名称后接多音节的名词,如“语言”、“文学”、“文化”、“方式”、“精神”等,失去专指性,则应分开。欧洲/ns 语言/n, 法国/ns 文学/n, 西方/ns 文化/n, 贝多芬/nr 交响乐/n, 雷锋/nr 精神/n, 美国/ns 方式/n,日本/ns 料理/n, 宋朝/t 古董/n 4. 商标(包括专名及后接的“牌”、“型”等)是专指的,标以nz,但其后所接的商品仍标以普通名词n。康师傅/nr 方便面/n, 中华牌/nz 香烟/n, 牡丹III型/nz 电视机/n, 联想/nz 电脑/n, 鳄鱼/nz 衬衣/n, 耐克/nz 鞋/n5. 以序号命名的名称一般不认为是专有名称。2/m 号/q 国道/n ,十一/m 届/q 三中全会/j如果前面有专名,合起来作为短语型专名。[中国/ns 101/m 国道/n]NZ, [中共/j 十一/m 届/q 三中全会/j]NZ,6. 书、报、杂志、文档、报告、协议、合同等的名称通常有书名号加以标识,不作为专有名词。由于这些名字往往较长,名字本身按常规处理。《/w 宁波/ns 日报/n 》/w ,《/w 鲁迅/nr 全集/n 》/w,中华/nz 读书/vn 报/n, 杜甫/nr 诗选/n,少数书名、报刊名等专有名称,则不切分。红楼梦/nz, 人民日报/nz,儒林外史/nz 7. 当有些专名无法分辨它们是人名还是地名或机构名时,暂标以nz。[巴黎/ns 贝尔希/nz 体育馆/n]NT,其中“贝尔希”只好暂标为nz。 |
+| 26 | o | 拟声词 | 取英语拟声词onomatopoeia的第1个字母。 | 哈哈/o 一/m 笑/v ,装载机/n 隆隆/o 推进/v , |
+| 27 | p | 介词 | 取英语介词prepositional的第1个字母。 | 对/p 子孙后代/n 负责/v ,以/p 煤/n 养/v 农/Ng ,为/p 治理/v 荒山/n 服务/v , 把/p 青年/n 推/v 上/v 了/u 领导/vn 岗位/n , |
+| 28 | q | 量词 | 取英语quantity的第1个字母。 | (参见数词m)首/m 批/q ,一/m 年/q , |
+| 29 | Rg | 代语素 | 代词性语素。代词代码为r,在语素的代码g前面置以R。 | 读者/n 就/d 是/v 这/r 两/m 棵/q 小树/n 扎根/v 于/p 斯/Rg 、/w 成长/v 于/p 斯/Rg 的/u 肥田/n 沃土/n , |
+| 30 | r | 代词 | 取英语代词pronoun的第2个字母,因p已用于介词。 | 单音节代词“本”、“每”、“各”、“诸”后接单音节名词时,和后接的单音节名词合为代词;当后接双音节名词时,应予切分。本报/r, 每人/r, 本社/r, 本/r 地区/n, 各/r 部门/n |
+| 31 | s | 处所词 | 取英语space的第1个字母。 | 家里/s 的/u 电脑/n 都/d 联通/v 了/u 国际/n 互联网/n ,西部/s 交通/n 咽喉/n , |
+| 32 | Tg | 时语素 | 时间词性语素。时间词代码为t,在语素的代码g前面置以T。 | 3日/t 晚/Tg 在/p 总统府/n 发表/v 声明/n ,尊重/v 现/Tg 执政/vn 当局/n 的/u 权威/n , |
+| 33 | t | 时间词 | 取英语time的第1个字母。 | 1. 年月日时分秒,按年、月、日、时、分、秒切分,标注为t 。1997年/t 3月/t 19日/t 下午/t 2时/t 18分/t若数字后无表示时间的“年、月、日、时、分、秒”等的标为数词m。1998/m 中文/n 信息/n 处理/vn 国际/n 会议/n 2. 历史朝代的名称虽然有专有名词的性质,仍标注为t。西周/t, 秦朝/t, 东汉/t, 南北朝/t, 清代/t“牛年、虎年”等一律不予切分,标注为:牛年/t, 虎年/t, 甲午年/t, 甲午/t 战争/n, 庚子/t 赔款/n, 戊戌/t 变法/n |
+| 34 | u | 助词 | 取英语助词auxiliary。 | [[俄罗斯/ns 和/c 北约/j]NP-BL 之间/f [战略/n 伙伴/n 关系/n]NP 的/u 建立/vn]NP 填平/v 了/u [[欧洲/ns 安全/a 政治/n]NP 的/u 鸿沟/n]NP |
+| 35 | Vg | 动语素 | 动词性语素。动词代码为v。在语素的代码g前面置以V。 | 洗/v 了/u 一个/m 舒舒服服/z 的/u 澡/Vg |
+| 36 | v | 动词 | 取英语动词verb的第一个字母。 | (参见 名词--n)[[[欧盟/j 扩大/v]S 的/u [历史性/n 决定/n]NP]NP 和/c [北约/j 开放/v]S]NP-BL [为/p [创建/v [一/m 种/q 新/a 的/u 欧洲/ns 安全/a 格局/n]NP]VP-SBI]PP-MD [奠定/v 了/u 基础/n]V-SBI ,, |
+| 37 | vd | 副动词 | 直接作状语的动词。动词和副词的代码并在一起。 | 形势/n 会/v 持续/vd 好转/v ,认为/v 是/v 电话局/n 收/v 错/vd 了/u 费/n , |
+| 38 | vn | 名动词 | 指具有名词功能的动词。动词和名词的代码并在一起。 | 引起/v 人们/n 的/u 关注/vn 和/c 思考/vn ,收费/vn 电话/n 的/u 号码/n , |
+| 39 | w | 标点符号 | | ”/w :/w |
+| 40 | x | 非语素字 | 非语素字只是一个符号,字母x通常用于代表未知数、符号。 | |
+| 41 | Yg | 语气语素 | 语气词性语素。语气词代码为y。在语素的代码g前面置以Y。 | 唯/d 大力/d 者/k 能/v 致/v 之/u 耳/Yg |
+| 42 | y | 语气词 | 取汉字“语”的声母。 | 会/v 泄露/v 用户/n 隐私/n 吗/y ,又/d 何在/v 呢/y ? |
+| 43 | z | 状态词 | 取汉字“状”的声母的前一个字母。 | 取得/v 扎扎实实/z 的/u 突破性/n 进展/vn ,四季/n 常青/z 的/u 热带/n 树木/n ,短短/z 几/m 年/q 间, |
\ No newline at end of file
diff --git a/docs/annotations/pos/ud.md b/docs/annotations/pos/ud.md
new file mode 100644
index 000000000..3df741381
--- /dev/null
+++ b/docs/annotations/pos/ud.md
@@ -0,0 +1,44 @@
+
+
+# Universal Dependencies
+
+See also [Universal Dependencies](https://universaldependencies.org/u/pos/).
+
+| Tag | Description |
+|------------|----------------------------------------------|
+| ADJ | adjective |
+| ADP | adposition |
+| ADV | adverb |
+| AUX | auxiliary |
+| CCONJ | coordinating conjunction |
+| DET | determiner |
+| INTJ | interjection |
+| NOUN | noun |
+| NUM | numeral |
+| PART | particle |
+| PRON | pronoun |
+| PROPN | proper noun |
+| PUNCT | punctuation |
+| SCONJ | subordinating conjunction |
+| SYM | symbol |
+| VERB | verb |
+| X | other |
\ No newline at end of file
diff --git a/docs/annotations/sdp/dm.md b/docs/annotations/sdp/dm.md
new file mode 100644
index 000000000..fd81d19b3
--- /dev/null
+++ b/docs/annotations/sdp/dm.md
@@ -0,0 +1,3 @@
+# The reduction of Minimal Recursion Semantics
+
+Please refer to [Minimal Recursion Semantics An Introduction](https://www.cl.cam.ac.uk/~aac10/papers/mrs.pdf).
diff --git a/docs/annotations/sdp/index.md b/docs/annotations/sdp/index.md
new file mode 100644
index 000000000..80a316819
--- /dev/null
+++ b/docs/annotations/sdp/index.md
@@ -0,0 +1,9 @@
+# Semantic Dependency Parsing
+
+```{toctree}
+dm
+pas
+psd
+semeval16
+```
+
diff --git a/docs/annotations/sdp/pas.md b/docs/annotations/sdp/pas.md
new file mode 100644
index 000000000..d07c88399
--- /dev/null
+++ b/docs/annotations/sdp/pas.md
@@ -0,0 +1,3 @@
+# Predicate-Argument Structures
+
+Please refer to [Probabilistic disambiguation models for wide-coverage HPSG parsing](https://www.aclweb.org/anthology/P05-1011.pdf).
diff --git a/docs/annotations/sdp/psd.md b/docs/annotations/sdp/psd.md
new file mode 100644
index 000000000..84b2271b3
--- /dev/null
+++ b/docs/annotations/sdp/psd.md
@@ -0,0 +1,3 @@
+# Prague Czech-English Dependency Treebank
+
+Please refer to [Prague Czech-English Dependency Treebank](http://ufal.mff.cuni.cz/pcedt2.0/en/index.html).
diff --git a/docs/annotations/sdp/semeval16.md b/docs/annotations/sdp/semeval16.md
new file mode 100644
index 000000000..401071204
--- /dev/null
+++ b/docs/annotations/sdp/semeval16.md
@@ -0,0 +1,98 @@
+
+
+# SemEval2016
+
+See also [SemEval-2016 Task 9](https://www.hankcs.com/nlp/sdp-corpus.html).
+
+| 关系类型 | Tag | Description | Example |
+|--------|---------------|--------------------|-----------------------------|
+| 施事关系 | Agt | Agent | 我送她一束花 (我 <– 送) |
+| 当事关系 | Exp | Experiencer | 我跑得快 (跑 –> 我) |
+| 感事关系 | Aft | Affection | 我思念家乡 (思念 –> 我) |
+| 领事关系 | Poss | Possessor | 他有一本好读 (他 <– 有) |
+| 受事关系 | Pat | Patient | 他打了小明 (打 –> 小明) |
+| 客事关系 | Cont | Content | 他听到鞭炮声 (听 –> 鞭炮声) |
+| 成事关系 | Prod | Product | 他写了本小说 (写 –> 小说) |
+| 源事关系 | Orig | Origin | 我军缴获敌人四辆坦克 (缴获 –> 坦克) |
+| 涉事关系 | Datv | Dative | 他告诉我个秘密 ( 告诉 –> 我 ) |
+| 比较角色 | Comp | Comitative | 他成绩比我好 (他 –> 我) |
+| 属事角色 | Belg | Belongings | 老赵有俩女儿 (老赵 <– 有) |
+| 类事角色 | Clas | Classification | 他是中学生 (是 –> 中学生) |
+| 依据角色 | Accd | According | 本庭依法宣判 (依法 <– 宣判) |
+| 缘故角色 | Reas | Reason | 他在愁女儿婚事 (愁 –> 婚事) |
+| 意图角色 | Int | Intention | 为了金牌他拼命努力 (金牌 <– 努力) |
+| 结局角色 | Cons | Consequence | 他跑了满头大汗 (跑 –> 满头大汗) |
+| 方式角色 | Mann | Manner | 球慢慢滚进空门 (慢慢 <– 滚) |
+| 工具角色 | Tool | Tool | 她用砂锅熬粥 (砂锅 <– 熬粥) |
+| 材料角色 | Malt | Material | 她用小米熬粥 (小米 <– 熬粥) |
+| 时间角色 | Time | Time | 唐朝有个李白 (唐朝 <– 有) |
+| 空间角色 | Loc | Location | 这房子朝南 (朝 –> 南) |
+| 历程角色 | Proc | Process | 火车正在过长江大桥 (过 –> 大桥) |
+| 趋向角色 | Dir | Direction | 部队奔向南方 (奔 –> 南) |
+| 范围角色 | Sco | Scope | 产品应该比质量 (比 –> 质量) |
+| 数量角色 | Quan | Quantity | 一年有365天 (有 –> 天) |
+| 数量数组 | Qp | Quantity-phrase | 三本书 (三 –> 本) |
+| 频率角色 | Freq | Frequency | 他每天看书 (每天 <– 看) |
+| 顺序角色 | Seq | Sequence | 他跑第一 (跑 –> 第一) |
+| 描写角色 | Desc(Feat) | Description | 他长得胖 (长 –> 胖) |
+| 宿主角色 | Host | Host | 住房面积 (住房 <– 面积) |
+| 名字修饰角色 | Nmod | Name-modifier | 果戈里大街 (果戈里 <– 大街) |
+| 时间修饰角色 | Tmod | Time-modifier | 星期一上午 (星期一 <– 上午) |
+| 反角色 | r + main role | | 打篮球的小姑娘 (打篮球 <– 姑娘) |
+| 嵌套角色 | d + main role | | 爷爷看见孙子在跑 (看见 –> 跑) |
+| 并列关系 | eCoo | event Coordination | 我喜欢唱歌和跳舞 (唱歌 –> 跳舞) |
+| 选择关系 | eSelt | event Selection | 您是喝茶还是喝咖啡 (茶 –> 咖啡) |
+| 等同关系 | eEqu | event Equivalent | 他们三个人一起走 (他们 –> 三个人) |
+| 先行关系 | ePrec | event Precedent | 首先,先 |
+| 顺承关系 | eSucc | event Successor | 随后,然后 |
+| 递进关系 | eProg | event Progression | 况且,并且 |
+| 转折关系 | eAdvt | event adversative | 却,然而 |
+| 原因关系 | eCau | event Cause | 因为,既然 |
+| 结果关系 | eResu | event Result | 因此,以致 |
+| 推论关系 | eInf | event Inference | 才,则 |
+| 条件关系 | eCond | event Condition | 只要,除非 |
+| 假设关系 | eSupp | event Supposition | 如果,要是 |
+| 让步关系 | eConc | event Concession | 纵使,哪怕 |
+| 手段关系 | eMetd | event Method | |
+| 目的关系 | ePurp | event Purpose | 为了,以便 |
+| 割舍关系 | eAban | event Abandonment | 与其,也不 |
+| 选取关系 | ePref | event Preference | 不如,宁愿 |
+| 总括关系 | eSum | event Summary | 总而言之 |
+| 分叙关系 | eRect | event Recount | 例如,比方说 |
+| 连词标记 | mConj | Recount Marker | 和,或 |
+| 的字标记 | mAux | Auxiliary | 的,地,得 |
+| 介词标记 | mPrep | Preposition | 把,被 |
+| 语气标记 | mTone | Tone | 吗,呢 |
+| 时间标记 | mTime | Time | 才,曾经 |
+| 范围标记 | mRang | Range | 都,到处 |
+| 程度标记 | mDegr | Degree | 很,稍微 |
+| 频率标记 | mFreq | Frequency Marker | 再,常常 |
+| 趋向标记 | mDir | Direction Marker | 上去,下来 |
+| 插入语标记 | mPars | Parenthesis Marker | 总的来说,众所周知 |
+| 否定标记 | mNeg | Negation Marker | 不,没,未 |
+| 情态标记 | mMod | Modal Marker | 幸亏,会,能 |
+| 标点标记 | mPunc | Punctuation Marker | ,。! |
+| 重复标记 | mPept | Repetition Marker | 走啊走 (走 –> 走) |
+| 多数标记 | mMaj | Majority Marker | 们,等 |
+| 实词虚化标记 | mVain | Vain Marker | |
+| 离合标记 | mSepa | Seperation Marker | 吃了个饭 (吃 –> 饭) 洗了个澡 (洗 –> 澡) |
+| 根节点 | Root | Root | 全句核心节点 |
\ No newline at end of file
diff --git a/docs/annotations/srl/cpb.md b/docs/annotations/srl/cpb.md
new file mode 100644
index 000000000..d79a23a67
--- /dev/null
+++ b/docs/annotations/srl/cpb.md
@@ -0,0 +1,45 @@
+
+
+# Chinese Proposition Bank
+
+| | 标签 | 角色 | 例子 |
+|------|----------|-------|-------------------------|
+| 中心角色 | ARG0 | 施事者 | (ARG0 中国政府)提供援助 |
+| | ARG1 | 受事者 | 中国政府提供(ARG1援助) |
+| | ARG2 | 与谓词相关 | 失业率控制(ARG2在百分之十内) |
+| | ARG3 | 与谓词相关 | (ARG3从城市)扩大到农村 |
+| | ARG4 | 与谓词相关 | 提高(ARG4 百分之二十) |
+| 附属角色 | ARGM-ADV | 状语 | (ARGM-ADV共同)承担 |
+| | ARGM-BNF | 受益者 | (ARGM-BNF为其他国家)进行融资 |
+| | ARGM-CND | 条件 | (ARGM-CND如果成功),他就留下 |
+| | ARGM-DIR | 方向 | (ARGM-DIR向和平)迈出一大步 |
+| | ARGM-EXT | 范围 | 在北京逗留 (ARGM-EXT两天) |
+| | ARGM-FRQ | 频率 | 每半年执行(ARGM-FRQ —次) |
+| | ARGM-LOC | 地点、位置 | (ARGM-LOC在机场)被捕获 |
+| | ARGM-MNR | 方式 | (ARGM-MNR以中英文)发行 |
+| | ARGM-PRP | 目的或原因 | (ARGM-PRP由于危机)而破产 |
+| | ARGM-TMP | 时间 | 公司 (ARGM-TMP去年)成立 |
+| | ARGM-TPC | 主题 | (ARGM-TPC稳定政策),核心是... |
+| | ARGM-DIS | 话语标记 | (ARGM-DIS)因此,他感到不公 |
+| | ARGM-CRD | 并列论元 | (ARGM-CRD与台湾)非正式接触 |
+| | ARGM-PRD | 次谓词 | 指控廉政公署五人(ARGM-PRD 接受贿赂) |
+
diff --git a/docs/annotations/srl/index.md b/docs/annotations/srl/index.md
new file mode 100644
index 000000000..64f456d01
--- /dev/null
+++ b/docs/annotations/srl/index.md
@@ -0,0 +1,7 @@
+# Semantic Role Labeling
+
+```{toctree}
+cpb
+propbank
+```
+
diff --git a/docs/annotations/srl/propbank.md b/docs/annotations/srl/propbank.md
new file mode 100644
index 000000000..f1156edcb
--- /dev/null
+++ b/docs/annotations/srl/propbank.md
@@ -0,0 +1,51 @@
+
+
+# English PropBank
+
+| Role | Description |
+|------|----------------------------------------|
+| ARG0 | agent |
+| ARG1 | patient |
+| ARG2 | instrument, benefactive, attribute |
+| ARG3 | starting point, benefactive, attribute |
+| ARG4 | ending point |
+| ARGM | modifier |
+| COM | Comitative |
+| LOC | Locative |
+| DIR | Directional |
+| GOL | Goal |
+| MNR | Manner |
+| TMP | Temporal |
+| EXT | Extent |
+| REC | Reciprocals |
+| PRD | Secondary Predication |
+| PRP | Purpose |
+| CAU | Cause |
+| DIS | Discourse |
+| ADV | Adverbials |
+| ADJ | Adjectival |
+| MOD | Modal |
+| NEG | Negation |
+| DSP | Direct Speech |
+| LVB | Light Verb |
+| CXN | Construction |
+
diff --git a/docs/api/common/configurable.rst b/docs/api/common/configurable.rst
new file mode 100644
index 000000000..f00d6deec
--- /dev/null
+++ b/docs/api/common/configurable.rst
@@ -0,0 +1,11 @@
+.. _api/configurable:
+
+configurable
+====================
+
+
+.. autoclass:: hanlp_common.configurable.Configurable
+ :members:
+
+.. autoclass:: hanlp_common.configurable.AutoConfigurable
+ :members:
diff --git a/docs/api/common/conll.rst b/docs/api/common/conll.rst
new file mode 100644
index 000000000..8ce6d4c0a
--- /dev/null
+++ b/docs/api/common/conll.rst
@@ -0,0 +1,14 @@
+.. _api/conll:
+
+conll
+====================
+
+
+.. autoclass:: hanlp_common.conll.CoNLLWord
+ :members:
+
+.. autoclass:: hanlp_common.conll.CoNLLUWord
+ :members:
+
+.. autoclass:: hanlp_common.conll.CoNLLSentence
+ :members:
\ No newline at end of file
diff --git a/docs/api/common/constant.rst b/docs/api/common/constant.rst
new file mode 100644
index 000000000..d6203cd91
--- /dev/null
+++ b/docs/api/common/constant.rst
@@ -0,0 +1,6 @@
+constant
+====================
+
+
+.. automodule:: hanlp_common.constant
+ :members:
diff --git a/docs/api/common/document.rst b/docs/api/common/document.rst
new file mode 100644
index 000000000..d8f433cd4
--- /dev/null
+++ b/docs/api/common/document.rst
@@ -0,0 +1,9 @@
+.. _api/document:
+
+document
+====================
+
+.. currentmodule:: hanlp_common
+
+.. autoclass:: hanlp_common.document.Document
+ :members:
diff --git a/docs/api/common/index.md b/docs/api/common/index.md
new file mode 100644
index 000000000..2fb12360d
--- /dev/null
+++ b/docs/api/common/index.md
@@ -0,0 +1,11 @@
+# hanlp_common
+
+Common API shared between `hanlp` and `restful`.
+
+```{toctree}
+document
+conll
+configurable
+constant
+```
+
diff --git a/docs/api/hanlp/common/component.rst b/docs/api/hanlp/common/component.rst
new file mode 100644
index 000000000..7b511b44c
--- /dev/null
+++ b/docs/api/hanlp/common/component.rst
@@ -0,0 +1,7 @@
+component
+=================
+
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.component.Component
+ :members:
diff --git a/docs/api/hanlp/common/dataset.md b/docs/api/hanlp/common/dataset.md
new file mode 100644
index 000000000..774811fb5
--- /dev/null
+++ b/docs/api/hanlp/common/dataset.md
@@ -0,0 +1,64 @@
+# dataset
+
+This module provides base definition for datasets, dataloaders and samplers.
+
+## datasets
+
+```{eval-rst}
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.dataset.Transformable
+ :members:
+
+.. autoclass:: hanlp.common.dataset.TransformableDataset
+ :members:
+ :special-members:
+ :exclude-members: __init__, __repr__
+```
+
+## dataloaders
+
+```{eval-rst}
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.dataset.PadSequenceDataLoader
+ :members:
+ :special-members:
+ :exclude-members: __init__, __repr__
+
+.. autoclass:: hanlp.common.dataset.PrefetchDataLoader
+ :members:
+ :special-members:
+ :exclude-members: __init__, __repr__
+```
+
+## samplers
+
+```{eval-rst}
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.dataset.BucketSampler
+ :members:
+
+.. autoclass:: hanlp.common.dataset.KMeansSampler
+ :members:
+
+.. autoclass:: hanlp.common.dataset.SortingSampler
+ :members:
+```
+
+## sampler builders
+
+```{eval-rst}
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.dataset.SamplerBuilder
+ :members:
+
+.. autoclass:: hanlp.common.dataset.SortingSamplerBuilder
+ :members:
+
+.. autoclass:: hanlp.common.dataset.KMeansSamplerBuilder
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/common/index.md b/docs/api/hanlp/common/index.md
new file mode 100644
index 000000000..c185c32cb
--- /dev/null
+++ b/docs/api/hanlp/common/index.md
@@ -0,0 +1,13 @@
+# common
+
+Common base classes.
+
+```{toctree}
+structure
+vocab
+transform
+dataset
+component
+torch_component
+```
+
diff --git a/docs/api/hanlp/common/structure.md b/docs/api/hanlp/common/structure.md
new file mode 100644
index 000000000..695fdaaf2
--- /dev/null
+++ b/docs/api/hanlp/common/structure.md
@@ -0,0 +1,12 @@
+# structure
+
+```{eval-rst}
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.structure.ConfigTracker
+ :members:
+
+.. autoclass:: hanlp.common.structure.History
+ :members:
+
+```
diff --git a/docs/api/hanlp/common/torch_component.md b/docs/api/hanlp/common/torch_component.md
new file mode 100644
index 000000000..bb9faccb7
--- /dev/null
+++ b/docs/api/hanlp/common/torch_component.md
@@ -0,0 +1,9 @@
+# torch_component
+
+```{eval-rst}
+.. currentmodule:: hanlp.common.torch_component
+
+.. autoclass:: hanlp.common.torch_component.TorchComponent
+ :members:
+
+```
diff --git a/docs/api/hanlp/common/transform.md b/docs/api/hanlp/common/transform.md
new file mode 100644
index 000000000..126f961a4
--- /dev/null
+++ b/docs/api/hanlp/common/transform.md
@@ -0,0 +1,9 @@
+# transform
+
+```{eval-rst}
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.transform.VocabDict
+ :members:
+
+```
diff --git a/docs/api/hanlp/common/vocab.md b/docs/api/hanlp/common/vocab.md
new file mode 100644
index 000000000..0d8c5ad37
--- /dev/null
+++ b/docs/api/hanlp/common/vocab.md
@@ -0,0 +1,11 @@
+# vocab
+
+```{eval-rst}
+.. currentmodule:: hanlp.common
+
+.. autoclass:: hanlp.common.transform.Vocab
+ :members:
+ :special-members:
+ :exclude-members: __init__, __repr__, __call__, __str__
+
+```
diff --git a/docs/api/hanlp/components/classifiers.md b/docs/api/hanlp/components/classifiers.md
new file mode 100644
index 000000000..ef2485ac8
--- /dev/null
+++ b/docs/api/hanlp/components/classifiers.md
@@ -0,0 +1,9 @@
+# classifiers
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.classifiers
+
+.. autoclass:: hanlp.components.classifiers.transformer_classifier.TransformerClassifier
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/eos.md b/docs/api/hanlp/components/eos.md
new file mode 100644
index 000000000..f4e69b9fd
--- /dev/null
+++ b/docs/api/hanlp/components/eos.md
@@ -0,0 +1,9 @@
+# eos
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.eos
+
+.. autoclass:: hanlp.components.eos.ngram.NgramSentenceBoundaryDetector
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/index.md b/docs/api/hanlp/components/index.md
new file mode 100644
index 000000000..8a1d627a2
--- /dev/null
+++ b/docs/api/hanlp/components/index.md
@@ -0,0 +1,16 @@
+# components
+
+NLP components.
+
+```{toctree}
+mtl/index
+classifiers
+eos
+tokenizers/index
+lemmatizer
+taggers/index
+ner/index
+parsers/index
+srl/index
+```
+
diff --git a/docs/api/hanlp/components/lemmatizer.md b/docs/api/hanlp/components/lemmatizer.md
new file mode 100644
index 000000000..00f822441
--- /dev/null
+++ b/docs/api/hanlp/components/lemmatizer.md
@@ -0,0 +1,9 @@
+# lemmatizer
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.lemmatizer
+
+.. autoclass:: TransformerLemmatizer
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/mtl/index.md b/docs/api/hanlp/components/mtl/index.md
new file mode 100644
index 000000000..7f36c9a9b
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/index.md
@@ -0,0 +1,9 @@
+# mtl
+
+Multi-Task Learning (MTL) framework.
+
+```{toctree}
+mtl
+tasks/index
+```
+
diff --git a/docs/api/hanlp/components/mtl/mtl.md b/docs/api/hanlp/components/mtl/mtl.md
new file mode 100644
index 000000000..4993180f5
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/mtl.md
@@ -0,0 +1,9 @@
+# MultiTaskLearning
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.multi_task_learning.MultiTaskLearning
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/con.md b/docs/api/hanlp/components/mtl/tasks/con.md
new file mode 100644
index 000000000..e94018e2e
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/con.md
@@ -0,0 +1,12 @@
+# con
+
+Constituency parsing.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.constituency.CRFConstituencyParsing
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/dep.md b/docs/api/hanlp/components/mtl/tasks/dep.md
new file mode 100644
index 000000000..18e4d70eb
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/dep.md
@@ -0,0 +1,12 @@
+# dep
+
+Dependency parsing.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.dep.BiaffineDependencyParsing
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/index.md b/docs/api/hanlp/components/mtl/tasks/index.md
new file mode 100644
index 000000000..ae9879cc2
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/index.md
@@ -0,0 +1,17 @@
+# tasks
+
+Multi-Task Learning (MTL) tasks.
+
+```{toctree}
+task
+con
+dep
+sdp
+ud
+lem
+pos
+tok
+ner/index
+srl/index
+```
+
diff --git a/docs/api/hanlp/components/mtl/tasks/lem.md b/docs/api/hanlp/components/mtl/tasks/lem.md
new file mode 100644
index 000000000..67c3b223c
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/lem.md
@@ -0,0 +1,12 @@
+# lem
+
+Lemmatization.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.lem.TransformerLemmatization
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/ner/biaffine_ner.md b/docs/api/hanlp/components/mtl/tasks/ner/biaffine_ner.md
new file mode 100644
index 000000000..e38854b2f
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/ner/biaffine_ner.md
@@ -0,0 +1,12 @@
+# biaffine_ner
+
+Biaffine Named Entity Recognition.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.ner.biaffine_ner.BiaffineNamedEntityRecognition
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/ner/index.md b/docs/api/hanlp/components/mtl/tasks/ner/index.md
new file mode 100644
index 000000000..fa9e239fc
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/ner/index.md
@@ -0,0 +1,9 @@
+# ner
+
+Named Entity Recognition.
+
+```{toctree}
+tag_ner
+biaffine_ner
+```
+
diff --git a/docs/api/hanlp/components/mtl/tasks/ner/tag_ner.md b/docs/api/hanlp/components/mtl/tasks/ner/tag_ner.md
new file mode 100644
index 000000000..4dadbeb8e
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/ner/tag_ner.md
@@ -0,0 +1,12 @@
+# tag_ner
+
+Tagging based Named Entity Recognition.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.ner.tag_ner.TaggingNamedEntityRecognition
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/pos.md b/docs/api/hanlp/components/mtl/tasks/pos.md
new file mode 100644
index 000000000..1965c93bf
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/pos.md
@@ -0,0 +1,12 @@
+# pos
+
+Part-of-speech tagging.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.pos.TransformerTagging
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/sdp.md b/docs/api/hanlp/components/mtl/tasks/sdp.md
new file mode 100644
index 000000000..39076e388
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/sdp.md
@@ -0,0 +1,12 @@
+# sdp
+
+Semantic Dependency Parsing.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.sdp.BiaffineSemanticDependencyParsing
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/srl/bio_srl.md b/docs/api/hanlp/components/mtl/tasks/srl/bio_srl.md
new file mode 100644
index 000000000..8398f304f
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/srl/bio_srl.md
@@ -0,0 +1,12 @@
+# bio_srl
+
+BIO Tagging based Semantic Role Labeling.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.srl.bio_srl.SpanBIOSemanticRoleLabeling
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/srl/index.md b/docs/api/hanlp/components/mtl/tasks/srl/index.md
new file mode 100644
index 000000000..2c5b03e81
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/srl/index.md
@@ -0,0 +1,9 @@
+# srl
+
+Semantic Role Labeling.
+
+```{toctree}
+bio_srl
+rank_srl
+```
+
diff --git a/docs/api/hanlp/components/mtl/tasks/srl/rank_srl.md b/docs/api/hanlp/components/mtl/tasks/srl/rank_srl.md
new file mode 100644
index 000000000..dd34106c3
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/srl/rank_srl.md
@@ -0,0 +1,12 @@
+# rank_srl
+
+Span Ranking Semantic Role Labeling.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.srl.rank_srl.SpanRankingSemanticRoleLabeling
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/task.md b/docs/api/hanlp/components/mtl/tasks/task.md
new file mode 100644
index 000000000..3349cd721
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/task.md
@@ -0,0 +1,10 @@
+# Task
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.Task
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/tok.md b/docs/api/hanlp/components/mtl/tasks/tok.md
new file mode 100644
index 000000000..939229c83
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/tok.md
@@ -0,0 +1,12 @@
+# tok
+
+Tokenization.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.tok.tag_tok.TaggingTokenization
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/mtl/tasks/ud.md b/docs/api/hanlp/components/mtl/tasks/ud.md
new file mode 100644
index 000000000..7334fc2b3
--- /dev/null
+++ b/docs/api/hanlp/components/mtl/tasks/ud.md
@@ -0,0 +1,12 @@
+# ud
+
+Universal Dependencies Parsing (lemmatization, features, PoS tagging and dependency parsing).
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.mtl
+
+.. autoclass:: hanlp.components.mtl.tasks.ud.UniversalDependenciesParsing
+ :members:
+ :exclude-members: execute_training_loop, fit_dataloader
+
+```
diff --git a/docs/api/hanlp/components/ner/biaffine_ner.md b/docs/api/hanlp/components/ner/biaffine_ner.md
new file mode 100644
index 000000000..68021686d
--- /dev/null
+++ b/docs/api/hanlp/components/ner/biaffine_ner.md
@@ -0,0 +1,11 @@
+# biaffine_ner
+
+Biaffine Named Entity Recognition.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.ner.transformer_ner
+
+.. autoclass:: hanlp.components.ner.biaffine_ner.biaffine_ner.BiaffineNamedEntityRecognizer
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/ner/index.md b/docs/api/hanlp/components/ner/index.md
new file mode 100644
index 000000000..45415bfca
--- /dev/null
+++ b/docs/api/hanlp/components/ner/index.md
@@ -0,0 +1,10 @@
+# ner
+
+Named Entity Recognition.
+
+```{toctree}
+transformer_ner
+rnn_ner
+biaffine_ner
+```
+
diff --git a/docs/api/hanlp/components/ner/rnn_ner.md b/docs/api/hanlp/components/ner/rnn_ner.md
new file mode 100644
index 000000000..d575f2ca0
--- /dev/null
+++ b/docs/api/hanlp/components/ner/rnn_ner.md
@@ -0,0 +1,11 @@
+# rnn_ner
+
+Tagging based Named Entity Recognition.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.ner.rnn_ner
+
+.. autoclass:: hanlp.components.ner.rnn_ner.RNNNamedEntityRecognizer
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/ner/transformer_ner.md b/docs/api/hanlp/components/ner/transformer_ner.md
new file mode 100644
index 000000000..d0a81bfae
--- /dev/null
+++ b/docs/api/hanlp/components/ner/transformer_ner.md
@@ -0,0 +1,11 @@
+# transformer_ner
+
+Tagging based Named Entity Recognition.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.ner.transformer_ner
+
+.. autoclass:: hanlp.components.ner.transformer_ner.TransformerNamedEntityRecognizer
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/parsers/biaffine_dep.md b/docs/api/hanlp/components/parsers/biaffine_dep.md
new file mode 100644
index 000000000..d3b2dfea8
--- /dev/null
+++ b/docs/api/hanlp/components/parsers/biaffine_dep.md
@@ -0,0 +1,11 @@
+# biaffine_dep
+
+Biaffine dependency parser.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components
+
+.. autoclass:: hanlp.components.parsers.biaffine.biaffine_dep.BiaffineDependencyParser
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/parsers/biaffine_sdp.md b/docs/api/hanlp/components/parsers/biaffine_sdp.md
new file mode 100644
index 000000000..b92ffa8f0
--- /dev/null
+++ b/docs/api/hanlp/components/parsers/biaffine_sdp.md
@@ -0,0 +1,11 @@
+# biaffine_sdp
+
+Biaffine dependency parser.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components
+
+.. autoclass:: hanlp.components.parsers.biaffine.biaffine_sdp.BiaffineSemanticDependencyParser
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/parsers/crf_constituency_parser.md b/docs/api/hanlp/components/parsers/crf_constituency_parser.md
new file mode 100644
index 000000000..c01d9d48f
--- /dev/null
+++ b/docs/api/hanlp/components/parsers/crf_constituency_parser.md
@@ -0,0 +1,11 @@
+# crf_constituency_parser
+
+Biaffine dependency parser.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components
+
+.. autoclass:: hanlp.components.parsers.constituency.crf_constituency_parser.CRFConstituencyParser
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/parsers/index.md b/docs/api/hanlp/components/parsers/index.md
new file mode 100644
index 000000000..0e83569a8
--- /dev/null
+++ b/docs/api/hanlp/components/parsers/index.md
@@ -0,0 +1,11 @@
+# parsers
+
+Parsers.
+
+```{toctree}
+biaffine_dep
+biaffine_sdp
+ud_parser
+crf_constituency_parser
+```
+
diff --git a/docs/api/hanlp/components/parsers/ud_parser.md b/docs/api/hanlp/components/parsers/ud_parser.md
new file mode 100644
index 000000000..88973b0e4
--- /dev/null
+++ b/docs/api/hanlp/components/parsers/ud_parser.md
@@ -0,0 +1,11 @@
+# ud_parser
+
+Universal Dependencies Parsing (lemmatization, features, PoS tagging and dependency parsing).
+
+```{eval-rst}
+.. currentmodule:: hanlp.components
+
+.. autoclass:: hanlp.components.parsers.ud.ud_parser.UniversalDependenciesParser
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/srl/index.md b/docs/api/hanlp/components/srl/index.md
new file mode 100644
index 000000000..0fdbef3c5
--- /dev/null
+++ b/docs/api/hanlp/components/srl/index.md
@@ -0,0 +1,9 @@
+# srl
+
+Semantic Role Labelers.
+
+```{toctree}
+span_rank
+span_bio
+```
+
diff --git a/docs/api/hanlp/components/srl/span_bio.md b/docs/api/hanlp/components/srl/span_bio.md
new file mode 100644
index 000000000..06fc0a408
--- /dev/null
+++ b/docs/api/hanlp/components/srl/span_bio.md
@@ -0,0 +1,11 @@
+# span_bio
+
+Span BIO tagging based SRL.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.srl.span_bio.span_bio
+
+.. autoclass:: SpanBIOSemanticRoleLabeler
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/srl/span_rank.md b/docs/api/hanlp/components/srl/span_rank.md
new file mode 100644
index 000000000..4531ec88d
--- /dev/null
+++ b/docs/api/hanlp/components/srl/span_rank.md
@@ -0,0 +1,11 @@
+# span_rank
+
+Span Rank based SRL.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.srl.span_rank.span_rank
+
+.. autoclass:: SpanRankingSemanticRoleLabeler
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/taggers/index.md b/docs/api/hanlp/components/taggers/index.md
new file mode 100644
index 000000000..5eb65b8dd
--- /dev/null
+++ b/docs/api/hanlp/components/taggers/index.md
@@ -0,0 +1,9 @@
+# taggers
+
+Taggers.
+
+```{toctree}
+transformer_tagger
+rnn_tagger
+```
+
diff --git a/docs/api/hanlp/components/taggers/rnn_tagger.md b/docs/api/hanlp/components/taggers/rnn_tagger.md
new file mode 100644
index 000000000..0178efc59
--- /dev/null
+++ b/docs/api/hanlp/components/taggers/rnn_tagger.md
@@ -0,0 +1,11 @@
+# rnn_tagger
+
+RNN based tagger.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components
+
+.. autoclass:: hanlp.components.taggers.rnn_tagger.RNNTagger
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/taggers/transformer_tagger.md b/docs/api/hanlp/components/taggers/transformer_tagger.md
new file mode 100644
index 000000000..3f28ac08d
--- /dev/null
+++ b/docs/api/hanlp/components/taggers/transformer_tagger.md
@@ -0,0 +1,11 @@
+# transformer_tagger
+
+Transformer based tagger.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components
+
+.. autoclass:: hanlp.components.taggers.transformers.transformer_tagger.TransformerTagger
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/tokenizers/index.md b/docs/api/hanlp/components/tokenizers/index.md
new file mode 100644
index 000000000..24fdd3c1c
--- /dev/null
+++ b/docs/api/hanlp/components/tokenizers/index.md
@@ -0,0 +1,9 @@
+# tokenizers
+
+Tokenizers.
+
+```{toctree}
+transformer
+multi_criteria
+```
+
diff --git a/docs/api/hanlp/components/tokenizers/multi_criteria.md b/docs/api/hanlp/components/tokenizers/multi_criteria.md
new file mode 100644
index 000000000..8d95daf1f
--- /dev/null
+++ b/docs/api/hanlp/components/tokenizers/multi_criteria.md
@@ -0,0 +1,11 @@
+# multi_criteria
+
+Transformer based Multi-Criteria Word tokenizer.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.tokenizers.multi_criteria_cws_transformer
+
+.. autoclass:: hanlp.components.tokenizers.multi_criteria_cws_transformer.MultiCriteriaTransformerTaggingTokenizer
+ :members:
+
+```
diff --git a/docs/api/hanlp/components/tokenizers/transformer.md b/docs/api/hanlp/components/tokenizers/transformer.md
new file mode 100644
index 000000000..5078b5f06
--- /dev/null
+++ b/docs/api/hanlp/components/tokenizers/transformer.md
@@ -0,0 +1,11 @@
+# transformer
+
+Transformer based tokenizer.
+
+```{eval-rst}
+.. currentmodule:: hanlp.components.tokenizers.transformer
+
+.. autoclass:: hanlp.components.tokenizers.transformer.TransformerTaggingTokenizer
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/con/constituency_dataset.md b/docs/api/hanlp/datasets/con/constituency_dataset.md
new file mode 100644
index 000000000..a67454b84
--- /dev/null
+++ b/docs/api/hanlp/datasets/con/constituency_dataset.md
@@ -0,0 +1,8 @@
+# constituency_dataset
+
+```{eval-rst}
+
+.. autoclass:: hanlp.components.parsers.constituency.constituency_dataset.ConstituencyDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/con/index.md b/docs/api/hanlp/datasets/con/index.md
new file mode 100644
index 000000000..759d7f019
--- /dev/null
+++ b/docs/api/hanlp/datasets/con/index.md
@@ -0,0 +1,9 @@
+# con
+
+Constituency parsing datasets.
+
+```{toctree}
+constituency_dataset
+resources
+```
+
diff --git a/docs/api/hanlp/datasets/con/resources.md b/docs/api/hanlp/datasets/con/resources.md
new file mode 100644
index 000000000..a000c8ba1
--- /dev/null
+++ b/docs/api/hanlp/datasets/con/resources.md
@@ -0,0 +1,52 @@
+# resources
+
+## Chinese Treebank
+
+
+### CTB8
+
+
+
+````{margin} **Discussion**
+```{seealso}
+About our data split on [our forum](https://bbs.hankcs.com/t/topic/3024).
+```
+````
+
+```{eval-rst}
+
+
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_BRACKET_LINE_NOEC_TRAIN
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_BRACKET_LINE_NOEC_DEV
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_BRACKET_LINE_NOEC_TEST
+
+```
+
+### CTB9
+
+````{margin} **Discussion**
+```{seealso}
+About our data split on [our forum](https://bbs.hankcs.com/t/topic/3024).
+```
+````
+
+```{eval-rst}
+
+
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_BRACKET_LINE_NOEC_TRAIN
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_BRACKET_LINE_NOEC_DEV
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_BRACKET_LINE_NOEC_TEST
+
+```
+
+## English Treebank
+
+### PTB
+
+```{eval-rst}
+
+.. autodata:: hanlp.datasets.parsing.ptb.PTB_TRAIN
+.. autodata:: hanlp.datasets.parsing.ptb.PTB_DEV
+.. autodata:: hanlp.datasets.parsing.ptb.PTB_TEST
+
+```
diff --git a/docs/api/hanlp/datasets/dep/conll_dataset.md b/docs/api/hanlp/datasets/dep/conll_dataset.md
new file mode 100644
index 000000000..ed5c4e71e
--- /dev/null
+++ b/docs/api/hanlp/datasets/dep/conll_dataset.md
@@ -0,0 +1,10 @@
+# conll
+
+```{eval-rst}
+.. currentmodule:: hanlp.datasets.parsing.conll_dataset
+
+
+.. autoclass:: CoNLLParsingDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/dep/index.md b/docs/api/hanlp/datasets/dep/index.md
new file mode 100644
index 000000000..94539efb0
--- /dev/null
+++ b/docs/api/hanlp/datasets/dep/index.md
@@ -0,0 +1,9 @@
+# dep
+
+Dependency parsing datasets.
+
+```{toctree}
+conll_dataset
+resources
+```
+
diff --git a/docs/api/hanlp/datasets/dep/resources.md b/docs/api/hanlp/datasets/dep/resources.md
new file mode 100644
index 000000000..9ed134d1b
--- /dev/null
+++ b/docs/api/hanlp/datasets/dep/resources.md
@@ -0,0 +1,109 @@
+# resources
+
+## Chinese Treebank
+
+### CTB5
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.parsing.ctb5
+ :members:
+
+```
+
+### CTB7
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.parsing.ctb7
+ :members:
+
+```
+
+### CTB8
+
+```{eval-rst}
+
+.. Attention::
+
+ We propose a new data split for CTB which is different from the academia conventions with the following 3 advantages.
+
+ - Easy to reproduce. Files ending with ``8`` go to dev set, ending with ``9`` go to the test set, otherwise go to the training set.
+ - Full use of CTB8. The academia conventional split omits 50 gold files while we recall them.
+ - More balanced split across genres. Proportions of samples in each genres are similar.
+
+ We also use Stanford Dependencies 3.3.0 which offers fine-grained relations and more grammars than the conventional
+ head finding rules introduced by :cite:`zhang-clark-2008-tale`.
+
+ Therefore, scores on our preprocessed CTB8 are not directly comparable to those in most literatures. We have
+ experimented the same model on the conventionally baked CTB8 and the scores could be 4~5 points higher.
+ We believe it's worthy since HanLP is made for practical purposes, not just for producing pretty numbers.
+
+```
+
+````{margin} **Discussion**
+```{seealso}
+We have a discussion on [our forum](https://bbs.hankcs.com/t/topic/3024).
+```
+````
+
+```{eval-rst}
+
+
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_SD330_TRAIN
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_SD330_DEV
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_SD330_TEST
+
+```
+
+### CTB9
+
+```{eval-rst}
+
+.. Attention::
+
+ Similar preprocessing and splits with CTB8 are applied. See the notice above.
+
+```
+
+
+```{eval-rst}
+
+
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_SD330_TRAIN
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_SD330_DEV
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_SD330_TEST
+
+```
+
+## English Treebank
+
+### PTB
+
+```{eval-rst}
+
+.. autodata:: hanlp.datasets.parsing.ptb.PTB_SD330_TRAIN
+.. autodata:: hanlp.datasets.parsing.ptb.PTB_SD330_DEV
+.. autodata:: hanlp.datasets.parsing.ptb.PTB_SD330_TEST
+
+```
+
+## Universal Dependencies
+
+### Languages
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.parsing.ud.ud27
+ :members:
+
+```
+
+### Multilingual
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.parsing.ud.ud27m
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/eos/eos.md b/docs/api/hanlp/datasets/eos/eos.md
new file mode 100644
index 000000000..459af2a41
--- /dev/null
+++ b/docs/api/hanlp/datasets/eos/eos.md
@@ -0,0 +1,9 @@
+# eos
+
+```{eval-rst}
+.. currentmodule:: hanlp.datasets.eos.eos
+
+.. autoclass:: SentenceBoundaryDetectionDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/eos/index.md b/docs/api/hanlp/datasets/eos/index.md
new file mode 100644
index 000000000..07c0c0827
--- /dev/null
+++ b/docs/api/hanlp/datasets/eos/index.md
@@ -0,0 +1,9 @@
+# eos
+
+Sentence boundary detection datasets.
+
+```{toctree}
+eos
+resources
+```
+
diff --git a/docs/api/hanlp/datasets/eos/resources.md b/docs/api/hanlp/datasets/eos/resources.md
new file mode 100644
index 000000000..fabd9a7c5
--- /dev/null
+++ b/docs/api/hanlp/datasets/eos/resources.md
@@ -0,0 +1,10 @@
+# resources
+
+## nn_eos
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.eos.nn_eos
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/datasets/index.md b/docs/api/hanlp/datasets/index.md
new file mode 100644
index 000000000..923760839
--- /dev/null
+++ b/docs/api/hanlp/datasets/index.md
@@ -0,0 +1,26 @@
+# datasets
+
+```{eval-rst}
+NLP datasets grouped by tasks. For each task, we provide at least one ``torch.utils.data.Dataset`` compatible class
+and several open-source resources. Their file format and description can be found in their ``Dataset.load_file``
+documents. Their contents are split into ``TRAIN``, ``DEV`` and ``TEST`` portions, each of them is stored in
+a Python constant which can be fetched using :meth:`~hanlp.utils.io_util.get_resource`.
+```
+
+````{margin} **Professionals use Linux**
+```{note}
+Many preprocessing scripts written by professionals make heavy use of Linux/Unix tool chains like shell, perl, gcc,
+etc., which is not available or buggy on Windows. You may need a *nix evironment to run these scripts.
+```
+````
+
+```{toctree}
+eos/index
+tok/index
+pos/index
+ner/index
+dep/index
+srl/index
+con/index
+```
+
diff --git a/docs/api/hanlp/datasets/ner/index.md b/docs/api/hanlp/datasets/ner/index.md
new file mode 100644
index 000000000..b024dab76
--- /dev/null
+++ b/docs/api/hanlp/datasets/ner/index.md
@@ -0,0 +1,10 @@
+# ner
+
+NER datasets.
+
+```{toctree}
+tsv
+json
+resources
+```
+
diff --git a/docs/api/hanlp/datasets/ner/json.md b/docs/api/hanlp/datasets/ner/json.md
new file mode 100644
index 000000000..393766597
--- /dev/null
+++ b/docs/api/hanlp/datasets/ner/json.md
@@ -0,0 +1,9 @@
+# json
+
+```{eval-rst}
+.. currentmodule:: hanlp.datasets.ner.json_ner
+
+.. autoclass:: JsonNERDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/ner/resources.md b/docs/api/hanlp/datasets/ner/resources.md
new file mode 100644
index 000000000..4656b4fca
--- /dev/null
+++ b/docs/api/hanlp/datasets/ner/resources.md
@@ -0,0 +1,50 @@
+# resources
+
+## CoNLL 2003
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.ner.conll03
+ :members:
+
+```
+
+## MSRA
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.ner.msra
+ :members:
+
+```
+
+## OntoNotes5
+
+```{eval-rst}
+
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_CHINESE_TRAIN
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_CHINESE_DEV
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_CHINESE_TEST
+
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_NER_CHINESE_TRAIN
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_NER_CHINESE_DEV
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_NER_CHINESE_TEST
+
+```
+
+## Resume
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.ner.resume
+ :members:
+```
+
+## Weibo
+
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.ner.weibo
+ :members:
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/datasets/ner/tsv.md b/docs/api/hanlp/datasets/ner/tsv.md
new file mode 100644
index 000000000..d6fed10e5
--- /dev/null
+++ b/docs/api/hanlp/datasets/ner/tsv.md
@@ -0,0 +1,9 @@
+# tsv
+
+```{eval-rst}
+.. currentmodule:: hanlp.datasets.ner.tsv
+
+.. autoclass:: TSVTaggingDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/pos/index.md b/docs/api/hanlp/datasets/pos/index.md
new file mode 100644
index 000000000..06671727a
--- /dev/null
+++ b/docs/api/hanlp/datasets/pos/index.md
@@ -0,0 +1,12 @@
+# pos
+
+PoS datasets.
+
+```{eval-rst}
+PoS is a normal tagging task which uses :class:`hanlp.datasets.ner.tsv.TSVTaggingDataset` for loading.
+```
+
+```{toctree}
+resources
+```
+
diff --git a/docs/api/hanlp/datasets/pos/resources.md b/docs/api/hanlp/datasets/pos/resources.md
new file mode 100644
index 000000000..186074e5a
--- /dev/null
+++ b/docs/api/hanlp/datasets/pos/resources.md
@@ -0,0 +1,32 @@
+# resources
+
+## CTB5
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.pos.ctb5
+ :members:
+
+```
+
+## CTB8
+
+```{eval-rst}
+
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_POS_TRAIN
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_POS_DEV
+.. autodata:: hanlp.datasets.parsing.ctb8.CTB8_POS_TEST
+
+```
+
+## CTB9
+
+
+```{eval-rst}
+
+
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_POS_TRAIN
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_POS_DEV
+.. autodata:: hanlp.datasets.parsing.ctb9.CTB9_POS_TEST
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/datasets/srl/conll2012_dataset.md b/docs/api/hanlp/datasets/srl/conll2012_dataset.md
new file mode 100644
index 000000000..4fad716de
--- /dev/null
+++ b/docs/api/hanlp/datasets/srl/conll2012_dataset.md
@@ -0,0 +1,8 @@
+# conll2012_dataset
+
+```{eval-rst}
+
+.. autoclass:: hanlp.datasets.srl.conll2012.CoNLL2012SRLDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/srl/index.md b/docs/api/hanlp/datasets/srl/index.md
new file mode 100644
index 000000000..c73520e46
--- /dev/null
+++ b/docs/api/hanlp/datasets/srl/index.md
@@ -0,0 +1,9 @@
+# srl
+
+Semantic Role Labeling datasets.
+
+```{toctree}
+conll2012_dataset
+resources
+```
+
diff --git a/docs/api/hanlp/datasets/srl/resources.md b/docs/api/hanlp/datasets/srl/resources.md
new file mode 100644
index 000000000..656dd2632
--- /dev/null
+++ b/docs/api/hanlp/datasets/srl/resources.md
@@ -0,0 +1,16 @@
+# resources
+
+## OntoNotes 5
+
+### Chinese
+
+```{eval-rst}
+
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_CHINESE_TRAIN
+ :noindex:
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_CHINESE_DEV
+ :noindex:
+.. autodata:: hanlp.datasets.srl.ontonotes5.chinese.ONTONOTES5_CONLL12_CHINESE_TEST
+ :noindex:
+
+```
diff --git a/docs/api/hanlp/datasets/tok/index.md b/docs/api/hanlp/datasets/tok/index.md
new file mode 100644
index 000000000..146f1396c
--- /dev/null
+++ b/docs/api/hanlp/datasets/tok/index.md
@@ -0,0 +1,10 @@
+# tok
+
+Tokenization datasets.
+
+```{toctree}
+txt
+mcws_dataset
+resources
+```
+
diff --git a/docs/api/hanlp/datasets/tok/mcws_dataset.md b/docs/api/hanlp/datasets/tok/mcws_dataset.md
new file mode 100644
index 000000000..88a3db198
--- /dev/null
+++ b/docs/api/hanlp/datasets/tok/mcws_dataset.md
@@ -0,0 +1,9 @@
+# mcws_dataset
+
+```{eval-rst}
+.. currentmodule:: hanlp.datasets.cws.multi_criteria_cws.mcws_dataset
+
+.. autoclass:: MultiCriteriaTextTokenizingDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/datasets/tok/resources.md b/docs/api/hanlp/datasets/tok/resources.md
new file mode 100644
index 000000000..1e2c6d701
--- /dev/null
+++ b/docs/api/hanlp/datasets/tok/resources.md
@@ -0,0 +1,76 @@
+# resources
+
+## sighan2005
+
+[The Second International Chinese Word Segmentation Bakeoff](http://sighan.cs.uchicago.edu/bakeoff2005/) took place over the summer of 2005.
+
+### pku
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.cws.sighan2005.pku
+ :members:
+
+```
+
+### msr
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.cws.sighan2005.msr
+ :members:
+
+```
+
+### as
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.cws.sighan2005.as_
+ :members:
+
+```
+
+### cityu
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.cws.sighan2005.cityu
+ :members:
+
+```
+
+## CTB6
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.cws.ctb6
+ :members:
+
+```
+
+## CTB8
+
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.parsing.ctb8
+
+.. autodata:: CTB8_CWS_TRAIN
+.. autodata:: CTB8_CWS_DEV
+.. autodata:: CTB8_CWS_TEST
+
+```
+
+## CTB9
+
+
+```{eval-rst}
+
+.. automodule:: hanlp.datasets.parsing.ctb9
+
+.. autodata:: CTB9_CWS_TRAIN
+.. autodata:: CTB9_CWS_DEV
+.. autodata:: CTB9_CWS_TEST
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/datasets/tok/txt.md b/docs/api/hanlp/datasets/tok/txt.md
new file mode 100644
index 000000000..6c9952b4e
--- /dev/null
+++ b/docs/api/hanlp/datasets/tok/txt.md
@@ -0,0 +1,9 @@
+# txt
+
+```{eval-rst}
+.. currentmodule:: hanlp.datasets.tokenization.txt
+
+.. autoclass:: TextTokenizingDataset
+ :members:
+
+```
diff --git a/docs/api/hanlp/hanlp.rst b/docs/api/hanlp/hanlp.rst
new file mode 100644
index 000000000..b4b6836cd
--- /dev/null
+++ b/docs/api/hanlp/hanlp.rst
@@ -0,0 +1,10 @@
+.. _api/main:
+
+hanlp
+==========
+
+.. currentmodule:: hanlp
+
+.. autofunction:: load
+
+.. autofunction:: pipeline
\ No newline at end of file
diff --git a/docs/api/hanlp/index.md b/docs/api/hanlp/index.md
new file mode 100644
index 000000000..8dcb226e6
--- /dev/null
+++ b/docs/api/hanlp/index.md
@@ -0,0 +1,13 @@
+# hanlp
+
+Core API for `hanlp`.
+
+```{toctree}
+hanlp
+common/index
+components/index
+pretrained/index
+datasets/index
+utils/index
+layers/index
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/layers/decoders/biaffine_ner.md b/docs/api/hanlp/layers/decoders/biaffine_ner.md
new file mode 100644
index 000000000..d1a08a0d7
--- /dev/null
+++ b/docs/api/hanlp/layers/decoders/biaffine_ner.md
@@ -0,0 +1,9 @@
+# biaffine_ner
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.components.ner.biaffine_ner.biaffine_ner_model.BiaffineNamedEntityRecognitionDecoder
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/decoders/index.md b/docs/api/hanlp/layers/decoders/index.md
new file mode 100644
index 000000000..56f331ea1
--- /dev/null
+++ b/docs/api/hanlp/layers/decoders/index.md
@@ -0,0 +1,7 @@
+# decoders
+
+```{toctree}
+linear_crf
+biaffine_ner
+```
+
diff --git a/docs/api/hanlp/layers/decoders/linear_crf.md b/docs/api/hanlp/layers/decoders/linear_crf.md
new file mode 100644
index 000000000..3266bfb4b
--- /dev/null
+++ b/docs/api/hanlp/layers/decoders/linear_crf.md
@@ -0,0 +1,9 @@
+# linear_crf
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.components.mtl.tasks.pos.LinearCRFDecoder
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/embeddings/char_cnn.md b/docs/api/hanlp/layers/embeddings/char_cnn.md
new file mode 100644
index 000000000..92224aaac
--- /dev/null
+++ b/docs/api/hanlp/layers/embeddings/char_cnn.md
@@ -0,0 +1,12 @@
+# char_cnn
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.layers.embeddings.char_cnn.CharCNN
+ :members:
+
+.. autoclass:: hanlp.layers.embeddings.char_cnn.CharCNNEmbedding
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/embeddings/char_rnn.md b/docs/api/hanlp/layers/embeddings/char_rnn.md
new file mode 100644
index 000000000..e5d481655
--- /dev/null
+++ b/docs/api/hanlp/layers/embeddings/char_rnn.md
@@ -0,0 +1,12 @@
+# char_rnn
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.layers.embeddings.char_rnn.CharRNN
+ :members:
+
+.. autoclass:: hanlp.layers.embeddings.char_rnn.CharRNNEmbedding
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/embeddings/embedding.md b/docs/api/hanlp/layers/embeddings/embedding.md
new file mode 100644
index 000000000..493f0d856
--- /dev/null
+++ b/docs/api/hanlp/layers/embeddings/embedding.md
@@ -0,0 +1,15 @@
+# embedding
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.layers.embeddings.embedding.Embedding
+ :members:
+
+.. autoclass:: hanlp.layers.embeddings.embedding.ConcatModuleList
+ :members:
+
+.. autoclass:: hanlp.layers.embeddings.embedding.EmbeddingList
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/embeddings/fasttext.md b/docs/api/hanlp/layers/embeddings/fasttext.md
new file mode 100644
index 000000000..5e8c60b4e
--- /dev/null
+++ b/docs/api/hanlp/layers/embeddings/fasttext.md
@@ -0,0 +1,11 @@
+# fasttext
+
+```{eval-rst}
+
+.. autoclass:: hanlp.layers.embeddings.fast_text.FastTextEmbedding
+ :members:
+
+.. autoclass:: hanlp.layers.embeddings.fast_text.FastTextEmbeddingModule
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/embeddings/index.md b/docs/api/hanlp/layers/embeddings/index.md
new file mode 100644
index 000000000..a3650d272
--- /dev/null
+++ b/docs/api/hanlp/layers/embeddings/index.md
@@ -0,0 +1,11 @@
+# embeddings
+
+```{toctree}
+embedding
+word2vec
+fasttext
+char_cnn
+char_rnn
+transformer
+```
+
diff --git a/docs/api/hanlp/layers/embeddings/transformer.md b/docs/api/hanlp/layers/embeddings/transformer.md
new file mode 100644
index 000000000..8fcbe4c0f
--- /dev/null
+++ b/docs/api/hanlp/layers/embeddings/transformer.md
@@ -0,0 +1,12 @@
+# transformer
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.layers.embeddings.contextual_word_embedding.ContextualWordEmbedding
+ :members:
+
+.. autoclass:: hanlp.layers.embeddings.contextual_word_embedding.ContextualWordEmbeddingModule
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/embeddings/word2vec.md b/docs/api/hanlp/layers/embeddings/word2vec.md
new file mode 100644
index 000000000..07883a704
--- /dev/null
+++ b/docs/api/hanlp/layers/embeddings/word2vec.md
@@ -0,0 +1,11 @@
+# word2vec
+
+```{eval-rst}
+
+.. autoclass:: hanlp.layers.embeddings.word2vec.Word2VecEmbedding
+ :members:
+
+.. autoclass:: hanlp.layers.embeddings.word2vec.Word2VecEmbeddingModule
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/index.md b/docs/api/hanlp/layers/index.md
new file mode 100644
index 000000000..4fdc68a82
--- /dev/null
+++ b/docs/api/hanlp/layers/index.md
@@ -0,0 +1,8 @@
+# layers
+
+```{toctree}
+embeddings/index
+transformers/index
+decoders/index
+```
+
diff --git a/docs/api/hanlp/layers/transformers/encoder.md b/docs/api/hanlp/layers/transformers/encoder.md
new file mode 100644
index 000000000..771355511
--- /dev/null
+++ b/docs/api/hanlp/layers/transformers/encoder.md
@@ -0,0 +1,9 @@
+# encoder
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.layers.transformers.encoder.TransformerEncoder
+ :members:
+
+```
diff --git a/docs/api/hanlp/layers/transformers/index.md b/docs/api/hanlp/layers/transformers/index.md
new file mode 100644
index 000000000..bab7edb21
--- /dev/null
+++ b/docs/api/hanlp/layers/transformers/index.md
@@ -0,0 +1,7 @@
+# transformers
+
+```{toctree}
+encoder
+tokenizer
+```
+
diff --git a/docs/api/hanlp/layers/transformers/tokenizer.md b/docs/api/hanlp/layers/transformers/tokenizer.md
new file mode 100644
index 000000000..36a934d6f
--- /dev/null
+++ b/docs/api/hanlp/layers/transformers/tokenizer.md
@@ -0,0 +1,9 @@
+# tokenizer
+
+
+```{eval-rst}
+
+.. autoclass:: hanlp.transform.transformer_tokenizer.TransformerSequenceTokenizer
+ :members:
+
+```
diff --git a/docs/api/hanlp/pretrained/dep.md b/docs/api/hanlp/pretrained/dep.md
new file mode 100644
index 000000000..6d4373470
--- /dev/null
+++ b/docs/api/hanlp/pretrained/dep.md
@@ -0,0 +1,8 @@
+# dep
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.dep
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/eos.md b/docs/api/hanlp/pretrained/eos.md
new file mode 100644
index 000000000..a804aeb05
--- /dev/null
+++ b/docs/api/hanlp/pretrained/eos.md
@@ -0,0 +1,9 @@
+# eos
+
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.eos
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/fasttext.md b/docs/api/hanlp/pretrained/fasttext.md
new file mode 100644
index 000000000..58ad516b8
--- /dev/null
+++ b/docs/api/hanlp/pretrained/fasttext.md
@@ -0,0 +1,8 @@
+# fasttext
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.fasttext
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/glove.md b/docs/api/hanlp/pretrained/glove.md
new file mode 100644
index 000000000..f43841531
--- /dev/null
+++ b/docs/api/hanlp/pretrained/glove.md
@@ -0,0 +1,8 @@
+# glove
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.glove
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/index.md b/docs/api/hanlp/pretrained/index.md
new file mode 100644
index 000000000..f79247ae9
--- /dev/null
+++ b/docs/api/hanlp/pretrained/index.md
@@ -0,0 +1,21 @@
+# pretrained
+
+```{eval-rst}
+NLP components grouped by tasks. For each task, we provide at least one ``torch.utils.data.Dataset`` compatible class
+and several open-source resources. Each of them is stored in a Python constant which can be fetched
+using :meth:`hanlp.load`.
+```
+
+```{toctree}
+mtl
+eos
+tok
+pos
+ner
+dep
+sdp
+word2vec
+glove
+fasttext
+```
+
diff --git a/docs/api/hanlp/pretrained/mtl.md b/docs/api/hanlp/pretrained/mtl.md
new file mode 100644
index 000000000..a81edb65c
--- /dev/null
+++ b/docs/api/hanlp/pretrained/mtl.md
@@ -0,0 +1,8 @@
+# mtl
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.mtl
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/ner.md b/docs/api/hanlp/pretrained/ner.md
new file mode 100644
index 000000000..26b1d078b
--- /dev/null
+++ b/docs/api/hanlp/pretrained/ner.md
@@ -0,0 +1,8 @@
+# ner
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.ner
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/pos.md b/docs/api/hanlp/pretrained/pos.md
new file mode 100644
index 000000000..f9d61d6ff
--- /dev/null
+++ b/docs/api/hanlp/pretrained/pos.md
@@ -0,0 +1,8 @@
+# pos
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.pos
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/sdp.md b/docs/api/hanlp/pretrained/sdp.md
new file mode 100644
index 000000000..f03b950a7
--- /dev/null
+++ b/docs/api/hanlp/pretrained/sdp.md
@@ -0,0 +1,8 @@
+# sdp
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.sdp
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/tok.md b/docs/api/hanlp/pretrained/tok.md
new file mode 100644
index 000000000..784013daa
--- /dev/null
+++ b/docs/api/hanlp/pretrained/tok.md
@@ -0,0 +1,8 @@
+# tok
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.tok
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/pretrained/word2vec.md b/docs/api/hanlp/pretrained/word2vec.md
new file mode 100644
index 000000000..b880d7095
--- /dev/null
+++ b/docs/api/hanlp/pretrained/word2vec.md
@@ -0,0 +1,8 @@
+# word2vec
+
+```{eval-rst}
+
+.. automodule:: hanlp.pretrained.word2vec
+ :members:
+
+```
\ No newline at end of file
diff --git a/docs/api/hanlp/utils/index.md b/docs/api/hanlp/utils/index.md
new file mode 100644
index 000000000..f6ab20ec0
--- /dev/null
+++ b/docs/api/hanlp/utils/index.md
@@ -0,0 +1,7 @@
+# utils
+
+Utilities.
+
+```{toctree}
+io_util
+```
diff --git a/docs/api/hanlp/utils/io_util.md b/docs/api/hanlp/utils/io_util.md
new file mode 100644
index 000000000..80a5b6b3d
--- /dev/null
+++ b/docs/api/hanlp/utils/io_util.md
@@ -0,0 +1,10 @@
+# io_util
+
+```{eval-rst}
+
+.. currentmodule:: hanlp.utils
+
+.. automodule:: hanlp.utils.io_util
+ :members:
+
+```
diff --git a/docs/api/restful.rst b/docs/api/restful.rst
new file mode 100644
index 000000000..d243c4cfa
--- /dev/null
+++ b/docs/api/restful.rst
@@ -0,0 +1,11 @@
+.. _api/hanlp_restful:
+
+hanlp_restful
+====================
+
+.. currentmodule:: hanlp_restful
+
+.. autoclass:: HanLPClient
+ :members:
+ :special-members:
+ :exclude-members: __init__, __repr__, __weakref__
\ No newline at end of file
diff --git a/docs/api/restful_java.md b/docs/api/restful_java.md
new file mode 100644
index 000000000..602c0bc66
--- /dev/null
+++ b/docs/api/restful_java.md
@@ -0,0 +1,21 @@
+# Java RESTful API
+
+Add the following dependency into the `pom.xml` file of your project.
+
+```xml
+
+ com.hankcs.hanlp.restful
+ hanlp-restful
+ 0.0.2
+
+```
+
+Obtain an `auth` from any compatible service provider, then initiate a `HanLPClient` and call its `parse`
+interface.
+
+```java
+HanLPClient client = new HanLPClient("https://hanlp.hankcs.com/api", null); // Replace null with your auth
+System.out.println(client.parse("2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。英首相与特朗普通电话讨论华为与苹果公司。"));
+```
+
+Refer to our testcases and [data format](../data_format) for more details.
\ No newline at end of file
diff --git a/docs/api/trie/dictionary.md b/docs/api/trie/dictionary.md
new file mode 100644
index 000000000..bf770517a
--- /dev/null
+++ b/docs/api/trie/dictionary.md
@@ -0,0 +1,11 @@
+# dictionary
+
+```{eval-rst}
+.. currentmodule:: hanlp_trie
+
+.. autoclass:: hanlp_trie.dictionary.DictInterface
+ :members:
+
+.. autoclass:: hanlp_trie.dictionary.TrieDict
+ :members:
+```
diff --git a/docs/api/trie/index.md b/docs/api/trie/index.md
new file mode 100644
index 000000000..4a32b7ab2
--- /dev/null
+++ b/docs/api/trie/index.md
@@ -0,0 +1,9 @@
+# hanlp_trie
+
+HanLP trie/dictionary interface and referential implementation.
+
+```{toctree}
+trie
+dictionary
+```
+
diff --git a/docs/api/trie/trie.md b/docs/api/trie/trie.md
new file mode 100644
index 000000000..d99c718e9
--- /dev/null
+++ b/docs/api/trie/trie.md
@@ -0,0 +1,11 @@
+# trie
+
+```{eval-rst}
+.. currentmodule:: hanlp_trie
+
+.. autoclass:: hanlp_trie.trie.Node
+ :members:
+
+.. autoclass:: hanlp_trie.trie.Trie
+ :members:
+```
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 000000000..0806d045a
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,156 @@
+# -- Project information -----------------------------------------------------
+import os
+import sys
+import os
+
+sys.path.append(os.path.abspath('..'))
+sys.path.append(os.path.abspath('../plugins/hanlp_common'))
+sys.path.append(os.path.abspath('../plugins/hanlp_trie'))
+sys.path.append(os.path.abspath('../plugins/hanlp_restful'))
+import hanlp
+
+project = 'HanLP'
+copyright = '2020, hankcs'
+author = 'hankcs'
+
+# The short X.Y version.
+version = hanlp.__version__
+# The full version, including alpha/beta/rc tags.
+release = hanlp.__version__
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+language = 'en'
+
+master_doc = "index"
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "myst_nb",
+ "sphinx_copybutton",
+ "sphinx_togglebutton",
+ "sphinxcontrib.bibtex",
+ 'sphinx_astrorefs', # astrophysics style, similar to ACL
+ "sphinx_thebe",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.viewcode",
+ "ablog",
+ 'sphinx.ext.napoleon',
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3.8", None),
+ "sphinx": ("https://www.sphinx-doc.org/en/3.x", None),
+}
+nitpick_ignore = [
+ ("py:class", "docutils.nodes.document"),
+ ("py:class", "docutils.parsers.rst.directives.body.Sidebar"),
+]
+autoclass_content = 'both'
+
+numfig = True
+
+myst_admonition_enable = True
+myst_deflist_enable = True
+myst_url_schemes = ("http", "https", "mailto")
+panels_add_bootstrap_css = False
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = "sphinx_book_theme"
+html_title = "HanLP documentation"
+html_logo = "_static/logo.png"
+html_favicon = "_static/favicon.png"
+html_copy_source = True
+html_sourcelink_suffix = ""
+
+html_sidebars = {
+ "reference/blog/*": [
+ "sidebar-search-bs.html",
+ "postcard.html",
+ "recentposts.html",
+ "tagcloud.html",
+ "categories.html",
+ "archives.html",
+ "sbt-sidebar-nav.html",
+ "sbt-sidebar-footer.html",
+ ]
+}
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["_static"]
+jupyter_execute_notebooks = "cache"
+thebe_config = {
+ "repository_url": "https://github.com/binder-examples/jupyter-stacks-datascience",
+ "repository_branch": "master",
+}
+
+html_theme_options = {
+ "theme_dev_mode": True,
+ "path_to_docs": "docs",
+ "repository_url": "https://github.com/hankcs/HanLP",
+ # "repository_branch": "gh-pages", # For testing
+ # "launch_buttons": {
+ # # "binderhub_url": "https://mybinder.org",
+ # # "jupyterhub_url": "https://datahub.berkeley.edu", # For testing
+ # "colab_url": "https://colab.research.google.com/",
+ # "notebook_interface": "jupyterlab",
+ # "thebe": True,
+ # },
+ "use_edit_page_button": True,
+ "use_issues_button": True,
+ "use_repository_button": True,
+ "use_download_button": True,
+ # For testing
+ # "home_page_in_toc": True,
+ # "single_page": True,
+ # "extra_footer": "Test ", # DEPRECATED KEY
+ # "extra_navbar": "Test ",
+}
+html_baseurl = "https://sphinx-book-theme.readthedocs.io/en/latest/"
+
+# -- ABlog config -------------------------------------------------
+blog_path = "reference/blog"
+blog_post_pattern = "reference/blog/*.md"
+blog_baseurl = "https://sphinx-book-theme.readthedocs.io"
+fontawesome_included = True
+post_auto_image = 1
+post_auto_excerpt = 2
+execution_show_tb = "READTHEDOCS" in os.environ
+
+# Localization
+nb_render_priority = {
+ "gettext": (
+ "application/vnd.jupyter.widget-view+json",
+ "application/javascript",
+ "text/html",
+ "image/svg+xml",
+ "image/png",
+ "image/jpeg",
+ "text/markdown",
+ "text/latex",
+ "text/plain",
+ )
+}
+
+locale_dirs = ['locale/']
+
+# bibtex
+bibtex_default_style = 'unsrtalpha'
diff --git a/docs/configure.md b/docs/configure.md
new file mode 100644
index 000000000..e555e3999
--- /dev/null
+++ b/docs/configure.md
@@ -0,0 +1,53 @@
+# Configuration
+
+## Customize ``HANLP_HOME``
+
+ `HANLP_HOME` is an environment variable which you can customize to any path you like. By default, `HANLP_HOME` resolves to `~/.hanlp` and `%appdata%\hanlp` on *nix and Windows respectively. If you want to temporally redirect `HANLP_HOME` to a different location, say `/data/hanlp`, the following shell command can be very helpful.
+
+```bash
+export HANLP_HOME=/data/hanlp
+```
+
+## Using GPUs
+
+By default, HanLP tries to use the least occupied GPU so that mostly you don't need to worry about it, HanLP makes the best choice for you. This behavior is very useful when you're using a public server shared across your lab or company with your collegues.
+
+HanLP also honors the ``CUDA_VISIBLE_DEVICES`` used by PyTorch and TensorFlow to limit which devices HanLP can choose from. For example, the following command will only keep the `0`th and `1`th GPU.
+
+```bash
+export CUDA_VISIBLE_DEVICES=0,1
+```
+
+```{eval-rst}
+If you need fine grained control over each component, ``hanlp.load(..., devices=...)`` is what you're looking for.
+See documents for :meth:`hanlp.load`.
+```
+
+:::{seealso}
+
+For deep learning beginners, you might need to learn how to set up a working GPU environment first. Here are some
+resources.
+
+- [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit)
+ - It's a good practice to install the driver inside a CUDA package.
+- [PyTorch](https://pytorch.org/get-started/locally/)
+ - If no existing PyTorch found, `pip install hanlp` will have the CPU-only PyTorch installed, which is universal and assumes no GPU or CUDA dependencies.
+ - You will need to install a GPU-enabled PyTorch according to your CUDA and OS versions.
+- Cloud servers
+ - There are many cloud service providing out-of-box deep learning images. HanLP works fine on these platforms.
+ They could save your time and efforts.
+
+:::
+
+## Using mirror sites
+
+By default, we maintain a global CDN to host the models. However, in some regions the downloading speed can
+be slow occasionally. If you happen to be in one of those regions, you can find some third party mirror sites
+on our [bbs](https://bbs.hankcs.com/). When you find a working URL, say
+[http://mirrors-hk.miduchina.com/hanlp/](http://mirrors-hk.miduchina.com/hanlp/) , you can set a `HANLP_URL`
+environment variable and HanLP will pick it up at the next startup.
+
+```bash
+export HANLP_URL=http://mirrors-hk.miduchina.com/hanlp/
+```
+
diff --git a/docs/contributing.md b/docs/contributing.md
new file mode 100644
index 000000000..9bc0062f6
--- /dev/null
+++ b/docs/contributing.md
@@ -0,0 +1,60 @@
+# Contributing Guide
+
+Thank you for being interested in contributing to the `HanLP`! You
+are awesome ✨.
+
+This guideline contains information about our conventions around coding style, pull request workflow, commit messages and more.
+
+This page contains information to help you get started with development on this
+project.
+
+## Development
+
+### Set-up
+
+Get the source code of this project using git:
+
+```bash
+git clone https://github.com/hankcs/HanLP
+cd HanLP
+pip install -e plugins/hanlp_trie
+pip install -e plugins/hanlp_common
+pip install -e plugins/hanlp_restful
+pip install -e .
+```
+
+To work on this project, you need Python 3.6 or newer.
+
+### Running Tests
+
+This project has a test suite to ensure certain important APIs work properly. The tests can be run using:
+
+```console
+$ python plugins/hanlp_trie/tests/test_trie.py
+```
+
+:::{tip}
+It's hard to cover every API especially those of deep learning models, due to the limited computation resource of CI. However, we suggest all inference APIs to be tested at least.
+
+:::
+
+## Repository structure
+
+This repository is a split into a few critical folders:
+
+hanlp/
+: The HanLP core package, containing the Python code.
+
+plugins/
+: Contains codes shared across several individual packages or non core APIs.
+
+docs/
+: The documentation for HanLP, which is in markdown format mostly.
+: The build configuration is contained in `conf.py`.
+
+tests/
+: Testing infrastructure that uses `unittest` to ensure the output of API is what we expect it to be.
+
+.github/workflows/
+: Contains Continuous-integration (CI) workflows, run on commits/PRs to the GitHub repository.
+
diff --git a/docs/data_format.md b/docs/data_format.md
new file mode 100644
index 000000000..f8293a946
--- /dev/null
+++ b/docs/data_format.md
@@ -0,0 +1,104 @@
+---
+jupytext:
+ formats: ipynb,md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: '0.8'
+ jupytext_version: 1.4.2
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+# Data Format
+
+
+## Input Format
+
+### RESTful Input
+
+#### Definition
+
+To make a RESTful call, one needs to send a `json` HTTP POST request to the server, which contains at least a `text`
+field or a `tokens` field. The input to RESTful API is very flexible. It can be one of the following 3 formats:
+
+1. It can be a document of raw `str` filled into `text`. The server will split it into sentences.
+1. It can be a `list` of sentences, each sentence is a raw `str`, filled into `text`.
+1. It can be a `list` of tokenized sentences, each sentence is a list of `str` typed tokens, filled into `tokens`.
+
+```{eval-rst}
+Additionally, fine-grained controls are performed with the arguments defined in
+:meth:`hanlp_restful.HanLPClient.parse`.
+```
+
+
+#### Examples
+
+```shell script
+curl -X POST "https://hanlp.hankcs.com/api/parse" \
+ -H "accept: application/json" -H "Content-Type: application/json"
+ -d "{\"text\":\"2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。阿婆主来到北京立方庭参观自然语义科技公司。\",\"tokens\":null,\"tasks\":null,\"skip_tasks\":null,\"language\":null}"
+```
+
+### Model Input
+
+The input format to models is specified per model and per tasks. Generally speaking, if a model has no tokenizer built in, then its input is
+a sentence in `list[str]` form, or multiple such sentences nested in a `list`.
+
+If a model has a tokenizer built in, each sentence is in `str` form.
+Additionally, you can use `skip_tasks='tok*'` to ask the model to use your tokenized inputs instead of tokenizing
+them, in which case, each of your sentence needs to be in `list[str]` form, as if there is no tokenizer.
+
+```{eval-rst}
+For any model, its input is of sentence level, which means you have to split a document into sentences beforehand.
+You may want to try :class:`~hanlp.components.eos.ngram.NgramSentenceBoundaryDetector` for sentence splitting.
+```
+
+## Output Format
+
+
+```{eval-rst}
+The outputs of both :class:`~hanlp_restful.HanLPClient` and
+:class:`~hanlp.components.mtl.multi_task_learning.MultiTaskLearning` are unified as the same
+:class:`~hanlp_common.document.Document` format.
+```
+
+For example, the following RESTful codes will output such an instance.
+
+```{code-cell} ipython3
+:tags: [output_scroll]
+from hanlp_restful import HanLPClient
+HanLP = HanLPClient('https://hanlp.hankcs.com/api', auth=None) # Fill in your auth
+print(HanLP('2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。英首相与特朗普通电话讨论华为与苹果公司。'))
+```
+
+The outputs above is represented as a `json` dictionary where each key is a model name and its value is
+the output of the corresponding model.
+For each output, if it's a nested `list` then it contains multiple sentences otherwise it's just one single sentence.
+
+We make the following naming convention of NLP tasks, each consists of 3 letters.
+
+````{margin} **How about annotations?**
+```{seealso}
+Each NLP task can exploit multiple datasets with their annotations, see our [annotations](annotations/index) for details.
+```
+````
+
+### Naming Convention
+
+| key | Task | Chinese |
+| ---- | ------------------------------------------------------------ | ------------ |
+| tok | Tokenization. Each element is a token. | 分词 |
+| pos | Part-of-Speech Tagging. Each element is a tag. | 词性标注 |
+| lem | Lemmatization. Each element is a lemma. | 词干提取 |
+| fea | Features of Universal Dependencies. Each element is a feature. | 词法语法特征 |
+| ner | Named Entity Recognition. Each element is a tuple of `(entity, type, begin, end)`, where `begin` and `end` are exclusive offsets. | 命名实体识别 |
+| dep | Dependency Parsing. Each element is a tuple of `(head, relation)` where `head` starts with index `0` and `ROOT` has index `-1`. | 依存句法分析 |
+| con | Constituency Parsing. Each list is a bracketed constituent. | 短语成分分析 |
+| srl | Semantic Role Labeling. Similar to `ner`, each element is tuple (arg/pred, label, begin, end), where the predicate is labeled as `PRED`. | 语义角色标注 |
+| sdp | Semantic Dependency Parsing. Similar to `dep`, however each token can have zero or zero or multiple heads and corresponding relations. | 语义依存分析 |
+| amr | Abstract Meaning Representation. Each AMR graph is represented as list of logical triples. See [AMR guidelines](https://github.com/amrisi/amr-guidelines/blob/master/amr.md#example). | 抽象意义表示 |
+
+When there are multiple models performing the same task, the keys are appended with a secondary identifier. For example, `tok/fine` and `tok/corase` means a fine-grained tokenization model and a coarse-grained one.
\ No newline at end of file
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 000000000..c645ba87d
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,65 @@
+# HanLP: Han Language Processing
+
+[![GitHub stars](https://img.shields.io/github/stars/hankcs/HanLP)](https://github.com/hankcs/HanLP/stargazers) [![GitHub forks](https://img.shields.io/github/forks/hankcs/HanLP)](https://github.com/hankcs/HanLP/network) ![pypi](https://img.shields.io/pypi/v/HanLP) [![Downloads](https://pepy.tech/badge/HanLP)](https://pepy.tech/project/HanLP) [![GitHub license](https://img.shields.io/github/license/hankcs/HanLP)](https://github.com/hankcs/HanLP/blob/master/LICENSE)
+
+The multilingual NLP library for researchers and companies, built on PyTorch and TensorFlow 2.x, for advancing
+state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be
+efficient, user friendly and extendable. It comes with pretrained models for various human languages
+including English, Chinese and many others.
+
+
+
+## Tutorials
+
+```{toctree}
+:maxdepth: 1
+:caption: Introduction
+
+tutorial
+install
+configure
+data_format
+annotations/index
+contributing
+GitHub repository
+```
+
+## Python API
+
+```{toctree}
+:caption: Python API
+:maxdepth: 2
+
+api/hanlp/index
+api/common/index
+api/restful
+api/trie/index
+```
+
+## Java API
+
+```{toctree}
+:maxdepth: 1
+:caption: Java API
+
+1.x API
+api/restful_java
+```
+
+## References
+
+```{toctree}
+:caption: References
+:maxdepth: 2
+
+references
+```
+
+
+## Acknowledgements
+
+HanLPv2.1 is heavily inspired by [AllenNLP](https://allennlp.org/) and [SuPar](https://pypi.org/project/supar/).
+
+[pypi-badge]: https://img.shields.io/pypi/v/hanlp.svg
+[pypi-link]: https://pypi.org/project/hanlp
+
diff --git a/docs/install.md b/docs/install.md
new file mode 100644
index 000000000..bd84396fe
--- /dev/null
+++ b/docs/install.md
@@ -0,0 +1,20 @@
+# Install
+
+## Install Native Package
+
+The native package runs locally which can be installed via pip.
+
+```
+pip install hanlp
+```
+
+HanLP requires Python 3.6 or later. GPU/TPU is suggested but not mandatory. Depending on your preference, HanLP offers the following flavors:
+
+| Flavor | Description |
+| ------- | ------------------------------------------------------------ |
+| default | This installs the default version which delivers the most commonly used functionalities. However, some heavy dependencies like TensorFlow are not installed. |
+| full | For experts who seek to maximize the efficiency via TensorFlow, `pip install hanlp[full]` installs every dependency HanLP will use in production. |
+
+## Install models
+
+In short, you don't need to manually install any models. Instead, they are automatically downloaded to a directory called `HANLP_HOME` when you call `hanlp.load`.
diff --git a/docs/references.bib b/docs/references.bib
new file mode 100644
index 000000000..ae4e1eb9f
--- /dev/null
+++ b/docs/references.bib
@@ -0,0 +1,351 @@
+%% This BibTeX bibliography file was created using BibDesk.
+%% https://bibdesk.sourceforge.io/
+
+%% Created for hankcs at 2020-12-31 15:16:06 -0500
+
+
+%% Saved with string encoding Unicode (UTF-8)
+
+
+
+@inproceedings{pennington-etal-2014-glove,
+ address = {Doha, Qatar},
+ author = {Pennington, Jeffrey and Socher, Richard and Manning, Christopher},
+ booktitle = {Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing ({EMNLP})},
+ date-added = {2020-12-31 15:07:29 -0500},
+ date-modified = {2020-12-31 15:07:29 -0500},
+ doi = {10.3115/v1/D14-1162},
+ month = oct,
+ pages = {1532--1543},
+ publisher = {Association for Computational Linguistics},
+ title = {{G}lo{V}e: Global Vectors for Word Representation},
+ url = {https://www.aclweb.org/anthology/D14-1162},
+ year = {2014},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/D14-1162},
+ Bdsk-Url-2 = {https://doi.org/10.3115/v1/D14-1162}}
+
+@incollection{he2018dual,
+ author = {He, Han and Wu, Lei and Yang, Xiaokun and Yan, Hua and Gao, Zhimin and Feng, Yi and Townsend, George},
+ booktitle = {Information Technology-New Generations},
+ date-added = {2020-12-31 15:03:58 -0500},
+ date-modified = {2020-12-31 15:03:58 -0500},
+ pages = {421--426},
+ publisher = {Springer},
+ title = {Dual long short-term memory networks for sub-character representation learning},
+ year = {2018}}
+
+@inproceedings{devlin-etal-2019-bert,
+ abstract = {We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation models (Peters et al., 2018a; Radford et al., 2018), BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial task-specific architecture modifications. BERT is conceptually simple and empirically powerful. It obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE score to 80.5 (7.7 point absolute improvement), MultiNLI accuracy to 86.7{\%} (4.6{\%} absolute improvement), SQuAD v1.1 question answering Test F1 to 93.2 (1.5 point absolute improvement) and SQuAD v2.0 Test F1 to 83.1 (5.1 point absolute improvement).},
+ address = {Minneapolis, Minnesota},
+ author = {Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
+ booktitle = {Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)},
+ date-added = {2020-12-31 14:46:54 -0500},
+ date-modified = {2020-12-31 14:46:54 -0500},
+ doi = {10.18653/v1/N19-1423},
+ month = jun,
+ pages = {4171--4186},
+ publisher = {Association for Computational Linguistics},
+ title = {{BERT}: Pre-training of Deep Bidirectional Transformers for Language Understanding},
+ url = {https://www.aclweb.org/anthology/N19-1423},
+ year = {2019},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/N19-1423},
+ Bdsk-Url-2 = {https://doi.org/10.18653/v1/N19-1423}}
+
+@inproceedings{Lan2020ALBERT:,
+ author = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
+ booktitle = {International Conference on Learning Representations},
+ date-added = {2020-12-31 14:44:52 -0500},
+ date-modified = {2020-12-31 14:44:52 -0500},
+ title = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
+ url = {https://openreview.net/forum?id=H1eA7AEtvS},
+ year = {2020},
+ Bdsk-Url-1 = {https://openreview.net/forum?id=H1eA7AEtvS}}
+
+@inproceedings{wang-xu-2017-convolutional,
+ abstract = {Character-based sequence labeling framework is flexible and efficient for Chinese word segmentation (CWS). Recently, many character-based neural models have been applied to CWS. While they obtain good performance, they have two obvious weaknesses. The first is that they heavily rely on manually designed bigram feature, i.e. they are not good at capturing $n$-gram features automatically. The second is that they make no use of full word information. For the first weakness, we propose a convolutional neural model, which is able to capture rich $n$-gram features without any feature engineering. For the second one, we propose an effective approach to integrate the proposed model with word embeddings. We evaluate the model on two benchmark datasets: PKU and MSR. Without any feature engineering, the model obtains competitive performance {---} 95.7{\%} on PKU and 97.3{\%} on MSR. Armed with word embeddings, the model achieves state-of-the-art performance on both datasets {---} 96.5{\%} on PKU and 98.0{\%} on MSR, without using any external labeled resource.},
+ address = {Taipei, Taiwan},
+ author = {Wang, Chunqi and Xu, Bo},
+ booktitle = {Proceedings of the Eighth International Joint Conference on Natural Language Processing (Volume 1: Long Papers)},
+ date-added = {2020-12-31 14:42:35 -0500},
+ date-modified = {2020-12-31 14:42:35 -0500},
+ month = nov,
+ pages = {163--172},
+ publisher = {Asian Federation of Natural Language Processing},
+ title = {Convolutional Neural Network with Word Embeddings for {C}hinese Word Segmentation},
+ url = {https://www.aclweb.org/anthology/I17-1017},
+ year = {2017},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/I17-1017}}
+
+@inproceedings{bertbaseline,
+ author = {He, Han and Choi, Jinho},
+ booktitle = {The Thirty-Third International Flairs Conference},
+ date-added = {2020-12-31 14:32:43 -0500},
+ date-modified = {2020-12-31 14:32:43 -0500},
+ title = {Establishing Strong Baselines for the New Decade: Sequence Tagging, Syntactic and Semantic Parsing with BERT},
+ year = {2020}}
+
+@article{bojanowski2017enriching,
+ author = {Bojanowski, Piotr and Grave, Edouard and Joulin, Armand and Mikolov, Tomas},
+ date-added = {2020-12-25 22:31:59 -0500},
+ date-modified = {2020-12-25 22:31:59 -0500},
+ issn = {2307-387X},
+ journal = {Transactions of the Association for Computational Linguistics},
+ pages = {135--146},
+ title = {Enriching Word Vectors with Subword Information},
+ volume = {5},
+ year = {2017}}
+
+@article{collins-koo-2005-discriminative,
+ author = {Collins, Michael and Koo, Terry},
+ date-added = {2020-12-25 17:25:59 -0500},
+ date-modified = {2020-12-25 17:25:59 -0500},
+ doi = {10.1162/0891201053630273},
+ journal = {Computational Linguistics},
+ number = {1},
+ pages = {25--70},
+ title = {Discriminative Reranking for Natural Language Parsing},
+ url = {https://www.aclweb.org/anthology/J05-1003},
+ volume = {31},
+ year = {2005},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/J05-1003},
+ Bdsk-Url-2 = {https://doi.org/10.1162/0891201053630273}}
+
+@inproceedings{zhang-clark-2008-tale,
+ address = {Honolulu, Hawaii},
+ author = {Zhang, Yue and Clark, Stephen},
+ booktitle = {Proceedings of the 2008 Conference on Empirical Methods in Natural Language Processing},
+ date-added = {2020-12-25 15:10:10 -0500},
+ date-modified = {2020-12-25 15:10:10 -0500},
+ month = oct,
+ pages = {562--571},
+ publisher = {Association for Computational Linguistics},
+ title = {A Tale of Two Parsers: {I}nvestigating and Combining Graph-based and Transition-based Dependency Parsing},
+ url = {https://www.aclweb.org/anthology/D08-1059},
+ year = {2008},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/D08-1059}}
+
+@inproceedings{pradhan-etal-2012-conll,
+ address = {Jeju Island, Korea},
+ author = {Pradhan, Sameer and Moschitti, Alessandro and Xue, Nianwen and Uryupina, Olga and Zhang, Yuchen},
+ booktitle = {Joint Conference on {EMNLP} and {C}o{NLL} - Shared Task},
+ date-added = {2020-12-24 23:42:41 -0500},
+ date-modified = {2020-12-24 23:42:41 -0500},
+ month = jul,
+ pages = {1--40},
+ publisher = {Association for Computational Linguistics},
+ title = {{C}o{NLL}-2012 Shared Task: Modeling Multilingual Unrestricted Coreference in {O}nto{N}otes},
+ url = {https://www.aclweb.org/anthology/W12-4501},
+ year = {2012},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/W12-4501}}
+
+@inproceedings{levow-2006-third,
+ address = {Sydney, Australia},
+ author = {Levow, Gina-Anne},
+ booktitle = {Proceedings of the Fifth {SIGHAN} Workshop on {C}hinese Language Processing},
+ date-added = {2020-12-24 23:21:14 -0500},
+ date-modified = {2020-12-24 23:21:14 -0500},
+ month = jul,
+ pages = {108--117},
+ publisher = {Association for Computational Linguistics},
+ title = {The Third International {C}hinese Language Processing Bakeoff: Word Segmentation and Named Entity Recognition},
+ url = {https://www.aclweb.org/anthology/W06-0115},
+ year = {2006},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/W06-0115}}
+
+@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,
+ author = {Tjong Kim Sang, Erik F. and De Meulder, Fien},
+ booktitle = {Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003},
+ date-added = {2020-12-24 23:19:00 -0500},
+ date-modified = {2020-12-24 23:19:00 -0500},
+ pages = {142--147},
+ title = {Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition},
+ url = {https://www.aclweb.org/anthology/W03-0419},
+ year = {2003},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/W03-0419}}
+
+@inproceedings{koehn2005europarl,
+ author = {Koehn, Philipp},
+ booktitle = {MT summit},
+ date-added = {2020-12-24 23:06:03 -0500},
+ date-modified = {2020-12-24 23:06:03 -0500},
+ organization = {Citeseer},
+ pages = {79--86},
+ title = {Europarl: A parallel corpus for statistical machine translation},
+ volume = {5},
+ year = {2005}}
+
+@inproceedings{Schweter:Ahmed:2019,
+ author = {Stefan Schweter and Sajawel Ahmed},
+ booktitle = {Proceedings of the 15th Conference on Natural Language Processing (KONVENS)},
+ date-added = {2020-12-24 23:03:23 -0500},
+ date-modified = {2020-12-24 23:03:23 -0500},
+ location = {Erlangen, Germany},
+ note = {accepted},
+ title = {{Deep-EOS: General-Purpose Neural Networks for Sentence Boundary Detection}},
+ year = 2019}
+
+@incollection{he2019effective,
+ author = {He, Han and Wu, Lei and Yan, Hua and Gao, Zhimin and Feng, Yi and Townsend, George},
+ booktitle = {Smart Intelligent Computing and Applications},
+ date-added = {2020-12-24 19:35:03 -0500},
+ date-modified = {2020-12-24 19:35:03 -0500},
+ pages = {133--142},
+ publisher = {Springer},
+ title = {Effective neural solution for multi-criteria word segmentation},
+ year = {2019}}
+
+@inproceedings{dozat2017stanford,
+ author = {Dozat, Timothy and Qi, Peng and Manning, Christopher D},
+ booktitle = {Proceedings of the CoNLL 2017 Shared Task: Multilingual Parsing from Raw Text to Universal Dependencies},
+ date-added = {2020-12-24 15:02:18 -0500},
+ date-modified = {2020-12-24 15:02:18 -0500},
+ pages = {20--30},
+ title = {Stanford's graph-based neural dependency parser at the conll 2017 shared task},
+ year = {2017}}
+
+@inproceedings{he-etal-2018-jointly,
+ abstract = {Recent BIO-tagging-based neural semantic role labeling models are very high performing, but assume gold predicates as part of the input and cannot incorporate span-level features. We propose an end-to-end approach for jointly predicting all predicates, arguments spans, and the relations between them. The model makes independent decisions about what relationship, if any, holds between every possible word-span pair, and learns contextualized span representations that provide rich, shared input features for each decision. Experiments demonstrate that this approach sets a new state of the art on PropBank SRL without gold predicates.},
+ address = {Melbourne, Australia},
+ author = {He, Luheng and Lee, Kenton and Levy, Omer and Zettlemoyer, Luke},
+ booktitle = {Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
+ date-added = {2020-12-24 14:23:45 -0500},
+ date-modified = {2020-12-24 14:23:45 -0500},
+ doi = {10.18653/v1/P18-2058},
+ month = jul,
+ pages = {364--369},
+ publisher = {Association for Computational Linguistics},
+ title = {Jointly Predicting Predicates and Arguments in Neural Semantic Role Labeling},
+ url = {https://www.aclweb.org/anthology/P18-2058},
+ year = {2018},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/P18-2058},
+ Bdsk-Url-2 = {https://doi.org/10.18653/v1/P18-2058}}
+
+@inproceedings{yu-etal-2020-named,
+ abstract = {Named Entity Recognition (NER) is a fundamental task in Natural Language Processing, concerned with identifying spans of text expressing references to entities. NER research is often focused on flat entities only (flat NER), ignoring the fact that entity references can be nested, as in [Bank of [China]] (Finkel and Manning, 2009). In this paper, we use ideas from graph-based dependency parsing to provide our model a global view on the input via a biaffine model (Dozat and Manning, 2017). The biaffine model scores pairs of start and end tokens in a sentence which we use to explore all spans, so that the model is able to predict named entities accurately. We show that the model works well for both nested and flat NER through evaluation on 8 corpora and achieving SoTA performance on all of them, with accuracy gains of up to 2.2 percentage points.},
+ address = {Online},
+ author = {Yu, Juntao and Bohnet, Bernd and Poesio, Massimo},
+ booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
+ date-added = {2020-12-24 13:35:09 -0500},
+ date-modified = {2020-12-24 13:35:09 -0500},
+ doi = {10.18653/v1/2020.acl-main.577},
+ month = jul,
+ pages = {6470--6476},
+ publisher = {Association for Computational Linguistics},
+ title = {Named Entity Recognition as Dependency Parsing},
+ url = {https://www.aclweb.org/anthology/2020.acl-main.577},
+ year = {2020},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/2020.acl-main.577},
+ Bdsk-Url-2 = {https://doi.org/10.18653/v1/2020.acl-main.577}}
+
+@inproceedings{10.1145/1457838.1457895,
+ abstract = {Many computer applications require the storage of large amounts of information within the computer's memory where it will be readily available for reference and updating. Quite commonly, more storage space is required than is available in the computer's high-speed working memory. It is, therefore, a common practice to equip computers with magnetic tapes, disks, or drums, or a combination of these to provide additional storage. This additional storage is always slower in operation than the computer's working memory and therefore care must be taken when using it to avoid excessive operating time.},
+ address = {New York, NY, USA},
+ author = {De La Briandais, Rene},
+ booktitle = {Papers Presented at the the March 3-5, 1959, Western Joint Computer Conference},
+ date-added = {2020-12-24 13:07:31 -0500},
+ date-modified = {2020-12-24 13:07:31 -0500},
+ doi = {10.1145/1457838.1457895},
+ isbn = {9781450378659},
+ location = {San Francisco, California},
+ numpages = {4},
+ pages = {295--298},
+ publisher = {Association for Computing Machinery},
+ series = {IRE-AIEE-ACM '59 (Western)},
+ title = {File Searching Using Variable Length Keys},
+ url = {https://doi.org/10.1145/1457838.1457895},
+ year = {1959},
+ Bdsk-Url-1 = {https://doi.org/10.1145/1457838.1457895}}
+
+@article{lafferty2001conditional,
+ author = {Lafferty, John and McCallum, Andrew and Pereira, Fernando CN},
+ date-added = {2020-12-24 11:46:30 -0500},
+ date-modified = {2020-12-24 12:08:29 -0500},
+ journal = {Departmental Papers (CIS)},
+ title = {Conditional random fields: Probabilistic models for segmenting and labeling sequence data},
+ year = {2001}}
+
+@inproceedings{clark-etal-2019-bam,
+ abstract = {It can be challenging to train multi-task neural networks that outperform or even match their single-task counterparts. To help address this, we propose using knowledge distillation where single-task models teach a multi-task model. We enhance this training with teacher annealing, a novel method that gradually transitions the model from distillation to supervised learning, helping the multi-task model surpass its single-task teachers. We evaluate our approach by multi-task fine-tuning BERT on the GLUE benchmark. Our method consistently improves over standard single-task and multi-task training.},
+ address = {Florence, Italy},
+ author = {Clark, Kevin and Luong, Minh-Thang and Khandelwal, Urvashi and Manning, Christopher D. and Le, Quoc V.},
+ booktitle = {Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
+ date-added = {2020-12-24 11:26:54 -0500},
+ date-modified = {2020-12-24 11:26:54 -0500},
+ doi = {10.18653/v1/P19-1595},
+ month = jul,
+ pages = {5931--5937},
+ publisher = {Association for Computational Linguistics},
+ title = {{BAM}! Born-Again Multi-Task Networks for Natural Language Understanding},
+ url = {https://www.aclweb.org/anthology/P19-1595},
+ year = {2019},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/P19-1595},
+ Bdsk-Url-2 = {https://doi.org/10.18653/v1/P19-1595}}
+
+@inproceedings{kondratyuk-straka-2019-75,
+ address = {Hong Kong, China},
+ author = {Kondratyuk, Dan and Straka, Milan},
+ booktitle = {Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)},
+ date-added = {2020-12-23 23:51:07 -0500},
+ date-modified = {2020-12-23 23:51:07 -0500},
+ pages = {2779--2795},
+ publisher = {Association for Computational Linguistics},
+ title = {75 Languages, 1 Model: Parsing Universal Dependencies Universally},
+ url = {https://www.aclweb.org/anthology/D19-1279},
+ year = {2019},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/D19-1279}}
+
+@inproceedings{dozat:17a,
+ author = {Dozat, Timothy and Manning, Christopher D.},
+ booktitle = {Proceedings of the 5th International Conference on Learning Representations},
+ date-added = {2020-12-23 23:46:20 -0500},
+ date-modified = {2020-12-23 23:46:20 -0500},
+ series = {ICLR'17},
+ title = {{Deep Biaffine Attention for Neural Dependency Parsing}},
+ url = {https://openreview.net/pdf?id=Hk95PK9le},
+ year = {2017},
+ Bdsk-Url-1 = {http://arxiv.org/abs/1611.01734},
+ Bdsk-Url-2 = {https://openreview.net/pdf?id=Hk95PK9le}}
+
+@inproceedings{smith-smith-2007-probabilistic,
+ address = {Prague, Czech Republic},
+ author = {Smith, David A. and Smith, Noah A.},
+ booktitle = {Proceedings of the 2007 Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning ({EMNLP}-{C}o{NLL})},
+ date-added = {2020-12-23 21:46:06 -0500},
+ date-modified = {2020-12-23 21:46:06 -0500},
+ month = jun,
+ pages = {132--140},
+ publisher = {Association for Computational Linguistics},
+ title = {Probabilistic Models of Nonprojective Dependency Trees},
+ url = {https://www.aclweb.org/anthology/D07-1014},
+ year = {2007},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/D07-1014}}
+
+@inproceedings{ijcai2020-560,
+ author = {Zhang, Yu and Zhou, Houquan and Li, Zhenghua},
+ booktitle = {Proceedings of the Twenty-Ninth International Joint Conference on Artificial Intelligence, {IJCAI-20}},
+ date-added = {2020-12-23 21:36:56 -0500},
+ date-modified = {2020-12-23 21:36:56 -0500},
+ doi = {10.24963/ijcai.2020/560},
+ editor = {Christian Bessiere},
+ month = {7},
+ note = {Main track},
+ pages = {4046--4053},
+ publisher = {International Joint Conferences on Artificial Intelligence Organization},
+ title = {Fast and Accurate Neural CRF Constituency Parsing},
+ url = {https://doi.org/10.24963/ijcai.2020/560},
+ year = {2020},
+ Bdsk-Url-1 = {https://doi.org/10.24963/ijcai.2020/560}}
+
+@inproceedings{buchholz-marsi-2006-conll,
+ address = {New York City},
+ author = {Buchholz, Sabine and Marsi, Erwin},
+ booktitle = {Proceedings of the Tenth Conference on Computational Natural Language Learning ({C}o{NLL}-X)},
+ date-added = {2020-12-22 22:57:41 -0500},
+ date-modified = {2020-12-22 22:57:41 -0500},
+ month = jun,
+ pages = {149--164},
+ publisher = {Association for Computational Linguistics},
+ title = {{C}o{NLL}-{X} Shared Task on Multilingual Dependency Parsing},
+ url = {https://www.aclweb.org/anthology/W06-2920},
+ year = {2006},
+ Bdsk-Url-1 = {https://www.aclweb.org/anthology/W06-2920}}
diff --git a/docs/references.rst b/docs/references.rst
new file mode 100644
index 000000000..66a81ec58
--- /dev/null
+++ b/docs/references.rst
@@ -0,0 +1,6 @@
+References
+==================
+
+.. bibliography:: references.bib
+ :cited:
+ :style: astrostyle
\ No newline at end of file
diff --git a/docs/tutorial.md b/docs/tutorial.md
new file mode 100644
index 000000000..0d5be3447
--- /dev/null
+++ b/docs/tutorial.md
@@ -0,0 +1,109 @@
+---
+jupytext:
+ formats: ipynb,md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: '0.8'
+ jupytext_version: 1.4.2
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+# Tutorial
+
+Natural Language Processing is an exciting field consists of many closely related tasks like lexical analysis
+and parsing. Each task involves many datasets and models, both requiring a high degree of expertise.
+Things get even more complex when dealing with multilingual text, as there's simply no datasets for some
+low-resource languages. However, with HanLP 2.1, core NLP tasks have been made easy to access and efficient in
+production environment. In this tutorial, we'll walk through the APIs in HanLP step by step.
+
+HanLP offers out-of-the-box RESTful API and native Python API which shares very similar interfaces
+while they are designed for different scenes.
+
+## RESTful API
+
+RESTful API is an endpoint where you send your documents to then get the parsed annotations back.
+We are hosting a **non-commercial** API service and you are welcome to [apply for an auth key](https://bbs.hankcs.com/t/apply-for-free-hanlp-restful-apis/3178).
+An auth key is a password which gives you access to our API and protects our server from being abused.
+Once obtained such an auth key, you can parse your document with our RESTful client which can be installed via:
+
+````{margin} **NonCommercial**
+```{seealso}
+Our models and RESTful APIs are under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) licence.
+```
+````
+
+```bash
+pip install hanlp_restful
+```
+
+```{eval-rst}
+Then initiate a :class:`~hanlp_restful.HanLPClient` with your auth key and send a document to have it parsed.
+```
+
+```{code-cell} ipython3
+:tags: [output_scroll]
+from hanlp_restful import HanLPClient
+HanLP = HanLPClient('https://hanlp.hankcs.com/api', auth=None, language='mul') # Fill in your auth
+
+print(HanLP('In 2021, HanLPv2.1 delivers state-of-the-art multilingual NLP techniques to production environment. ' \
+ '2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。' \
+ '2021年 HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。'))
+```
+````{margin} **But what does these annotations mean?**
+```{seealso}
+See our [data format](data_format) and [annotations](annotations/index) for details.
+```
+````
+
+
+## Visualization
+
+```{eval-rst}
+:class:`~hanlp_common.document.Document` has a handy method :meth:`~hanlp_common.document.Document.pretty_print`
+which offsers visualization in any mono-width text environment.
+```
+
+````{margin} **Non-ASCII**
+```{note}
+Non-ASCII text might screw in which case copying it into a `.tsv` editor will align it correctly.
+```
+````
+
+```{code-cell} ipython3
+from hanlp_restful import HanLPClient
+HanLP = HanLPClient('https://hanlp.hankcs.com/api', auth=None, language='mul') # Fill in your auth
+HanLP('In 2021, HanLPv2.1 delivers state-of-the-art multilingual NLP techniques to production environment. ' \
+ '2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。' \
+ '2021年 HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。').pretty_print()
+```
+
+## Native API
+
+If you want to run our models locally or you want to implement your own RESTful server, you can call the native API
+and it behaves just like a RESTful one.
+
+```{eval-rst}
+Then initiate a :class:`~hanlp_restful.HanLPClient` with your auth key and send a document to have it parsed.
+```
+````{margin} **Sentences Required**
+```{seealso}
+As MTL doesn't predict sentence boundaries, inputs have to be split beforehand.
+See our [data format](data_format) for details.
+```
+````
+
+```{code-cell} ipython3
+:tags: [output_scroll]
+import hanlp
+HanLP = hanlp.load(hanlp.pretrained.mtl.UD_ONTONOTES_TOK_POS_LEM_FEA_NER_SRL_DEP_SDP_CON_MT5_BASE)
+print(HanLP(['In 2021, HanLPv2.1 delivers state-of-the-art multilingual NLP techniques to production environment.',
+ '2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。',
+ '2021年 HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。']))
+```
+
+Due to the fact that the service provider is very likely running a different model or having different settings, the
+RESTful and native results might be slightly different.
\ No newline at end of file
diff --git a/hanlp/__init__.py b/hanlp/__init__.py
index a7195ef52..24945abd3 100644
--- a/hanlp/__init__.py
+++ b/hanlp/__init__.py
@@ -1,61 +1,58 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-06-13 18:05
-import os
-
-if not int(os.environ.get('HANLP_SHOW_TF_LOG', 0)):
- os.environ['VERBOSE'] = '0'
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
- os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '0'
- import absl.logging, logging
-
- logging.getLogger('tensorflow').setLevel(logging.ERROR)
- logging.root.removeHandler(absl.logging._absl_handler)
- exec('absl.logging._warn_preinit_stderr = False') # prevent exporting _warn_preinit_stderr
-
-import hanlp.callbacks
import hanlp.common
import hanlp.components
-import hanlp.datasets
-import hanlp.layers
-import hanlp.losses
-import hanlp.metrics
-import hanlp.optimizers
import hanlp.pretrained
import hanlp.utils
-
from hanlp.version import __version__
-if not os.environ.get('HANLP_GREEDY_GPU', None):
- exec('from hanlp.utils.tf_util import nice_gpu')
- exec('nice_gpu()')
+hanlp.utils.ls_resource_in_module(hanlp.pretrained)
-exec('''
-from hanlp.utils.util import ls_resource_in_module
-ls_resource_in_module(hanlp.pretrained)
-''')
+def load(save_dir: str, verbose=None, **kwargs) -> hanlp.common.component.Component:
+ """Load pretrained component from an identifier.
+
+ Args:
+ save_dir (str): The identifier to the saved component. It could be a remote URL or a local path.
+ verbose: ``True`` to print loading progress.
+ **kwargs: Arguments passed to `Component.load`
+
+ ``devices`` is a useful arguments to specify the GPU devices component will use.
+
+ Examples::
+
+ import hanlp
+ # Load component onto the 0-th GPU.
+ hanlp.load(..., devices=0)
+ # Load component onto the 0-th and 1-th GPU using data parallelization.
+ hanlp.load(..., devices=[0,1])
+
+ .. Note::
+ A component can have dependencies on other components or resources, which will be recursively loaded. So it's
+ common to see multiple downloading per single load.
+
+ Returns:
+ A pretrained component.
-def load(save_dir: str, meta_filename='meta.json', transform_only=False, load_kwargs=None,
- **kwargs) -> hanlp.common.component.Component:
- """
- Load saved component from identifier.
- :param save_dir: The identifier to the saved component.
- :param meta_filename: The meta file of that saved component, which stores the class_path and version.
- :param transform_only: Whether to load transform only.
- :param load_kwargs: The arguments passed to `load`
- :param kwargs: Additional arguments parsed to the `from_meta` method.
- :return: A pretrained component.
"""
save_dir = hanlp.pretrained.ALL.get(save_dir, save_dir)
from hanlp.utils.component_util import load_from_meta_file
- return load_from_meta_file(save_dir, meta_filename, transform_only=transform_only, load_kwargs=load_kwargs, **kwargs)
+ if verbose is None:
+ from hanlp_common.constant import HANLP_VERBOSE
+ verbose = HANLP_VERBOSE
+ return load_from_meta_file(save_dir, 'meta.json', verbose=verbose, **kwargs)
def pipeline(*pipes) -> hanlp.components.pipeline.Pipeline:
- """
- Creates a pipeline of components.
- :param pipes: Components if pre-defined any.
- :return: A pipeline
+ """Creates a pipeline of components. It's made for bundling KerasComponents. For TorchComponent, use
+ :class:`~hanlp.components.mtl.multi_task_learning.MultiTaskLearning` instead.
+
+ Args:
+ *pipes: Components if pre-defined any.
+
+ Returns:
+ A pipeline, which is list of components in order.
+
"""
return hanlp.components.pipeline.Pipeline(*pipes)
diff --git a/hanlp/common/__init__.py b/hanlp/common/__init__.py
index 08dd68e3f..308bfb628 100644
--- a/hanlp/common/__init__.py
+++ b/hanlp/common/__init__.py
@@ -1,8 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-08-26 14:45
-from . import component
-from . import constant
-from . import structure
-from . import transform
-from . import vocab
diff --git a/hanlp/common/component.py b/hanlp/common/component.py
index b81ab44a4..9c9c4b6d2 100644
--- a/hanlp/common/component.py
+++ b/hanlp/common/component.py
@@ -2,530 +2,35 @@
# Author: hankcs
# Date: 2019-08-26 14:45
import inspect
-import logging
-import math
-import os
-import warnings
from abc import ABC, abstractmethod
-from typing import Any, Dict, Optional, List
+from typing import Any
-import numpy as np
-import tensorflow as tf
+from hanlp_common.configurable import Configurable
-import hanlp
-import hanlp.version
-from hanlp.callbacks.fine_csv_logger import FineCSVLogger
-from hanlp.common.structure import SerializableDict
-from hanlp.common.transform import Transform
-from hanlp.common.vocab import Vocab
-from hanlp.metrics.chunking.iobes import IOBES_F1
-from hanlp.optimizers.adamw.optimization import AdamWeightDecay
-from hanlp.utils import io_util
-from hanlp.utils.io_util import get_resource, tempdir_human, save_json, load_json
-from hanlp.utils.log_util import init_logger, logger
-from hanlp.utils.reflection import class_path_of, str_to_type
-from hanlp.utils.string_util import format_metrics, format_scores
-from hanlp.utils.tf_util import size_of_dataset, summary_of_model, get_callback_by_class
-from hanlp.utils.time_util import Timer, now_datetime
-from hanlp.utils.util import merge_dict
-
-
-class Component(ABC):
-
- def __init__(self) -> None:
- super().__init__()
- self.meta = {
- 'class_path': class_path_of(self),
- 'hanlp_version': hanlp.version.__version__,
- }
+class Component(Configurable, ABC):
@abstractmethod
def predict(self, data: Any, **kwargs):
- """
- Predict on data
- :param data: Any type of data subject to sub-classes
- :param kwargs: Additional arguments
- """
- raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
-
- def __call__(self, data, **kwargs):
- return self.predict(data, **kwargs)
-
- @staticmethod
- def from_meta(meta: dict, **kwargs):
- """
-
- Parameters
- ----------
- meta
- kwargs
-
- Returns
- -------
- Component
- """
- cls = meta.get('class_path', None)
- assert cls, f'{meta} doesn\'t contain class_path field'
- cls = str_to_type(cls)
- return cls.from_meta(meta)
-
-
-class KerasComponent(Component, ABC):
- def __init__(self, transform: Transform) -> None:
- super().__init__()
- self.model: Optional[tf.keras.Model] = None
- self.config = SerializableDict()
- self.transform = transform
- # share config with transform for convenience, so we don't need to pass args around
- if self.transform.config:
- for k, v in self.transform.config.items():
- self.config[k] = v
- self.transform.config = self.config
-
- def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None,
- callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs):
- input_path = get_resource(input_path)
- file_prefix, ext = os.path.splitext(input_path)
- name = os.path.basename(file_prefix)
- if not name:
- name = 'evaluate'
- if save_dir and not logger:
- logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN,
- mode='w')
- tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size)
- samples = size_of_dataset(tst_data)
- num_batches = math.ceil(samples / batch_size)
- if warm_up:
- for x, y in tst_data:
- self.model.predict_on_batch(x)
- break
- if output:
- assert save_dir, 'Must pass save_dir in order to output'
- if isinstance(output, bool):
- output = os.path.join(save_dir, name) + '.predict' + ext
- elif isinstance(output, str):
- output = output
- else:
- raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output)))
- timer = Timer()
- loss, score, output = self.evaluate_dataset(tst_data, callbacks, output, num_batches)
- delta_time = timer.stop()
- speed = samples / delta_time.delta_seconds
+ """Predict on data. This is the base class for all components, including rule based and statistical ones.
- if logger:
- f1: IOBES_F1 = None
- for metric in self.model.metrics:
- if isinstance(metric, IOBES_F1):
- f1 = metric
- break
- extra_report = ''
- if f1:
- overall, by_type, extra_report = f1.state.result(full=True, verbose=False)
- extra_report = ' \n' + extra_report
- logger.info('Evaluation results for {} - '
- 'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}'
- .format(name + ext, loss,
- format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics),
- speed, extra_report))
- if output:
- logger.info('Saving output to {}'.format(output))
- with open(output, 'w', encoding='utf-8') as out:
- self.evaluate_output(tst_data, out, num_batches, self.model.metrics)
+ Args:
+ data: Any type of data subject to sub-classes
+ kwargs: Additional arguments
- return loss, score, speed
+ Returns: Any predicted annotations.
- def evaluate_dataset(self, tst_data, callbacks, output, num_batches):
- loss, score = self.model.evaluate(tst_data, callbacks=callbacks, steps=num_batches)
- return loss, score, output
-
- def evaluate_output(self, tst_data, out, num_batches, metrics: List[tf.keras.metrics.Metric]):
- # out.write('x\ty_true\ty_pred\n')
- for metric in metrics:
- metric.reset_states()
- for idx, batch in enumerate(tst_data):
- outputs = self.model.predict_on_batch(batch[0])
- for metric in metrics:
- metric(batch[1], outputs, outputs._keras_mask if hasattr(outputs, '_keras_mask') else None)
- self.evaluate_output_to_file(batch, outputs, out)
- print('\r{}/{} {}'.format(idx + 1, num_batches, format_metrics(metrics)), end='')
- print()
-
- def evaluate_output_to_file(self, batch, outputs, out):
- for x, y_gold, y_pred in zip(self.transform.X_to_inputs(batch[0]),
- self.transform.Y_to_outputs(batch[1], gold=True),
- self.transform.Y_to_outputs(outputs, gold=False)):
- out.write(self.transform.input_truth_output_to_str(x, y_gold, y_pred))
-
- def _capture_config(self, config: Dict,
- exclude=(
- 'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose',
- 'dev_batch_size', '__class__')):
- """
- Save arguments to config
-
- Parameters
- ----------
- config
- `locals()`
- exclude
- """
- if 'kwargs' in config:
- config.update(config['kwargs'])
- config = dict(
- (key, tf.keras.utils.serialize_keras_object(value)) if hasattr(value, 'get_config') else (key, value) for
- key, value in config.items())
- for key in exclude:
- config.pop(key, None)
- self.config.update(config)
-
- def save_meta(self, save_dir, filename='meta.json', **kwargs):
- self.meta['create_time']: now_datetime()
- self.meta.update(kwargs)
- save_json(self.meta, os.path.join(save_dir, filename))
-
- def load_meta(self, save_dir, filename='meta.json'):
- save_dir = get_resource(save_dir)
- metapath = os.path.join(save_dir, filename)
- if os.path.isfile(metapath):
- self.meta.update(load_json(metapath))
-
- def save_config(self, save_dir, filename='config.json'):
- self.config.save_json(os.path.join(save_dir, filename))
-
- def load_config(self, save_dir, filename='config.json'):
- save_dir = get_resource(save_dir)
- self.config.load_json(os.path.join(save_dir, filename))
-
- def save_weights(self, save_dir, filename='model.h5'):
- self.model.save_weights(os.path.join(save_dir, filename))
-
- def load_weights(self, save_dir, filename='model.h5', **kwargs):
- assert self.model.built or self.model.weights, 'You must call self.model.built() in build_model() ' \
- 'in order to load it'
- save_dir = get_resource(save_dir)
- self.model.load_weights(os.path.join(save_dir, filename))
-
- def save_vocabs(self, save_dir, filename='vocabs.json'):
- vocabs = SerializableDict()
- for key, value in vars(self.transform).items():
- if isinstance(value, Vocab):
- vocabs[key] = value.to_dict()
- vocabs.save_json(os.path.join(save_dir, filename))
-
- def load_vocabs(self, save_dir, filename='vocabs.json'):
- save_dir = get_resource(save_dir)
- vocabs = SerializableDict()
- vocabs.load_json(os.path.join(save_dir, filename))
- for key, value in vocabs.items():
- vocab = Vocab()
- vocab.copy_from(value)
- setattr(self.transform, key, vocab)
-
- def load_transform(self, save_dir) -> Transform:
- """
- Try to load transform only. This method might fail due to the fact it avoids building the model.
- If it do fail, then you have to use `load` which might be too heavy but that's the best we can do.
- :param save_dir: The path to load.
"""
- save_dir = get_resource(save_dir)
- self.load_config(save_dir)
- self.load_vocabs(save_dir)
- self.transform.build_config()
- self.transform.lock_vocabs()
- return self.transform
-
- def save(self, save_dir: str, **kwargs):
- self.save_config(save_dir)
- self.save_vocabs(save_dir)
- self.save_weights(save_dir)
-
- def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs):
- self.meta['load_path'] = save_dir
- save_dir = get_resource(save_dir)
- self.load_config(save_dir)
- self.load_vocabs(save_dir)
- self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True))
- self.load_weights(save_dir, **kwargs)
- self.load_meta(save_dir)
-
- @property
- def input_shape(self) -> List:
- return self.transform.output_shapes[0]
-
- def build(self, logger, **kwargs):
- self.transform.build_config()
- self.model = self.build_model(**merge_dict(self.config, training=kwargs.get('training', None),
- loss=kwargs.get('loss', None)))
- self.transform.lock_vocabs()
- optimizer = self.build_optimizer(**self.config)
- loss = self.build_loss(
- **self.config if 'loss' in self.config else dict(list(self.config.items()) + [('loss', None)]))
- # allow for different
- metrics = self.build_metrics(**merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'),
- logger=logger, overwrite=True))
- if not isinstance(metrics, list):
- if isinstance(metrics, tf.keras.metrics.Metric):
- metrics = [metrics]
- if not self.model.built:
- sample_inputs = self.sample_data
- if sample_inputs is not None:
- self.model(sample_inputs)
- else:
- if len(self.transform.output_shapes[0]) == 1 and self.transform.output_shapes[0][0] is None:
- x_shape = self.transform.output_shapes[0]
- else:
- x_shape = list(self.transform.output_shapes[0])
- for i, shape in enumerate(x_shape):
- x_shape[i] = [None] + shape # batch + X.shape
- self.model.build(input_shape=x_shape)
- self.compile_model(optimizer, loss, metrics)
- return self.model, optimizer, loss, metrics
-
- def compile_model(self, optimizer, loss, metrics):
- self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics, run_eagerly=self.config.run_eagerly)
-
- def build_optimizer(self, optimizer, **kwargs):
- if isinstance(optimizer, (str, dict)):
- custom_objects = {'AdamWeightDecay': AdamWeightDecay}
- optimizer: tf.keras.optimizers.Optimizer = tf.keras.utils.deserialize_keras_object(optimizer,
- module_objects=vars(tf.keras.optimizers),
- custom_objects=custom_objects)
- self.config.optimizer = tf.keras.utils.serialize_keras_object(optimizer)
- return optimizer
-
- def build_loss(self, loss, **kwargs):
- if not loss:
- loss = tf.keras.losses.SparseCategoricalCrossentropy(
- reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
- from_logits=True)
- elif isinstance(loss, (str, dict)):
- loss = tf.keras.utils.deserialize_keras_object(loss, module_objects=tf.keras.losses)
- if isinstance(loss, tf.keras.losses.Loss):
- self.config.loss = tf.keras.utils.serialize_keras_object(loss)
- return loss
-
- def build_transform(self, **kwargs):
- return self.transform
-
- def build_vocab(self, trn_data, logger):
- train_examples = self.transform.fit(trn_data, **self.config)
- self.transform.summarize_vocabs(logger)
- return train_examples
-
- def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
- metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
- return [metric]
-
- @abstractmethod
- def build_model(self, **kwargs) -> tf.keras.Model:
- pass
-
- def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True,
- **kwargs):
- self._capture_config(locals())
- self.transform = self.build_transform(**self.config)
- if not save_dir:
- save_dir = tempdir_human()
- if not logger:
- logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN)
- logger.info('Hyperparameter:\n' + self.config.to_json())
- num_examples = self.build_vocab(trn_data, logger)
- # assert num_examples, 'You forgot to return the number of training examples in your build_vocab'
- logger.info('Building...')
- train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None
- self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None
- model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True))
- logger.info('Model built:\n' + summary_of_model(self.model))
- self.save_config(save_dir)
- self.save_vocabs(save_dir)
- self.save_meta(save_dir)
- trn_data = self.build_train_dataset(trn_data, batch_size, num_examples)
- dev_data = self.build_valid_dataset(dev_data, batch_size)
- callbacks = self.build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger))
- # need to know #batches, otherwise progbar crashes
- dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size)
- checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint)
- timer = Timer()
- try:
- history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs,
- num_examples=num_examples,
- train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps,
- callbacks=callbacks, logger=logger, model=model, optimizer=optimizer,
- loss=loss,
- metrics=metrics, overwrite=True))
- except KeyboardInterrupt:
- print()
- if not checkpoint or checkpoint.best in (np.Inf, -np.Inf):
- self.save_weights(save_dir)
- logger.info('Aborted with model saved')
- else:
- logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}')
- # noinspection PyTypeChecker
- history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History)
- delta_time = timer.stop()
- best_epoch_ago = 0
- if history and hasattr(history, 'epoch'):
- trained_epoch = len(history.epoch)
- logger.info('Trained {} epochs in {}, each epoch takes {}'.
- format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time))
- io_util.save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder)
- monitor_history: List = history.history.get(checkpoint.monitor, None)
- if monitor_history:
- best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best)
- if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]:
- logger.info(f'Restored the best model saved with best '
- f'{checkpoint.monitor} = {checkpoint.best:.4f} '
- f'saved {best_epoch_ago} epochs ago')
- self.load_weights(save_dir) # restore best model
- return history
-
- def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer,
- loss, metrics, callbacks,
- logger, **kwargs):
- history = self.model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch,
- validation_data=dev_data,
- callbacks=callbacks,
- validation_steps=dev_steps,
- ) # type:tf.keras.callbacks.History
- return history
-
- def build_valid_dataset(self, dev_data, batch_size):
- dev_data = self.transform.file_to_dataset(dev_data, batch_size=batch_size, shuffle=False)
- return dev_data
-
- def build_train_dataset(self, trn_data, batch_size, num_examples):
- trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size,
- shuffle=True,
- repeat=-1 if self.config.train_steps else None)
- return trn_data
-
- def build_callbacks(self, save_dir, logger, **kwargs):
- metrics = kwargs.get('metrics', 'accuracy')
- if isinstance(metrics, (list, tuple)):
- metrics = metrics[-1]
- monitor = f'val_{metrics}'
- checkpoint = tf.keras.callbacks.ModelCheckpoint(
- os.path.join(save_dir, 'model.h5'),
- # verbose=1,
- monitor=monitor, save_best_only=True,
- mode='max',
- save_weights_only=True)
- logger.debug(f'Monitor {checkpoint.monitor} for checkpoint')
- tensorboard_callback = tf.keras.callbacks.TensorBoard(
- log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs')))
- csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True)
- callbacks = [checkpoint, tensorboard_callback, csv_logger]
- lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None)
- if lr_decay_per_epoch:
- learning_rate = self.model.optimizer.get_config().get('learning_rate', None)
- if not learning_rate:
- logger.warning('Learning rate decay not supported for optimizer={}'.format(repr(self.model.optimizer)))
- else:
- logger.debug(f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}')
- callbacks.append(tf.keras.callbacks.LearningRateScheduler(
- lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch)))
- anneal_factor = self.config.get('anneal_factor', None)
- if anneal_factor:
- callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor,
- patience=self.config.get('anneal_patience', 10)))
- early_stopping_patience = self.config.get('early_stopping_patience', None)
- if early_stopping_patience:
- callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=monitor, mode='max',
- verbose=1,
- patience=early_stopping_patience))
- return callbacks
-
- def on_train_begin(self):
- """
- Callback before the training starts
- """
- pass
-
- def predict(self, data: Any, batch_size=None, **kwargs):
- assert self.model, 'Please call fit or load before predict'
- if not data:
- return []
- data, flat = self.transform.input_to_inputs(data)
-
- if not batch_size:
- batch_size = self.config.batch_size
-
- dataset = self.transform.inputs_to_dataset(data, batch_size=batch_size, gold=kwargs.get('gold', False))
-
- results = []
- num_samples = 0
- data_is_list = isinstance(data, list)
- for idx, batch in enumerate(dataset):
- samples_in_batch = tf.shape(batch[-1] if isinstance(batch[-1], tf.Tensor) else batch[-1][0])[0]
- if data_is_list:
- inputs = data[num_samples:num_samples + samples_in_batch]
- else:
- inputs = None # if data is a generator, it's usually one-time, not able to transform into a list
- for output in self.predict_batch(batch, inputs=inputs, **kwargs):
- results.append(output)
- num_samples += samples_in_batch
-
- if flat:
- return results[0]
- return results
-
- def predict_batch(self, batch, inputs=None, **kwargs):
- X = batch[0]
- Y = self.model.predict_on_batch(X)
- for output in self.transform.Y_to_outputs(Y, X=X, inputs=inputs, batch=batch, **kwargs):
- yield output
-
- @property
- def sample_data(self):
- return None
+ raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
- @staticmethod
- def from_meta(meta: dict, **kwargs):
+ def __call__(self, data: Any, **kwargs):
"""
+ A shortcut for :func:`~hanlp.common.component.predict`.
- Parameters
- ----------
- meta
- kwargs
+ Args:
+ data:
+ **kwargs:
- Returns
- -------
- KerasComponent
+ Returns:
"""
- cls = str_to_type(meta['class_path'])
- obj: KerasComponent = cls()
- assert 'load_path' in meta, f'{meta} doesn\'t contain load_path field'
- obj.load(meta['load_path'])
- return obj
-
- def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False):
- assert self.model, 'You have to fit or load a model before exporting it'
- if not export_dir:
- assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present'
- export_dir = get_resource(self.meta['load_path'])
- model_path = os.path.join(export_dir, str(version))
- if os.path.isdir(model_path) and not overwrite:
- logger.info(f'{model_path} exists, skip since overwrite = {overwrite}')
- return export_dir
- logger.info(f'Exporting to {export_dir} ...')
- tf.saved_model.save(self.model, model_path)
- logger.info(f'Successfully exported model to {export_dir}')
- if show_hint:
- logger.info(f'You can serve it through \n'
- f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} '
- f'--model_base_path={export_dir} --rest_api_port=8888')
- return export_dir
-
- def serve(self, export_dir=None, grpc_port=8500, rest_api_port=0, overwrite=False, dry_run=False):
- export_dir = self.export_model_for_serving(export_dir, show_hint=False, overwrite=overwrite)
- if not dry_run:
- del self.model # free memory
- logger.info('The inputs of exported model is shown below.')
- os.system(f'saved_model_cli show --all --dir {export_dir}/1')
- cmd = f'nohup tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' \
- f'--model_base_path={export_dir} --port={grpc_port} --rest_api_port={rest_api_port} ' \
- f'>serve.log 2>&1 &'
- logger.info(f'Running ...\n{cmd}')
- if not dry_run:
- os.system(cmd)
+ return self.predict(data, **kwargs)
diff --git a/hanlp/common/constant.py b/hanlp/common/constant.py
deleted file mode 100644
index 7e328538d..000000000
--- a/hanlp/common/constant.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-06-13 22:41
-import os
-
-PAD = ''
-UNK = ''
-ROOT = ''
-HANLP_URL = os.getenv('HANLP_URL', 'https://file.hankcs.com/hanlp/')
diff --git a/hanlp/common/dataset.py b/hanlp/common/dataset.py
new file mode 100644
index 000000000..5f490863f
--- /dev/null
+++ b/hanlp/common/dataset.py
@@ -0,0 +1,793 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-09 20:27
+import math
+import random
+import warnings
+from abc import ABC, abstractmethod
+from copy import copy
+from logging import Logger
+from typing import Union, List, Callable, Iterable, Dict
+
+import torch
+import torch.multiprocessing as mp
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import Dataset, DataLoader, Sampler
+from torch.utils.data.dataset import IterableDataset
+
+from hanlp_common.constant import IDX
+from hanlp_common.configurable import AutoConfigurable
+from hanlp.common.transform import TransformList, VocabDict, EmbeddingNamedTransform
+from hanlp.common.vocab import Vocab
+from hanlp.components.parsers.alg import kmeans
+from hanlp.utils.io_util import read_cells, get_resource
+from hanlp.utils.torch_util import dtype_of
+from hanlp_common.util import isdebugging, merge_list_of_dict, k_fold
+
+
+class Transformable(ABC):
+ def __init__(self, transform: Union[Callable, List] = None) -> None:
+ """An object which can be transformed with a list of functions. It can be imaged as an objected being passed
+ through a list of functions, while these functions are kept in a list.
+
+ Args:
+ transform: A transform function or a list of functions.
+ """
+ super().__init__()
+ if isinstance(transform, list) and not isinstance(transform, TransformList):
+ transform = TransformList(*transform)
+ self.transform: Union[Callable, TransformList] = transform
+
+ def append_transform(self, transform: Callable):
+ """Append a transform to its list of transforms.
+
+ Args:
+ transform: A new transform to be appended.
+
+ Returns: Itself.
+
+ """
+ assert transform is not None, 'None transform not allowed'
+ if not self.transform:
+ self.transform = TransformList(transform)
+ elif not isinstance(self.transform, TransformList):
+ if self.transform != transform:
+ self.transform = TransformList(self.transform, transform)
+ else:
+ if transform not in self.transform:
+ self.transform.append(transform)
+ return self
+
+ def insert_transform(self, index: int, transform: Callable):
+ """Insert a transform to a certain position.
+
+ Args:
+ index: A certain position.
+ transform: A new transform.
+
+ Returns: Dataset itself.
+
+ """
+ assert transform is not None, 'None transform not allowed'
+ if not self.transform:
+ self.transform = TransformList(transform)
+ elif not isinstance(self.transform, TransformList):
+ if self.transform != transform:
+ self.transform = TransformList(self.transform)
+ self.transform.insert(index, transform)
+ else:
+ if transform not in self.transform:
+ self.transform.insert(index, transform)
+ return self
+
+ def transform_sample(self, sample: dict, inplace=False) -> dict:
+ """Apply transforms to a sample.
+
+ Args:
+ sample: A sample, which is a ``dict`` holding features.
+ inplace: ``True`` to apply transforms inplace.
+
+ .. Attention::
+ If any transform modifies existing features, it will modify again and again when ``inplace=True``.
+ For example, if a transform insert a ``BOS`` token to a list inplace, and it is called twice,
+ then 2 ``BOS`` will be inserted which might not be an intended result.
+
+ Returns:
+
+ """
+ if not inplace:
+ sample = copy(sample)
+ if self.transform:
+ sample = self.transform(sample)
+ return sample
+
+
+class TransformableDataset(Transformable, Dataset, ABC):
+
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ generate_idx=None) -> None:
+ """A :class:`~torch.utils.data.Dataset` which can be applied with a list of transform functions.
+
+ Args:
+ data: The local or remote path to a dataset, or a list of samples where each sample is a dict.
+ transform: Predefined transform(s).
+ cache: ``True`` to enable caching, so that transforms won't be called twice.
+ generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when
+ samples are re-ordered by a sampler.
+ """
+ super().__init__(transform)
+ if generate_idx is None:
+ generate_idx = isinstance(data, list)
+ data = self.load_data(data, generate_idx)
+ assert data, 'No samples loaded'
+ assert isinstance(data[0],
+ dict), f'TransformDataset expects each sample to be a dict but got {type(data[0])} instead.'
+ self.data = data
+ if cache:
+ self.cache = [None] * len(data)
+ else:
+ self.cache = None
+
+ def load_data(self, data, generate_idx=False):
+ """A intermediate step between constructor and calling the actual file loading method.
+
+ Args:
+ data: If data is a file, this method calls :meth:`~hanlp.common.dataset.TransformableDataset.load_file`
+ to load it.
+ generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when
+ samples are re-ordered by a sampler.
+
+ Returns: Loaded samples.
+
+ """
+ if self.should_load_file(data):
+ if isinstance(data, str):
+ data = get_resource(data)
+ data = list(self.load_file(data))
+ if generate_idx:
+ for i, each in enumerate(data):
+ each[IDX] = i
+ # elif isinstance(data, list):
+ # data = self.load_list(data)
+ return data
+
+ # noinspection PyMethodMayBeStatic
+ # def load_list(self, data: list) -> List[Dict[str, Any]]:
+ # return data
+
+ def should_load_file(self, data) -> bool:
+ """Determines whether data is a filepath.
+
+ Args:
+ data: Data to check.
+
+ Returns: ``True`` to indicate it's a filepath.
+
+ """
+ return isinstance(data, str)
+
+ @abstractmethod
+ def load_file(self, filepath: str):
+ """The actual file loading logic.
+
+ Args:
+ filepath: The path to a dataset.
+ """
+ pass
+
+ def __getitem__(self, index: Union[int, slice]) -> Union[dict, List[dict]]:
+ """ Get the index-th sample in this dataset.
+
+ Args:
+ index: Either a integer index of a list of indices.
+
+ Returns: Either a sample or or list of samples depending on how many indices are passed in.
+
+ """
+ # if isinstance(index, (list, tuple)):
+ # assert len(index) == 1
+ # index = index[0]
+ if isinstance(index, slice):
+ indices = range(*index.indices(len(self)))
+ return [self[i] for i in indices]
+
+ if self.cache:
+ cache = self.cache[index]
+ if cache:
+ return cache
+ sample = self.data[index]
+ sample = self.transform_sample(sample)
+ if self.cache:
+ self.cache[index] = sample
+ return sample
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def __repr__(self) -> str:
+ return f'{len(self)} samples: {self[0]} ...'
+
+ def purge_cache(self):
+ """Purges all cache. If cache is not enabled, this method enables it.
+ """
+ self.cache = [None] * len(self.data)
+
+ def split(self, *ratios):
+ """Split dataset into subsets.
+
+ Args:
+ *ratios: The ratios for each subset. They can be any type of numbers which will be normalized. For example,
+ ``8, 1, 1`` are equivalent to ``0.8, 0.1, 0.1``.
+
+ Returns:
+ list[TransformableDataset]: A list of subsets.
+ """
+ ratios = [x / sum(ratios) for x in ratios]
+ chunks = []
+ prev = 0
+ for r in ratios:
+ cur = prev + math.ceil(len(self) * r)
+ chunks.append([prev, cur])
+ prev = cur
+ chunks[-1][1] = len(self)
+ outputs = []
+ for b, e in chunks:
+ dataset = copy(self)
+ dataset.data = dataset.data[b:e]
+ if dataset.cache:
+ dataset.cache = dataset.cache[b:e]
+ outputs.append(dataset)
+ return outputs
+
+ def k_fold(self, k, i):
+ """Perform k-fold sampling.
+
+ Args:
+ k (int): Number of folds.
+ i (int): The i-th fold.
+
+ Returns:
+ TransformableDataset: The i-th fold subset of this dataset.
+
+ """
+ assert 0 <= i <= k, f'Invalid split {i}'
+ train_indices, test_indices = k_fold(k, len(self), i)
+ return self.subset(train_indices), self.subset(test_indices)
+
+ def subset(self, indices):
+ """Create a subset given indices of samples.
+
+ Args:
+ indices: Indices of samples.
+
+ Returns:
+ TransformableDataset: The a subset of this dataset.
+ """
+ dataset = copy(self)
+ dataset.data = [dataset.data[i] for i in indices]
+ if dataset.cache:
+ dataset.cache = [dataset.cache[i] for i in indices]
+ return dataset
+
+ def shuffle(self):
+ """Shuffle this dataset inplace.
+ """
+ if not self.cache:
+ random.shuffle(self.data)
+ else:
+ z = list(zip(self.data, self.cache))
+ random.shuffle(z)
+ self.data, self.cache = zip(*z)
+
+ def prune(self, criterion: Callable, logger: Logger = None):
+ """Prune (to discard) samples according to a criterion.
+
+ Args:
+ criterion: A functions takes a sample as input and output ``True`` if the sample needs to be pruned.
+ logger: If any, log statistical messages using it.
+
+ Returns:
+ int: Size before pruning.
+ """
+ # noinspection PyTypeChecker
+ size_before = len(self)
+ good_ones = [i for i, s in enumerate(self) if not criterion(s)]
+ self.data = [self.data[i] for i in good_ones]
+ if self.cache:
+ self.cache = [self.cache[i] for i in good_ones]
+ if logger:
+ size_after = len(self)
+ num_pruned = size_before - size_after
+ logger.info(f'Pruned [yellow]{num_pruned} ({num_pruned / size_before:.1%})[/yellow] '
+ f'samples out of {size_before}.')
+ return size_before
+
+
+class TransformSequentialDataset(Transformable, IterableDataset, ABC):
+ pass
+
+
+class DeviceDataLoader(DataLoader):
+ def __init__(self, dataset, batch_size=32, shuffle=False, sampler=None,
+ batch_sampler=None, num_workers=None, collate_fn=None,
+ pin_memory=False, drop_last=False, timeout=0,
+ worker_init_fn=None, multiprocessing_context=None,
+ device=None, **kwargs):
+ if batch_sampler is not None:
+ batch_size = 1
+ if num_workers is None:
+ if isdebugging():
+ num_workers = 0
+ else:
+ num_workers = 2
+ # noinspection PyArgumentList
+ super(DeviceDataLoader, self).__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle,
+ sampler=sampler,
+ batch_sampler=batch_sampler, num_workers=num_workers,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory, drop_last=drop_last, timeout=timeout,
+ worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, **kwargs)
+ self.device = device
+
+ def __iter__(self):
+ for raw_batch in super(DeviceDataLoader, self).__iter__():
+ if self.device is not None:
+ for field, data in raw_batch.items():
+ if isinstance(data, torch.Tensor):
+ data = data.to(self.device)
+ raw_batch[field] = data
+ yield raw_batch
+
+ def collate_fn(self, samples):
+ return merge_list_of_dict(samples)
+
+
+class PadSequenceDataLoader(DataLoader):
+
+ def __init__(self, dataset, batch_size=32, shuffle=False, sampler=None,
+ batch_sampler=None, num_workers=0, collate_fn=None,
+ pin_memory=False, drop_last=False, timeout=0,
+ worker_init_fn=None, multiprocessing_context=None,
+ pad: dict = None, vocabs: VocabDict = None, device=None, **kwargs):
+ """ A dataloader commonly used for NLP tasks. It offers the following convenience.
+
+ - Bachify each field of samples into a :class:`~torch.Tensor` if the field name satisfies the following criterion.
+ - Name ends with _id, _ids, _count, _offset, _span, mask
+ - Name is in `pad` dict.
+
+ - Pad each field according to field name, the vocabs and pad dict.
+ - Move :class:`~torch.Tensor` onto device.
+
+ Args:
+ dataset: A :class:`~torch.utils.data.Dataset` to be bachified.
+ batch_size: Max size of each batch.
+ shuffle: ``True`` to shuffle batches.
+ sampler: A :class:`~torch.utils.data.Sampler` to sample samples from data.
+ batch_sampler: A :class:`~torch.utils.data.Sampler` to sample batches form all batches.
+ num_workers: Number of workers for multi-thread loading. Note that multi-thread loading aren't always
+ faster.
+ collate_fn: A function to perform batchifying. It must be set to ``None`` in order to make use of the
+ features this class offers.
+ pin_memory: If samples are loaded in the Dataset on CPU and would like to be pushed to
+ the GPU, enabling pin_memory can speed up the transfer. It's not useful since most data field are
+ not in Tensor type.
+ drop_last: Drop the last batch since it could be half-empty.
+ timeout: For multi-worker loading, set a timeout to wait for a worker.
+ worker_init_fn: Init function for multi-worker.
+ multiprocessing_context: Context for multiprocessing.
+ pad: A dict holding field names and their padding values.
+ vocabs: A dict of vocabs so padding value can be fetched from it.
+ device: The device tensors will be moved onto.
+ **kwargs: Other arguments will be passed to :meth:`torch.utils.data.Dataset.__init__`
+ """
+ if device == -1:
+ device = None
+ if collate_fn is None:
+ collate_fn = self.collate_fn
+ if num_workers is None:
+ if isdebugging():
+ num_workers = 0
+ else:
+ num_workers = 2
+ if batch_sampler is None:
+ assert batch_size, 'batch_size has to be specified when batch_sampler is None'
+ else:
+ batch_size = 1
+ shuffle = None
+ drop_last = None
+ # noinspection PyArgumentList
+ super(PadSequenceDataLoader, self).__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle,
+ sampler=sampler,
+ batch_sampler=batch_sampler, num_workers=num_workers,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory, drop_last=drop_last, timeout=timeout,
+ worker_init_fn=worker_init_fn,
+ multiprocessing_context=multiprocessing_context, **kwargs)
+ self.vocabs = vocabs
+ if isinstance(dataset, TransformableDataset) and dataset.transform:
+ transform = dataset.transform
+ if not isinstance(transform, TransformList):
+ transform = []
+ for each in transform:
+ if isinstance(each, EmbeddingNamedTransform):
+ if pad is None:
+ pad = {}
+ if each.dst not in pad:
+ pad[each.dst] = 0
+ self.pad = pad
+ self.device = device
+
+ def __iter__(self):
+ for raw_batch in super(PadSequenceDataLoader, self).__iter__():
+ for field, data in raw_batch.items():
+ if isinstance(data, torch.Tensor):
+ continue
+ vocab_key = field[:-len('_id')] if field.endswith('_id') else None
+ vocab: Vocab = self.vocabs.get(vocab_key, None) if self.vocabs and vocab_key else None
+ if vocab:
+ pad = vocab.safe_pad_token_idx
+ dtype = torch.long
+ elif self.pad is not None and field in self.pad:
+ pad = self.pad[field]
+ dtype = dtype_of(pad)
+ elif field.endswith('_offset') or field.endswith('_id') or field.endswith(
+ '_count') or field.endswith('_ids') or field.endswith('_score') or field.endswith(
+ '_length') or field.endswith('_span'):
+ # guess some common fields to pad
+ pad = 0
+ dtype = torch.long
+ elif field.endswith('_mask'):
+ pad = False
+ dtype = torch.bool
+ else:
+ # no need to pad
+ continue
+ data = self.pad_data(data, pad, dtype)
+ raw_batch[field] = data
+ if self.device is not None:
+ for field, data in raw_batch.items():
+ if isinstance(data, torch.Tensor):
+ data = data.to(self.device)
+ raw_batch[field] = data
+ yield raw_batch
+
+ @staticmethod
+ def pad_data(data: Union[torch.Tensor, Iterable], pad, dtype=None, device=None):
+ """Perform the actual padding for a given data.
+
+ Args:
+ data: Data to be padded.
+ pad: Padding value.
+ dtype: Data type.
+ device: Device to be moved onto.
+
+ Returns:
+ torch.Tensor: A ``torch.Tensor``.
+ """
+ if isinstance(data[0], torch.Tensor):
+ data = pad_sequence(data, True, pad)
+ elif isinstance(data[0], Iterable):
+ inner_is_iterable = False
+ for each in data:
+ if len(each):
+ if isinstance(each[0], Iterable):
+ inner_is_iterable = True
+ if len(each[0]):
+ if not dtype:
+ dtype = dtype_of(each[0][0])
+ else:
+ inner_is_iterable = False
+ if not dtype:
+ dtype = dtype_of(each[0])
+ break
+ if inner_is_iterable:
+ max_seq_len = len(max(data, key=len))
+ max_word_len = len(max([chars for words in data for chars in words], key=len))
+ ids = torch.zeros(len(data), max_seq_len, max_word_len, dtype=dtype, device=device)
+ for i, words in enumerate(data):
+ for j, chars in enumerate(words):
+ ids[i][j][:len(chars)] = torch.tensor(chars, dtype=dtype, device=device)
+ data = ids
+ else:
+ data = pad_sequence([torch.tensor(x, dtype=dtype, device=device) for x in data], True, pad)
+ elif isinstance(data, list):
+ data = torch.tensor(data, dtype=dtype, device=device)
+ return data
+
+ def collate_fn(self, samples):
+ return merge_list_of_dict(samples)
+
+
+def _prefetch_generator(dataloader, queue, batchify=None):
+ while True:
+ for batch in dataloader:
+ if batchify:
+ batch = batchify(batch)
+ queue.put(batch)
+
+
+class PrefetchDataLoader(DataLoader):
+ def __init__(self, dataloader: torch.utils.data.DataLoader, prefetch: int = 10, batchify: Callable = None) -> None:
+ """ A dataloader wrapper which speeds up bachifying using multi-processing. It works best for dataloaders
+ of which the bachify takes very long time. But it introduces extra GPU memory consumption since prefetched
+ batches are stored in a ``Queue`` on GPU.
+
+ .. Caution::
+
+ PrefetchDataLoader only works in spawn mode with the following initialization code:
+
+ Examples::
+
+ if __name__ == '__main__':
+ import torch
+
+ torch.multiprocessing.set_start_method('spawn')
+
+ And these 2 lines **MUST** be put into ``if __name__ == '__main__':`` block.
+
+ Args:
+ dataloader: A :class:`~torch.utils.data.DatasetLoader` to be prefetched.
+ prefetch: Number of batches to prefetch.
+ batchify: A bachify function called on each batch of samples. In which case, the inner dataloader shall
+ return samples without really bachify them.
+ """
+ super().__init__(dataset=dataloader)
+ self._batchify = batchify
+ self.prefetch = None if isdebugging() else prefetch
+ if self.prefetch:
+ self._fire_process(dataloader, prefetch)
+
+ def _fire_process(self, dataloader, prefetch):
+ self.queue = mp.Queue(prefetch)
+ self.process = mp.Process(target=_prefetch_generator, args=(dataloader, self.queue, self._batchify))
+ self.process.start()
+
+ def __iter__(self):
+ if not self.prefetch:
+ for batch in self.dataset:
+ if self._batchify:
+ batch = self._batchify(batch)
+ yield batch
+ else:
+ size = len(self)
+ while size:
+ batch = self.queue.get()
+ yield batch
+ size -= 1
+
+ def close(self):
+ """Close this dataloader and terminates internal processes and queue. It's recommended to call this method to
+ before a program can gracefully shutdown.
+ """
+ if self.prefetch:
+ self.queue.close()
+ self.process.terminate()
+
+ @property
+ def batchify(self):
+ return self._batchify
+
+ @batchify.setter
+ def batchify(self, batchify):
+ self._batchify = batchify
+ if not self.prefetch:
+ prefetch = vars(self.queue).get('maxsize', 10)
+ self.close()
+ self._fire_process(self.dataset, prefetch)
+
+
+class BucketSampler(Sampler):
+ # noinspection PyMissingConstructor
+ def __init__(self, buckets: Dict[float, List[int]], batch_max_tokens, batch_size=None, shuffle=False):
+ """A bucketing based sampler which groups samples into buckets then creates batches from each bucket.
+
+ Args:
+ buckets: A dict of which keys are some statistical numbers of each bucket, and values are the indices of
+ samples in each bucket.
+ batch_max_tokens: Maximum tokens per batch.
+ batch_size: Maximum samples per batch.
+ shuffle: ``True`` to shuffle batches and samples in a batch.
+ """
+ self.shuffle = shuffle
+ self.sizes, self.buckets = zip(*[
+ (size, bucket) for size, bucket in buckets.items()
+ ])
+ # the number of chunks in each bucket, which is clipped by
+ # range [1, len(bucket)]
+ if batch_size:
+ self.chunks = [
+ max(batch_size, min(len(bucket), max(round(size * len(bucket) / batch_max_tokens), 1)))
+ for size, bucket in zip(self.sizes, self.buckets)
+ ]
+ else:
+ self.chunks = [
+ min(len(bucket), max(round(size * len(bucket) / batch_max_tokens), 1))
+ for size, bucket in zip(self.sizes, self.buckets)
+ ]
+
+ def __iter__(self):
+ # if shuffle, shuffle both the buckets and samples in each bucket
+ range_fn = torch.randperm if self.shuffle else torch.arange
+ for i in range_fn(len(self.buckets)).tolist():
+ split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1 for j in range(self.chunks[i])]
+ # DON'T use `torch.chunk` which may return wrong number of chunks
+ for batch in range_fn(len(self.buckets[i])).split(split_sizes):
+ yield [self.buckets[i][j] for j in batch.tolist()]
+
+ def __len__(self):
+ return sum(self.chunks)
+
+
+class KMeansSampler(BucketSampler):
+ def __init__(self, lengths, batch_max_tokens, batch_size=None, shuffle=False, n_buckets=1):
+ """A bucket sampler which groups samples using KMeans on their lengths.
+
+ Args:
+ lengths: Lengths of each sample, usually measured by number of tokens.
+ batch_max_tokens: Maximum tokens per batch.
+ batch_size: Maximum samples per batch.
+ shuffle: ``True`` to shuffle batches. Samples in the same batch won't be shuffled since the ordered sequence
+ is helpful to speed up RNNs.
+ n_buckets: Number of buckets. Clusters in terms of KMeans.
+ """
+ if n_buckets > len(lengths):
+ n_buckets = 1
+ self.n_buckets = n_buckets
+ self.lengths = lengths
+ buckets = dict(zip(*kmeans(self.lengths, n_buckets)))
+ super().__init__(buckets, batch_max_tokens, batch_size, shuffle)
+
+
+class SortingSampler(Sampler):
+ # noinspection PyMissingConstructor
+ def __init__(self, lengths: List[int], batch_size=None, batch_max_tokens=None, shuffle=False) -> None:
+ """A sampler which sort samples according to their lengths. It takes a continuous chunk of sorted samples to
+ make a batch.
+
+ Args:
+ lengths: Lengths of each sample, usually measured by number of tokens.
+ batch_max_tokens: Maximum tokens per batch.
+ batch_size: Maximum samples per batch.
+ shuffle: ``True`` to shuffle batches and samples in a batch.
+ """
+ # assert any([batch_size, batch_max_tokens]), 'At least one of batch_size and batch_max_tokens is required'
+ self.shuffle = shuffle
+ self.batch_size = batch_size
+ # self.batch_max_tokens = batch_max_tokens
+ self.batch_indices = []
+ num_tokens = 0
+ mini_batch = []
+ for i in torch.argsort(torch.tensor(lengths), descending=True).tolist():
+ # if batch_max_tokens:
+ if (batch_max_tokens is None or num_tokens + lengths[i] <= batch_max_tokens) and (
+ batch_size is None or len(mini_batch) < batch_size):
+ mini_batch.append(i)
+ num_tokens += lengths[i]
+ else:
+ if not mini_batch: # this sequence is longer than batch_max_tokens
+ mini_batch.append(i)
+ self.batch_indices.append(mini_batch)
+ mini_batch = []
+ num_tokens = 0
+ else:
+ self.batch_indices.append(mini_batch)
+ mini_batch = [i]
+ num_tokens = lengths[i]
+ if mini_batch:
+ self.batch_indices.append(mini_batch)
+
+ def __iter__(self):
+ # if self.shuffle:
+ # random.shuffle(self.batch_indices)
+ for batch in self.batch_indices:
+ yield batch
+
+ def __len__(self) -> int:
+ return len(self.batch_indices)
+
+
+class SamplerBuilder(AutoConfigurable, ABC):
+ @abstractmethod
+ def build(self, lengths: List[int], shuffle=False, gradient_accumulation=1, **kwargs) -> Sampler:
+ """Build a ``Sampler`` given statistics of samples and other arguments.
+
+ Args:
+ lengths: The lengths of samples.
+ shuffle: ``True`` to shuffle batches. Note samples in each mini-batch are not necessarily shuffled.
+ gradient_accumulation: Number of mini-batches per update step.
+ **kwargs: Other arguments to be passed to the constructor of the sampler.
+ """
+ pass
+
+ def __call__(self, lengths: List[int], shuffle=False, **kwargs) -> Sampler:
+ return self.build(lengths, shuffle, **kwargs)
+
+ def scale(self, gradient_accumulation):
+ r"""Scale down the ``batch_size`` and ``batch_max_tokens`` to :math:`\frac{1}{\text{gradient_accumulation}}`
+ of them respectively.
+
+ Args:
+ gradient_accumulation: Number of mini-batches per update step.
+
+ Returns:
+ tuple(int,int): batch_size, batch_max_tokens
+ """
+ batch_size = self.batch_size
+ batch_max_tokens = self.batch_max_tokens
+ if gradient_accumulation:
+ if batch_size:
+ batch_size //= gradient_accumulation
+ if batch_max_tokens:
+ batch_max_tokens //= gradient_accumulation
+ return batch_size, batch_max_tokens
+
+
+class SortingSamplerBuilder(SortingSampler, SamplerBuilder):
+ # noinspection PyMissingConstructor
+ def __init__(self, batch_size=None, batch_max_tokens=None) -> None:
+ """Builds a :class:`~hanlp.common.dataset.SortingSampler`.
+
+ Args:
+ batch_max_tokens: Maximum tokens per batch.
+ batch_size: Maximum samples per batch.
+ """
+ self.batch_max_tokens = batch_max_tokens
+ self.batch_size = batch_size
+
+ def build(self, lengths: List[int], shuffle=False, gradient_accumulation=1, **kwargs) -> Sampler:
+ batch_size, batch_max_tokens = self.scale(gradient_accumulation)
+ return SortingSampler(lengths, batch_size, batch_max_tokens, shuffle)
+
+ def __len__(self) -> int:
+ return 1
+
+
+class KMeansSamplerBuilder(KMeansSampler, SamplerBuilder):
+ # noinspection PyMissingConstructor
+ def __init__(self, batch_max_tokens, batch_size=None, n_buckets=1):
+ """Builds a :class:`~hanlp.common.dataset.KMeansSampler`.
+
+ Args:
+ batch_max_tokens: Maximum tokens per batch.
+ batch_size: Maximum samples per batch.
+ n_buckets: Number of buckets. Clusters in terms of KMeans.
+ """
+ self.n_buckets = n_buckets
+ self.batch_size = batch_size
+ self.batch_max_tokens = batch_max_tokens
+
+ def build(self, lengths: List[int], shuffle=False, gradient_accumulation=1, **kwargs) -> Sampler:
+ batch_size, batch_max_tokens = self.scale(gradient_accumulation)
+ return KMeansSampler(lengths, batch_max_tokens, batch_size, shuffle, self.n_buckets)
+
+ def __len__(self) -> int:
+ return 1
+
+
+class TableDataset(TransformableDataset):
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ delimiter='auto',
+ strip=True,
+ headers=None) -> None:
+ self.headers = headers
+ self.strip = strip
+ self.delimiter = delimiter
+ super().__init__(data, transform, cache)
+
+ def load_file(self, filepath: str):
+ for idx, cells in enumerate(read_cells(filepath, strip=self.strip, delimiter=self.delimiter)):
+ if not idx and not self.headers:
+ self.headers = cells
+ if any(len(h) > 32 for h in self.headers):
+ warnings.warn('As you did not pass in `headers` to `TableDataset`, the first line is regarded as '
+ 'headers. However, the length for some headers are too long (>32), which might be '
+ 'wrong. To make sure, pass `headers=...` explicitly.')
+ else:
+ yield dict(zip(self.headers, cells))
diff --git a/hanlp/common/document.py b/hanlp/common/document.py
deleted file mode 100644
index f433d8c5d..000000000
--- a/hanlp/common/document.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-31 04:16
-import json
-from typing import List
-
-from hanlp.common.structure import SerializableDict
-from hanlp.components.parsers.conll import CoNLLSentence
-from hanlp.utils.util import collapse_json
-
-
-class Sentence(SerializableDict):
- KEY_WORDS = 'words'
- KEY_POS = 'pos'
- KEY_NER = 'ner'
-
- def __init__(self, **kwargs) -> None:
- super().__init__()
- self.update(kwargs)
-
- @property
- def words(self) -> List[str]:
- return self.get(Sentence.KEY_WORDS)
-
- @words.setter
- def words(self, words: List[str]):
- self[Sentence.KEY_WORDS] = words
-
-
-class Document(SerializableDict):
- def __init__(self) -> None:
- super().__init__()
- # self.sentences = []
- # self.tokens = []
- # self.part_of_speech_tags = []
- # self.named_entities = []
- # self.syntactic_dependencies = []
- # self.semantic_dependencies = []
-
- def __missing__(self, key):
- value = []
- self[key] = value
- return value
-
- def to_dict(self) -> dict:
- return dict((k, v) for k, v in self.items() if v)
-
- def to_json(self, ensure_ascii=False, indent=2) -> str:
- text = json.dumps(self.to_dict(), ensure_ascii=ensure_ascii, indent=indent)
- text = collapse_json(text, 4)
- return text
-
- def __str__(self) -> str:
- return self.to_json()
-
- def to_conll(self) -> List[CoNLLSentence]:
- # try to find if any field is conll type
- if self.semantic_dependencies and isinstance(self.semantic_dependencies[0], CoNLLSentence):
- return self.semantic_dependencies
- if self.syntactic_dependencies and isinstance(self.syntactic_dependencies[0], CoNLLSentence):
- return self.syntactic_dependencies
- for k, v in self.items():
- if len(v) and isinstance(v[0], CoNLLSentence):
- return v
diff --git a/hanlp/common/keras_component.py b/hanlp/common/keras_component.py
new file mode 100644
index 000000000..488eae1f0
--- /dev/null
+++ b/hanlp/common/keras_component.py
@@ -0,0 +1,498 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-08-26 14:45
+import logging
+import math
+import os
+from abc import ABC, abstractmethod
+from typing import Optional, List, Any, Dict
+
+import numpy as np
+import tensorflow as tf
+
+import hanlp.utils
+from hanlp_common.io import save_json,load_json
+from hanlp.callbacks.fine_csv_logger import FineCSVLogger
+from hanlp.common.component import Component
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.metrics.chunking.iobes_tf import IOBES_F1_TF
+from hanlp.optimizers.adamw import AdamWeightDecay
+from hanlp.utils import io_util
+from hanlp.utils.io_util import get_resource, tempdir_human
+from hanlp.utils.log_util import init_logger, logger
+from hanlp.utils.string_util import format_scores
+from hanlp.utils.tf_util import format_metrics, size_of_dataset, summary_of_model, get_callback_by_class
+from hanlp.utils.time_util import Timer, now_datetime
+from hanlp_common.reflection import str_to_type, classpath_of
+from hanlp_common.structure import SerializableDict
+from hanlp_common.util import merge_dict
+
+
+class KerasComponent(Component, ABC):
+ def __init__(self, transform: Transform) -> None:
+ super().__init__()
+ self.meta = {
+ 'class_path': classpath_of(self),
+ 'hanlp_version': hanlp.version.__version__,
+ }
+ self.model: Optional[tf.keras.Model] = None
+ self.config = SerializableDict()
+ self.transform = transform
+ # share config with transform for convenience, so we don't need to pass args around
+ if self.transform.config:
+ for k, v in self.transform.config.items():
+ self.config[k] = v
+ self.transform.config = self.config
+
+ def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None,
+ callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs):
+ input_path = get_resource(input_path)
+ file_prefix, ext = os.path.splitext(input_path)
+ name = os.path.basename(file_prefix)
+ if not name:
+ name = 'evaluate'
+ if save_dir and not logger:
+ logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN,
+ mode='w')
+ tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size)
+ samples = self.num_samples_in(tst_data)
+ num_batches = math.ceil(samples / batch_size)
+ if warm_up:
+ for x, y in tst_data:
+ self.model.predict_on_batch(x)
+ break
+ if output:
+ assert save_dir, 'Must pass save_dir in order to output'
+ if isinstance(output, bool):
+ output = os.path.join(save_dir, name) + '.predict' + ext
+ elif isinstance(output, str):
+ output = output
+ else:
+ raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output)))
+ timer = Timer()
+ eval_outputs = self.evaluate_dataset(tst_data, callbacks, output, num_batches, **kwargs)
+ loss, score, output = eval_outputs[0], eval_outputs[1], eval_outputs[2]
+ delta_time = timer.stop()
+ speed = samples / delta_time.delta_seconds
+
+ if logger:
+ f1: IOBES_F1_TF = None
+ for metric in self.model.metrics:
+ if isinstance(metric, IOBES_F1_TF):
+ f1 = metric
+ break
+ extra_report = ''
+ if f1:
+ overall, by_type, extra_report = f1.state.result(full=True, verbose=False)
+ extra_report = ' \n' + extra_report
+ logger.info('Evaluation results for {} - '
+ 'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}'
+ .format(name + ext, loss,
+ format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics),
+ speed, extra_report))
+ if output:
+ logger.info('Saving output to {}'.format(output))
+ with open(output, 'w', encoding='utf-8') as out:
+ self.evaluate_output(tst_data, out, num_batches, self.model.metrics)
+
+ return loss, score, speed
+
+ def num_samples_in(self, dataset):
+ return size_of_dataset(dataset)
+
+ def evaluate_dataset(self, tst_data, callbacks, output, num_batches, **kwargs):
+ loss, score = self.model.evaluate(tst_data, callbacks=callbacks, steps=num_batches)
+ return loss, score, output
+
+ def evaluate_output(self, tst_data, out, num_batches, metrics: List[tf.keras.metrics.Metric]):
+ # out.write('x\ty_true\ty_pred\n')
+ for metric in metrics:
+ metric.reset_states()
+ for idx, batch in enumerate(tst_data):
+ outputs = self.model.predict_on_batch(batch[0])
+ for metric in metrics:
+ metric(batch[1], outputs, outputs._keras_mask if hasattr(outputs, '_keras_mask') else None)
+ self.evaluate_output_to_file(batch, outputs, out)
+ print('\r{}/{} {}'.format(idx + 1, num_batches, format_metrics(metrics)), end='')
+ print()
+
+ def evaluate_output_to_file(self, batch, outputs, out):
+ for x, y_gold, y_pred in zip(self.transform.X_to_inputs(batch[0]),
+ self.transform.Y_to_outputs(batch[1], gold=True),
+ self.transform.Y_to_outputs(outputs, gold=False)):
+ out.write(self.transform.input_truth_output_to_str(x, y_gold, y_pred))
+
+ def _capture_config(self, config: Dict,
+ exclude=(
+ 'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose',
+ 'dev_batch_size', '__class__')):
+ """
+ Save arguments to config
+
+ Parameters
+ ----------
+ config
+ `locals()`
+ exclude
+ """
+ if 'kwargs' in config:
+ config.update(config['kwargs'])
+ config = dict(
+ (key, tf.keras.utils.serialize_keras_object(value)) if hasattr(value, 'get_config') else (key, value) for
+ key, value in config.items())
+ for key in exclude:
+ config.pop(key, None)
+ self.config.update(config)
+
+ def save_meta(self, save_dir, filename='meta.json', **kwargs):
+ self.meta['create_time']: now_datetime()
+ self.meta.update(kwargs)
+ save_json(self.meta, os.path.join(save_dir, filename))
+
+ def load_meta(self, save_dir, filename='meta.json'):
+ save_dir = get_resource(save_dir)
+ metapath = os.path.join(save_dir, filename)
+ if os.path.isfile(metapath):
+ self.meta.update(load_json(metapath))
+
+ def save_config(self, save_dir, filename='config.json'):
+ self.config.save_json(os.path.join(save_dir, filename))
+
+ def load_config(self, save_dir, filename='config.json'):
+ save_dir = get_resource(save_dir)
+ self.config.load_json(os.path.join(save_dir, filename))
+
+ def save_weights(self, save_dir, filename='model.h5'):
+ self.model.save_weights(os.path.join(save_dir, filename))
+
+ def load_weights(self, save_dir, filename='model.h5', **kwargs):
+ assert self.model.built or self.model.weights, 'You must call self.model.built() in build_model() ' \
+ 'in order to load it'
+ save_dir = get_resource(save_dir)
+ self.model.load_weights(os.path.join(save_dir, filename))
+
+ def save_vocabs(self, save_dir, filename='vocabs.json'):
+ vocabs = SerializableDict()
+ for key, value in vars(self.transform).items():
+ if isinstance(value, VocabTF):
+ vocabs[key] = value.to_dict()
+ vocabs.save_json(os.path.join(save_dir, filename))
+
+ def load_vocabs(self, save_dir, filename='vocabs.json'):
+ save_dir = get_resource(save_dir)
+ vocabs = SerializableDict()
+ vocabs.load_json(os.path.join(save_dir, filename))
+ for key, value in vocabs.items():
+ vocab = VocabTF()
+ vocab.copy_from(value)
+ setattr(self.transform, key, vocab)
+
+ def load_transform(self, save_dir) -> Transform:
+ """
+ Try to load transform only. This method might fail due to the fact it avoids building the model.
+ If it do fail, then you have to use `load` which might be too heavy but that's the best we can do.
+ :param save_dir: The path to load.
+ """
+ save_dir = get_resource(save_dir)
+ self.load_config(save_dir)
+ self.load_vocabs(save_dir)
+ self.transform.build_config()
+ self.transform.lock_vocabs()
+ return self.transform
+
+ def save(self, save_dir: str, **kwargs):
+ self.save_config(save_dir)
+ self.save_vocabs(save_dir)
+ self.save_weights(save_dir)
+
+ def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs):
+ self.meta['load_path'] = save_dir
+ save_dir = get_resource(save_dir)
+ self.load_config(save_dir)
+ self.load_vocabs(save_dir)
+ self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True))
+ self.load_weights(save_dir, **kwargs)
+ self.load_meta(save_dir)
+
+ @property
+ def input_shape(self) -> List:
+ return self.transform.output_shapes[0]
+
+ def build(self, logger, **kwargs):
+ self.transform.build_config()
+ self.model = self.build_model(**merge_dict(self.config, training=kwargs.get('training', None),
+ loss=kwargs.get('loss', None)))
+ self.transform.lock_vocabs()
+ optimizer = self.build_optimizer(**self.config)
+ loss = self.build_loss(
+ **self.config if 'loss' in self.config else dict(list(self.config.items()) + [('loss', None)]))
+ # allow for different
+ metrics = self.build_metrics(**merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'),
+ logger=logger, overwrite=True))
+ if not isinstance(metrics, list):
+ if isinstance(metrics, tf.keras.metrics.Metric):
+ metrics = [metrics]
+ if not self.model.built:
+ sample_inputs = self.sample_data
+ if sample_inputs is not None:
+ self.model(sample_inputs)
+ else:
+ if len(self.transform.output_shapes[0]) == 1 and self.transform.output_shapes[0][0] is None:
+ x_shape = self.transform.output_shapes[0]
+ else:
+ x_shape = list(self.transform.output_shapes[0])
+ for i, shape in enumerate(x_shape):
+ x_shape[i] = [None] + shape # batch + X.shape
+ self.model.build(input_shape=x_shape)
+ self.compile_model(optimizer, loss, metrics)
+ return self.model, optimizer, loss, metrics
+
+ def compile_model(self, optimizer, loss, metrics):
+ self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics, run_eagerly=self.config.run_eagerly)
+
+ def build_optimizer(self, optimizer, **kwargs):
+ if isinstance(optimizer, (str, dict)):
+ custom_objects = {'AdamWeightDecay': AdamWeightDecay}
+ optimizer: tf.keras.optimizers.Optimizer = tf.keras.utils.deserialize_keras_object(optimizer,
+ module_objects=vars(tf.keras.optimizers),
+ custom_objects=custom_objects)
+ self.config.optimizer = tf.keras.utils.serialize_keras_object(optimizer)
+ return optimizer
+
+ def build_loss(self, loss, **kwargs):
+ if not loss:
+ loss = tf.keras.losses.SparseCategoricalCrossentropy(
+ reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
+ from_logits=True)
+ elif isinstance(loss, (str, dict)):
+ loss = tf.keras.utils.deserialize_keras_object(loss, module_objects=vars(tf.keras.losses))
+ if isinstance(loss, tf.keras.losses.Loss):
+ self.config.loss = tf.keras.utils.serialize_keras_object(loss)
+ return loss
+
+ def build_transform(self, **kwargs):
+ return self.transform
+
+ def build_vocab(self, trn_data, logger):
+ train_examples = self.transform.fit(trn_data, **self.config)
+ self.transform.summarize_vocabs(logger)
+ return train_examples
+
+ def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
+ metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
+ return [metric]
+
+ @abstractmethod
+ def build_model(self, **kwargs) -> tf.keras.Model:
+ pass
+
+ def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True,
+ **kwargs):
+ self._capture_config(locals())
+ self.transform = self.build_transform(**self.config)
+ if not save_dir:
+ save_dir = tempdir_human()
+ if not logger:
+ logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN)
+ logger.info('Hyperparameter:\n' + self.config.to_json())
+ num_examples = self.build_vocab(trn_data, logger)
+ # assert num_examples, 'You forgot to return the number of training examples in your build_vocab'
+ logger.info('Building...')
+ train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None
+ self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None
+ model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True))
+ logger.info('Model built:\n' + summary_of_model(self.model))
+ self.save_config(save_dir)
+ self.save_vocabs(save_dir)
+ self.save_meta(save_dir)
+ trn_data = self.build_train_dataset(trn_data, batch_size, num_examples)
+ dev_data = self.build_valid_dataset(dev_data, batch_size)
+ callbacks = self.build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger))
+ # need to know #batches, otherwise progbar crashes
+ dev_steps = math.ceil(self.num_samples_in(dev_data) / batch_size)
+ checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint)
+ timer = Timer()
+ try:
+ history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs,
+ num_examples=num_examples,
+ train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps,
+ callbacks=callbacks, logger=logger, model=model, optimizer=optimizer,
+ loss=loss,
+ metrics=metrics, overwrite=True))
+ except KeyboardInterrupt:
+ print()
+ if not checkpoint or checkpoint.best in (np.Inf, -np.Inf):
+ self.save_weights(save_dir)
+ logger.info('Aborted with model saved')
+ else:
+ logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}')
+ # noinspection PyTypeChecker
+ history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History)
+ delta_time = timer.stop()
+ best_epoch_ago = 0
+ if history and hasattr(history, 'epoch'):
+ trained_epoch = len(history.epoch)
+ logger.info('Trained {} epochs in {}, each epoch takes {}'.
+ format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time))
+ save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder)
+ monitor_history: List = history.history.get(checkpoint.monitor, None)
+ if monitor_history:
+ best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best)
+ if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]:
+ logger.info(f'Restored the best model saved with best '
+ f'{checkpoint.monitor} = {checkpoint.best:.4f} '
+ f'saved {best_epoch_ago} epochs ago')
+ self.load_weights(save_dir) # restore best model
+ return history
+
+ def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer,
+ loss, metrics, callbacks,
+ logger, **kwargs):
+ history = self.model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch,
+ validation_data=dev_data,
+ callbacks=callbacks,
+ validation_steps=dev_steps,
+ ) # type:tf.keras.callbacks.History
+ return history
+
+ def build_valid_dataset(self, dev_data, batch_size):
+ dev_data = self.transform.file_to_dataset(dev_data, batch_size=batch_size, shuffle=False)
+ return dev_data
+
+ def build_train_dataset(self, trn_data, batch_size, num_examples):
+ trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size,
+ shuffle=True,
+ repeat=-1 if self.config.train_steps else None)
+ return trn_data
+
+ def build_callbacks(self, save_dir, logger, **kwargs):
+ metrics = kwargs.get('metrics', 'accuracy')
+ if isinstance(metrics, (list, tuple)):
+ metrics = metrics[-1]
+ monitor = f'val_{metrics}'
+ checkpoint = tf.keras.callbacks.ModelCheckpoint(
+ os.path.join(save_dir, 'model.h5'),
+ # verbose=1,
+ monitor=monitor, save_best_only=True,
+ mode='max',
+ save_weights_only=True)
+ logger.debug(f'Monitor {checkpoint.monitor} for checkpoint')
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs')))
+ csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True)
+ callbacks = [checkpoint, tensorboard_callback, csv_logger]
+ lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None)
+ if lr_decay_per_epoch:
+ learning_rate = self.model.optimizer.get_config().get('learning_rate', None)
+ if not learning_rate:
+ logger.warning('Learning rate decay not supported for optimizer={}'.format(repr(self.model.optimizer)))
+ else:
+ logger.debug(f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}')
+ callbacks.append(tf.keras.callbacks.LearningRateScheduler(
+ lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch)))
+ anneal_factor = self.config.get('anneal_factor', None)
+ if anneal_factor:
+ callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor,
+ patience=self.config.get('anneal_patience', 10)))
+ early_stopping_patience = self.config.get('early_stopping_patience', None)
+ if early_stopping_patience:
+ callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=monitor, mode='max',
+ verbose=1,
+ patience=early_stopping_patience))
+ return callbacks
+
+ def on_train_begin(self):
+ """
+ Callback before the training starts
+ """
+ pass
+
+ def predict(self, data: Any, batch_size=None, **kwargs):
+ assert self.model, 'Please call fit or load before predict'
+ if not data:
+ return []
+ data, flat = self.transform.input_to_inputs(data)
+
+ if not batch_size:
+ batch_size = self.config.batch_size
+
+ dataset = self.transform.inputs_to_dataset(data, batch_size=batch_size, gold=kwargs.get('gold', False))
+
+ results = []
+ num_samples = 0
+ data_is_list = isinstance(data, list)
+ for idx, batch in enumerate(dataset):
+ samples_in_batch = tf.shape(batch[-1] if isinstance(batch[-1], tf.Tensor) else batch[-1][0])[0]
+ if data_is_list:
+ inputs = data[num_samples:num_samples + samples_in_batch]
+ else:
+ inputs = None # if data is a generator, it's usually one-time, not able to transform into a list
+ for output in self.predict_batch(batch, inputs=inputs, **kwargs):
+ results.append(output)
+ num_samples += samples_in_batch
+
+ if flat:
+ return results[0]
+ return results
+
+ def predict_batch(self, batch, inputs=None, **kwargs):
+ X = batch[0]
+ Y = self.model.predict_on_batch(X)
+ for output in self.transform.Y_to_outputs(Y, X=X, inputs=inputs, batch=batch, **kwargs):
+ yield output
+
+ @property
+ def sample_data(self):
+ return None
+
+ @staticmethod
+ def from_meta(meta: dict, **kwargs):
+ """
+
+ Parameters
+ ----------
+ meta
+ kwargs
+
+ Returns
+ -------
+ KerasComponent
+
+ """
+ cls = str_to_type(meta['class_path'])
+ obj: KerasComponent = cls()
+ assert 'load_path' in meta, f'{meta} doesn\'t contain load_path field'
+ obj.load(meta['load_path'])
+ return obj
+
+ def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False):
+ assert self.model, 'You have to fit or load a model before exporting it'
+ if not export_dir:
+ assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present'
+ export_dir = get_resource(self.meta['load_path'])
+ model_path = os.path.join(export_dir, str(version))
+ if os.path.isdir(model_path) and not overwrite:
+ logger.info(f'{model_path} exists, skip since overwrite = {overwrite}')
+ return export_dir
+ logger.info(f'Exporting to {export_dir} ...')
+ tf.saved_model.save(self.model, model_path)
+ logger.info(f'Successfully exported model to {export_dir}')
+ if show_hint:
+ logger.info(f'You can serve it through \n'
+ f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} '
+ f'--model_base_path={export_dir} --rest_api_port=8888')
+ return export_dir
+
+ def serve(self, export_dir=None, grpc_port=8500, rest_api_port=0, overwrite=False, dry_run=False):
+ export_dir = self.export_model_for_serving(export_dir, show_hint=False, overwrite=overwrite)
+ if not dry_run:
+ del self.model # free memory
+ logger.info('The inputs of exported model is shown below.')
+ os.system(f'saved_model_cli show --all --dir {export_dir}/1')
+ cmd = f'nohup tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} ' \
+ f'--model_base_path={export_dir} --port={grpc_port} --rest_api_port={rest_api_port} ' \
+ f'>serve.log 2>&1 &'
+ logger.info(f'Running ...\n{cmd}')
+ if not dry_run:
+ os.system(cmd)
diff --git a/hanlp/common/structure.py b/hanlp/common/structure.py
index 5e1d3d415..cc9c61c75 100644
--- a/hanlp/common/structure.py
+++ b/hanlp/common/structure.py
@@ -1,99 +1,69 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-08-26 14:58
-import json
-
-from hanlp.utils.io_util import save_json, save_pickle, load_pickle, load_json, filename_is_json
-
-
-class Serializable(object):
- """
- A super class for save/load operations.
- """
-
- def save(self, path, fmt=None):
- if not fmt:
- if filename_is_json(path):
- self.save_json(path)
- else:
- self.save_pickle(path)
- elif fmt in ['json', 'jsonl']:
- self.save_json(path)
- else:
- self.save_pickle(path)
-
- def load(self, path, fmt=None):
- if not fmt:
- if filename_is_json(path):
- self.load_json(path)
- else:
- self.load_pickle(path)
- elif fmt in ['json', 'jsonl']:
- self.load_json(path)
- else:
- self.load_pickle(path)
-
- def save_pickle(self, path):
- """Save to path
-
- Parameters
- ----------
- path : str
- file path
- """
- save_pickle(self, path)
-
- def load_pickle(self, path):
- """Load from path
+from typing import Dict
- Parameters
- ----------
- path : str
- file path
-
- Returns
- -------
- Serializable
- An object
- """
- item = load_pickle(path)
- return self.copy_from(item)
+from hanlp_common.configurable import Configurable
+from hanlp_common.reflection import classpath_of
+from hanlp_common.structure import SerializableDict
- def save_json(self, path):
- save_json(self.to_dict(), path)
- def load_json(self, path):
- item = load_json(path)
- return self.copy_from(item)
+class ConfigTracker(Configurable):
- # @abstractmethod
- def copy_from(self, item):
- self.__dict__ = item.__dict__
- # raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+ def __init__(self, locals_: Dict, exclude=('kwargs', 'self', '__class__', 'locals_')) -> None:
+ """This base class helps sub-classes to capture their arguments passed to ``__init__``, and also their types so
+ that they can be deserialized from a config in dict form.
- def to_json(self, ensure_ascii=False, indent=2) -> str:
- return json.dumps(self, ensure_ascii=ensure_ascii, indent=indent, default=lambda o: repr(o))
+ Args:
+ locals_: Obtained by :meth:`locals`.
+ exclude: Arguments to be excluded.
- def to_dict(self) -> dict:
- return self.__dict__
+ Examples:
+ >>> class MyClass(ConfigTracker):
+ >>> def __init__(self, i_need_this='yes') -> None:
+ >>> super().__init__(locals())
+ >>> obj = MyClass()
+ >>> print(obj.config)
+ {'i_need_this': 'yes', 'classpath': 'test_config_tracker.MyClass'}
+ """
+ if 'kwargs' in locals_:
+ locals_.update(locals_['kwargs'])
+ self.config = SerializableDict(
+ (k, v.config if hasattr(v, 'config') else v) for k, v in locals_.items() if k not in exclude)
+ self.config['classpath'] = classpath_of(self)
+
+
+class History(object):
+ def __init__(self):
+ """ A history of training context. It records how many steps have passed and provides methods to decide whether
+ an update should be performed, and to caculate number of training steps given dataloader size and
+ ``gradient_accumulation``.
+ """
+ self.num_mini_batches = 0
-class SerializableDict(Serializable, dict):
+ def step(self, gradient_accumulation):
+ """ Whether the training procedure should perform an update.
- def save_json(self, path):
- save_json(self, path)
+ Args:
+ gradient_accumulation: Number of batches per update.
- def copy_from(self, item):
- if isinstance(item, dict):
- self.clear()
- self.update(item)
+ Returns:
+ bool: ``True`` to update.
+ """
+ self.num_mini_batches += 1
+ return self.num_mini_batches % gradient_accumulation == 0
- def __getattr__(self, key):
- if key.startswith('__'):
- return dict.__getattr__(key)
- return self.__getitem__(key)
+ def num_training_steps(self, num_batches, gradient_accumulation):
+ """ Caculate number of training steps.
- def __setattr__(self, key, value):
- return self.__setitem__(key, value)
+ Args:
+ num_batches: Size of dataloader.
+ gradient_accumulation: Number of batches per update.
+ Returns:
+ """
+ return len(
+ [i for i in range(self.num_mini_batches + 1, self.num_mini_batches + num_batches + 1) if
+ i % gradient_accumulation == 0])
diff --git a/hanlp/common/torch_component.py b/hanlp/common/torch_component.py
new file mode 100644
index 000000000..aef2715ea
--- /dev/null
+++ b/hanlp/common/torch_component.py
@@ -0,0 +1,631 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-08 21:20
+import logging
+import os
+import re
+import time
+from abc import ABC, abstractmethod
+from typing import Optional, Dict, List, Union, Callable
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+import hanlp
+from hanlp.common.component import Component
+from hanlp.common.dataset import TransformableDataset
+from hanlp.common.transform import VocabDict
+from hanlp.utils.io_util import get_resource, basename_no_ext
+from hanlp.utils.log_util import init_logger, flash
+from hanlp.utils.torch_util import cuda_devices, set_seed
+from hanlp_common.configurable import Configurable
+from hanlp_common.constant import IDX, HANLP_VERBOSE
+from hanlp_common.reflection import classpath_of
+from hanlp_common.structure import SerializableDict
+from hanlp_common.util import merge_dict, isdebugging
+
+
+class TorchComponent(Component, ABC):
+ def __init__(self, **kwargs) -> None:
+ """The base class for all components using PyTorch as backend. It provides common workflows of building vocabs,
+ datasets, dataloaders and models. These workflows are more of a conventional guideline than en-forced
+ protocols, which means subclass has the freedom to override or completely skip some steps.
+
+ Args:
+ **kwargs: Addtional arguments to be stored in the ``config`` property.
+ """
+ super().__init__()
+ self.model: Optional[torch.nn.Module] = None
+ self.config = SerializableDict(**kwargs)
+ self.vocabs = VocabDict()
+
+ def _capture_config(self, locals_: Dict,
+ exclude=(
+ 'trn_data', 'dev_data', 'save_dir', 'kwargs', 'self', 'logger', 'verbose',
+ 'dev_batch_size', '__class__', 'devices', 'eval_trn')):
+ """Save arguments to config
+
+ Args:
+ locals_: Dict:
+ exclude: (Default value = ('trn_data')
+ 'dev_data':
+ 'save_dir':
+ 'kwargs':
+ 'self':
+ 'logger':
+ 'verbose':
+ 'dev_batch_size':
+ '__class__':
+ 'devices'):
+
+ Returns:
+
+
+ """
+ if 'kwargs' in locals_:
+ locals_.update(locals_['kwargs'])
+ locals_ = dict((k, v) for k, v in locals_.items() if k not in exclude and not k.startswith('_'))
+ self.config.update(locals_)
+ return self.config
+
+ def save_weights(self, save_dir, filename='model.pt', trainable_only=True, **kwargs):
+ """Save model weights to a directory.
+
+ Args:
+ save_dir: The directory to save weights into.
+ filename: A file name for weights.
+ trainable_only: ``True`` to only save trainable weights. Useful when the model contains lots of static
+ embeddings.
+ **kwargs: Not used for now.
+ """
+ model = self.model_
+ state_dict = model.state_dict()
+ if trainable_only:
+ trainable_names = set(n for n, p in model.named_parameters() if p.requires_grad)
+ state_dict = dict((n, p) for n, p in state_dict.items() if n in trainable_names)
+ torch.save(state_dict, os.path.join(save_dir, filename))
+
+ def load_weights(self, save_dir, filename='model.pt', **kwargs):
+ """Load weights from a directory.
+
+ Args:
+ save_dir: The directory to load weights from.
+ filename: A file name for weights.
+ **kwargs: Not used.
+ """
+ save_dir = get_resource(save_dir)
+ filename = os.path.join(save_dir, filename)
+ # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]')
+ self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False)
+ # flash('')
+
+ def save_config(self, save_dir, filename='config.json'):
+ """Save config into a directory.
+
+ Args:
+ save_dir: The directory to save config.
+ filename: A file name for config.
+ """
+ self._savable_config.save_json(os.path.join(save_dir, filename))
+
+ def load_config(self, save_dir, filename='config.json', **kwargs):
+ """Load config from a directory.
+
+ Args:
+ save_dir: The directory to load config.
+ filename: A file name for config.
+ **kwargs: K-V pairs to override config.
+ """
+ save_dir = get_resource(save_dir)
+ self.config.load_json(os.path.join(save_dir, filename))
+ self.config.update(kwargs) # overwrite config loaded from disk
+ for k, v in self.config.items():
+ if isinstance(v, dict) and 'classpath' in v:
+ self.config[k] = Configurable.from_config(v)
+ self.on_config_ready(**self.config)
+
+ def save_vocabs(self, save_dir, filename='vocabs.json'):
+ """Save vocabularies to a directory.
+
+ Args:
+ save_dir: The directory to save vocabularies.
+ filename: The name for vocabularies.
+ """
+ if hasattr(self, 'vocabs'):
+ self.vocabs.save_vocabs(save_dir, filename)
+
+ def load_vocabs(self, save_dir, filename='vocabs.json'):
+ """Load vocabularies from a directory.
+
+ Args:
+ save_dir: The directory to load vocabularies.
+ filename: The name for vocabularies.
+ """
+ if hasattr(self, 'vocabs'):
+ self.vocabs = VocabDict()
+ self.vocabs.load_vocabs(save_dir, filename)
+
+ def save(self, save_dir: str, **kwargs):
+ """Save this component to a directory.
+
+ Args:
+ save_dir: The directory to save this component.
+ **kwargs: Not used.
+ """
+ self.save_config(save_dir)
+ self.save_vocabs(save_dir)
+ self.save_weights(save_dir)
+
+ def load(self, save_dir: str, devices=None, verbose=HANLP_VERBOSE, **kwargs):
+ """Load from a local/remote component.
+
+ Args:
+ save_dir: An identifier which can be a local path or a remote URL or a pre-defined string.
+ devices: The devices this component will be moved onto.
+ verbose: ``True`` to log loading progress.
+ **kwargs: To override some configs.
+ """
+ save_dir = get_resource(save_dir)
+ # flash('Loading config and vocabs [blink][yellow]...[/yellow][/blink]')
+ if devices is None and self.model:
+ devices = self.devices
+ self.load_config(save_dir, **kwargs)
+ self.load_vocabs(save_dir)
+ if verbose:
+ flash('Building model [blink][yellow]...[/yellow][/blink]')
+ self.model = self.build_model(
+ **merge_dict(self.config, training=False, **kwargs, overwrite=True,
+ inplace=True))
+ if verbose:
+ flash('')
+ self.load_weights(save_dir, **kwargs)
+ self.to(devices)
+ self.model.eval()
+
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ batch_size,
+ epochs,
+ devices=None,
+ logger=None,
+ seed=None,
+ finetune: Union[bool, str] = False,
+ eval_trn=True,
+ _device_placeholder=False,
+ **kwargs):
+ """Fit to data, triggers the training procedure. For training set and dev set, they shall be local or remote
+ files.
+
+ Args:
+ trn_data: Training set.
+ dev_data: Development set.
+ save_dir: The directory to save trained component.
+ batch_size: The number of samples in a batch.
+ epochs: Number of epochs.
+ devices: Devices this component will live on.
+ logger: Any :class:`logging.Logger` instance.
+ seed: Random seed to reproduce this training.
+ finetune: ``True`` to load from ``save_dir`` instead of creating a randomly initialized component. ``str``
+ to specify a different ``save_dir`` to load from.
+ eval_trn: Evaluate training set after each update. This can slow down the training but provides a quick
+ diagnostic for debugging.
+ _device_placeholder: ``True`` to create a placeholder tensor which triggers PyTorch to occupy devices so
+ other components won't take these devices as first choices.
+ **kwargs: Hyperparameters used by sub-classes.
+
+ Returns:
+ Any results sub-classes would like to return. Usually the best metrics on training set.
+
+ """
+ # Common initialization steps
+ config = self._capture_config(locals())
+ if not logger:
+ logger = self.build_logger('train', save_dir)
+ if not seed:
+ self.config.seed = 233 if isdebugging() else int(time.time())
+ set_seed(self.config.seed)
+ logger.info(self._savable_config.to_json(sort=True))
+ if isinstance(devices, list) or devices is None or isinstance(devices, float):
+ flash('[yellow]Querying CUDA devices [blink]...[/blink][/yellow]')
+ devices = -1 if isdebugging() else cuda_devices(devices)
+ flash('')
+ # flash(f'Available GPUs: {devices}')
+ if isinstance(devices, list):
+ first_device = (devices[0] if devices else -1)
+ elif isinstance(devices, dict):
+ first_device = next(iter(devices.values()))
+ elif isinstance(devices, int):
+ first_device = devices
+ else:
+ first_device = -1
+ if _device_placeholder and first_device >= 0:
+ _dummy_placeholder = self._create_dummy_placeholder_on(first_device)
+ if finetune:
+ if isinstance(finetune, str):
+ self.load(finetune, devices=devices)
+ else:
+ self.load(save_dir, devices=devices)
+ logger.info(
+ f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
+ f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
+ self.on_config_ready(**self.config)
+ trn = self.build_dataloader(**merge_dict(config, data=trn_data, batch_size=batch_size, shuffle=True,
+ training=True, device=first_device, logger=logger, vocabs=self.vocabs,
+ overwrite=True))
+ dev = self.build_dataloader(**merge_dict(config, data=dev_data, batch_size=batch_size, shuffle=False,
+ training=None, device=first_device, logger=logger, vocabs=self.vocabs,
+ overwrite=True)) if dev_data else None
+ if not finetune:
+ flash('[yellow]Building model [blink]...[/blink][/yellow]')
+ self.model = self.build_model(**merge_dict(config, training=True))
+ flash('')
+ logger.info(f'Model built with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
+ f'/{sum(p.numel() for p in self.model.parameters())} trainable/total parameters.')
+ assert self.model, 'build_model is not properly implemented.'
+ _description = repr(self.model)
+ if len(_description.split('\n')) < 10:
+ logger.info(_description)
+ self.save_config(save_dir)
+ self.save_vocabs(save_dir)
+ self.to(devices, logger)
+ if _device_placeholder and first_device >= 0:
+ del _dummy_placeholder
+ criterion = self.build_criterion(**merge_dict(config, trn=trn))
+ optimizer = self.build_optimizer(**merge_dict(config, trn=trn, criterion=criterion))
+ metric = self.build_metric(**self.config)
+ if hasattr(trn.dataset, '__len__') and dev and hasattr(dev.dataset, '__len__'):
+ logger.info(f'{len(trn.dataset)}/{len(dev.dataset)} samples in trn/dev set.')
+ trn_size = len(trn) // self.config.get('gradient_accumulation', 1)
+ ratio_width = len(f'{trn_size}/{trn_size}')
+ else:
+ ratio_width = None
+ return self.execute_training_loop(**merge_dict(config, trn=trn, dev=dev, epochs=epochs, criterion=criterion,
+ optimizer=optimizer, metric=metric, logger=logger,
+ save_dir=save_dir,
+ devices=devices,
+ ratio_width=ratio_width,
+ trn_data=trn_data,
+ dev_data=dev_data,
+ eval_trn=eval_trn,
+ overwrite=True))
+
+ def build_logger(self, name, save_dir):
+ """Build a :class:`logging.Logger`.
+
+ Args:
+ name: The name of this logger.
+ save_dir: The directory this logger should save logs into.
+
+ Returns:
+ logging.Logger: A logger.
+ """
+ logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO, fmt="%(message)s")
+ return logger
+
+ @abstractmethod
+ def build_dataloader(self, data, batch_size, shuffle=False, device=None, logger: logging.Logger = None,
+ **kwargs) -> DataLoader:
+ """Build dataloader for training, dev and test sets. It's suggested to build vocabs in this method if they are
+ not built yet.
+
+ Args:
+ data: Data representing samples, which can be a path or a list of samples.
+ batch_size: Number of samples per batch.
+ shuffle: Whether to shuffle this dataloader.
+ device: Device tensors should be loaded onto.
+ logger: Logger for reporting some message if dataloader takes a long time or if vocabs has to be built.
+ **kwargs: Arguments from ``**self.config``.
+ """
+ pass
+
+ def build_vocabs(self, **kwargs):
+ """Override this method to build vocabs.
+
+ Args:
+ **kwargs: The subclass decides the method signature.
+ """
+ pass
+
+ @property
+ def _savable_config(self):
+ def convert(k, v):
+ if not isinstance(v, SerializableDict) and hasattr(v, 'config'):
+ v = v.config
+ elif isinstance(v, (set, tuple)):
+ v = list(v)
+ if isinstance(v, dict):
+ v = dict(convert(_k, _v) for _k, _v in v.items())
+ return k, v
+
+ config = SerializableDict(
+ convert(k, v) for k, v in sorted(self.config.items()))
+ config.update({
+ # 'create_time': now_datetime(),
+ 'classpath': classpath_of(self),
+ 'hanlp_version': hanlp.__version__,
+ })
+ return config
+
+ @abstractmethod
+ def build_optimizer(self, **kwargs):
+ """Implement this method to build an optimizer.
+
+ Args:
+ **kwargs: The subclass decides the method signature.
+ """
+ pass
+
+ @abstractmethod
+ def build_criterion(self, decoder, **kwargs):
+ """Implement this method to build criterion (loss function).
+
+ Args:
+ decoder: The model or decoder.
+ **kwargs: The subclass decides the method signature.
+ """
+ pass
+
+ @abstractmethod
+ def build_metric(self, **kwargs):
+ """Implement this to build metric(s).
+
+ Args:
+ **kwargs: The subclass decides the method signature.
+ """
+ pass
+
+ @abstractmethod
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
+ logger: logging.Logger, devices, ratio_width=None,
+ **kwargs):
+ """Implement this to run training loop.
+
+ Args:
+ trn: Training set.
+ dev: Development set.
+ epochs: Number of epochs.
+ criterion: Loss function.
+ optimizer: Optimizer(s).
+ metric: Metric(s)
+ save_dir: The directory to save this component.
+ logger: Logger for reporting progress.
+ devices: Devices this component and dataloader will live on.
+ ratio_width: The width of dataset size measured in number of characters. Used for logger to align messages.
+ **kwargs: Other hyper-parameters passed from sub-class.
+ """
+ pass
+
+ @abstractmethod
+ def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs):
+ """Fit onto a dataloader.
+
+ Args:
+ trn: Training set.
+ criterion: Loss function.
+ optimizer: Optimizer.
+ metric: Metric(s).
+ logger: Logger for reporting progress.
+ **kwargs: Other hyper-parameters passed from sub-class.
+ """
+ pass
+
+ @abstractmethod
+ def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric=None, output=False, **kwargs):
+ """Evaluate on a dataloader.
+
+ Args:
+ data: Dataloader which can build from any data source.
+ criterion: Loss function.
+ metric: Metric(s).
+ output: Whether to save outputs into some file.
+ **kwargs: Not used.
+ """
+ pass
+
+ @abstractmethod
+ def build_model(self, training=True, **kwargs) -> torch.nn.Module:
+ """Build model.
+
+ Args:
+ training: ``True`` if called during training.
+ **kwargs: ``**self.config``.
+ """
+ raise NotImplementedError
+
+ def evaluate(self, tst_data, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs):
+ """Evaluate test set.
+
+ Args:
+ tst_data: Test set, which is usually a file path.
+ save_dir: The directory to save evaluation scores or predictions.
+ logger: Logger for reporting progress.
+ batch_size: Batch size for test dataloader.
+ output: Whether to save outputs into some file.
+ **kwargs: Not used.
+
+ Returns:
+ (metric, outputs) where outputs are the return values of ``evaluate_dataloader``.
+ """
+ if not self.model:
+ raise RuntimeError('Call fit or load before evaluate.')
+ if isinstance(tst_data, str):
+ tst_data = get_resource(tst_data)
+ filename = os.path.basename(tst_data)
+ else:
+ filename = None
+ if output is True:
+ output = self.generate_prediction_filename(tst_data if isinstance(tst_data, str) else 'test.txt', save_dir)
+ if logger is None:
+ _logger_name = basename_no_ext(filename) if filename else None
+ logger = self.build_logger(_logger_name, save_dir)
+ if not batch_size:
+ batch_size = self.config.get('batch_size', 32)
+ data = self.build_dataloader(**merge_dict(self.config, data=tst_data, batch_size=batch_size, shuffle=False,
+ device=self.devices[0], logger=logger, overwrite=True))
+ dataset = data
+ while dataset and hasattr(dataset, 'dataset'):
+ dataset = dataset.dataset
+ num_samples = len(dataset) if dataset else None
+ if output and isinstance(dataset, TransformableDataset):
+ def add_idx(samples):
+ for idx, sample in enumerate(samples):
+ if sample:
+ sample[IDX] = idx
+
+ add_idx(dataset.data)
+ if dataset.cache:
+ add_idx(dataset.cache)
+
+ criterion = self.build_criterion(**self.config)
+ metric = self.build_metric(**self.config)
+ start = time.time()
+ outputs = self.evaluate_dataloader(data, criterion=criterion, filename=filename, output=output, input=tst_data,
+ save_dir=save_dir,
+ test=True,
+ num_samples=num_samples,
+ **merge_dict(self.config, batch_size=batch_size, metric=metric,
+ logger=logger, **kwargs))
+ elapsed = time.time() - start
+ if logger:
+ if num_samples:
+ logger.info(f'speed: {num_samples / elapsed:.0f} samples/second')
+ else:
+ logger.info(f'speed: {len(data) / elapsed:.0f} batches/second')
+ return metric, outputs
+
+ def generate_prediction_filename(self, tst_data, save_dir):
+ assert isinstance(tst_data,
+ str), 'tst_data has be a str in order to infer the output name'
+ output = os.path.splitext(os.path.basename(tst_data))
+ output = os.path.join(save_dir, output[0] + '.pred' + output[1])
+ return output
+
+ def to(self,
+ devices=Union[int, float, List[int], Dict[str, Union[int, torch.device]]],
+ logger: logging.Logger = None, verbose=HANLP_VERBOSE):
+ """Move this component to devices.
+
+ Args:
+ devices: Target devices.
+ logger: Logger for printing progress report, as copying a model from CPU to GPU can takes several seconds.
+ verbose: ``True`` to print progress when logger is None.
+ """
+ if devices == -1 or devices == [-1]:
+ devices = []
+ elif isinstance(devices, (int, float)) or devices is None:
+ devices = cuda_devices(devices)
+ if devices:
+ if logger:
+ logger.info(f'Using GPUs: [on_blue][cyan][bold]{devices}[/bold][/cyan][/on_blue]')
+ if isinstance(devices, list):
+ if verbose:
+ flash(f'Moving model to GPUs {devices} [blink][yellow]...[/yellow][/blink]')
+ self.model = self.model.to(devices[0])
+ if len(devices) > 1 and not isdebugging() and not isinstance(self.model, nn.DataParallel):
+ self.model = self.parallelize(devices)
+ elif isinstance(devices, dict):
+ for name, module in self.model.named_modules():
+ for regex, device in devices.items():
+ try:
+ on_device: torch.device = next(module.parameters()).device
+ except StopIteration:
+ continue
+ if on_device == device:
+ continue
+ if isinstance(device, int):
+ if on_device.index == device:
+ continue
+ if re.match(regex, name):
+ if not name:
+ name = '*'
+ flash(f'Moving module [yellow]{name}[/yellow] to [on_yellow][magenta][bold]{device}'
+ f'[/bold][/magenta][/on_yellow]: [red]{regex}[/red]\n')
+ module.to(device)
+ else:
+ raise ValueError(f'Unrecognized devices {devices}')
+ if verbose:
+ flash('')
+ else:
+ if logger:
+ logger.info('Using CPU')
+
+ def parallelize(self, devices: List[Union[int, torch.device]]):
+ return nn.DataParallel(self.model, device_ids=devices)
+
+ @property
+ def devices(self):
+ """The devices this component lives on.
+ """
+ if self.model is None:
+ return None
+ # next(parser.model.parameters()).device
+ if hasattr(self.model, 'device_ids'):
+ return self.model.device_ids
+ device: torch.device = next(self.model.parameters()).device
+ return [device]
+
+ @property
+ def device(self):
+ """The first device this component lives on.
+ """
+ devices = self.devices
+ if not devices:
+ return None
+ return devices[0]
+
+ def on_config_ready(self, **kwargs):
+ """Called when config is ready, either during ``fit`` ot ``load``. Subclass can perform extra initialization
+ tasks in this callback.
+
+ Args:
+ **kwargs: Not used.
+ """
+ pass
+
+ @property
+ def model_(self) -> nn.Module:
+ """
+ The actual model when it's wrapped by a `DataParallel`
+
+ Returns: The "real" model
+
+ """
+ if isinstance(self.model, nn.DataParallel):
+ return self.model.module
+ return self.model
+
+ # noinspection PyMethodOverriding
+ @abstractmethod
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
+ """Predict on data fed by user. Users shall avoid directly call this method since it is not guarded with
+ ``torch.no_grad`` and will introduces unnecessary gradient computation. Use ``__call__`` instead.
+
+ Args:
+ data: Sentences or tokens.
+ batch_size: Decoding batch size.
+ **kwargs: Used in sub-classes.
+ """
+ pass
+
+ @staticmethod
+ def _create_dummy_placeholder_on(device):
+ if device < 0:
+ device = 'cpu:0'
+ return torch.zeros(16, 16, device=device)
+
+ @torch.no_grad()
+ def __call__(self, data, batch_size=None, **kwargs):
+ """Predict on data fed by user. This method calls :meth:`~hanlp.common.torch_component.predict` but decorates
+ it with ``torch.no_grad``.
+
+ Args:
+ data: Sentences or tokens.
+ batch_size: Decoding batch size.
+ **kwargs: Used in sub-classes.
+ """
+ return super().__call__(data, **merge_dict(self.config, overwrite=True,
+ batch_size=batch_size or self.config.get('batch_size', None),
+ **kwargs))
diff --git a/hanlp/common/transform.py b/hanlp/common/transform.py
index 70fe81b8e..c46c35693 100644
--- a/hanlp/common/transform.py
+++ b/hanlp/common/transform.py
@@ -1,297 +1,522 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-10-27 14:22
-import inspect
+# Date: 2020-05-03 14:44
+import logging
+import os
from abc import ABC, abstractmethod
-from typing import Generator, Tuple, Union, Iterable, Any
+from typing import Tuple, Union, List
-import tensorflow as tf
-
-from hanlp.common.structure import SerializableDict
+from hanlp_common.constant import EOS, PAD
+from hanlp_common.structure import SerializableDict
+from hanlp_common.configurable import Configurable
from hanlp.common.vocab import Vocab
from hanlp.utils.io_util import get_resource
-from hanlp.utils.log_util import logger
+from hanlp_common.io import load_json
+from hanlp_common.reflection import classpath_of, str_to_type
+from hanlp.utils.string_util import ispunct
-class Transform(ABC):
+class ToIndex(ABC):
- def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None:
+ def __init__(self, vocab: Vocab = None) -> None:
super().__init__()
- self.map_y = map_y
- self.map_x = map_x
- if kwargs:
- if not config:
- config = SerializableDict()
- for k, v in kwargs.items():
- config[k] = v
- self.config = config
- self.output_types = None
- self.output_shapes = None
- self.padding_values = None
+ if vocab is None:
+ vocab = Vocab()
+ self.vocab = vocab
@abstractmethod
- def fit(self, trn_path: str, **kwargs) -> int:
+ def __call__(self, sample):
+ pass
+
+ def save_vocab(self, save_dir, filename='vocab.json'):
+ vocab = SerializableDict()
+ vocab.update(self.vocab.to_dict())
+ vocab.save_json(os.path.join(save_dir, filename))
+
+ def load_vocab(self, save_dir, filename='vocab.json'):
+ save_dir = get_resource(save_dir)
+ vocab = SerializableDict()
+ vocab.load_json(os.path.join(save_dir, filename))
+ self.vocab.copy_from(vocab)
+
+
+class FieldToIndex(ToIndex):
+
+ def __init__(self, src, vocab: Vocab, dst=None) -> None:
+ super().__init__(vocab)
+ self.src = src
+ if not dst:
+ dst = f'{src}_id'
+ self.dst = dst
+
+ def __call__(self, sample: dict):
+ sample[self.dst] = self.vocab(sample[self.src])
+ return sample
+
+ def save_vocab(self, save_dir, filename=None):
+ if not filename:
+ filename = f'{self.dst}_vocab.json'
+ super().save_vocab(save_dir, filename)
+
+ def load_vocab(self, save_dir, filename=None):
+ if not filename:
+ filename = f'{self.dst}_vocab.json'
+ super().load_vocab(save_dir, filename)
+
+
+class VocabList(list):
+
+ def __init__(self, *fields) -> None:
+ super().__init__()
+ for each in fields:
+ self.append(FieldToIndex(each))
+
+ def append(self, item: Union[str, Tuple[str, Vocab], Tuple[str, str, Vocab], FieldToIndex]) -> None:
+ if isinstance(item, str):
+ item = FieldToIndex(item)
+ elif isinstance(item, (list, tuple)):
+ if len(item) == 2:
+ item = FieldToIndex(src=item[0], vocab=item[1])
+ elif len(item) == 3:
+ item = FieldToIndex(src=item[0], dst=item[1], vocab=item[2])
+ else:
+ raise ValueError(f'Unsupported argument length: {item}')
+ elif isinstance(item, FieldToIndex):
+ pass
+ else:
+ raise ValueError(f'Unsupported argument type: {item}')
+ super(self).append(item)
+
+ def save_vocab(self, save_dir):
+ for each in self:
+ each.save_vocab(save_dir, None)
+
+ def load_vocab(self, save_dir):
+ for each in self:
+ each.load_vocab(save_dir, None)
+
+
+class VocabDict(SerializableDict):
+
+ def __init__(self, *args, **kwargs) -> None:
+ """A dict holding :class:`hanlp.common.vocab.Vocab` instances. When used a transform, it transforms the field
+ corresponding to each :class:`hanlp.common.vocab.Vocab` into indices.
+
+ Args:
+ *args: A list of vocab names.
+ **kwargs: Names and corresponding :class:`hanlp.common.vocab.Vocab` instances.
"""
- Build the vocabulary from training file
+ vocabs = dict(kwargs)
+ for each in args:
+ vocabs[each] = Vocab()
+ super().__init__(vocabs)
- Parameters
- ----------
- trn_path : path to training set
- kwargs
+ def save_vocabs(self, save_dir, filename='vocabs.json'):
+ """Save vocabularies to a directory.
- Returns
- -------
- int
- How many samples in the training set
+ Args:
+ save_dir: The directory to save vocabularies.
+ filename: The name for vocabularies.
"""
- raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+ vocabs = SerializableDict()
+ for key, value in self.items():
+ if isinstance(value, Vocab):
+ vocabs[key] = value.to_dict()
+ vocabs.save_json(os.path.join(save_dir, filename))
- def build_config(self):
+ def load_vocabs(self, save_dir, filename='vocabs.json', vocab_cls=Vocab):
+ """Load vocabularies from a directory.
+
+ Args:
+ save_dir: The directory to load vocabularies.
+ filename: The name for vocabularies.
"""
- By default, call build_types_shapes_values, usually called in component's build method.
- You can perform other building task here. Remember to call super().build_config
+ save_dir = get_resource(save_dir)
+ vocabs = SerializableDict()
+ vocabs.load_json(os.path.join(save_dir, filename))
+ self._load_vocabs(self, vocabs, vocab_cls)
+
+ @staticmethod
+ def _load_vocabs(vd, vocabs: dict, vocab_cls=Vocab):
"""
- self.output_types, self.output_shapes, self.padding_values = self.create_types_shapes_values()
- # We prefer list over shape here, as it's easier to type [] than ()
- # if isinstance(self.output_shapes, tuple):
- # self.output_shapes = list(self.output_shapes)
- # for i, shapes in enumerate(self.output_shapes):
- # if isinstance(shapes, tuple):
- # self.output_shapes[i] = list(shapes)
- # for j, shape in enumerate(shapes):
- # if isinstance(shape, tuple):
- # shapes[j] = list(shape)
- @abstractmethod
- def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
+ Args:
+ vd:
+ vocabs:
+ vocab_cls: Default class for the new vocab
"""
- Create dataset related values,
+ for key, value in vocabs.items():
+ if 'idx_to_token' in value:
+ cls = value.get('type', None)
+ if cls:
+ cls = str_to_type(cls)
+ else:
+ cls = vocab_cls
+ vocab = cls()
+ vocab.copy_from(value)
+ vd[key] = vocab
+ else: # nested Vocab
+ # noinspection PyTypeChecker
+ vd[key] = nested = VocabDict()
+ VocabDict._load_vocabs(nested, value, vocab_cls)
+
+ def lock(self):
"""
- raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
-
- @abstractmethod
- def file_to_inputs(self, filepath: str, gold=True):
+ Lock each vocabs.
"""
- Transform file to inputs. The inputs are defined as raw features (e.g. words) to be processed into more
- features (e.g. forms and characters)
+ for key, value in self.items():
+ if isinstance(value, Vocab):
+ value.lock()
- Parameters
- ----------
- filepath
- gold
- """
- raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+ @property
+ def mutable(self):
+ status = [v.mutable for v in self.values() if isinstance(v, Vocab)]
+ return len(status) == 0 or any(status)
- def inputs_to_samples(self, inputs, gold=False):
- if gold:
- yield from inputs
- else:
- for x in inputs:
- yield x, self.padding_values[-1]
+ def __call__(self, sample: dict):
+ for key, value in self.items():
+ if isinstance(value, Vocab):
+ field = sample.get(key, None)
+ if field is not None:
+ sample[f'{key}_id'] = value(field)
+ return sample
+
+ def __getattr__(self, key):
+ if key.startswith('__'):
+ return dict.__getattr__(key)
+ return self.__getitem__(key)
+
+ def __setattr__(self, key, value):
+ return self.__setitem__(key, value)
- def file_to_samples(self, filepath: str, gold=True):
+ def __getitem__(self, k: str) -> Vocab:
+ return super().__getitem__(k)
+
+ def __setitem__(self, k: str, v: Vocab) -> None:
+ super().__setitem__(k, v)
+
+ def summary(self, logger: logging.Logger = None):
+ """Log a summary of vocabs using a given logger.
+
+ Args:
+ logger: The logger to use.
"""
- Transform file to samples
- Parameters
- ----------
- filepath
- gold
+ for key, value in self.items():
+ if isinstance(value, Vocab):
+ report = value.summary(verbose=False)
+ if logger:
+ logger.info(f'{key}{report}')
+ else:
+ print(f'{key}{report}')
+
+ def put(self, **kwargs):
+ """Put names and corresponding :class:`hanlp.common.vocab.Vocab` instances into self.
+
+ Args:
+ **kwargs: Names and corresponding :class:`hanlp.common.vocab.Vocab` instances.
"""
- filepath = get_resource(filepath)
- inputs = self.file_to_inputs(filepath, gold)
- yield from self.inputs_to_samples(inputs, gold)
-
- def file_to_dataset(self, filepath: str, gold=True, map_x=None, map_y=None, batch_size=32, shuffle=None,
- repeat=None,
- drop_remainder=False,
- prefetch=1,
- cache=True,
- **kwargs) -> tf.data.Dataset:
+ for k, v in kwargs.items():
+ self[k] = v
+
+
+class NamedTransform(ABC):
+ def __init__(self, src: str, dst: str = None) -> None:
+ if dst is None:
+ dst = src
+ self.dst = dst
+ self.src = src
+
+ @abstractmethod
+ def __call__(self, sample: dict) -> dict:
+ return sample
+
+
+class ConfigurableTransform(Configurable, ABC):
+ @property
+ def config(self):
+ return dict([('classpath', classpath_of(self))] +
+ [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')])
+
+ @classmethod
+ def from_config(cls, config: dict):
"""
- Transform file to dataset
-
- Parameters
- ----------
- filepath
- gold : bool
- Whether it's processing gold data or not. Example: there is usually a column for gold answer
- when gold = True.
- map_x : bool
- Whether call map_x or not. Default to self.map_x
- map_y : bool
- Whether call map_y or not. Default to self.map_y
- batch_size
- shuffle
- repeat
- prefetch
- kwargs
-
- Returns
- -------
+ Args:
+ config:
+ kwargs:
+ config: dict:
+
+ Returns:
+
+
"""
+ cls = config.get('classpath', None)
+ assert cls, f'{config} doesn\'t contain classpath field'
+ cls = str_to_type(cls)
+ config = dict(config)
+ config.pop('classpath')
+ return cls(**config)
- # debug
- # for sample in self.file_to_samples(filepath):
- # pass
-
- def generator():
- inputs = self.file_to_inputs(filepath, gold)
- samples = self.inputs_to_samples(inputs, gold)
- yield from samples
-
- return self.samples_to_dataset(generator, map_x, map_y, batch_size, shuffle, repeat, drop_remainder, prefetch,
- cache)
-
- def inputs_to_dataset(self, inputs, gold=False, map_x=None, map_y=None, batch_size=32, shuffle=None, repeat=None,
- drop_remainder=False,
- prefetch=1, cache=False, **kwargs) -> tf.data.Dataset:
- # debug
- # for sample in self.inputs_to_samples(inputs):
- # pass
-
- def generator():
- samples = self.inputs_to_samples(inputs, gold)
- yield from samples
-
- return self.samples_to_dataset(generator, map_x, map_y, batch_size, shuffle, repeat, drop_remainder, prefetch,
- cache)
-
- def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_size=32, shuffle=None, repeat=None,
- drop_remainder=False,
- prefetch=1, cache=True) -> tf.data.Dataset:
- output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values
- if not all(v for v in [output_shapes, output_shapes,
- padding_values]):
- # print('Did you forget to call build_config() on your transform?')
- self.build_config()
- output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values
- assert all(v for v in [output_shapes, output_shapes,
- padding_values]), 'Your create_types_shapes_values returns None, which is not allowed'
- # if not callable(samples):
- # samples = Transform.generator_to_callable(samples)
- dataset = tf.data.Dataset.from_generator(samples, output_types=output_types, output_shapes=output_shapes)
- if cache:
- logger.debug('Dataset cache enabled')
- dataset = dataset.cache(cache if isinstance(cache, str) else '')
- if shuffle:
- if isinstance(shuffle, bool):
- shuffle = 1024
- dataset = dataset.shuffle(shuffle)
- if repeat:
- dataset = dataset.repeat(repeat)
- if batch_size:
- dataset = dataset.padded_batch(batch_size, output_shapes, padding_values, drop_remainder)
- if prefetch:
- dataset = dataset.prefetch(prefetch)
- if map_x is None:
- map_x = self.map_x
- if map_y is None:
- map_y = self.map_y
- if map_x or map_y:
- def mapper(X, Y):
- if map_x:
- X = self.x_to_idx(X)
- if map_y:
- Y = self.y_to_idx(Y)
- return X, Y
-
- dataset = dataset.map(mapper, num_parallel_calls=tf.data.experimental.AUTOTUNE)
- return dataset
- @abstractmethod
- def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
- raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+class ConfigurableNamedTransform(NamedTransform, ConfigurableTransform, ABC):
+ pass
- @abstractmethod
- def y_to_idx(self, y) -> tf.Tensor:
- raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
- def lock_vocabs(self):
- for key, value in vars(self).items():
- if isinstance(value, Vocab):
- value.lock()
+class EmbeddingNamedTransform(ConfigurableNamedTransform, ABC):
- def summarize_vocabs(self, logger=None, header='Vocab summary:'):
- output = header + '\n'
- vocabs = {}
- for key, value in vars(self).items():
- if isinstance(value, Vocab):
- vocabs[key] = value
- # tag vocab comes last usually
- for key, value in sorted(vocabs.items(), key=lambda kv: len(kv[1]), reverse=True):
- output += f'{key}' + value.summary(verbose=False) + '\n'
- output = output.strip()
- if logger:
- logger.info(output)
- else:
- print(output)
+ def __init__(self, output_dim: int, src: str, dst: str) -> None:
+ super().__init__(src, dst)
+ self.output_dim = output_dim
+
+
+class RenameField(NamedTransform):
+
+ def __call__(self, sample: dict):
+ sample[self.dst] = sample.pop(self.src)
+ return sample
+
+
+class CopyField(object):
+ def __init__(self, src, dst) -> None:
+ self.dst = dst
+ self.src = src
+
+ def __call__(self, sample: dict) -> dict:
+ sample[self.dst] = sample[self.src]
+ return sample
+
+
+class FilterField(object):
+ def __init__(self, *keys) -> None:
+ self.keys = keys
+
+ def __call__(self, sample: dict):
+ sample = dict((k, sample[k]) for k in self.keys)
+ return sample
+
+
+class TransformList(list):
+ """Composes several transforms together.
+
+ Args:
+ transforms(list of ``Transform`` objects): list of transforms to compose.
+ Example:
+
+ Returns:
+
+ >>> transforms.TransformList(
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> )
+ """
+
+ def __init__(self, *transforms) -> None:
+ super().__init__()
+ self.extend(transforms)
+
+ def __call__(self, sample):
+ for t in self:
+ sample = t(sample)
+ return sample
+
+ def index_by_type(self, t):
+ for i, trans in enumerate(self):
+ if isinstance(trans, t):
+ return i
+
+
+class LowerCase(object):
+ def __init__(self, src, dst=None) -> None:
+ if dst is None:
+ dst = src
+ self.src = src
+ self.dst = dst
+
+ def __call__(self, sample: dict) -> dict:
+ src = sample[self.src]
+ if isinstance(src, str):
+ sample[self.dst] = src.lower()
+ elif isinstance(src, list):
+ sample[self.dst] = [x.lower() for x in src]
+ return sample
+
+
+class ToChar(object):
+ def __init__(self, src, dst='char', max_word_length=None, min_word_length=None, pad=PAD) -> None:
+ if dst is None:
+ dst = src
+ self.src = src
+ self.dst = dst
+ self.max_word_length = max_word_length
+ self.min_word_length = min_word_length
+ self.pad = pad
+
+ def __call__(self, sample: dict) -> dict:
+ src = sample[self.src]
+ if isinstance(src, str):
+ sample[self.dst] = self.to_chars(src)
+ elif isinstance(src, list):
+ sample[self.dst] = [self.to_chars(x) for x in src]
+ return sample
+
+ def to_chars(self, word: str):
+ chars = list(word)
+ if self.min_word_length and len(chars) < self.min_word_length:
+ chars = chars + [self.pad] * (self.min_word_length - len(chars))
+ if self.max_word_length:
+ chars = chars[:self.max_word_length]
+ return chars
+
+
+class AppendEOS(NamedTransform):
+
+ def __init__(self, src: str, dst: str = None, eos=EOS) -> None:
+ super().__init__(src, dst)
+ self.eos = eos
+
+ def __call__(self, sample: dict) -> dict:
+ sample[self.dst] = sample[self.src] + [self.eos]
+ return sample
+
+
+class WhitespaceTokenizer(NamedTransform):
+
+ def __call__(self, sample: dict) -> dict:
+ src = sample[self.src]
+ if isinstance(src, str):
+ sample[self.dst] = self.tokenize(src)
+ elif isinstance(src, list):
+ sample[self.dst] = [self.tokenize(x) for x in src]
+ return sample
@staticmethod
- def generator_to_callable(generator: Generator):
- return lambda: (x for x in generator)
+ def tokenize(text: str):
+ return text.split()
- def str_to_idx(self, X, Y) -> Tuple[Union[tf.Tensor, Tuple], tf.Tensor]:
- return self.x_to_idx(X), self.y_to_idx(Y)
- def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
- return [repr(x) for x in X]
+class NormalizeDigit(object):
+ def __init__(self, src, dst=None) -> None:
+ if dst is None:
+ dst = src
+ self.src = src
+ self.dst = dst
- def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
- return [repr(y) for y in Y]
+ @staticmethod
+ def transform(word: str):
+ new_word = ""
+ for char in word:
+ if char.isdigit():
+ new_word += '0'
+ else:
+ new_word += char
+ return new_word
- def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]],
- Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False) -> Iterable:
- """
- Convert predicted tensors to outputs
+ def __call__(self, sample: dict) -> dict:
+ src = sample[self.src]
+ if isinstance(src, str):
+ sample[self.dst] = self.transform(src)
+ elif isinstance(src, list):
+ sample[self.dst] = [self.transform(x) for x in src]
+ return sample
- Parameters
- ----------
- X : Union[tf.Tensor, Tuple[tf.Tensor]]
- The inputs of model
- Y : Union[tf.Tensor, Tuple[tf.Tensor]]
- The outputs of model
- Returns
- -------
+class Bigram(NamedTransform):
- """
- return [(x, y) for x, y in zip(self.X_to_inputs(X), self.Y_to_outputs(Y, gold))]
+ def __init__(self, src: str, dst: str = None) -> None:
+ if not dst:
+ dst = f'{src}_bigram'
+ super().__init__(src, dst)
- def input_is_single_sample(self, input: Any) -> bool:
- return False
+ def __call__(self, sample: dict) -> dict:
+ src: List = sample[self.src]
+ dst = src + [EOS]
+ dst = [dst[i] + dst[i + 1] for i in range(len(src))]
+ sample[self.dst] = dst
+ return sample
- def input_to_inputs(self, input: Any) -> Tuple[Any, bool]:
- """
- If input is one sample, convert it to a list which contains this unique sample
- Parameters
- ----------
- input :
- sample or samples
+class FieldLength(NamedTransform):
- Returns
- -------
- (inputs, converted) : Tuple[Any, bool]
+ def __init__(self, src: str, dst: str = None, delta=0) -> None:
+ self.delta = delta
+ if not dst:
+ dst = f'{src}_length'
+ super().__init__(src, dst)
- """
- flat = self.input_is_single_sample(input)
- if flat:
- input = [input]
- return input, flat
+ def __call__(self, sample: dict) -> dict:
+ sample[self.dst] = len(sample[self.src]) + self.delta
+ return sample
- def input_truth_output_to_str(self, input, truth, output):
- """
- Convert input truth output to string representation, usually for writing to file during evaluation
- Parameters
- ----------
- input
- truth
- output
+class BMESOtoIOBES(object):
+ def __init__(self, field='tag') -> None:
+ self.field = field
+
+ def __call__(self, sample: dict) -> dict:
+ sample[self.field] = [self.convert(y) for y in sample[self.field]]
+ return sample
- Returns
- -------
+ @staticmethod
+ def convert(y: str):
+ if y.startswith('M-'):
+ return 'I-'
+ return y
+
+
+class NormalizeToken(ConfigurableNamedTransform):
+
+ def __init__(self, mapper: Union[str, dict], src: str, dst: str = None) -> None:
+ super().__init__(src, dst)
+ self.mapper = mapper
+ if isinstance(mapper, str):
+ mapper = get_resource(mapper)
+ if isinstance(mapper, str):
+ self._table = load_json(mapper)
+ elif isinstance(mapper, dict):
+ self._table = mapper
+ else:
+ raise ValueError(f'Unrecognized mapper type {mapper}')
+
+ def __call__(self, sample: dict) -> dict:
+ src = sample[self.src]
+ if self.src == self.dst:
+ sample[f'{self.src}_'] = src
+ if isinstance(src, str):
+ src = self.convert(src)
+ else:
+ src = [self.convert(x) for x in src]
+ sample[self.dst] = src
+ return sample
+
+ def convert(self, token) -> str:
+ return self._table.get(token, token)
+
+
+class PunctuationMask(ConfigurableNamedTransform):
+ def __init__(self, src: str, dst: str = None) -> None:
+ """Mask out all punctuations (set mask of punctuations to False)
+
+ Args:
+ src:
+ dst:
+
+ Returns:
"""
- return '\t'.join([input, truth, output]) + '\n'
+ if not dst:
+ dst = f'{src}_punct_mask'
+ super().__init__(src, dst)
+
+ def __call__(self, sample: dict) -> dict:
+ src = sample[self.src]
+ if isinstance(src, str):
+ dst = not ispunct(src)
+ else:
+ dst = [not ispunct(x) for x in src]
+ sample[self.dst] = dst
+ return sample
+
+
+class NormalizeCharacter(NormalizeToken):
+ def convert(self, token) -> str:
+ return ''.join([NormalizeToken.convert(self, c) for c in token])
diff --git a/hanlp/common/transform_tf.py b/hanlp/common/transform_tf.py
new file mode 100644
index 000000000..58c1dc679
--- /dev/null
+++ b/hanlp/common/transform_tf.py
@@ -0,0 +1,297 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-10-27 14:22
+import inspect
+from abc import ABC, abstractmethod
+from typing import Generator, Tuple, Union, Iterable, Any
+
+import tensorflow as tf
+
+from hanlp_common.structure import SerializableDict
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.utils.io_util import get_resource
+from hanlp.utils.log_util import logger
+
+
+class Transform(ABC):
+
+ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None:
+ super().__init__()
+ self.map_y = map_y
+ self.map_x = map_x
+ if kwargs:
+ if not config:
+ config = SerializableDict()
+ for k, v in kwargs.items():
+ config[k] = v
+ self.config = config
+ self.output_types = None
+ self.output_shapes = None
+ self.padding_values = None
+
+ @abstractmethod
+ def fit(self, trn_path: str, **kwargs) -> int:
+ """
+ Build the vocabulary from training file
+
+ Parameters
+ ----------
+ trn_path : path to training set
+ kwargs
+
+ Returns
+ -------
+ int
+ How many samples in the training set
+ """
+ raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+
+ def build_config(self):
+ """
+ By default, call build_types_shapes_values, usually called in component's build method.
+ You can perform other building task here. Remember to call super().build_config
+ """
+ self.output_types, self.output_shapes, self.padding_values = self.create_types_shapes_values()
+ # We prefer list over shape here, as it's easier to type [] than ()
+ # if isinstance(self.output_shapes, tuple):
+ # self.output_shapes = list(self.output_shapes)
+ # for i, shapes in enumerate(self.output_shapes):
+ # if isinstance(shapes, tuple):
+ # self.output_shapes[i] = list(shapes)
+ # for j, shape in enumerate(shapes):
+ # if isinstance(shape, tuple):
+ # shapes[j] = list(shape)
+
+ @abstractmethod
+ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
+ """
+ Create dataset related values,
+ """
+ raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+
+ @abstractmethod
+ def file_to_inputs(self, filepath: str, gold=True):
+ """
+ Transform file to inputs. The inputs are defined as raw features (e.g. words) to be processed into more
+ features (e.g. forms and characters)
+
+ Parameters
+ ----------
+ filepath
+ gold
+ """
+ raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+
+ def inputs_to_samples(self, inputs, gold=False):
+ if gold:
+ yield from inputs
+ else:
+ for x in inputs:
+ yield x, self.padding_values[-1]
+
+ def file_to_samples(self, filepath: str, gold=True):
+ """
+ Transform file to samples
+ Parameters
+ ----------
+ filepath
+ gold
+ """
+ filepath = get_resource(filepath)
+ inputs = self.file_to_inputs(filepath, gold)
+ yield from self.inputs_to_samples(inputs, gold)
+
+ def file_to_dataset(self, filepath: str, gold=True, map_x=None, map_y=None, batch_size=32, shuffle=None,
+ repeat=None,
+ drop_remainder=False,
+ prefetch=1,
+ cache=True,
+ **kwargs) -> tf.data.Dataset:
+ """
+ Transform file to dataset
+
+ Parameters
+ ----------
+ filepath
+ gold : bool
+ Whether it's processing gold data or not. Example: there is usually a column for gold answer
+ when gold = True.
+ map_x : bool
+ Whether call map_x or not. Default to self.map_x
+ map_y : bool
+ Whether call map_y or not. Default to self.map_y
+ batch_size
+ shuffle
+ repeat
+ prefetch
+ kwargs
+
+ Returns
+ -------
+
+ """
+
+ # debug
+ # for sample in self.file_to_samples(filepath):
+ # pass
+
+ def generator():
+ inputs = self.file_to_inputs(filepath, gold)
+ samples = self.inputs_to_samples(inputs, gold)
+ yield from samples
+
+ return self.samples_to_dataset(generator, map_x, map_y, batch_size, shuffle, repeat, drop_remainder, prefetch,
+ cache)
+
+ def inputs_to_dataset(self, inputs, gold=False, map_x=None, map_y=None, batch_size=32, shuffle=None, repeat=None,
+ drop_remainder=False,
+ prefetch=1, cache=False, **kwargs) -> tf.data.Dataset:
+ # debug
+ # for sample in self.inputs_to_samples(inputs):
+ # pass
+
+ def generator():
+ samples = self.inputs_to_samples(inputs, gold)
+ yield from samples
+
+ return self.samples_to_dataset(generator, map_x, map_y, batch_size, shuffle, repeat, drop_remainder, prefetch,
+ cache)
+
+ def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_size=32, shuffle=None, repeat=None,
+ drop_remainder=False,
+ prefetch=1, cache=True) -> tf.data.Dataset:
+ output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values
+ if not all(v for v in [output_shapes, output_shapes,
+ padding_values]):
+ # print('Did you forget to call build_config() on your transform?')
+ self.build_config()
+ output_types, output_shapes, padding_values = self.output_types, self.output_shapes, self.padding_values
+ assert all(v for v in [output_shapes, output_shapes,
+ padding_values]), 'Your create_types_shapes_values returns None, which is not allowed'
+ # if not callable(samples):
+ # samples = Transform.generator_to_callable(samples)
+ dataset = tf.data.Dataset.from_generator(samples, output_types=output_types, output_shapes=output_shapes)
+ if cache:
+ logger.debug('Dataset cache enabled')
+ dataset = dataset.cache(cache if isinstance(cache, str) else '')
+ if shuffle:
+ if isinstance(shuffle, bool):
+ shuffle = 1024
+ dataset = dataset.shuffle(shuffle)
+ if repeat:
+ dataset = dataset.repeat(repeat)
+ if batch_size:
+ dataset = dataset.padded_batch(batch_size, output_shapes, padding_values, drop_remainder)
+ if prefetch:
+ dataset = dataset.prefetch(prefetch)
+ if map_x is None:
+ map_x = self.map_x
+ if map_y is None:
+ map_y = self.map_y
+ if map_x or map_y:
+ def mapper(X, Y):
+ if map_x:
+ X = self.x_to_idx(X)
+ if map_y:
+ Y = self.y_to_idx(Y)
+ return X, Y
+
+ dataset = dataset.map(mapper, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ @abstractmethod
+ def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
+ raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+
+ @abstractmethod
+ def y_to_idx(self, y) -> tf.Tensor:
+ raise NotImplementedError('%s.%s()' % (self.__class__.__name__, inspect.stack()[0][3]))
+
+ def lock_vocabs(self):
+ for key, value in vars(self).items():
+ if isinstance(value, VocabTF):
+ value.lock()
+
+ def summarize_vocabs(self, logger=None, header='Vocab summary:'):
+ output = header + '\n'
+ vocabs = {}
+ for key, value in vars(self).items():
+ if isinstance(value, VocabTF):
+ vocabs[key] = value
+ # tag vocab comes last usually
+ for key, value in sorted(vocabs.items(), key=lambda kv: len(kv[1]), reverse=True):
+ output += f'{key}' + value.summary(verbose=False) + '\n'
+ output = output.strip()
+ if logger:
+ logger.info(output)
+ else:
+ print(output)
+
+ @staticmethod
+ def generator_to_callable(generator: Generator):
+ return lambda: (x for x in generator)
+
+ def str_to_idx(self, X, Y) -> Tuple[Union[tf.Tensor, Tuple], tf.Tensor]:
+ return self.x_to_idx(X), self.y_to_idx(Y)
+
+ def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
+ return [repr(x) for x in X]
+
+ def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
+ return [repr(y) for y in Y]
+
+ def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]],
+ Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False) -> Iterable:
+ """
+ Convert predicted tensors to outputs
+
+ Parameters
+ ----------
+ X : Union[tf.Tensor, Tuple[tf.Tensor]]
+ The inputs of model
+ Y : Union[tf.Tensor, Tuple[tf.Tensor]]
+ The outputs of model
+
+ Returns
+ -------
+
+ """
+ return [(x, y) for x, y in zip(self.X_to_inputs(X), self.Y_to_outputs(Y, gold))]
+
+ def input_is_single_sample(self, input: Any) -> bool:
+ return False
+
+ def input_to_inputs(self, input: Any) -> Tuple[Any, bool]:
+ """
+ If input is one sample, convert it to a list which contains this unique sample
+
+ Parameters
+ ----------
+ input :
+ sample or samples
+
+ Returns
+ -------
+ (inputs, converted) : Tuple[Any, bool]
+
+ """
+ flat = self.input_is_single_sample(input)
+ if flat:
+ input = [input]
+ return input, flat
+
+ def input_truth_output_to_str(self, input, truth, output):
+ """
+ Convert input truth output to string representation, usually for writing to file during evaluation
+
+ Parameters
+ ----------
+ input
+ truth
+ output
+
+ Returns
+ -------
+
+ """
+ return '\t'.join([input, truth, output]) + '\n'
diff --git a/hanlp/common/trie.py b/hanlp/common/trie.py
deleted file mode 100644
index 34ef95f10..000000000
--- a/hanlp/common/trie.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-04 23:46
-from typing import Dict, Any, List, Tuple, Iterable, Sequence, Union, Set
-
-
-class Node(object):
- def __init__(self, value=None) -> None:
- self._children = {}
- self._value = value
-
- def _add_child(self, char, value, overwrite=False):
- child = self._children.get(char)
- if child is None:
- child = Node(value)
- self._children[char] = child
- elif overwrite:
- child._value = value
- return child
-
- def transit(self, key):
- state = self
- for char in key:
- state = state._children.get(char)
- if state is None:
- break
- return state
-
-
-class Trie(Node):
- def __init__(self, tokens: Union[Dict[str, Any], Set[str]] = None) -> None:
- super().__init__()
- if tokens:
- if isinstance(tokens, set):
- for k in tokens:
- self[k] = True
- else:
- for k, v in tokens.items():
- self[k] = v
-
- def __contains__(self, key):
- return self[key] is not None
-
- def __getitem__(self, key):
- state = self.transit(key)
- if state is None:
- return None
- return state._value
-
- def __setitem__(self, key, value):
- state = self
- for i, char in enumerate(key):
- if i < len(key) - 1:
- state = state._add_child(char, None, False)
- else:
- state = state._add_child(char, value, True)
-
- def __delitem__(self, key):
- state = self.transit(key)
- if state is not None:
- state._value = None
-
- def update(self, dic: Dict[str, Any]):
- for k, v in dic.items():
- self[k] = v
- return self
-
- def parse(self, text: Sequence[str]) -> List[Tuple[Union[str, Sequence[str]], Any, int, int]]:
- found = []
- for i in range(len(text)):
- state = self
- for j in range(i, len(text)):
- state = state.transit(text[j])
- if state:
- if state._value is not None:
- found.append((text[i: j + 1], state._value, i, j + 1))
- else:
- break
- return found
-
- def parse_longest(self, text: Sequence[str]) -> List[Tuple[Union[str, Sequence[str]], Any, int, int]]:
- found = []
- i = 0
- while i < len(text):
- state = self.transit(text[i])
- if state:
- to = i + 1
- end = to
- value = state._value
- for to in range(i + 1, len(text)):
- state = state.transit(text[to])
- if not state:
- break
- if state._value is not None:
- value = state._value
- end = to + 1
- if value is not None:
- found.append((text[i:end], value, i, end))
- i = end - 1
- i += 1
- return found
diff --git a/hanlp/common/vocab.py b/hanlp/common/vocab.py
index 74a0fc11f..bd63c7fa3 100644
--- a/hanlp/common/vocab.py
+++ b/hanlp/common/vocab.py
@@ -1,17 +1,26 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-06-13 22:42
+from collections import Counter
from typing import List, Dict, Union, Iterable
-from hanlp.common.structure import Serializable
-from hanlp.common.constant import PAD, UNK
-import tensorflow as tf
-from tensorflow.python.ops.lookup_ops import index_table_from_tensor, index_to_string_table_from_tensor
+from hanlp_common.constant import UNK, PAD
+from hanlp_common.structure import Serializable
+from hanlp_common.reflection import classpath_of
class Vocab(Serializable):
def __init__(self, idx_to_token: List[str] = None, token_to_idx: Dict = None, mutable=True, pad_token=PAD,
unk_token=UNK) -> None:
+ """Vocabulary base class which converts tokens to indices and vice versa.
+
+ Args:
+ idx_to_token: id to token mapping.
+ token_to_idx: token to id mapping.
+ mutable: ``True`` to allow adding new tokens, ``False`` to map OOV to ``unk``.
+ pad_token: The token representing padding.
+ unk_token: The token representing OOV.
+ """
super().__init__()
if idx_to_token:
t2i = dict((token, idx) for idx, token in enumerate(idx_to_token))
@@ -20,23 +29,29 @@ def __init__(self, idx_to_token: List[str] = None, token_to_idx: Dict = None, mu
token_to_idx = t2i
if token_to_idx is None:
token_to_idx = {}
- if pad_token:
+ if pad_token is not None:
token_to_idx[pad_token] = len(token_to_idx)
- if unk_token:
- token_to_idx[unk_token] = len(token_to_idx)
+ if unk_token is not None:
+ token_to_idx[unk_token] = token_to_idx.get(unk_token, len(token_to_idx))
self.token_to_idx = token_to_idx
- self.idx_to_token: list = None
+ self.idx_to_token: List[str] = None
self.mutable = mutable
self.pad_token = pad_token
self.unk_token = unk_token
- self.token_to_idx_table: tf.lookup.StaticHashTable = None
- self.idx_to_token_table = None
def __setitem__(self, token: str, idx: int):
assert self.mutable, 'Update an immutable Vocab object is not allowed'
self.token_to_idx[token] = idx
def __getitem__(self, key: Union[str, int, List]) -> Union[int, str, List]:
+ """ Get the index/indices associated with a token or a list of tokens or vice versa.
+
+ Args:
+ key: ``str`` for token(s) and ``int`` for index/indices.
+
+ Returns: Associated indices or tokens.
+
+ """
if isinstance(key, str):
return self.get_idx(key)
elif isinstance(key, int):
@@ -58,9 +73,19 @@ def __contains__(self, key: Union[str, int]):
return False
def add(self, token: str) -> int:
+ """ Tries to add a token into a vocab and returns its id. If it has already been there, its id will be returned
+ and the vocab won't be updated. If the vocab is locked, an assertion failure will occur.
+
+ Args:
+ token: A new or existing token.
+
+ Returns:
+ Its associated id.
+
+ """
assert self.mutable, 'It is not allowed to call add on an immutable Vocab'
assert isinstance(token, str), f'Token type must be str but got {type(token)} from {token}'
- assert token, 'Token must not be None or length 0'
+ assert token is not None, 'Token must not be None'
idx = self.token_to_idx.get(token, None)
if idx is None:
idx = len(self.token_to_idx)
@@ -68,21 +93,28 @@ def add(self, token: str) -> int:
return idx
def update(self, tokens: Iterable[str]) -> None:
- """
- Update the vocab with these tokens by adding them to vocab one by one.
- Parameters
- ----------
- tokens
+ """Update the vocab with these tokens by adding them to vocab one by one.
+
+ Args:
+ tokens (Iterable[str]): A list of tokens.
"""
assert self.mutable, 'It is not allowed to update an immutable Vocab'
for token in tokens:
self.add(token)
def get_idx(self, token: str) -> int:
- if type(token) is list:
- idx = [self.get_idx(t) for t in token]
- else:
- idx = self.token_to_idx.get(token, None)
+ """Get the idx of a token. If it's not there, it will be added to the vocab when the vocab is locked otherwise
+ the id of UNK will be returned.
+
+ Args:
+ token: A token.
+
+ Returns:
+ The id of that token.
+
+ """
+ assert isinstance(token, str), 'token has to be `str`'
+ idx = self.token_to_idx.get(token, None)
if idx is None:
if self.mutable:
idx = len(self.token_to_idx)
@@ -94,10 +126,18 @@ def get_idx(self, token: str) -> int:
def get_idx_without_add(self, token: str) -> int:
idx = self.token_to_idx.get(token, None)
if idx is None:
- idx = self.token_to_idx.get(self.unk_token, None)
+ idx = self.token_to_idx.get(self.safe_unk_token, None)
return idx
def get_token(self, idx: int) -> str:
+ """Get the token using its index.
+
+ Args:
+ idx: The index to a token.
+
+ Returns:
+
+ """
if self.idx_to_token:
return self.idx_to_token[idx]
@@ -113,11 +153,16 @@ def __len__(self):
return len(self.token_to_idx)
def lock(self):
+ """Lock this vocab up so that it won't accept new tokens.
+
+ Returns:
+ Itself.
+
+ """
if self.locked:
return self
self.mutable = False
self.build_idx_to_token()
- self.build_lookup_table()
return self
def build_idx_to_token(self):
@@ -126,27 +171,31 @@ def build_idx_to_token(self):
for token, idx in self.token_to_idx.items():
self.idx_to_token[idx] = token
- def build_lookup_table(self):
- tensor = tf.constant(self.idx_to_token, dtype=tf.string)
- self.token_to_idx_table = index_table_from_tensor(tensor, num_oov_buckets=1 if self.unk_idx is None else 0,
- default_value=-1 if self.unk_idx is None else self.unk_idx)
- # self.idx_to_token_table = index_to_string_table_from_tensor(self.idx_to_token, self.safe_unk_token)
-
def unlock(self):
+ """Unlock this vocab so that new tokens can be added in.
+
+ Returns:
+ Itself.
+
+ """
if not self.locked:
return
self.mutable = True
self.idx_to_token = None
- self.idx_to_token_table = None
- self.token_to_idx_table = None
return self
@property
def locked(self):
+ """
+ ``True`` indicates this vocab is locked.
+ """
return not self.mutable
@property
def unk_idx(self):
+ """
+ The index of ``UNK`` token.
+ """
if self.unk_token is None:
return None
else:
@@ -154,6 +203,9 @@ def unk_idx(self):
@property
def pad_idx(self):
+ """
+ The index of ``PAD`` token.
+ """
if self.pad_token is None:
return None
else:
@@ -161,12 +213,24 @@ def pad_idx(self):
@property
def tokens(self):
+ """
+ A set of all tokens in this vocab.
+ """
return self.token_to_idx.keys()
def __str__(self) -> str:
return self.token_to_idx.__str__()
def summary(self, verbose=True) -> str:
+ """Get or print a summary of this vocab.
+
+ Args:
+ verbose: ``True`` to print the summary to stdout.
+
+ Returns:
+ Summary in text form.
+
+ """
# report = 'Length: {}\n'.format(len(self))
# report += 'Samples: {}\n'.format(str(list(self.token_to_idx.keys())[:min(50, len(self))]))
# report += 'Mutable: {}'.format(self.mutable)
@@ -177,21 +241,29 @@ def summary(self, verbose=True) -> str:
print(report)
return report
- def __call__(self, some_token: Union[str, List[str]]) -> Union[int, List[int]]:
- if isinstance(some_token, list):
+ def __call__(self, some_token: Union[str, Iterable[str]]) -> Union[int, List[int]]:
+ if isinstance(some_token, (list, tuple, set)):
indices = []
+ if len(some_token) and isinstance(some_token[0], (list, tuple, set)):
+ for sent in some_token:
+ inside = []
+ for token in sent:
+ inside.append(self.get_idx(token))
+ indices.append(inside)
+ return indices
for token in some_token:
indices.append(self.get_idx(token))
return indices
else:
return self.get_idx(some_token)
- def lookup(self, token_tensor: tf.Tensor) -> tf.Tensor:
- if self.mutable:
- self.lock()
- return self.token_to_idx_table.lookup(token_tensor)
-
def to_dict(self) -> dict:
+ """Convert this vocab to a dict so that it can be json serialized.
+
+ Returns:
+ A dict.
+
+ """
idx_to_token = self.idx_to_token
pad_token = self.pad_token
unk_token = self.unk_token
@@ -201,13 +273,27 @@ def to_dict(self) -> dict:
return items
def copy_from(self, item: dict):
+ """Copy properties from a dict so that it can json de-serialized.
+
+ Args:
+ item: A dict holding ``token_to_idx``
+
+ Returns:
+ Itself.
+
+ """
for key, value in item.items():
setattr(self, key, value)
self.token_to_idx = {k: v for v, k in enumerate(self.idx_to_token)}
- if not self.mutable:
- self.build_lookup_table()
+ return self
def lower(self):
+ """Convert all tokens to lower case.
+
+ Returns:
+ Itself.
+
+ """
self.unlock()
token_to_idx = self.token_to_idx
self.token_to_idx = {}
@@ -217,6 +303,8 @@ def lower(self):
@property
def first_token(self):
+ """The first token in this vocab.
+ """
if self.idx_to_token:
return self.idx_to_token[0]
if self.token_to_idx:
@@ -224,18 +312,18 @@ def first_token(self):
return None
def merge(self, other):
+ """Merge this with another vocab inplace.
+
+ Args:
+ other (Vocab): Another vocab.
+ """
for word, idx in other.token_to_idx.items():
self.get_idx(word)
@property
def safe_pad_token(self) -> str:
- """
- Get the pad token safely. It always returns a pad token, which is the token
- closest to pad if not presented in the vocab.
-
- Returns
- -------
- str pad token
+ """Get the pad token safely. It always returns a pad token, which is the pad token or the first token
+ if pad does not present in the vocab.
"""
if self.pad_token:
return self.pad_token
@@ -245,17 +333,15 @@ def safe_pad_token(self) -> str:
@property
def safe_pad_token_idx(self) -> int:
+ """Get the idx to the pad token safely. It always returns an index, which corresponds to the pad token or the
+ first token if pad does not present in the vocab.
+ """
return self.token_to_idx.get(self.safe_pad_token, 0)
@property
def safe_unk_token(self) -> str:
- """
- Get the unk token safely. It always returns a unk token, which is the token
- closest to unk if not presented in the vocab.
-
- Returns
- -------
- str pad token
+ """Get the unk token safely. It always returns a unk token, which is the unk token or the first token if unk
+ does not presented in the vocab.
"""
if self.unk_token:
return self.unk_token
@@ -263,6 +349,127 @@ def safe_unk_token(self) -> str:
return self.first_token
return UNK
+ def __repr__(self) -> str:
+ if self.idx_to_token is not None:
+ return self.idx_to_token.__repr__()
+ return self.token_to_idx.__repr__()
+
+ def extend(self, tokens: Iterable[str]):
+ self.unlock()
+ self(tokens)
+
+ def reload_idx_to_token(self, idx_to_token: List[str], pad_idx=0, unk_idx=1):
+ self.idx_to_token = idx_to_token
+ self.token_to_idx = dict((s, i) for i, s in enumerate(idx_to_token))
+ if pad_idx is not None:
+ self.pad_token = idx_to_token[pad_idx]
+ if unk_idx is not None:
+ self.unk_token = idx_to_token[unk_idx]
+
+ def set_unk_as_safe_unk(self):
+ """Set ``self.unk_token = self.safe_unk_token``. It's useful when the dev/test set contains OOV labels.
+ """
+ self.unk_token = self.safe_unk_token
+
+ def clear(self):
+ self.unlock()
+ self.token_to_idx.clear()
+
+
+class CustomVocab(Vocab):
+ def to_dict(self) -> dict:
+ d = super().to_dict()
+ d['type'] = classpath_of(self)
+ return d
+
+
+class LowercaseVocab(CustomVocab):
+ def get_idx(self, token: str) -> int:
+ idx = self.token_to_idx.get(token, None)
+ if idx is None:
+ idx = self.token_to_idx.get(token.lower(), None)
+ if idx is None:
+ if self.mutable:
+ idx = len(self.token_to_idx)
+ self.token_to_idx[token] = idx
+ else:
+ idx = self.token_to_idx.get(self.unk_token, None)
+ return idx
+
+
+class VocabWithNone(CustomVocab):
+ def get_idx(self, token: str) -> int:
+ if token is None:
+ return -1
+ return super().get_idx(token)
+
+
+class VocabWithFrequency(CustomVocab):
+
+ def __init__(self, counter: Counter = None, min_occur_cnt=0, pad_token=PAD, unk_token=UNK, specials=None) -> None:
+ super().__init__(None, None, True, pad_token, unk_token)
+ if specials:
+ for each in specials:
+ counter.pop(each, None)
+ self.add(each)
+ self.frequencies = [1] * len(self)
+ if counter:
+ for token, freq in counter.most_common():
+ if freq >= min_occur_cnt:
+ self.add(token)
+ self.frequencies.append(freq)
+ self.lock()
+
+ def to_dict(self) -> dict:
+ d = super().to_dict()
+ d['frequencies'] = self.frequencies
+ return d
+
+ def copy_from(self, item: dict):
+ super().copy_from(item)
+ self.frequencies = item['frequencies']
+
+ def get_frequency(self, token):
+ idx = self.get_idx(token)
+ if idx is not None:
+ return self.frequencies[idx]
+ return 0
+
+
+class VocabCounter(CustomVocab):
+
+ def __init__(self, idx_to_token: List[str] = None, token_to_idx: Dict = None, mutable=True, pad_token=PAD,
+ unk_token=UNK) -> None:
+ super().__init__(idx_to_token, token_to_idx, mutable, pad_token, unk_token)
+ self.counter = Counter()
+
+ def get_idx(self, token: str) -> int:
+ if self.mutable:
+ self.counter[token] += 1
+ return super().get_idx(token)
+
+ def trim(self, min_frequency):
+ assert self.mutable
+ specials = {self.unk_token, self.pad_token}
+ survivors = list((token, freq) for token, freq in self.counter.most_common()
+ if freq >= min_frequency and token not in specials)
+ survivors = [(x, -1) for x in specials if x] + survivors
+ self.counter = Counter(dict(survivors))
+ self.token_to_idx = dict()
+ self.idx_to_token = None
+ for token, freq in survivors:
+ idx = len(self.token_to_idx)
+ self.token_to_idx[token] = idx
+
+ def copy_from(self, item: dict):
+ super().copy_from(item)
+ self.counter = Counter(item['counter'].items()) if 'counter' in item else Counter()
+
+ def to_dict(self) -> dict:
+ d = super().to_dict()
+ d['counter'] = dict(self.counter.items())
+ return d
+
def create_label_vocab() -> Vocab:
return Vocab(pad_token=None, unk_token=None)
diff --git a/hanlp/common/vocab_tf.py b/hanlp/common/vocab_tf.py
new file mode 100644
index 000000000..d723073dd
--- /dev/null
+++ b/hanlp/common/vocab_tf.py
@@ -0,0 +1,270 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-06-13 22:42
+from typing import List, Dict, Union, Iterable
+
+from hanlp_common.structure import Serializable
+from hanlp_common.constant import PAD, UNK
+import tensorflow as tf
+from tensorflow.python.ops.lookup_ops import index_table_from_tensor
+
+
+class VocabTF(Serializable):
+ def __init__(self, idx_to_token: List[str] = None, token_to_idx: Dict = None, mutable=True, pad_token=PAD,
+ unk_token=UNK) -> None:
+ super().__init__()
+ if idx_to_token:
+ t2i = dict((token, idx) for idx, token in enumerate(idx_to_token))
+ if token_to_idx:
+ t2i.update(token_to_idx)
+ token_to_idx = t2i
+ if token_to_idx is None:
+ token_to_idx = {}
+ if pad_token:
+ token_to_idx[pad_token] = len(token_to_idx)
+ if unk_token:
+ token_to_idx[unk_token] = len(token_to_idx)
+ self.token_to_idx = token_to_idx
+ self.idx_to_token: list = None
+ self.mutable = mutable
+ self.pad_token = pad_token
+ self.unk_token = unk_token
+ self.token_to_idx_table: tf.lookup.StaticHashTable = None
+ self.idx_to_token_table = None
+
+ def __setitem__(self, token: str, idx: int):
+ assert self.mutable, 'Update an immutable Vocab object is not allowed'
+ self.token_to_idx[token] = idx
+
+ def __getitem__(self, key: Union[str, int, List]) -> Union[int, str, List]:
+ if isinstance(key, str):
+ return self.get_idx(key)
+ elif isinstance(key, int):
+ return self.get_token(key)
+ elif isinstance(key, list):
+ if len(key) == 0:
+ return []
+ elif isinstance(key[0], str):
+ return [self.get_idx(x) for x in key]
+ elif isinstance(key[0], int):
+ return [self.get_token(x) for x in key]
+
+ def __contains__(self, key: Union[str, int]):
+ if isinstance(key, str):
+ return key in self.token_to_idx
+ elif isinstance(key, int):
+ return 0 <= key < len(self.idx_to_token)
+ else:
+ return False
+
+ def add(self, token: str) -> int:
+ assert self.mutable, 'It is not allowed to call add on an immutable Vocab'
+ assert isinstance(token, str), f'Token type must be str but got {type(token)} from {token}'
+ assert token, 'Token must not be None or length 0'
+ idx = self.token_to_idx.get(token, None)
+ if idx is None:
+ idx = len(self.token_to_idx)
+ self.token_to_idx[token] = idx
+ return idx
+
+ def update(self, tokens: Iterable[str]) -> None:
+ """Update the vocab with these tokens by adding them to vocab one by one.
+
+ Args:
+ tokens: Iterable[str]:
+
+ Returns:
+
+
+ """
+ assert self.mutable, 'It is not allowed to update an immutable Vocab'
+ for token in tokens:
+ self.add(token)
+
+ def get_idx(self, token: str) -> int:
+ idx = self.token_to_idx.get(token, None)
+ if idx is None:
+ if self.mutable:
+ idx = len(self.token_to_idx)
+ self.token_to_idx[token] = idx
+ else:
+ idx = self.token_to_idx.get(self.unk_token, None)
+ return idx
+
+ def get_idx_without_add(self, token: str) -> int:
+ idx = self.token_to_idx.get(token, None)
+ if idx is None:
+ idx = self.token_to_idx.get(self.safe_unk_token, None)
+ return idx
+
+ def get_token(self, idx: int) -> str:
+ if self.idx_to_token:
+ return self.idx_to_token[idx]
+
+ if self.mutable:
+ for token in self.token_to_idx:
+ if self.token_to_idx[token] == idx:
+ return token
+
+ def has_key(self, token):
+ return token in self.token_to_idx
+
+ def __len__(self):
+ return len(self.token_to_idx)
+
+ def lock(self):
+ if self.locked:
+ return self
+ self.mutable = False
+ self.build_idx_to_token()
+ self.build_lookup_table()
+ return self
+
+ def build_idx_to_token(self):
+ max_idx = max(self.token_to_idx.values())
+ self.idx_to_token = [None] * (max_idx + 1)
+ for token, idx in self.token_to_idx.items():
+ self.idx_to_token[idx] = token
+
+ def build_lookup_table(self):
+ tensor = tf.constant(self.idx_to_token, dtype=tf.string)
+ self.token_to_idx_table = index_table_from_tensor(tensor, num_oov_buckets=1 if self.unk_idx is None else 0,
+ default_value=-1 if self.unk_idx is None else self.unk_idx)
+ # self.idx_to_token_table = index_to_string_table_from_tensor(self.idx_to_token, self.safe_unk_token)
+
+ def unlock(self):
+ if not self.locked:
+ return
+ self.mutable = True
+ self.idx_to_token = None
+ self.idx_to_token_table = None
+ self.token_to_idx_table = None
+ return self
+
+ @property
+ def locked(self):
+ return not self.mutable
+
+ @property
+ def unk_idx(self):
+ if self.unk_token is None:
+ return None
+ else:
+ return self.token_to_idx.get(self.unk_token, None)
+
+ @property
+ def pad_idx(self):
+ if self.pad_token is None:
+ return None
+ else:
+ return self.token_to_idx.get(self.pad_token, None)
+
+ @property
+ def tokens(self):
+ return self.token_to_idx.keys()
+
+ def __str__(self) -> str:
+ return self.token_to_idx.__str__()
+
+ def summary(self, verbose=True) -> str:
+ # report = 'Length: {}\n'.format(len(self))
+ # report += 'Samples: {}\n'.format(str(list(self.token_to_idx.keys())[:min(50, len(self))]))
+ # report += 'Mutable: {}'.format(self.mutable)
+ # report = report.strip()
+ report = '[{}] = '.format(len(self))
+ report += str(list(self.token_to_idx.keys())[:min(50, len(self))])
+ if verbose:
+ print(report)
+ return report
+
+ def __call__(self, some_token: Union[str, List[str]]) -> Union[int, List[int]]:
+ if isinstance(some_token, list):
+ indices = []
+ for token in some_token:
+ indices.append(self.get_idx(token))
+ return indices
+ else:
+ return self.get_idx(some_token)
+
+ def lookup(self, token_tensor: tf.Tensor) -> tf.Tensor:
+ if self.mutable:
+ self.lock()
+ return self.token_to_idx_table.lookup(token_tensor)
+
+ def to_dict(self) -> dict:
+ idx_to_token = self.idx_to_token
+ pad_token = self.pad_token
+ unk_token = self.unk_token
+ mutable = self.mutable
+ items = locals().copy()
+ items.pop('self')
+ return items
+
+ def copy_from(self, item: dict):
+ for key, value in item.items():
+ setattr(self, key, value)
+ self.token_to_idx = {k: v for v, k in enumerate(self.idx_to_token)}
+ if not self.mutable:
+ self.build_lookup_table()
+
+ def lower(self):
+ self.unlock()
+ token_to_idx = self.token_to_idx
+ self.token_to_idx = {}
+ for token in token_to_idx.keys():
+ self.add(token.lower())
+ return self
+
+ @property
+ def first_token(self):
+ if self.idx_to_token:
+ return self.idx_to_token[0]
+ if self.token_to_idx:
+ return next(iter(self.token_to_idx))
+ return None
+
+ def merge(self, other):
+ for word, idx in other.token_to_idx.items():
+ self.get_idx(word)
+
+ @property
+ def safe_pad_token(self) -> str:
+ """Get the pad token safely. It always returns a pad token, which is the token
+ closest to pad if not presented in the vocab.
+
+ Args:
+
+ Returns:
+
+
+ """
+ if self.pad_token:
+ return self.pad_token
+ if self.first_token:
+ return self.first_token
+ return PAD
+
+ @property
+ def safe_pad_token_idx(self) -> int:
+ return self.token_to_idx.get(self.safe_pad_token, 0)
+
+ @property
+ def safe_unk_token(self) -> str:
+ """Get the unk token safely. It always returns a unk token, which is the token
+ closest to unk if not presented in the vocab.
+
+ Args:
+
+ Returns:
+
+
+ """
+ if self.unk_token:
+ return self.unk_token
+ if self.first_token:
+ return self.first_token
+ return UNK
+
+
+def create_label_vocab() -> VocabTF:
+ return VocabTF(pad_token=None, unk_token=None)
diff --git a/hanlp/components/__init__.py b/hanlp/components/__init__.py
index 5aaed45a1..47f2f4638 100644
--- a/hanlp/components/__init__.py
+++ b/hanlp/components/__init__.py
@@ -1,5 +1,4 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-08-26 16:10
-from .pipeline import Pipeline
-from . import tok
\ No newline at end of file
+from .pipeline import Pipeline
\ No newline at end of file
diff --git a/hanlp/components/amr/__init__.py b/hanlp/components/amr/__init__.py
new file mode 100644
index 000000000..b195bb373
--- /dev/null
+++ b/hanlp/components/amr/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-20 17:35
diff --git a/hanlp/components/classifiers/transformer_classifier.py b/hanlp/components/classifiers/transformer_classifier.py
index a70d9b65a..33126be54 100644
--- a/hanlp/components/classifiers/transformer_classifier.py
+++ b/hanlp/components/classifiers/transformer_classifier.py
@@ -1,189 +1,383 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-11-10 13:19
-
-import math
-from typing import Union, Tuple, List, Any, Iterable
-
-import tensorflow as tf
-from bert.tokenization.bert_tokenization import FullTokenizer
-
-from hanlp.common.component import KerasComponent
-from hanlp.common.structure import SerializableDict
-from hanlp.layers.transformers.loader import build_transformer
-from hanlp.optimizers.adamw import create_optimizer
-from hanlp.transform.table import TableTransform
-from hanlp.utils.log_util import logger
-from hanlp.utils.util import merge_locals_kwargs
-import numpy as np
-
-
-class TransformerTextTransform(TableTransform):
-
- def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None,
- y_column=-1, skip_header=True, delimiter='auto', multi_label=False, **kwargs) -> None:
- super().__init__(config, map_x, map_y, x_columns, y_column, multi_label, skip_header, delimiter, **kwargs)
- self.tokenizer: FullTokenizer = None
-
- def inputs_to_samples(self, inputs, gold=False):
- tokenizer = self.tokenizer
- max_length = self.config.max_length
- num_features = None
- pad_token = None if self.label_vocab.mutable else tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
- for (X, Y) in super().inputs_to_samples(inputs, gold):
- if self.label_vocab.mutable:
- yield None, Y
- continue
- if isinstance(X, str):
- X = (X,)
- if num_features is None:
- num_features = self.config.num_features
- assert num_features == len(X), f'Numbers of features {num_features} ' \
- f'inconsistent with current {len(X)}={X}'
- text_a = X[0]
- text_b = X[1] if len(X) > 1 else None
- tokens_a = self.tokenizer.tokenize(text_a)
- tokens_b = self.tokenizer.tokenize(text_b) if text_b else None
- tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
- segment_ids = [0] * len(tokens)
- if tokens_b:
- tokens += tokens_b
- segment_ids += [1] * len(tokens_b)
- token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
- attention_mask = [1] * len(token_ids)
- diff = max_length - len(token_ids)
- if diff < 0:
- token_ids = token_ids[:max_length]
- attention_mask = attention_mask[:max_length]
- segment_ids = segment_ids[:max_length]
- elif diff > 0:
- token_ids += [pad_token] * diff
- attention_mask += [0] * diff
- segment_ids += [0] * diff
-
- assert len(token_ids) == max_length, "Error with input length {} vs {}".format(len(token_ids), max_length)
- assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
- assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids), max_length)
-
-
- label = Y
- yield (token_ids, attention_mask, segment_ids), label
-
- def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
- max_length = self.config.max_length
- types = (tf.int32, tf.int32, tf.int32), tf.string
- shapes = ([max_length], [max_length], [max_length]), [None,] if self.config.multi_label else []
- values = (0, 0, 0), self.label_vocab.safe_pad_token
- return types, shapes, values
-
- def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
- logger.fatal('map_x should always be set to True')
- exit(1)
-
- def y_to_idx(self, y) -> tf.Tensor:
- if self.config.multi_label:
- #need to change index to binary vector
- mapped = tf.map_fn(fn=lambda x: tf.cast(self.label_vocab.lookup(x), tf.int32), elems=y, fn_output_signature=tf.TensorSpec(dtype=tf.dtypes.int32, shape=[None,]))
- one_hots = tf.one_hot(mapped, len(self.label_vocab))
- idx = tf.reduce_sum(one_hots, -2)
- else:
- idx = self.label_vocab.lookup(y)
- return idx
+# Date: 2020-06-08 16:31
+import logging
+from abc import ABC
+from typing import Callable, Union
+from typing import List
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import IDX
+from hanlp.common.dataset import TableDataset, SortingSampler, PadSequenceDataLoader, TransformableDataset
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.vocab import Vocab
+from hanlp.components.distillation.schedulers import LinearTeacherAnnealingScheduler
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.layers.transformers.encoder import TransformerEncoder
+from hanlp.layers.transformers.pt_imports import PreTrainedModel, AutoTokenizer, BertTokenizer
+from hanlp.layers.transformers.utils import transformer_sliding_window, build_optimizer_scheduler_with_transformer
+from hanlp.metrics.accuracy import CategoricalAccuracy
+from hanlp.transform.transformer_tokenizer import TransformerTextTokenizer
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs, merge_dict, isdebugging
+
- def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
- # Prediction to be Y > 0:
- if self.config.multi_label:
- preds = Y
+class TransformerClassificationModel(nn.Module):
+
+ def __init__(self,
+ transformer: PreTrainedModel,
+ num_labels: int,
+ max_seq_length=512) -> None:
+ super().__init__()
+ self.max_seq_length = max_seq_length
+ self.transformer = transformer
+ self.dropout = nn.Dropout(transformer.config.hidden_dropout_prob)
+ self.classifier = nn.Linear(transformer.config.hidden_size, num_labels)
+
+ def forward(self, input_ids, attention_mask, token_type_ids):
+ seq_length = input_ids.size(-1)
+ if seq_length > self.max_seq_length:
+ sequence_output = transformer_sliding_window(self.transformer, input_ids,
+ max_pieces=self.max_seq_length, ret_cls='max')
else:
- preds = tf.argmax(Y, axis=-1)
- for y in preds:
- yield self.label_vocab.idx_to_token[y]
+ sequence_output = self.transformer(input_ids, attention_mask, token_type_ids)[0][:, 0, :]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+ return logits
- def input_is_single_sample(self, input: Any) -> bool:
- return isinstance(input, (str, tuple))
+class TransformerComponent(TorchComponent, ABC):
+ def __init__(self, **kwargs) -> None:
+ """ The base class for transorfmer based components. If offers methods to build transformer tokenizers
+ , optimizers and models.
-class TransformerClassifier(KerasComponent):
+ Args:
+ **kwargs: Passed to config.
+ """
+ super().__init__(**kwargs)
+ self.transformer_tokenizer = None
- def __init__(self, bert_text_transform=None) -> None:
- if not bert_text_transform:
- bert_text_transform = TransformerTextTransform()
- super().__init__(bert_text_transform)
- self.model: tf.keras.Model
- self.transform: TransformerTextTransform = bert_text_transform
+ def build_optimizer(self,
+ trn,
+ epochs,
+ lr,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ transformer_lr=None,
+ teacher=None,
+ **kwargs):
+ num_training_steps = len(trn) * epochs // self.config.get('gradient_accumulation', 1)
+ if transformer_lr is None:
+ transformer_lr = lr
+ transformer = self.model.encoder.transformer
+ optimizer, scheduler = build_optimizer_scheduler_with_transformer(self.model, transformer,
+ lr, transformer_lr,
+ num_training_steps, warmup_steps,
+ weight_decay, adam_epsilon)
+ if teacher:
+ lambda_scheduler = LinearTeacherAnnealingScheduler(num_training_steps)
+ scheduler = (scheduler, lambda_scheduler)
+ return optimizer, scheduler
- # noinspection PyMethodOverriding
- def fit(self, trn_data: Any, dev_data: Any, save_dir: str, transformer: str, max_length: int = 128,
- optimizer='adamw', warmup_steps_ratio=0.1, use_amp=False, batch_size=32,
- epochs=3, logger=None, verbose=1, **kwargs):
+ def fit(self, trn_data, dev_data, save_dir,
+ transformer=None,
+ lr=5e-5,
+ transformer_lr=None,
+ adam_epsilon=1e-8,
+ weight_decay=0,
+ warmup_steps=0.1,
+ batch_size=32,
+ gradient_accumulation=1,
+ grad_norm=5.0,
+ transformer_grad_norm=None,
+ average_subwords=False,
+ scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None,
+ word_dropout=None,
+ hidden_dropout=None,
+ max_sequence_length=None,
+ ret_raw_hidden_states=False,
+ batch_max_tokens=None,
+ epochs=3,
+ logger=None,
+ devices: Union[float, int, List[int]] = None,
+ **kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
- def evaluate_output(self, tst_data, out, num_batches, metric):
- out.write('sentence\tpred\tgold\n')
- total, correct, score = 0, 0, 0
- for idx, batch in enumerate(tst_data):
- outputs = self.model.predict_on_batch(batch[0])[0]
- outputs = tf.argmax(outputs, axis=1)
- for X, Y_pred, Y_gold, in zip(batch[0][0], outputs, batch[1]):
- feature = ' '.join(self.transform.tokenizer.convert_ids_to_tokens(X.numpy(), skip_special_tokens=True))
- feature = feature.replace(' ##', '') # fix sub-word generated by BERT tagger
- out.write('{}\t{}\t{}\n'.format(feature,
- self._y_id_to_str(Y_pred),
- self._y_id_to_str(Y_gold)))
- total += 1
- correct += int(tf.equal(Y_pred, Y_gold).numpy())
- score = correct / total
- print('\r{}/{} {}: {:.2f}'.format(idx + 1, num_batches, metric, score * 100), end='')
- print()
- return score
-
- def _y_id_to_str(self, Y_pred) -> str:
- return self.transform.label_vocab.idx_to_token[Y_pred.numpy()]
-
- def build_loss(self, loss, **kwargs):
- if loss:
- assert isinstance(loss, tf.keras.losses.loss), 'Must specify loss as an instance in tf.keras.losses'
- return loss
- elif self.config.multi_label:
- #Loss to be BinaryCrossentropy for multi-label:
- loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
+ def on_config_ready(self, **kwargs):
+ super().on_config_ready(**kwargs)
+ if 'albert_chinese' in self.config.transformer:
+ self.transformer_tokenizer = BertTokenizer.from_pretrained(self.config.transformer, use_fast=True)
else:
- loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ self.transformer_tokenizer = AutoTokenizer.from_pretrained(self.config.transformer, use_fast=True)
+
+ def build_transformer(self, training=True):
+ transformer = TransformerEncoder(self.config.transformer, self.transformer_tokenizer,
+ self.config.average_subwords,
+ self.config.scalar_mix, self.config.word_dropout,
+ self.config.max_sequence_length, self.config.ret_raw_hidden_states,
+ training=training)
+ transformer_layers = self.config.get('transformer_layers', None)
+ if transformer_layers:
+ transformer.transformer.encoder.layer = transformer.transformer.encoder.layer[:-transformer_layers]
+ return transformer
+
+
+class TransformerClassifier(TransformerComponent):
+
+ def __init__(self, **kwargs) -> None:
+ """A classifier using transformer as encoder.
+
+ Args:
+ **kwargs: Passed to config.
+ """
+ super().__init__(**kwargs)
+ self.model: TransformerClassificationModel = None
+
+ def build_criterion(self, **kwargs):
+ criterion = nn.CrossEntropyLoss()
+ return criterion
+
+ def build_metric(self, **kwargs):
+ return CategoricalAccuracy()
+
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
+ logger: logging.Logger, devices, **kwargs):
+ best_epoch, best_metric = 0, -1
+ timer = CountdownTimer(epochs)
+ ratio_width = len(f'{len(trn)}/{len(trn)}')
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger)
+ if dev:
+ self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
+ report = f'{timer.elapsed_human}/{timer.total_time_human}'
+ dev_score = metric.get_metric()
+ if dev_score > best_metric:
+ self.save_weights(save_dir)
+ best_metric = dev_score
+ report += ' [red]saved[/red]'
+ timer.log(report, ratio_percentage=False, newline=True, ratio=False)
+
+ @property
+ def label_vocab(self):
+ return self.vocabs[self.config.label_key]
+
+ def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs):
+ self.model.train()
+ timer = CountdownTimer(len(trn))
+ optimizer, scheduler = optimizer
+ total_loss = 0
+ metric.reset()
+ for batch in trn:
+ optimizer.zero_grad()
+ logits = self.feed_batch(batch)
+ target = batch['label_id']
+ loss = self.compute_loss(criterion, logits, target, batch)
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ total_loss += loss.item()
+ self.update_metric(metric, logits, target)
+ timer.log(f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}',
+ ratio_percentage=None,
+ logger=logger)
+ del loss
+ return total_loss / timer.total
+
+ def update_metric(self, metric, logits: torch.Tensor, target, output=None):
+ metric(logits, target)
+ if output:
+ label_ids = logits.argmax(-1)
+ return label_ids
+
+ def compute_loss(self, criterion, logits, target, batch):
+ loss = criterion(logits, target)
return loss
+ def feed_batch(self, batch) -> torch.LongTensor:
+ logits = self.model(*[batch[key] for key in ['input_ids', 'attention_mask', 'token_type_ids']])
+ return logits
+
# noinspection PyMethodOverriding
- def build_optimizer(self, optimizer, use_amp, train_steps, warmup_steps, **kwargs):
- if optimizer == 'adamw':
- opt = create_optimizer(init_lr=5e-5, num_train_steps=train_steps, num_warmup_steps=warmup_steps)
- # opt = tfa.optimizers.AdamW(learning_rate=3e-5, epsilon=1e-08, weight_decay=0.01)
- # opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
- self.config.optimizer = tf.keras.utils.serialize_keras_object(opt)
- lr_config = self.config.optimizer['config']['learning_rate']['config']
- if hasattr(lr_config['decay_schedule_fn'], 'get_config'):
- lr_config['decay_schedule_fn'] = dict(
- (k, v) for k, v in lr_config['decay_schedule_fn'].get_config().items() if not k.startswith('_'))
- else:
- opt = super().build_optimizer(optimizer)
- if use_amp:
- # loss scaling is currently required when using mixed precision
- opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
- return opt
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric,
+ logger,
+ ratio_width=None,
+ filename=None,
+ output=None,
+ **kwargs):
+ self.model.eval()
+ timer = CountdownTimer(len(data))
+ total_loss = 0
+ metric.reset()
+ num_samples = 0
+ if output:
+ output = open(output, 'w')
+ for batch in data:
+ logits = self.feed_batch(batch)
+ target = batch['label_id']
+ loss = self.compute_loss(criterion, logits, target, batch)
+ total_loss += loss.item()
+ label_ids = self.update_metric(metric, logits, target, output)
+ if output:
+ labels = [self.vocabs[self.config.label_key].idx_to_token[i] for i in label_ids.tolist()]
+ for i, label in enumerate(labels):
+ # text_a text_b pred gold
+ columns = [batch[self.config.text_a_key][i]]
+ if self.config.text_b_key:
+ columns.append(batch[self.config.text_b_key][i])
+ columns.append(label)
+ columns.append(batch[self.config.label_key][i])
+ output.write('\t'.join(columns))
+ output.write('\n')
+ num_samples += len(target)
+ report = f'loss: {total_loss / (timer.current + 1):.4f} acc: {metric.get_metric():.2%}'
+ if filename:
+ report = f'{filename} {report} {num_samples / timer.elapsed:.0f} samples/sec'
+ timer.log(report, ratio_percentage=None, logger=logger, ratio_width=ratio_width)
+ if output:
+ output.close()
+ return total_loss / timer.total
# noinspection PyMethodOverriding
- def build_model(self, transformer, max_length, **kwargs):
- model, self.transform.tokenizer = build_transformer(transformer, max_length, len(self.transform.label_vocab),
- tagging=False)
+ def build_model(self, transformer, training=True, **kwargs) -> torch.nn.Module:
+ # config: PretrainedConfig = AutoConfig.from_pretrained(transformer)
+ # config.num_labels = len(self.vocabs.label)
+ # config.hidden_dropout_prob = self.config.hidden_dropout_prob
+ transformer = self.build_transformer(training=training).transformer
+ model = TransformerClassificationModel(transformer, len(self.vocabs.label))
+ # truncated_normal_(model.classifier.weight, mean=0.02, std=0.05)
return model
- def build_vocab(self, trn_data, logger):
- train_examples = super().build_vocab(trn_data, logger)
- warmup_steps_per_epoch = math.ceil(train_examples * self.config.warmup_steps_ratio / self.config.batch_size)
- self.config.warmup_steps = warmup_steps_per_epoch * self.config.epochs
- return train_examples
+ # noinspection PyMethodOverriding
+ def build_dataloader(self, data, batch_size, shuffle, device, text_a_key, text_b_key,
+ label_key,
+ logger: logging.Logger = None,
+ sorting=True,
+ **kwargs) -> DataLoader:
+ if not batch_size:
+ batch_size = self.config.batch_size
+ dataset = self.build_dataset(data)
+ dataset.append_transform(self.vocabs)
+ if self.vocabs.mutable:
+ if not any([text_a_key, text_b_key]):
+ if len(dataset.headers) == 2:
+ self.config.text_a_key = dataset.headers[0]
+ self.config.label_key = dataset.headers[1]
+ elif len(dataset.headers) >= 3:
+ self.config.text_a_key, self.config.text_b_key, self.config.label_key = dataset.headers[0], \
+ dataset.headers[1], \
+ dataset.headers[-1]
+ else:
+ raise ValueError('Wrong dataset format')
+ report = {'text_a_key', 'text_b_key', 'label_key'}
+ report = dict((k, self.config[k]) for k in report)
+ report = [f'{k}={v}' for k, v in report.items() if v]
+ report = ', '.join(report)
+ logger.info(f'Guess [bold][blue]{report}[/blue][/bold] according to the headers of training dataset: '
+ f'[blue]{dataset}[/blue]')
+ self.build_vocabs(dataset, logger)
+ dataset.purge_cache()
+ # if self.config.transform:
+ # dataset.append_transform(self.config.transform)
+ dataset.append_transform(TransformerTextTokenizer(tokenizer=self.transformer_tokenizer,
+ text_a_key=self.config.text_a_key,
+ text_b_key=self.config.text_b_key,
+ max_seq_length=self.config.max_seq_length,
+ truncate_long_sequences=self.config.truncate_long_sequences,
+ output_key=''))
+ batch_sampler = None
+ if sorting and not isdebugging():
+ if dataset.cache and len(dataset) > 1000:
+ timer = CountdownTimer(len(dataset))
+ lens = []
+ for idx, sample in enumerate(dataset):
+ lens.append(len(sample['input_ids']))
+ timer.log('Pre-processing and caching dataset [blink][yellow]...[/yellow][/blink]',
+ ratio_percentage=None)
+ else:
+ lens = [len(sample['input_ids']) for sample in dataset]
+ batch_sampler = SortingSampler(lens, batch_size=batch_size, shuffle=shuffle,
+ batch_max_tokens=self.config.batch_max_tokens)
+ return PadSequenceDataLoader(dataset, batch_size, shuffle, batch_sampler=batch_sampler, device=device)
- def build_metrics(self, metrics, logger, **kwargs):
- if self.config.multi_label:
- metric = tf.keras.metrics.BinaryAccuracy('binary_accuracy')
+ def build_dataset(self, data) -> TransformableDataset:
+ if isinstance(data, str):
+ dataset = TableDataset(data, cache=True)
+ elif isinstance(data, TableDataset):
+ dataset = data
+ elif isinstance(data, list):
+ dataset = TableDataset(data)
else:
- metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
- return [metric]
\ No newline at end of file
+ raise ValueError(f'Unsupported data {data}')
+ return dataset
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
+ if not data:
+ return []
+ flat = isinstance(data, str) or isinstance(data, tuple)
+ if flat:
+ data = [data]
+ samples = []
+ for idx, d in enumerate(data):
+ sample = {IDX: idx}
+ if self.config.text_b_key:
+ sample[self.config.text_a_key] = d[0]
+ sample[self.config.text_b_key] = d[1]
+ else:
+ sample[self.config.text_a_key] = d
+ samples.append(sample)
+ dataloader = self.build_dataloader(samples,
+ sorting=False,
+ **merge_dict(self.config,
+ batch_size=batch_size,
+ shuffle=False,
+ device=self.device,
+ overwrite=True)
+ )
+ labels = [None] * len(data)
+ vocab = self.vocabs.label
+ for batch in dataloader:
+ logits = self.feed_batch(batch)
+ pred = logits.argmax(-1)
+ pred = pred.tolist()
+ for idx, tag in zip(batch[IDX], pred):
+ labels[idx] = vocab.idx_to_token[tag]
+ if flat:
+ return labels[0]
+ return labels
+
+ def fit(self, trn_data, dev_data, save_dir,
+ text_a_key=None,
+ text_b_key=None,
+ label_key=None,
+ transformer=None,
+ max_seq_length=512,
+ truncate_long_sequences=True,
+ # hidden_dropout_prob=0.0,
+ lr=5e-5,
+ transformer_lr=None,
+ adam_epsilon=1e-6,
+ weight_decay=0,
+ warmup_steps=0.1,
+ batch_size=32,
+ batch_max_tokens=None,
+ epochs=3,
+ logger=None,
+ # transform=None,
+ devices: Union[float, int, List[int]] = None,
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_vocabs(self, trn, logger, **kwargs):
+ self.vocabs.label = Vocab(pad_token=None, unk_token=None)
+ for each in trn:
+ pass
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
diff --git a/hanlp/components/classifiers/transformer_classifier_tf.py b/hanlp/components/classifiers/transformer_classifier_tf.py
new file mode 100644
index 000000000..3e94ffbb9
--- /dev/null
+++ b/hanlp/components/classifiers/transformer_classifier_tf.py
@@ -0,0 +1,194 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-11-10 13:19
+
+import math
+from typing import Union, Tuple, Any, Iterable
+import tensorflow as tf
+from bert.tokenization.bert_tokenization import FullTokenizer
+from hanlp.common.keras_component import KerasComponent
+from hanlp_common.structure import SerializableDict
+from hanlp.layers.transformers.loader_tf import build_transformer
+from hanlp.optimizers.adamw import create_optimizer
+from hanlp.transform.table import TableTransform
+from hanlp.utils.log_util import logger
+from hanlp_common.util import merge_locals_kwargs
+
+
+class TransformerTextTransform(TableTransform):
+
+ def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=None,
+ y_column=-1, skip_header=True, delimiter='auto', multi_label=False, **kwargs) -> None:
+ super().__init__(config, map_x, map_y, x_columns, y_column, multi_label, skip_header, delimiter, **kwargs)
+ self.tokenizer: FullTokenizer = None
+
+ def inputs_to_samples(self, inputs, gold=False):
+ tokenizer = self.tokenizer
+ max_length = self.config.max_length
+ num_features = None
+ pad_token = None if self.label_vocab.mutable else tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
+ for (X, Y) in super().inputs_to_samples(inputs, gold):
+ if self.label_vocab.mutable:
+ yield None, Y
+ continue
+ if isinstance(X, str):
+ X = (X,)
+ if num_features is None:
+ num_features = self.config.num_features
+ assert num_features == len(X), f'Numbers of features {num_features} ' \
+ f'inconsistent with current {len(X)}={X}'
+ text_a = X[0]
+ text_b = X[1] if len(X) > 1 else None
+ tokens_a = self.tokenizer.tokenize(text_a)
+ tokens_b = self.tokenizer.tokenize(text_b) if text_b else None
+ tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
+ segment_ids = [0] * len(tokens)
+ if tokens_b:
+ tokens += tokens_b
+ segment_ids += [1] * len(tokens_b)
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
+ attention_mask = [1] * len(token_ids)
+ diff = max_length - len(token_ids)
+ if diff < 0:
+ logger.warning(
+ f'Input tokens {tokens} exceed the max sequence length of {max_length - 2}. '
+ f'The exceeded part will be truncated and ignored. '
+ f'You are recommended to split your long text into several sentences within '
+ f'{max_length - 2} tokens beforehand.')
+ token_ids = token_ids[:max_length]
+ attention_mask = attention_mask[:max_length]
+ segment_ids = segment_ids[:max_length]
+ elif diff > 0:
+ token_ids += [pad_token] * diff
+ attention_mask += [0] * diff
+ segment_ids += [0] * diff
+
+ assert len(token_ids) == max_length, "Error with input length {} vs {}".format(len(token_ids), max_length)
+ assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),
+ max_length)
+ assert len(segment_ids) == max_length, "Error with input length {} vs {}".format(len(segment_ids),
+ max_length)
+
+ label = Y
+ yield (token_ids, attention_mask, segment_ids), label
+
+ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
+ max_length = self.config.max_length
+ types = (tf.int32, tf.int32, tf.int32), tf.string
+ shapes = ([max_length], [max_length], [max_length]), [None, ] if self.config.multi_label else []
+ values = (0, 0, 0), self.label_vocab.safe_pad_token
+ return types, shapes, values
+
+ def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
+ logger.fatal('map_x should always be set to True')
+ exit(1)
+
+ def y_to_idx(self, y) -> tf.Tensor:
+ if self.config.multi_label:
+ # need to change index to binary vector
+ mapped = tf.map_fn(fn=lambda x: tf.cast(self.label_vocab.lookup(x), tf.int32), elems=y,
+ fn_output_signature=tf.TensorSpec(dtype=tf.dtypes.int32, shape=[None, ]))
+ one_hots = tf.one_hot(mapped, len(self.label_vocab))
+ idx = tf.reduce_sum(one_hots, -2)
+ else:
+ idx = self.label_vocab.lookup(y)
+ return idx
+
+ def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None,
+ batch=None) -> Iterable:
+ # Prediction to be Y > 0:
+ if self.config.multi_label:
+ preds = Y
+ else:
+ preds = tf.argmax(Y, axis=-1)
+ for y in preds:
+ yield self.label_vocab.idx_to_token[y]
+
+ def input_is_single_sample(self, input: Any) -> bool:
+ return isinstance(input, (str, tuple))
+
+
+class TransformerClassifierTF(KerasComponent):
+
+ def __init__(self, bert_text_transform=None) -> None:
+ if not bert_text_transform:
+ bert_text_transform = TransformerTextTransform()
+ super().__init__(bert_text_transform)
+ self.model: tf.keras.Model
+ self.transform: TransformerTextTransform = bert_text_transform
+
+ # noinspection PyMethodOverriding
+ def fit(self, trn_data: Any, dev_data: Any, save_dir: str, transformer: str, max_length: int = 128,
+ optimizer='adamw', warmup_steps_ratio=0.1, use_amp=False, batch_size=32,
+ epochs=3, logger=None, verbose=1, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def evaluate_output(self, tst_data, out, num_batches, metric):
+ out.write('sentence\tpred\tgold\n')
+ total, correct, score = 0, 0, 0
+ for idx, batch in enumerate(tst_data):
+ outputs = self.model.predict_on_batch(batch[0])
+ outputs = tf.argmax(outputs, axis=1)
+ for X, Y_pred, Y_gold, in zip(batch[0][0], outputs, batch[1]):
+ feature = ' '.join(self.transform.tokenizer.convert_ids_to_tokens(X.numpy()))
+ feature = feature.replace(' ##', '') # fix sub-word generated by BERT tagger
+ out.write('{}\t{}\t{}\n'.format(feature,
+ self._y_id_to_str(Y_pred),
+ self._y_id_to_str(Y_gold)))
+ total += 1
+ correct += int(tf.equal(Y_pred, Y_gold).numpy())
+ score = correct / total
+ print('\r{}/{} {}: {:.2f}'.format(idx + 1, num_batches, metric, score * 100), end='')
+ print()
+ return score
+
+ def _y_id_to_str(self, Y_pred) -> str:
+ return self.transform.label_vocab.idx_to_token[Y_pred.numpy()]
+
+ def build_loss(self, loss, **kwargs):
+ if loss:
+ assert isinstance(loss, tf.keras.losses.loss), 'Must specify loss as an instance in tf.keras.losses'
+ return loss
+ elif self.config.multi_label:
+ # Loss to be BinaryCrossentropy for multi-label:
+ loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
+ else:
+ loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ return loss
+
+ # noinspection PyMethodOverriding
+ def build_optimizer(self, optimizer, use_amp, train_steps, warmup_steps, **kwargs):
+ if optimizer == 'adamw':
+ opt = create_optimizer(init_lr=5e-5, num_train_steps=train_steps, num_warmup_steps=warmup_steps)
+ # opt = tfa.optimizers.AdamW(learning_rate=3e-5, epsilon=1e-08, weight_decay=0.01)
+ # opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
+ self.config.optimizer = tf.keras.utils.serialize_keras_object(opt)
+ lr_config = self.config.optimizer['config']['learning_rate']['config']
+ if hasattr(lr_config['decay_schedule_fn'], 'get_config'):
+ lr_config['decay_schedule_fn'] = dict(
+ (k, v) for k, v in lr_config['decay_schedule_fn'].config().items() if not k.startswith('_'))
+ else:
+ opt = super().build_optimizer(optimizer)
+ if use_amp:
+ # loss scaling is currently required when using mixed precision
+ opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
+ return opt
+
+ # noinspection PyMethodOverriding
+ def build_model(self, transformer, max_length, **kwargs):
+ model, self.transform.tokenizer = build_transformer(transformer, max_length, len(self.transform.label_vocab),
+ tagging=False)
+ return model
+
+ def build_vocab(self, trn_data, logger):
+ train_examples = super().build_vocab(trn_data, logger)
+ warmup_steps_per_epoch = math.ceil(train_examples * self.config.warmup_steps_ratio / self.config.batch_size)
+ self.config.warmup_steps = warmup_steps_per_epoch * self.config.epochs
+ return train_examples
+
+ def build_metrics(self, metrics, logger, **kwargs):
+ if self.config.multi_label:
+ metric = tf.keras.metrics.BinaryAccuracy('binary_accuracy')
+ else:
+ metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
+ return [metric]
diff --git a/tests/debug/__init__.py b/hanlp/components/coref/__init__.py
similarity index 62%
rename from tests/debug/__init__.py
rename to hanlp/components/coref/__init__.py
index bda4ea248..83097d68d 100644
--- a/tests/debug/__init__.py
+++ b/hanlp/components/coref/__init__.py
@@ -1,3 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2020-01-01 17:33
\ No newline at end of file
+# Date: 2020-07-05 19:56
\ No newline at end of file
diff --git a/hanlp/components/coref/end_to_end.py b/hanlp/components/coref/end_to_end.py
new file mode 100644
index 000000000..4971609b7
--- /dev/null
+++ b/hanlp/components/coref/end_to_end.py
@@ -0,0 +1,270 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-05 20:28
+import logging
+from typing import Union, List, Callable, Dict, Any, Tuple
+
+import torch
+from hanlp.layers.transformers.utils import get_optimizers
+from alnlp.metrics.conll_coref_scores import ConllCorefScores
+from alnlp.metrics.mention_recall import MentionRecall
+from alnlp.models.coref import CoreferenceResolver
+from alnlp.modules.initializers import InitializerApplicator
+from alnlp.modules.util import lengths_to_mask
+from alnlp.training.optimizers import make_parameter_groups
+from torch.utils.data import DataLoader
+from hanlp.common.dataset import PadSequenceDataLoader
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength
+from hanlp.datasets.coref.conll12coref import CONLL12CorefDataset
+from hanlp.layers.context_layer import LSTMContextualEncoder
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.layers.feed_forward import FeedForward
+from hanlp.utils.time_util import CountdownTimer
+from hanlp.utils.torch_util import clip_grad_norm
+from hanlp_common.util import merge_locals_kwargs
+
+
+class CoreferenceResolverModel(CoreferenceResolver):
+ # noinspection PyMethodOverriding
+ def forward(self, batch: dict) -> Dict[str, torch.Tensor]:
+ batch['mask'] = mask = lengths_to_mask(batch['text_length'])
+ return super().forward(batch, batch['spans'], batch.get('span_labels'), mask=mask)
+
+
+class EndToEndCoreferenceResolver(TorchComponent):
+ def build_optimizer(self,
+ trn,
+ epochs,
+ lr,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ transformer_lr,
+ **kwargs):
+ # noinspection PyProtectedMember
+ transformer = getattr(self.model._text_field_embedder, 'transformer', None)
+ if transformer:
+ model = self.model
+ num_training_steps = len(trn) * epochs // self.config.get('gradient_accumulation', 1)
+
+ optimizer_grouped_parameters = make_parameter_groups(list(self.model.named_parameters()),
+ [([".*transformer.*"], {"lr": transformer_lr})])
+ optimizer, linear_scheduler = get_optimizers(model,
+ num_training_steps,
+ learning_rate=lr,
+ adam_epsilon=adam_epsilon,
+ weight_decay=weight_decay,
+ warmup_steps=warmup_steps,
+ optimizer_grouped_parameters=optimizer_grouped_parameters
+ )
+ else:
+ optimizer = torch.optim.Adam(self.model.parameters(), self.config.lr)
+ linear_scheduler = None
+ reduce_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ optimizer=optimizer,
+ mode='max',
+ factor=0.5,
+ patience=2,
+ verbose=True,
+ )
+ return optimizer, reduce_lr_scheduler, linear_scheduler
+
+ def build_criterion(self, **kwargs):
+ pass
+
+ def build_metric(self, **kwargs) -> Tuple[MentionRecall, ConllCorefScores]:
+ return self.model._mention_recall, self.model._conll_coref_scores
+
+ def execute_training_loop(self,
+ trn: DataLoader,
+ dev: DataLoader,
+ epochs,
+ criterion,
+ optimizer,
+ metric,
+ save_dir,
+ logger: logging.Logger,
+ devices,
+ **kwargs):
+ best_epoch, best_metric = 0, -1
+ mention_recall, conll_coref_scores = self.build_metric()
+ optimizer, reduce_lr_scheduler, linear_scheduler = optimizer
+ timer = CountdownTimer(epochs)
+ ratio_width = len(f'{len(trn)}/{len(trn)}')
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, linear_scheduler=linear_scheduler)
+ if dev:
+ self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
+ report = f'{timer.elapsed_human}/{timer.total_time_human}'
+ dev_score = conll_coref_scores.get_metric()[-1]
+ reduce_lr_scheduler.step(dev_score)
+ if dev_score > best_metric:
+ self.save_weights(save_dir)
+ best_metric = dev_score
+ report += ' [red]saved[/red]'
+ timer.log(report, ratio_percentage=False, newline=True, ratio=False)
+
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ linear_scheduler=None,
+ **kwargs):
+ self.model.train()
+ timer = CountdownTimer(len(trn))
+ total_loss = 0
+ self.reset_metrics()
+ for batch in trn:
+ optimizer.zero_grad()
+ output_dict = self.feed_batch(batch)
+ loss = output_dict['loss']
+ loss.backward()
+ if self.config.grad_norm:
+ clip_grad_norm(self.model, self.config.grad_norm)
+ optimizer.step()
+ if linear_scheduler:
+ linear_scheduler.step()
+ total_loss += loss.item()
+ timer.log(self.report_metrics(total_loss / (timer.current + 1)), ratio_percentage=None, logger=logger)
+ del loss
+ return total_loss / timer.total
+
+ # noinspection PyMethodOverriding
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric,
+ logger,
+ ratio_width=None,
+ output=False,
+ **kwargs):
+ self.model.eval()
+ self.reset_metrics()
+ timer = CountdownTimer(len(data))
+ total_loss = 0
+ self.reset_metrics()
+ for batch in data:
+ output_dict = self.feed_batch(batch)
+ loss = output_dict['loss']
+ total_loss += loss.item()
+ timer.log(self.report_metrics(total_loss / (timer.current + 1)), ratio_percentage=None, logger=logger,
+ ratio_width=ratio_width)
+ del loss
+ return total_loss / timer.total
+
+ def build_model(self,
+ training=True,
+ **kwargs) -> torch.nn.Module:
+ # noinspection PyTypeChecker
+ model = CoreferenceResolverModel(
+ self.config.embed.module(vocabs=self.vocabs, training=training),
+ self.config.context_layer,
+ self.config.mention_feedforward,
+ self.config.antecedent_feedforward,
+ self.config.feature_size,
+ self.config.max_span_width,
+ self.config.spans_per_word,
+ self.config.max_antecedents,
+ self.config.coarse_to_fine,
+ self.config.inference_order,
+ self.config.lexical_dropout,
+ InitializerApplicator([
+ [".*linear_layers.*weight", {"type": "xavier_normal"}],
+ [".*scorer._module.weight", {"type": "xavier_normal"}],
+ ["_distance_embedding.weight", {"type": "xavier_normal"}],
+ ["_span_width_embedding.weight", {"type": "xavier_normal"}],
+ ["_context_layer._module.weight_ih.*", {"type": "xavier_normal"}],
+ ["_context_layer._module.weight_hh.*", {"type": "orthogonal"}]
+ ])
+ )
+ return model
+
+ def build_dataloader(self,
+ data,
+ batch_size,
+ shuffle,
+ device,
+ logger: logging.Logger,
+ **kwargs) -> DataLoader:
+ dataset = CONLL12CorefDataset(data, [FieldLength('text')])
+ if isinstance(self.config.embed, Embedding):
+ transform = self.config.embed.transform(vocabs=self.vocabs)
+ if transform:
+ dataset.append_transform(transform)
+ dataset.append_transform(self.vocabs)
+ if isinstance(data, str):
+ dataset.purge_cache() # Enable cache
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset)
+ return PadSequenceDataLoader(batch_size=batch_size,
+ shuffle=shuffle,
+ device=device,
+ dataset=dataset,
+ pad={'spans': 0, 'span_labels': -1})
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
+ pass
+
+ # noinspection PyMethodOverriding
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ batch_size,
+ embed: Embedding,
+ mention_feedforward: FeedForward,
+ antecedent_feedforward: FeedForward,
+ feature_size: int,
+ max_span_width: int,
+ spans_per_word: float,
+ max_antecedents: int,
+ lr=1e-3,
+ transformer_lr=1e-5,
+ adam_epsilon=1e-6,
+ weight_decay=0.01,
+ warmup_steps=0.1,
+ epochs=150,
+ grad_norm=None,
+ coarse_to_fine: bool = False,
+ inference_order: int = 1,
+ lexical_dropout: float = 0.2,
+ context_layer: LSTMContextualEncoder = None,
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs
+ ):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def update_metric(self, metric, logits: torch.Tensor, target, output=None):
+ metric(logits, target)
+ if output:
+ label_ids = logits.argmax(-1)
+ return label_ids
+
+ def compute_loss(self, criterion, logits, target, batch):
+ loss = criterion(logits, target)
+ return loss
+
+ def feed_batch(self, batch) -> Dict[str, Any]:
+ output_dict = self.model(batch)
+ return output_dict
+
+ def build_vocabs(self, dataset, **kwargs):
+ if self.vocabs:
+ for each in dataset:
+ pass
+ self.vocabs.lock()
+ self.vocabs.summary()
+
+ def reset_metrics(self):
+ for each in self.build_metric():
+ each.reset()
+
+ def report_metrics(self, loss):
+ mention_recall, conll_coref_scores = self.build_metric()
+ return f'loss:{loss:.4f} mention_recall:{mention_recall.get_metric():.2%} f1:{conll_coref_scores.get_metric()[-1]:.2%}'
diff --git a/hanlp/components/distillation/__init__.py b/hanlp/components/distillation/__init__.py
new file mode 100644
index 000000000..9627975df
--- /dev/null
+++ b/hanlp/components/distillation/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-17 20:29
diff --git a/hanlp/components/distillation/distillable_component.py b/hanlp/components/distillation/distillable_component.py
new file mode 100644
index 000000000..afc0e41eb
--- /dev/null
+++ b/hanlp/components/distillation/distillable_component.py
@@ -0,0 +1,54 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-17 20:30
+from abc import ABC
+from copy import copy
+
+import hanlp
+from hanlp.common.torch_component import TorchComponent
+from hanlp.components.distillation.losses import KnowledgeDistillationLoss
+from hanlp.components.distillation.schedulers import TemperatureScheduler
+from hanlp.utils.torch_util import cuda_devices
+from hanlp_common.util import merge_locals_kwargs
+
+
+class DistillableComponent(TorchComponent, ABC):
+
+ # noinspection PyMethodMayBeStatic,PyTypeChecker
+ def build_teacher(self, teacher: str, devices) -> TorchComponent:
+ return hanlp.load(teacher, load_kwargs={'devices': devices})
+
+ def distill(self,
+ teacher: str,
+ trn_data,
+ dev_data,
+ save_dir,
+ batch_size=None,
+ epochs=None,
+ kd_criterion='kd_ce_loss',
+ temperature_scheduler='flsw',
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs):
+ devices = devices or cuda_devices()
+ if isinstance(kd_criterion, str):
+ kd_criterion = KnowledgeDistillationLoss(kd_criterion)
+ if isinstance(temperature_scheduler, str):
+ temperature_scheduler = TemperatureScheduler.from_name(temperature_scheduler)
+ teacher = self.build_teacher(teacher, devices=devices)
+ self.vocabs = teacher.vocabs
+ config = copy(teacher.config)
+ batch_size = batch_size or config.get('batch_size', None)
+ epochs = epochs or config.get('epochs', None)
+ config.update(kwargs)
+ return super().fit(**merge_locals_kwargs(locals(),
+ config,
+ excludes=('self', 'kwargs', '__class__', 'config')))
+
+ @property
+ def _savable_config(self):
+ config = super(DistillableComponent, self)._savable_config
+ if 'teacher' in config:
+ config.teacher = config.teacher.load_path
+ return config
diff --git a/hanlp/components/distillation/losses.py b/hanlp/components/distillation/losses.py
new file mode 100644
index 000000000..4c042938e
--- /dev/null
+++ b/hanlp/components/distillation/losses.py
@@ -0,0 +1,285 @@
+# Adopted from https://github.com/airaria/TextBrewer
+# Apache License Version 2.0
+
+import torch
+import torch.nn.functional as F
+
+from hanlp_common.configurable import AutoConfigurable
+
+
+def kd_mse_loss(logits_S, logits_T, temperature=1):
+ '''
+ Calculate the mse loss between logits_S and logits_T
+
+ :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
+ :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
+ :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
+ '''
+ if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
+ temperature = temperature.unsqueeze(-1)
+ beta_logits_T = logits_T / temperature
+ beta_logits_S = logits_S / temperature
+ loss = F.mse_loss(beta_logits_S, beta_logits_T)
+ return loss
+
+
+def kd_ce_loss(logits_S, logits_T, temperature=1):
+ '''
+ Calculate the cross entropy between logits_S and logits_T
+
+ :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
+ :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
+ :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
+ '''
+ if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
+ temperature = temperature.unsqueeze(-1)
+ beta_logits_T = logits_T / temperature
+ beta_logits_S = logits_S / temperature
+ p_T = F.softmax(beta_logits_T, dim=-1)
+ loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
+ return loss
+
+
+def att_mse_loss(attention_S, attention_T, mask=None):
+ '''
+ * Calculates the mse loss between `attention_S` and `attention_T`.
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+
+ :param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
+ :param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
+ :param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
+ '''
+ if mask is None:
+ attention_S_select = torch.where(attention_S <= -1e-3, torch.zeros_like(attention_S), attention_S)
+ attention_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), attention_T)
+ loss = F.mse_loss(attention_S_select, attention_T_select)
+ else:
+ mask = mask.to(attention_S).unsqueeze(1).expand(-1, attention_S.size(1), -1) # (bs, num_of_heads, len)
+ valid_count = torch.pow(mask.sum(dim=2), 2).sum()
+ loss = (F.mse_loss(attention_S, attention_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(
+ 2)).sum() / valid_count
+ return loss
+
+
+def att_mse_sum_loss(attention_S, attention_T, mask=None):
+ '''
+ * Calculates the mse loss between `attention_S` and `attention_T`.
+ * If the the shape is (*batch_size*, *num_heads*, *length*, *length*), sums along the `num_heads` dimension and then calcuates the mse loss between the two matrices.
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+
+ :param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
+ :param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
+ :param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
+ '''
+ if len(attention_S.size()) == 4:
+ attention_T = attention_T.sum(dim=1)
+ attention_S = attention_S.sum(dim=1)
+ if mask is None:
+ attention_S_select = torch.where(attention_S <= -1e-3, torch.zeros_like(attention_S), attention_S)
+ attention_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), attention_T)
+ loss = F.mse_loss(attention_S_select, attention_T_select)
+ else:
+ mask = mask.to(attention_S)
+ valid_count = torch.pow(mask.sum(dim=1), 2).sum()
+ loss = (F.mse_loss(attention_S, attention_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(
+ 1)).sum() / valid_count
+ return loss
+
+
+def att_ce_loss(attention_S, attention_T, mask=None):
+ '''
+
+ * Calculates the cross-entropy loss between `attention_S` and `attention_T`, where softmax is to applied on ``dim=-1``.
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+
+ :param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
+ :param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
+ :param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
+ '''
+ probs_T = F.softmax(attention_T, dim=-1)
+ if mask is None:
+ probs_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), probs_T)
+ loss = -((probs_T_select * F.log_softmax(attention_S, dim=-1)).sum(dim=-1)).mean()
+ else:
+ mask = mask.to(attention_S).unsqueeze(1).expand(-1, attention_S.size(1), -1) # (bs, num_of_heads, len)
+ loss = -((probs_T * F.log_softmax(attention_S, dim=-1) * mask.unsqueeze(2)).sum(
+ dim=-1) * mask).sum() / mask.sum()
+ return loss
+
+
+def att_ce_mean_loss(attention_S, attention_T, mask=None):
+ '''
+ * Calculates the cross-entropy loss between `attention_S` and `attention_T`, where softmax is to applied on ``dim=-1``.
+ * If the shape is (*batch_size*, *num_heads*, *length*, *length*), averages over dimension `num_heads` and then computes cross-entropy loss between the two matrics.
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+
+ :param torch.tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
+ :param torch.tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*) or (*batch_size*, *length*, *length*)
+ :param torch.tensor mask: tensor of shape (*batch_size*, *length*)
+ '''
+ if len(attention_S.size()) == 4:
+ attention_S = attention_S.mean(dim=1) # (bs, len, len)
+ attention_T = attention_T.mean(dim=1)
+ probs_T = F.softmax(attention_T, dim=-1)
+ if mask is None:
+ probs_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), probs_T)
+ loss = -((probs_T_select * F.log_softmax(attention_S, dim=-1)).sum(dim=-1)).mean()
+ else:
+ mask = mask.to(attention_S)
+ loss = -((probs_T * F.log_softmax(attention_S, dim=-1) * mask.unsqueeze(1)).sum(
+ dim=-1) * mask).sum() / mask.sum()
+ return loss
+
+
+def hid_mse_loss(state_S, state_T, mask=None):
+ '''
+ * Calculates the mse loss between `state_S` and `state_T`, which are the hidden state of the models.
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+ * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
+
+ :param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
+ '''
+ if mask is None:
+ loss = F.mse_loss(state_S, state_T)
+ else:
+ mask = mask.to(state_S)
+ valid_count = mask.sum() * state_S.size(-1)
+ loss = (F.mse_loss(state_S, state_T, reduction='none') * mask.unsqueeze(-1)).sum() / valid_count
+ return loss
+
+
+def cos_loss(state_S, state_T, mask=None):
+ '''
+ * Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see `DistilBERT `_
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+ * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
+
+ :param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
+ '''
+ if mask is None:
+ state_S = state_S.view(-1, state_S.size(-1))
+ state_T = state_T.view(-1, state_T.size(-1))
+ else:
+ mask = mask.to(state_S).unsqueeze(-1).expand_as(state_S) # (bs,len,dim)
+ state_S = torch.masked_select(state_S, mask).view(-1, mask.size(-1)) # (bs * select, dim)
+ state_T = torch.masked_select(state_T, mask).view(-1, mask.size(-1)) # (bs * select, dim)
+
+ target = state_S.new(state_S.size(0)).fill_(1)
+ loss = F.cosine_embedding_loss(state_S, state_T, target, reduction='mean')
+ return loss
+
+
+def pkd_loss(state_S, state_T, mask=None):
+ '''
+ * Computes normalized vector mse loss at position 0 along `length` dimension. This is the loss used in BERT-PKD, see `Patient Knowledge Distillation for BERT Model Compression `_.
+ * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
+
+ :param torch.Tensor state_S: tensor of shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.Tensor state_T: tensor of shape (*batch_size*, *length*, *hidden_size*)
+ :param mask: not used.
+ '''
+
+ cls_T = state_T[:, 0] # (batch_size, hidden_dim)
+ cls_S = state_S[:, 0] # (batch_size, hidden_dim)
+ normed_cls_T = cls_T / torch.norm(cls_T, dim=1, keepdim=True)
+ normed_cls_S = cls_S / torch.norm(cls_S, dim=1, keepdim=True)
+ loss = (normed_cls_S - normed_cls_T).pow(2).sum(dim=-1).mean()
+ return loss
+
+
+def fsp_loss(state_S, state_T, mask=None):
+ r'''
+ * Takes in two lists of matrics `state_S` and `state_T`. Each list contains two matrices of the shape (*batch_size*, *length*, *hidden_size*). Computes the similarity matrix between the two matrices in `state_S` ( with the resulting shape (*batch_size*, *hidden_size*, *hidden_size*) ) and the ones in B ( with the resulting shape (*batch_size*, *hidden_size*, *hidden_size*) ), then computes the mse loss between the similarity matrices:
+
+ .. math::
+
+ loss = mean((S_{1}^T \cdot S_{2} - T_{1}^T \cdot T_{2})^2)
+
+ * It is a Variant of FSP loss in `A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning `_.
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+ * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
+
+ :param torch.tensor state_S: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.tensor state_T: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.tensor mask: tensor of the shape (*batch_size*, *length*)
+
+ Example in `intermediate_matches`::
+
+ intermediate_matches = [
+ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'fsp', 'weight' : 1, 'proj':['linear',384,768]},
+ ...]
+ '''
+ if mask is None:
+ state_S_0 = state_S[0] # (batch_size , length, hidden_dim)
+ state_S_1 = state_S[1] # (batch_size, length, hidden_dim)
+ state_T_0 = state_T[0]
+ state_T_1 = state_T[1]
+ gram_S = torch.bmm(state_S_0.transpose(1, 2), state_S_1) / state_S_1.size(
+ 1) # (batch_size, hidden_dim, hidden_dim)
+ gram_T = torch.bmm(state_T_0.transpose(1, 2), state_T_1) / state_T_1.size(1)
+ else:
+ mask = mask.to(state_S[0]).unsqueeze(-1)
+ lengths = mask.sum(dim=1, keepdim=True)
+ state_S_0 = state_S[0] * mask
+ state_S_1 = state_S[1] * mask
+ state_T_0 = state_T[0] * mask
+ state_T_1 = state_T[1] * mask
+ gram_S = torch.bmm(state_S_0.transpose(1, 2), state_S_1) / lengths
+ gram_T = torch.bmm(state_T_0.transpose(1, 2), state_T_1) / lengths
+ loss = F.mse_loss(gram_S, gram_T)
+ return loss
+
+
+def mmd_loss(state_S, state_T, mask=None):
+ r'''
+ * Takes in two lists of matrices `state_S` and `state_T`. Each list contains 2 matrices of the shape (*batch_size*, *length*, *hidden_size*). `hidden_size` of matrices in `State_S` doesn't need to be the same as that of `state_T`. Computes the similarity matrix between the two matrices in `state_S` ( with the resulting shape (*batch_size*, *length*, *length*) ) and the ones in B ( with the resulting shape (*batch_size*, *length*, *length*) ), then computes the mse loss between the similarity matrices:
+
+ .. math::
+
+ loss = mean((S_{1} \cdot S_{2}^T - T_{1} \cdot T_{2}^T)^2)
+
+ * It is a Variant of the NST loss in `Like What You Like: Knowledge Distill via Neuron Selectivity Transfer `_
+ * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
+
+ :param torch.tensor state_S: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.tensor state_T: list of two tensors, each tensor is of the shape (*batch_size*, *length*, *hidden_size*)
+ :param torch.tensor mask: tensor of the shape (*batch_size*, *length*)
+
+ Example in `intermediate_matches`::
+
+ intermediate_matches = [
+ {'layer_T':[0,0], 'layer_S':[0,0], 'feature':'hidden','loss': 'nst', 'weight' : 1},
+ ...]
+ '''
+ state_S_0 = state_S[0] # (batch_size , length, hidden_dim_S)
+ state_S_1 = state_S[1] # (batch_size , length, hidden_dim_S)
+ state_T_0 = state_T[0] # (batch_size , length, hidden_dim_T)
+ state_T_1 = state_T[1] # (batch_size , length, hidden_dim_T)
+ if mask is None:
+ gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2) # (batch_size, length, length)
+ gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
+ loss = F.mse_loss(gram_S, gram_T)
+ else:
+ mask = mask.to(state_S[0])
+ valid_count = torch.pow(mask.sum(dim=1), 2).sum()
+ gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(1) # (batch_size, length, length)
+ gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(1)
+ loss = (F.mse_loss(gram_S, gram_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(
+ 1)).sum() / valid_count
+ return loss
+
+
+class KnowledgeDistillationLoss(AutoConfigurable):
+ def __init__(self, name) -> None:
+ super().__init__()
+ self.name = name
+ import sys
+ thismodule = sys.modules[__name__]
+ self._loss = getattr(thismodule, name)
+
+ def __call__(self, *args, **kwargs):
+ return self._loss(*args, **kwargs)
diff --git a/hanlp/components/distillation/schedulers.py b/hanlp/components/distillation/schedulers.py
new file mode 100644
index 000000000..44c725f1a
--- /dev/null
+++ b/hanlp/components/distillation/schedulers.py
@@ -0,0 +1,124 @@
+# Adopted from https://github.com/airaria/TextBrewer
+# Apache License Version 2.0
+from abc import ABC, abstractmethod
+
+import torch
+
+# x is between 0 and 1
+from hanlp_common.configurable import AutoConfigurable
+
+
+def linear_growth_weight_scheduler(x):
+ return x
+
+
+def linear_decay_weight_scheduler(x):
+ return 1 - x
+
+
+def constant_temperature_scheduler(logits_S, logits_T, base_temperature):
+ '''
+ Remember to detach logits_S
+ '''
+ return base_temperature
+
+
+def flsw_temperature_scheduler_builder(beta, gamma, eps=1e-4, *args):
+ '''
+ adapted from arXiv:1911.07471
+ '''
+
+ def flsw_temperature_scheduler(logits_S, logits_T, base_temperature):
+ v = logits_S.detach()
+ t = logits_T.detach()
+ with torch.no_grad():
+ v = v / (torch.norm(v, dim=-1, keepdim=True) + eps)
+ t = t / (torch.norm(t, dim=-1, keepdim=True) + eps)
+ w = torch.pow((1 - (v * t).sum(dim=-1)), gamma)
+ tau = base_temperature + (w.mean() - w) * beta
+ return tau
+
+ return flsw_temperature_scheduler
+
+
+def cwsm_temperature_scheduler_builder(beta, *args):
+ '''
+ adapted from arXiv:1911.07471
+ '''
+
+ def cwsm_temperature_scheduler(logits_S, logits_T, base_temperature):
+ v = logits_S.detach()
+ with torch.no_grad():
+ v = torch.softmax(v, dim=-1)
+ v_max = v.max(dim=-1)[0]
+ w = 1 / (v_max + 1e-3)
+ tau = base_temperature + (w.mean() - w) * beta
+ return tau
+
+ return cwsm_temperature_scheduler
+
+
+class LinearTeacherAnnealingScheduler(object):
+ def __init__(self, num_training_steps: int) -> None:
+ super().__init__()
+ self._num_training_steps = num_training_steps
+ self._current_training_steps = 0
+
+ def step(self):
+ self._current_training_steps += 1
+
+ def __float__(self):
+ return self._current_training_steps / self._num_training_steps
+
+
+class TemperatureScheduler(ABC, AutoConfigurable):
+
+ def __init__(self, base_temperature) -> None:
+ super().__init__()
+ self.base_temperature = base_temperature
+
+ def __call__(self, logits_S, logits_T):
+ return self.forward(logits_S, logits_T)
+
+ @abstractmethod
+ def forward(self, logits_S, logits_T):
+ raise NotImplementedError()
+
+ @staticmethod
+ def from_name(name):
+ classes = {
+ 'constant': ConstantScheduler,
+ 'flsw': FlswScheduler,
+ 'cwsm': CwsmScheduler,
+ }
+ assert name in classes, f'Unsupported temperature scheduler {name}. Expect one from {list(classes.keys())}.'
+ return classes[name]()
+
+
+class FunctionalScheduler(TemperatureScheduler):
+
+ def __init__(self, scheduler_func, base_temperature) -> None:
+ super().__init__(base_temperature)
+ self._scheduler_func = scheduler_func
+
+ def forward(self, logits_S, logits_T):
+ return self._scheduler_func(logits_S, logits_T, self.base_temperature)
+
+
+class ConstantScheduler(TemperatureScheduler):
+ def forward(self, logits_S, logits_T):
+ return self.base_temperature
+
+
+class FlswScheduler(FunctionalScheduler):
+ def __init__(self, beta=1, gamma=1, eps=1e-4, base_temperature=8):
+ super().__init__(flsw_temperature_scheduler_builder(beta, gamma, eps), base_temperature)
+ self.beta = beta
+ self.gamma = gamma
+ self.eps = eps
+
+
+class CwsmScheduler(FunctionalScheduler):
+ def __init__(self, beta=1, base_temperature=8):
+ super().__init__(cwsm_temperature_scheduler_builder(beta), base_temperature)
+ self.beta = beta
diff --git a/tests/demo/__init__.py b/hanlp/components/eos/__init__.py
similarity index 62%
rename from tests/demo/__init__.py
rename to hanlp/components/eos/__init__.py
index 94beb0dd9..a034fb194 100644
--- a/tests/demo/__init__.py
+++ b/hanlp/components/eos/__init__.py
@@ -1,3 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-12-30 18:01
\ No newline at end of file
+# Date: 2020-07-26 20:19
\ No newline at end of file
diff --git a/hanlp/components/eos/ngram.py b/hanlp/components/eos/ngram.py
new file mode 100644
index 000000000..aa8db8bc8
--- /dev/null
+++ b/hanlp/components/eos/ngram.py
@@ -0,0 +1,311 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-26 20:19
+import logging
+from collections import Counter
+from typing import Union, List, Callable
+
+import torch
+from torch import nn, optim
+from torch.nn import BCEWithLogitsLoss
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import PadSequenceDataLoader
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.vocab import Vocab
+from hanlp.datasets.eos.eos import SentenceBoundaryDetectionDataset
+from hanlp.metrics.f1 import F1
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs
+
+
+class NgramSentenceBoundaryDetectionModel(nn.Module):
+
+ def __init__(self,
+ char_vocab_size,
+ embedding_size=128,
+ rnn_type: str = 'LSTM',
+ rnn_size=256,
+ rnn_layers=1,
+ rnn_bidirectional=False,
+ dropout=0.2,
+ **kwargs
+ ):
+ super(NgramSentenceBoundaryDetectionModel, self).__init__()
+ self.embed = nn.Embedding(num_embeddings=char_vocab_size,
+ embedding_dim=embedding_size)
+ rnn_type = rnn_type.lower()
+ if rnn_type == 'lstm':
+ self.rnn = nn.LSTM(input_size=embedding_size,
+ hidden_size=rnn_size,
+ num_layers=rnn_layers,
+ dropout=self.dropout if rnn_layers > 1 else 0.0,
+ bidirectional=rnn_bidirectional,
+ batch_first=True)
+ elif rnn_type == 'gru':
+ self.rnn = nn.GRU(input_size=self.embdding_size,
+ hidden_size=rnn_size,
+ num_layers=rnn_layers,
+ dropout=self.dropout if rnn_layers > 1 else 0.0,
+ bidirectional=rnn_bidirectional,
+ batch_first=True)
+ else:
+ raise NotImplementedError(f"'{rnn_type}' has to be one of [LSTM, GRU]")
+ self.dropout = nn.Dropout(p=dropout) if dropout else None
+ self.dense = nn.Linear(in_features=rnn_size * (2 if rnn_bidirectional else 1),
+ out_features=1)
+
+ def forward(self, x: torch.Tensor):
+ output = self.embed(x)
+ self.rnn.flatten_parameters()
+ output, _ = self.rnn(output)
+ if self.dropout:
+ output = self.dropout(output[:, -1, :])
+ output = output.squeeze(1)
+ output = self.dense(output).squeeze(-1)
+ return output
+
+
+class NgramSentenceBoundaryDetector(TorchComponent):
+
+ def __init__(self, **kwargs) -> None:
+ """A sentence boundary detector using ngram as features and LSTM as encoder (:cite:`Schweter:Ahmed:2019`).
+ It predicts whether a punctuation marks an ``EOS``.
+
+ .. Note::
+ This component won't work on text without the punctuations defined in its config. It's always
+ recommended to understand how it works before using it. The predefined punctuations can be listed by the
+ following codes.
+
+ >>> print(eos.config.eos_chars)
+
+ Args:
+ **kwargs: Passed to config.
+ """
+ super().__init__(**kwargs)
+
+ def build_optimizer(self, **kwargs):
+ optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)
+ return optimizer
+
+ def build_criterion(self, **kwargs):
+ return BCEWithLogitsLoss()
+
+ def build_metric(self, **kwargs):
+ return F1()
+
+ def execute_training_loop(self,
+ trn: DataLoader,
+ dev: DataLoader,
+ epochs,
+ criterion,
+ optimizer,
+ metric,
+ save_dir,
+ logger: logging.Logger,
+ devices,
+ **kwargs):
+ best_epoch, best_metric = 0, -1
+ timer = CountdownTimer(epochs)
+ ratio_width = len(f'{len(trn)}/{len(trn)}')
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger)
+ if dev:
+ self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
+ report = f'{timer.elapsed_human}/{timer.total_time_human}'
+ dev_score = metric.score
+ if dev_score > best_metric:
+ self.save_weights(save_dir)
+ best_metric = dev_score
+ report += ' [red]saved[/red]'
+ timer.log(report, ratio_percentage=False, newline=True, ratio=False)
+
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ **kwargs):
+ self.model.train()
+ timer = CountdownTimer(len(trn))
+ total_loss = 0
+ self.reset_metrics(metric)
+ for batch in trn:
+ optimizer.zero_grad()
+ prediction = self.feed_batch(batch)
+ loss = self.compute_loss(prediction, batch, criterion)
+ self.update_metrics(batch, prediction, metric)
+ loss.backward()
+ if self.config.grad_norm:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
+ optimizer.step()
+ total_loss += loss.item()
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger)
+ del loss
+ return total_loss / timer.total
+
+ def compute_loss(self, prediction, batch, criterion):
+ loss = criterion(prediction, batch['label_id'])
+ return loss
+
+ # noinspection PyMethodOverriding
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric,
+ logger,
+ ratio_width=None,
+ output=False,
+ **kwargs):
+ self.model.eval()
+ self.reset_metrics(metric)
+ timer = CountdownTimer(len(data))
+ total_loss = 0
+ for batch in data:
+ prediction = self.feed_batch(batch)
+ self.update_metrics(batch, prediction, metric)
+ loss = self.compute_loss(prediction, batch, criterion)
+ total_loss += loss.item()
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger,
+ ratio_width=ratio_width)
+ del loss
+ return total_loss / timer.total, metric
+
+ def build_model(self, training=True, **kwargs) -> torch.nn.Module:
+ model = NgramSentenceBoundaryDetectionModel(**self.config, char_vocab_size=len(self.vocabs.char))
+ return model
+
+ def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger, **kwargs) -> DataLoader:
+ dataset = SentenceBoundaryDetectionDataset(data, **self.config, transform=[self.vocabs])
+ if isinstance(data, str):
+ dataset.purge_cache()
+ if not self.vocabs:
+ self.build_vocabs(dataset, logger)
+ return PadSequenceDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, device=device,
+ pad={'label_id': .0})
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, strip=True, **kwargs):
+ """Sentence split.
+
+ Args:
+ data: A paragraph or a list of paragraphs.
+ batch_size: Number of samples per batch.
+ strip: Strip out blank characters at the head and tail of each sentence.
+
+ Returns:
+ A list of sentences or a list of lists of sentences.
+ """
+ if not data:
+ return []
+ self.model.eval()
+ flat = isinstance(data, str)
+ if flat:
+ data = [data]
+ samples = []
+ eos_chars = self.config.eos_chars
+ window_size = self.config.window_size
+ for doc_id_, corpus in enumerate(data):
+ corpus = list(corpus)
+ for i, c in enumerate(corpus):
+ if c in eos_chars:
+ window = corpus[max(0, i - window_size): i + window_size + 1]
+ samples.append({'char': window, 'offset_': i, 'doc_id_': doc_id_})
+ eos_prediction = [[] for _ in range(len(data))]
+ if samples:
+ dataloader = self.build_dataloader(samples, **self.config, device=self.device, shuffle=False, logger=None)
+ for batch in dataloader:
+ logits = self.feed_batch(batch)
+ prediction = (logits > 0).tolist()
+ for doc_id_, offset_, eos in zip(batch['doc_id_'], batch['offset_'], prediction):
+ if eos:
+ eos_prediction[doc_id_].append(offset_)
+ outputs = []
+ for corpus, output in zip(data, eos_prediction):
+ sents_per_document = []
+ prev_offset = 0
+ for offset in output:
+ offset += 1
+ sents_per_document.append(corpus[prev_offset:offset])
+ prev_offset = offset
+ if prev_offset != len(corpus):
+ sents_per_document.append(corpus[prev_offset:])
+ if strip:
+ sents_per_document = [x.strip() for x in sents_per_document]
+ sents_per_document = [x for x in sents_per_document if x]
+ outputs.append(sents_per_document)
+ if flat:
+ outputs = outputs[0]
+ return outputs
+
+ # noinspection PyMethodOverriding
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ epochs=5,
+ append_after_sentence=None,
+ eos_chars=None,
+ eos_char_min_freq=200,
+ eos_char_is_punct=True,
+ char_min_freq=None,
+ window_size=5,
+ batch_size=32,
+ lr=0.001,
+ grad_norm=None,
+ loss_reduction='sum',
+ embedding_size=128,
+ rnn_type: str = 'LSTM',
+ rnn_size=256,
+ rnn_layers=1,
+ rnn_bidirectional=False,
+ dropout=0.2,
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs
+ ):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_vocabs(self, dataset: SentenceBoundaryDetectionDataset, logger, **kwargs):
+ char_min_freq = self.config.char_min_freq
+ if char_min_freq:
+ has_cache = dataset.cache is not None
+ char_counter = Counter()
+ for each in dataset:
+ for c in each['char']:
+ char_counter[c] += 1
+ self.vocabs.char = vocab = Vocab()
+ for c, f in char_counter.items():
+ if f >= char_min_freq:
+ vocab.add(c)
+ if has_cache:
+ dataset.purge_cache()
+ for each in dataset:
+ pass
+ else:
+ self.vocabs.char = Vocab()
+ for each in dataset:
+ pass
+ self.config.eos_chars = dataset.eos_chars
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
+
+ def reset_metrics(self, metrics):
+ metrics.reset()
+
+ def report_metrics(self, loss, metrics):
+ return f'loss: {loss:.4f} {metrics}'
+
+ def update_metrics(self, batch: dict, prediction: torch.FloatTensor, metrics):
+ def nonzero_offsets(y):
+ return set(y.nonzero().squeeze(-1).tolist())
+
+ metrics(nonzero_offsets(prediction > 0), nonzero_offsets(batch['label_id']))
+
+ def feed_batch(self, batch):
+ prediction = self.model(batch['char_id'])
+ return prediction
diff --git a/hanlp/components/lambda_wrapper.py b/hanlp/components/lambda_wrapper.py
index 9863a2f6a..3049c5552 100644
--- a/hanlp/components/lambda_wrapper.py
+++ b/hanlp/components/lambda_wrapper.py
@@ -4,14 +4,16 @@
from typing import Callable, Any
from hanlp.common.component import Component
-from hanlp.utils.reflection import class_path_of, object_from_class_path, str_to_type
+from hanlp_common.reflection import classpath_of, object_from_classpath, str_to_type
class LambdaComponent(Component):
def __init__(self, function: Callable) -> None:
super().__init__()
+ self.config = {}
self.function = function
- self.meta['function'] = class_path_of(function)
+ self.config['function'] = classpath_of(function)
+ self.config['classpath'] = classpath_of(self)
def predict(self, data: Any, **kwargs):
unpack = kwargs.pop('_hanlp_unpack', None)
@@ -20,8 +22,8 @@ def predict(self, data: Any, **kwargs):
return self.function(data, **kwargs)
@staticmethod
- def from_meta(meta: dict, **kwargs):
- cls = str_to_type(meta['class_path'])
+ def from_config(meta: dict, **kwargs):
+ cls = str_to_type(meta['classpath'])
function = meta['function']
- function = object_from_class_path(function)
+ function = object_from_classpath(function)
return cls(function)
diff --git a/hanlp/components/lemmatizer.py b/hanlp/components/lemmatizer.py
new file mode 100644
index 000000000..b75bad46f
--- /dev/null
+++ b/hanlp/components/lemmatizer.py
@@ -0,0 +1,42 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-08 18:35
+from typing import List
+
+from hanlp.common.transform import TransformList
+from hanlp.components.parsers.ud.lemma_edit import gen_lemma_rule, apply_lemma_rule
+from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
+
+
+def add_lemma_rules_to_sample(sample: dict):
+ if 'tag' in sample and 'lemma' not in sample:
+ lemma_rules = [gen_lemma_rule(word, lemma)
+ if lemma != "_" else "_"
+ for word, lemma in zip(sample['token'], sample['tag'])]
+ sample['lemma'] = sample['tag'] = lemma_rules
+ return sample
+
+
+class TransformerLemmatizer(TransformerTagger):
+
+ def __init__(self, **kwargs) -> None:
+ """A transition based lemmatizer using transformer as encoder.
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+
+ def build_dataset(self, data, transform=None, **kwargs):
+ if not isinstance(transform, list):
+ transform = TransformList()
+ transform.append(add_lemma_rules_to_sample)
+ return super().build_dataset(data, transform, **kwargs)
+
+ def prediction_to_human(self, pred, vocab: List[str], batch, token=None):
+ if token is None:
+ token = batch['token']
+ rules = super().prediction_to_human(pred, vocab, batch)
+ for token_per_sent, rule_per_sent in zip(token, rules):
+ lemma_per_sent = [apply_lemma_rule(t, r) for t, r in zip(token_per_sent, rule_per_sent)]
+ yield lemma_per_sent
diff --git a/tests/demo/zh/__init__.py b/hanlp/components/mtl/__init__.py
similarity index 62%
rename from tests/demo/zh/__init__.py
rename to hanlp/components/mtl/__init__.py
index f9a39bdee..9264a7dc1 100644
--- a/tests/demo/zh/__init__.py
+++ b/hanlp/components/mtl/__init__.py
@@ -1,3 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2020-01-01 17:53
\ No newline at end of file
+# Date: 2020-06-20 19:54
\ No newline at end of file
diff --git a/hanlp/components/mtl/multi_task_learning.py b/hanlp/components/mtl/multi_task_learning.py
new file mode 100644
index 000000000..9408d8128
--- /dev/null
+++ b/hanlp/components/mtl/multi_task_learning.py
@@ -0,0 +1,751 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-20 19:55
+import functools
+import itertools
+import logging
+from collections import defaultdict
+from copy import copy
+from typing import Union, List, Callable, Dict, Optional, Any, Iterable, Tuple, Set
+from itertools import chain
+import numpy as np
+import torch
+from alnlp.modules import util
+from toposort import toposort
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import IDX, BOS, EOS
+from hanlp.common.dataset import PadSequenceDataLoader, PrefetchDataLoader
+from hanlp_common.document import Document
+from hanlp.common.structure import History
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength, TransformList
+from hanlp.components.mtl.tasks import Task
+from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding, ContextualWordEmbeddingModule
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.layers.transformers.pt_imports import optimization
+from hanlp.layers.transformers.utils import pick_tensor_for_each_token
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp_common.visualization import markdown_table
+from hanlp.utils.time_util import CountdownTimer
+from hanlp.utils.torch_util import clip_grad_norm
+from hanlp_common.util import merge_locals_kwargs, topological_sort, reorder, prefix_match
+
+
+class MultiTaskModel(torch.nn.Module):
+
+ def __init__(self,
+ encoder: torch.nn.Module,
+ scalar_mixes: torch.nn.ModuleDict,
+ decoders: torch.nn.ModuleDict,
+ use_raw_hidden_states: dict) -> None:
+ super().__init__()
+ self.use_raw_hidden_states = use_raw_hidden_states
+ self.encoder: ContextualWordEmbeddingModule = encoder
+ self.scalar_mixes = scalar_mixes
+ self.decoders = decoders
+
+
+class MultiTaskDataLoader(DataLoader):
+
+ def __init__(self, training=True, tau: float = 0.8, **dataloaders) -> None:
+ # noinspection PyTypeChecker
+ super().__init__(None)
+ self.tau = tau
+ self.training = training
+ self.dataloaders: Dict[str, DataLoader] = dataloaders if dataloaders else {}
+ # self.iterators = dict((k, iter(v)) for k, v in dataloaders.items())
+
+ def __len__(self) -> int:
+ if self.dataloaders:
+ return sum(len(x) for x in self.dataloaders.values())
+ return 0
+
+ def __iter__(self):
+ if self.training:
+ sampling_weights, total_size = self.sampling_weights
+ task_names = list(self.dataloaders.keys())
+ iterators = dict((k, itertools.cycle(v)) for k, v in self.dataloaders.items())
+ for i in range(total_size):
+ task_name = np.random.choice(task_names, p=sampling_weights)
+ yield task_name, next(iterators[task_name])
+ else:
+ for task_name, dataloader in self.dataloaders.items():
+ for batch in dataloader:
+ yield task_name, batch
+
+ @property
+ def sampling_weights(self):
+ sampling_weights = self.sizes
+ total_size = sum(sampling_weights)
+ Z = sum(pow(v, self.tau) for v in sampling_weights)
+ sampling_weights = [pow(v, self.tau) / Z for v in sampling_weights]
+ return sampling_weights, total_size
+
+ @property
+ def sizes(self):
+ return [len(v) for v in self.dataloaders.values()]
+
+
+class MultiTaskLearning(TorchComponent):
+
+ def __init__(self, **kwargs) -> None:
+ """ A multi-task learning (MTL) framework. It shares the same encoder across multiple decoders. These decoders
+ can have dependencies on each other which will be properly handled during decoding. To integrate a component
+ into this MTL framework, a component needs to implement the :class:`~hanlp.components.mtl.tasks.Task` interface.
+
+ This framework mostly follows the architecture of :cite:`clark-etal-2019-bam`, with additional scalar mix
+ tricks (:cite:`kondratyuk-straka-2019-75`) allowing each task to attend to any subset of layers. We also
+ experimented with knowledge distillation on single tasks, the performance gain was nonsignificant on a large
+ dataset. In the near future, we have no plan to invest more efforts in distillation, since most datasets HanLP
+ uses are relatively large, and our hardware is relatively powerful.
+
+ Args:
+ **kwargs: Arguments passed to config.
+ """
+ super().__init__(**kwargs)
+ self.model: Optional[MultiTaskModel] = None
+ self.tasks: Dict[str, Task] = None
+ self.vocabs = None
+
+ def build_dataloader(self,
+ data,
+ batch_size,
+ shuffle=False,
+ device=None,
+ logger: logging.Logger = None,
+ gradient_accumulation=1,
+ tau: float = 0.8,
+ prune=None,
+ prefetch=None,
+ tasks_need_custom_eval=None,
+ debug=False,
+ **kwargs) -> DataLoader:
+ dataloader = MultiTaskDataLoader(training=shuffle, tau=tau)
+ for i, (task_name, task) in enumerate(self.tasks.items()):
+ encoder_transform, transform = self.build_transform(task)
+ training = None
+ if data == 'trn':
+ if debug:
+ _data = task.dev
+ else:
+ _data = task.trn
+ training = True
+ elif data == 'dev':
+ _data = task.dev
+ training = False
+ elif data == 'tst':
+ _data = task.tst
+ training = False
+ else:
+ _data = data
+ if isinstance(data, str):
+ logger.info(f'[yellow]{i + 1} / {len(self.tasks)}[/yellow] Building [blue]{data}[/blue] dataset for '
+ f'[cyan]{task_name}[/cyan] ...')
+ # Adjust Tokenizer according to task config
+ config = copy(task.config)
+ config.pop('transform', None)
+ task_dataloader: DataLoader = task.build_dataloader(_data, transform, training, device, logger,
+ tokenizer=encoder_transform.tokenizer,
+ gradient_accumulation=gradient_accumulation,
+ cache=isinstance(data, str), **config)
+ # if prune:
+ # # noinspection PyTypeChecker
+ # task_dataset: TransformDataset = task_dataloader.dataset
+ # size_before = len(task_dataset)
+ # task_dataset.prune(prune)
+ # size_after = len(task_dataset)
+ # num_pruned = size_before - size_after
+ # logger.info(f'Pruned [yellow]{num_pruned} ({num_pruned / size_before:.1%})[/yellow] '
+ # f'samples out of {size_before}.')
+ dataloader.dataloaders[task_name] = task_dataloader
+ if data == 'trn':
+ sampling_weights, total_size = dataloader.sampling_weights
+ headings = ['task', '#batches', '%batches', '#scaled', '%scaled', '#epoch']
+ matrix = []
+ min_epochs = []
+ for (task_name, dataset), weight in zip(dataloader.dataloaders.items(), sampling_weights):
+ epochs = len(dataset) / weight / total_size
+ matrix.append(
+ [f'{task_name}', len(dataset), f'{len(dataset) / total_size:.2%}', int(total_size * weight),
+ f'{weight:.2%}', f'{epochs:.2f}'])
+ min_epochs.append(epochs)
+ longest = int(torch.argmax(torch.tensor(min_epochs)))
+ table = markdown_table(headings, matrix)
+ rows = table.splitlines()
+ cells = rows[longest + 2].split('|')
+ cells[-2] = cells[-2].replace(f'{min_epochs[longest]:.2f}',
+ f'[bold][red]{min_epochs[longest]:.2f}[/red][/bold]')
+ rows[longest + 2] = '|'.join(cells)
+ logger.info(f'[bold][yellow]{"Samples Distribution": ^{len(rows[0])}}[/yellow][/bold]')
+ logger.info('\n'.join(rows))
+ if prefetch and isinstance(data, str) and (data == 'trn' or not tasks_need_custom_eval):
+ dataloader = PrefetchDataLoader(dataloader, prefetch=prefetch)
+
+ return dataloader
+
+ def build_transform(self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]:
+ encoder: ContextualWordEmbedding = self.config.encoder
+ encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer(encoder.transform())
+ length_transform = FieldLength('token', 'token_length')
+ transform = TransformList(encoder_transform, length_transform)
+ extra_transform = self.config.get('transform', None)
+ if extra_transform:
+ transform.insert(0, extra_transform)
+ return encoder_transform, transform
+
+ def build_optimizer(self,
+ trn,
+ epochs,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ lr,
+ encoder_lr,
+ **kwargs):
+ model = self.model_
+ encoder = model.encoder
+ num_training_steps = len(trn) * epochs // self.config.get('gradient_accumulation', 1)
+ encoder_parameters = list(encoder.parameters())
+ parameter_groups: List[Dict[str, Any]] = []
+
+ decoders = model.decoders
+ decoder_optimizers = dict()
+ for k, task in self.tasks.items():
+ decoder: torch.nn.Module = decoders[k]
+ decoder_parameters = list(decoder.parameters())
+ if task.separate_optimizer:
+ decoder_optimizers[k] = task.build_optimizer(decoder=decoder, **kwargs)
+ else:
+ task_lr = task.lr or lr
+ parameter_groups.append({"params": decoder_parameters, 'lr': task_lr})
+ parameter_groups.append({"params": encoder_parameters, 'lr': encoder_lr})
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
+ no_decay_parameters = set()
+ for n, p in model.named_parameters():
+ if any(nd in n for nd in no_decay):
+ no_decay_parameters.add(p)
+ no_decay_by_lr = defaultdict(list)
+ for group in parameter_groups:
+ _lr = group['lr']
+ ps = group['params']
+ group['params'] = decay_parameters = []
+ group['weight_decay'] = weight_decay
+ for p in ps:
+ if p in no_decay_parameters:
+ no_decay_by_lr[_lr].append(p)
+ else:
+ decay_parameters.append(p)
+ for _lr, ps in no_decay_by_lr.items():
+ parameter_groups.append({"params": ps, 'lr': _lr, 'weight_decay': 0.0})
+ # noinspection PyTypeChecker
+ encoder_optimizer = optimization.AdamW(
+ parameter_groups,
+ lr=lr,
+ weight_decay=weight_decay,
+ eps=adam_epsilon,
+ )
+ encoder_scheduler = optimization.get_linear_schedule_with_warmup(encoder_optimizer,
+ num_training_steps * warmup_steps,
+ num_training_steps)
+ return encoder_optimizer, encoder_scheduler, decoder_optimizers
+
+ def build_criterion(self, **kwargs):
+ return dict((k, v.build_criterion(decoder=self.model_.decoders[k], **kwargs)) for k, v in self.tasks.items())
+
+ def build_metric(self, **kwargs):
+ metrics = MetricDict()
+ for key, task in self.tasks.items():
+ metric = task.build_metric(**kwargs)
+ assert metric, f'Please implement `build_metric` of {type(task)} to return a metric.'
+ metrics[key] = metric
+ return metrics
+
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
+ logger: logging.Logger, devices, patience=0.5, **kwargs):
+ if isinstance(patience, float):
+ patience = int(patience * epochs)
+ best_epoch, best_metric = 0, -1
+ timer = CountdownTimer(epochs)
+ ratio_width = len(f'{len(trn)}/{len(trn)}')
+ epoch = 0
+ history = History()
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, history, ratio_width=ratio_width,
+ **self.config)
+ if dev:
+ self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width, input='dev')
+ report = f'{timer.elapsed_human}/{timer.total_time_human}'
+ dev_score = metric.score
+ if dev_score > best_metric:
+ self.save_weights(save_dir)
+ best_metric = dev_score
+ best_epoch = epoch
+ report += ' [red]saved[/red]'
+ else:
+ report += f' ({epoch - best_epoch})'
+ if epoch - best_epoch >= patience:
+ report += ' early stop'
+ break
+ timer.log(report, ratio_percentage=False, newline=True, ratio=False)
+ for d in [trn, dev]:
+ if isinstance(d, PrefetchDataLoader):
+ d.close()
+ if best_epoch != epoch:
+ logger.info(f'Restoring best model saved [red]{epoch - best_epoch}[/red] epochs ago')
+ self.load_weights(save_dir)
+ return best_metric
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ history: History,
+ ratio_width=None,
+ gradient_accumulation=1,
+ encoder_grad_norm=None,
+ decoder_grad_norm=None,
+ patience=0.5,
+ eval_trn=False,
+ **kwargs):
+ self.model.train()
+ encoder_optimizer, encoder_scheduler, decoder_optimizers = optimizer
+ timer = CountdownTimer(len(trn))
+ total_loss = 0
+ self.reset_metrics(metric)
+ model = self.model_
+ encoder_parameters = model.encoder.parameters()
+ decoder_parameters = model.decoders.parameters()
+ for idx, (task_name, batch) in enumerate(trn):
+ decoder_optimizer = decoder_optimizers.get(task_name, None)
+ output_dict, _ = self.feed_batch(batch, task_name)
+ loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name],
+ self.tasks[task_name])
+ if gradient_accumulation and gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += float(loss.item())
+ if history.step(gradient_accumulation):
+ if self.config.get('grad_norm', None):
+ clip_grad_norm(model, self.config.grad_norm)
+ if encoder_grad_norm:
+ torch.nn.utils.clip_grad_norm_(encoder_parameters, encoder_grad_norm)
+ if decoder_grad_norm:
+ torch.nn.utils.clip_grad_norm_(decoder_parameters, decoder_grad_norm)
+ encoder_optimizer.step()
+ encoder_optimizer.zero_grad()
+ encoder_scheduler.step()
+ if decoder_optimizer:
+ if isinstance(decoder_optimizer, tuple):
+ decoder_optimizer, decoder_scheduler = decoder_optimizer
+ else:
+ decoder_scheduler = None
+ decoder_optimizer.step()
+ decoder_optimizer.zero_grad()
+ if decoder_scheduler:
+ decoder_scheduler.step()
+ if eval_trn:
+ self.decode_output(output_dict, batch, task_name)
+ self.update_metrics(batch, output_dict, metric, task_name)
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric if eval_trn else None),
+ ratio_percentage=None,
+ ratio_width=ratio_width,
+ logger=logger)
+ del loss
+ del output_dict
+ return total_loss / timer.total
+
+ def report_metrics(self, loss, metrics: MetricDict):
+ return f'loss: {loss:.4f} {metrics.cstr()}' if metrics else f'loss: {loss:.4f}'
+
+ # noinspection PyMethodOverriding
+ @torch.no_grad()
+ def evaluate_dataloader(self,
+ data: MultiTaskDataLoader,
+ criterion,
+ metric: MetricDict,
+ logger,
+ ratio_width=None,
+ input: str = None,
+ **kwargs):
+ self.model.eval()
+ self.reset_metrics(metric)
+ tasks_need_custom_eval = self.config.get('tasks_need_custom_eval', None)
+ tasks_need_custom_eval = tasks_need_custom_eval or {}
+ tasks_need_custom_eval = dict((k, None) for k in tasks_need_custom_eval)
+ for each in tasks_need_custom_eval:
+ tasks_need_custom_eval[each] = data.dataloaders.pop(each)
+ timer = CountdownTimer(len(data) + len(tasks_need_custom_eval))
+ total_loss = 0
+ for idx, (task_name, batch) in enumerate(data):
+ output_dict, _ = self.feed_batch(batch, task_name)
+ loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name],
+ self.tasks[task_name])
+ total_loss += loss.item()
+ self.decode_output(output_dict, batch, task_name)
+ self.update_metrics(batch, output_dict, metric, task_name)
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger,
+ ratio_width=ratio_width)
+ del loss
+ del output_dict
+
+ for task_name, dataset in tasks_need_custom_eval.items():
+ task = self.tasks[task_name]
+ decoder = self.model_.decoders[task_name]
+ task.evaluate_dataloader(
+ dataset, task.build_criterion(decoder=decoder),
+ metric=metric[task_name],
+ input=task.dev if input == 'dev' else task.tst,
+ split=input,
+ decoder=decoder,
+ h=functools.partial(self._encode, task_name=task_name,
+ cls_is_bos=task.cls_is_bos, sep_is_eos=task.sep_is_eos)
+ )
+ data.dataloaders[task_name] = dataset
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger,
+ ratio_width=ratio_width)
+
+ return total_loss / timer.total, metric, data
+
+ def build_model(self, training=False, **kwargs) -> torch.nn.Module:
+ tasks = self.tasks
+ encoder: ContextualWordEmbedding = self.config.encoder
+ encoder_size = encoder.get_output_dim()
+ scalar_mixes = torch.nn.ModuleDict()
+ decoders = torch.nn.ModuleDict()
+ use_raw_hidden_states = dict()
+ for task_name, task in tasks.items():
+ decoder = task.build_model(encoder_size, training=training, **task.config)
+ assert decoder, f'Please implement `build_model` of {type(task)} to return a decoder.'
+ decoders[task_name] = decoder
+ if task.scalar_mix:
+ scalar_mix = task.scalar_mix.build()
+ scalar_mixes[task_name] = scalar_mix
+ # Activate scalar mix starting from 0-th layer
+ encoder.scalar_mix = 0
+ use_raw_hidden_states[task_name] = task.use_raw_hidden_states
+ encoder.ret_raw_hidden_states = any(use_raw_hidden_states.values())
+ return MultiTaskModel(encoder.module(training=training), scalar_mixes, decoders, use_raw_hidden_states)
+
+ def predict(self,
+ data: Union[str, List[str]],
+ batch_size: int = None,
+ tasks: Optional[Union[str, List[str]]] = None,
+ skip_tasks: Optional[Union[str, List[str]]] = None,
+ **kwargs) -> Document:
+ doc = Document()
+ if not data:
+ return doc
+
+ target_tasks = self.resolve_tasks(tasks, skip_tasks)
+ flatten_target_tasks = [self.tasks[t] for group in target_tasks for t in group]
+ cls_is_bos = any([x.cls_is_bos for x in flatten_target_tasks])
+ sep_is_eos = any([x.sep_is_eos for x in flatten_target_tasks])
+ # Now build the dataloaders and execute tasks
+ first_task_name: str = list(target_tasks[0])[0]
+ first_task: Task = self.tasks[first_task_name]
+ encoder_transform, transform = self.build_transform(first_task)
+ # Override the tokenizer config of the 1st task
+ encoder_transform.sep_is_eos = sep_is_eos
+ encoder_transform.cls_is_bos = cls_is_bos
+ average_subwords = self.model.encoder.average_subwords
+ flat = first_task.input_is_flat(data)
+ if flat:
+ data = [data]
+ device = self.device
+ samples = first_task.build_samples(data, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
+ dataloader = first_task.build_dataloader(samples, transform=transform, device=device)
+ results = defaultdict(list)
+ order = []
+ for batch in dataloader:
+ order.extend(batch[IDX])
+ # Run the first task, let it make the initial batch for the successors
+ output_dict = self.predict_task(first_task, first_task_name, batch, results, run_transform=True,
+ cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
+ # Run each task group in order
+ for group_id, group in enumerate(target_tasks):
+ # We could parallelize this in the future
+ for task_name in group:
+ if task_name == first_task_name:
+ continue
+ output_dict = self.predict_task(self.tasks[task_name], task_name, batch, results, output_dict,
+ run_transform=True, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
+ if group_id == 0:
+ # We are kind of hard coding here. If the first task is a tokenizer,
+ # we need to convert the hidden and mask to token level
+ if first_task_name.startswith('tok'):
+ spans = []
+ tokens = []
+ for span_per_sent, token_per_sent in zip(output_dict[first_task_name]['prediction'],
+ results[first_task_name][-len(batch[IDX]):]):
+ if cls_is_bos:
+ span_per_sent = [(-1, 0)] + span_per_sent
+ token_per_sent = [BOS] + token_per_sent
+ if sep_is_eos:
+ span_per_sent = span_per_sent + [(span_per_sent[-1][0] + 1, span_per_sent[-1][1] + 1)]
+ token_per_sent = token_per_sent + [EOS]
+ # The offsets start with 0 while [CLS] is zero
+ if average_subwords:
+ span_per_sent = [list(range(x[0] + 1, x[1] + 1)) for x in span_per_sent]
+ else:
+ span_per_sent = [x[0] + 1 for x in span_per_sent]
+ spans.append(span_per_sent)
+ tokens.append(token_per_sent)
+ spans = PadSequenceDataLoader.pad_data(spans, 0, torch.long, device=device)
+ output_dict['hidden'] = pick_tensor_for_each_token(output_dict['hidden'], spans,
+ average_subwords)
+ batch['token_token_span'] = spans
+ batch['token'] = tokens
+ # noinspection PyTypeChecker
+ batch['token_length'] = torch.tensor([len(x) for x in tokens], dtype=torch.long, device=device)
+ batch.pop('mask', None)
+ # Put results into doc in the order of tasks
+ for k in self.config.task_names:
+ v = results.get(k, None)
+ if v is None:
+ continue
+ doc[k] = reorder(v, order)
+ # Allow task to perform finalization on document
+ for group in target_tasks:
+ for task_name in group:
+ task = self.tasks[task_name]
+ task.finalize_document(doc, task_name)
+ # If no tok in doc, use raw input as tok
+ if not any(k.startswith('tok') for k in doc):
+ doc['tok'] = data
+ if flat:
+ for k, v in list(doc.items()):
+ doc[k] = v[0]
+ # If there is only one field, don't bother to wrap it
+ # if len(doc) == 1:
+ # return list(doc.values())[0]
+ return doc
+
+ def resolve_tasks(self, tasks, skip_tasks) -> List[Iterable[str]]:
+ # Now we decide which tasks to perform and their orders
+ tasks_in_topological_order = self._tasks_in_topological_order
+ task_topological_order = self._task_topological_order
+ computation_graph = self._computation_graph
+ target_tasks = self._resolve_task_name(tasks)
+ if not target_tasks:
+ target_tasks = tasks_in_topological_order
+ else:
+ target_topological_order = defaultdict(set)
+ for task_name in target_tasks:
+ for dependency in topological_sort(computation_graph, task_name):
+ target_topological_order[task_topological_order[dependency]].add(dependency)
+ target_tasks = [item[1] for item in sorted(target_topological_order.items())]
+ if skip_tasks:
+ skip_tasks = self._resolve_task_name(skip_tasks)
+ target_tasks = [x - skip_tasks for x in target_tasks]
+ target_tasks = [x for x in target_tasks if x]
+ assert target_tasks, f'No task to perform due to `tasks = {tasks}`.'
+ # Sort target tasks within the same group in a defined order
+ target_tasks = [sorted(x, key=lambda _x: self.config.task_names.index(_x)) for x in target_tasks]
+ return target_tasks
+
+ def predict_task(self, task: Task, output_key, batch, results, output_dict=None, run_transform=True,
+ cls_is_bos=True, sep_is_eos=True):
+ output_dict, batch = self.feed_batch(batch, output_key, output_dict, run_transform, cls_is_bos, sep_is_eos,
+ results)
+ self.decode_output(output_dict, batch, output_key)
+ results[output_key].extend(task.prediction_to_result(output_dict[output_key]['prediction'], batch))
+ return output_dict
+
+ def _resolve_task_name(self, dependencies):
+ resolved_dependencies = set()
+ if isinstance(dependencies, str):
+ if dependencies in self.tasks:
+ resolved_dependencies.add(dependencies)
+ elif dependencies.endswith('*'):
+ resolved_dependencies.update(x for x in self.tasks if x.startswith(dependencies[:-1]))
+ else:
+ prefix_matched = prefix_match(dependencies, self.config.task_names)
+ assert prefix_matched, f'No prefix matching for {dependencies}. ' \
+ f'Check your dependencies definition: {list(self.tasks.values())}'
+ resolved_dependencies.add(prefix_matched)
+ elif isinstance(dependencies, Iterable):
+ resolved_dependencies.update(set(chain.from_iterable(self._resolve_task_name(x) for x in dependencies)))
+ return resolved_dependencies
+
+ def fit(self,
+ encoder: Embedding,
+ tasks: Dict[str, Task],
+ save_dir,
+ epochs,
+ patience=0.5,
+ lr=1e-3,
+ encoder_lr=5e-5,
+ adam_epsilon=1e-8,
+ weight_decay=0.0,
+ warmup_steps=0.1,
+ gradient_accumulation=1,
+ grad_norm=5.0,
+ encoder_grad_norm=None,
+ decoder_grad_norm=None,
+ tau: float = 0.8,
+ transform=None,
+ # prune: Callable = None,
+ eval_trn=True,
+ prefetch=None,
+ tasks_need_custom_eval=None,
+ _device_placeholder=False,
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs):
+ trn_data, dev_data, batch_size = 'trn', 'dev', None
+ task_names = list(tasks.keys())
+ return super().fit(**merge_locals_kwargs(locals(), kwargs, excludes=('self', 'kwargs', '__class__', 'tasks')),
+ **tasks)
+
+ # noinspection PyAttributeOutsideInit
+ def on_config_ready(self, **kwargs):
+ self.tasks = dict((key, task) for key, task in self.config.items() if isinstance(task, Task))
+ computation_graph = dict()
+ for task_name, task in self.tasks.items():
+ dependencies = task.dependencies
+ resolved_dependencies = self._resolve_task_name(dependencies)
+ computation_graph[task_name] = resolved_dependencies
+
+ # We can cache this order
+ tasks_in_topological_order = list(toposort(computation_graph))
+ task_topological_order = dict()
+ for i, group in enumerate(tasks_in_topological_order):
+ for task_name in group:
+ task_topological_order[task_name] = i
+ self._tasks_in_topological_order = tasks_in_topological_order
+ self._task_topological_order = task_topological_order
+ self._computation_graph = computation_graph
+
+ @staticmethod
+ def reset_metrics(metrics: Dict[str, Metric]):
+ for metric in metrics.values():
+ metric.reset()
+
+ def feed_batch(self,
+ batch: Dict[str, Any],
+ task_name,
+ output_dict=None,
+ run_transform=False,
+ cls_is_bos=False,
+ sep_is_eos=False,
+ results=None) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ h, output_dict = self._encode(batch, task_name, output_dict, cls_is_bos, sep_is_eos)
+ task = self.tasks[task_name]
+ if run_transform:
+ batch = task.transform_batch(batch, results=results, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos)
+ batch['mask'] = mask = util.lengths_to_mask(batch['token_length'])
+ output_dict[task_name] = {
+ 'output': task.feed_batch(h,
+ batch=batch,
+ mask=mask,
+ decoder=self.model.decoders[task_name]),
+ 'mask': mask
+ }
+ return output_dict, batch
+
+ def _encode(self, batch, task_name, output_dict=None, cls_is_bos=False, sep_is_eos=False):
+ model = self.model
+ if output_dict:
+ hidden, raw_hidden = output_dict['hidden'], output_dict['raw_hidden']
+ else:
+ hidden = model.encoder(batch)
+ if isinstance(hidden, tuple):
+ hidden, raw_hidden = hidden
+ else:
+ raw_hidden = None
+ output_dict = {'hidden': hidden, 'raw_hidden': raw_hidden}
+ hidden_states = raw_hidden if model.use_raw_hidden_states[task_name] else hidden
+ if task_name in model.scalar_mixes:
+ scalar_mix = model.scalar_mixes[task_name]
+ h = scalar_mix(hidden_states)
+ else:
+ h = hidden_states
+ # If the task doesn't need cls while h has cls, remove cls
+ task = self.tasks[task_name]
+ if cls_is_bos and not task.cls_is_bos:
+ h = h[:, 1:, :]
+ if sep_is_eos and not task.sep_is_eos:
+ h = h[:, :-1, :]
+ return h, output_dict
+
+ def decode_output(self, output_dict, batch, task_name=None):
+ if not task_name:
+ for task_name, task in self.tasks.items():
+ output_per_task = output_dict.get(task_name, None)
+ if output_per_task is not None:
+ output_per_task['prediction'] = task.decode_output(
+ output_per_task['output'],
+ output_per_task['mask'],
+ batch, self.model.decoders[task_name])
+ else:
+ output_per_task = output_dict[task_name]
+ output_per_task['prediction'] = self.tasks[task_name].decode_output(
+ output_per_task['output'],
+ output_per_task['mask'],
+ batch,
+ self.model.decoders[task_name])
+
+ def update_metrics(self, batch: Dict[str, Any], output_dict: Dict[str, Any], metrics: MetricDict, task_name):
+ task = self.tasks[task_name]
+ output_per_task = output_dict.get(task_name, None)
+ if output_per_task:
+ output = output_per_task['output']
+ prediction = output_per_task['prediction']
+ metric = metrics.get(task_name, None)
+ task.update_metrics(batch, output, prediction, metric)
+
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion: Callable,
+ task: Task) -> torch.FloatTensor:
+ return task.compute_loss(batch, output, criterion)
+
+ def evaluate(self, save_dir=None, logger: logging.Logger = None, batch_size=None, output=False, **kwargs):
+ rets = super().evaluate('tst', save_dir, logger, batch_size, output, **kwargs)
+ tst = rets[-1]
+ if isinstance(tst, PrefetchDataLoader):
+ tst.close()
+ return rets
+
+ def save_vocabs(self, save_dir, filename='vocabs.json'):
+ for task_name, task in self.tasks.items():
+ task.save_vocabs(save_dir, f'{task_name}_{filename}')
+
+ def load_vocabs(self, save_dir, filename='vocabs.json'):
+ for task_name, task in self.tasks.items():
+ task.load_vocabs(save_dir, f'{task_name}_{filename}')
+
+ def parallelize(self, devices: List[Union[int, torch.device]]):
+ raise NotImplementedError('Parallelization is not implemented yet.')
+
+ def __call__(self, data, batch_size=None, **kwargs) -> Document:
+ return super().__call__(data, batch_size, **kwargs)
+
+ def __getitem__(self, task_name: str) -> Task:
+ return self.tasks[task_name]
+
+ def __delitem__(self, task_name: str):
+ del self.tasks[task_name]
+ del self.model.decoders[task_name]
+ del self._computation_graph[task_name]
+ self._task_topological_order.pop(task_name)
+ for group in self._tasks_in_topological_order:
+ group: set = group
+ group.discard(task_name)
+
+ def __repr__(self):
+ return repr(self.config)
+
+ def items(self):
+ yield from self.tasks.items()
diff --git a/hanlp/components/mtl/tasks/__init__.py b/hanlp/components/mtl/tasks/__init__.py
new file mode 100644
index 000000000..7afaff0cb
--- /dev/null
+++ b/hanlp/components/mtl/tasks/__init__.py
@@ -0,0 +1,289 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-02 16:51
+import logging
+import os
+import warnings
+from abc import ABC, abstractmethod
+from copy import copy
+from typing import Callable, Dict, Any, Union, Iterable, List
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import BOS, EOS
+from hanlp.common.dataset import SamplerBuilder, SortingSamplerBuilder, TransformableDataset, KMeansSamplerBuilder
+from hanlp_common.document import Document
+from hanlp.common.structure import ConfigTracker
+from hanlp.common.torch_component import TorchComponent
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp.utils.time_util import CountdownTimer
+
+
+class Task(ConfigTracker, TorchComponent, ABC):
+ # noinspection PyMissingConstructor
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=None,
+ separate_optimizer=False,
+ cls_is_bos=False,
+ sep_is_eos=False,
+ **kwargs) -> None:
+ """
+ A task in the multi-task learning framework
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ **kwargs: Not used.
+ """
+ ConfigTracker.__init__(self, locals())
+ for f, n in zip([trn, dev, tst], ['trn', 'dev', 'tst']):
+ if f and os.path.isfile(f): # anonymize local file names
+ self.config.pop(n)
+ self.separate_optimizer = separate_optimizer
+ self.lr = lr
+ self.use_raw_hidden_states = use_raw_hidden_states
+ if sampler_builder is None:
+ sampler_builder = SortingSamplerBuilder(batch_size=32)
+ self.sampler_builder: Union[SortingSamplerBuilder, KMeansSamplerBuilder] = sampler_builder
+ self.dependencies = dependencies
+ self.tst = tst
+ self.dev = dev
+ self.trn = trn
+ self.scalar_mix = scalar_mix
+ self.cls_is_bos = cls_is_bos
+ self.sep_is_eos = sep_is_eos
+
+ @abstractmethod
+ def build_dataloader(self,
+ data,
+ transform: Callable = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ cache=False,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ """
+ Build a dataloader for training or evaluation.
+
+ Args:
+ data: Either a path or a list of samples.
+ transform: The transform from MTL, which is usually [TransformerSequenceTokenizer, FieldLength('token')]
+ training: Whether this method is called on training set.
+ device: The device dataloader is intended to work with.
+ logger: Logger for printing message indicating progress.
+ cache: Whether the dataloader should be cached.
+ gradient_accumulation: Gradient accumulation to be passed to sampler builder.
+ **kwargs: Additional experimental arguments.
+ """
+ pass
+
+ def build_optimizer(self, decoder: torch.nn.Module, **kwargs):
+ pass
+
+ def build_batch_wise_scheduler(self, decoder: torch.nn.Module, **kwargs):
+ pass
+
+ @abstractmethod
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion,
+ ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ pass
+
+ @abstractmethod
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any], decoder: torch.nn.Module, **kwargs) -> Union[Dict[str, Any], Any]:
+ pass
+
+ @abstractmethod
+ def update_metrics(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any],
+ metric: Union[MetricDict, Metric]):
+ pass
+
+ # noinspection PyMethodOverriding
+ @abstractmethod
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ pass
+
+ @abstractmethod
+ def build_metric(self, **kwargs):
+ pass
+
+ def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, **kwargs):
+ pass
+
+ def evaluate_dataloader(self, data: DataLoader, criterion: Callable, output=False, **kwargs):
+ pass
+
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
+ logger: logging.Logger, devices, **kwargs):
+ pass
+
+ # noinspection PyMethodMayBeStatic
+ def compute_lens(self, data: Union[List[Dict[str, Any]], str], dataset: TransformableDataset,
+ input_ids='token_input_ids', length_field='token'):
+ """
+
+ Args:
+ data: Samples to be measured or path to dataset during training time.
+ dataset: During training time, use this dataset to measure the length of each sample inside.
+ input_ids: Field name corresponds to input ids.
+ length_field: Fall back to this field during prediction as input_ids may not be generated yet.
+
+ Returns:
+
+ Length list of this samples
+
+ """
+ if isinstance(data, str):
+ if not dataset.cache:
+ warnings.warn(f'Caching for the dataset is not enabled, '
+ f'try `dataset.purge_cache()` if possible. The dataset is {dataset}.')
+ timer = CountdownTimer(len(dataset))
+ for each in dataset:
+ timer.log('Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]')
+ timer.erase()
+ return [len(x[input_ids]) for x in dataset]
+ return [len(x[length_field]) for x in data]
+
+ def feed_batch(self,
+ h: torch.FloatTensor,
+ batch: Dict[str, torch.Tensor],
+ mask: torch.BoolTensor,
+ decoder: torch.nn.Module):
+ return decoder(h, batch=batch, mask=mask)
+
+ def input_is_flat(self, data) -> bool:
+ """
+ Check whether the data is flat (meaning that it's only a single sample, not even batched).
+
+ Returns:
+ bool: ``True`` to indicate the input data is flat.
+ """
+ raise NotImplementedError(
+ '`input_is_flat()` needs to be implemented for the task component to accept raw input from user.'
+ )
+
+ @abstractmethod
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ raise NotImplementedError()
+
+ # noinspection PyMethodMayBeStatic
+ def transform_batch(self,
+ batch: Dict[str, Any],
+ # inputs: List[List[str]],
+ results: Dict[str, Any] = None,
+ cls_is_bos=False,
+ sep_is_eos=False) -> Dict[str, Any]:
+ """
+ Let the task transform the batch before feeding the batch into its decoder. The default behavior is to
+ adjust the head and tail of tokens, according to ``cls_is_bos``, ``sep_is_eos`` passed in and the two
+ settings of the task itself.
+
+ Args:
+ batch: A batch of samples.
+ results: Predicted results from other tasks which might be useful for this task to utilize. Say a dep task
+ uses both token and pos as features, then it will need both tok and pos results to make a batch.
+ cls_is_bos: First token in this batch is BOS.
+ sep_is_eos: Last token in this batch is EOS.
+
+ Returns:
+ A batch.
+
+ """
+ if cls_is_bos != self.cls_is_bos or sep_is_eos != self.sep_is_eos:
+ batch = copy(batch)
+ tokens = self._adjust_token(batch, cls_is_bos, sep_is_eos, 'token')
+ delta = len(tokens[0]) - len(batch['token'][0])
+ batch['token_length'] = batch['token_length'] + delta
+ batch['token'] = tokens
+ if 'token_' in batch:
+ if isinstance(batch['token_'][0], list):
+ batch['token_'] = self._adjust_token(batch, cls_is_bos, sep_is_eos, 'token_')
+ else:
+ batch['token_'] = tokens
+ return batch
+
+ def _adjust_token(self, batch, cls_is_bos, sep_is_eos, token_key):
+ tokens = []
+ for sent in batch[token_key]:
+ if cls_is_bos:
+ if not self.cls_is_bos:
+ sent = sent[1:]
+ elif self.cls_is_bos:
+ sent = [BOS] + sent
+ if sep_is_eos:
+ if not self.sep_is_eos:
+ sent = sent[:-1]
+ elif self.sep_is_eos:
+ sent = sent + [EOS]
+ tokens.append(sent)
+ return tokens
+
+ # noinspection PyMethodMayBeStatic
+ def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False):
+ """
+ Build samples for this task. Called when this task is the first task. Default behaviour is to take inputs as
+ list of tokens and put these tokens into a dict per sample.
+
+ Args:
+ inputs: Inputs from users, usually a list of lists of tokens.
+ cls_is_bos: Insert BOS to the head of each sentence.
+ sep_is_eos: Append EOS to the tail of each sentence.
+
+ Returns:
+ List of samples.
+
+ """
+ if cls_is_bos:
+ inputs = [[BOS] + x for x in inputs]
+ if sep_is_eos:
+ inputs = [x + [EOS] for x in inputs]
+ return [{'token': token} for token in inputs]
+
+ def build_tokenizer(self, tokenizer: TransformerSequenceTokenizer):
+ """Build a transformer tokenizer for this task.
+
+ Args:
+ tokenizer: A tokenizer which is shared but can be adjusted to provide per-task settings.
+
+ Returns:
+ A TransformerSequenceTokenizer.
+
+ """
+ if tokenizer.cls_is_bos != self.cls_is_bos or tokenizer.sep_is_eos != self.sep_is_eos:
+ tokenizer = copy(tokenizer)
+ tokenizer.cls_is_bos = self.cls_is_bos
+ tokenizer.sep_is_eos = self.sep_is_eos
+ return tokenizer
+
+ # noinspection PyMethodMayBeStatic
+ def finalize_document(self, doc: Document, task_name: str):
+ pass
diff --git a/hanlp/components/mtl/tasks/amr.py b/hanlp/components/mtl/tasks/amr.py
new file mode 100644
index 000000000..ce1277f99
--- /dev/null
+++ b/hanlp/components/mtl/tasks/amr.py
@@ -0,0 +1,174 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-12 16:05
+import logging
+from typing import Dict, Any, List, Union, Iterable, Callable
+
+import torch
+from stog.data.dataset_readers.amr_parsing.amr import AMRGraph
+from stog.data.dataset_readers.amr_parsing.node_utils import NodeUtilities
+from stog.data.dataset_readers.amr_parsing.postprocess.node_restore import NodeRestore
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import CLS
+from hanlp.common.dataset import PrefetchDataLoader, SamplerBuilder
+from hanlp.common.transform import VocabDict
+from hanlp.components.amr.amr_parser.graph_amr_decoder import GraphAbstractMeaningRepresentationDecoder
+from hanlp.components.amr.amr_parser.graph_parser import GraphAbstractMeaningRepresentationParser
+from hanlp.components.amr.amr_parser.postprocess import PostProcessor
+from hanlp.components.amr.amr_parser.work import parse_batch
+from hanlp.components.mtl.tasks import Task
+from hanlp.datasets.parsing.amr import batchify, get_concepts
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.amr.smatch_eval import SmatchScores, get_amr_utils
+from hanlp.metrics.f1 import F1_
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.utils.io_util import get_resource
+from hanlp_common.util import merge_list_of_dict, merge_locals_kwargs
+
+
+class GraphAbstractMeaningRepresentationParsing(Task, GraphAbstractMeaningRepresentationParser):
+
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=1e-3,
+ separate_optimizer=False,
+ cls_is_bos=True,
+ sep_is_eos=False,
+ char2concept_dim=128,
+ cnn_filters=((3, 256),),
+ concept_char_dim=32,
+ concept_dim=300,
+ dropout=0.2,
+ embed_dim=512,
+ eval_every=20,
+ ff_embed_dim=1024,
+ graph_layers=2,
+ inference_layers=4,
+ num_heads=8,
+ rel_dim=100,
+ snt_layers=4,
+ unk_rate=0.33,
+ vocab_min_freq=5,
+ beam_size=8,
+ alpha=0.6,
+ max_time_step=100,
+ amr_version='2.0',
+ **kwargs) -> None:
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+ utils_dir = get_resource(get_amr_utils(amr_version))
+ self.sense_restore = NodeRestore(NodeUtilities.from_json(utils_dir))
+
+ def build_dataloader(self,
+ data,
+ transform: Callable = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ cache=False,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ if isinstance(data, list):
+ data = GraphAbstractMeaningRepresentationParser.build_samples(self, data)
+ dataset, lens = GraphAbstractMeaningRepresentationParser.build_dataset(self, data, logger=logger,
+ transform=transform, training=training)
+ if self.vocabs.mutable:
+ GraphAbstractMeaningRepresentationParser.build_vocabs(self, dataset, logger)
+ dataloader = PrefetchDataLoader(
+ DataLoader(batch_sampler=self.sampler_builder.build(lens, shuffle=training,
+ gradient_accumulation=gradient_accumulation),
+ dataset=dataset,
+ collate_fn=merge_list_of_dict,
+ num_workers=0), batchify=self.build_batchify(device, training),
+ prefetch=None)
+ return dataloader
+
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ concept_loss, arc_loss, rel_loss, graph_arc_loss = output
+ concept_loss, concept_correct, concept_total = concept_loss
+ rel_loss, rel_correct, rel_total = rel_loss
+ loss = concept_loss + arc_loss + rel_loss
+ return loss
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder: torch.nn.Module, **kwargs) -> Union[Dict[str, Any], Any]:
+ return output
+
+ def update_metrics(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any],
+ metric: Union[MetricDict, Metric]):
+ pass
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return GraphAbstractMeaningRepresentationDecoder(vocabs=self.vocabs, encoder_size=encoder_size, **self.config)
+
+ def build_metric(self, **kwargs):
+ return SmatchScores({'Smatch': F1_(0, 0, 0)})
+
+ def input_is_flat(self, data) -> bool:
+ return GraphAbstractMeaningRepresentationParser.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ pp = PostProcessor(self.vocabs['rel'])
+ for concept, relation, score in zip(prediction['concept'], prediction['relation'], prediction['score']):
+ amr = pp.to_amr(concept, relation)
+ amr_graph = AMRGraph(amr)
+ self.sense_restore.restore_graph(amr_graph)
+ yield amr_graph
+
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric=None,
+ output=False,
+ input=None,
+ decoder=None,
+ h=None,
+ split=None,
+ **kwargs):
+ # noinspection PyTypeChecker
+ GraphAbstractMeaningRepresentationParser.evaluate_dataloader(self, data, logger=None, metric=metric,
+ input=input, model=decoder, h=lambda x: h(x)[0],
+ use_fast=True)
+
+ def feed_batch(self,
+ h: torch.FloatTensor,
+ batch: Dict[str, torch.Tensor],
+ mask: torch.BoolTensor,
+ decoder: torch.nn.Module):
+ if decoder.training:
+ return super().feed_batch(h, batch, mask, decoder)
+ beam_size = self.config.get('beam_size', 8)
+ alpha = self.config.get('alpha', 0.6)
+ max_time_step = self.config.get('max_time_step', 100)
+ res = parse_batch(decoder, batch, beam_size, alpha, max_time_step, h=h)
+ return res
+
+ def transform_batch(self, batch: Dict[str, Any], results: Dict[str, Any] = None, cls_is_bos=False,
+ sep_is_eos=False) -> Dict[str, Any]:
+ batch = super().transform_batch(batch, results, cls_is_bos, sep_is_eos)
+ batch['lemma'] = [[CLS] + x for x in results['lem']]
+ copy_seq = merge_list_of_dict(
+ [get_concepts({'token': t[1:], 'lemma': l[1:]}, self.vocabs.predictable_concept) for t, l in
+ zip(batch['token'], batch['lemma'])])
+ copy_seq.pop('token')
+ copy_seq.pop('lemma')
+ batch.update(copy_seq)
+ ret = batchify(batch, self.vocabs, device=batch['token_input_ids'].device)
+ return ret
diff --git a/hanlp/components/mtl/tasks/constituency.py b/hanlp/components/mtl/tasks/constituency.py
new file mode 100644
index 000000000..ba78ebd65
--- /dev/null
+++ b/hanlp/components/mtl/tasks/constituency.py
@@ -0,0 +1,170 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-29 16:52
+import logging
+from typing import Dict, Any, List, Union, Iterable, Callable
+
+import torch
+from phrasetree.tree import Tree
+
+from hanlp_common.constant import BOS, EOS
+from hanlp_common.document import Document
+from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.parsers.constituency.crf_constituency_model import CRFConstituencyDecoder
+from hanlp.components.parsers.constituency.crf_constituency_parser import CRFConstituencyParser
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs, prefix_match
+
+
+class CRFConstituencyParsing(Task, CRFConstituencyParser):
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=None,
+ separate_optimizer=False,
+ cls_is_bos=True,
+ sep_is_eos=True,
+ delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP', ',', 'S1'),
+ equal=(('ADVP', 'PRT'),),
+ mbr=True,
+ n_mlp_span=500,
+ n_mlp_label=100,
+ mlp_dropout=.33,
+ no_subcategory=True,
+ **kwargs
+ ) -> None:
+ r"""Two-stage CRF Parsing (:cite:`ijcai2020-560`).
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ delete: Constituencies to be deleted from training and evaluation.
+ equal: Constituencies that are regarded as equal during evaluation.
+ mbr: ``True`` to enable Minimum Bayes Risk (MBR) decoding (:cite:`smith-smith-2007-probabilistic`).
+ n_mlp_span: Number of features for span decoder.
+ n_mlp_label: Number of features for label decoder.
+ mlp_dropout: Dropout applied to MLPs.
+ no_subcategory: Strip out subcategories.
+ **kwargs: Not used.
+ """
+ if isinstance(equal, tuple):
+ equal = dict(equal)
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ # noinspection DuplicatedCode
+ def build_dataloader(self,
+ data,
+ transform: Callable = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ cache=False,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ dataset = CRFConstituencyParsing.build_dataset(self, data, transform)
+ if isinstance(data, str):
+ dataset.purge_cache()
+ if self.vocabs.mutable:
+ CRFConstituencyParsing.build_vocabs(self, dataset, logger)
+ if dataset.cache:
+ timer = CountdownTimer(len(dataset))
+ # noinspection PyCallByClass
+ BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
+ gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def feed_batch(self,
+ h: torch.FloatTensor,
+ batch: Dict[str, torch.Tensor],
+ mask: torch.BoolTensor,
+ decoder: torch.nn.Module):
+ return {
+ 'output': decoder(h),
+ 'mask': CRFConstituencyParser.compute_mask(
+ self, batch, offset=1 if 'constituency' in batch or batch['token'][0][-1] == EOS else -1)
+ }
+
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ out, mask = output['output'], output['mask']
+ loss, span_probs = CRFConstituencyParser.compute_loss(self, out, batch['chart_id'], mask, crf_decoder=criterion)
+ output['span_probs'] = span_probs
+ return loss
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder: torch.nn.Module, **kwargs) -> Union[Dict[str, Any], Any]:
+ out, mask = output['output'], output['mask']
+ tokens = []
+ for sent in batch['token']:
+ if sent[0] == BOS:
+ sent = sent[1:]
+ if sent[-1] == EOS:
+ sent = sent[:-1]
+ tokens.append(sent)
+ return CRFConstituencyParser.decode_output(self, out, mask, batch, output.get('span_probs', None),
+ decoder=decoder, tokens=tokens)
+
+ def update_metrics(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ return CRFConstituencyParser.update_metrics(self, metric, batch, prediction)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return CRFConstituencyDecoder(n_labels=len(self.vocabs.chart), n_hidden=encoder_size)
+
+ def build_metric(self, **kwargs):
+ return CRFConstituencyParser.build_metric(self)
+
+ def input_is_flat(self, data) -> bool:
+ return CRFConstituencyParser.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: List, batch: Dict[str, Any]) -> List:
+ return prediction
+
+ def finalize_document(self, doc: Document, task_name: str):
+ pos_key = prefix_match('pos', doc)
+ pos: List[List[str]] = doc.get(pos_key, None)
+ if pos:
+ for tree, pos_per_sent in zip(doc[task_name], pos):
+ tree: Tree = tree
+ offset = 0
+ for subtree in tree.subtrees(lambda t: t.height() == 2):
+ tag = subtree.label()
+ if tag == '_':
+ subtree.set_label(pos_per_sent[offset])
+ offset += 1
+
+ def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False):
+ return CRFConstituencyParser.build_samples(self, inputs)
diff --git a/hanlp/components/mtl/tasks/dep.py b/hanlp/components/mtl/tasks/dep.py
new file mode 100644
index 000000000..af408e397
--- /dev/null
+++ b/hanlp/components/mtl/tasks/dep.py
@@ -0,0 +1,167 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-13 21:39
+import logging
+from typing import Dict, Any, Union, Iterable, List
+
+import torch
+from torch.optim import Adam
+from torch.optim.lr_scheduler import ExponentialLR
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict, TransformList
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
+from hanlp.components.parsers.biaffine.biaffine_model import BiaffineDecoder
+from hanlp.datasets.parsing.conll_dataset import append_bos
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.constant import EOS
+from hanlp_common.util import merge_locals_kwargs
+
+
+class BiaffineDependencyParsing(Task, BiaffineDependencyParser):
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=2e-3, separate_optimizer=False,
+ cls_is_bos=True,
+ sep_is_eos=False,
+ punct=False,
+ tree=False,
+ proj=False,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ decay=.75,
+ decay_steps=5000,
+ use_pos=False,
+ max_seq_len=None,
+ **kwargs) -> None:
+ """Biaffine dependency parsing (:cite:`dozat:17a`).
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ punct: ``True`` to include punctuations in evaluation.
+ tree: ``True`` to enforce tree constraint.
+ proj: ``True`` for projective parsing.
+ n_mlp_arc: Number of features for arc representation.
+ n_mlp_rel: Number of features for rel representation.
+ mlp_dropout: Dropout applied to MLPs.
+ mu: First coefficient used for computing running averages of gradient and its square in Adam.
+ nu: Second coefficient used for computing running averages of gradient and its square in Adam.
+ epsilon: Term added to the denominator to improve numerical stability
+ decay: Decay rate for exceptional lr scheduler.
+ decay_steps: Decay every ``decay_steps`` steps.
+ use_pos: Use pos feature.
+ max_seq_len: Prune samples longer than this length.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ BiaffineDependencyParser.update_metric(self, *prediction, batch['arc'], batch['rel_id'], output[1],
+ batch.get('punct_mask', None), metric, batch)
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder, **kwargs) -> Union[Dict[str, Any], Any]:
+ (arc_scores, rel_scores), mask = output
+ return BiaffineDependencyParser.decode(self, arc_scores, rel_scores, mask, batch)
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
+ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ (arc_scores, rel_scores), mask = output
+ return BiaffineDependencyParser.compute_loss(self, arc_scores, rel_scores, batch['arc'], batch['rel_id'], mask,
+ criterion,
+ batch)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return BiaffineDecoder(encoder_size, self.config.n_mlp_arc, self.config.n_mlp_rel, self.config.mlp_dropout,
+ len(self.vocabs.rel))
+
+ def build_metric(self, **kwargs):
+ return BiaffineDependencyParser.build_metric(self, **kwargs)
+
+ def build_dataloader(self, data, transform: TransformList = None, training=False, device=None,
+ logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader:
+ transform.insert(0, append_bos)
+ dataset = BiaffineDependencyParser.build_dataset(self, data, transform)
+ if isinstance(data, str):
+ dataset.purge_cache()
+ if self.vocabs.mutable:
+ BiaffineDependencyParser.build_vocabs(self, dataset, logger, transformer=True)
+ if dataset.cache:
+ timer = CountdownTimer(len(dataset))
+ BiaffineDependencyParser.cache_dataset(self, dataset, timer, training, logger)
+ max_seq_len = self.config.get('max_seq_len', None)
+ if max_seq_len and isinstance(data, str):
+ dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, length_field='FORM'),
+ shuffle=training, gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset,
+ pad=self.get_pad_dict())
+
+ def feed_batch(self, h: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask: torch.BoolTensor,
+ decoder: torch.nn.Module):
+ logits = super().feed_batch(h, batch, mask, decoder)
+ mask = mask.clone()
+ mask[:, 0] = 0
+ return logits, mask
+
+ def build_optimizer(self, decoder: torch.nn.Module, **kwargs):
+ config = self.config
+ optimizer = Adam(decoder.parameters(),
+ config.lr,
+ (config.mu, config.nu),
+ config.epsilon)
+ scheduler = ExponentialLR(optimizer, config.decay ** (1 / config.decay_steps))
+ return optimizer, scheduler
+
+ def input_is_flat(self, data) -> bool:
+ return BiaffineDependencyParser.input_is_flat(self, data, self.config.use_pos)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ arcs, rels = prediction
+ arcs = arcs[:, 1:] # Skip the ROOT
+ rels = rels[:, 1:]
+ arcs = arcs.tolist()
+ rels = rels.tolist()
+ vocab = self.vocabs['rel'].idx_to_token
+ for arcs_per_sent, rels_per_sent, tokens in zip(arcs, rels, batch['token']):
+ tokens = tokens[1:]
+ sent_len = len(tokens)
+ result = list(zip(arcs_per_sent[:sent_len], [vocab[r] for r in rels_per_sent[:sent_len]]))
+ yield result
+
+ def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False):
+ return [{'FORM': token + [EOS] if sep_is_eos else []} for token in inputs]
diff --git a/hanlp/components/mtl/tasks/dep_2nd.py b/hanlp/components/mtl/tasks/dep_2nd.py
new file mode 100644
index 000000000..a63f1593a
--- /dev/null
+++ b/hanlp/components/mtl/tasks/dep_2nd.py
@@ -0,0 +1,112 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-07 14:14
+import logging
+from typing import Dict, Any, Union, Iterable, Callable, List
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.parsers.biaffine.biaffine_2nd_dep import BiaffineSecondaryParser, BiaffineJointDecoder, \
+ BiaffineSeparateDecoder
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+from alnlp.modules import util
+
+
+class BiaffineSecondaryDependencyDecoder(torch.nn.Module):
+ def __init__(self, hidden_size, config) -> None:
+ super().__init__()
+ self.decoder = BiaffineJointDecoder(hidden_size, config) if config.joint \
+ else BiaffineSeparateDecoder(hidden_size, config)
+
+ def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
+ if mask is None:
+ mask = util.lengths_to_mask(batch['token_length'])
+ else:
+ mask = mask.clone()
+ scores = self.decoder(contextualized_embeddings, mask)
+ mask[:, 0] = 0
+ return scores, mask
+
+
+class BiaffineSecondaryDependencyParsing(Task, BiaffineSecondaryParser):
+
+ def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None,
+ dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False,
+ lr=2e-3, separate_optimizer=False,
+ punct=False,
+ tree=False,
+ apply_constraint=True,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ pad_rel=None,
+ joint=True,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ cls_is_bos=True,
+ **kwargs) -> None:
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def build_dataloader(self, data, transform: Callable = None, training=False, device=None,
+ logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader:
+ dataset = BiaffineSecondaryParser.build_dataset(self, data, transform)
+ if isinstance(data, str):
+ dataset.purge_cache()
+ if self.vocabs.mutable:
+ BiaffineSecondaryParser.build_vocabs(self, dataset, logger, transformer=True)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
+ gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset,
+ pad={'arc': 0, 'arc_2nd': False})
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+
+ BiaffineSecondaryParser.update_metric(self, *prediction, batch['arc'], batch['rel_id'], output[1],
+ batch['punct_mask'], metric, batch)
+
+ def decode_output(self, output: Dict[str, Any], batch: Dict[str, Any], decoder, **kwargs) \
+ -> Union[Dict[str, Any], Any]:
+ return BiaffineSecondaryParser.decode(self, *output[0], output[1], batch)
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
+ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return BiaffineSecondaryParser.compute_loss(self, *output[0], batch['arc'], batch['rel_id'], output[1],
+ criterion, batch)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return BiaffineSecondaryDependencyDecoder(encoder_size, self.config)
+
+ def build_metric(self, **kwargs):
+ return BiaffineSecondaryParser.build_metric(self, **kwargs)
+
+ def build_criterion(self, **kwargs):
+ return BiaffineSecondaryParser.build_criterion(self, **kwargs)
+
+ def build_optimizer(self, decoder: torch.nn.Module, **kwargs):
+ config = self.config
+ optimizer = torch.optim.Adam(decoder.parameters(),
+ config.lr,
+ (config.mu, config.nu),
+ config.epsilon)
+ return optimizer
+
+ def input_is_flat(self, data) -> bool:
+ return BiaffineSecondaryParser.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ outputs = []
+ return BiaffineSecondaryParser.predictions_to_human(self, prediction, outputs, batch['token'], use_pos=False)
diff --git a/hanlp/components/mtl/tasks/lem.py b/hanlp/components/mtl/tasks/lem.py
new file mode 100644
index 000000000..7c4e487ed
--- /dev/null
+++ b/hanlp/components/mtl/tasks/lem.py
@@ -0,0 +1,129 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-09 16:37
+import logging
+from typing import Dict, Any, Union, Iterable, Callable, List
+
+import torch
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict
+from hanlp.components.lemmatizer import TransformerLemmatizer
+from hanlp.components.mtl.tasks import Task
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+from torch.utils.data import DataLoader
+
+
+class LinearDecoder(torch.nn.Module):
+ def __init__(self,
+ hidden_size,
+ num_labels) -> None:
+ super().__init__()
+ self.classifier = torch.nn.Linear(hidden_size, num_labels)
+
+ def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
+ return self.classifier(contextualized_embeddings)
+
+
+class TransformerLemmatization(Task, TransformerLemmatizer):
+
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=1e-3,
+ separate_optimizer=False,
+ cls_is_bos=False,
+ sep_is_eos=False,
+ max_seq_len=None,
+ sent_delimiter=None,
+ char_level=False,
+ hard_constraint=False,
+ token_key='token', **kwargs) -> None:
+ """ Transition based lemmatization (:cite:`kondratyuk-straka-2019-75`).
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level, which is never the case for
+ lemmatization.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ token_key: The key to tokens in dataset. This should always be set to ``token`` in MTL.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def build_dataloader(self,
+ data: List[List[str]],
+ transform: Callable = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ cache=False,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ args = dict((k, self.config[k]) for k in
+ ['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'] if k in self.config)
+ dataset = self.build_dataset(data, cache=cache, transform=transform, **args)
+ dataset.append_transform(self.vocabs)
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, 'token_input_ids', 'token'),
+ shuffle=training, gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return TransformerLemmatizer.compute_loss(self, criterion, output, batch['tag_id'], batch['mask'])
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder,
+ **kwargs) -> Union[Dict[str, Any], Any]:
+ return TransformerLemmatizer.decode_output(self, output, mask, batch, decoder)
+
+ def update_metrics(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any],
+ metric: Union[MetricDict, Metric]):
+ return TransformerLemmatizer.update_metrics(self, metric, output, batch['tag_id'], batch['mask'])
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return LinearDecoder(encoder_size, len(self.vocabs['tag']))
+
+ def build_metric(self, **kwargs):
+ return TransformerLemmatizer.build_metric(self, **kwargs)
+
+ def input_is_flat(self, data) -> bool:
+ return TransformerLemmatizer.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> Union[List, Dict]:
+ return TransformerLemmatizer.prediction_to_human(self, prediction, self.vocabs['tag'].idx_to_token, batch,
+ token=batch['token'])
diff --git a/hanlp/components/mtl/tasks/ner/__init__.py b/hanlp/components/mtl/tasks/ner/__init__.py
new file mode 100644
index 000000000..473c47f61
--- /dev/null
+++ b/hanlp/components/mtl/tasks/ner/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-03 14:34
diff --git a/hanlp/components/mtl/tasks/ner/biaffine_ner.py b/hanlp/components/mtl/tasks/ner/biaffine_ner.py
new file mode 100644
index 000000000..5b2c30489
--- /dev/null
+++ b/hanlp/components/mtl/tasks/ner/biaffine_ner.py
@@ -0,0 +1,106 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-05 01:49
+import logging
+from copy import copy
+from typing import Dict, Any, Union, Iterable, List
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict, TransformList
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.ner.biaffine_ner.biaffine_ner import BiaffineNamedEntityRecognizer
+from hanlp.components.ner.biaffine_ner.biaffine_ner_model import BiaffineNamedEntityRecognitionDecoder
+from hanlp.datasets.ner.json_ner import unpack_ner
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+
+
+class BiaffineNamedEntityRecognition(Task, BiaffineNamedEntityRecognizer):
+
+ def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None,
+ dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False,
+ lr=None, separate_optimizer=False,
+ doc_level_offset=True, is_flat_ner=True, tagset=None, ret_tokens=' ',
+ ffnn_size=150, loss_reduction='mean', **kwargs) -> None:
+ """An implementation of Named Entity Recognition as Dependency Parsing (:cite:`yu-etal-2020-named`). It treats
+ every possible span as a candidate of entity and predicts its entity label. Non-entity spans are assigned NULL
+ label to be excluded. The label prediction is done with a biaffine layer (:cite:`dozat:17a`). As it makes no
+ assumption about the spans, it naturally supports flat NER and nested NER.
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ doc_level_offset: ``True`` to indicate the offsets in ``jsonlines`` are of document level.
+ is_flat_ner: ``True`` for flat NER, otherwise nested NER.
+ tagset: Optional tagset to prune entities outside of this tagset from datasets.
+ ret_tokens: A delimiter between tokens in entities so that the surface form of an entity can be rebuilt.
+ ffnn_size: Feedforward size for MLPs extracting the head/tail representations.
+ loss_reduction: The loss reduction used in aggregating losses.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ BiaffineNamedEntityRecognizer.update_metrics(self, batch, prediction, metric)
+
+ def decode_output(self,
+ output: Dict[str, Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder,
+ **kwargs) -> Union[Dict[str, Any], Any]:
+ return self.get_pred_ner(batch['token'], output['candidate_ner_scores'])
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
+ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return output['loss']
+
+ def build_dataloader(self, data,
+ transform: TransformList = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ transform = copy(transform)
+ transform.append(unpack_ner)
+ dataset = BiaffineNamedEntityRecognizer.build_dataset(self, data, self.vocabs, transform)
+ if self.vocabs.mutable:
+ BiaffineNamedEntityRecognizer.build_vocabs(self, dataset, logger, self.vocabs)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
+ gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return BiaffineNamedEntityRecognitionDecoder(encoder_size, self.config.ffnn_size, len(self.vocabs.label),
+ self.config.loss_reduction)
+
+ def build_metric(self, **kwargs):
+ return BiaffineNamedEntityRecognizer.build_metric(self, **kwargs)
+
+ def input_is_flat(self, data) -> bool:
+ return BiaffineNamedEntityRecognizer.input_is_flat(data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ results = []
+ BiaffineNamedEntityRecognizer.prediction_to_result(batch['token'], prediction, results,
+ ret_tokens=self.config.get('ret_tokens', ' '))
+ return results
diff --git a/hanlp/components/mtl/tasks/ner/tag_ner.py b/hanlp/components/mtl/tasks/ner/tag_ner.py
new file mode 100644
index 000000000..340ab1f32
--- /dev/null
+++ b/hanlp/components/mtl/tasks/ner/tag_ner.py
@@ -0,0 +1,162 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-03 14:35
+import logging
+from typing import Union, List, Dict, Any, Iterable, Callable, Set
+
+import torch
+from hanlp_trie import DictInterface
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.ner.transformer_ner import TransformerNamedEntityRecognizer
+from hanlp.layers.crf.crf import CRF
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+
+
+class LinearCRFDecoder(torch.nn.Module):
+ def __init__(self,
+ hidden_size,
+ num_labels,
+ secondary_encoder=None,
+ crf=False) -> None:
+ super().__init__()
+ self.secondary_encoder = secondary_encoder
+ self.classifier = torch.nn.Linear(hidden_size, num_labels)
+ self.crf = CRF(num_labels) if crf else None
+
+ def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
+ if self.secondary_encoder:
+ contextualized_embeddings = self.secondary_encoder(contextualized_embeddings, mask=mask)
+ return self.classifier(contextualized_embeddings)
+
+
+class TaggingNamedEntityRecognition(Task, TransformerNamedEntityRecognizer):
+
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=1e-3,
+ separate_optimizer=False,
+ max_seq_len=None,
+ sent_delimiter=None,
+ char_level=False,
+ hard_constraint=False,
+ tagging_scheme=None,
+ crf=False,
+ delimiter_in_entity=None,
+ secondary_encoder=None,
+ token_key='token',
+ dict_whitelist: Union[DictInterface, Union[Dict[str, Any], Set[str]]] = None,
+ dict_blacklist: Union[DictInterface, Union[Dict[str, Any], Set[str]]] = None,
+ **kwargs) -> None:
+ r"""A simple tagger using a linear layer with an optional CRF (:cite:`lafferty2001conditional`) layer for
+ NER task. It can utilize whitelist gazetteers which is dict mapping from entity name to entity type.
+ During decoding, it performs longest-prefix-matching of these words to override the prediction from
+ underlining statistical model. It also uses a blacklist to mask out mis-predicted entities.
+
+ .. Note:: For algorithm beginners, longest-prefix-matching is the prerequisite to understand what dictionary can
+ do and what it can't do. The tutorial in `this book `_ can be very helpful.
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level, which is never the case for
+ lemmatization.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ token_key: The key to tokens in dataset. This should always be set to ``token`` in MTL.
+ crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
+ delimiter_in_entity: The delimiter between tokens in entity, which is used to rebuild entity by joining
+ tokens during decoding.
+ secondary_encoder: An optional secondary encoder to provide enhanced representation by taking the hidden
+ states from the main encoder as input.
+ token_key: The key to tokens in dataset. This should always be set to ``token`` in MTL.
+ dict_whitelist: A :class:`dict` or a :class:`~hanlp_trie.dictionary.DictInterface` of gazetteers to be
+ included into the final results.
+ dict_blacklist: A :class:`set` or a :class:`~hanlp_trie.dictionary.DictInterface` of badcases to be
+ excluded from the final results.
+ **kwargs:
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+ self.secondary_encoder = secondary_encoder
+ self.dict_whitelist = dict_whitelist
+ self.dict_blacklist = dict_blacklist
+
+ def build_dataloader(self,
+ data,
+ transform: Callable = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ cache=False,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ args = dict((k, self.config[k]) for k in
+ ['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'] if k in self.config)
+ dataset = self.build_dataset(data, cache=cache, transform=transform, **args)
+ dataset.append_transform(self.vocabs)
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(
+ self.compute_lens(data, dataset, 'token_input_ids', 'token'),
+ shuffle=training, gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return TransformerNamedEntityRecognizer.compute_loss(self, criterion, output, batch['tag_id'], batch['mask'])
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder,
+ **kwargs) -> Union[Dict[str, Any], Any]:
+ return TransformerNamedEntityRecognizer.decode_output(self, output, batch['mask'], batch, decoder)
+
+ def update_metrics(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any],
+ metric: Union[MetricDict, Metric]):
+ return TransformerNamedEntityRecognizer.update_metrics(self, metric, output, batch['tag_id'], batch['mask'],
+ batch, prediction)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return LinearCRFDecoder(encoder_size, len(self.vocabs['tag']), self.secondary_encoder, self.config.crf)
+
+ def build_metric(self, **kwargs):
+ return TransformerNamedEntityRecognizer.build_metric(self, **kwargs)
+
+ def input_is_flat(self, data) -> bool:
+ return TransformerNamedEntityRecognizer.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> Union[List, Dict]:
+ return TransformerNamedEntityRecognizer.prediction_to_human(self, prediction, self.vocabs['tag'].idx_to_token,
+ batch)
diff --git a/hanlp/components/mtl/tasks/pos.py b/hanlp/components/mtl/tasks/pos.py
new file mode 100644
index 000000000..974a9176e
--- /dev/null
+++ b/hanlp/components/mtl/tasks/pos.py
@@ -0,0 +1,154 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-19 18:56
+import logging
+from typing import Dict, Any, Union, Iterable, Callable, List
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
+from hanlp.layers.crf.crf import CRF
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+
+
+class LinearCRFDecoder(torch.nn.Module):
+ def __init__(self,
+ hidden_size,
+ num_labels,
+ crf=False) -> None:
+ """A linear layer with an optional CRF (:cite:`lafferty2001conditional`) layer on top of it.
+
+ Args:
+ hidden_size: Size of hidden states.
+ num_labels: Size of tag set.
+ crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
+ """
+ super().__init__()
+ self.classifier = torch.nn.Linear(hidden_size, num_labels)
+ self.crf = CRF(num_labels) if crf else None
+
+ def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
+ """
+
+ Args:
+ contextualized_embeddings: Hidden states for contextual layer.
+ batch: A dict of a batch.
+ mask: Mask for tokens.
+
+ Returns:
+ Logits. Users are expected to call ``CRF.decode`` on these emissions during decoding and ``CRF.forward``
+ during training.
+
+ """
+ return self.classifier(contextualized_embeddings)
+
+
+class TransformerTagging(Task, TransformerTagger):
+
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=1e-3,
+ separate_optimizer=False,
+ cls_is_bos=False,
+ sep_is_eos=False,
+ max_seq_len=None,
+ sent_delimiter=None,
+ char_level=False,
+ hard_constraint=False,
+ crf=False,
+ token_key='token', **kwargs) -> None:
+ """A simple tagger using a linear layer with an optional CRF (:cite:`lafferty2001conditional`) layer for
+ any tagging tasks including PoS tagging and many others.
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level, which is never the case for
+ lemmatization.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
+ token_key: The key to tokens in dataset. This should always be set to ``token`` in MTL.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def build_dataloader(self,
+ data,
+ transform: Callable = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ cache=False,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ args = dict((k, self.config[k]) for k in
+ ['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'] if k in self.config)
+ dataset = self.build_dataset(data, cache=cache, transform=transform, **args)
+ dataset.append_transform(self.vocabs)
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, 'token_input_ids', 'token'),
+ shuffle=training, gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return TransformerTagger.compute_loss(self, criterion, output, batch['tag_id'], batch['mask'])
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder,
+ **kwargs) -> Union[Dict[str, Any], Any]:
+ return TransformerTagger.decode_output(self, output, mask, batch, decoder)
+
+ def update_metrics(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any],
+ metric: Union[MetricDict, Metric]):
+ return TransformerTagger.update_metrics(self, metric, output, batch['tag_id'], batch['mask'])
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return LinearCRFDecoder(encoder_size, len(self.vocabs['tag']), self.config.crf)
+
+ def build_metric(self, **kwargs):
+ return TransformerTagger.build_metric(self, **kwargs)
+
+ def input_is_flat(self, data) -> bool:
+ return TransformerTagger.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> Union[List, Dict]:
+ return TransformerTagger.prediction_to_human(self, prediction, self.vocabs['tag'].idx_to_token, batch)
diff --git a/hanlp/components/mtl/tasks/sdp.py b/hanlp/components/mtl/tasks/sdp.py
new file mode 100644
index 000000000..42dda531b
--- /dev/null
+++ b/hanlp/components/mtl/tasks/sdp.py
@@ -0,0 +1,169 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-13 21:39
+import logging
+from typing import Dict, Any, Union, Iterable, List
+
+import torch
+from torch.optim import Adam
+from torch.optim.lr_scheduler import ExponentialLR
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict, TransformList
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.parsers.biaffine.biaffine_model import BiaffineDecoder
+from hanlp.components.parsers.biaffine.biaffine_sdp import BiaffineSemanticDependencyParser
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs
+
+
+class BiaffineSemanticDependencyParsing(Task, BiaffineSemanticDependencyParser):
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=2e-3, separate_optimizer=False,
+ punct=False,
+ tree=True,
+ pad_rel=None,
+ apply_constraint=False,
+ single_root=True,
+ no_zero_head=None,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ decay=.75,
+ decay_steps=5000,
+ cls_is_bos=True,
+ use_pos=False,
+ **kwargs) -> None:
+ r"""Implementation of "Stanford's graph-based neural dependency parser at
+ the conll 2017 shared task" (:cite:`dozat2017stanford`).
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ punct: ``True`` to include punctuations in evaluation.
+ pad_rel: Padding token for relations.
+ apply_constraint: Enforce constraints (see following parameters).
+ single_root: Force single root.
+ no_zero_head: Every token has at least one head.
+ n_mlp_arc: Number of features for arc representation.
+ n_mlp_rel: Number of features for rel representation.
+ mlp_dropout: Dropout applied to MLPs.
+ mu: First coefficient used for computing running averages of gradient and its square in Adam.
+ nu: Second coefficient used for computing running averages of gradient and its square in Adam.
+ epsilon: Term added to the denominator to improve numerical stability
+ decay: Decay rate for exceptional lr scheduler.
+ decay_steps: Decay every ``decay_steps`` steps.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ use_pos: Use pos feature.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ BiaffineSemanticDependencyParser.update_metric(self, *prediction, batch['arc'], batch['rel_id'], output[1],
+ output[-1], metric, batch)
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder, **kwargs) -> Union[Dict[str, Any], Any]:
+ (arc_scores, rel_scores), mask, punct_mask = output
+ return BiaffineSemanticDependencyParser.decode(self, arc_scores, rel_scores, mask, batch)
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
+ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ (arc_scores, rel_scores), mask, punct_mask = output
+ return BiaffineSemanticDependencyParser.compute_loss(self, arc_scores, rel_scores, batch['arc'],
+ batch['rel_id'], mask, criterion,
+ batch)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return BiaffineDecoder(encoder_size, self.config.n_mlp_arc, self.config.n_mlp_rel, self.config.mlp_dropout,
+ len(self.vocabs.rel))
+
+ def build_metric(self, **kwargs):
+ return BiaffineSemanticDependencyParser.build_metric(self, **kwargs)
+
+ def build_dataloader(self, data, transform: TransformList = None, training=False, device=None,
+ logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader:
+ if isinstance(data, list):
+ data = BiaffineSemanticDependencyParser.build_samples(self, data, self.config.use_pos)
+ dataset = BiaffineSemanticDependencyParser.build_dataset(self, data, transform)
+ if isinstance(data, str):
+ dataset.purge_cache()
+ if self.vocabs.mutable:
+ BiaffineSemanticDependencyParser.build_vocabs(self, dataset, logger, transformer=True)
+ if dataset.cache:
+ timer = CountdownTimer(len(dataset))
+ BiaffineSemanticDependencyParser.cache_dataset(self, dataset, timer, training, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
+ gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset,
+ pad=self.get_pad_dict())
+
+ def feed_batch(self, h: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask: torch.BoolTensor,
+ decoder: torch.nn.Module):
+ logits = super().feed_batch(h, batch, mask, decoder)
+ arc_scores = logits[0]
+ mask = mask.clone()
+ mask[:, 0] = 0
+ mask = self.convert_to_3d_mask(arc_scores, mask)
+ punct_mask = self.convert_to_3d_puncts(batch.get('punct_mask', None), mask)
+ return logits, mask, punct_mask
+
+ def build_optimizer(self, decoder: torch.nn.Module, **kwargs):
+ config = self.config
+ optimizer = Adam(decoder.parameters(),
+ config.lr,
+ (config.mu, config.nu),
+ config.epsilon)
+ scheduler = ExponentialLR(optimizer, config.decay ** (1 / config.decay_steps))
+ return optimizer, scheduler
+
+ def input_is_flat(self, data) -> bool:
+ return BiaffineSemanticDependencyParser.input_is_flat(self, data, self.config.use_pos)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ arcs, rels = prediction
+ arcs = arcs[:, 1:, :] # Skip the ROOT
+ rels = rels[:, 1:, :]
+ arcs = arcs.tolist()
+ rels = rels.tolist()
+ vocab = self.vocabs['rel'].idx_to_token
+ for arcs_per_sent, rels_per_sent, tokens in zip(arcs, rels, batch['token']):
+ tokens = tokens[1:]
+ sent_len = len(tokens)
+ result = []
+ for a, r in zip(arcs_per_sent[:sent_len], rels_per_sent[:sent_len]):
+ heads = [i for i in range(sent_len + 1) if a[i]]
+ deprels = [vocab[r[i]] for i in range(sent_len + 1) if a[i]]
+ result.append(list(zip(heads, deprels)))
+ yield result
diff --git a/hanlp/components/mtl/tasks/srl/__init__.py b/hanlp/components/mtl/tasks/srl/__init__.py
new file mode 100644
index 000000000..34bb68447
--- /dev/null
+++ b/hanlp/components/mtl/tasks/srl/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-04 16:49
diff --git a/hanlp/components/mtl/tasks/srl/bio_srl.py b/hanlp/components/mtl/tasks/srl/bio_srl.py
new file mode 100644
index 000000000..1e0832a6e
--- /dev/null
+++ b/hanlp/components/mtl/tasks/srl/bio_srl.py
@@ -0,0 +1,123 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-04 16:50
+import logging
+from typing import Dict, Any, List, Union, Iterable, Callable
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import PadSequenceDataLoader, SamplerBuilder
+from hanlp.common.transform import VocabDict
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.srl.span_bio.baffine_tagging import BiaffineTaggingDecoder
+from hanlp.components.srl.span_bio.span_bio import SpanBIOSemanticRoleLabeler
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+import torch.nn.functional as F
+
+
+class SpanBIOSemanticRoleLabeling(Task, SpanBIOSemanticRoleLabeler):
+
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=None,
+ separate_optimizer=False,
+ cls_is_bos=False,
+ sep_is_eos=False,
+ crf=False,
+ n_mlp_rel=300,
+ mlp_dropout=0.2,
+ loss_reduction='mean',
+ doc_level_offset=True,
+ **kwargs) -> None:
+ """A span based Semantic Role Labeling task using BIO scheme for tagging the role of each token. Given a
+ predicate and a token, it uses biaffine (:cite:`dozat:17a`) to predict their relations as one of BIO-ROLE.
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
+ n_mlp_rel: Output size of MLPs for representing predicate and tokens.
+ mlp_dropout: Dropout applied to MLPs.
+ loss_reduction: Loss reduction for aggregating losses.
+ doc_level_offset: ``True`` to indicate the offsets in ``jsonlines`` are of document level.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def build_dataloader(self, data, transform: Callable = None, training=False, device=None,
+ logger: logging.Logger = None, cache=False, gradient_accumulation=1, **kwargs) -> DataLoader:
+ dataset = self.build_dataset(data, transform=[transform, self.vocabs])
+ if self.vocabs.mutable:
+ SpanBIOSemanticRoleLabeler.build_vocabs(self, dataset, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
+ gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
+ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ pred, mask = output
+ return SpanBIOSemanticRoleLabeler.compute_loss(self, criterion, pred, batch['srl_id'], mask)
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder: torch.nn.Module, **kwargs) -> Union[Dict[str, Any], Any]:
+ pred, mask = output
+ return SpanBIOSemanticRoleLabeler.decode_output(self, pred, mask, batch, decoder)
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ return SpanBIOSemanticRoleLabeler.update_metrics(self, metric, prediction, batch)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return BiaffineTaggingDecoder(
+ len(self.vocabs['srl']),
+ encoder_size,
+ self.config.n_mlp_rel,
+ self.config.mlp_dropout,
+ self.config.crf,
+ )
+
+ def feed_batch(self, h: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask: torch.BoolTensor,
+ decoder: torch.nn.Module):
+ pred = decoder(h)
+ mask3d = self.compute_mask(mask)
+ if self.config.crf:
+ token_index = mask3d[0]
+ pred = pred.flatten(end_dim=1)[token_index]
+ pred = F.log_softmax(pred, dim=-1)
+ return pred, mask3d
+
+ def build_metric(self, **kwargs):
+ return SpanBIOSemanticRoleLabeler.build_metric(self)
+
+ def input_is_flat(self, data) -> bool:
+ return SpanBIOSemanticRoleLabeler.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: List, batch: Dict[str, Any]) -> List:
+ yield from SpanBIOSemanticRoleLabeler.prediction_to_result(self, prediction, batch)
diff --git a/hanlp/components/mtl/tasks/srl/rank_srl.py b/hanlp/components/mtl/tasks/srl/rank_srl.py
new file mode 100644
index 000000000..2bdcba8b5
--- /dev/null
+++ b/hanlp/components/mtl/tasks/srl/rank_srl.py
@@ -0,0 +1,120 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-05 15:43
+import logging
+from typing import Union, List, Dict, Any, Iterable, Callable
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.srl.span_rank.span_rank import SpanRankingSemanticRoleLabeler
+from hanlp.components.srl.span_rank.span_ranking_srl_model import SpanRankingSRLDecoder
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+
+
+class SpanRankingSemanticRoleLabeling(Task, SpanRankingSemanticRoleLabeler):
+
+ def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None,
+ dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None, use_raw_hidden_states=False,
+ lr=1e-3, separate_optimizer=False,
+ lexical_dropout=0.5,
+ dropout=0.2,
+ span_width_feature_size=20,
+ ffnn_size=150,
+ ffnn_depth=2,
+ argument_ratio=0.8,
+ predicate_ratio=0.4,
+ max_arg_width=30,
+ mlp_label_size=100,
+ enforce_srl_constraint=False,
+ use_gold_predicates=False,
+ doc_level_offset=True,
+ use_biaffine=False,
+ loss_reduction='mean',
+ with_argument=' ',
+ **kwargs) -> None:
+ r""" An implementation of "Jointly Predicting Predicates and Arguments in Neural Semantic Role Labeling"
+ (:cite:`he-etal-2018-jointly`). It generates candidates triples of (predicate, arg_start, arg_end) and rank them.
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ lexical_dropout: Dropout applied to hidden states of encoder.
+ dropout: Dropout used for other layers except the encoder.
+ span_width_feature_size: Span width feature size.
+ ffnn_size: Feedforward size.
+ ffnn_depth: Number of layers of feedforward MLPs.
+ argument_ratio: Ratio of candidate arguments over number of tokens.
+ predicate_ratio: Ratio of candidate predicates over number of tokens.
+ max_arg_width: Maximum argument width.
+ mlp_label_size: Feature size for label representation.
+ enforce_srl_constraint: Enforce SRL constraints (number of core ARGs etc.).
+ use_gold_predicates: Use gold predicates instead of predicting them.
+ doc_level_offset: ``True`` to indicate the offsets in ``jsonlines`` are of document level.
+ use_biaffine: ``True`` to use biaffine (:cite:`dozat:17a`) instead of lineary layer for label prediction.
+ loss_reduction: The loss reduction used in aggregating losses.
+ with_argument: The delimiter between tokens in arguments to be used for joining tokens for outputs.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def build_dataloader(self, data, transform: Callable = None, training=False, device=None,
+ logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader:
+ dataset = self.build_dataset(data, isinstance(data, list), logger, transform)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
+ gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ return SpanRankingSemanticRoleLabeler.update_metrics(self, batch, {'prediction': prediction},
+ tuple(metric.values()))
+
+ def decode_output(self,
+ output: Dict[str, Any],
+ mask: torch.BoolTensor,
+ batch: Dict[str, Any],
+ decoder, **kwargs) -> Union[Dict[str, Any], Any]:
+ return SpanRankingSemanticRoleLabeler.decode_output(self, output, batch)
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
+ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return output['loss']
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return SpanRankingSRLDecoder(encoder_size, len(self.vocabs.srl_label), self.config)
+
+ def build_metric(self, **kwargs):
+ predicate_f1, end_to_end_f1 = SpanRankingSemanticRoleLabeler.build_metric(self, **kwargs)
+ return MetricDict({'predicate': predicate_f1, 'e2e': end_to_end_f1})
+
+ def build_criterion(self, **kwargs):
+ pass
+
+ def input_is_flat(self, data) -> bool:
+ return SpanRankingSemanticRoleLabeler.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ return SpanRankingSemanticRoleLabeler.format_dict_to_results(batch['token'], prediction, exclusive_offset=True,
+ with_predicate=True,
+ with_argument=self.config.get('with_argument',
+ ' '),
+ label_first=True)
diff --git a/tests/script/__init__.py b/hanlp/components/mtl/tasks/tok/__init__.py
similarity index 62%
rename from tests/script/__init__.py
rename to hanlp/components/mtl/tasks/tok/__init__.py
index a2812a725..d56f3cdb0 100644
--- a/tests/script/__init__.py
+++ b/hanlp/components/mtl/tasks/tok/__init__.py
@@ -1,3 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2020-01-01 17:34
\ No newline at end of file
+# Date: 2020-08-11 16:34
\ No newline at end of file
diff --git a/hanlp/components/mtl/tasks/tok/reg_tok.py b/hanlp/components/mtl/tasks/tok/reg_tok.py
new file mode 100644
index 000000000..e01a4866b
--- /dev/null
+++ b/hanlp/components/mtl/tasks/tok/reg_tok.py
@@ -0,0 +1,109 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-02 16:51
+import logging
+from typing import Union, List, Dict, Any, Iterable, Tuple
+
+import torch
+from alnlp.modules import util
+from torch import Tensor
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import FieldLength, TransformList
+from hanlp.components.mtl.tasks import Task
+from hanlp.datasets.tokenization.txt import TextTokenizingDataset
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.layers.transformers.pt_imports import PreTrainedTokenizer
+from hanlp.metrics.chunking.binary_chunking_f1 import BinaryChunkingF1
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp_common.util import merge_locals_kwargs
+
+
+def generate_token_span_tuple(sample: dict):
+ prefix_mask = sample.get('text_prefix_mask', None)
+ if prefix_mask:
+ sample['span_tuple'] = spans = []
+ previous_prefix = 0
+ prefix_mask_ = prefix_mask[1:-1]
+ for i, mask in enumerate(prefix_mask_):
+ if i and mask:
+ spans.append((previous_prefix, i))
+ previous_prefix = i
+ spans.append((previous_prefix, len(prefix_mask_)))
+ return sample
+
+
+class RegressionTokenizingDecoder(torch.nn.Linear):
+
+ def __init__(self, in_features: int, out_features: int = 1, bias: bool = ...) -> None:
+ super().__init__(in_features, out_features, bias)
+
+ # noinspection PyMethodOverriding
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
+ return super().forward(input[:, 1:-1, :]).squeeze_(-1)
+
+
+class RegressionTokenization(Task):
+
+ def __init__(self, trn: str = None, dev: str = None, tst: str = None, sampler_builder: SamplerBuilder = None,
+ dependencies: str = None, scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=True, lr=1e-3, separate_optimizer=False, delimiter=None,
+ max_seq_len=None, sent_delimiter=None) -> None:
+ super().__init__(**merge_locals_kwargs(locals()))
+
+ def build_criterion(self, **kwargs):
+ return torch.nn.BCEWithLogitsLoss(reduction='mean')
+
+ def build_metric(self, **kwargs):
+ return BinaryChunkingF1()
+
+ # noinspection PyMethodOverriding
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return RegressionTokenizingDecoder(encoder_size)
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
+ pass
+
+ def build_dataloader(self,
+ data,
+ transform: TransformList = None,
+ training=False,
+ device=None,
+ logger: logging.Logger = None,
+ tokenizer: PreTrainedTokenizer = None,
+ **kwargs) -> DataLoader:
+ assert tokenizer
+ dataset = TextTokenizingDataset(data, cache=isinstance(data, str), delimiter=self.config.sent_delimiter,
+ generate_idx=isinstance(data, list),
+ max_seq_len=self.config.max_seq_len,
+ sent_delimiter=self.config.sent_delimiter,
+ transform=[
+ TransformerSequenceTokenizer(tokenizer,
+ 'text',
+ ret_prefix_mask=True,
+ ret_subtokens=True,
+ ),
+ FieldLength('text_input_ids', 'text_input_ids_length', delta=-2),
+ generate_token_span_tuple])
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, 'text_input_ids', 'text'),
+ shuffle=training),
+ device=device,
+ dataset=dataset)
+
+ def decode_output(self,
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ batch: Dict[str, Any], **kwargs) -> List[Tuple[int, int]]:
+ spans = BinaryChunkingF1.decode_spans(output > 0, batch['text_input_ids_length'])
+ return spans
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: List[Tuple[int, int]], metric: BinaryChunkingF1):
+ metric.update(prediction, batch['span_tuple'])
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion):
+ mask = util.lengths_to_mask(batch['text_input_ids_length'])
+ return criterion(output[mask], batch['text_prefix_mask'][:, 1:-1][mask].to(torch.float))
diff --git a/hanlp/components/mtl/tasks/tok/tag_tok.py b/hanlp/components/mtl/tasks/tok/tag_tok.py
new file mode 100644
index 000000000..422325ab9
--- /dev/null
+++ b/hanlp/components/mtl/tasks/tok/tag_tok.py
@@ -0,0 +1,198 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-11 16:35
+import logging
+from typing import Dict, Any, Union, Iterable, List, Set
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict, TransformList
+from hanlp.components.mtl.tasks import Task
+from hanlp.components.tokenizers.transformer import TransformerTaggingTokenizer
+from hanlp.layers.crf.crf import CRF
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp_common.util import merge_locals_kwargs
+from hanlp_trie import DictInterface, TrieDict
+
+
+class LinearCRFDecoder(torch.nn.Module):
+ def __init__(self,
+ hidden_size,
+ num_labels,
+ crf=False) -> None:
+ super().__init__()
+ self.classifier = torch.nn.Linear(hidden_size, num_labels)
+ self.crf = CRF(num_labels) if crf else None
+
+ def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
+ return self.classifier(contextualized_embeddings[:, 1:-1, :])
+
+
+class TaggingTokenization(Task, TransformerTaggingTokenizer):
+
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=1e-3, separate_optimizer=False,
+ cls_is_bos=True,
+ sep_is_eos=True,
+ delimiter=None,
+ max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False,
+ transform=None,
+ tagging_scheme='BMES',
+ crf=False,
+ token_key='token', **kwargs) -> None:
+ """Tokenization which casts a chunking problem into a tagging problem.
+ This task has to create batch of tokens containing both [CLS] and [SEP] since it's usually the first task
+ and later tasks might need them.
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ delimiter: Delimiter used to split a line in the corpus.
+ max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ transform: An optional transform to be applied to samples. Usually a character normalization transform is
+ passed in.
+ tagging_scheme: Either ``BMES`` or ``BI``.
+ crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
+ token_key: The key to tokens in dataset. This should always be set to ``token`` in MTL.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.transform = transform
+ self.vocabs = VocabDict()
+
+ def build_dataloader(self, data, transform: TransformList = None, training=False, device=None,
+ logger: logging.Logger = None, cache=False, gradient_accumulation=1, **kwargs) -> DataLoader:
+ args = dict((k, self.config[k]) for k in
+ ['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'] if k in self.config)
+ # We only need those transforms before TransformerTokenizer
+ transformer_index = transform.index_by_type(TransformerSequenceTokenizer)
+ assert transformer_index is not None
+ transform = transform[:transformer_index + 1]
+ if self.transform:
+ transform.insert(0, self.transform)
+ transform.append(self.last_transform())
+ dataset = self.build_dataset(data, cache=cache, transform=transform, **args)
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, 'token_input_ids'),
+ shuffle=training, gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset)
+
+ def compute_loss(self,
+ batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ criterion) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return TransformerTaggingTokenizer.compute_loss(self, criterion, output, batch['tag_id'], batch['mask'])
+
+ def decode_output(self, output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor, batch: Dict[str, Any], decoder, **kwargs) -> Union[Dict[str, Any], Any]:
+ return TransformerTaggingTokenizer.decode_output(self, output, mask, batch, decoder)
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ TransformerTaggingTokenizer.update_metrics(self, metric, output, batch['tag_id'], None, batch, prediction)
+
+ def build_model(self, encoder_size, training=True, **kwargs) -> torch.nn.Module:
+ return LinearCRFDecoder(encoder_size, len(self.vocabs['tag']), self.config.crf)
+
+ def build_metric(self, **kwargs):
+ return TransformerTaggingTokenizer.build_metric(self)
+
+ def build_criterion(self, model=None, **kwargs):
+ return TransformerTaggingTokenizer.build_criterion(self, model=model, reduction='mean')
+
+ def input_is_flat(self, data) -> bool:
+ return TransformerTaggingTokenizer.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> Union[List, Dict]:
+ return TransformerTaggingTokenizer.prediction_to_human(self, prediction, None, batch, rebuild_span=True)
+
+ def build_tokenizer(self, tokenizer: TransformerSequenceTokenizer):
+ # The transform for tokenizer needs very special settings, ensure these settings are set properly.
+ return TransformerSequenceTokenizer(
+ tokenizer.tokenizer,
+ tokenizer.input_key,
+ tokenizer.output_key,
+ tokenizer.max_seq_length,
+ tokenizer.truncate_long_sequences,
+ ret_subtokens=True,
+ ret_subtokens_group=True,
+ ret_token_span=True,
+ cls_is_bos=True,
+ sep_is_eos=True,
+ use_fast=tokenizer.tokenizer.is_fast,
+ dict_force=self.dict_force,
+ strip_cls_sep=False,
+ )
+
+ def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False):
+ return [{self.config.token_key: sent} for sent in inputs]
+
+ @property
+ def dict_force(self) -> DictInterface:
+ return TransformerTaggingTokenizer.dict_force.fget(self)
+
+ @dict_force.setter
+ def dict_force(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
+ if not isinstance(dictionary, DictInterface):
+ dictionary = TrieDict(dictionary)
+ self.config.dict_force = dictionary
+
+ @property
+ def dict_combine(self) -> DictInterface:
+ return TransformerTaggingTokenizer.dict_combine.fget(self)
+
+ @dict_combine.setter
+ def dict_combine(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
+ # noinspection PyArgumentList
+ TransformerTaggingTokenizer.dict_combine.fset(self, dictionary)
+
+ def transform_batch(self, batch: Dict[str, Any], results: Dict[str, Any] = None, cls_is_bos=False,
+ sep_is_eos=False) -> Dict[str, Any]:
+ """
+ This method is overrode to honor the zero indexed token used in custom dict. Although for a tokenizer,
+ cls_is_bos = sep_is_eos = True, its tokens don't contain [CLS] or [SEP]. This behaviour is adopted from the
+ early versions and it is better kept to avoid migration efforts.
+
+
+ Args:
+ batch: A batch of samples.
+ results: Predicted results from other tasks which might be useful for this task to utilize. Say a dep task
+ uses both token and pos as features, then it will need both tok and pos results to make a batch.
+ cls_is_bos: First token in this batch is BOS.
+ sep_is_eos: Last token in this batch is EOS.
+
+ Returns:
+ A batch.
+
+ """
+ return batch
diff --git a/hanlp/components/mtl/tasks/ud.py b/hanlp/components/mtl/tasks/ud.py
new file mode 100644
index 000000000..dd11f3fcc
--- /dev/null
+++ b/hanlp/components/mtl/tasks/ud.py
@@ -0,0 +1,170 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-17 21:54
+import logging
+from typing import Dict, Any, List, Union, Iterable, Callable
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import SamplerBuilder, PadSequenceDataLoader
+from hanlp_common.document import Document
+from hanlp.common.transform import VocabDict, PunctuationMask
+from hanlp.components.mtl.tasks import Task
+from hanlp_common.conll import CoNLLUWord
+from hanlp.components.parsers.ud.ud_model import UniversalDependenciesDecoder
+from hanlp.components.parsers.ud.ud_parser import UniversalDependenciesParser
+from hanlp.components.parsers.ud.util import generate_lemma_rule, append_bos
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+
+
+class UniversalDependenciesParsing(Task, UniversalDependenciesParser):
+
+ def __init__(self,
+ trn: str = None,
+ dev: str = None,
+ tst: str = None,
+ sampler_builder: SamplerBuilder = None,
+ dependencies: str = None,
+ scalar_mix: ScalarMixWithDropoutBuilder = None,
+ use_raw_hidden_states=False,
+ lr=None,
+ separate_optimizer=False,
+ cls_is_bos=True,
+ sep_is_eos=False,
+ n_mlp_arc=768,
+ n_mlp_rel=256,
+ mlp_dropout=.33,
+ tree=False,
+ proj=False,
+ punct=False,
+ max_seq_len=None,
+ **kwargs) -> None:
+ r"""Universal Dependencies Parsing (lemmatization, features, PoS tagging and dependency parsing) implementation
+ of "75 Languages, 1 Model: Parsing Universal Dependencies Universally" (:cite:`kondratyuk-straka-2019-75`).
+
+ Args:
+ trn: Path to training set.
+ dev: Path to dev set.
+ tst: Path to test set.
+ sampler_builder: A builder which builds a sampler.
+ dependencies: Its dependencies on other tasks.
+ scalar_mix: A builder which builds a `ScalarMixWithDropout` object.
+ use_raw_hidden_states: Whether to use raw hidden states from transformer without any pooling.
+ lr: Learning rate for this task.
+ separate_optimizer: Use customized separate optimizer for this task.
+ cls_is_bos: ``True`` to treat the first token as ``BOS``.
+ sep_is_eos: ``True`` to treat the last token as ``EOS``.
+ n_mlp_arc: Number of features for arc representation.
+ n_mlp_rel: Number of features for rel representation.
+ mlp_dropout: Dropout applied to MLPs.
+ tree: ``True`` to enforce tree constraint.
+ proj: ``True`` for projective parsing.
+ punct: ``True`` to include punctuations in evaluation.
+ max_seq_len: Prune samples longer than this length. Useful for reducing GPU consumption.
+ **kwargs: Not used.
+ """
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.vocabs = VocabDict()
+
+ def build_dataloader(self, data, transform: Callable = None, training=False, device=None,
+ logger: logging.Logger = None, cache=False, gradient_accumulation=1, **kwargs) -> DataLoader:
+ _transform = [generate_lemma_rule, append_bos, self.vocabs, transform]
+ if isinstance(data, str) and not self.config.punct:
+ _transform.append(PunctuationMask('token', 'punct_mask'))
+ dataset = UniversalDependenciesParser.build_dataset(self, data, _transform)
+ if self.vocabs.mutable:
+ UniversalDependenciesParser.build_vocabs(self, dataset, logger, transformer=True)
+ max_seq_len = self.config.get('max_seq_len', None)
+ if max_seq_len and isinstance(data, str):
+ dataset.prune(lambda x: len(x['token_input_ids']) > max_seq_len, logger)
+ return PadSequenceDataLoader(
+ batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, length_field='token'),
+ shuffle=training, gradient_accumulation=gradient_accumulation),
+ device=device,
+ dataset=dataset,
+ pad={'arc': 0})
+
+ def compute_loss(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
+ Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
+ return output[0]['loss'] / 4 # we have 4 tasks
+
+ def decode_output(self, output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ mask: torch.BoolTensor, batch: Dict[str, Any], decoder: torch.nn.Module, **kwargs) -> Union[
+ Dict[str, Any], Any]:
+ return UniversalDependenciesParser.decode_output(self, *output, batch)
+
+ def update_metrics(self, batch: Dict[str, Any],
+ output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any],
+ prediction: Dict[str, Any], metric: Union[MetricDict, Metric]):
+ UniversalDependenciesParser.update_metrics(self, metric, batch, *output)
+
+ # noinspection PyMethodOverriding
+ def build_model(self,
+ encoder_size,
+ n_mlp_arc,
+ n_mlp_rel,
+ mlp_dropout,
+ training=True,
+ **kwargs) -> torch.nn.Module:
+ return UniversalDependenciesDecoder(
+ encoder_size,
+ n_mlp_arc,
+ n_mlp_rel,
+ mlp_dropout,
+ len(self.vocabs.rel),
+ len(self.vocabs.lemma),
+ len(self.vocabs.pos),
+ len(self.vocabs.feat),
+ 0,
+ 0
+ )
+
+ def build_metric(self, **kwargs):
+ return UniversalDependenciesParser.build_metric(self)
+
+ def input_is_flat(self, data) -> bool:
+ return UniversalDependenciesParser.input_is_flat(self, data)
+
+ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]) -> List:
+ yield from UniversalDependenciesParser.prediction_to_human(self, prediction, batch)
+
+ def feed_batch(self, h: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask: torch.BoolTensor,
+ decoder: torch.nn.Module):
+ mask = self.compute_mask(batch)
+ output_dict = decoder(h, batch, mask)
+ if decoder.training:
+ mask = mask.clone()
+ mask[:, 0] = 0
+ return output_dict, mask
+
+ def finalize_document(self, doc: Document, task_name: str):
+ lem = []
+ pos = []
+ feat = []
+ dep = []
+ for sent in doc[task_name]:
+ sent: List[CoNLLUWord] = sent
+ lem.append([x.lemma for x in sent])
+ pos.append([x.upos for x in sent])
+ feat.append([x.feats for x in sent])
+ dep.append([(x.head, x.deprel) for x in sent])
+ promoted = 0
+ if 'lem' not in doc:
+ doc['lem'] = lem
+ promoted += 1
+ if 'pos' not in doc:
+ doc['pos'] = pos
+ promoted += 1
+ if 'feat' not in doc:
+ doc['fea'] = feat
+ promoted += 1
+ if 'dep' not in doc:
+ doc['dep'] = dep
+ promoted += 1
+ if promoted == 4:
+ doc.pop(task_name)
diff --git a/hanlp/components/ner/__init__.py b/hanlp/components/ner/__init__.py
new file mode 100644
index 000000000..625a51750
--- /dev/null
+++ b/hanlp/components/ner/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-21 17:22
\ No newline at end of file
diff --git a/hanlp/components/ner/biaffine_ner/__init__.py b/hanlp/components/ner/biaffine_ner/__init__.py
new file mode 100644
index 000000000..c53450b4d
--- /dev/null
+++ b/hanlp/components/ner/biaffine_ner/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-21 18:41
\ No newline at end of file
diff --git a/hanlp/components/ner/biaffine_ner/biaffine_ner.py b/hanlp/components/ner/biaffine_ner/biaffine_ner.py
new file mode 100644
index 000000000..85b0a5a0f
--- /dev/null
+++ b/hanlp/components/ner/biaffine_ner/biaffine_ner.py
@@ -0,0 +1,401 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-09 18:13
+import logging
+from typing import Union, List, Callable, Dict, Any
+
+from hanlp_common.constant import IDX
+from hanlp.common.structure import History
+from hanlp.components.ner.biaffine_ner.biaffine_ner_model import BiaffineNamedEntityRecognitionModel
+from hanlp.datasets.ner.json_ner import JsonNERDataset, unpack_ner
+from hanlp.layers.transformers.utils import build_optimizer_scheduler_with_transformer
+import torch
+from torch.utils.data import DataLoader
+from hanlp.common.dataset import PadSequenceDataLoader
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength, TransformList
+from hanlp.common.vocab import Vocab
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.metrics.f1 import F1
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs, reorder
+
+
+class BiaffineNamedEntityRecognizer(TorchComponent):
+
+ def __init__(self, **kwargs) -> None:
+ """An implementation of Named Entity Recognition as Dependency Parsing (:cite:`yu-etal-2020-named`). It treats
+ every possible span as a candidate of entity and predicts its entity label. Non-entity spans are assigned NULL
+ label to be excluded. The label prediction is done with a biaffine layer (:cite:`dozat:17a`). As it makes no
+ assumption about the spans, it naturally supports flat NER and nested NER.
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+ self.model: BiaffineNamedEntityRecognitionModel = None
+
+ def build_optimizer(self,
+ trn,
+ epochs,
+ lr,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ transformer_lr,
+ **kwargs):
+ # noinspection PyProtectedMember
+ if self.use_transformer:
+ num_training_steps = len(trn) * epochs // self.config.get('gradient_accumulation', 1)
+ optimizer, scheduler = build_optimizer_scheduler_with_transformer(self.model,
+ self._get_transformer(),
+ lr, transformer_lr,
+ num_training_steps, warmup_steps,
+ weight_decay, adam_epsilon)
+ else:
+ optimizer = torch.optim.Adam(self.model.parameters(), self.config.lr)
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ optimizer=optimizer,
+ mode='max',
+ factor=0.5,
+ patience=2,
+ verbose=True,
+ )
+ return optimizer, scheduler
+
+ @property
+ def use_transformer(self):
+ return 'token' not in self.vocabs
+
+ def _get_transformer(self):
+ return getattr(self.model_.embed, 'transformer', None)
+
+ def build_criterion(self, **kwargs):
+ pass
+
+ # noinspection PyProtectedMember
+ def build_metric(self, **kwargs) -> F1:
+ return F1()
+
+ def execute_training_loop(self,
+ trn: DataLoader,
+ dev: DataLoader,
+ epochs,
+ criterion,
+ optimizer,
+ metric,
+ save_dir,
+ logger: logging.Logger,
+ devices,
+ gradient_accumulation=1,
+ **kwargs):
+ best_epoch, best_metric = 0, -1
+ optimizer, scheduler = optimizer
+ history = History()
+ timer = CountdownTimer(epochs)
+ ratio_width = len(f'{len(trn)}/{len(trn)}')
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history,
+ gradient_accumulation=gradient_accumulation,
+ linear_scheduler=scheduler if self._get_transformer() else None)
+ if dev:
+ self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
+ report = f'{timer.elapsed_human}/{timer.total_time_human}'
+ dev_score = metric.score
+ if not self._get_transformer():
+ scheduler.step(dev_score)
+ if dev_score > best_metric:
+ self.save_weights(save_dir)
+ best_metric = dev_score
+ report += ' [red]saved[/red]'
+ timer.log(report, ratio_percentage=False, newline=True, ratio=False)
+ return best_metric
+
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ linear_scheduler=None,
+ history: History = None,
+ gradient_accumulation=1,
+ **kwargs):
+ self.model.train()
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
+ total_loss = 0
+ self.reset_metrics(metric)
+ for batch in trn:
+ optimizer.zero_grad()
+ output_dict = self.feed_batch(batch)
+ self.update_metrics(batch, output_dict, metric)
+ loss = output_dict['loss']
+ if gradient_accumulation and gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += loss.item()
+ if history.step(gradient_accumulation):
+ if self.config.grad_norm:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
+ optimizer.step()
+ if linear_scheduler:
+ linear_scheduler.step()
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger)
+ del loss
+ return total_loss / timer.total
+
+ # noinspection PyMethodOverriding
+ @torch.no_grad()
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric,
+ logger,
+ ratio_width=None,
+ output=False,
+ **kwargs):
+ self.model.eval()
+ self.reset_metrics(metric)
+ timer = CountdownTimer(len(data))
+ total_loss = 0
+ if output:
+ fp = open(output, 'w')
+ for batch in data:
+ output_dict = self.feed_batch(batch)
+ if output:
+ for sent, pred, gold in zip(batch['token'], output_dict['prediction'], batch['ner']):
+ fp.write('Tokens\t' + ' '.join(sent) + '\n')
+ fp.write('Pred\t' + '\t'.join(
+ ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in pred]) + '\n')
+ fp.write('Gold\t' + '\t'.join(
+ ['[' + ' '.join(sent[x:y + 1]) + f']/{label}' for x, y, label in gold]) + '\n')
+ fp.write('\n')
+ self.update_metrics(batch, output_dict, metric)
+ loss = output_dict['loss']
+ total_loss += loss.item()
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger,
+ ratio_width=ratio_width)
+ del loss
+ if output:
+ fp.close()
+ return total_loss / timer.total, metric
+
+ def build_model(self,
+ training=True,
+ **kwargs) -> torch.nn.Module:
+ # noinspection PyTypeChecker
+ # embed: torch.nn.Embedding = self.config.embed.module(vocabs=self.vocabs)[0].embed
+ model = BiaffineNamedEntityRecognitionModel(self.config,
+ self.config.embed.module(vocabs=self.vocabs),
+ self.config.context_layer,
+ len(self.vocabs.label))
+ return model
+
+ # noinspection PyMethodOverriding
+ def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None, vocabs=None,
+ sampler_builder=None,
+ gradient_accumulation=1,
+ **kwargs) -> DataLoader:
+ if vocabs is None:
+ vocabs = self.vocabs
+ transform = TransformList(unpack_ner, FieldLength('token'))
+ if isinstance(self.config.embed, Embedding):
+ transform.append(self.config.embed.transform(vocabs=vocabs))
+ transform.append(self.vocabs)
+ dataset = self.build_dataset(data, vocabs, transform)
+ if vocabs.mutable:
+ self.build_vocabs(dataset, logger, vocabs)
+ if 'token' in vocabs:
+ lens = [x['token'] for x in dataset]
+ else:
+ lens = [len(x['token_input_ids']) for x in dataset]
+ if sampler_builder:
+ sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
+ else:
+ sampler = None
+ return PadSequenceDataLoader(batch_sampler=sampler,
+ device=device,
+ dataset=dataset)
+
+ def build_dataset(self, data, vocabs, transform):
+ dataset = JsonNERDataset(data, transform=transform,
+ doc_level_offset=self.config.get('doc_level_offset', True),
+ tagset=self.config.get('tagset', None))
+ dataset.append_transform(vocabs)
+ if isinstance(data, str):
+ dataset.purge_cache() # Enable cache
+ return dataset
+
+ def predict(self, data: Union[List[str], List[List[str]]], batch_size: int = None, ret_tokens=True, **kwargs):
+ if not data:
+ return []
+ flat = self.input_is_flat(data)
+ if flat:
+ data = [data]
+ dataloader = self.build_dataloader([{'token': x} for x in data], batch_size, False, self.device)
+ predictions = []
+ orders = []
+ for batch in dataloader:
+ output_dict = self.feed_batch(batch)
+ token = batch['token']
+ prediction = output_dict['prediction']
+ self.prediction_to_result(token, prediction, predictions, ret_tokens)
+ orders.extend(batch[IDX])
+ predictions = reorder(predictions, orders)
+ if flat:
+ return predictions[0]
+ return predictions
+
+ @staticmethod
+ def prediction_to_result(token, prediction, predictions: List, ret_tokens: Union[bool, str]):
+ for tokens, ner in zip(token, prediction):
+ prediction_per_sent = []
+ for i, (b, e, l) in enumerate(ner):
+ if ret_tokens is not None:
+ entity = tokens[b: e + 1]
+ if isinstance(ret_tokens, str):
+ entity = ret_tokens.join(entity)
+ prediction_per_sent.append((entity, l, b, e + 1))
+ else:
+ prediction_per_sent.append((b, e + 1, l))
+ predictions.append(prediction_per_sent)
+
+ @staticmethod
+ def input_is_flat(data):
+ return isinstance(data[0], str)
+
+ # noinspection PyMethodOverriding
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ embed: Embedding,
+ context_layer,
+ sampler='sorting',
+ n_buckets=32,
+ batch_size=50,
+ lexical_dropout=0.5,
+ ffnn_size=150,
+ is_flat_ner=True,
+ doc_level_offset=True,
+ lr=1e-3,
+ transformer_lr=1e-5,
+ adam_epsilon=1e-6,
+ weight_decay=0.01,
+ warmup_steps=0.1,
+ grad_norm=5.0,
+ epochs=50,
+ loss_reduction='sum',
+ gradient_accumulation=1,
+ ret_tokens=True,
+ tagset=None,
+ sampler_builder=None,
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs
+ ):
+ """
+
+ Args:
+ trn_data: Path to training set.
+ dev_data: Path to dev set.
+ save_dir: The directory to save trained component.
+ embed: Embeddings to use.
+ context_layer: A contextualization layer (transformer or RNN).
+ sampler: Sampler to use.
+ n_buckets: Number of buckets to use in KMeans sampler.
+ batch_size: The number of samples in a batch.
+ lexical_dropout: Dropout applied to hidden states of context layer.
+ ffnn_size: Feedforward size for MLPs extracting the head/tail representations.
+ is_flat_ner: ``True`` for flat NER, otherwise nested NER.
+ doc_level_offset: ``True`` to indicate the offsets in ``jsonlines`` are of document level.
+ lr: Learning rate for decoder.
+ transformer_lr: Learning rate for encoder.
+ adam_epsilon: The epsilon to use in Adam.
+ weight_decay: The weight decay to use.
+ warmup_steps: The number of warmup steps.
+ grad_norm: Gradient norm for clipping.
+ epochs: The number of epochs to train.
+ loss_reduction: The loss reduction used in aggregating losses.
+ gradient_accumulation: Number of mini-batches per update step.
+ ret_tokens: A delimiter between tokens in entities so that the surface form of an entity can be rebuilt.
+ tagset: Optional tagset to prune entities outside of this tagset from datasets.
+ sampler_builder: The builder to build sampler, which will override batch_size.
+ devices: Devices this component will live on.
+ logger: Any :class:`logging.Logger` instance.
+ seed: Random seed to reproduce this training.
+ **kwargs: Not used.
+
+ Returns:
+ The best metrics on training set.
+ """
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_vocabs(self, dataset, logger, vocabs, lock=True, label_vocab_name='label', **kwargs):
+ vocabs[label_vocab_name] = label_vocab = Vocab(pad_token=None, unk_token=None)
+ # Use null to indicate no relationship
+ label_vocab.add('')
+ timer = CountdownTimer(len(dataset))
+ for each in dataset:
+ timer.log('Building NER vocab [blink][yellow]...[/yellow][/blink]')
+ label_vocab.set_unk_as_safe_unk()
+ if lock:
+ vocabs.lock()
+ vocabs.summary(logger)
+
+ def reset_metrics(self, metrics):
+ metrics.reset()
+
+ def report_metrics(self, loss, metrics):
+ return f'loss: {loss:.4f} {metrics}'
+
+ def feed_batch(self, batch) -> Dict[str, Any]:
+ output_dict = self.model(batch)
+ output_dict['prediction'] = self.get_pred_ner(batch['token'], output_dict['candidate_ner_scores'])
+ return output_dict
+
+ def update_metrics(self, batch: dict, prediction: Union[Dict, List], metrics):
+ if isinstance(prediction, dict):
+ prediction = prediction['prediction']
+ assert len(prediction) == len(batch['ner'])
+ for pred, gold in zip(prediction, batch['ner']):
+ metrics(set(pred), set(gold))
+
+ def get_pred_ner(self, sentences, span_scores):
+ is_flat_ner = self.config.is_flat_ner
+ candidates = []
+ for sid, sent in enumerate(sentences):
+ for s in range(len(sent)):
+ for e in range(s, len(sent)):
+ candidates.append((sid, s, e))
+
+ top_spans = [[] for _ in range(len(sentences))]
+ span_scores_cpu = span_scores.tolist()
+ for i, type in enumerate(torch.argmax(span_scores, dim=-1).tolist()):
+ if type > 0:
+ sid, s, e = candidates[i]
+ top_spans[sid].append((s, e, type, span_scores_cpu[i][type]))
+
+ top_spans = [sorted(top_span, reverse=True, key=lambda x: x[3]) for top_span in top_spans]
+ sent_pred_mentions = [[] for _ in range(len(sentences))]
+ for sid, top_span in enumerate(top_spans):
+ for ns, ne, t, _ in top_span:
+ for ts, te, _ in sent_pred_mentions[sid]:
+ if ns < ts <= ne < te or ts < ns <= te < ne:
+ # for both nested and flat ner no clash is allowed
+ break
+ if is_flat_ner and (ns <= ts <= te <= ne or ts <= ns <= ne <= te):
+ # for flat ner nested mentions are not allowed
+ break
+ else:
+ sent_pred_mentions[sid].append((ns, ne, t))
+ pred_mentions = set((sid, s, e, t) for sid, spr in enumerate(sent_pred_mentions) for s, e, t in spr)
+ prediction = [[] for _ in range(len(sentences))]
+ idx_to_label = self.vocabs['label'].idx_to_token
+ for sid, s, e, t in sorted(pred_mentions):
+ prediction[sid].append((s, e, idx_to_label[t]))
+ return prediction
diff --git a/hanlp/components/ner/biaffine_ner/biaffine_ner_model.py b/hanlp/components/ner/biaffine_ner/biaffine_ner_model.py
new file mode 100644
index 000000000..78fa6acd3
--- /dev/null
+++ b/hanlp/components/ner/biaffine_ner/biaffine_ner_model.py
@@ -0,0 +1,127 @@
+from typing import Dict
+
+import torch
+import torch.nn.functional as F
+from alnlp.modules import util
+from alnlp.modules.time_distributed import TimeDistributed
+from torch import nn
+
+from ...parsers.biaffine.biaffine import Biaffine
+
+
+def initializer_1d(input_tensor, initializer):
+ assert len(input_tensor.size()) == 1
+ input_tensor = input_tensor.view(-1, 1)
+ input_tensor = initializer(input_tensor)
+ return input_tensor.view(-1)
+
+
+class BiaffineNamedEntityRecognitionModel(nn.Module):
+
+ def __init__(self, config, embed: torch.nn.Module, context_layer: torch.nn.Module, label_space_size):
+ super(BiaffineNamedEntityRecognitionModel, self).__init__()
+ self.config = config
+ self.lexical_dropout = float(self.config.lexical_dropout)
+ self.label_space_size = label_space_size
+
+ # Initialize layers and parameters
+ self.word_embedding_dim = embed.get_output_dim() # get the embedding dim
+ self.embed = embed
+ # Initialize context layer
+ self.context_layer = context_layer
+ context_layer_output_dim = context_layer.get_output_dim()
+
+ self.decoder = BiaffineNamedEntityRecognitionDecoder(context_layer_output_dim, config.ffnn_size,
+ label_space_size, config.loss_reduction)
+
+ def forward(self,
+ batch: Dict[str, torch.Tensor]
+ ):
+ keys = 'token_length', 'begin_offset', 'end_offset', 'label_id'
+ sent_lengths, gold_starts, gold_ends, gold_labels = [batch.get(k, None) for k in keys]
+ masks = util.lengths_to_mask(sent_lengths)
+ num_sentences, max_sent_length = masks.size()
+ raw_embeddings = self.embed(batch, mask=masks)
+
+ raw_embeddings = F.dropout(raw_embeddings, self.lexical_dropout, self.training)
+
+ contextualized_embeddings = self.context_layer(raw_embeddings, masks)
+ return self.decoder.decode(contextualized_embeddings, gold_starts, gold_ends, gold_labels, masks,
+ max_sent_length,
+ num_sentences, sent_lengths)
+
+
+class BiaffineNamedEntityRecognitionDecoder(nn.Module):
+ def __init__(self, hidden_size, ffnn_size, label_space_size, loss_reduction='sum') -> None:
+ """An implementation of the biaffine decoder in "Named Entity Recognition as Dependency Parsing"
+ (:cite:`yu-etal-2020-named`).
+
+ Args:
+ hidden_size: Size of hidden states.
+ ffnn_size: Feedforward size for MLPs extracting the head/tail representations.
+ label_space_size: Size of tag set.
+ loss_reduction: The loss reduction used in aggregating losses.
+ """
+ super().__init__()
+ self.loss_reduction = loss_reduction
+
+ # MLPs
+ def new_mlp():
+ return TimeDistributed(nn.Linear(hidden_size, ffnn_size))
+
+ self.start_mlp = new_mlp()
+ self.end_mlp = new_mlp()
+ self.biaffine = Biaffine(ffnn_size, label_space_size)
+
+ def forward(self, contextualized_embeddings: torch.FloatTensor, batch: Dict[str, torch.Tensor], mask=None):
+ keys = 'token_length', 'begin_offset', 'end_offset', 'label_id'
+ sent_lengths, gold_starts, gold_ends, gold_labels = [batch.get(k, None) for k in keys]
+ if mask is None:
+ mask = util.lengths_to_mask(sent_lengths)
+ num_sentences, max_sent_length = mask.size()
+ return self.decode(contextualized_embeddings, gold_starts, gold_ends, gold_labels, mask,
+ max_sent_length,
+ num_sentences, sent_lengths)
+
+ def get_dense_span_labels(self, span_starts, span_ends, span_labels, max_sentence_length):
+ num_sentences, max_spans_num = span_starts.size()
+
+ sentence_indices = torch.arange(0, num_sentences, device=span_starts.device).unsqueeze(1).expand(-1,
+ max_spans_num)
+
+ sparse_indices = torch.cat([sentence_indices.unsqueeze(2), span_starts.unsqueeze(2), span_ends.unsqueeze(2)],
+ dim=2)
+ rank = 3
+ dense_labels = torch.sparse.LongTensor(sparse_indices.view(num_sentences * max_spans_num, rank).t(),
+ span_labels.view(-1),
+ torch.Size([num_sentences] + [max_sentence_length] * (rank - 1))) \
+ .to_dense()
+ return dense_labels
+
+ def decode(self, contextualized_embeddings, gold_starts, gold_ends, gold_labels, masks, max_sent_length,
+ num_sentences, sent_lengths):
+ # Apply MLPs to starts and ends, [num_sentences, max_sentences_length,emb]
+ candidate_starts_emb = self.start_mlp(contextualized_embeddings)
+ candidate_ends_emb = self.end_mlp(contextualized_embeddings)
+ candidate_ner_scores = self.biaffine(candidate_starts_emb, candidate_ends_emb).permute([0, 2, 3, 1])
+
+ """generate candidate spans with argument pruning"""
+ # Generate masks
+ candidate_scores_mask = masks.unsqueeze(1) & masks.unsqueeze(2)
+ device = sent_lengths.device
+ sentence_ends_leq_starts = (
+ ~util.lengths_to_mask(torch.arange(max_sent_length, device=device), max_sent_length)) \
+ .unsqueeze_(0).expand(num_sentences, -1, -1)
+ candidate_scores_mask &= sentence_ends_leq_starts
+ candidate_ner_scores = candidate_ner_scores[candidate_scores_mask]
+ predict_dict = {
+ "candidate_ner_scores": candidate_ner_scores,
+
+ }
+ if gold_starts is not None:
+ gold_ner_labels = self.get_dense_span_labels(gold_starts, gold_ends, gold_labels, max_sent_length)
+ loss = torch.nn.functional.cross_entropy(candidate_ner_scores,
+ gold_ner_labels[candidate_scores_mask],
+ reduction=self.loss_reduction)
+ predict_dict['loss'] = loss
+ return predict_dict
diff --git a/hanlp/components/ner/rnn_ner.py b/hanlp/components/ner/rnn_ner.py
new file mode 100644
index 000000000..c41772031
--- /dev/null
+++ b/hanlp/components/ner/rnn_ner.py
@@ -0,0 +1,69 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-12 18:00
+from typing import Any
+
+import torch
+from alnlp.metrics import span_utils
+
+from hanlp.components.taggers.rnn_tagger import RNNTagger
+from hanlp.metrics.chunking.conlleval import SpanF1
+from hanlp_common.util import merge_locals_kwargs
+
+
+class RNNNamedEntityRecognizer(RNNTagger):
+
+ def __init__(self, **kwargs) -> None:
+ """An old-school RNN tagger using word2vec or fasttext embeddings.
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+
+ def build_metric(self, **kwargs):
+ return SpanF1(self.tagging_scheme)
+
+ def evaluate_dataloader(self, data, criterion, logger=None, ratio_width=None, **kwargs):
+ loss, metric = super().evaluate_dataloader(data, criterion, logger, ratio_width, **kwargs)
+ if logger:
+ logger.info(metric.result(True, False)[-1])
+ return loss, metric
+
+ def fit(self, trn_data, dev_data, save_dir, batch_size=50, epochs=100, embed=100, rnn_input=None, rnn_hidden=256,
+ drop=0.5, lr=0.001, patience=10, crf=True, optimizer='adam', token_key='token', tagging_scheme=None,
+ anneal_factor: float = 0.5, delimiter=None, anneal_patience=2, devices=None,
+ token_delimiter=None,
+ logger=None,
+ verbose=True, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def update_metrics(self, metric, logits, y, mask, batch, prediction):
+ logits = self.decode_output(logits, mask, batch)
+ if isinstance(logits, torch.Tensor):
+ logits = logits.tolist()
+ metric(self._id_to_tags(logits), batch['tag'])
+
+ def predict(self, tokens: Any, batch_size: int = None, **kwargs):
+ return super().predict(tokens, batch_size, **kwargs)
+
+ def predict_data(self, data, batch_size, **kwargs):
+ outputs = super().predict_data(data, batch_size)
+ tagging_scheme = self.tagging_scheme
+ if tagging_scheme == 'IOBES':
+ entities = [span_utils.iobes_tags_to_spans(y) for y in outputs]
+ elif tagging_scheme == 'BIO':
+ entities = [span_utils.bio_tags_to_spans(y) for y in outputs]
+ elif tagging_scheme == 'BIOUL':
+ entities = [span_utils.bioul_tags_to_spans(y) for y in outputs]
+ else:
+ raise ValueError(f'Unrecognized tag scheme {tagging_scheme}')
+ for i, (tokens, es) in enumerate(zip(data, entities)):
+ outputs[i] = [(self.config.token_delimiter.join(tokens[b:e + 1]), t, b, e + 1) for t, (b, e) in es]
+ return outputs
+
+ def save_config(self, save_dir, filename='config.json'):
+ if self.config.token_delimiter is None:
+ self.config.token_delimiter = '' if all(
+ [len(x) == 1 for x in self.vocabs[self.config.token_key].idx_to_token[-100:]]) else ' '
+ super().save_config(save_dir, filename)
diff --git a/hanlp/components/ner/transformer_ner.py b/hanlp/components/ner/transformer_ner.py
new file mode 100644
index 000000000..a75fd5f0a
--- /dev/null
+++ b/hanlp/components/ner/transformer_ner.py
@@ -0,0 +1,217 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-07 11:08
+import functools
+from typing import Union, List, Dict, Any, Set
+
+import torch
+from hanlp_trie import DictInterface, TrieDict
+
+from hanlp.common.dataset import SamplerBuilder
+from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
+from hanlp.metrics.chunking.sequence_labeling import get_entities
+from hanlp.metrics.f1 import F1
+from hanlp.datasets.ner.json_ner import prune_ner_tagset
+from hanlp.utils.string_util import guess_delimiter
+from hanlp_common.util import merge_locals_kwargs
+
+
+class TransformerNamedEntityRecognizer(TransformerTagger):
+
+ def __init__(self, **kwargs) -> None:
+ r"""A simple tagger using transformers and a linear layer with an optional CRF
+ (:cite:`lafferty2001conditional`) layer for
+ NER task. It can utilize whitelist gazetteers which is dict mapping from entity name to entity type.
+ During decoding, it performs longest-prefix-matching of these words to override the prediction from
+ underlining statistical model. It also uses a blacklist to mask out mis-predicted entities.
+
+ .. Note:: For algorithm beginners, longest-prefix-matching is the prerequisite to understand what dictionary can
+ do and what it can't do. The tutorial in `this book `_ can be very helpful.
+
+ Args:
+ **kwargs: Not used.
+ """
+ super().__init__(**kwargs)
+
+ def build_metric(self, **kwargs):
+ return F1()
+
+ # noinspection PyMethodOverriding
+ def update_metrics(self, metric, logits, y, mask, batch, prediction):
+ for p, g in zip(prediction, self.tag_to_span(batch['tag'], batch)):
+ pred = set(p)
+ gold = set(g)
+ metric(pred, gold)
+
+ # noinspection PyMethodOverriding
+ def decode_output(self, logits, mask, batch, model=None):
+ output = super().decode_output(logits, mask, batch, model)
+ if isinstance(output, torch.Tensor):
+ output = output.tolist()
+ prediction = self.id_to_tags(output, [len(x) for x in batch['token']])
+ return self.tag_to_span(prediction, batch)
+
+ def tag_to_span(self, batch_tags, batch):
+ spans = []
+ sents = batch[self.config.token_key]
+ dict_whitelist = self.dict_whitelist
+ dict_blacklist = self.dict_blacklist
+ for tags, tokens in zip(batch_tags, sents):
+ if dict_whitelist:
+ for start, end, label in dict_whitelist.tokenize(tokens):
+ if (tags[start].startswith('B') and tags[end - 1].startswith('E')) or all(
+ x.startswith('S') for x in tags[start:end]):
+ if end - start == 1:
+ tags[start] = 'S-' + label
+ else:
+ tags[start] = 'B-' + label
+ for i in range(start + 1, end - 1):
+ tags[i] = 'I-' + label
+ tags[end - 1] = 'E-' + label
+ entities = get_entities(tags)
+ if dict_blacklist:
+ pruned = []
+ delimiter_in_entity = self.config.get('delimiter_in_entity', ' ')
+ for label, start, end in entities:
+ entity = delimiter_in_entity.join(tokens[start:end])
+ if entity not in dict_blacklist:
+ pruned.append((label, start, end))
+ entities = pruned
+ spans.append(entities)
+ return spans
+
+ def decorate_spans(self, spans, batch):
+ batch_ner = []
+ delimiter_in_entity = self.config.get('delimiter_in_entity', ' ')
+ for spans_per_sent, tokens in zip(spans, batch.get(f'{self.config.token_key}_', batch[self.config.token_key])):
+ ner_per_sent = []
+ for label, start, end in spans_per_sent:
+ ner_per_sent.append((delimiter_in_entity.join(tokens[start:end]), label, start, end))
+ batch_ner.append(ner_per_sent)
+ return batch_ner
+
+ def generate_prediction_filename(self, tst_data, save_dir):
+ return super().generate_prediction_filename(tst_data.replace('.tsv', '.txt'), save_dir)
+
+ def prediction_to_human(self, pred, vocab, batch):
+ return self.decorate_spans(pred, batch)
+
+ def input_is_flat(self, tokens):
+ return tokens and isinstance(tokens, list) and isinstance(tokens[0], str)
+
+ def fit(self, trn_data, dev_data, save_dir, transformer,
+ delimiter_in_entity=None,
+ average_subwords=False,
+ word_dropout: float = 0.2,
+ hidden_dropout=None,
+ layer_dropout=0,
+ scalar_mix=None,
+ grad_norm=5.0,
+ lr=5e-5,
+ transformer_lr=None,
+ adam_epsilon=1e-8,
+ weight_decay=0,
+ warmup_steps=0.1,
+ crf=False,
+ secondary_encoder=None,
+ reduction='sum',
+ batch_size=32,
+ sampler_builder: SamplerBuilder = None,
+ epochs=3,
+ tagset=None,
+ token_key=None,
+ max_seq_len=None,
+ sent_delimiter=None,
+ char_level=False,
+ hard_constraint=False,
+ transform=None,
+ logger=None,
+ seed=None,
+ devices: Union[float, int, List[int]] = None,
+ **kwargs):
+ """Fit component to training set.
+
+ Args:
+ trn_data: Training set.
+ dev_data: Development set.
+ save_dir: The directory to save trained component.
+ transformer: An identifier of a pre-trained transformer.
+ delimiter_in_entity: The delimiter between tokens in entity, which is used to rebuild entity by joining
+ tokens during decoding.
+ average_subwords: ``True`` to average subword representations.
+ word_dropout: Dropout rate to randomly replace a subword with MASK.
+ hidden_dropout: Dropout rate applied to hidden states.
+ layer_dropout: Randomly zero out hidden states of a transformer layer.
+ scalar_mix: Layer attention.
+ grad_norm: Gradient norm for clipping.
+ lr: Learning rate for decoder.
+ transformer_lr: Learning for encoder.
+ adam_epsilon: The epsilon to use in Adam.
+ weight_decay: The weight decay to use.
+ warmup_steps: The number of warmup steps.
+ crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
+ secondary_encoder: An optional secondary encoder to provide enhanced representation by taking the hidden
+ states from the main encoder as input.
+ reduction: The loss reduction used in aggregating losses.
+ batch_size: The number of samples in a batch.
+ sampler_builder: The builder to build sampler, which will override batch_size.
+ epochs: The number of epochs to train.
+ tagset: Optional tagset to prune entities outside of this tagset from datasets.
+ token_key: The key to tokens in dataset.
+ max_seq_len: The maximum sequence length. Sequence longer than this will be handled by sliding
+ window.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level, which is never the case for
+ lemmatization.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ transform: An optional transform to be applied to samples. Usually a character normalization transform is
+ passed in.
+ devices: Devices this component will live on.
+ logger: Any :class:`logging.Logger` instance.
+ seed: Random seed to reproduce this training.
+ **kwargs: Not used.
+
+ Returns:
+ The best metrics on training set.
+ """
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_vocabs(self, trn, logger, **kwargs):
+ super().build_vocabs(trn, logger, **kwargs)
+ if self.config.get('delimiter_in_entity', None) is None:
+ # Check the first sample to guess the delimiter between tokens in a NE
+ tokens = trn[0][self.config.token_key]
+ delimiter_in_entity = guess_delimiter(tokens)
+ logger.info(f'Guess the delimiter between tokens in named entity could be [blue]"{delimiter_in_entity}'
+ f'"[/blue]. If not, specify `delimiter_in_entity` in `fit()`')
+ self.config.delimiter_in_entity = delimiter_in_entity
+
+ def build_dataset(self, data, transform=None, **kwargs):
+ dataset = super().build_dataset(data, transform, **kwargs)
+ if isinstance(data, str):
+ tagset = self.config.get('tagset', None)
+ if tagset:
+ dataset.append_transform(functools.partial(prune_ner_tagset, tagset=tagset))
+ return dataset
+
+ @property
+ def dict_whitelist(self) -> DictInterface:
+ return self.config.get('dict_whitelist', None)
+
+ @dict_whitelist.setter
+ def dict_whitelist(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
+ if not isinstance(dictionary, DictInterface):
+ dictionary = TrieDict(dictionary)
+ self.config.dict_whitelist = dictionary
+
+ @property
+ def dict_blacklist(self) -> DictInterface:
+ return self.config.get('dict_blacklist', None)
+
+ @dict_blacklist.setter
+ def dict_blacklist(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
+ if not isinstance(dictionary, DictInterface):
+ dictionary = TrieDict(dictionary)
+ self.config.dict_blacklist = dictionary
diff --git a/hanlp/components/ner.py b/hanlp/components/ner_tf.py
similarity index 81%
rename from hanlp/components/ner.py
rename to hanlp/components/ner_tf.py
index 79f61f00a..9f8283287 100644
--- a/hanlp/components/ner.py
+++ b/hanlp/components/ner_tf.py
@@ -5,16 +5,16 @@
from typing import Union, Any, Tuple, Iterable
import tensorflow as tf
-from hanlp.components.taggers.transformers.transformer_transform import TransformerTransform
+from hanlp.components.taggers.transformers.transformer_transform_tf import TransformerTransform
-from hanlp.common.transform import Transform
+from hanlp.common.transform_tf import Transform
-from hanlp.common.component import KerasComponent
-from hanlp.components.taggers.ngram_conv.ngram_conv_tagger import NgramConvTagger
-from hanlp.components.taggers.rnn_tagger import RNNTagger
-from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
-from hanlp.metrics.chunking.sequence_labeling import get_entities, iobes_to_span
-from hanlp.utils.util import merge_locals_kwargs
+from hanlp.common.keras_component import KerasComponent
+from hanlp.components.taggers.ngram_conv.ngram_conv_tagger import NgramConvTaggerTF
+from hanlp.components.taggers.rnn_tagger_tf import RNNTaggerTF
+from hanlp.components.taggers.transformers.transformer_tagger_tf import TransformerTaggerTF
+from hanlp.metrics.chunking.sequence_labeling import iobes_to_span
+from hanlp_common.util import merge_locals_kwargs
class IOBES_NamedEntityRecognizer(KerasComponent, ABC):
@@ -26,12 +26,13 @@ def predict_batch(self, batch, inputs=None):
class IOBES_Transform(Transform):
- def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
+ def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None,
+ batch=None) -> Iterable:
for words, tags in zip(inputs, super().Y_to_outputs(Y, gold, inputs=inputs, X=X, batch=batch)):
yield from iobes_to_span(words, tags)
-class RNNNamedEntityRecognizer(RNNTagger, IOBES_NamedEntityRecognizer):
+class RNNNamedEntityRecognizerTF(RNNTaggerTF, IOBES_NamedEntityRecognizer):
def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, logger=None,
@@ -52,7 +53,7 @@ def build_loss(self, loss, **kwargs):
return super().build_loss(loss, **kwargs)
-class NgramConvNamedEntityRecognizer(NgramConvTagger, IOBES_NamedEntityRecognizer):
+class NgramConvNamedEntityRecognizerTF(NgramConvTaggerTF, IOBES_NamedEntityRecognizer):
def fit(self, trn_data: Any, dev_data: Any, save_dir: str, word_embed: Union[str, int, dict] = 200,
ngram_embed: Union[str, int, dict] = 50, embedding_trainable=True, window_size=4, kernel_size=3,
@@ -69,7 +70,7 @@ class IOBES_TransformerTransform(IOBES_Transform, TransformerTransform):
pass
-class TransformerNamedEntityRecognizer(TransformerTagger):
+class TransformerNamedEntityRecognizerTF(TransformerTaggerTF):
def __init__(self, transform: TransformerTransform = None) -> None:
if not transform:
diff --git a/hanlp/components/parsers/alg.py b/hanlp/components/parsers/alg.py
index 302c86a05..27076ec12 100644
--- a/hanlp/components/parsers/alg.py
+++ b/hanlp/components/parsers/alg.py
@@ -1,81 +1,761 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-26 19:49
-# Ported from the PyTorch implementation https://github.com/zysite/biaffine-parser
-from typing import List
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
-import tensorflow as tf
+import torch
-def nonzero(t: tf.Tensor) -> tf.Tensor:
- return tf.where(t > 0)
+from hanlp_common.conll import isprojective
-def view(t: tf.Tensor, *dims) -> tf.Tensor:
- return tf.reshape(t, dims)
+def kmeans(x, k, max_it=32):
+ r"""
+ KMeans algorithm for clustering the sentences by length.
+ Args:
+ x (list[int]):
+ The list of sentence lengths.
+ k (int):
+ The number of clusters.
+ This is an approximate value. The final number of clusters can be less or equal to `k`.
+ max_it (int):
+ Maximum number of iterations.
+ If centroids does not converge after several iterations, the algorithm will be early stopped.
-def arange(n: int) -> tf.Tensor:
- return tf.range(n)
+ Returns:
+ list[float], list[list[int]]:
+ The first list contains average lengths of sentences in each cluster.
+ The second is the list of clusters holding indices of data points.
-
-def randperm(n: int) -> tf.Tensor:
- return tf.random.shuffle(arange(n))
-
-
-def tolist(t: tf.Tensor) -> List:
- if isinstance(t, tf.Tensor):
- t = t.numpy()
- return t.tolist()
-
-
-def kmeans(x, k):
- """
- See https://github.com/zysite/biaffine-parser/blob/master/parser/utils/alg.py#L7
- :param x:
- :param k:
- :return:
+ Examples:
+ >>> x = torch.randint(10,20,(10,)).tolist()
+ >>> x
+ [15, 10, 17, 11, 18, 13, 17, 19, 18, 14]
+ >>> centroids, clusters = kmeans(x, 3)
+ >>> centroids
+ [10.5, 14.0, 17.799999237060547]
+ >>> clusters
+ [[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]]
"""
- x = tf.constant(x, dtype=tf.float32)
- # count the frequency of each datapoint
- d, indices, f = tf.unique_with_counts(x, tf.int32)
- f = tf.cast(f, tf.float32)
- # calculate the sum of the values of the same datapoints
- total = d * f
+
+ # the number of clusters must not be greater than the number of datapoints
+ x, k = torch.tensor(x, dtype=torch.float), min(len(x), k)
+ # collect unique datapoints
+ d = x.unique()
# initialize k centroids randomly
- c, old = tf.random.shuffle(d)[:k], None
- # assign labels to each datapoint based on centroids
- dists = tf.abs(tf.expand_dims(d, -1) - c)
- y = tf.argmin(dists, axis=-1, output_type=tf.int32)
- dists = tf.gather_nd(dists, tf.transpose(tf.stack([tf.range(tf.shape(dists)[0], dtype=tf.int32), y])))
- # make sure number of datapoints is greater than that of clusters
- assert len(x) >= k, f"unable to assign {len(x)} datapoints to {k} clusters"
-
- while old is None or not tf.reduce_all(c == old):
+ c = d[torch.randperm(len(d))[:k]]
+ # assign each datapoint to the cluster with the closest centroid
+ dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
+
+ for _ in range(max_it):
# if an empty cluster is encountered,
- # choose the farthest datapoint from the biggest cluster
- # and move that the empty one
- for i in range(k):
- if not tf.reduce_any(y == i):
- mask = tf.cast(y == tf.expand_dims(tf.range(k, dtype=tf.int32), -1), tf.float32)
- lens = tf.reduce_sum(mask, axis=-1)
- biggest = view(nonzero(mask[tf.argmax(lens)]), -1)
- farthest = tf.argmax(tf.gather(dists, biggest))
- tf.tensor_scatter_nd_update(y, tf.expand_dims(tf.expand_dims(biggest[farthest], -1), -1), [i])
- mask = tf.cast(y == tf.expand_dims(tf.range(k, dtype=tf.int32), -1), tf.float32)
+ # choose the farthest datapoint from the biggest cluster and move that the empty one
+ mask = torch.arange(k).unsqueeze(-1).eq(y)
+ none = torch.where(~mask.any(-1))[0].tolist()
+ while len(none) > 0:
+ for i in none:
+ # the biggest cluster
+ b = torch.where(mask[mask.sum(-1).argmax()])[0]
+ # the datapoint farthest from the centroid of cluster b
+ f = dists[b].argmax()
+ # update the assigned cluster of f
+ y[b[f]] = i
+ # re-calculate the mask
+ mask = torch.arange(k).unsqueeze(-1).eq(y)
+ none = torch.where(~mask.any(-1))[0].tolist()
# update the centroids
- c, old = tf.cast(tf.reduce_sum(total * mask, axis=-1), tf.float32) / tf.cast(tf.reduce_sum(f * mask, axis=-1),
- tf.float32), c
+ c, old = (x * mask).sum(-1) / mask.sum(-1), c
# re-assign all datapoints to clusters
- dists = tf.abs(tf.expand_dims(d, -1) - c)
- y = tf.argmin(dists, axis=-1, output_type=tf.int32)
- dists = tf.gather_nd(dists, tf.transpose(tf.stack([tf.range(tf.shape(dists)[0], dtype=tf.int32), y])))
+ dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
+ # stop iteration early if the centroids converge
+ if c.equal(old):
+ break
# assign all datapoints to the new-generated clusters
- # without considering the empty ones
- y, (assigned, _) = tf.gather(y, indices), tf.unique(y)
+ # the empty ones are discarded
+ assigned = y.unique().tolist()
# get the centroids of the assigned clusters
- centroids = tf.gather(c, assigned).numpy().tolist()
+ centroids = c[assigned].tolist()
# map all values of datapoints to buckets
- clusters = [tf.squeeze(tf.where(y == i), axis=-1).numpy().tolist() for i in assigned]
+ clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned]
return centroids, clusters
+
+
+def eisner(scores, mask):
+ r"""
+ First-order Eisner algorithm for projective decoding.
+
+ References:
+ - Ryan McDonald, Koby Crammer and Fernando Pereira. 2005.
+ `Online Large-Margin Training of Dependency Parsers`_.
+
+ Args:
+ scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all dependent-head pairs.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ The mask to avoid parsing over padding tokens.
+ The first column serving as pseudo words for roots should be ``False``.
+
+ Returns:
+ ~torch.Tensor:
+ A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees.
+
+ Examples:
+ >>> scores = torch.tensor([[[-13.5026, -18.3700, -13.0033, -16.6809],
+ [-36.5235, -28.6344, -28.4696, -31.6750],
+ [ -2.9084, -7.4825, -1.4861, -6.8709],
+ [-29.4880, -27.6905, -26.1498, -27.0233]]])
+ >>> mask = torch.tensor([[False, True, True, True]])
+ >>> eisner(scores, mask)
+ tensor([[0, 2, 0, 2]])
+
+ .. _Online Large-Margin Training of Dependency Parsers:
+ https://www.aclweb.org/anthology/P05-1012/
+ """
+
+ lens = mask.sum(1)
+ batch_size, seq_len, _ = scores.shape
+ scores = scores.permute(2, 1, 0)
+ s_i = torch.full_like(scores, float('-inf'))
+ s_c = torch.full_like(scores, float('-inf'))
+ p_i = scores.new_zeros(seq_len, seq_len, batch_size).long()
+ p_c = scores.new_zeros(seq_len, seq_len, batch_size).long()
+ s_c.diagonal().fill_(0)
+
+ for w in range(1, seq_len):
+ n = seq_len - w
+ starts = p_i.new_tensor(range(n)).unsqueeze(0)
+ # ilr = C(i->r) + C(j->r+1)
+ ilr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
+ # [batch_size, n, w]
+ il = ir = ilr.permute(2, 0, 1)
+ # I(j->i) = max(C(i->r) + C(j->r+1) + s(j->i)), i <= r < j
+ il_span, il_path = il.max(-1)
+ s_i.diagonal(-w).copy_(il_span + scores.diagonal(-w))
+ p_i.diagonal(-w).copy_(il_path + starts)
+ # I(i->j) = max(C(i->r) + C(j->r+1) + s(i->j)), i <= r < j
+ ir_span, ir_path = ir.max(-1)
+ s_i.diagonal(w).copy_(ir_span + scores.diagonal(w))
+ p_i.diagonal(w).copy_(ir_path + starts)
+
+ # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
+ cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
+ cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
+ s_c.diagonal(-w).copy_(cl_span)
+ p_c.diagonal(-w).copy_(cl_path + starts)
+ # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
+ cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
+ cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
+ s_c.diagonal(w).copy_(cr_span)
+ s_c[0, w][lens.ne(w)] = float('-inf')
+ p_c.diagonal(w).copy_(cr_path + starts + 1)
+
+ def backtrack(p_i, p_c, heads, i, j, complete):
+ if i == j:
+ return
+ if complete:
+ r = p_c[i, j]
+ backtrack(p_i, p_c, heads, i, r, False)
+ backtrack(p_i, p_c, heads, r, j, True)
+ else:
+ r, heads[j] = p_i[i, j], i
+ i, j = sorted((i, j))
+ backtrack(p_i, p_c, heads, i, r, True)
+ backtrack(p_i, p_c, heads, j, r + 1, True)
+
+ preds = []
+ p_c = p_c.permute(2, 0, 1).cpu()
+ p_i = p_i.permute(2, 0, 1).cpu()
+ for i, length in enumerate(lens.tolist()):
+ heads = p_c.new_zeros(length + 1, dtype=torch.long)
+ backtrack(p_i[i], p_c[i], heads, 0, length, True)
+ preds.append(heads.to(mask.device))
+
+ return pad(preds, total_length=seq_len).to(mask.device)
+
+
+def backtrack(p_i, p_c, heads, i, j, complete):
+ if i == j:
+ return
+ if complete:
+ r = p_c[i, j]
+ backtrack(p_i, p_c, heads, i, r, False)
+ backtrack(p_i, p_c, heads, r, j, True)
+ else:
+ r, heads[j] = p_i[i, j], i
+ i, j = sorted((i, j))
+ backtrack(p_i, p_c, heads, i, r, True)
+ backtrack(p_i, p_c, heads, j, r + 1, True)
+
+
+def stripe(x, n, w, offset=(0, 0), dim=1):
+ """r'''Returns a diagonal stripe of the tensor.
+
+ Args:
+ x: Tensor
+ n: int
+ w: int
+ offset: tuple (Default value = (0)
+ dim: int (Default value = 1)
+ Example:
+ 0):
+
+ Returns:
+
+ >>> x = torch.arange(25).view(5, 5)
+ >>> x
+ tensor([[ 0, 1, 2, 3, 4],
+ [ 5, 6, 7, 8, 9],
+ [10, 11, 12, 13, 14],
+ [15, 16, 17, 18, 19],
+ [20, 21, 22, 23, 24]])
+ >>> stripe(x, 2, 3, (1, 1))
+ tensor([[ 6, 7, 8],
+ [12, 13, 14]])
+ >>> stripe(x, 2, 3, dim=0)
+ tensor([[ 0, 5, 10],
+ [ 6, 11, 16]])
+ """
+ x, seq_len = x.contiguous(), x.size(1)
+ stride, numel = list(x.stride()), x[0, 0].numel()
+ stride[0] = (seq_len + 1) * numel
+ stride[1] = (1 if dim == 1 else seq_len) * numel
+ return x.as_strided(size=(n, w, *x.shape[2:]),
+ stride=stride,
+ storage_offset=(offset[0] * seq_len + offset[1]) * numel)
+
+
+def cky(scores, mask):
+ r"""
+ The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees.
+
+ References:
+ - Yu Zhang, Houquan Zhou and Zhenghua Li. 2020.
+ `Fast and Accurate Neural CRF Constituency Parsing`_.
+
+ Args:
+ scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all candidate constituents.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+ The mask to avoid parsing over padding tokens.
+ For each square matrix in a batch, the positions except upper triangular part should be masked out.
+
+ Returns:
+ Sequences of factorized predicted bracketed trees that are traversed in pre-order.
+
+ Examples:
+ >>> scores = torch.tensor([[[ 2.5659, 1.4253, -2.5272, 3.3011],
+ [ 1.3687, -0.5869, 1.0011, 3.3020],
+ [ 1.2297, 0.4862, 1.1975, 2.5387],
+ [-0.0511, -1.2541, -0.7577, 0.2659]]])
+ >>> mask = torch.tensor([[[False, True, True, True],
+ [False, False, True, True],
+ [False, False, False, True],
+ [False, False, False, False]]])
+ >>> cky(scores, mask)
+ [[(0, 3), (0, 1), (1, 3), (1, 2), (2, 3)]]
+
+ .. _Cocke-Kasami-Younger:
+ https://en.wikipedia.org/wiki/CYK_algorithm
+ .. _Fast and Accurate Neural CRF Constituency Parsing:
+ https://www.ijcai.org/Proceedings/2020/560/
+ """
+
+ lens = mask[:, 0].sum(-1)
+ scores = scores.permute(1, 2, 0)
+ seq_len, seq_len, batch_size = scores.shape
+ s = scores.new_zeros(seq_len, seq_len, batch_size)
+ p = scores.new_zeros(seq_len, seq_len, batch_size).long()
+
+ for w in range(1, seq_len):
+ n = seq_len - w
+ starts = p.new_tensor(range(n)).unsqueeze(0)
+
+ if w == 1:
+ s.diagonal(w).copy_(scores.diagonal(w))
+ continue
+ # [n, w, batch_size]
+ s_span = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0)
+ # [batch_size, n, w]
+ s_span = s_span.permute(2, 0, 1)
+ # [batch_size, n]
+ s_span, p_span = s_span.max(-1)
+ s.diagonal(w).copy_(s_span + scores.diagonal(w))
+ p.diagonal(w).copy_(p_span + starts + 1)
+
+ def backtrack(p, i, j):
+ if j == i + 1:
+ return [(i, j)]
+ split = p[i][j]
+ ltree = backtrack(p, i, split)
+ rtree = backtrack(p, split, j)
+ return [(i, j)] + ltree + rtree
+
+ p = p.permute(2, 0, 1).tolist()
+ trees = [backtrack(p[i], 0, length) for i, length in enumerate(lens.tolist())]
+
+ return trees
+
+
+def istree(sequence, proj=False, multiroot=False):
+ r"""
+ Checks if the arcs form an valid dependency tree.
+
+ Args:
+ sequence (list[int]):
+ A list of head indices.
+ proj (bool):
+ If ``True``, requires the tree to be projective. Default: ``False``.
+ multiroot (bool):
+ If ``False``, requires the tree to contain only a single root. Default: ``True``.
+
+ Returns:
+ ``True`` if the arcs form an valid tree, ``False`` otherwise.
+
+ Examples:
+ >>> istree([3, 0, 0, 3], multiroot=True)
+ True
+ >>> istree([3, 0, 0, 3], proj=True)
+ False
+ """
+
+ if proj and not isprojective(sequence):
+ return False
+ n_roots = sum(head == 0 for head in sequence)
+ if n_roots == 0:
+ return False
+ if not multiroot and n_roots > 1:
+ return False
+ if any(i == head for i, head in enumerate(sequence, 1)):
+ return False
+ return next(tarjan(sequence), None) is None
+
+
+def tarjan(sequence):
+ r"""
+ Tarjan algorithm for finding Strongly Connected Components (SCCs) of a graph.
+
+ Args:
+ sequence (list):
+ List of head indices.
+
+ Yields:
+ A list of indices that make up a SCC. All self-loops are ignored.
+
+ Examples:
+ >>> next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle
+ [2, 5, 1]
+ """
+
+ sequence = [-1] + sequence
+ # record the search order, i.e., the timestep
+ dfn = [-1] * len(sequence)
+ # record the the smallest timestep in a SCC
+ low = [-1] * len(sequence)
+ # push the visited into the stack
+ stack, onstack = [], [False] * len(sequence)
+
+ def connect(i, timestep):
+ dfn[i] = low[i] = timestep[0]
+ timestep[0] += 1
+ stack.append(i)
+ onstack[i] = True
+
+ for j, head in enumerate(sequence):
+ if head != i:
+ continue
+ if dfn[j] == -1:
+ yield from connect(j, timestep)
+ low[i] = min(low[i], low[j])
+ elif onstack[j]:
+ low[i] = min(low[i], dfn[j])
+
+ # a SCC is completed
+ if low[i] == dfn[i]:
+ cycle = [stack.pop()]
+ while cycle[-1] != i:
+ onstack[cycle[-1]] = False
+ cycle.append(stack.pop())
+ onstack[i] = False
+ # ignore the self-loop
+ if len(cycle) > 1:
+ yield cycle
+
+ timestep = [0]
+ for i in range(len(sequence)):
+ if dfn[i] == -1:
+ yield from connect(i, timestep)
+
+
+def chuliu_edmonds(s):
+ r"""
+ ChuLiu/Edmonds algorithm for non-projective decoding.
+
+ Some code is borrowed from `tdozat's implementation`_.
+ Descriptions of notations and formulas can be found in
+ `Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
+
+ Notes:
+ The algorithm does not guarantee to parse a single-root tree.
+
+ References:
+ - Ryan McDonald, Fernando Pereira, Kiril Ribarov and Jan Hajic. 2005.
+ `Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
+
+ Args:
+ s (~torch.Tensor): ``[seq_len, seq_len]``.
+ Scores of all dependent-head pairs.
+
+ Returns:
+ ~torch.Tensor:
+ A tensor with shape ``[seq_len]`` for the resulting non-projective parse tree.
+
+ .. _tdozat's implementation:
+ https://github.com/tdozat/Parser-v3
+ .. _Non-projective Dependency Parsing using Spanning Tree Algorithms:
+ https://www.aclweb.org/anthology/H05-1066/
+ """
+
+ s[0, 1:] = float('-inf')
+ # prevent self-loops
+ s.diagonal()[1:].fill_(float('-inf'))
+ # select heads with highest scores
+ tree = s.argmax(-1)
+ # return the cycle finded by tarjan algorithm lazily
+ cycle = next(tarjan(tree.tolist()[1:]), None)
+ # if the tree has no cycles, then it is a MST
+ if not cycle:
+ return tree
+ # indices of cycle in the original tree
+ cycle = torch.tensor(cycle)
+ # indices of noncycle in the original tree
+ noncycle = torch.ones(len(s)).index_fill_(0, cycle, 0)
+ noncycle = torch.where(noncycle.gt(0))[0]
+
+ def contract(s):
+ # heads of cycle in original tree
+ cycle_heads = tree[cycle]
+ # scores of cycle in original tree
+ s_cycle = s[cycle, cycle_heads]
+
+ # calculate the scores of cycle's potential dependents
+ # s(c->x) = max(s(x'->x)), x in noncycle and x' in cycle
+ s_dep = s[noncycle][:, cycle]
+ # find the best cycle head for each noncycle dependent
+ deps = s_dep.argmax(1)
+ # calculate the scores of cycle's potential heads
+ # s(x->c) = max(s(x'->x) - s(a(x')->x') + s(cycle)), x in noncycle and x' in cycle
+ # a(v) is the predecessor of v in cycle
+ # s(cycle) = sum(s(a(v)->v))
+ s_head = s[cycle][:, noncycle] - s_cycle.view(-1, 1) + s_cycle.sum()
+ # find the best noncycle head for each cycle dependent
+ heads = s_head.argmax(0)
+
+ contracted = torch.cat((noncycle, torch.tensor([-1])))
+ # calculate the scores of contracted graph
+ s = s[contracted][:, contracted]
+ # set the contracted graph scores of cycle's potential dependents
+ s[:-1, -1] = s_dep[range(len(deps)), deps]
+ # set the contracted graph scores of cycle's potential heads
+ s[-1, :-1] = s_head[heads, range(len(heads))]
+
+ return s, heads, deps
+
+ # keep track of the endpoints of the edges into and out of cycle for reconstruction later
+ s, heads, deps = contract(s)
+
+ # y is the contracted tree
+ y = chuliu_edmonds(s)
+ # exclude head of cycle from y
+ y, cycle_head = y[:-1], y[-1]
+
+ # fix the subtree with no heads coming from the cycle
+ # len(y) denotes heads coming from the cycle
+ subtree = y < len(y)
+ # add the nodes to the new tree
+ tree[noncycle[subtree]] = noncycle[y[subtree]]
+ # fix the subtree with heads coming from the cycle
+ subtree = ~subtree
+ # add the nodes to the tree
+ tree[noncycle[subtree]] = cycle[deps[subtree]]
+ # fix the root of the cycle
+ cycle_root = heads[cycle_head]
+ # break the cycle and add the root of the cycle to the tree
+ tree[cycle[cycle_root]] = noncycle[cycle_head]
+
+ return tree
+
+
+def mst(scores, mask, multiroot=False):
+ r"""
+ MST algorithm for decoding non-pojective trees.
+ This is a wrapper for ChuLiu/Edmonds algorithm.
+
+ The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots,
+ If ``multiroot=True`` and there indeed exist multi-roots, the algorithm seeks to find
+ best single-root trees by iterating all possible single-root trees parsed by ChuLiu/Edmonds.
+ Otherwise the resulting trees are directly taken as the final outputs.
+
+ Args:
+ scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all dependent-head pairs.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ The mask to avoid parsing over padding tokens.
+ The first column serving as pseudo words for roots should be ``False``.
+ muliroot (bool):
+ Ensures to parse a single-root tree If ``False``.
+
+ Returns:
+ ~torch.Tensor:
+ A tensor with shape ``[batch_size, seq_len]`` for the resulting non-projective parse trees.
+
+ Examples:
+ >>> scores = torch.tensor([[[-11.9436, -13.1464, -6.4789, -13.8917],
+ [-60.6957, -60.2866, -48.6457, -63.8125],
+ [-38.1747, -49.9296, -45.2733, -49.5571],
+ [-19.7504, -23.9066, -9.9139, -16.2088]]])
+ >>> scores[:, 0, 1:] = float('-inf')
+ >>> scores.diagonal(0, 1, 2)[1:].fill_(float('-inf'))
+ >>> mask = torch.tensor([[False, True, True, True]])
+ >>> mst(scores, mask)
+ tensor([[0, 2, 0, 2]])
+ """
+
+ batch_size, seq_len, _ = scores.shape
+ scores = scores.detach().cpu().unbind()
+
+ preds = []
+ for i, length in enumerate(mask.sum(1).tolist()):
+ s = scores[i][:length + 1, :length + 1]
+ tree = chuliu_edmonds(s)
+ roots = torch.where(tree[1:].eq(0))[0] + 1
+ if not multiroot and len(roots) > 1:
+ s_root = s[:, 0]
+ s_best = float('-inf')
+ s = s.index_fill(1, torch.tensor(0), float('-inf'))
+ for root in roots:
+ s[:, 0] = float('-inf')
+ s[root, 0] = s_root[root]
+ t = chuliu_edmonds(s)
+ s_tree = s[1:].gather(1, t[1:].unsqueeze(-1)).sum()
+ if s_tree > s_best:
+ s_best, tree = s_tree, t
+ preds.append(tree)
+
+ return pad(preds, total_length=seq_len).to(mask.device)
+
+
+def eisner2o(scores, mask):
+ r"""
+ Second-order Eisner algorithm for projective decoding.
+ This is an extension of the first-order one that further incorporates sibling scores into tree scoring.
+
+ References:
+ - Ryan McDonald and Fernando Pereira. 2006.
+ `Online Learning of Approximate Dependency Parsing Algorithms`_.
+
+ Args:
+ scores (~torch.Tensor, ~torch.Tensor):
+ A tuple of two tensors representing the first-order and second-order scores repectively.
+ The first (``[batch_size, seq_len, seq_len]``) holds scores of all dependent-head pairs.
+ The second (``[batch_size, seq_len, seq_len, seq_len]``) holds scores of all dependent-head-sibling triples.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ The mask to avoid parsing over padding tokens.
+ The first column serving as pseudo words for roots should be ``False``.
+
+ Returns:
+ ~torch.Tensor:
+ A tensor with shape ``[batch_size, seq_len]`` for the resulting projective parse trees.
+
+ Examples:
+ >>> s_arc = torch.tensor([[[ -2.8092, -7.9104, -0.9414, -5.4360],
+ [-10.3494, -7.9298, -3.6929, -7.3985],
+ [ 1.1815, -3.8291, 2.3166, -2.7183],
+ [ -3.9776, -3.9063, -1.6762, -3.1861]]])
+ >>> s_sib = torch.tensor([[[[ 0.4719, 0.4154, 1.1333, 0.6946],
+ [ 1.1252, 1.3043, 2.1128, 1.4621],
+ [ 0.5974, 0.5635, 1.0115, 0.7550],
+ [ 1.1174, 1.3794, 2.2567, 1.4043]],
+ [[-2.1480, -4.1830, -2.5519, -1.8020],
+ [-1.2496, -1.7859, -0.0665, -0.4938],
+ [-2.6171, -4.0142, -2.9428, -2.2121],
+ [-0.5166, -1.0925, 0.5190, 0.1371]],
+ [[ 0.5827, -1.2499, -0.0648, -0.0497],
+ [ 1.4695, 0.3522, 1.5614, 1.0236],
+ [ 0.4647, -0.7996, -0.3801, 0.0046],
+ [ 1.5611, 0.3875, 1.8285, 1.0766]],
+ [[-1.3053, -2.9423, -1.5779, -1.2142],
+ [-0.1908, -0.9699, 0.3085, 0.1061],
+ [-1.6783, -2.8199, -1.8853, -1.5653],
+ [ 0.3629, -0.3488, 0.9011, 0.5674]]]])
+ >>> mask = torch.tensor([[False, True, True, True]])
+ >>> eisner2o((s_arc, s_sib), mask)
+ tensor([[0, 2, 0, 2]])
+
+ .. _Online Learning of Approximate Dependency Parsing Algorithms:
+ https://www.aclweb.org/anthology/E06-1011/
+ """
+
+ # the end position of each sentence in a batch
+ lens = mask.sum(1)
+ s_arc, s_sib = scores
+ batch_size, seq_len, _ = s_arc.shape
+ # [seq_len, seq_len, batch_size]
+ s_arc = s_arc.permute(2, 1, 0)
+ # [seq_len, seq_len, seq_len, batch_size]
+ s_sib = s_sib.permute(2, 1, 3, 0)
+ s_i = torch.full_like(s_arc, float('-inf'))
+ s_s = torch.full_like(s_arc, float('-inf'))
+ s_c = torch.full_like(s_arc, float('-inf'))
+ p_i = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
+ p_s = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
+ p_c = s_arc.new_zeros(seq_len, seq_len, batch_size).long()
+ s_c.diagonal().fill_(0)
+
+ for w in range(1, seq_len):
+ # n denotes the number of spans to iterate,
+ # from span (0, w) to span (n, n+w) given width w
+ n = seq_len - w
+ starts = p_i.new_tensor(range(n)).unsqueeze(0)
+ # I(j->i) = max(I(j->r) + S(j->r, i)), i < r < j |
+ # C(j->j) + C(i->j-1))
+ # + s(j->i)
+ # [n, w, batch_size]
+ il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0)
+ il += stripe(s_sib[range(w, n + w), range(n)], n, w, (0, 1))
+ # [n, 1, batch_size]
+ il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1))
+ # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf
+ il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1)
+ il_span, il_path = il.permute(2, 0, 1).max(-1)
+ s_i.diagonal(-w).copy_(il_span + s_arc.diagonal(-w))
+ p_i.diagonal(-w).copy_(il_path + starts + 1)
+ # I(i->j) = max(I(i->r) + S(i->r, j), i < r < j |
+ # C(i->i) + C(j->i+1))
+ # + s(i->j)
+ # [n, w, batch_size]
+ ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0)
+ ir += stripe(s_sib[range(n), range(w, n + w)], n, w)
+ ir[0] = float('-inf')
+ # [n, 1, batch_size]
+ ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1))
+ ir[:, 0] = ir0.squeeze(1)
+ ir_span, ir_path = ir.permute(2, 0, 1).max(-1)
+ s_i.diagonal(w).copy_(ir_span + s_arc.diagonal(w))
+ p_i.diagonal(w).copy_(ir_path + starts)
+
+ # [n, w, batch_size]
+ slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
+ slr_span, slr_path = slr.permute(2, 0, 1).max(-1)
+ # S(j, i) = max(C(i->r) + C(j->r+1)), i <= r < j
+ s_s.diagonal(-w).copy_(slr_span)
+ p_s.diagonal(-w).copy_(slr_path + starts)
+ # S(i, j) = max(C(i->r) + C(j->r+1)), i <= r < j
+ s_s.diagonal(w).copy_(slr_span)
+ p_s.diagonal(w).copy_(slr_path + starts)
+
+ # C(j->i) = max(C(r->i) + I(j->r)), i <= r < j
+ cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
+ cl_span, cl_path = cl.permute(2, 0, 1).max(-1)
+ s_c.diagonal(-w).copy_(cl_span)
+ p_c.diagonal(-w).copy_(cl_path + starts)
+ # C(i->j) = max(I(i->r) + C(r->j)), i < r <= j
+ cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
+ cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
+ s_c.diagonal(w).copy_(cr_span)
+ # disable multi words to modify the root
+ s_c[0, w][lens.ne(w)] = float('-inf')
+ p_c.diagonal(w).copy_(cr_path + starts + 1)
+
+ def backtrack(p_i, p_s, p_c, heads, i, j, flag):
+ if i == j:
+ return
+ if flag == 'c':
+ r = p_c[i, j]
+ backtrack(p_i, p_s, p_c, heads, i, r, 'i')
+ backtrack(p_i, p_s, p_c, heads, r, j, 'c')
+ elif flag == 's':
+ r = p_s[i, j]
+ i, j = sorted((i, j))
+ backtrack(p_i, p_s, p_c, heads, i, r, 'c')
+ backtrack(p_i, p_s, p_c, heads, j, r + 1, 'c')
+ elif flag == 'i':
+ r, heads[j] = p_i[i, j], i
+ if r == i:
+ r = i + 1 if i < j else i - 1
+ backtrack(p_i, p_s, p_c, heads, j, r, 'c')
+ else:
+ backtrack(p_i, p_s, p_c, heads, i, r, 'i')
+ backtrack(p_i, p_s, p_c, heads, r, j, 's')
+
+ preds = []
+ p_i = p_i.permute(2, 0, 1).cpu()
+ p_s = p_s.permute(2, 0, 1).cpu()
+ p_c = p_c.permute(2, 0, 1).cpu()
+ for i, length in enumerate(lens.tolist()):
+ heads = p_c.new_zeros(length + 1, dtype=torch.long)
+ backtrack(p_i[i], p_s[i], p_c[i], heads, 0, length, 'c')
+ preds.append(heads.to(mask.device))
+
+ return pad(preds, total_length=seq_len).to(mask.device)
+
+
+def pad(tensors, padding_value=0, total_length=None):
+ size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors)
+ for i in range(len(tensors[0].size()))]
+ if total_length is not None:
+ assert total_length >= size[1]
+ size[1] = total_length
+ out_tensor = tensors[0].data.new(*size).fill_(padding_value)
+ for i, tensor in enumerate(tensors):
+ out_tensor[i][[slice(0, i) for i in tensor.size()]] = tensor
+ return out_tensor
+
+
+def decode_dep(s_arc, mask, tree=False, proj=False):
+ r"""
+ Args:
+ s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all possible arcs.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ The mask for covering the unpadded tokens.
+ tree (bool):
+ If ``True``, ensures to output well-formed trees. Default: ``False``.
+ proj (bool):
+ If ``True``, ensures to output projective trees. Default: ``False``.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ Predicted arcs and labels of shape ``[batch_size, seq_len]``.
+ """
+
+ lens = mask.sum(1)
+ arc_preds = s_arc.argmax(-1)
+ bad = [not istree(seq[1:i + 1], proj) for i, seq in zip(lens.tolist(), arc_preds.tolist())]
+ if tree and any(bad):
+ if proj:
+ alg = eisner
+ else:
+ alg = mst
+ s_arc.diagonal(0, 1, 2)[1:].fill_(float('-inf'))
+ arc_preds[bad] = alg(s_arc[bad], mask[bad])
+
+ return arc_preds
diff --git a/hanlp/components/parsers/alg_tf.py b/hanlp/components/parsers/alg_tf.py
new file mode 100644
index 000000000..75536dc79
--- /dev/null
+++ b/hanlp/components/parsers/alg_tf.py
@@ -0,0 +1,289 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-26 19:49
+# Ported from the PyTorch implementation https://github.com/zysite/biaffine-parser
+from typing import List
+import numpy as np
+import tensorflow as tf
+from collections import defaultdict
+
+
+def nonzero(t: tf.Tensor) -> tf.Tensor:
+ return tf.where(t > 0)
+
+
+def view(t: tf.Tensor, *dims) -> tf.Tensor:
+ return tf.reshape(t, dims)
+
+
+def arange(n: int) -> tf.Tensor:
+ return tf.range(n)
+
+
+def randperm(n: int) -> tf.Tensor:
+ return tf.random.shuffle(arange(n))
+
+
+def tolist(t: tf.Tensor) -> List:
+ if isinstance(t, tf.Tensor):
+ t = t.numpy()
+ return t.tolist()
+
+
+def kmeans(x, k, seed=None):
+ """See https://github.com/zysite/biaffine-parser/blob/master/parser/utils/alg.py#L7
+
+ Args:
+ x(list): Lengths of sentences
+ k(int):
+ seed: (Default value = None)
+
+ Returns:
+
+
+ """
+ x = tf.constant(x, dtype=tf.float32)
+ # count the frequency of each datapoint
+ d, indices, f = tf.unique_with_counts(x, tf.int32)
+ f = tf.cast(f, tf.float32)
+ # calculate the sum of the values of the same datapoints
+ total = d * f
+ # initialize k centroids randomly
+ c, old = tf.random.shuffle(d, seed)[:k], None
+ # assign labels to each datapoint based on centroids
+ dists = tf.abs(tf.expand_dims(d, -1) - c)
+ y = tf.argmin(dists, axis=-1, output_type=tf.int32)
+ dists = tf.gather_nd(dists, tf.transpose(tf.stack([tf.range(tf.shape(dists)[0], dtype=tf.int32), y])))
+ # make sure number of datapoints is greater than that of clusters
+ assert len(d) >= k, f"unable to assign {len(d)} datapoints to {k} clusters"
+
+ while old is None or not tf.reduce_all(c == old):
+ # if an empty cluster is encountered,
+ # choose the farthest datapoint from the biggest cluster
+ # and move that the empty one
+ for i in range(k):
+ if not tf.reduce_any(y == i):
+ mask = tf.cast(y == tf.expand_dims(tf.range(k, dtype=tf.int32), -1), tf.float32)
+ lens = tf.reduce_sum(mask, axis=-1)
+ biggest = view(nonzero(mask[tf.argmax(lens)]), -1)
+ farthest = tf.argmax(tf.gather(dists, biggest))
+ tf.tensor_scatter_nd_update(y, tf.expand_dims(tf.expand_dims(biggest[farthest], -1), -1), [i])
+ mask = tf.cast(y == tf.expand_dims(tf.range(k, dtype=tf.int32), -1), tf.float32)
+ # update the centroids
+ c, old = tf.cast(tf.reduce_sum(total * mask, axis=-1), tf.float32) / tf.cast(tf.reduce_sum(f * mask, axis=-1),
+ tf.float32), c
+ # re-assign all datapoints to clusters
+ dists = tf.abs(tf.expand_dims(d, -1) - c)
+ y = tf.argmin(dists, axis=-1, output_type=tf.int32)
+ dists = tf.gather_nd(dists, tf.transpose(tf.stack([tf.range(tf.shape(dists)[0], dtype=tf.int32), y])))
+ # assign all datapoints to the new-generated clusters
+ # without considering the empty ones
+ y, (assigned, _) = tf.gather(y, indices), tf.unique(y)
+ # get the centroids of the assigned clusters
+ centroids = tf.gather(c, assigned).numpy().tolist()
+ # map all values of datapoints to buckets
+ clusters = [tf.squeeze(tf.where(y == i), axis=-1).numpy().tolist() for i in assigned]
+
+ return centroids, clusters
+
+
+# ***************************************************************
+class Tarjan:
+ """Computes Tarjan's algorithm for finding strongly connected components (cycles) of a graph"""
+
+ def __init__(self, prediction, tokens):
+ """
+
+ Parameters
+ ----------
+ prediction : numpy.ndarray
+ a predicted dependency tree where prediction[dep_idx] = head_idx
+ tokens : numpy.ndarray
+ the tokens we care about (i.e. exclude _GO, _EOS, and _PAD)
+ """
+ self._edges = defaultdict(set)
+ self._vertices = set((0,))
+ for dep, head in enumerate(prediction[tokens]):
+ self._vertices.add(dep + 1)
+ self._edges[head].add(dep + 1)
+ self._indices = {}
+ self._lowlinks = {}
+ self._onstack = defaultdict(lambda: False)
+ self._SCCs = []
+
+ index = 0
+ stack = []
+ for v in self.vertices:
+ if v not in self.indices:
+ self.strongconnect(v, index, stack)
+
+ # =============================================================
+ def strongconnect(self, v, index, stack):
+ """
+
+ Args:
+ v:
+ index:
+ stack:
+
+ Returns:
+
+ """
+
+ self._indices[v] = index
+ self._lowlinks[v] = index
+ index += 1
+ stack.append(v)
+ self._onstack[v] = True
+ for w in self.edges[v]:
+ if w not in self.indices:
+ self.strongconnect(w, index, stack)
+ self._lowlinks[v] = min(self._lowlinks[v], self._lowlinks[w])
+ elif self._onstack[w]:
+ self._lowlinks[v] = min(self._lowlinks[v], self._indices[w])
+
+ if self._lowlinks[v] == self._indices[v]:
+ self._SCCs.append(set())
+ while stack[-1] != v:
+ w = stack.pop()
+ self._onstack[w] = False
+ self._SCCs[-1].add(w)
+ w = stack.pop()
+ self._onstack[w] = False
+ self._SCCs[-1].add(w)
+ return
+
+ # ======================
+ @property
+ def edges(self):
+ return self._edges
+
+ @property
+ def vertices(self):
+ return self._vertices
+
+ @property
+ def indices(self):
+ return self._indices
+
+ @property
+ def SCCs(self):
+ return self._SCCs
+
+
+def tarjan(parse_probs, length, tokens_to_keep, ensure_tree=True):
+ """Adopted from Timothy Dozat https://github.com/tdozat/Parser/blob/master/lib/models/nn.py
+
+ Args:
+ parse_probs(NDArray): seq_len x seq_len, the probability of arcs
+ length(NDArray): sentence length including ROOT
+ tokens_to_keep(NDArray): mask matrix
+ ensure_tree: (Default value = True)
+
+ Returns:
+
+
+ """
+ if ensure_tree:
+ I = np.eye(len(tokens_to_keep))
+ # block loops and pad heads
+ parse_probs = parse_probs * tokens_to_keep * (1 - I)
+ parse_preds = np.argmax(parse_probs, axis=1)
+ tokens = np.arange(1, length)
+ roots = np.where(parse_preds[tokens] == 0)[0] + 1
+ # ensure at least one root
+ if len(roots) < 1:
+ # The current root probabilities
+ root_probs = parse_probs[tokens, 0]
+ # The current head probabilities
+ old_head_probs = parse_probs[tokens, parse_preds[tokens]]
+ # Get new potential root probabilities
+ new_root_probs = root_probs / old_head_probs
+ # Select the most probable root
+ new_root = tokens[np.argmax(new_root_probs)]
+ # Make the change
+ parse_preds[new_root] = 0
+ # ensure at most one root
+ elif len(roots) > 1:
+ # The probabilities of the current heads
+ root_probs = parse_probs[roots, 0]
+ # Set the probability of depending on the root zero
+ parse_probs[roots, 0] = 0
+ # Get new potential heads and their probabilities
+ new_heads = np.argmax(parse_probs[roots][:, tokens], axis=1) + 1
+ new_head_probs = parse_probs[roots, new_heads] / root_probs
+ # Select the most probable root
+ new_root = roots[np.argmin(new_head_probs)]
+ # Make the change
+ parse_preds[roots] = new_heads
+ parse_preds[new_root] = 0
+ # remove cycles
+ tarjan = Tarjan(parse_preds, tokens)
+ for SCC in tarjan.SCCs:
+ if len(SCC) > 1:
+ dependents = set()
+ to_visit = set(SCC)
+ while len(to_visit) > 0:
+ node = to_visit.pop()
+ if not node in dependents:
+ dependents.add(node)
+ to_visit.update(tarjan.edges[node])
+ # The indices of the nodes that participate in the cycle
+ cycle = np.array(list(SCC))
+ # The probabilities of the current heads
+ old_heads = parse_preds[cycle]
+ old_head_probs = parse_probs[cycle, old_heads]
+ # Set the probability of depending on a non-head to zero
+ non_heads = np.array(list(dependents))
+ parse_probs[np.repeat(cycle, len(non_heads)), np.repeat([non_heads], len(cycle), axis=0).flatten()] = 0
+ # Get new potential heads and their probabilities
+ new_heads = np.argmax(parse_probs[cycle][:, tokens], axis=1) + 1
+ new_head_probs = parse_probs[cycle, new_heads] / old_head_probs
+ # Select the most probable change
+ change = np.argmax(new_head_probs)
+ changed_cycle = cycle[change]
+ old_head = old_heads[change]
+ new_head = new_heads[change]
+ # Make the change
+ parse_preds[changed_cycle] = new_head
+ tarjan.edges[new_head].add(changed_cycle)
+ tarjan.edges[old_head].remove(changed_cycle)
+ return parse_preds
+ else:
+ # block and pad heads
+ parse_probs = parse_probs * tokens_to_keep
+ parse_preds = np.argmax(parse_probs, axis=1)
+ return parse_preds
+
+
+def rel_argmax(rel_probs, length, root, ensure_tree=True):
+ """Fix the relation prediction by heuristic rules
+
+ Args:
+ rel_probs(NDArray): seq_len x rel_size
+ length: real sentence length
+ ensure_tree: (Default value = True)
+ root:
+
+ Returns:
+
+
+ """
+ if ensure_tree:
+ tokens = np.arange(1, length)
+ rel_preds = np.argmax(rel_probs, axis=1)
+ roots = np.where(rel_preds[tokens] == root)[0] + 1
+ if len(roots) < 1:
+ rel_preds[1 + np.argmax(rel_probs[tokens, root])] = root
+ elif len(roots) > 1:
+ root_probs = rel_probs[roots, root]
+ rel_probs[roots, root] = 0
+ new_rel_preds = np.argmax(rel_probs[roots], axis=1)
+ new_rel_probs = rel_probs[roots, new_rel_preds] / root_probs
+ new_root = roots[np.argmin(new_rel_probs)]
+ rel_preds[roots] = new_rel_preds
+ rel_preds[new_root] = root
+ return rel_preds
+ else:
+ rel_preds = np.argmax(rel_probs, axis=1)
+ return rel_preds
diff --git a/hanlp/components/parsers/biaffine/__init__.py b/hanlp/components/parsers/biaffine/__init__.py
index 12a4372f1..c9bddbdd8 100644
--- a/hanlp/components/parsers/biaffine/__init__.py
+++ b/hanlp/components/parsers/biaffine/__init__.py
@@ -1,3 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-12-26 23:03
\ No newline at end of file
+# Date: 2020-05-08 20:43
diff --git a/hanlp/components/parsers/biaffine/biaffine.py b/hanlp/components/parsers/biaffine/biaffine.py
new file mode 100644
index 000000000..0d2212c92
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/biaffine.py
@@ -0,0 +1,98 @@
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import torch
+import torch.nn as nn
+
+
+class Biaffine(nn.Module):
+ r"""
+ Biaffine layer for first-order scoring.
+
+ This function has a tensor of weights :math:`W` and bias terms if needed.
+ The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y`,
+ in which :math:`x` and :math:`y` can be concatenated with bias terms.
+
+ References:
+ - Timothy Dozat and Christopher D. Manning. 2017.
+ `Deep Biaffine Attention for Neural Dependency Parsing`_.
+
+ Args:
+ n_in (int):
+ The size of the input feature.
+ n_out (int):
+ The number of output channels.
+ bias_x (bool):
+ If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``.
+ bias_y (bool):
+ If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``.
+
+ .. _Deep Biaffine Attention for Neural Dependency Parsing:
+ https://openreview.net/forum?id=Hk95PK9le
+ """
+
+ def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True):
+ super().__init__()
+
+ self.n_in = n_in
+ self.n_out = n_out
+ self.bias_x = bias_x
+ self.bias_y = bias_y
+ self.weight = nn.Parameter(torch.Tensor(n_out, n_in + bias_x, n_in + bias_y))
+
+ self.reset_parameters()
+
+ def __repr__(self):
+ s = f"n_in={self.n_in}, n_out={self.n_out}"
+ if self.bias_x:
+ s += f", bias_x={self.bias_x}"
+ if self.bias_y:
+ s += f", bias_y={self.bias_y}"
+
+ return f"{self.__class__.__name__}({s})"
+
+ def reset_parameters(self):
+ nn.init.zeros_(self.weight)
+
+ def forward(self, x, y):
+ r"""
+ Args:
+ x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+ y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+
+ Returns:
+ ~torch.Tensor:
+ A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``.
+ If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
+ """
+
+ if self.bias_x:
+ x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
+ if self.bias_y:
+ y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
+ # [batch_size, n_out, seq_len, seq_len]
+ s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y)
+ # remove dim 1 if n_out == 1
+ s = s.squeeze(1)
+
+ return s
diff --git a/hanlp/components/parsers/biaffine/biaffine_2nd_dep.py b/hanlp/components/parsers/biaffine/biaffine_2nd_dep.py
new file mode 100644
index 000000000..45bb42dfe
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/biaffine_2nd_dep.py
@@ -0,0 +1,213 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-06 13:57
+import functools
+from typing import Union, List, Any
+
+import torch
+from hanlp_common.constant import UNK
+from hanlp.common.transform import TransformList
+from hanlp.common.vocab import Vocab
+from hanlp.components.parsers.biaffine.biaffine import Biaffine
+from hanlp.components.parsers.biaffine.biaffine_model import BiaffineDecoder, \
+ EncoderWithContextualLayer
+from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
+from hanlp.components.parsers.biaffine.biaffine_sdp import BiaffineSemanticDependencyParser
+from hanlp_common.conll import CoNLLUWord, CoNLLSentence
+from hanlp.components.parsers.parse_alg import add_secondary_arcs_by_preds
+from hanlp.datasets.parsing.conll_dataset import append_bos
+from hanlp.datasets.parsing.semeval15 import unpack_deps_to_head_deprel, merge_head_deprel_with_2nd
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+from transformers import PreTrainedModel, PreTrainedTokenizer
+
+
+class BiaffineSeparateDecoder(torch.nn.Module):
+
+ def __init__(self, hidden_size, config) -> None:
+ super().__init__()
+ self.biaffine_decoder = BiaffineDecoder(hidden_size,
+ config.n_mlp_arc,
+ config.n_mlp_rel,
+ config.mlp_dropout,
+ config.n_rels)
+ self.biaffine_decoder_2nd = BiaffineDecoder(hidden_size,
+ config.n_mlp_arc,
+ config.n_mlp_rel,
+ config.mlp_dropout,
+ config.n_rels_2nd)
+
+ def forward(self, x, mask):
+ return tuple(zip(self.biaffine_decoder(x, mask), self.biaffine_decoder_2nd(x, mask)))
+
+
+class BiaffineJointDecoder(BiaffineDecoder):
+ def __init__(self, hidden_size, config) -> None:
+ super().__init__(hidden_size, config.n_mlp_arc, config.n_mlp_rel, config.mlp_dropout, config.n_rels)
+ # the Biaffine layers for secondary dep
+ self.arc_attn_2nd = Biaffine(n_in=config.n_mlp_arc,
+ bias_x=True,
+ bias_y=False)
+ self.rel_attn_2nd = Biaffine(n_in=config.n_mlp_rel,
+ n_out=config.n_rels,
+ bias_x=True,
+ bias_y=True)
+
+ def forward(self, x, mask=None, **kwargs: Any):
+ arc_d, arc_h, rel_d, rel_h = self.apply_mlps(x)
+ s_arc, s_rel = self.decode(arc_d, arc_h, rel_d, rel_h, mask, self.arc_attn, self.rel_attn)
+ s_arc_2nd, s_rel_2nd = self.decode(arc_d, arc_h, rel_d, rel_h, mask, self.arc_attn_2nd, self.rel_attn_2nd)
+ return (s_arc, s_arc_2nd), (s_rel, s_rel_2nd)
+
+
+class BiaffineSecondaryModel(torch.nn.Module):
+
+ def __init__(self, config, pretrained_embed: torch.Tensor = None, transformer: PreTrainedModel = None,
+ transformer_tokenizer: PreTrainedTokenizer = None):
+ super().__init__()
+ self.encoder = EncoderWithContextualLayer(config, pretrained_embed, transformer, transformer_tokenizer)
+ self.decoder = BiaffineJointDecoder(self.encoder.hidden_size, config) if config.joint \
+ else BiaffineSeparateDecoder(self.encoder.hidden_size, config)
+
+ def forward(self,
+ words=None,
+ feats=None,
+ input_ids=None,
+ token_span=None,
+ mask=None, lens=None, **kwargs):
+ x, mask = self.encoder(words, feats, input_ids, token_span, mask, lens)
+ return self.decoder(x, mask)
+
+
+class BiaffineSecondaryParser(BiaffineDependencyParser):
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.model: BiaffineSecondaryModel = None
+
+ def build_dataset(self, data, bos_transform=None):
+ transform = TransformList(functools.partial(append_bos, pos_key='UPOS'),
+ functools.partial(unpack_deps_to_head_deprel, pad_rel=self.config.pad_rel,
+ arc_key='arc_2nd',
+ rel_key='rel_2nd'))
+ if self.config.joint:
+ transform.append(merge_head_deprel_with_2nd)
+ if bos_transform:
+ transform.append(bos_transform)
+ return super().build_dataset(data, transform)
+
+ def build_criterion(self, **kwargs):
+ # noinspection PyCallByClass
+ return super().build_criterion(**kwargs), (BiaffineSemanticDependencyParser.build_criterion(self, **kwargs))
+
+ def fit(self, trn_data, dev_data, save_dir, feat=None, n_embed=100, pretrained_embed=None, transformer=None,
+ average_subwords=False, word_dropout: float = 0.2, transformer_hidden_dropout=None, layer_dropout=0,
+ scalar_mix: int = None, embed_dropout=.33, n_lstm_hidden=400, n_lstm_layers=3, hidden_dropout=.33,
+ n_mlp_arc=500, n_mlp_rel=100, mlp_dropout=.33, lr=2e-3, transformer_lr=5e-5, mu=.9, nu=.9, epsilon=1e-12,
+ clip=5.0, decay=.75, decay_steps=5000, patience=100, batch_size=None, sampler_builder=None,
+ lowercase=False, epochs=50000, tree=False, punct=False, min_freq=2,
+ apply_constraint=True, joint=False, no_cycle=False, root=None,
+ logger=None,
+ verbose=True, unk=UNK, pad_rel=None, max_sequence_length=512, devices: Union[float, int, List[int]] = None,
+ transform=None, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_vocabs(self, dataset, logger=None, transformer=None):
+ self.vocabs['rel_2nd'] = rel_2nd = Vocab(pad_token=self.config.pad_rel, unk_token=self.config.pad_rel)
+ if self.config.joint:
+ self.vocabs['rel'] = rel_2nd
+ super().build_vocabs(dataset, logger, transformer)
+ self.config.n_rels_2nd = len(rel_2nd)
+
+ def create_model(self, pretrained_embed, transformer):
+ return BiaffineSecondaryModel(self.config, pretrained_embed, transformer, self.transformer_tokenizer)
+
+ def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch=None):
+ arc_scores_1st, arc_scores_2nd, rel_scores_1st, rel_scores_2nd = self.unpack_scores(arc_scores, rel_scores)
+ loss_1st = super().compute_loss(arc_scores_1st, rel_scores_1st, arcs, rels, mask, criterion[0], batch)
+ mask = self.compute_mask(arc_scores_2nd, batch, mask)
+ # noinspection PyCallByClass
+ loss_2st = BiaffineSemanticDependencyParser.compute_loss(self, arc_scores_2nd, rel_scores_2nd,
+ batch['arc_2nd'], batch['rel_2nd_id'], mask,
+ criterion[1], batch)
+ return loss_1st + loss_2st
+
+ @staticmethod
+ def compute_mask(arc_scores_2nd, batch, mask_1st):
+ mask = batch.get('mask_2nd', None)
+ if mask is None:
+ batch['mask_2nd'] = mask = BiaffineSemanticDependencyParser.convert_to_3d_mask(arc_scores_2nd, mask_1st)
+ return mask
+
+ def unpack_scores(self, arc_scores, rel_scores):
+ arc_scores_1st, arc_scores_2nd = arc_scores
+ rel_scores_1st, rel_scores_2nd = rel_scores
+ return arc_scores_1st, arc_scores_2nd, rel_scores_1st, rel_scores_2nd
+
+ def get_pad_dict(self):
+ d = super(BiaffineSecondaryParser, self).get_pad_dict()
+ d.update({'arc_2nd': False})
+ return d
+
+ def decode(self, arc_scores, rel_scores, mask, batch=None, predicting=None):
+ output_1st, output_2nd = batch.get('outputs', (None, None))
+ if output_1st is None:
+ arc_scores_1st, arc_scores_2nd, rel_scores_1st, rel_scores_2nd = self.unpack_scores(arc_scores, rel_scores)
+ output_1st = super().decode(arc_scores_1st, rel_scores_1st, mask)
+ mask = self.compute_mask(arc_scores_2nd, batch, mask)
+ # noinspection PyCallByClass
+ output_2nd = BiaffineSemanticDependencyParser.decode(self, arc_scores_2nd, rel_scores_2nd, mask, batch)
+ if self.config.get('no_cycle'):
+ assert predicting, 'No cycle constraint for evaluation is not implemented yet. If you are ' \
+ 'interested, welcome to submit a pull request.'
+ root_rel_idx = self.vocabs['rel'].token_to_idx.get(self.config.get('root', None), None)
+ arc_pred_1st, rel_pred_1st, arc_pred_2nd, rel_pred_2nd = *output_1st, *output_2nd
+ arc_scores_2nd = arc_scores_2nd.transpose(1, 2).cpu().detach().numpy()
+ arc_pred_2nd = arc_pred_2nd.cpu().detach().numpy()
+ rel_pred_2nd = rel_pred_2nd.cpu().detach().numpy()
+ trees = arc_pred_1st.cpu().detach().numpy()
+ graphs = []
+ for i, (arc_scores, arc_preds, rel_preds, tree, tokens) in enumerate(
+ zip(arc_scores_2nd, arc_pred_2nd, rel_pred_2nd, trees, batch['token'])):
+ sent_len = len(tokens)
+ graph = add_secondary_arcs_by_preds(arc_scores, arc_preds[:sent_len, :sent_len], rel_preds,
+ tree[:sent_len], root_rel_idx)
+ graphs.append(graph[1:]) # Remove root
+ # if not predicting:
+ # # Write back to torch Tensor
+ # for d, hr in zip(graph):
+ # pass
+ output_2nd = None, graphs
+
+ return tuple(zip(output_1st, output_2nd))
+
+ def update_metric(self, arc_preds, rel_preds, arcs, rels, mask, puncts, metric, batch=None):
+ super().update_metric(arc_preds[0], rel_preds[0], arcs, rels, mask, puncts, metric['1st'], batch)
+ puncts = BiaffineSemanticDependencyParser.convert_to_3d_puncts(puncts, batch['mask_2nd'])
+ # noinspection PyCallByClass
+ BiaffineSemanticDependencyParser.update_metric(self, arc_preds[1], rel_preds[1], batch['arc_2nd'],
+ batch['rel_2nd_id'], batch['mask_2nd'], puncts, metric['2nd'],
+ batch)
+
+ def build_metric(self, **kwargs):
+ # noinspection PyCallByClass
+ return MetricDict({'1st': super().build_metric(**kwargs),
+ '2nd': BiaffineSemanticDependencyParser.build_metric(self, **kwargs)})
+
+ def collect_outputs_extend(self, predictions: list, arc_preds, rel_preds, lens, mask):
+ predictions.extend(rel_preds[1])
+
+ def predictions_to_human(self, predictions, outputs, data, use_pos):
+ rel_vocab = self.vocabs['rel'].idx_to_token
+ for d, graph in zip(data, predictions):
+ sent = CoNLLSentence()
+ for idx, (cell, hrs) in enumerate(zip(d, graph)):
+ if use_pos:
+ token, pos = cell
+ else:
+ token, pos = cell, None
+ head = hrs[0][0]
+ deprel = rel_vocab[hrs[0][1]]
+ deps = [(h, rel_vocab[r]) for h, r in hrs[1:]]
+ sent.append(CoNLLUWord(idx + 1, token, upos=pos, head=head, deprel=deprel, deps=deps))
+ outputs.append(sent)
diff --git a/hanlp/components/parsers/biaffine/biaffine_dep.py b/hanlp/components/parsers/biaffine/biaffine_dep.py
new file mode 100644
index 000000000..2429d1493
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/biaffine_dep.py
@@ -0,0 +1,569 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-08 20:51
+import os
+from collections import Counter
+from typing import Union, Any, List
+
+from hanlp.layers.transformers.pt_imports import AutoTokenizer, PreTrainedTokenizer, AutoModel_
+import torch
+from alnlp.modules.util import lengths_to_mask
+from torch import nn
+from torch.optim import Adam
+from torch.optim.lr_scheduler import ExponentialLR
+from torch.utils.data import DataLoader
+from hanlp_common.constant import ROOT, UNK, IDX
+from hanlp.common.dataset import PadSequenceDataLoader
+from hanlp.common.structure import History
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import LowerCase, FieldLength, PunctuationMask
+from hanlp.common.vocab import Vocab
+from hanlp.components.parsers.alg import decode_dep
+from hanlp.components.parsers.biaffine.biaffine_model import BiaffineDependencyModel
+from hanlp_common.conll import CoNLLWord, CoNLLSentence
+from hanlp.datasets.parsing.conll_dataset import CoNLLParsingDataset, append_bos
+from hanlp.layers.embeddings.util import index_word2vec_with_vocab
+from hanlp.layers.transformers.utils import build_optimizer_scheduler_with_transformer
+from hanlp.metrics.parsing.attachmentscore import AttachmentScore
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import isdebugging, merge_locals_kwargs, merge_dict, reorder
+
+
+class BiaffineDependencyParser(TorchComponent):
+ def __init__(self) -> None:
+ """Biaffine dependency parsing (:cite:`dozat:17a`).
+ """
+ super().__init__()
+ self.model: BiaffineDependencyModel = None
+ self.transformer_tokenizer: PreTrainedTokenizer = None
+
+ def predict(self, data: Any, batch_size=None, batch_max_tokens=None, output_format='conllx', **kwargs):
+ if not data:
+ return []
+ use_pos = self.use_pos
+ flat = self.input_is_flat(data, use_pos)
+ if flat:
+ data = [data]
+ samples = self.build_samples(data, use_pos)
+ if not batch_max_tokens:
+ batch_max_tokens = self.config.batch_max_tokens
+ if not batch_size:
+ batch_size = self.config.batch_size
+ dataloader = self.build_dataloader(samples,
+ device=self.devices[0], shuffle=False,
+ **merge_dict(self.config,
+ batch_size=batch_size,
+ batch_max_tokens=batch_max_tokens,
+ overwrite=True,
+ **kwargs))
+ predictions, build_data, data, order = self.before_outputs(data)
+ for batch in dataloader:
+ arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
+ self.collect_outputs(arc_scores, rel_scores, mask, batch, predictions, order, data, use_pos,
+ build_data)
+ outputs = self.post_outputs(predictions, data, order, use_pos, build_data)
+ if flat:
+ return outputs[0]
+ return outputs
+
+ def build_samples(self, data, use_pos=None):
+ samples = []
+ pos_key = 'CPOS' if 'CPOS' in self.vocabs else 'UPOS'
+ for idx, each in enumerate(data):
+ sample = {IDX: idx}
+ if use_pos:
+ token, pos = zip(*each)
+ sample.update({'FORM': list(token), pos_key: list(pos)})
+ else:
+ token = each
+ sample.update({'FORM': list(token)})
+ samples.append(sample)
+ return samples
+
+ def input_is_flat(self, data, use_pos=None):
+ if use_pos is None:
+ use_pos = 'CPOS' in self.vocabs
+ if use_pos:
+ flat = isinstance(data[0], (list, tuple)) and isinstance(data[0][0], str)
+ else:
+ flat = isinstance(data[0], str)
+ return flat
+
+ def before_outputs(self, data):
+ predictions, order = [], []
+ build_data = data is None
+ if build_data:
+ data = []
+ return predictions, build_data, data, order
+
+ def post_outputs(self, predictions, data, order, use_pos, build_data):
+ predictions = reorder(predictions, order)
+ if build_data:
+ data = reorder(data, order)
+ outputs = []
+ self.predictions_to_human(predictions, outputs, data, use_pos)
+ return outputs
+
+ def predictions_to_human(self, predictions, outputs, data, use_pos):
+ for d, (arcs, rels) in zip(data, predictions):
+ sent = CoNLLSentence()
+ for idx, (cell, a, r) in enumerate(zip(d, arcs, rels)):
+ if use_pos:
+ token, pos = cell
+ else:
+ token, pos = cell, None
+ sent.append(CoNLLWord(idx + 1, token, cpos=pos, head=a, deprel=self.vocabs['rel'][r]))
+ outputs.append(sent)
+
+ def collect_outputs(self, arc_scores, rel_scores, mask, batch, predictions, order, data, use_pos,
+ build_data):
+ lens = [len(token) - 1 for token in batch['token']]
+ arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask, batch)
+ self.collect_outputs_extend(predictions, arc_preds, rel_preds, lens, mask)
+ order.extend(batch[IDX])
+ if build_data:
+ if use_pos:
+ data.extend(zip(batch['FORM'], batch['CPOS']))
+ else:
+ data.extend(batch['FORM'])
+
+ def collect_outputs_extend(self, predictions: list, arc_preds, rel_preds, lens, mask):
+ predictions.extend(zip([seq.tolist() for seq in arc_preds[mask].split(lens)],
+ [seq.tolist() for seq in rel_preds[mask].split(lens)]))
+
+ @property
+ def use_pos(self):
+ return self.config.get('feat', None) == 'pos'
+
+ def fit(self, trn_data, dev_data, save_dir,
+ feat=None,
+ n_embed=100,
+ pretrained_embed=None,
+ transformer=None,
+ average_subwords=False,
+ word_dropout=0.2,
+ transformer_hidden_dropout=None,
+ layer_dropout=0,
+ scalar_mix: int = None,
+ embed_dropout=.33,
+ n_lstm_hidden=400,
+ n_lstm_layers=3,
+ hidden_dropout=.33,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ lr=2e-3,
+ transformer_lr=5e-5,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ grad_norm=5.0,
+ decay=.75,
+ decay_steps=5000,
+ weight_decay=0,
+ warmup_steps=0.1,
+ separate_optimizer=False,
+ patience=100,
+ lowercase=False,
+ epochs=50000,
+ tree=False,
+ proj=False,
+ punct=False,
+ min_freq=2,
+ logger=None,
+ verbose=True,
+ unk=UNK,
+ max_sequence_length=512,
+ batch_size=None,
+ sampler_builder=None,
+ gradient_accumulation=1,
+ devices: Union[float, int, List[int]] = None,
+ transform=None,
+ secondary_encoder=None,
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def execute_training_loop(self, trn, dev, devices, epochs, logger, patience, save_dir, optimizer,
+ gradient_accumulation, **kwargs):
+ optimizer, scheduler, transformer_optimizer, transformer_scheduler = optimizer
+ criterion = self.build_criterion()
+ best_e, best_metric = 0, self.build_metric()
+ timer = CountdownTimer(epochs)
+ history = History()
+ ratio_width = len(f'{len(trn) // gradient_accumulation}/{len(trn) // gradient_accumulation}')
+ for epoch in range(1, epochs + 1):
+ # train one epoch and update the parameters
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, optimizer, scheduler, criterion, epoch, logger, history,
+ transformer_optimizer, transformer_scheduler,
+ gradient_accumulation=gradient_accumulation)
+ loss, dev_metric = self.evaluate_dataloader(dev, criterion, ratio_width=ratio_width, logger=logger)
+ timer.update()
+ # logger.info(f"{'Dev' + ' ' * ratio_width} loss: {loss:.4f} {dev_metric}")
+ # save the model if it is the best so far
+ report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
+ if dev_metric > best_metric:
+ best_e, best_metric = epoch, dev_metric
+ self.save_weights(save_dir)
+ report += ' ([red]saved[/red])'
+ else:
+ if patience != epochs:
+ report += f' ({epoch - best_e}/{patience})'
+ else:
+ report += f' ({epoch - best_e})'
+ logger.info(report)
+ if patience is not None and epoch - best_e >= patience:
+ logger.info(f'LAS has stopped improving for {patience} epochs, early stop.')
+ break
+ timer.stop()
+ if not best_e:
+ self.save_weights(save_dir)
+ elif best_e != epoch:
+ self.load_weights(save_dir)
+ logger.info(f"Max score of dev is {best_metric.score:.2%} at epoch {best_e}")
+ logger.info(f"Average time of each epoch is {timer.elapsed_average_human}")
+ logger.info(f"{timer.elapsed_human} elapsed")
+
+ def build_optimizer(self, epochs, trn, gradient_accumulation, **kwargs):
+ config = self.config
+ model = self.model
+ if isinstance(model, nn.DataParallel):
+ model = model.module
+ if self.config.transformer:
+ transformer = model.encoder.transformer
+ optimizer = Adam(set(model.parameters()) - set(transformer.parameters()),
+ config.lr,
+ (config.mu, config.nu),
+ config.epsilon)
+ if self.config.transformer_lr:
+ num_training_steps = len(trn) * epochs // gradient_accumulation
+ if self.config.separate_optimizer:
+ transformer_optimizer, transformer_scheduler = \
+ build_optimizer_scheduler_with_transformer(transformer,
+ transformer,
+ config.transformer_lr,
+ config.transformer_lr,
+ num_training_steps,
+ config.warmup_steps,
+ config.weight_decay,
+ adam_epsilon=1e-8)
+ else:
+ optimizer, scheduler = build_optimizer_scheduler_with_transformer(model,
+ transformer,
+ config.lr,
+ config.transformer_lr,
+ num_training_steps,
+ config.warmup_steps,
+ config.weight_decay,
+ adam_epsilon=1e-8)
+ transformer_optimizer, transformer_scheduler = None, None
+ else:
+ transformer.requires_grad_(False)
+ transformer_optimizer, transformer_scheduler = None, None
+ else:
+ optimizer = Adam(model.parameters(),
+ config.lr,
+ (config.mu, config.nu),
+ config.epsilon)
+ transformer_optimizer, transformer_scheduler = None, None
+ if self.config.separate_optimizer:
+ scheduler = ExponentialLR(optimizer, config.decay ** (1 / config.decay_steps))
+ # noinspection PyUnboundLocalVariable
+ return optimizer, scheduler, transformer_optimizer, transformer_scheduler
+
+ def build_transformer_tokenizer(self):
+ transformer = self.config.transformer
+ if transformer:
+ transformer_tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(transformer, use_fast=True)
+ else:
+ transformer_tokenizer = None
+ self.transformer_tokenizer = transformer_tokenizer
+ return transformer_tokenizer
+
+ # noinspection PyMethodOverriding
+ def build_dataloader(self,
+ data,
+ shuffle,
+ device,
+ training=False,
+ logger=None,
+ gradient_accumulation=1,
+ sampler_builder=None,
+ batch_size=None,
+ **kwargs) -> DataLoader:
+ dataset = self.build_dataset(data)
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger, self.config.transformer)
+ transformer_tokenizer = self.transformer_tokenizer
+ if transformer_tokenizer:
+ dataset.transform.append(self.build_tokenizer_transform())
+ dataset.append_transform(FieldLength('token', 'sent_length'))
+ if isinstance(data, str):
+ dataset.purge_cache()
+ if len(dataset) > 1000 and isinstance(data, str):
+ timer = CountdownTimer(len(dataset))
+ self.cache_dataset(dataset, timer, training, logger)
+ if self.config.transformer:
+ lens = [len(sample['input_ids']) for sample in dataset]
+ else:
+ lens = [sample['sent_length'] for sample in dataset]
+ if sampler_builder:
+ sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
+ else:
+ sampler = None
+ loader = PadSequenceDataLoader(dataset=dataset,
+ batch_sampler=sampler,
+ batch_size=batch_size,
+ num_workers=0 if isdebugging() else 2,
+ pad=self.get_pad_dict(),
+ device=device,
+ vocabs=self.vocabs)
+ return loader
+
+ def cache_dataset(self, dataset, timer, training=False, logger=None):
+ for each in dataset:
+ timer.log('Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]')
+
+ def get_pad_dict(self):
+ return {'arc': 0}
+
+ def build_dataset(self, data, bos_transform=None):
+ if not bos_transform:
+ bos_transform = append_bos
+ transform = [bos_transform]
+ if self.config.get('transform', None):
+ transform.append(self.config.transform)
+ if self.config.get('lowercase', False):
+ transform.append(LowerCase('token'))
+ transform.append(self.vocabs)
+ if not self.config.punct:
+ transform.append(PunctuationMask('token', 'punct_mask'))
+ return CoNLLParsingDataset(data, transform=transform)
+
+ def build_tokenizer_transform(self):
+ return TransformerSequenceTokenizer(self.transformer_tokenizer, 'token', '',
+ ret_token_span=True, cls_is_bos=True,
+ max_seq_length=self.config.get('max_sequence_length',
+ 512),
+ truncate_long_sequences=False)
+
+ def build_vocabs(self, dataset, logger=None, transformer=None):
+ rel_vocab = self.vocabs.get('rel', None)
+ if rel_vocab is None:
+ rel_vocab = Vocab(unk_token=None, pad_token=self.config.get('pad_rel', None))
+ self.vocabs.put(rel=rel_vocab)
+ if self.config.get('feat', None) == 'pos' or self.config.get('use_pos', False):
+ self.vocabs['pos'] = Vocab(unk_token=None, pad_token=None)
+
+ timer = CountdownTimer(len(dataset))
+ if transformer:
+ token_vocab = None
+ else:
+ token_vocab = Vocab()
+ self.vocabs.token = token_vocab
+ unk = self.config.get('unk', None)
+ if unk is not None:
+ token_vocab.unk_token = unk
+ if token_vocab and self.config.get('min_freq', None):
+ counter = Counter()
+ for sample in dataset:
+ for form in sample['token']:
+ counter[form] += 1
+ reserved_token = [token_vocab.pad_token, token_vocab.unk_token]
+ if ROOT in token_vocab:
+ reserved_token.append(ROOT)
+ freq_words = reserved_token + [token for token, freq in counter.items() if
+ freq >= self.config.min_freq]
+ token_vocab.token_to_idx.clear()
+ for word in freq_words:
+ token_vocab(word)
+ else:
+ for i, sample in enumerate(dataset):
+ timer.log('vocab building [blink][yellow]...[/yellow][/blink]', ratio_percentage=True)
+ rel_vocab.set_unk_as_safe_unk() # Some relation in dev set is OOV
+ self.vocabs.lock()
+ self.vocabs.summary(logger=logger)
+ if token_vocab:
+ self.config.n_words = len(self.vocabs['token'])
+ if 'pos' in self.vocabs:
+ self.config.n_feats = len(self.vocabs['pos'])
+ self.vocabs['pos'].set_unk_as_safe_unk()
+ self.config.n_rels = len(self.vocabs['rel'])
+ if token_vocab:
+ self.config.pad_index = self.vocabs['token'].pad_idx
+ self.config.unk_index = self.vocabs['token'].unk_idx
+
+ def build_model(self, training=True, **kwargs) -> torch.nn.Module:
+ pretrained_embed, transformer = self.build_embeddings(training=training)
+ if pretrained_embed is not None:
+ self.config.n_embed = pretrained_embed.size(-1)
+ model = self.create_model(pretrained_embed, transformer)
+ return model
+
+ def create_model(self, pretrained_embed, transformer):
+ return BiaffineDependencyModel(self.config,
+ pretrained_embed,
+ transformer,
+ self.transformer_tokenizer)
+
+ def build_embeddings(self, training=True):
+ pretrained_embed = None
+ if self.config.get('pretrained_embed', None):
+ pretrained_embed = index_word2vec_with_vocab(self.config.pretrained_embed, self.vocabs['token'],
+ init='zeros', normalize=True)
+ transformer = self.config.transformer
+ if transformer:
+ transformer = AutoModel_.from_pretrained(transformer, training=training)
+ return pretrained_embed, transformer
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn,
+ optimizer,
+ scheduler,
+ criterion,
+ epoch,
+ logger,
+ history: History,
+ transformer_optimizer=None,
+ transformer_scheduler=None,
+ gradient_accumulation=1,
+ **kwargs):
+ self.model.train()
+
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
+ metric = self.build_metric(training=True)
+ total_loss = 0
+ for idx, batch in enumerate(trn):
+ arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
+ arcs, rels = batch['arc'], batch['rel_id']
+ loss = self.compute_loss(arc_scores, rel_scores, arcs, rels, mask, criterion, batch)
+ if gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += loss.item()
+ arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask, batch)
+ self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric, batch)
+ if history.step(gradient_accumulation):
+ self._step(optimizer, scheduler, transformer_optimizer, transformer_scheduler)
+ report = self._report(total_loss / (timer.current + 1), metric)
+ timer.log(report, ratio_percentage=False, logger=logger)
+ del loss
+
+ def _step(self, optimizer, scheduler, transformer_optimizer, transformer_scheduler):
+ if self.config.get('grad_norm', None):
+ nn.utils.clip_grad_norm_(self.model.parameters(),
+ self.config.grad_norm)
+ optimizer.step()
+ optimizer.zero_grad()
+ scheduler.step()
+ if self.config.transformer and self.config.transformer_lr and transformer_optimizer:
+ transformer_optimizer.step()
+ transformer_optimizer.zero_grad()
+ transformer_scheduler.step()
+
+ def feed_batch(self, batch):
+ words, feats, lens, puncts = batch.get('token_id', None), batch.get('pos_id', None), batch['sent_length'], \
+ batch.get('punct_mask', None)
+ mask = lengths_to_mask(lens)
+ arc_scores, rel_scores = self.model(words=words, feats=feats, mask=mask, batch=batch, **batch)
+ # ignore the first token of each sentence
+ # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+ if self.model.training:
+ mask = mask.clone()
+ mask[:, 0] = 0
+ return arc_scores, rel_scores, mask, puncts
+
+ def _report(self, loss, metric: AttachmentScore):
+ return f'loss: {loss:.4f} {metric}'
+
+ def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch=None):
+ arc_scores, arcs = arc_scores[mask], arcs[mask]
+ rel_scores, rels = rel_scores[mask], rels[mask]
+ rel_scores = rel_scores[torch.arange(len(arcs)), arcs]
+ arc_loss = criterion(arc_scores, arcs)
+ rel_loss = criterion(rel_scores, rels)
+ loss = arc_loss + rel_loss
+
+ return loss
+
+ # noinspection PyUnboundLocalVariable
+ @torch.no_grad()
+ def evaluate_dataloader(self, loader: PadSequenceDataLoader, criterion, logger=None, filename=None, output=False,
+ ratio_width=None,
+ metric=None,
+ **kwargs):
+ self.model.eval()
+
+ loss = 0
+ if not metric:
+ metric = self.build_metric()
+ if output:
+ fp = open(output, 'w')
+ predictions, build_data, data, order = self.before_outputs(None)
+
+ timer = CountdownTimer(len(loader))
+ use_pos = self.use_pos
+ for batch in loader:
+ arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
+ if output:
+ self.collect_outputs(arc_scores, rel_scores, mask, batch, predictions, order, data, use_pos,
+ build_data)
+ arcs, rels = batch['arc'], batch['rel_id']
+ loss += self.compute_loss(arc_scores, rel_scores, arcs, rels, mask, criterion, batch).item()
+ arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask, batch)
+ self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric, batch)
+ report = self._report(loss / (timer.current + 1), metric)
+ if filename:
+ report = f'{os.path.basename(filename)} ' + report
+ timer.log(report, ratio_percentage=False, logger=logger, ratio_width=ratio_width)
+ loss /= len(loader)
+ if output:
+ outputs = self.post_outputs(predictions, data, order, use_pos, build_data)
+ for each in outputs:
+ fp.write(f'{each}\n\n')
+ fp.close()
+ logger.info(f'Predictions saved in [underline][yellow]{output}[/yellow][/underline]')
+
+ return loss, metric
+
+ def update_metric(self, arc_preds, rel_preds, arcs, rels, mask, puncts, metric, batch=None):
+ # ignore all punctuation if not specified
+ if not self.config.punct:
+ mask &= puncts
+ metric(arc_preds, rel_preds, arcs, rels, mask)
+
+ def decode(self, arc_scores, rel_scores, mask, batch=None):
+ tree, proj = self.config.tree, self.config.get('proj', False)
+ if tree:
+ arc_preds = decode_dep(arc_scores, mask, tree, proj)
+ else:
+ arc_preds = arc_scores.argmax(-1)
+ rel_preds = rel_scores.argmax(-1)
+ rel_preds = rel_preds.gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
+
+ return arc_preds, rel_preds
+
+ def build_criterion(self, **kwargs):
+ criterion = nn.CrossEntropyLoss()
+ return criterion
+
+ def build_metric(self, **kwargs):
+ return AttachmentScore()
+
+ def on_config_ready(self, **kwargs):
+ self.build_transformer_tokenizer() # We have to build tokenizer before building the dataloader and model
+ self.config.patience = min(self.config.patience, self.config.epochs)
+
+ def prediction_to_head_rel(self, arcs: torch.LongTensor, rels: torch.LongTensor, batch: dict):
+ arcs = arcs[:, 1:] # Skip the ROOT
+ rels = rels[:, 1:]
+ arcs = arcs.tolist()
+ rels = rels.tolist()
+ vocab = self.vocabs['rel'].idx_to_token
+ for arcs_per_sent, rels_per_sent, tokens in zip(arcs, rels, batch['token']):
+ tokens = tokens[1:]
+ sent_len = len(tokens)
+ result = list(zip(arcs_per_sent[:sent_len], [vocab[r] for r in rels_per_sent[:sent_len]]))
+ yield result
diff --git a/hanlp/components/parsers/biaffine/biaffine_model.py b/hanlp/components/parsers/biaffine/biaffine_model.py
new file mode 100644
index 000000000..2730190f4
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/biaffine_model.py
@@ -0,0 +1,230 @@
+# -*- coding: utf-8 -*-
+from typing import Any, Tuple
+
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import (pack_padded_sequence, pad_packed_sequence,
+ pad_sequence)
+
+from hanlp.components.parsers.biaffine.biaffine import Biaffine
+from hanlp.components.parsers.biaffine.mlp import MLP
+from hanlp.components.parsers.biaffine.variationalbilstm import VariationalLSTM
+from hanlp.layers.dropout import IndependentDropout, SharedDropout, WordDropout
+from hanlp.layers.transformers.encoder import TransformerEncoder
+from hanlp.layers.transformers.pt_imports import PreTrainedModel, PreTrainedTokenizer
+from hanlp.layers.transformers.utils import transformer_encode
+
+
+class EncoderWithContextualLayer(nn.Module):
+ def __init__(self,
+ config,
+ pretrained_embed: torch.Tensor = None,
+ transformer: PreTrainedModel = None,
+ transformer_tokenizer: PreTrainedTokenizer = None,
+ ):
+ super(EncoderWithContextualLayer, self).__init__()
+
+ self.secondary_encoder = config.get('secondary_encoder', None)
+ self.config = config
+
+ if not transformer:
+ self.pad_index = config.pad_index
+ self.unk_index = config.unk_index
+ if config.word_dropout:
+ oov = self.unk_index
+ excludes = [self.pad_index]
+ self.word_dropout = WordDropout(p=config.word_dropout, oov_token=oov, exclude_tokens=excludes)
+ else:
+ self.word_dropout = None
+ if transformer:
+ input_size = 0
+ if self.config.transformer_lr:
+ hidden_size = transformer.config.hidden_size
+ else:
+ input_size = transformer.config.hidden_size
+ hidden_size = config.n_lstm_hidden * 2
+ if config.feat == 'pos':
+ self.feat_embed = nn.Embedding(num_embeddings=config.n_feats,
+ embedding_dim=config.n_embed)
+ self.embed_dropout = IndependentDropout(p=config.embed_dropout)
+ if self.config.transformer_lr:
+ hidden_size += config.n_embed
+ else:
+ input_size += config.n_embed
+ if not self.config.transformer_lr:
+ self.lstm = VariationalLSTM(input_size=input_size,
+ hidden_size=config.n_lstm_hidden,
+ num_layers=config.n_lstm_layers,
+ dropout=config.hidden_dropout, bidirectional=True)
+ else:
+ # the embedding layer
+ input_size = config.n_embed
+ self.word_embed = nn.Embedding(num_embeddings=config.n_words,
+ embedding_dim=config.n_embed)
+ if pretrained_embed is not None:
+ if not isinstance(pretrained_embed, torch.Tensor):
+ pretrained_embed = torch.Tensor(pretrained_embed)
+ self.pretrained = nn.Embedding.from_pretrained(pretrained_embed)
+ nn.init.zeros_(self.word_embed.weight)
+ if config.feat == 'pos':
+ self.feat_embed = nn.Embedding(num_embeddings=config.n_feats,
+ embedding_dim=config.n_embed)
+ self.embed_dropout = IndependentDropout(p=config.embed_dropout)
+ input_size += config.n_embed
+
+ # the word-lstm layer
+ hidden_size = config.n_lstm_hidden * 2
+ self.lstm = VariationalLSTM(input_size=input_size,
+ hidden_size=config.n_lstm_hidden,
+ num_layers=config.n_lstm_layers,
+ dropout=config.hidden_dropout, bidirectional=True)
+ self.hidden_size = hidden_size
+ self.hidden_dropout = SharedDropout(p=config.hidden_dropout)
+ if transformer:
+ transformer = TransformerEncoder(transformer, transformer_tokenizer, config.average_subwords,
+ word_dropout=config.word_dropout,
+ max_sequence_length=config.max_sequence_length)
+ self.transformer = transformer
+
+ def forward(self, words, feats, input_ids, token_span, mask, lens):
+ if mask is None:
+ # get the mask and lengths of given batch
+ mask = words.ne(self.pad_index)
+ if lens is None:
+ lens = mask.sum(dim=1)
+ batch_size, seq_len = mask.shape
+ if self.config.transformer:
+ # trans_embed = self.run_transformer(input_ids, token_span=token_span)
+ trans_embed = self.transformer.forward(input_ids, token_span=token_span)
+ if hasattr(self, 'feat_embed'):
+ feat_embed = self.feat_embed(feats)
+ trans_embed, feat_embed = self.embed_dropout(trans_embed, feat_embed)
+ embed = torch.cat((trans_embed, feat_embed), dim=-1)
+ else:
+ embed = trans_embed
+ if hasattr(self, 'lstm'):
+ x = self.run_rnn(embed, lens, seq_len)
+ else:
+ x = embed
+ if self.secondary_encoder:
+ x = self.secondary_encoder(x, mask)
+ x = self.hidden_dropout(x)
+ else:
+ if self.word_dropout:
+ words = self.word_dropout(words)
+ # set the indices larger than num_embeddings to unk_index
+ ext_mask = words.ge(self.word_embed.num_embeddings)
+ ext_words = words.masked_fill(ext_mask, self.unk_index)
+
+ # get outputs from embedding layers
+ word_embed = self.word_embed(ext_words)
+ if hasattr(self, 'pretrained'):
+ word_embed += self.pretrained(words)
+ if self.config.feat == 'char':
+ feat_embed = self.feat_embed(feats[mask])
+ feat_embed = pad_sequence(feat_embed.split(lens.tolist()), True)
+ elif self.config.feat == 'bert':
+ feat_embed = self.feat_embed(*feats)
+ elif hasattr(self, 'feat_embed'):
+ feat_embed = self.feat_embed(feats)
+ else:
+ feat_embed = None
+ if feat_embed is not None:
+ word_embed, feat_embed = self.embed_dropout(word_embed, feat_embed)
+ # concatenate the word and feat representations
+ embed = torch.cat((word_embed, feat_embed), dim=-1)
+ else:
+ embed = word_embed
+
+ x = self.run_rnn(embed, lens, seq_len)
+ x = self.hidden_dropout(x)
+ return x, mask
+
+ def run_rnn(self, embed, lens, seq_len):
+ x = pack_padded_sequence(embed, lens, True, False)
+ x, _ = self.lstm(x)
+ x, _ = pad_packed_sequence(x, True, total_length=seq_len)
+ return x
+
+ def run_transformer(self, input_ids, token_span):
+ return transformer_encode(self.transformer, input_ids, None, None, token_span,
+ average_subwords=self.config.average_subwords)
+
+
+class BiaffineDecoder(nn.Module):
+ def __init__(self, hidden_size, n_mlp_arc, n_mlp_rel, mlp_dropout, n_rels, arc_dropout=None,
+ rel_dropout=None) -> None:
+ super().__init__()
+ # the MLP layers
+ self.mlp_arc_h = MLP(hidden_size,
+ n_mlp_arc,
+ dropout=arc_dropout or mlp_dropout)
+ self.mlp_arc_d = MLP(hidden_size,
+ n_mlp_arc,
+ dropout=arc_dropout or mlp_dropout)
+ self.mlp_rel_h = MLP(hidden_size,
+ n_mlp_rel,
+ dropout=rel_dropout or mlp_dropout)
+ self.mlp_rel_d = MLP(hidden_size,
+ n_mlp_rel,
+ dropout=rel_dropout or mlp_dropout)
+
+ # the Biaffine layers
+ self.arc_attn = Biaffine(n_in=n_mlp_arc,
+ bias_x=True,
+ bias_y=False)
+ self.rel_attn = Biaffine(n_in=n_mlp_rel,
+ n_out=n_rels,
+ bias_x=True,
+ bias_y=True)
+
+ def forward(self, x, mask=None, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]:
+ arc_d, arc_h, rel_d, rel_h = self.apply_mlps(x)
+
+ s_arc, s_rel = self.decode(arc_d, arc_h, rel_d, rel_h, mask, self.arc_attn, self.rel_attn)
+
+ return s_arc, s_rel
+
+ @staticmethod
+ def decode(arc_d, arc_h, rel_d, rel_h, mask, arc_attn, rel_attn):
+ # get arc and rel scores from the bilinear attention
+ # [batch_size, seq_len, seq_len]
+ s_arc = arc_attn(arc_d, arc_h)
+ # [batch_size, seq_len, seq_len, n_rels]
+ s_rel = rel_attn(rel_d, rel_h).permute(0, 2, 3, 1)
+ if mask is not None:
+ # set the scores that exceed the length of each sentence to -inf
+ s_arc.masked_fill_(~mask.unsqueeze(1), float('-inf'))
+ return s_arc, s_rel
+
+ def apply_mlps(self, x):
+ # apply MLPs to the hidden states
+ arc_d = self.mlp_arc_d(x)
+ arc_h = self.mlp_arc_h(x)
+ rel_d = self.mlp_rel_d(x)
+ rel_h = self.mlp_rel_h(x)
+ return arc_d, arc_h, rel_d, rel_h
+
+
+class BiaffineDependencyModel(nn.Module):
+
+ def __init__(self, config, pretrained_embed: torch.Tensor = None, transformer: PreTrainedModel = None,
+ transformer_tokenizer: PreTrainedTokenizer = None):
+ super().__init__()
+ self.encoder = EncoderWithContextualLayer(config, pretrained_embed, transformer, transformer_tokenizer)
+ self.biaffine_decoder = BiaffineDecoder(self.encoder.hidden_size,
+ config.n_mlp_arc,
+ config.n_mlp_rel,
+ config.mlp_dropout,
+ config.n_rels)
+
+ def forward(self,
+ words=None,
+ feats=None,
+ input_ids=None,
+ token_span=None,
+ mask=None, lens=None, **kwargs):
+ x, mask = self.encoder(words, feats, input_ids, token_span, mask, lens)
+ s_arc, s_rel = self.biaffine_decoder(x, mask)
+
+ return s_arc, s_rel
diff --git a/hanlp/components/parsers/biaffine/biaffine_sdp.py b/hanlp/components/parsers/biaffine/biaffine_sdp.py
new file mode 100644
index 000000000..29dd34b8e
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/biaffine_sdp.py
@@ -0,0 +1,200 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-28 15:12
+import functools
+from collections import Counter
+from typing import Union, List
+
+import torch
+from torch import nn
+
+from hanlp_common.constant import UNK
+from hanlp.common.transform import TransformList
+from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
+from hanlp_common.conll import CoNLLUWord, CoNLLSentence
+from hanlp.datasets.parsing.semeval15 import unpack_deps_to_head_deprel, append_bos_to_form_pos
+from hanlp.metrics.parsing.labeled_f1 import LabeledF1
+from hanlp_common.util import merge_locals_kwargs
+
+
+class BiaffineSemanticDependencyParser(BiaffineDependencyParser):
+ def __init__(self) -> None:
+ """Implementation of "Stanford's graph-based neural dependency parser at
+ the conll 2017 shared task" (:cite:`dozat2017stanford`).
+ """
+ super().__init__()
+
+ def get_pad_dict(self):
+ return {'arc': False}
+
+ def build_metric(self, **kwargs):
+ return LabeledF1()
+
+ # noinspection PyMethodOverriding
+ def build_dataset(self, data, transform=None):
+ transforms = TransformList(functools.partial(append_bos_to_form_pos, pos_key='UPOS'),
+ functools.partial(unpack_deps_to_head_deprel, pad_rel=self.config.pad_rel))
+ if transform:
+ transforms.append(transform)
+ return super(BiaffineSemanticDependencyParser, self).build_dataset(data, transforms)
+
+ def build_criterion(self, **kwargs):
+ return nn.BCEWithLogitsLoss(), nn.CrossEntropyLoss()
+
+ def feed_batch(self, batch):
+ arc_scores, rel_scores, mask, puncts = super().feed_batch(batch)
+ mask = self.convert_to_3d_mask(arc_scores, mask)
+ puncts = self.convert_to_3d_puncts(puncts, mask)
+ return arc_scores, rel_scores, mask, puncts
+
+ @staticmethod
+ def convert_to_3d_puncts(puncts, mask):
+ if puncts is not None:
+ puncts = puncts.unsqueeze(-1).expand_as(mask)
+ return puncts
+
+ @staticmethod
+ def convert_to_3d_mask(arc_scores, mask):
+ # 3d masks
+ mask = mask.unsqueeze(-1).expand_as(arc_scores)
+ mask = mask & mask.transpose(1, 2)
+ return mask
+
+ def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask: torch.BoolTensor, criterion, batch=None):
+ bce, ce = criterion
+ arc_scores, arcs = arc_scores[mask], arcs[mask]
+ rel_scores, rels = rel_scores[mask], rels[mask]
+ rel_scores, rels = rel_scores[arcs], rels[arcs]
+ arc_loss = bce(arc_scores, arcs.to(torch.float))
+ arc_loss_interpolation = self.config.get('arc_loss_interpolation', None)
+ loss = arc_loss * arc_loss_interpolation if arc_loss_interpolation else arc_loss
+ if len(rels):
+ rel_loss = ce(rel_scores, rels)
+ loss += (rel_loss * (1 - arc_loss_interpolation)) if arc_loss_interpolation else rel_loss
+ if arc_loss_interpolation:
+ loss *= 2
+ return loss
+
+ def cache_dataset(self, dataset, timer, training=False, logger=None):
+ if not self.config.apply_constraint:
+ return super(BiaffineSemanticDependencyParser, self).cache_dataset(dataset, timer, training)
+ num_roots = Counter()
+ no_zero_head = True
+ root_rels = Counter()
+ for each in dataset:
+ if training:
+ num_roots[sum([x[0] for x in each['arc']])] += 1
+ no_zero_head &= all([x != '_' for x in each['DEPS']])
+ head_is_root = [i for i in range(len(each['arc'])) if each['arc'][i][0]]
+ if head_is_root:
+ for i in head_is_root:
+ root_rels[each['rel'][i][0]] += 1
+ timer.log('Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]')
+ if training:
+ if self.config.single_root is None:
+ self.config.single_root = len(num_roots) == 1 and num_roots.most_common()[0][0] == 1
+ if self.config.no_zero_head is None:
+ self.config.no_zero_head = no_zero_head
+ root_rel = root_rels.most_common()[0][0]
+ self.config.root_rel_id = self.vocabs['rel'].get_idx(root_rel)
+ if logger:
+ logger.info(f'Training set properties: [blue]single_root = {self.config.single_root}[/blue], '
+ f'[blue]no_zero_head = {no_zero_head}[/blue], '
+ f'[blue]root_rel = {root_rel}[/blue]')
+
+ def decode(self, arc_scores, rel_scores, mask, batch=None):
+ eye = torch.arange(0, arc_scores.size(1), device=arc_scores.device).view(1, 1, -1).expand(
+ arc_scores.size(0), -1, -1)
+ inf = float('inf')
+ arc_scores.scatter_(dim=1, index=eye, value=-inf)
+
+ if self.config.apply_constraint:
+ if self.config.get('single_root', False):
+ root_mask = arc_scores[:, :, 0].argmax(dim=-1).unsqueeze_(-1).expand_as(arc_scores[:, :, 0])
+ arc_scores[:, :, 0] = -inf
+ arc_scores[:, :, 0].scatter_(dim=-1, index=root_mask, value=inf)
+
+ root_rel_id = self.config.root_rel_id
+ rel_scores[:, :, 0, root_rel_id] = inf
+ rel_scores[:, :, 1:, root_rel_id] = -inf
+
+ arc_scores_T = arc_scores.transpose(-1, -2)
+ arc = ((arc_scores > 0) & (arc_scores_T < arc_scores))
+ if self.config.get('no_zero_head', False):
+ arc_scores_fix = arc_scores_T.argmax(dim=-2).unsqueeze_(-1).expand_as(arc_scores)
+ arc.scatter_(dim=-1, index=arc_scores_fix, value=True)
+ else:
+ arc = arc_scores > 0
+ rel = rel_scores.argmax(dim=-1)
+ return arc, rel
+
+ def collect_outputs_extend(self, predictions, arc_preds, rel_preds, lens, mask):
+ predictions.extend(zip(arc_preds.tolist(), rel_preds.tolist(), mask.tolist()))
+ # all_arcs.extend(seq.tolist() for seq in arc_preds[mask].split([x * x for x in lens]))
+ # all_rels.extend(seq.tolist() for seq in rel_preds[mask].split([x * x for x in lens]))
+
+ def predictions_to_human(self, predictions, outputs, data, use_pos):
+ for d, (arcs, rels, masks) in zip(data, predictions):
+ sent = CoNLLSentence()
+ for idx, (cell, a, r) in enumerate(zip(d, arcs[1:], rels[1:])):
+ if use_pos:
+ token, pos = cell
+ else:
+ token, pos = cell, None
+ heads = [i for i in range(len(d) + 1) if a[i]]
+ deprels = [self.vocabs['rel'][r[i]] for i in range(len(d) + 1) if a[i]]
+ sent.append(
+ CoNLLUWord(idx + 1, token, upos=pos, head=None, deprel=None, deps=list(zip(heads, deprels))))
+ outputs.append(sent)
+
+ def fit(self, trn_data, dev_data, save_dir,
+ feat=None,
+ n_embed=100,
+ pretrained_embed=None,
+ transformer=None,
+ average_subwords=False,
+ word_dropout: float = 0.2,
+ transformer_hidden_dropout=None,
+ layer_dropout=0,
+ mix_embedding: int = None,
+ embed_dropout=.33,
+ n_lstm_hidden=400,
+ n_lstm_layers=3,
+ hidden_dropout=.33,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ arc_dropout=None,
+ rel_dropout=None,
+ arc_loss_interpolation=0.4,
+ lr=2e-3,
+ transformer_lr=5e-5,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ clip=5.0,
+ decay=.75,
+ decay_steps=5000,
+ weight_decay=0,
+ warmup_steps=0.1,
+ separate_optimizer=True,
+ patience=100,
+ batch_size=None,
+ sampler_builder=None,
+ lowercase=False,
+ epochs=50000,
+ apply_constraint=False,
+ single_root=None,
+ no_zero_head=None,
+ punct=False,
+ min_freq=2,
+ logger=None,
+ verbose=True,
+ unk=UNK,
+ pad_rel=None,
+ max_sequence_length=512,
+ gradient_accumulation=1,
+ devices: Union[float, int, List[int]] = None,
+ transform=None,
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
diff --git a/hanlp/components/parsers/biaffine/mlp.py b/hanlp/components/parsers/biaffine/mlp.py
new file mode 100644
index 000000000..5798d853f
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/mlp.py
@@ -0,0 +1,83 @@
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+
+import torch.nn as nn
+
+from hanlp.layers.dropout import SharedDropout
+
+
+class MLP(nn.Module):
+ r"""
+ Applies a linear transformation together with a non-linear activation to the incoming tensor:
+ :math:`y = \mathrm{Activation}(x A^T + b)`
+
+ Args:
+ n_in (~torch.Tensor):
+ The size of each input feature.
+ n_out (~torch.Tensor):
+ The size of each output feature.
+ dropout (float):
+ If non-zero, introduce a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0.
+ activation (bool):
+ Whether to use activations. Default: True.
+ """
+
+ def __init__(self, n_in, n_out, dropout=0, activation=True):
+ super().__init__()
+
+ self.n_in = n_in
+ self.n_out = n_out
+ self.linear = nn.Linear(n_in, n_out)
+ self.activation = nn.LeakyReLU(negative_slope=0.1) if activation else nn.Identity()
+ self.dropout = SharedDropout(p=dropout)
+
+ self.reset_parameters()
+
+ def __repr__(self):
+ s = f"n_in={self.n_in}, n_out={self.n_out}"
+ if self.dropout.p > 0:
+ s += f", dropout={self.dropout.p}"
+
+ return f"{self.__class__.__name__}({s})"
+
+ def reset_parameters(self):
+ nn.init.orthogonal_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x):
+ r"""
+ Args:
+ x (~torch.Tensor):
+ The size of each input feature is `n_in`.
+
+ Returns:
+ A tensor with the size of each output feature `n_out`.
+ """
+
+ x = self.linear(x)
+ x = self.activation(x)
+ x = self.dropout(x)
+
+ return x
+
diff --git a/hanlp/components/parsers/biaffine/model.py b/hanlp/components/parsers/biaffine/model.py
deleted file mode 100644
index 93d02e446..000000000
--- a/hanlp/components/parsers/biaffine/model.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-26 23:04
-import tensorflow as tf
-
-from hanlp.components.parsers.biaffine.layers import IndependentDropout, SharedDropout, Biaffine, \
- MLP
-
-
-class BiaffineModel(tf.keras.Model):
-
- def __init__(self, config, embed):
- """
- An implementation of T. Dozat and C. D. Manning, “Deep Biaffine Attention for Neural Dependency Parsing.,” ICLR, 2017.
- Although I have my MXNet implementation, I found zysite's PyTorch implementation is cleaner so I port it to TensorFlow
- :param config:
- :param embed:
- """
- super(BiaffineModel, self).__init__()
-
- normal = tf.keras.initializers.RandomNormal(stddev=1.)
- # the embedding layer
- self.word_embed = tf.keras.layers.Embedding(input_dim=config.n_words,
- output_dim=config.n_embed,
- embeddings_initializer=tf.keras.initializers.zeros() if embed
- else normal,
- name='word_embed')
- self.feat_embed = tf.keras.layers.Embedding(input_dim=config.n_feats,
- output_dim=config.n_embed,
- embeddings_initializer=tf.keras.initializers.zeros() if embed
- else normal,
- name='feat_embed')
- self.embed_dropout = IndependentDropout(p=config.embed_dropout, name='embed_dropout')
-
- # the word-lstm layer
- self.lstm = tf.keras.models.Sequential(name='lstm')
- for _ in range(config.n_lstm_layers):
- self.lstm.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(
- units=config.n_lstm_hidden,
- dropout=config.lstm_dropout,
- recurrent_dropout=config.lstm_dropout,
- return_sequences=True,
- kernel_initializer='orthogonal',
- unit_forget_bias=False, # turns out to hinder performance
- )))
- self.lstm_dropout = SharedDropout(p=config.lstm_dropout, name='lstm_dropout')
-
- # the MLP layers
- self.mlp_arc_h = MLP(n_hidden=config.n_mlp_arc,
- dropout=config.mlp_dropout, name='mlp_arc_h')
- self.mlp_arc_d = MLP(n_hidden=config.n_mlp_arc,
- dropout=config.mlp_dropout, name='mlp_arc_d')
- self.mlp_rel_h = MLP(n_hidden=config.n_mlp_rel,
- dropout=config.mlp_dropout, name='mlp_rel_h')
- self.mlp_rel_d = MLP(n_hidden=config.n_mlp_rel,
- dropout=config.mlp_dropout, name='mlp_rel_d')
-
- # the Biaffine layers
- self.arc_attn = Biaffine(n_in=config.n_mlp_arc,
- bias_x=True,
- bias_y=False, name='arc_attn')
- self.rel_attn = Biaffine(n_in=config.n_mlp_rel,
- n_out=config.n_rels,
- bias_x=True,
- bias_y=True, name='rel_attn')
- if embed is not None:
- self.pretrained = embed
- self.pad_index = tf.constant(config.pad_index, dtype=tf.int64)
- self.unk_index = tf.constant(config.unk_index, dtype=tf.int64)
-
- # noinspection PyMethodOverriding
- def call(self, inputs, mask_inf=True, **kwargs):
- words, feats = inputs
- # batch_size, seq_len = words.shape
- # get the mask and lengths of given batch
- # mask = words.ne(self.pad_index)
- mask = tf.not_equal(words, self.pad_index)
- # set the indices larger than num_embeddings to unk_index
- # ext_mask = words.ge(self.word_embed.num_embeddings)
- ext_mask = tf.greater_equal(words, self.word_embed.input_dim)
- ext_words = tf.where(ext_mask, self.unk_index, words)
-
- # get outputs from embedding layers
- word_embed = self.word_embed(ext_words)
- if hasattr(self, 'pretrained'):
- word_embed += self.pretrained(words)
- feat_embed = self.feat_embed(feats)
- word_embed, feat_embed = self.embed_dropout([word_embed, feat_embed])
- # concatenate the word and feat representations
- embed = tf.concat((word_embed, feat_embed), axis=-1)
-
- x = self.lstm(embed, mask=mask)
- x = self.lstm_dropout(x)
-
- # apply MLPs to the BiLSTM output states
- arc_h = self.mlp_arc_h(x)
- arc_d = self.mlp_arc_d(x)
- rel_h = self.mlp_rel_h(x)
- rel_d = self.mlp_rel_d(x)
-
- # get arc and rel scores from the bilinear attention
- # [batch_size, seq_len, seq_len]
- s_arc = self.arc_attn(arc_d, arc_h)
- # [batch_size, seq_len, seq_len, n_rels]
- s_rel = tf.transpose(self.rel_attn(rel_d, rel_h), [0, 2, 3, 1])
- # set the scores that exceed the length of each sentence to -inf
- if mask_inf:
- s_arc = tf.where(tf.expand_dims(mask, 1), s_arc, float('-inf'))
-
- return s_arc, s_rel
-
- def to_functional(self):
- words = tf.keras.Input(shape=[None], dtype=tf.int64, name='words')
- feats = tf.keras.Input(shape=[None], dtype=tf.int64, name='feats')
- s_arc, s_rel = self.call([words, feats], mask_inf=False)
- return tf.keras.Model(inputs=[words, feats], outputs=[s_arc, s_rel])
diff --git a/hanlp/components/parsers/biaffine/structual_attention.py b/hanlp/components/parsers/biaffine/structual_attention.py
new file mode 100644
index 000000000..510b7f10d
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/structual_attention.py
@@ -0,0 +1,234 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-26 10:40
+from typing import Union, List
+
+import torch
+import torch.nn.functional as F
+from alnlp.modules.util import lengths_to_mask
+from torch import nn
+
+from hanlp.common.torch_component import TorchComponent
+from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
+from hanlp.components.parsers.biaffine.biaffine_model import BiaffineDecoder
+from hanlp.layers.transformers.encoder import TransformerEncoder
+from hanlp.layers.transformers.pt_imports import PreTrainedModel, PreTrainedTokenizer
+from hanlp.metrics.accuracy import CategoricalAccuracy
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp_common.util import merge_locals_kwargs
+
+
+class StructuralAttentionLayer(nn.Module):
+
+ def __init__(self, hidden_size, n_mlp_arc, n_mlp_rel, mlp_dropout, n_rels, projeciton=None) -> None:
+ super().__init__()
+ self.biaffine = BiaffineDecoder(hidden_size, n_mlp_arc, n_mlp_rel, mlp_dropout, n_rels)
+ if projeciton:
+ self.projection = nn.Linear(hidden_size, projeciton)
+ hidden_size = projeciton
+ else:
+ self.projection = None
+ self.head_WV = nn.Parameter(torch.randn(n_rels, hidden_size, hidden_size))
+ self.dense = nn.Linear(hidden_size * n_rels, hidden_size)
+ self.activation = nn.GELU()
+
+ def forward(self, x, mask):
+ s_arc, s_rel = self.biaffine(x, mask)
+ p_arc = F.softmax(s_arc, dim=-1) * mask.unsqueeze(-1)
+ p_rel = F.softmax(s_rel, -1)
+ A = p_arc.unsqueeze(-1) * p_rel
+ if self.projection:
+ x = self.projection(x)
+ Ax = torch.einsum('bijk,bih->bihk', A, x)
+ AxW = torch.einsum('bihk,khm->bihk', Ax, self.head_WV)
+ AxW = AxW.flatten(2)
+ x = self.dense(AxW)
+ x = self.activation(x)
+ return s_arc, s_rel, x
+
+
+class StructuralAttentionModel(nn.Module):
+ def __init__(self,
+ config,
+ transformer: PreTrainedModel = None,
+ transformer_tokenizer: PreTrainedTokenizer = None
+ ) -> None:
+ super().__init__()
+ self.encoder = TransformerEncoder(transformer,
+ transformer_tokenizer,
+ config.average_subwords,
+ config.scalar_mix,
+ None, # No word_dropout since SA is predicting masked tokens
+ config.transformer_hidden_dropout,
+ config.layer_dropout,
+ config.max_sequence_length)
+ hidden_size = transformer.config.hidden_size
+ self.sa = StructuralAttentionLayer(hidden_size,
+ config.n_mlp_arc,
+ config.n_mlp_rel,
+ config.mlp_dropout,
+ config.n_rels,
+ config.projection
+ )
+ if config.projection:
+ hidden_size = config.projection
+ self.mlm = nn.Linear(hidden_size, transformer_tokenizer.vocab_size)
+
+ def forward(self,
+ input_ids: torch.LongTensor,
+ attention_mask=None,
+ token_type_ids=None,
+ token_span=None,
+ mask=None,
+ batch=None,
+ **kwargs):
+ h = self.encoder(input_ids, attention_mask, token_type_ids, token_span)
+ s_arc, s_rel, h = self.sa(h, mask)
+ x = self.mlm(h)
+ return s_arc, s_rel, x
+
+
+class MaskedTokenGenerator(object):
+
+ def __init__(self, transformer_tokenizer: PreTrainedTokenizer, mask_prob=0.15) -> None:
+ super().__init__()
+ self.mask_prob = mask_prob
+ self.transformer_tokenizer = transformer_tokenizer
+ self.oov = transformer_tokenizer.mask_token_id
+ self.pad = transformer_tokenizer.pad_token_id
+ self.cls = transformer_tokenizer.cls_token_id
+ self.sep = transformer_tokenizer.sep_token_id
+ self.excludes = [self.pad, self.cls, self.sep]
+
+ def __call__(self, tokens: torch.LongTensor, prefix_mask: torch.LongTensor):
+ padding_mask = tokens.new_ones(tokens.size(), dtype=torch.bool)
+ for pad in self.excludes:
+ padding_mask &= (tokens != pad)
+ padding_mask &= prefix_mask # Only mask prefixes since the others won't be attended
+ # Create a uniformly random mask selecting either the original words or OOV tokens
+ dropout_mask = (tokens.new_empty(tokens.size(), dtype=torch.float).uniform_() < self.mask_prob)
+ oov_mask = dropout_mask & padding_mask
+
+ oov_fill = tokens.new_empty(tokens.size(), dtype=torch.long).fill_(self.oov)
+
+ result = torch.where(oov_mask, oov_fill, tokens)
+ return result, oov_mask
+
+
+class StructuralAttentionParser(BiaffineDependencyParser):
+ def __init__(self) -> None:
+ super().__init__()
+ self.model: StructuralAttentionModel = None
+ self.mlm_generator: MaskedTokenGenerator = None
+
+ def build_model(self, training=True, **kwargs) -> torch.nn.Module:
+ transformer = TransformerEncoder.build_transformer(config=self.config, training=training)
+ model = StructuralAttentionModel(self.config, transformer, self.transformer_tokenizer)
+ return model
+
+ def fit(self, trn_data, dev_data, save_dir,
+ transformer=None,
+ mask_prob=0.15,
+ projection=None,
+ average_subwords=False,
+ transformer_hidden_dropout=None,
+ layer_dropout=0,
+ mix_embedding: int = None,
+ embed_dropout=.33,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ lr=2e-3,
+ transformer_lr=5e-5,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ clip=5.0,
+ decay=.75,
+ decay_steps=5000,
+ patience=100,
+ sampler='kmeans',
+ n_buckets=32,
+ batch_max_tokens=5000,
+ batch_size=None,
+ epochs=50000,
+ tree=False,
+ punct=False,
+ logger=None,
+ verbose=True,
+ max_sequence_length=512,
+ devices: Union[float, int, List[int]] = None,
+ transform=None,
+ **kwargs):
+ return TorchComponent.fit(self, **merge_locals_kwargs(locals(), kwargs))
+
+ def feed_batch(self, batch):
+ if self.model.training:
+ input_ids = batch['input_ids']
+ prefix_mask = batch['prefix_mask']
+ batch['gold_input_ids'] = input_ids
+ batch['input_ids'], batch['input_ids_mask'] = self.mlm_generator(input_ids, prefix_mask)
+ words, feats, lens, puncts = batch.get('token_id', None), batch.get('pos_id', None), batch['sent_length'], \
+ batch.get('punct_mask', None)
+ mask = lengths_to_mask(lens)
+ arc_scores, rel_scores, pred_input_ids = self.model(words=words, feats=feats, mask=mask, batch=batch, **batch)
+ batch['pred_input_ids'] = pred_input_ids
+ # ignore the first token of each sentence
+ # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+ if self.model.training:
+ mask = mask.clone()
+ mask[:, 0] = 0
+ return arc_scores, rel_scores, mask, puncts
+
+ def on_config_ready(self, **kwargs):
+ super().on_config_ready(**kwargs)
+ self.mlm_generator = MaskedTokenGenerator(self.transformer_tokenizer, self.config.mask_prob)
+
+ def compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch=None):
+ parse_loss = BiaffineDependencyParser.compute_loss(self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch)
+ if self.model.training:
+ gold_input_ids = batch['gold_input_ids']
+ pred_input_ids = batch['pred_input_ids']
+ input_ids_mask = batch['input_ids_mask']
+ token_span = batch['token_span']
+ gold_input_ids = batch['gold_input_ids'] = gold_input_ids.gather(1, token_span[:, :, 0])
+ input_ids_mask = batch['input_ids_mask'] = input_ids_mask.gather(1, token_span[:, :, 0])
+ mlm_loss = F.cross_entropy(pred_input_ids[input_ids_mask], gold_input_ids[input_ids_mask])
+ loss = parse_loss + mlm_loss
+ return loss
+ return parse_loss
+
+ def build_tokenizer_transform(self):
+ return TransformerSequenceTokenizer(self.transformer_tokenizer, 'token', '', ret_prefix_mask=True,
+ ret_token_span=True, cls_is_bos=True,
+ max_seq_length=self.config.get('max_sequence_length',
+ 512),
+ truncate_long_sequences=False)
+
+ def build_metric(self, training=None, **kwargs):
+ parse_metric = super().build_metric(**kwargs)
+ if training:
+ mlm_metric = CategoricalAccuracy()
+ return parse_metric, mlm_metric
+ return parse_metric
+
+ def update_metric(self, arc_scores, rel_scores, arcs, rels, mask, puncts, metric, batch=None):
+ if isinstance(metric, tuple):
+ parse_metric, mlm_metric = metric
+ super().update_metric(arc_scores, rel_scores, arcs, rels, mask, puncts, parse_metric)
+ gold_input_ids = batch['gold_input_ids']
+ input_ids_mask = batch['input_ids_mask']
+ pred_input_ids = batch['pred_input_ids']
+ pred_input_ids = pred_input_ids[input_ids_mask]
+ gold_input_ids = gold_input_ids[input_ids_mask]
+ if len(pred_input_ids):
+ mlm_metric(pred_input_ids, gold_input_ids)
+ else:
+ super().update_metric(arc_scores, rel_scores, arcs, rels, mask, puncts, metric)
+
+ def _report(self, loss, metric):
+ if isinstance(metric, tuple):
+ parse_metric, mlm_metric = metric
+ return super()._report(loss, parse_metric) + f' {mlm_metric}'
+ else:
+ return super()._report(loss, metric)
diff --git a/hanlp/components/parsers/biaffine/variationalbilstm.py b/hanlp/components/parsers/biaffine/variationalbilstm.py
new file mode 100644
index 000000000..1a5373a6f
--- /dev/null
+++ b/hanlp/components/parsers/biaffine/variationalbilstm.py
@@ -0,0 +1,231 @@
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.rnn import apply_permutation
+from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
+
+from hanlp.common.structure import ConfigTracker
+from hanlp.layers.dropout import SharedDropout
+
+
+class VariationalLSTM(nn.Module):
+ r"""
+ LSTM is an variant of the vanilla bidirectional LSTM adopted by Biaffine Parser
+ with the only difference of the dropout strategy.
+ It drops nodes in the LSTM layers (input and recurrent connections)
+ and applies the same dropout mask at every recurrent timesteps.
+
+ APIs are roughly the same as :class:`~torch.nn.LSTM` except that we only allows
+ :class:`~torch.nn.utils.rnn.PackedSequence` as input.
+
+ References:
+ - Timothy Dozat and Christopher D. Manning. 2017.
+ `Deep Biaffine Attention for Neural Dependency Parsing`_.
+
+ Args:
+ input_size (int):
+ The number of expected features in the input.
+ hidden_size (int):
+ The number of features in the hidden state `h`.
+ num_layers (int):
+ The number of recurrent layers. Default: 1.
+ bidirectional (bool):
+ If ``True``, becomes a bidirectional LSTM. Default: ``False``
+ dropout (float):
+ If non-zero, introduces a :class:`SharedDropout` layer on the outputs of each LSTM layer except the last layer.
+ Default: 0.
+
+ .. _Deep Biaffine Attention for Neural Dependency Parsing:
+ https://openreview.net/forum?id=Hk95PK9le
+ """
+
+ def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False, dropout=0):
+ super().__init__()
+
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.bidirectional = bidirectional
+ self.dropout = dropout
+ self.num_directions = 1 + self.bidirectional
+
+ self.f_cells = nn.ModuleList()
+ if bidirectional:
+ self.b_cells = nn.ModuleList()
+ for _ in range(self.num_layers):
+ self.f_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size))
+ if bidirectional:
+ self.b_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=hidden_size))
+ input_size = hidden_size * self.num_directions
+
+ self.reset_parameters()
+
+ def __repr__(self):
+ s = f"{self.input_size}, {self.hidden_size}"
+ if self.num_layers > 1:
+ s += f", num_layers={self.num_layers}"
+ if self.bidirectional:
+ s += f", bidirectional={self.bidirectional}"
+ if self.dropout > 0:
+ s += f", dropout={self.dropout}"
+
+ return f"{self.__class__.__name__}({s})"
+
+ def reset_parameters(self):
+ for param in self.parameters():
+ # apply orthogonal_ to weight
+ if len(param.shape) > 1:
+ nn.init.orthogonal_(param)
+ # apply zeros_ to bias
+ else:
+ nn.init.zeros_(param)
+
+ def permute_hidden(self, hx, permutation):
+ if permutation is None:
+ return hx
+ h = apply_permutation(hx[0], permutation)
+ c = apply_permutation(hx[1], permutation)
+
+ return h, c
+
+ def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
+ hx_0 = hx_i = hx
+ hx_n, output = [], []
+ steps = reversed(range(len(x))) if reverse else range(len(x))
+ if self.training:
+ hid_mask = SharedDropout.get_mask(hx_0[0], self.dropout)
+
+ for t in steps:
+ last_batch_size, batch_size = len(hx_i[0]), batch_sizes[t]
+ if last_batch_size < batch_size:
+ hx_i = [torch.cat((h, ih[last_batch_size:batch_size])) for h, ih in zip(hx_i, hx_0)]
+ else:
+ hx_n.append([h[batch_size:] for h in hx_i])
+ hx_i = [h[:batch_size] for h in hx_i]
+ hx_i = [h for h in cell(x[t], hx_i)]
+ output.append(hx_i[0])
+ if self.training:
+ hx_i[0] = hx_i[0] * hid_mask[:batch_size]
+ if reverse:
+ hx_n = hx_i
+ output.reverse()
+ else:
+ hx_n.append(hx_i)
+ hx_n = [torch.cat(h) for h in zip(*reversed(hx_n))]
+ output = torch.cat(output)
+
+ return output, hx_n
+
+ def forward(self, sequence, hx=None):
+ r"""
+ Args:
+ sequence (~torch.nn.utils.rnn.PackedSequence):
+ A packed variable length sequence.
+ hx (~torch.Tensor, ~torch.Tensor):
+ A tuple composed of two tensors `h` and `c`.
+ `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial hidden state
+ for each element in the batch.
+ `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial cell state
+ for each element in the batch.
+ If `hx` is not provided, both `h` and `c` default to zero.
+ Default: ``None``.
+
+ Returns:
+ ~torch.nn.utils.rnn.PackedSequence, (~torch.Tensor, ~torch.Tensor):
+ The first is a packed variable length sequence.
+ The second is a tuple of tensors `h` and `c`.
+ `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the hidden state for `t=seq_len`.
+ Like output, the layers can be separated using ``h.view(num_layers, num_directions, batch_size, hidden_size)``
+ and similarly for c.
+ `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the cell state for `t=seq_len`.
+ """
+ x, batch_sizes = sequence.data, sequence.batch_sizes.tolist()
+ batch_size = batch_sizes[0]
+ h_n, c_n = [], []
+
+ if hx is None:
+ ih = x.new_zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size)
+ h, c = ih, ih
+ else:
+ h, c = self.permute_hidden(hx, sequence.sorted_indices)
+ h = h.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
+ c = c.view(self.num_layers, self.num_directions, batch_size, self.hidden_size)
+
+ for i in range(self.num_layers):
+ x = torch.split(x, batch_sizes)
+ if self.training:
+ mask = SharedDropout.get_mask(x[0], self.dropout)
+ x = [i * mask[:len(i)] for i in x]
+ x_i, (h_i, c_i) = self.layer_forward(x=x,
+ hx=(h[i, 0], c[i, 0]),
+ cell=self.f_cells[i],
+ batch_sizes=batch_sizes)
+ if self.bidirectional:
+ x_b, (h_b, c_b) = self.layer_forward(x=x,
+ hx=(h[i, 1], c[i, 1]),
+ cell=self.b_cells[i],
+ batch_sizes=batch_sizes,
+ reverse=True)
+ x_i = torch.cat((x_i, x_b), -1)
+ h_i = torch.stack((h_i, h_b))
+ c_i = torch.stack((c_i, c_b))
+ x = x_i
+ h_n.append(h_i)
+ c_n.append(h_i)
+
+ x = PackedSequence(x,
+ sequence.batch_sizes,
+ sequence.sorted_indices,
+ sequence.unsorted_indices)
+ hx = torch.cat(h_n, 0), torch.cat(c_n, 0)
+ hx = self.permute_hidden(hx, sequence.unsorted_indices)
+
+ return x, hx
+
+
+class VariationalLSTMEncoder(VariationalLSTM, ConfigTracker):
+ def __init__(self,
+ input_size,
+ hidden_size,
+ num_layers=1,
+ bidirectional=False,
+ variational_dropout=0,
+ word_dropout=0,
+ ):
+ super().__init__(input_size, hidden_size, num_layers, bidirectional, variational_dropout)
+ ConfigTracker.__init__(self, locals())
+ self.lstm_dropout = SharedDropout(p=word_dropout)
+
+ # noinspection PyMethodOverriding
+ def forward(self, embed, mask):
+ batch_size, seq_len = mask.shape
+ x = pack_padded_sequence(embed, mask.sum(1), True, False)
+ x, _ = super().forward(x)
+ x, _ = pad_packed_sequence(x, True, total_length=seq_len)
+ x = self.lstm_dropout(x)
+ return x
+
+ def get_output_dim(self):
+ return self.hidden_size * self.num_directions
diff --git a/hanlp/components/parsers/biaffine_parser.py b/hanlp/components/parsers/biaffine_parser.py
deleted file mode 100644
index 9f34612f5..000000000
--- a/hanlp/components/parsers/biaffine_parser.py
+++ /dev/null
@@ -1,339 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-22 12:47
-import logging
-import tensorflow as tf
-from typing import List, Tuple, Union
-from hanlp.common.component import KerasComponent
-from hanlp.components.parsers.biaffine.model import BiaffineModel
-from hanlp.components.parsers.conll import CoNLLSentence, CoNLL_DEP_Transform, CoNLL_SDP_Transform
-from hanlp.layers.embeddings import build_embedding
-from hanlp.metrics.parsing.labeled_f1 import LabeledF1
-from hanlp.metrics.parsing.labeled_score import LabeledScore
-from hanlp.utils.util import merge_locals_kwargs, merge_dict
-
-
-class BiaffineDependencyParser(KerasComponent):
- def __init__(self, transform: CoNLL_DEP_Transform = None) -> None:
- if not transform:
- transform = CoNLL_DEP_Transform()
- super().__init__(transform)
- self.transform: CoNLL_DEP_Transform = transform
- self.model: BiaffineModel = None
-
- def build_model(self, pretrained_embed, n_embed, training, **kwargs) -> tf.keras.Model:
- if training:
- self.config.n_words = len(self.transform.form_vocab)
- else:
- self.config.lstm_dropout = 0. # keras will use cuda lstm when config.lstm_dropout is 0
- self.config.n_feats = len(self.transform.cpos_vocab)
- self.config.n_rels = len(self.transform.rel_vocab)
- self.config.pad_index = self.transform.form_vocab.pad_idx
- self.config.unk_index = self.transform.form_vocab.unk_idx
- self.config.bos_index = 2
- pretrained: tf.keras.layers.Embedding = build_embedding(pretrained_embed, self.transform.form_vocab,
- self.transform) if pretrained_embed else None
- if pretrained_embed:
- self.config.n_embed = pretrained.output_dim
- model = BiaffineModel(self.config, pretrained)
- return model
-
- def load_weights(self, save_dir, filename='model.h5', functional=False, **kwargs):
- super().load_weights(save_dir, filename)
- if functional:
- self.model = self.model.to_functional()
-
- def build_vocab(self, trn_data, logger):
- return super().build_vocab(trn_data, logger)
-
- def fit(self, trn_data, dev_data, save_dir,
- n_embed=100,
- pretrained_embed=None,
- embed_dropout=.33,
- n_lstm_hidden=400,
- n_lstm_layers=3,
- lstm_dropout=.33,
- n_mlp_arc=500,
- n_mlp_rel=100,
- mlp_dropout=.33,
- optimizer='adam',
- lr=2e-3,
- mu=.9,
- nu=.9,
- epsilon=1e-12,
- clip=5.0,
- decay=.75,
- decay_steps=5000,
- patience=100,
- arc_loss='sparse_categorical_crossentropy',
- rel_loss='sparse_categorical_crossentropy',
- metrics=('UAS', 'LAS'),
- n_buckets=32,
- batch_size=5000,
- epochs=50000,
- early_stopping_patience=100,
- tree=False,
- punct=False,
- min_freq=2,
- run_eagerly=False, logger=None, verbose=True,
- **kwargs):
- return super().fit(**merge_locals_kwargs(locals(), kwargs))
-
- # noinspection PyMethodOverriding
- def train_loop(self, trn_data, dev_data, epochs, num_examples,
- train_steps_per_epoch, dev_steps, model, optimizer, loss, metrics,
- callbacks, logger: logging.Logger, arc_loss, rel_loss,
- **kwargs):
- arc_loss, rel_loss = loss
- # because we are customizing batching
- train_steps_per_epoch = len(list(iter(trn_data)))
- # progbar: tf.keras.callbacks.ProgbarLogger = callbacks[-1]
- c: tf.keras.callbacks.Callback = None
- metric = self._build_metrics()
- for c in callbacks:
- if not hasattr(c, 'params'):
- params = {'verbose': 1, 'epochs': epochs, 'steps': train_steps_per_epoch}
- c.params = params
- c.set_params(params)
- c.params['metrics'] = ['loss'] + self.config.metrics
- c.params['metrics'] = c.params['metrics'] + [f'val_{k}' for k in c.params['metrics']]
- c.on_train_begin()
- for epoch in range(epochs):
- metric.reset_states()
- for c in callbacks:
- c.params['steps'] = train_steps_per_epoch
- c.on_epoch_begin(epoch)
- for idx, ((words, feats), (arcs, rels)) in enumerate(iter(trn_data)):
- logs = {}
- for c in callbacks:
- c.on_batch_begin(idx, logs)
- mask = tf.not_equal(words, self.config.pad_index) & tf.not_equal(words, self.config.bos_index)
- loss, arc_scores, rel_scores = self.train_batch(words, feats, arcs, rels, mask,
- optimizer, arc_loss, rel_loss)
- self.run_metrics(arcs, rels, arc_scores, rel_scores, words, mask, metric)
- logs['loss'] = loss
- logs.update(metric.to_dict())
- if epoch == epochs - 1:
- self.model.stop_training = True
- for c in callbacks:
- c.on_train_batch_end(idx, logs)
- # evaluate on dev
- metric.reset_states()
- logs = {}
- for idx, ((words, feats), (arcs, rels)) in enumerate(iter(dev_data)):
- arc_scores, rel_scores, loss, mask, arc_preds, rel_preds = self.evaluate_batch(words, feats, arcs, rels,
- arc_loss, rel_loss,
- metric)
- logs['val_loss'] = loss
- logs.update((f'val_{k}', v) for k, v in metric.to_dict().items())
-
- for c in callbacks:
- c.on_epoch_end(epoch, logs)
- if getattr(self.model, 'stop_training', None):
- break
-
- for c in callbacks:
- c.on_train_end()
-
- def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=None, logger: logging.Logger = None,
- callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs):
- if batch_size is None:
- batch_size = self.config.batch_size
- return super().evaluate(input_path, save_dir, output, batch_size, logger, callbacks, warm_up, verbose, **kwargs)
-
- def evaluate_batch(self, words, feats, arcs, rels, arc_loss, rel_loss, metric):
- mask = tf.not_equal(words, self.config.pad_index) & tf.not_equal(words, self.config.bos_index)
- arc_scores, rel_scores = self.model((words, feats))
- loss = self.get_loss(arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss)
- arc_preds, rel_preds = self.run_metrics(arcs, rels, arc_scores, rel_scores, words, mask, metric)
- return arc_scores, rel_scores, loss, mask, arc_preds, rel_preds
-
- def _build_metrics(self):
- if isinstance(self.config.metrics, tuple):
- self.config.metrics = list(self.config.metrics)
- if self.config.metrics == ['UAS', 'LAS']:
- metric = LabeledScore()
- else:
- metric = LabeledF1()
- return metric
-
- def run_metrics(self, arcs, rels, arc_scores, rel_scores, words, mask, metric):
- arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
- # ignore all punctuation if not specified
- if not self.config.punct:
- mask &= tf.reduce_all(tf.not_equal(tf.expand_dims(words, axis=-1), self.transform.puncts), axis=-1)
- metric(arc_preds, rel_preds, arcs, rels, mask)
- return arc_preds, rel_preds
-
- def train_batch(self, words, feats, arcs, rels, mask, optimizer, arc_loss, rel_loss):
- with tf.GradientTape() as tape:
- arc_scores, rel_scores = self.model((words, feats), training=True)
- loss = self.get_loss(arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss)
- grads = tape.gradient(loss, self.model.trainable_variables)
- optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
- return loss, arc_scores, rel_scores
-
- def get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss):
- arc_scores, arcs = arc_scores[mask], arcs[mask]
- rel_scores, rels = rel_scores[mask], rels[mask]
- rel_scores = tf.gather_nd(rel_scores, tf.stack([tf.range(len(arcs), dtype=tf.int64), arcs], axis=1))
- arc_loss = arc_loss(arcs, arc_scores)
- rel_loss = rel_loss(rels, rel_scores)
- loss = arc_loss + rel_loss
-
- return loss
-
- def build_optimizer(self, optimizer, **kwargs):
- if optimizer == 'adam':
- scheduler = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=self.config.lr,
- decay_steps=self.config.decay_steps,
- decay_rate=self.config.decay)
- optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
- beta_1=self.config.mu,
- beta_2=self.config.nu,
- epsilon=self.config.epsilon,
- clipnorm=self.config.clip)
- return optimizer
- return super().build_optimizer(optimizer, **kwargs)
-
- # noinspection PyMethodOverriding
- def build_loss(self, arc_loss, rel_loss, **kwargs):
- if arc_loss == 'binary_crossentropy':
- arc_loss = tf.losses.BinaryCrossentropy(from_logits=True)
- else:
- arc_loss = tf.keras.losses.SparseCategoricalCrossentropy(
- from_logits=True) if arc_loss == 'sparse_categorical_crossentropy' else super().build_loss(arc_loss)
- rel_loss = tf.keras.losses.SparseCategoricalCrossentropy(
- from_logits=True) if rel_loss == 'sparse_categorical_crossentropy' else super().build_loss(rel_loss)
- return arc_loss, rel_loss
-
- @property
- def sample_data(self):
- return tf.constant([[2, 3, 4], [2, 5, 0]], dtype=tf.int64), tf.constant([[1, 2, 3], [4, 5, 0]], dtype=tf.int64)
-
- def build_train_dataset(self, trn_data, batch_size, num_examples):
- trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size,
- shuffle=True,
- repeat=None)
- return trn_data
-
- # noinspection PyMethodOverriding
- def build_callbacks(self, save_dir, logger, metrics, **kwargs):
- callbacks = super().build_callbacks(save_dir,
- **merge_dict(self.config, overwrite=True, logger=logger, metrics=metrics,
- **kwargs))
- if isinstance(metrics, tuple):
- metrics = list(metrics)
- callbacks.append(self.build_progbar(metrics))
- params = {'verbose': 1, 'epochs': self.config.epochs}
- for c in callbacks:
- c.set_params(params)
- c.set_model(self.model)
- return callbacks
-
- def build_progbar(self, metrics, training=True):
- return tf.keras.callbacks.ProgbarLogger(count_mode='steps',
- stateful_metrics=metrics + [f'val_{k}' for k in metrics] if training
- else [])
-
- def decode(self, arc_scores, rel_scores, mask):
- if self.config.tree:
- # arc_preds = eisner(arc_scores, mask)
- pass
- else:
- arc_preds = tf.argmax(arc_scores, -1)
-
- rel_preds = tf.argmax(rel_scores, -1)
- rel_preds = tf.squeeze(tf.gather(rel_preds, tf.expand_dims(arc_preds, -1), batch_dims=2), axis=-1)
-
- return arc_preds, rel_preds
-
- def evaluate_dataset(self, tst_data, callbacks, output, num_batches):
- arc_loss, rel_loss = self.build_loss(**self.config)
- callbacks = [self.build_progbar(self.config['metrics'])]
- steps_per_epoch = len(list(iter(tst_data)))
- metric = self._build_metrics()
- params = {'verbose': 1, 'epochs': 1, 'metrics': ['loss'] + self.config.metrics, 'steps': steps_per_epoch}
- for c in callbacks:
- c.set_params(params)
- c.on_test_begin()
- c.on_epoch_end(0)
- logs = {}
- if output:
- output = open(output, 'w', encoding='utf-8')
- for idx, ((words, feats), (arcs, rels)) in enumerate(iter(tst_data)):
- for c in callbacks:
- c.on_test_batch_begin(idx, logs)
- arc_scores, rel_scores, loss, mask, arc_preds, rel_preds = self.evaluate_batch(words, feats, arcs, rels,
- arc_loss, rel_loss, metric)
- if output:
- for sent in self.transform.XY_to_inputs_outputs((words, feats, mask), (arc_preds, rel_preds)):
- output.write(str(sent))
- output.write('\n\n')
- logs['loss'] = loss
- logs.update(metric.to_dict())
- for c in callbacks:
- c.on_test_batch_end(idx, logs)
- for c in callbacks:
- c.on_epoch_end(0)
- c.on_test_end()
- if output:
- output.close()
- loss = float(c.progbar._values['loss'][0] / c.progbar._values['loss'][1])
- return loss, metric.to_dict(), False
-
- def predict_batch(self, batch, inputs=None, conll=True, **kwargs):
- ((words, feats), (arcs, rels)) = batch
- mask = tf.not_equal(words, self.config.pad_index) & tf.not_equal(words, self.config.bos_index)
- arc_scores, rel_scores = self.model((words, feats))
- arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
- for sent in self.transform.XY_to_inputs_outputs((words, feats, mask), (arc_preds, rel_preds), gold=False,
- inputs=inputs, conll=conll):
- yield sent
-
- def compile_model(self, optimizer, loss, metrics):
- super().compile_model(optimizer, loss, metrics)
-
-
-class BiaffineSemanticDependencyParser(BiaffineDependencyParser):
- def __init__(self, transform: CoNLL_SDP_Transform = None) -> None:
- if not transform:
- transform = CoNLL_SDP_Transform()
- # noinspection PyTypeChecker
- super().__init__(transform)
- self.transform: CoNLL_SDP_Transform = transform
-
- def fit(self, trn_data, dev_data, save_dir, n_embed=100, pretrained_embed=None, embed_dropout=.33,
- n_lstm_hidden=400, n_lstm_layers=3, lstm_dropout=.33, n_mlp_arc=500, n_mlp_rel=100, mlp_dropout=.33,
- optimizer='adam', lr=2e-3, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75, decay_steps=5000, patience=100,
- arc_loss='binary_crossentropy', rel_loss='sparse_categorical_crossentropy',
- metrics=('UF', 'LF'), n_buckets=32, batch_size=5000, epochs=50000, early_stopping_patience=100,
- tree=False, punct=False, min_freq=2, run_eagerly=False, logger=None, verbose=True, **kwargs):
- return super().fit(trn_data, dev_data, save_dir, n_embed, pretrained_embed, embed_dropout, n_lstm_hidden,
- n_lstm_layers, lstm_dropout, n_mlp_arc, n_mlp_rel, mlp_dropout, optimizer, lr, mu, nu,
- epsilon, clip, decay, decay_steps, patience, arc_loss, rel_loss, metrics, n_buckets,
- batch_size, epochs, early_stopping_patience, tree, punct, min_freq, run_eagerly, logger,
- verbose, **kwargs)
-
- def get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss):
- mask = tf.tile(tf.expand_dims(mask, -1), [1, 1, tf.shape(mask)[-1]])
- mask &= tf.transpose(mask, [0, 2, 1])
- arc_scores, arcs = arc_scores[mask], arcs[mask]
- rel_scores, rels = rel_scores[mask], rels[mask]
- rel_scores, rels = rel_scores[arcs], rels[arcs]
- arc_loss = arc_loss(arcs, arc_scores)
- rel_loss = rel_loss(rels, rel_scores)
- loss = arc_loss + rel_loss
-
- return loss
-
- def decode(self, arc_scores, rel_scores, mask):
- if self.config.tree:
- # arc_preds = eisner(arc_scores, mask)
- raise NotImplemented('Give me some time...')
- else:
- arc_preds = arc_scores > 0
-
- rel_preds = tf.argmax(rel_scores, -1)
-
- return arc_preds, rel_preds
diff --git a/hanlp/components/parsers/biaffine_parser_tf.py b/hanlp/components/parsers/biaffine_parser_tf.py
new file mode 100644
index 000000000..e3775dbee
--- /dev/null
+++ b/hanlp/components/parsers/biaffine_parser_tf.py
@@ -0,0 +1,723 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-22 12:47
+import logging
+import math
+import os
+from typing import List
+import numpy as np
+import tensorflow as tf
+
+from hanlp.components.parsers.parse_alg import unique_root, adjust_root_score, chu_liu_edmonds
+from hanlp.layers.transformers.loader_tf import build_transformer
+
+from hanlp.common.keras_component import KerasComponent
+from hanlp.components.parsers.alg_tf import tarjan
+from hanlp.components.parsers.biaffine_tf.model import BiaffineModelTF, StructuralAttentionModel
+from hanlp.transform.conll_tf import CoNLL_DEP_Transform, CoNLL_Transformer_Transform, CoNLL_SDP_Transform
+from hanlp.layers.embeddings.util_tf import build_embedding
+from hanlp.layers.transformers.tf_imports import PreTrainedTokenizer, TFAutoModel, TFPreTrainedModel, AutoTokenizer, \
+ TFAutoModelWithLMHead, BertTokenizerFast, AlbertConfig, BertTokenizer, TFBertModel
+from hanlp.layers.transformers.utils_tf import build_adamw_optimizer
+from hanlp.metrics.parsing.labeled_f1_tf import LabeledF1TF
+from hanlp.metrics.parsing.labeled_score import LabeledScore
+from hanlp_common.util import merge_locals_kwargs
+
+
+class BiaffineDependencyParserTF(KerasComponent):
+ def __init__(self, transform: CoNLL_DEP_Transform = None) -> None:
+ if not transform:
+ transform = CoNLL_DEP_Transform()
+ super().__init__(transform)
+ self.transform: CoNLL_DEP_Transform = transform
+ self.model: BiaffineModelTF = None
+
+ def build_model(self, pretrained_embed, n_embed, training, **kwargs) -> tf.keras.Model:
+ if training:
+ self.config.n_words = len(self.transform.form_vocab)
+ else:
+ self.config.lstm_dropout = 0. # keras will use cuda lstm when config.lstm_dropout is 0
+ self.config.n_feats = len(self.transform.cpos_vocab)
+ self._init_config()
+ pretrained: tf.keras.layers.Embedding = build_embedding(pretrained_embed, self.transform.form_vocab,
+ self.transform) if pretrained_embed else None
+ if pretrained_embed:
+ self.config.n_embed = pretrained.output_dim
+ model = BiaffineModelTF(self.config, pretrained)
+ return model
+
+ def _init_config(self):
+ self.config.n_rels = len(self.transform.rel_vocab)
+ self.config.pad_index = self.transform.form_vocab.pad_idx
+ self.config.unk_index = self.transform.form_vocab.unk_idx
+ self.config.bos_index = 2
+
+ def load_weights(self, save_dir, filename='model.h5', functional=False, **kwargs):
+ super().load_weights(save_dir, filename)
+ if functional:
+ self.model = self.model.to_functional()
+
+ def fit(self, trn_data, dev_data, save_dir,
+ n_embed=100,
+ pretrained_embed=None,
+ embed_dropout=.33,
+ n_lstm_hidden=400,
+ n_lstm_layers=3,
+ lstm_dropout=.33,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ optimizer='adam',
+ lr=2e-3,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ clip=5.0,
+ decay=.75,
+ decay_steps=5000,
+ patience=100,
+ arc_loss='sparse_categorical_crossentropy',
+ rel_loss='sparse_categorical_crossentropy',
+ metrics=('UAS', 'LAS'),
+ n_buckets=32,
+ batch_size=5000,
+ epochs=50000,
+ early_stopping_patience=100,
+ tree=False,
+ punct=False,
+ min_freq=2,
+ run_eagerly=False, logger=None, verbose=True,
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ # noinspection PyMethodOverriding
+ def train_loop(self, trn_data, dev_data, epochs, num_examples,
+ train_steps_per_epoch, dev_steps, model, optimizer, loss, metrics,
+ callbacks, logger: logging.Logger, arc_loss, rel_loss,
+ **kwargs):
+ arc_loss, rel_loss = loss
+ # because we are customizing batching
+ train_steps_per_epoch = len(list(iter(trn_data)))
+ # progbar: tf.keras.callbacks.ProgbarLogger = callbacks[-1]
+ c: tf.keras.callbacks.Callback = None
+ metric = self._build_metrics()
+ for c in callbacks:
+ c.params['epochs'] = epochs
+ c.params['trn_data'] = trn_data
+ c.params['metrics'] = ['loss'] + self.config.metrics
+ c.params['metrics'] = c.params['metrics'] + [f'val_{k}' for k in c.params['metrics']]
+ c.on_train_begin()
+ for epoch in range(epochs):
+ metric.reset_states()
+ for c in callbacks:
+ c.params['steps'] = train_steps_per_epoch
+ c.on_epoch_begin(epoch)
+ for idx, ((words, feats), (arcs, rels)) in enumerate(iter(trn_data)):
+ logs = {}
+ for c in callbacks:
+ c.on_batch_begin(idx, logs)
+ mask = tf.not_equal(words, self.config.pad_index) & tf.not_equal(words, self.config.bos_index)
+ loss, arc_scores, rel_scores = self.train_batch(words, feats, arcs, rels, mask,
+ optimizer, arc_loss, rel_loss)
+ self.run_metrics(arcs, rels, arc_scores, rel_scores, words, mask, metric)
+ logs['loss'] = loss
+ logs.update(metric.to_dict())
+ if epoch == epochs - 1:
+ self.model.stop_training = True
+ for c in callbacks:
+ c.on_batch_end(idx, logs)
+ # evaluate on dev
+ metric.reset_states()
+ logs = {}
+ for idx, ((words, feats), (arcs, rels)) in enumerate(iter(dev_data)):
+ arc_scores, rel_scores, loss, mask, arc_preds, rel_preds = self.evaluate_batch(words, feats, arcs, rels,
+ arc_loss, rel_loss,
+ metric)
+ logs['val_loss'] = loss
+ logs.update((f'val_{k}', v) for k, v in metric.to_dict().items())
+
+ for c in callbacks:
+ c.on_epoch_end(epoch, logs)
+ if getattr(self.model, 'stop_training', None):
+ break
+
+ for c in callbacks:
+ c.on_train_end()
+
+ def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=None, logger: logging.Logger = None,
+ callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=False, verbose=True, **kwargs):
+ if batch_size is None:
+ batch_size = self.config.batch_size
+ return super().evaluate(input_path, save_dir, output, batch_size, logger, callbacks, warm_up, verbose, **kwargs)
+
+ def evaluate_batch(self, words, feats, arcs, rels, arc_loss, rel_loss, metric):
+ mask = tf.not_equal(words, self.config.pad_index) & tf.not_equal(words, self.config.bos_index)
+ arc_scores, rel_scores = self.model((words, feats))
+ loss = self.get_loss(arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss)
+ arc_preds, rel_preds = self.run_metrics(arcs, rels, arc_scores, rel_scores, words, mask, metric)
+ return arc_scores, rel_scores, loss, mask, arc_preds, rel_preds
+
+ def _build_metrics(self):
+ if isinstance(self.config.metrics, tuple):
+ self.config.metrics = list(self.config.metrics)
+ if self.config.metrics == ['UAS', 'LAS']:
+ metric = LabeledScore()
+ else:
+ metric = LabeledF1TF()
+ return metric
+
+ def run_metrics(self, arcs, rels, arc_scores, rel_scores, words, mask, metric):
+ arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
+ # ignore all punctuation if not specified
+ if not self.config.punct:
+ mask &= tf.reduce_all(tf.not_equal(tf.expand_dims(words, axis=-1), self.transform.puncts), axis=-1)
+ metric(arc_preds, rel_preds, arcs, rels, mask)
+ return arc_preds, rel_preds
+
+ def train_batch(self, words, feats, arcs, rels, mask, optimizer, arc_loss, rel_loss):
+ with tf.GradientTape() as tape:
+ arc_scores, rel_scores = self.model((words, feats), training=True)
+ loss = self.get_loss(arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss)
+ grads = tape.gradient(loss, self.model.trainable_variables)
+ optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
+ return loss, arc_scores, rel_scores
+
+ def get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss):
+ arc_scores, arcs = arc_scores[mask], arcs[mask]
+ rel_scores, rels = rel_scores[mask], rels[mask]
+ rel_scores = tf.gather_nd(rel_scores, tf.stack([tf.range(len(arcs), dtype=tf.int64), arcs], axis=1))
+ arc_loss = arc_loss(arcs, arc_scores)
+ rel_loss = rel_loss(rels, rel_scores)
+ loss = arc_loss + rel_loss
+
+ return loss
+
+ def build_optimizer(self, optimizer='adam', lr=2e-3, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75,
+ decay_steps=5000, **kwargs):
+ if optimizer == 'adam':
+ scheduler = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=lr,
+ decay_steps=decay_steps,
+ decay_rate=decay)
+ optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler,
+ beta_1=mu,
+ beta_2=nu,
+ epsilon=epsilon,
+ clipnorm=clip)
+ return optimizer
+ return super().build_optimizer(optimizer, **kwargs)
+
+ # noinspection PyMethodOverriding
+ def build_loss(self, arc_loss, rel_loss, **kwargs):
+ if arc_loss == 'binary_crossentropy':
+ arc_loss = tf.losses.BinaryCrossentropy(from_logits=True)
+ else:
+ arc_loss = tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True) if arc_loss == 'sparse_categorical_crossentropy' else super().build_loss(arc_loss)
+ rel_loss = tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True) if rel_loss == 'sparse_categorical_crossentropy' else super().build_loss(rel_loss)
+ return arc_loss, rel_loss
+
+ @property
+ def sample_data(self):
+ return tf.constant([[2, 3, 4], [2, 5, 0]], dtype=tf.int64), tf.constant([[1, 2, 3], [4, 5, 0]], dtype=tf.int64)
+
+ def num_samples_in(self, dataset):
+ return sum(len(x[0][0]) for x in iter(dataset))
+
+ def build_train_dataset(self, trn_data, batch_size, num_examples):
+ trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size,
+ shuffle=True,
+ repeat=None)
+ return trn_data
+
+ # noinspection PyMethodOverriding
+ def build_callbacks(self, save_dir, logger, metrics, **kwargs):
+ callbacks = super().build_callbacks(save_dir, logger, metrics=metrics, **kwargs)
+ if isinstance(metrics, tuple):
+ metrics = list(metrics)
+ callbacks.append(self.build_progbar(metrics))
+ params = {'verbose': 1, 'epochs': 1}
+ for c in callbacks:
+ c.set_params(params)
+ c.set_model(self.model)
+ return callbacks
+
+ def build_progbar(self, metrics, training=True):
+ return tf.keras.callbacks.ProgbarLogger(count_mode='steps',
+ stateful_metrics=metrics + [f'val_{k}' for k in metrics] if training
+ else [])
+
+ def decode(self, arc_scores, rel_scores, mask):
+ if self.config.tree:
+ root_rel_idx = self.transform.root_rel_idx
+ root_rel_onehot = np.eye(len(self.transform.rel_vocab))[root_rel_idx]
+ arc_preds = np.zeros_like(mask, dtype=np.int64)
+ rel_preds = np.zeros_like(mask, dtype=np.int64)
+ for arc, rel, m, arc_pred, rel_pred in zip(arc_scores, rel_scores, mask, arc_preds, rel_preds):
+ length = int(tf.math.count_nonzero(m)) + 1
+ arc = arc[:length, :length]
+ arc_probs = tf.nn.softmax(arc).numpy()
+ m = np.expand_dims(m.numpy()[:length], -1)
+ if self.config.tree == 'tarjan':
+ heads = tarjan(arc_probs, length, m)
+ elif self.config.tree == 'mst':
+ heads, head_probs, tokens = unique_root(arc_probs, m, length)
+ arc = arc.numpy()
+ adjust_root_score(arc, heads, root_rel_idx)
+ heads = chu_liu_edmonds(arc, length)
+ else:
+ raise ValueError(f'Unknown tree algorithm {self.config.tree}')
+ arc_pred[:length] = heads
+ root = np.where(heads[np.arange(1, length)] == 0)[0] + 1
+ rel_prob = tf.nn.softmax(rel[:length, :length, :]).numpy()
+ rel_prob = rel_prob[np.arange(length), heads]
+ rel_prob[root] = root_rel_onehot
+ rel_prob[np.arange(length) != root, np.arange(len(self.transform.rel_vocab)) == root_rel_idx] = 0
+ # rels = rel_argmax(rel_prob, length, root_rel_idx)
+ rels = np.argmax(rel_prob, axis=1)
+ rel_pred[:length] = rels
+ arc_preds = tf.constant(arc_preds)
+ rel_preds = tf.constant(rel_preds)
+ else:
+ arc_preds = tf.argmax(arc_scores, -1)
+ rel_preds = tf.argmax(rel_scores, -1)
+ rel_preds = tf.squeeze(tf.gather(rel_preds, tf.expand_dims(arc_preds, -1), batch_dims=2), axis=-1)
+
+ return arc_preds, rel_preds
+
+ def evaluate_dataset(self, tst_data, callbacks, output, num_batches, ret_scores=None, **kwargs):
+ if 'mask_p' in self.config:
+ self.config['mask_p'] = None
+ arc_loss, rel_loss = self.build_loss(**self.config)
+ callbacks = [self.build_progbar(self.config['metrics'])]
+ steps_per_epoch = len(list(iter(tst_data)))
+ metric = self._build_metrics()
+ params = {'verbose': 1, 'epochs': 1, 'metrics': ['loss'] + self.config.metrics, 'steps': steps_per_epoch}
+ for c in callbacks:
+ c.set_params(params)
+ c.on_test_begin()
+ c.on_epoch_end(0)
+ logs = {}
+ if ret_scores:
+ scores = []
+ if output:
+ ext = os.path.splitext(output)[-1]
+ output = open(output, 'w', encoding='utf-8')
+ for idx, ((words, feats), Y) in enumerate(iter(tst_data)):
+ arcs, rels = Y[0], Y[1]
+ for c in callbacks:
+ c.on_test_batch_begin(idx, logs)
+ arc_scores, rel_scores, loss, mask, arc_preds, rel_preds = self.evaluate_batch(words, feats, arcs, rels,
+ arc_loss, rel_loss, metric)
+ if ret_scores:
+ scores.append((arc_scores.numpy(), rel_scores.numpy(), mask.numpy()))
+ if output:
+ for sent in self.transform.XY_to_inputs_outputs((words, feats, mask), (arc_preds, rel_preds),
+ conll=ext, arc_scores=arc_scores,
+ rel_scores=rel_scores):
+ output.write(str(sent))
+ output.write('\n\n')
+ logs['loss'] = loss
+ logs.update(metric.to_dict())
+ for c in callbacks:
+ c.on_test_batch_end(idx, logs)
+ for c in callbacks:
+ c.on_epoch_end(0)
+ c.on_test_end()
+ if output:
+ output.close()
+ loss = float(c.progbar._values['loss'][0] / c.progbar._values['loss'][1])
+ outputs = loss, metric.to_dict(), False
+ if ret_scores:
+ outputs += (scores,)
+ return outputs
+
+ def predict_batch(self, batch, inputs=None, conll=True, **kwargs):
+ ((words, feats), (arcs, rels)) = batch
+ mask = tf.not_equal(words, self.config.pad_index) & tf.not_equal(words, self.config.bos_index)
+ arc_scores, rel_scores = self.model((words, feats))
+ arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask)
+ for sent in self.transform.XY_to_inputs_outputs((words, feats, mask), (arc_preds, rel_preds), gold=False,
+ inputs=inputs, conll=conll):
+ yield sent
+
+ def compile_model(self, optimizer, loss, metrics):
+ super().compile_model(optimizer, loss, metrics)
+
+
+class BiaffineSemanticDependencyParserTF(BiaffineDependencyParserTF):
+ def __init__(self, transform: CoNLL_SDP_Transform = None) -> None:
+ if not transform:
+ transform = CoNLL_SDP_Transform()
+ # noinspection PyTypeChecker
+ super().__init__(transform)
+ self.transform: CoNLL_SDP_Transform = transform
+
+ def fit(self, trn_data, dev_data, save_dir, n_embed=100, pretrained_embed=None, embed_dropout=.33,
+ n_lstm_hidden=400, n_lstm_layers=3, lstm_dropout=.33, n_mlp_arc=500, n_mlp_rel=100, mlp_dropout=.33,
+ optimizer='adam', lr=2e-3, mu=.9, nu=.9, epsilon=1e-12, clip=5.0, decay=.75, decay_steps=5000, patience=100,
+ arc_loss='binary_crossentropy', rel_loss='sparse_categorical_crossentropy',
+ metrics=('UF', 'LF'), n_buckets=32, batch_size=5000, epochs=50000, early_stopping_patience=100,
+ tree=False, punct=False, min_freq=2, run_eagerly=False, logger=None, verbose=True, **kwargs):
+ return super().fit(trn_data, dev_data, save_dir, n_embed, pretrained_embed, embed_dropout, n_lstm_hidden,
+ n_lstm_layers, lstm_dropout, n_mlp_arc, n_mlp_rel, mlp_dropout, optimizer, lr, mu, nu,
+ epsilon, clip, decay, decay_steps, patience, arc_loss, rel_loss, metrics, n_buckets,
+ batch_size, epochs, early_stopping_patience, tree, punct, min_freq, run_eagerly, logger,
+ verbose, **kwargs)
+
+ def get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss):
+ mask = tf.tile(tf.expand_dims(mask, -1), [1, 1, tf.shape(mask)[-1]])
+ mask &= tf.transpose(mask, [0, 2, 1])
+ arc_scores, arcs = arc_scores[mask], arcs[mask]
+ rel_scores, rels = rel_scores[mask], rels[mask]
+ rel_scores, rels = rel_scores[arcs], rels[arcs]
+ arc_loss = arc_loss(arcs, arc_scores)
+ rel_loss = rel_loss(rels, rel_scores)
+ loss = arc_loss + rel_loss
+
+ return loss
+
+ def decode(self, arc_scores, rel_scores, mask):
+ arc_preds = arc_scores > 0
+ rel_preds = tf.argmax(rel_scores, -1)
+
+ return arc_preds, rel_preds
+
+
+class BiaffineTransformerDependencyParserTF(BiaffineDependencyParserTF, tf.keras.callbacks.Callback):
+ def __init__(self, transform: CoNLL_Transformer_Transform = None) -> None:
+ if not transform:
+ transform = CoNLL_Transformer_Transform()
+ super().__init__(transform)
+ self.transform: CoNLL_Transformer_Transform = transform
+
+ def build_model(self, transformer, training, **kwargs) -> tf.keras.Model:
+ transformer = self.build_transformer(training, transformer)
+ model = BiaffineModelTF(self.config, transformer=transformer)
+ return model
+
+ def build_transformer(self, training, transformer):
+ if training:
+ self.config.n_words = len(self.transform.form_vocab)
+ self._init_config()
+ if isinstance(transformer, str):
+ if 'albert_chinese' in transformer:
+ tokenizer = BertTokenizerFast.from_pretrained(transformer, add_special_tokens=False)
+ transformer: TFPreTrainedModel = TFAutoModel.from_pretrained(transformer, name=transformer,
+ from_pt=True)
+ elif transformer.startswith('albert') and transformer.endswith('zh'):
+ transformer, tokenizer, path = build_transformer(transformer)
+ transformer.config = AlbertConfig.from_json_file(os.path.join(path, "albert_config.json"))
+ tokenizer = BertTokenizer.from_pretrained(os.path.join(path, "vocab_chinese.txt"),
+ add_special_tokens=False)
+ elif 'chinese-roberta' in transformer:
+ tokenizer = BertTokenizer.from_pretrained(transformer)
+ transformer = TFBertModel.from_pretrained(transformer, name=transformer, from_pt=True)
+ else:
+ tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(transformer)
+ try:
+ transformer: TFPreTrainedModel = TFAutoModel.from_pretrained(transformer, name=transformer)
+ except (TypeError, OSError):
+ transformer: TFPreTrainedModel = TFAutoModel.from_pretrained(transformer, name=transformer,
+ from_pt=True)
+ elif transformer[0] == 'AutoModelWithLMHead':
+ tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(transformer[1])
+ transformer: TFAutoModelWithLMHead = TFAutoModelWithLMHead.from_pretrained(transformer[1])
+ else:
+ raise ValueError(f'Unknown identifier {transformer}')
+ self.transform.tokenizer = tokenizer
+ if self.config.get('fp16', None) or self.config.get('use_amp', None):
+ policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
+ tf.keras.mixed_precision.experimental.set_policy(policy)
+ # tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
+ transformer.set_weights([w.astype('float16') for w in transformer.get_weights()])
+ self.transform.transformer_config = transformer.config
+ return transformer
+
+ # noinspection PyMethodOverriding
+ def fit(self, trn_data, dev_data, save_dir, transformer, max_seq_length=256, transformer_dropout=.33,
+ d_positional=None,
+ n_mlp_arc=500, n_mlp_rel=100, mlp_dropout=.33,
+ optimizer='adamw',
+ learning_rate=5e-5,
+ learning_rate_transformer=None,
+ weight_decay_rate=0,
+ epsilon=1e-8,
+ clipnorm=None,
+ fp16=False,
+ warmup_steps_ratio=0,
+ arc_loss='sparse_categorical_crossentropy', rel_loss='sparse_categorical_crossentropy',
+ metrics=('UAS', 'LAS'),
+ batch_size=3000,
+ samples_per_batch=150,
+ max_samples_per_batch=None,
+ epochs=100,
+ tree=False, punct=False, token_mapping=None, run_eagerly=False, logger=None, verbose=True, **kwargs):
+ self.set_params({})
+ return KerasComponent.fit(self, **merge_locals_kwargs(locals(), kwargs))
+
+ @property
+ def sample_data(self):
+ dataset = self.transform.inputs_to_dataset(
+ [[('Hello', 'NN'), ('world', 'NN')], [('HanLP', 'NN'), ('is', 'NN'), ('good', 'NN')]] if self.config.get(
+ 'use_pos', None) else
+ [['Hello', 'world'], ['HanLP', 'is', 'good']])
+ return next(iter(dataset))[0]
+
+ # noinspection PyMethodOverriding
+ def build_optimizer(self, optimizer, learning_rate, epsilon, weight_decay_rate, clipnorm, fp16, train_steps,
+ **kwargs):
+ if optimizer == 'adamw':
+ epochs = self.config['epochs']
+ learning_rate_transformer = kwargs.get('learning_rate_transformer', None)
+ train_steps = math.ceil(self.config.train_examples * epochs / self.config.samples_per_batch)
+ warmup_steps = math.ceil(train_steps * self.config['warmup_steps_ratio'])
+ if learning_rate_transformer is not None:
+ if learning_rate_transformer > 0:
+ self.params['optimizer_transformer'] = build_adamw_optimizer(self.config, learning_rate_transformer,
+ epsilon,
+ clipnorm, train_steps, fp16,
+ math.ceil(warmup_steps),
+ weight_decay_rate)
+ else:
+ self.model.transformer.trainable = False
+ return super().build_optimizer(lr=learning_rate) # use a normal adam for biaffine
+ else:
+ return build_adamw_optimizer(self.config, learning_rate, epsilon, clipnorm, train_steps, fp16,
+ math.ceil(warmup_steps), weight_decay_rate)
+ return super().build_optimizer(optimizer, **kwargs)
+
+ def build_vocab(self, trn_data, logger):
+ self.config.train_examples = train_examples = super().build_vocab(trn_data, logger)
+ return train_examples
+
+ def build_callbacks(self, save_dir, logger, metrics, **kwargs):
+ callbacks = super().build_callbacks(save_dir, logger, metrics=metrics, **kwargs)
+ callbacks.append(self)
+ if not self.params:
+ self.set_params({})
+ return callbacks
+
+ def on_train_begin(self):
+ self.params['accum_grads'] = [tf.Variable(tf.zeros_like(tv.read_value()), trainable=False) for tv in
+ self.model.trainable_variables]
+ self.params['trained_samples'] = 0
+ self.params['transformer_variable_names'] = {x.name for x in self.model.transformer.trainable_variables}
+
+ def train_batch(self, words, feats, arcs, rels, mask, optimizer, arc_loss, rel_loss):
+ with tf.GradientTape() as tape:
+ arc_scores, rel_scores = self.model((words, feats), training=True)
+ loss = self.get_loss(arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss)
+ grads = tape.gradient(loss, self.model.trainable_variables)
+ accum_grads = self.params['accum_grads']
+ for i, grad in enumerate(grads):
+ if grad is not None:
+ accum_grads[i].assign_add(grad)
+ self.params['trained_samples'] += tf.shape(words)[0]
+ if self.params['trained_samples'] >= self.config.samples_per_batch:
+ self._apply_grads(accum_grads)
+ return loss, arc_scores, rel_scores
+
+ def _apply_grads(self, accum_grads):
+ optimizer_transformer = self.params.get('optimizer_transformer', None)
+ if optimizer_transformer:
+ transformer = self.params['transformer_variable_names']
+ trainable_variables = self.model.trainable_variables
+ optimizer_transformer.apply_gradients(
+ (g, w) for g, w in zip(accum_grads, trainable_variables) if w.name in transformer)
+ self.model.optimizer.apply_gradients(
+ (g, w) for g, w in zip(accum_grads, trainable_variables) if w.name not in transformer)
+ else:
+ self.model.optimizer.apply_gradients(zip(accum_grads, self.model.trainable_variables))
+ for tv in accum_grads:
+ tv.assign(tf.zeros_like(tv))
+ # print('Apply grads after', self.params['trained_samples'], 'samples')
+ self.params['trained_samples'] = 0
+
+ def on_epoch_end(self, epoch, logs=None):
+ if self.params['trained_samples']:
+ self._apply_grads(self.params['accum_grads'])
+
+
+class BiaffineTransformerSemanticDependencyParser(BiaffineTransformerDependencyParserTF):
+
+ def __init__(self, transform: CoNLL_Transformer_Transform = None) -> None:
+ if not transform:
+ transform = CoNLL_Transformer_Transform(graph=True)
+ super().__init__(transform)
+
+ def get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss):
+ return BiaffineSemanticDependencyParserTF.get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss,
+ rel_loss)
+
+ def fit(self, trn_data, dev_data, save_dir, transformer, max_seq_length=256, transformer_dropout=.33,
+ d_positional=None, n_mlp_arc=500, n_mlp_rel=100, mlp_dropout=.33, optimizer='adamw', learning_rate=5e-5,
+ learning_rate_transformer=None, weight_decay_rate=0, epsilon=1e-8, clipnorm=None, fp16=False,
+ warmup_steps_ratio=0, arc_loss='binary_crossentropy',
+ rel_loss='sparse_categorical_crossentropy', metrics=('UF', 'LF'), batch_size=3000, samples_per_batch=150,
+ max_samples_per_batch=None, epochs=100, tree=False, punct=False, token_mapping=None, enhanced_only=False,
+ run_eagerly=False,
+ logger=None, verbose=True, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def decode(self, arc_scores, rel_scores, mask):
+ return BiaffineSemanticDependencyParserTF.decode(self, arc_scores, rel_scores, mask)
+
+
+class StructuralAttentionDependencyParserTF(BiaffineTransformerDependencyParserTF):
+
+ def build_model(self, transformer, training, masked_lm_embed=None, **kwargs) -> tf.keras.Model:
+ transformer = self.build_transformer(training, transformer)
+ self.config.num_heads = len(self.transform.rel_vocab)
+ if self.config.get('use_pos', None):
+ self.config.n_pos = len(self.transform.cpos_vocab)
+ if masked_lm_embed:
+ masked_lm_embed = build_embedding(masked_lm_embed, self.transform.form_vocab, self.transform)
+ masked_lm_embed(tf.constant(0)) # build it with sample data
+ masked_lm_embed = tf.transpose(masked_lm_embed._embeddings)
+ return StructuralAttentionModel(self.config, transformer, masked_lm_embed)
+
+ def fit(self, trn_data, dev_data, save_dir, transformer, max_seq_length=256, transformer_dropout=.33,
+ d_positional=None, mask_p=.15, masked_lm_dropout=None, masked_lm_embed=None, joint_pos=False, alpha=0.1,
+ sa_dim=None,
+ num_decoder_layers=1,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ mlp_dropout=.33,
+ optimizer='adamw',
+ learning_rate=5e-5,
+ learning_rate_transformer=None, weight_decay_rate=0, epsilon=1e-8, clipnorm=None, fp16=False,
+ warmup_steps_ratio=0, arc_loss='sparse_categorical_crossentropy',
+ rel_loss='sparse_categorical_crossentropy', metrics=('UAS', 'LAS'), batch_size=3000, samples_per_batch=150,
+ epochs=100, tree=False, punct=False, token_mapping=None, run_eagerly=False, logger=None, verbose=True,
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer,
+ loss, metrics, callbacks, logger: logging.Logger, arc_loss, rel_loss, **kwargs):
+ arc_loss, rel_loss = loss
+ # because we are customizing batching
+ train_steps_per_epoch = len(list(iter(trn_data)))
+ # progbar: tf.keras.callbacks.ProgbarLogger = callbacks[-1]
+ c: tf.keras.callbacks.Callback = None
+ metrics = self._build_metrics()
+ acc: tf.keras.metrics.SparseCategoricalAccuracy = metrics[1]
+ for c in callbacks:
+ if not hasattr(c, 'params'):
+ c.params = {}
+ c.params['epochs'] = epochs
+ c.params['trn_data'] = trn_data
+ c.params['metrics'] = ['loss'] + self.config.metrics + [acc.name]
+ c.params['metrics'] = c.params['metrics'] + [f'val_{k}' for k in c.params['metrics']]
+ c.on_train_begin()
+ for epoch in range(epochs):
+ for metric in metrics:
+ metric.reset_states()
+ for c in callbacks:
+ c.params['steps'] = train_steps_per_epoch
+ c.on_epoch_begin(epoch)
+ for idx, ((words, feats), (arcs, rels, offsets)) in enumerate(iter(trn_data)):
+ logs = {}
+ for c in callbacks:
+ c.on_batch_begin(idx, logs)
+ mask = tf.not_equal(words, self.config.pad_index) & tf.not_equal(words, self.config.bos_index)
+ loss, arc_scores, rel_scores, lm_ids = self.train_batch(words, feats, arcs, rels, offsets, mask,
+ optimizer, arc_loss, rel_loss, acc)
+ self.run_metrics(arcs, rels, arc_scores, rel_scores, words, mask, metrics[0])
+ logs['loss'] = loss
+ logs.update(metrics[0].to_dict())
+ logs[acc.name] = acc.result()
+ if epoch == epochs - 1:
+ self.model.stop_training = True
+ for c in callbacks:
+ c.on_batch_end(idx, logs)
+ # evaluate on dev
+ for metric in metrics:
+ metric.reset_states()
+ logs = {}
+ for idx, ((words, feats), (arcs, rels, offsets)) in enumerate(iter(dev_data)):
+ arc_scores, rel_scores, loss, mask, arc_preds, rel_preds = self.evaluate_batch(words, feats, arcs, rels,
+ arc_loss, rel_loss,
+ metrics[0])
+ logs['val_loss'] = loss
+ logs.update((f'val_{k}', v) for k, v in metrics[0].to_dict().items())
+
+ for c in callbacks:
+ c.on_epoch_end(epoch, logs)
+ if getattr(self.model, 'stop_training', None):
+ break
+
+ for c in callbacks:
+ c.on_train_end()
+
+ # noinspection PyMethodOverriding
+ def train_batch(self, words, feats, arcs, rels, ids, mask, optimizer, arc_loss, rel_loss, metric):
+ with tf.GradientTape() as tape:
+ arc_scores, rel_scores, lm_ids = self.model((words, feats), training=True)
+ loss = self.get_total_loss(words, feats, arcs, rels, arc_scores, rel_scores, arc_loss, rel_loss, ids,
+ lm_ids, mask,
+ metric)
+ grads = tape.gradient(loss, self.model.trainable_variables)
+ accum_grads = self.params['accum_grads']
+ for i, grad in enumerate(grads):
+ if grad is not None:
+ accum_grads[i].assign_add(grad)
+ self.params['trained_samples'] += tf.shape(words)[0]
+ if self.params['trained_samples'] >= self.config.samples_per_batch:
+ self._apply_grads(accum_grads)
+ return loss, arc_scores, rel_scores, lm_ids
+
+ def get_total_loss(self, words, feats, arcs, rels, arc_scores, rel_scores, arc_loss, rel_loss, gold_offsets,
+ pred_ids,
+ mask, metric):
+ masked_lm_loss = self.get_masked_lm_loss(words, feats, gold_offsets, pred_ids, metric)
+ # return masked_lm_loss
+ parser_loss = self.get_loss(arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss)
+ loss = parser_loss + masked_lm_loss * self.config.alpha
+ return loss
+
+ def get_masked_lm_loss(self, words, feats, gold_offsets, pred_ids, metric):
+ if self.config.get('joint_pos', None):
+ gold_ids = tf.gather(feats[-1], gold_offsets, batch_dims=1)
+ else:
+ gold_ids = tf.gather(words, gold_offsets, batch_dims=1)
+ pred_ids = tf.gather(pred_ids, gold_offsets, batch_dims=1)
+ masked_lm_loss = tf.keras.losses.sparse_categorical_crossentropy(gold_ids, pred_ids)
+ mask = gold_offsets != 0
+ if metric:
+ metric(gold_ids, pred_ids, mask)
+ return tf.reduce_mean(tf.boolean_mask(masked_lm_loss, mask))
+
+ def _build_metrics(self):
+ if not self.config['mask_p']:
+ return super()._build_metrics()
+ acc = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
+ return super()._build_metrics(), acc
+
+ def build_train_dataset(self, trn_data, batch_size, num_examples):
+ trn_data = self.transform.file_to_dataset(trn_data, batch_size=batch_size,
+ shuffle=True,
+ repeat=None,
+ cache=False) # Generate different masks every time
+ return trn_data
+
+ def build_loss(self, arc_loss, rel_loss, **kwargs):
+ if arc_loss == 'binary_crossentropy':
+ arc_loss = tf.losses.BinaryCrossentropy(from_logits=False)
+ else:
+ arc_loss = tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True) if arc_loss == 'sparse_categorical_crossentropy' else super().build_loss(arc_loss)
+ rel_loss = tf.keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True) if rel_loss == 'sparse_categorical_crossentropy' else super().build_loss(rel_loss)
+ return arc_loss, rel_loss
+
+ def decode(self, arc_scores, rel_scores, mask):
+ if self.transform.graph:
+ return BiaffineSemanticDependencyParserTF.decode(self, arc_scores, rel_scores, mask)
+ return super().decode(arc_scores, rel_scores, mask)
+
+ def get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss):
+ if self.transform.graph:
+ return BiaffineSemanticDependencyParserTF.get_loss(self, arc_scores, rel_scores, arcs, rels, mask, arc_loss,
+ rel_loss)
+ return super().get_loss(arc_scores, rel_scores, arcs, rels, mask, arc_loss, rel_loss)
diff --git a/hanlp/components/parsers/biaffine_tf/__init__.py b/hanlp/components/parsers/biaffine_tf/__init__.py
new file mode 100644
index 000000000..12a4372f1
--- /dev/null
+++ b/hanlp/components/parsers/biaffine_tf/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-26 23:03
\ No newline at end of file
diff --git a/hanlp/components/parsers/biaffine_tf/alg.py b/hanlp/components/parsers/biaffine_tf/alg.py
new file mode 100644
index 000000000..75536dc79
--- /dev/null
+++ b/hanlp/components/parsers/biaffine_tf/alg.py
@@ -0,0 +1,289 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-26 19:49
+# Ported from the PyTorch implementation https://github.com/zysite/biaffine-parser
+from typing import List
+import numpy as np
+import tensorflow as tf
+from collections import defaultdict
+
+
+def nonzero(t: tf.Tensor) -> tf.Tensor:
+ return tf.where(t > 0)
+
+
+def view(t: tf.Tensor, *dims) -> tf.Tensor:
+ return tf.reshape(t, dims)
+
+
+def arange(n: int) -> tf.Tensor:
+ return tf.range(n)
+
+
+def randperm(n: int) -> tf.Tensor:
+ return tf.random.shuffle(arange(n))
+
+
+def tolist(t: tf.Tensor) -> List:
+ if isinstance(t, tf.Tensor):
+ t = t.numpy()
+ return t.tolist()
+
+
+def kmeans(x, k, seed=None):
+ """See https://github.com/zysite/biaffine-parser/blob/master/parser/utils/alg.py#L7
+
+ Args:
+ x(list): Lengths of sentences
+ k(int):
+ seed: (Default value = None)
+
+ Returns:
+
+
+ """
+ x = tf.constant(x, dtype=tf.float32)
+ # count the frequency of each datapoint
+ d, indices, f = tf.unique_with_counts(x, tf.int32)
+ f = tf.cast(f, tf.float32)
+ # calculate the sum of the values of the same datapoints
+ total = d * f
+ # initialize k centroids randomly
+ c, old = tf.random.shuffle(d, seed)[:k], None
+ # assign labels to each datapoint based on centroids
+ dists = tf.abs(tf.expand_dims(d, -1) - c)
+ y = tf.argmin(dists, axis=-1, output_type=tf.int32)
+ dists = tf.gather_nd(dists, tf.transpose(tf.stack([tf.range(tf.shape(dists)[0], dtype=tf.int32), y])))
+ # make sure number of datapoints is greater than that of clusters
+ assert len(d) >= k, f"unable to assign {len(d)} datapoints to {k} clusters"
+
+ while old is None or not tf.reduce_all(c == old):
+ # if an empty cluster is encountered,
+ # choose the farthest datapoint from the biggest cluster
+ # and move that the empty one
+ for i in range(k):
+ if not tf.reduce_any(y == i):
+ mask = tf.cast(y == tf.expand_dims(tf.range(k, dtype=tf.int32), -1), tf.float32)
+ lens = tf.reduce_sum(mask, axis=-1)
+ biggest = view(nonzero(mask[tf.argmax(lens)]), -1)
+ farthest = tf.argmax(tf.gather(dists, biggest))
+ tf.tensor_scatter_nd_update(y, tf.expand_dims(tf.expand_dims(biggest[farthest], -1), -1), [i])
+ mask = tf.cast(y == tf.expand_dims(tf.range(k, dtype=tf.int32), -1), tf.float32)
+ # update the centroids
+ c, old = tf.cast(tf.reduce_sum(total * mask, axis=-1), tf.float32) / tf.cast(tf.reduce_sum(f * mask, axis=-1),
+ tf.float32), c
+ # re-assign all datapoints to clusters
+ dists = tf.abs(tf.expand_dims(d, -1) - c)
+ y = tf.argmin(dists, axis=-1, output_type=tf.int32)
+ dists = tf.gather_nd(dists, tf.transpose(tf.stack([tf.range(tf.shape(dists)[0], dtype=tf.int32), y])))
+ # assign all datapoints to the new-generated clusters
+ # without considering the empty ones
+ y, (assigned, _) = tf.gather(y, indices), tf.unique(y)
+ # get the centroids of the assigned clusters
+ centroids = tf.gather(c, assigned).numpy().tolist()
+ # map all values of datapoints to buckets
+ clusters = [tf.squeeze(tf.where(y == i), axis=-1).numpy().tolist() for i in assigned]
+
+ return centroids, clusters
+
+
+# ***************************************************************
+class Tarjan:
+ """Computes Tarjan's algorithm for finding strongly connected components (cycles) of a graph"""
+
+ def __init__(self, prediction, tokens):
+ """
+
+ Parameters
+ ----------
+ prediction : numpy.ndarray
+ a predicted dependency tree where prediction[dep_idx] = head_idx
+ tokens : numpy.ndarray
+ the tokens we care about (i.e. exclude _GO, _EOS, and _PAD)
+ """
+ self._edges = defaultdict(set)
+ self._vertices = set((0,))
+ for dep, head in enumerate(prediction[tokens]):
+ self._vertices.add(dep + 1)
+ self._edges[head].add(dep + 1)
+ self._indices = {}
+ self._lowlinks = {}
+ self._onstack = defaultdict(lambda: False)
+ self._SCCs = []
+
+ index = 0
+ stack = []
+ for v in self.vertices:
+ if v not in self.indices:
+ self.strongconnect(v, index, stack)
+
+ # =============================================================
+ def strongconnect(self, v, index, stack):
+ """
+
+ Args:
+ v:
+ index:
+ stack:
+
+ Returns:
+
+ """
+
+ self._indices[v] = index
+ self._lowlinks[v] = index
+ index += 1
+ stack.append(v)
+ self._onstack[v] = True
+ for w in self.edges[v]:
+ if w not in self.indices:
+ self.strongconnect(w, index, stack)
+ self._lowlinks[v] = min(self._lowlinks[v], self._lowlinks[w])
+ elif self._onstack[w]:
+ self._lowlinks[v] = min(self._lowlinks[v], self._indices[w])
+
+ if self._lowlinks[v] == self._indices[v]:
+ self._SCCs.append(set())
+ while stack[-1] != v:
+ w = stack.pop()
+ self._onstack[w] = False
+ self._SCCs[-1].add(w)
+ w = stack.pop()
+ self._onstack[w] = False
+ self._SCCs[-1].add(w)
+ return
+
+ # ======================
+ @property
+ def edges(self):
+ return self._edges
+
+ @property
+ def vertices(self):
+ return self._vertices
+
+ @property
+ def indices(self):
+ return self._indices
+
+ @property
+ def SCCs(self):
+ return self._SCCs
+
+
+def tarjan(parse_probs, length, tokens_to_keep, ensure_tree=True):
+ """Adopted from Timothy Dozat https://github.com/tdozat/Parser/blob/master/lib/models/nn.py
+
+ Args:
+ parse_probs(NDArray): seq_len x seq_len, the probability of arcs
+ length(NDArray): sentence length including ROOT
+ tokens_to_keep(NDArray): mask matrix
+ ensure_tree: (Default value = True)
+
+ Returns:
+
+
+ """
+ if ensure_tree:
+ I = np.eye(len(tokens_to_keep))
+ # block loops and pad heads
+ parse_probs = parse_probs * tokens_to_keep * (1 - I)
+ parse_preds = np.argmax(parse_probs, axis=1)
+ tokens = np.arange(1, length)
+ roots = np.where(parse_preds[tokens] == 0)[0] + 1
+ # ensure at least one root
+ if len(roots) < 1:
+ # The current root probabilities
+ root_probs = parse_probs[tokens, 0]
+ # The current head probabilities
+ old_head_probs = parse_probs[tokens, parse_preds[tokens]]
+ # Get new potential root probabilities
+ new_root_probs = root_probs / old_head_probs
+ # Select the most probable root
+ new_root = tokens[np.argmax(new_root_probs)]
+ # Make the change
+ parse_preds[new_root] = 0
+ # ensure at most one root
+ elif len(roots) > 1:
+ # The probabilities of the current heads
+ root_probs = parse_probs[roots, 0]
+ # Set the probability of depending on the root zero
+ parse_probs[roots, 0] = 0
+ # Get new potential heads and their probabilities
+ new_heads = np.argmax(parse_probs[roots][:, tokens], axis=1) + 1
+ new_head_probs = parse_probs[roots, new_heads] / root_probs
+ # Select the most probable root
+ new_root = roots[np.argmin(new_head_probs)]
+ # Make the change
+ parse_preds[roots] = new_heads
+ parse_preds[new_root] = 0
+ # remove cycles
+ tarjan = Tarjan(parse_preds, tokens)
+ for SCC in tarjan.SCCs:
+ if len(SCC) > 1:
+ dependents = set()
+ to_visit = set(SCC)
+ while len(to_visit) > 0:
+ node = to_visit.pop()
+ if not node in dependents:
+ dependents.add(node)
+ to_visit.update(tarjan.edges[node])
+ # The indices of the nodes that participate in the cycle
+ cycle = np.array(list(SCC))
+ # The probabilities of the current heads
+ old_heads = parse_preds[cycle]
+ old_head_probs = parse_probs[cycle, old_heads]
+ # Set the probability of depending on a non-head to zero
+ non_heads = np.array(list(dependents))
+ parse_probs[np.repeat(cycle, len(non_heads)), np.repeat([non_heads], len(cycle), axis=0).flatten()] = 0
+ # Get new potential heads and their probabilities
+ new_heads = np.argmax(parse_probs[cycle][:, tokens], axis=1) + 1
+ new_head_probs = parse_probs[cycle, new_heads] / old_head_probs
+ # Select the most probable change
+ change = np.argmax(new_head_probs)
+ changed_cycle = cycle[change]
+ old_head = old_heads[change]
+ new_head = new_heads[change]
+ # Make the change
+ parse_preds[changed_cycle] = new_head
+ tarjan.edges[new_head].add(changed_cycle)
+ tarjan.edges[old_head].remove(changed_cycle)
+ return parse_preds
+ else:
+ # block and pad heads
+ parse_probs = parse_probs * tokens_to_keep
+ parse_preds = np.argmax(parse_probs, axis=1)
+ return parse_preds
+
+
+def rel_argmax(rel_probs, length, root, ensure_tree=True):
+ """Fix the relation prediction by heuristic rules
+
+ Args:
+ rel_probs(NDArray): seq_len x rel_size
+ length: real sentence length
+ ensure_tree: (Default value = True)
+ root:
+
+ Returns:
+
+
+ """
+ if ensure_tree:
+ tokens = np.arange(1, length)
+ rel_preds = np.argmax(rel_probs, axis=1)
+ roots = np.where(rel_preds[tokens] == root)[0] + 1
+ if len(roots) < 1:
+ rel_preds[1 + np.argmax(rel_probs[tokens, root])] = root
+ elif len(roots) > 1:
+ root_probs = rel_probs[roots, root]
+ rel_probs[roots, root] = 0
+ new_rel_preds = np.argmax(rel_probs[roots], axis=1)
+ new_rel_probs = rel_probs[roots, new_rel_preds] / root_probs
+ new_root = roots[np.argmin(new_rel_probs)]
+ rel_preds[roots] = new_rel_preds
+ rel_preds[new_root] = root
+ return rel_preds
+ else:
+ rel_preds = np.argmax(rel_probs, axis=1)
+ return rel_preds
diff --git a/hanlp/components/parsers/biaffine/layers.py b/hanlp/components/parsers/biaffine_tf/layers.py
similarity index 55%
rename from hanlp/components/parsers/biaffine/layers.py
rename to hanlp/components/parsers/biaffine_tf/layers.py
index b6c50f126..ed25748dc 100644
--- a/hanlp/components/parsers/biaffine/layers.py
+++ b/hanlp/components/parsers/biaffine_tf/layers.py
@@ -3,6 +3,7 @@
# Date: 2019-12-26 23:05
# Ported from the PyTorch implementation https://github.com/zysite/biaffine-parser
import tensorflow as tf
+from params_flow import LayerNormalization
from hanlp.utils.tf_util import tf_bernoulli
@@ -66,6 +67,7 @@ def call(self, x, **kwargs):
class SharedDropout(tf.keras.layers.Layer):
def __init__(self, p=0.5, batch_first=True, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
+ """Dropout on timesteps with bernoulli distribution"""
super().__init__(trainable, name, dtype, dynamic, **kwargs)
self.p = p
self.batch_first = batch_first
@@ -89,7 +91,7 @@ def call(self, x, training=None, **kwargs):
@staticmethod
def get_mask(x, p):
- mask = tf_bernoulli(tf.shape(x), 1 - p)
+ mask = tf_bernoulli(tf.shape(x), 1 - p, x.dtype)
mask = mask / (1 - p)
return mask
@@ -98,6 +100,7 @@ def get_mask(x, p):
class IndependentDropout(tf.keras.layers.Layer):
def __init__(self, p=0.5, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
+ """Dropout on the first two dimensions"""
super().__init__(trainable, name, dtype, dynamic, **kwargs)
self.p = p
@@ -115,3 +118,68 @@ def call(self, inputs, training=None, **kwargs):
for item, mask in zip(inputs, masks)]
return inputs
+
+
+class StructuralAttentionLayer(tf.keras.layers.Layer):
+ def __init__(self, config, num_heads, x_dim, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):
+ super().__init__(trainable, name, dtype, dynamic, **kwargs)
+
+ sa_dim = config.get('sa_dim', None)
+ if sa_dim:
+ self.shrink = tf.keras.layers.Dense(sa_dim, name='shrink')
+ x_dim = sa_dim
+
+ self.mlp_arc_h = MLP(n_hidden=config.n_mlp_arc,
+ dropout=config.mlp_dropout, name='mlp_arc_h')
+ self.mlp_arc_d = MLP(n_hidden=config.n_mlp_arc,
+ dropout=config.mlp_dropout, name='mlp_arc_d')
+ self.mlp_rel_h = MLP(n_hidden=config.n_mlp_rel,
+ dropout=config.mlp_dropout, name='mlp_rel_h')
+ self.mlp_rel_d = MLP(n_hidden=config.n_mlp_rel,
+ dropout=config.mlp_dropout, name='mlp_rel_d')
+
+ # the Biaffine layers
+ self.arc_attn = Biaffine(n_in=config.n_mlp_arc,
+ bias_x=True,
+ bias_y=False, name='arc_attn')
+ self.rel_attn = Biaffine(n_in=config.n_mlp_rel,
+ n_out=config.n_rels,
+ bias_x=True,
+ bias_y=True, name='rel_attn')
+ self.heads_WV = self.add_weight(shape=[num_heads, x_dim, x_dim])
+ self.dense = tf.keras.layers.Dense(x_dim)
+ self.layer_norm = LayerNormalization(name="LayerNorm")
+ self.graph = config.get('graph', False)
+
+ def call(self, inputs, mask=None, **kwargs):
+ x = inputs
+ arc_h = self.mlp_arc_h(x)
+ arc_d = self.mlp_arc_d(x)
+ rel_h = self.mlp_rel_h(x)
+ rel_d = self.mlp_rel_d(x)
+
+ # get arc and rel scores from the bilinear attention
+ # [batch_size, seq_len, seq_len]
+ s_arc = self.arc_attn(arc_d, arc_h)
+ # if mask is not None:
+ # negative_infinity = -10000.0
+ # s_arc += (1.0 - mask) * negative_infinity
+
+ # [batch_size, seq_len, seq_len, n_rels]
+ s_rel = tf.transpose(self.rel_attn(rel_d, rel_h), [0, 2, 3, 1])
+ if self.graph:
+ p_arc = tf.nn.sigmoid(s_arc - (1.0 - mask) * 10000.0)
+ else:
+ p_arc = tf.nn.softmax(s_arc - (1.0 - mask) * 10000.0, axis=-1)
+ p_rel = tf.nn.softmax(s_rel, axis=-1)
+ A = tf.expand_dims(p_arc, -1) * p_rel
+ A = tf.transpose(A, [0, 3, 1, 2])
+ if hasattr(self, 'shrink'):
+ x = self.shrink(x)
+ Ax = A @ tf.expand_dims(x, 1)
+ AxW = Ax @ self.heads_WV
+ AxW = tf.transpose(AxW, [0, 2, 1, 3])
+ AxW = tf.reshape(AxW, list(tf.shape(AxW)[:2]) + [-1])
+ x = self.dense(AxW) + x
+ x = self.layer_norm(x)
+ return s_arc, s_rel, x
diff --git a/hanlp/components/parsers/biaffine_tf/model.py b/hanlp/components/parsers/biaffine_tf/model.py
new file mode 100644
index 000000000..d74deff8d
--- /dev/null
+++ b/hanlp/components/parsers/biaffine_tf/model.py
@@ -0,0 +1,235 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-26 23:04
+import tensorflow as tf
+from hanlp.layers.transformers.tf_imports import TFPreTrainedModel
+
+from hanlp.components.parsers.biaffine_tf.layers import IndependentDropout, SharedDropout, Biaffine, \
+ MLP, StructuralAttentionLayer
+
+
+class BiaffineModelTF(tf.keras.Model):
+
+ def __init__(self, config, embed=None, transformer: TFPreTrainedModel = None):
+ """An implementation of T. Dozat and C. D. Manning, “Deep Biaffine Attention for Neural Dependency Parsing.,” ICLR, 2017.
+ Although I have my MXNet implementation, I found zysite's PyTorch implementation is cleaner so I port it to TensorFlow
+
+ Args:
+ config: param embed:
+
+ Returns:
+
+ """
+ super(BiaffineModelTF, self).__init__()
+ assert not (embed and transformer), 'Either pre-trained word embed and transformer is supported, but not both'
+ normal = tf.keras.initializers.RandomNormal(stddev=1.)
+ if not transformer:
+ # the embedding layer
+ self.word_embed = tf.keras.layers.Embedding(input_dim=config.n_words,
+ output_dim=config.n_embed,
+ embeddings_initializer=tf.keras.initializers.zeros() if embed
+ else normal,
+ name='word_embed')
+ self.feat_embed = tf.keras.layers.Embedding(input_dim=config.n_feats,
+ output_dim=config.n_embed,
+ embeddings_initializer=tf.keras.initializers.zeros() if embed
+ else normal,
+ name='feat_embed')
+ self.embed_dropout = IndependentDropout(p=config.embed_dropout, name='embed_dropout')
+
+ # the word-lstm layer
+ self.lstm = tf.keras.models.Sequential(name='lstm')
+ for _ in range(config.n_lstm_layers):
+ self.lstm.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(
+ units=config.n_lstm_hidden,
+ dropout=config.lstm_dropout,
+ recurrent_dropout=config.lstm_dropout,
+ return_sequences=True,
+ kernel_initializer='orthogonal',
+ unit_forget_bias=False, # turns out to hinder performance
+ )))
+ self.lstm_dropout = SharedDropout(p=config.lstm_dropout, name='lstm_dropout')
+ else:
+ self.transformer = transformer
+ transformer_dropout = config.get('transformer_dropout', None)
+ if transformer_dropout:
+ self.transformer_dropout = SharedDropout(p=config.transformer_dropout, name='transformer_dropout')
+ d_positional = config.get('d_positional', None)
+ if d_positional:
+ max_seq_length = config.get('max_seq_length', 256)
+ self.position_table = self.add_weight(shape=(max_seq_length, d_positional),
+ initializer='random_normal',
+ trainable=True)
+ # the MLP layers
+ self.mlp_arc_h = MLP(n_hidden=config.n_mlp_arc,
+ dropout=config.mlp_dropout, name='mlp_arc_h')
+ self.mlp_arc_d = MLP(n_hidden=config.n_mlp_arc,
+ dropout=config.mlp_dropout, name='mlp_arc_d')
+ self.mlp_rel_h = MLP(n_hidden=config.n_mlp_rel,
+ dropout=config.mlp_dropout, name='mlp_rel_h')
+ self.mlp_rel_d = MLP(n_hidden=config.n_mlp_rel,
+ dropout=config.mlp_dropout, name='mlp_rel_d')
+
+ # the Biaffine layers
+ self.arc_attn = Biaffine(n_in=config.n_mlp_arc,
+ bias_x=True,
+ bias_y=False, name='arc_attn')
+ self.rel_attn = Biaffine(n_in=config.n_mlp_rel,
+ n_out=config.n_rels,
+ bias_x=True,
+ bias_y=True, name='rel_attn')
+ if embed is not None:
+ self.pretrained = embed
+ self.pad_index = tf.constant(config.pad_index, dtype=tf.int64)
+ self.unk_index = tf.constant(config.unk_index, dtype=tf.int64)
+
+ # noinspection PyMethodOverriding
+ def call(self, inputs, mask_inf=True, **kwargs):
+ # batch_size, seq_len = words.shape
+ # get the mask and lengths of given batch
+ # mask = words.ne(self.pad_index)
+ if hasattr(self, 'lstm'):
+ words, feats = inputs
+ mask = tf.not_equal(words, self.pad_index)
+ # set the indices larger than num_embeddings to unk_index
+ # ext_mask = words.ge(self.word_embed.num_embeddings)
+ ext_mask = tf.greater_equal(words, self.word_embed.input_dim)
+ ext_words = tf.where(ext_mask, self.unk_index, words)
+
+ # get outputs from embedding layers
+ word_embed = self.word_embed(ext_words)
+ if hasattr(self, 'pretrained'):
+ word_embed += self.pretrained(words)
+ feat_embed = self.feat_embed(feats)
+ word_embed, feat_embed = self.embed_dropout([word_embed, feat_embed])
+ # concatenate the word and feat representations
+ embed = tf.concat((word_embed, feat_embed), axis=-1)
+
+ x = self.lstm(embed, mask=mask)
+ x = self.lstm_dropout(x)
+ else:
+ words, (input_ids, input_mask, prefix_offset) = inputs
+ mask = tf.not_equal(words, self.pad_index)
+ x = self.run_transformer(input_ids, input_mask, prefix_offset)
+
+ # apply MLPs to the BiLSTM output states
+ arc_h = self.mlp_arc_h(x)
+ arc_d = self.mlp_arc_d(x)
+ rel_h = self.mlp_rel_h(x)
+ rel_d = self.mlp_rel_d(x)
+
+ # get arc and rel scores from the bilinear attention
+ # [batch_size, seq_len, seq_len]
+ s_arc = self.arc_attn(arc_d, arc_h)
+ # [batch_size, seq_len, seq_len, n_rels]
+ s_rel = tf.transpose(self.rel_attn(rel_d, rel_h), [0, 2, 3, 1])
+ # set the scores that exceed the length of each sentence to -inf
+ if mask_inf:
+ s_arc = tf.where(tf.expand_dims(mask, 1), s_arc, float('-inf'))
+
+ return s_arc, s_rel
+
+ def run_transformer(self, input_ids, input_mask, prefix_offset):
+ if isinstance(self.transformer, TFPreTrainedModel):
+ sequence_output = self.transformer([input_ids, input_mask])
+ sequence_output = sequence_output[0]
+ else:
+ sequence_output = self.transformer([input_ids, tf.zeros_like(input_ids)], mask=input_mask)
+ x = tf.gather(sequence_output, prefix_offset, batch_dims=1)
+ if hasattr(self, 'transformer_dropout'):
+ x = self.transformer_dropout(x)
+ if hasattr(self, 'position_table'):
+ batch_size, seq_length = tf.shape(x)[:2]
+ timing_signal = tf.broadcast_to(self.position_table[:seq_length],
+ [batch_size, seq_length, self.position_table.shape[-1]])
+ x = tf.concat([x, timing_signal], axis=-1)
+ return x
+
+ def to_functional(self):
+ words = tf.keras.Input(shape=[None], dtype=tf.int64, name='words')
+ feats = tf.keras.Input(shape=[None], dtype=tf.int64, name='feats')
+ s_arc, s_rel = self.call([words, feats], mask_inf=False)
+ return tf.keras.Model(inputs=[words, feats], outputs=[s_arc, s_rel])
+
+
+class StructuralAttentionModel(tf.keras.Model):
+ def __init__(self, config, transformer: TFPreTrainedModel = None, masked_lm_embed=None, **kwargs):
+ super().__init__(**kwargs)
+ self.transformer = transformer
+ transformer_dropout = config.get('transformer_dropout', None)
+ if transformer_dropout:
+ self.transformer_dropout = SharedDropout(p=config.transformer_dropout, name='transformer_dropout')
+ d_positional = config.get('d_positional', None)
+ if d_positional:
+ max_seq_length = config.get('max_seq_length', 256)
+ self.position_table = self.add_weight(shape=(max_seq_length, d_positional),
+ initializer='random_normal',
+ trainable=True)
+ self.sa = [StructuralAttentionLayer(config, config.num_heads, transformer.config.hidden_size) for _ in
+ range(config.num_decoder_layers)]
+ self.pad_index = tf.constant(config.pad_index, dtype=tf.int64)
+ masked_lm_dropout = config.get('masked_lm_dropout', None)
+ if masked_lm_dropout:
+ self.masked_lm_dropout = tf.keras.layers.Dropout(masked_lm_dropout, name='masked_lm_dropout')
+ self.use_pos = config.get('use_pos', None)
+ if masked_lm_embed is not None:
+ word_dim, vocab_size = tf.shape(masked_lm_embed)
+ self.projection = tf.keras.layers.Dense(word_dim, name='projection')
+ self.dense = tf.keras.layers.Dense(vocab_size, use_bias=False,
+ kernel_initializer=tf.keras.initializers.constant(
+ masked_lm_embed.numpy()),
+ trainable=False, name='masked_lm')
+ else:
+ self.dense = tf.keras.layers.Dense(config.n_pos if self.use_pos else config.n_words, name='masked_lm')
+
+ def call(self, inputs, training=None, mask=None):
+ if self.use_pos:
+ words, (input_ids, input_mask, prefix_offset, pos) = inputs
+ else:
+ words, (input_ids, input_mask, prefix_offset) = inputs
+
+ x = BiaffineModelTF.run_transformer(self, input_ids, input_mask, prefix_offset)
+ # return None, None, self.dense(x)
+ arcs, rels = [], []
+ mask = tf.not_equal(words, self.pad_index)
+ mask = StructuralAttentionModel.create_attention_mask(tf.shape(x), mask)
+ for sa in self.sa:
+ s_arc, s_rel, x = sa(x, mask=mask)
+ arcs.append(s_arc)
+ rels.append(s_rel)
+
+ if len(self.sa) > 1:
+ arc_scores = tf.reduce_mean(tf.stack(arcs), axis=0)
+ rel_scores = tf.reduce_mean(tf.stack(rels), axis=0)
+ else:
+ arc_scores = arcs[0]
+ rel_scores = rels[0]
+ if training or not self.dense.built:
+ if hasattr(self, 'masked_lm_dropout'):
+ x = self.masked_lm_dropout(x)
+ if hasattr(self, 'projection'):
+ x = self.projection(x)
+ ids = self.dense(x)
+ else:
+ ids = self.dense(x)
+ return arc_scores, rel_scores, ids
+ return arc_scores, rel_scores
+
+ @staticmethod
+ def create_attention_mask(from_shape, input_mask):
+ """Creates 3D attention.
+
+ Args:
+ from_shape: batch_size, from_seq_len, ...]
+ input_mask: batch_size, seq_len]
+
+ Returns:
+ batch_size, from_seq_len, seq_len]
+
+ """
+
+ mask = tf.cast(tf.expand_dims(input_mask, axis=1), tf.float32) # [B, 1, T]
+ ones = tf.expand_dims(tf.ones(shape=from_shape[:2], dtype=tf.float32), axis=-1) # [B, F, 1]
+ mask = ones * mask # broadcast along two dimensions
+
+ return mask # [B, F, T]
diff --git a/hanlp/components/parsers/chu_liu_edmonds.py b/hanlp/components/parsers/chu_liu_edmonds.py
new file mode 100644
index 000000000..897eaa316
--- /dev/null
+++ b/hanlp/components/parsers/chu_liu_edmonds.py
@@ -0,0 +1,314 @@
+# Adopted from https://github.com/allenai/allennlp under Apache Licence 2.0.
+# Changed the packaging.
+
+from typing import List, Set, Tuple, Dict
+import numpy
+
+
+def decode_mst(
+ energy: numpy.ndarray, length: int, has_labels: bool = True
+) -> Tuple[numpy.ndarray, numpy.ndarray]:
+ """Note: Counter to typical intuition, this function decodes the _maximum_
+ spanning tree.
+
+ Decode the optimal MST tree with the Chu-Liu-Edmonds algorithm for
+ maximum spanning arborescences on graphs.
+
+ Adopted from https://github.com/allenai/allennlp/blob/master/allennlp/nn/chu_liu_edmonds.py
+ which is licensed under the Apache License 2.0
+
+ # Parameters
+
+ energy : `numpy.ndarray`, required.
+ A tensor with shape (num_labels, timesteps, timesteps)
+ containing the energy of each edge. If has_labels is `False`,
+ the tensor should have shape (timesteps, timesteps) instead.
+ length : `int`, required.
+ The length of this sequence, as the energy may have come
+ from a padded batch.
+ has_labels : `bool`, optional, (default = True)
+ Whether the graph has labels or not.
+
+ Args:
+ energy: numpy.ndarray:
+ length: int:
+ has_labels: bool: (Default value = True)
+
+ Returns:
+
+ """
+ if has_labels and energy.ndim != 3:
+ raise ValueError("The dimension of the energy array is not equal to 3.")
+ elif not has_labels and energy.ndim != 2:
+ raise ValueError("The dimension of the energy array is not equal to 2.")
+ input_shape = energy.shape
+ max_length = input_shape[-1]
+
+ # Our energy matrix might have been batched -
+ # here we clip it to contain only non padded tokens.
+ if has_labels:
+ energy = energy[:, :length, :length]
+ # get best label for each edge.
+ label_id_matrix = energy.argmax(axis=0)
+ energy = energy.max(axis=0)
+ else:
+ energy = energy[:length, :length]
+ label_id_matrix = None
+ # get original score matrix
+ original_score_matrix = energy
+ # initialize score matrix to original score matrix
+ score_matrix = numpy.array(original_score_matrix, copy=True)
+
+ old_input = numpy.zeros([length, length], dtype=numpy.int32)
+ old_output = numpy.zeros([length, length], dtype=numpy.int32)
+ current_nodes = [True for _ in range(length)]
+ representatives: List[Set[int]] = []
+
+ for node1 in range(length):
+ original_score_matrix[node1, node1] = 0.0
+ score_matrix[node1, node1] = 0.0
+ representatives.append({node1})
+
+ for node2 in range(node1 + 1, length):
+ old_input[node1, node2] = node1
+ old_output[node1, node2] = node2
+
+ old_input[node2, node1] = node2
+ old_output[node2, node1] = node1
+
+ final_edges: Dict[int, int] = {}
+
+ # The main algorithm operates inplace.
+ chu_liu_edmonds(
+ length, score_matrix, current_nodes, final_edges, old_input, old_output, representatives
+ )
+
+ heads = numpy.zeros([max_length], numpy.int32)
+ if has_labels:
+ head_type = numpy.ones([max_length], numpy.int32)
+ else:
+ head_type = None
+
+ for child, parent in final_edges.items():
+ heads[child] = parent
+ if has_labels:
+ head_type[child] = label_id_matrix[parent, child]
+
+ return heads, head_type
+
+
+def chu_liu_edmonds(
+ length: int,
+ score_matrix: numpy.ndarray,
+ current_nodes: List[bool],
+ final_edges: Dict[int, int],
+ old_input: numpy.ndarray,
+ old_output: numpy.ndarray,
+ representatives: List[Set[int]],
+):
+ """Applies the chu-liu-edmonds algorithm recursively
+ to a graph with edge weights defined by score_matrix.
+
+ Note that this function operates in place, so variables
+ will be modified.
+
+ # Parameters
+
+ length : `int`, required.
+ The number of nodes.
+ score_matrix : `numpy.ndarray`, required.
+ The score matrix representing the scores for pairs
+ of nodes.
+ current_nodes : `List[bool]`, required.
+ The nodes which are representatives in the graph.
+ A representative at it's most basic represents a node,
+ but as the algorithm progresses, individual nodes will
+ represent collapsed cycles in the graph.
+ final_edges : `Dict[int, int]`, required.
+ An empty dictionary which will be populated with the
+ nodes which are connected in the maximum spanning tree.
+ old_input : `numpy.ndarray`, required.
+ old_output : `numpy.ndarray`, required.
+ representatives : `List[Set[int]]`, required.
+ A list containing the nodes that a particular node
+ is representing at this iteration in the graph.
+
+ # Returns
+
+ Nothing - all variables are modified in place.
+
+ Args:
+ length: int:
+ score_matrix: numpy.ndarray:
+ current_nodes: List[bool]:
+ final_edges: Dict[int:
+ int]:
+ old_input: numpy.ndarray:
+ old_output: numpy.ndarray:
+ representatives: List[Set[int]]:
+
+ Returns:
+
+ """
+ # Set the initial graph to be the greedy best one.
+ parents = [-1]
+ for node1 in range(1, length):
+ parents.append(0)
+ if current_nodes[node1]:
+ max_score = score_matrix[0, node1]
+ for node2 in range(1, length):
+ if node2 == node1 or not current_nodes[node2]:
+ continue
+
+ new_score = score_matrix[node2, node1]
+ if new_score > max_score:
+ max_score = new_score
+ parents[node1] = node2
+
+ # Check if this solution has a cycle.
+ has_cycle, cycle = _find_cycle(parents, length, current_nodes)
+ # If there are no cycles, find all edges and return.
+ if not has_cycle:
+ final_edges[0] = -1
+ for node in range(1, length):
+ if not current_nodes[node]:
+ continue
+
+ parent = old_input[parents[node], node]
+ child = old_output[parents[node], node]
+ final_edges[child] = parent
+ return
+
+ # Otherwise, we have a cycle so we need to remove an edge.
+ # From here until the recursive call is the contraction stage of the algorithm.
+ cycle_weight = 0.0
+ # Find the weight of the cycle.
+ index = 0
+ for node in cycle:
+ index += 1
+ cycle_weight += score_matrix[parents[node], node]
+
+ # For each node in the graph, find the maximum weight incoming
+ # and outgoing edge into the cycle.
+ cycle_representative = cycle[0]
+ for node in range(length):
+ if not current_nodes[node] or node in cycle:
+ continue
+
+ in_edge_weight = float("-inf")
+ in_edge = -1
+ out_edge_weight = float("-inf")
+ out_edge = -1
+
+ for node_in_cycle in cycle:
+ if score_matrix[node_in_cycle, node] > in_edge_weight:
+ in_edge_weight = score_matrix[node_in_cycle, node]
+ in_edge = node_in_cycle
+
+ # Add the new edge score to the cycle weight
+ # and subtract the edge we're considering removing.
+ score = (
+ cycle_weight
+ + score_matrix[node, node_in_cycle]
+ - score_matrix[parents[node_in_cycle], node_in_cycle]
+ )
+
+ if score > out_edge_weight:
+ out_edge_weight = score
+ out_edge = node_in_cycle
+
+ score_matrix[cycle_representative, node] = in_edge_weight
+ old_input[cycle_representative, node] = old_input[in_edge, node]
+ old_output[cycle_representative, node] = old_output[in_edge, node]
+
+ score_matrix[node, cycle_representative] = out_edge_weight
+ old_output[node, cycle_representative] = old_output[node, out_edge]
+ old_input[node, cycle_representative] = old_input[node, out_edge]
+
+ # For the next recursive iteration, we want to consider the cycle as a
+ # single node. Here we collapse the cycle into the first node in the
+ # cycle (first node is arbitrary), set all the other nodes not be
+ # considered in the next iteration. We also keep track of which
+ # representatives we are considering this iteration because we need
+ # them below to check if we're done.
+ considered_representatives: List[Set[int]] = []
+ for i, node_in_cycle in enumerate(cycle):
+ considered_representatives.append(set())
+ if i > 0:
+ # We need to consider at least one
+ # node in the cycle, arbitrarily choose
+ # the first.
+ current_nodes[node_in_cycle] = False
+
+ for node in representatives[node_in_cycle]:
+ considered_representatives[i].add(node)
+ if i > 0:
+ representatives[cycle_representative].add(node)
+
+ chu_liu_edmonds(
+ length, score_matrix, current_nodes, final_edges, old_input, old_output, representatives
+ )
+
+ # Expansion stage.
+ # check each node in cycle, if one of its representatives
+ # is a key in the final_edges, it is the one we need.
+ found = False
+ key_node = -1
+ for i, node in enumerate(cycle):
+ for cycle_rep in considered_representatives[i]:
+ if cycle_rep in final_edges:
+ key_node = node
+ found = True
+ break
+ if found:
+ break
+
+ previous = parents[key_node]
+ while previous != key_node:
+ child = old_output[parents[previous], previous]
+ parent = old_input[parents[previous], previous]
+ final_edges[child] = parent
+ previous = parents[previous]
+
+
+def _find_cycle(
+ parents: List[int], length: int, current_nodes: List[bool]
+) -> Tuple[bool, List[int]]:
+ added = [False for _ in range(length)]
+ added[0] = True
+ cycle = set()
+ has_cycle = False
+ for i in range(1, length):
+ if has_cycle:
+ break
+ # don't redo nodes we've already
+ # visited or aren't considering.
+ if added[i] or not current_nodes[i]:
+ continue
+ # Initialize a new possible cycle.
+ this_cycle = set()
+ this_cycle.add(i)
+ added[i] = True
+ has_cycle = True
+ next_node = i
+ while parents[next_node] not in this_cycle:
+ next_node = parents[next_node]
+ # If we see a node we've already processed,
+ # we can stop, because the node we are
+ # processing would have been in that cycle.
+ if added[next_node]:
+ has_cycle = False
+ break
+ added[next_node] = True
+ this_cycle.add(next_node)
+
+ if has_cycle:
+ original = next_node
+ cycle.add(original)
+ next_node = parents[original]
+ while next_node != original:
+ cycle.add(next_node)
+ next_node = parents[next_node]
+ break
+
+ return has_cycle, list(cycle)
diff --git a/hanlp/components/parsers/conll.py b/hanlp/components/parsers/conll.py
index 44508f038..980fb4fca 100644
--- a/hanlp/components/parsers/conll.py
+++ b/hanlp/components/parsers/conll.py
@@ -1,586 +1,73 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-26 15:37
-from abc import abstractmethod
-from collections import Counter
-from typing import Generator, Tuple, Union, Iterable, Any, List
+from typing import Union
-import tensorflow as tf
-import numpy as np
-from hanlp.common.structure import SerializableDict
-from hanlp.common.transform import Transform
-from hanlp.components.parsers.alg import kmeans, randperm, arange, tolist
-from hanlp.common.constant import ROOT
-from hanlp.common.vocab import Vocab
-from hanlp.utils.io_util import get_resource
+from hanlp.utils.io_util import get_resource, TimingFileIterator
from hanlp.utils.log_util import logger
-from hanlp.utils.string_util import ispunct
-from hanlp.utils.util import merge_locals_kwargs
-class CoNLLWord(SerializableDict):
- def __init__(self, id, form, lemma=None, cpos=None, pos=None, feats=None, head=None, deprel=None, phead=None,
- pdeprel=None):
- """CoNLL format template, see http://anthology.aclweb.org/W/W06/W06-2920.pdf
-
- Parameters
- ----------
- id : int
- Token counter, starting at 1 for each new sentence.
- form : str
- Word form or punctuation symbol.
- lemma : str
- Lemma or stem (depending on the particular treebank) of word form, or an underscore if not available.
- cpos : str
- Coarse-grained part-of-speech tag, where the tagset depends on the treebank.
- pos : str
- Fine-grained part-of-speech tag, where the tagset depends on the treebank.
- feats : str
- Unordered set of syntactic and/or morphological features (depending on the particular treebank),
- or an underscore if not available.
- head : Union[int, List[int]]
- Head of the current token, which is either a value of ID,
- or zero (’0’) if the token links to the virtual root node of the sentence.
- deprel : Union[str, List[str]]
- Dependency relation to the HEAD.
- phead : int
- Projective head of current token, which is either a value of ID or zero (’0’),
- or an underscore if not available.
- pdeprel : str
- Dependency relation to the PHEAD, or an underscore if not available.
- """
- self.id = id
- self.form = form
- self.cpos = cpos
- self.pos = pos
- self.head = head
- self.deprel = deprel
- self.lemma = lemma
- self.feats = feats
- self.phead = phead
- self.pdeprel = pdeprel
-
- def __str__(self):
- if isinstance(self.head, list):
- return '\n'.join('\t'.join(['_' if v is None else v for v in values]) for values in [
- [str(self.id), self.form, self.lemma, self.cpos, self.pos, self.feats,
- None if head is None else str(head), deprel, self.phead, self.pdeprel] for head, deprel in
- zip(self.head, self.deprel)
- ])
- values = [str(self.id), self.form, self.lemma, self.cpos, self.pos, self.feats,
- None if self.head is None else str(self.head), self.deprel, self.phead, self.pdeprel]
- return '\t'.join(['_' if v is None else v for v in values])
-
- @property
- def nonempty_fields(self):
- return list(f for f in
- [self.form, self.lemma, self.cpos, self.pos, self.feats, self.head, self.deprel, self.phead,
- self.pdeprel] if f)
-
-
-class CoNLLSentence(list):
- def __init__(self, words=None):
- """A list of ConllWord
-
- Parameters
- ----------
- words : Sequence[ConllWord]
- words of a sentence
- """
- super().__init__()
- if words:
- self.extend(words)
-
- def __str__(self):
- return '\n'.join([word.__str__() for word in self])
-
- @staticmethod
- def from_str(conll: str):
- """
- Build a CoNLLSentence from CoNLL-X format str
-
- Parameters
- ----------
- conll : str
- CoNLL-X format string
-
- Returns
- -------
- CoNLLSentence
-
- """
- words: List[CoNLLWord] = []
- prev_id = None
- for line in conll.strip().split('\n'):
- if line.startswith('#'):
- continue
- cells = line.split()
- cells[0] = int(cells[0])
- cells[6] = int(cells[6])
- if cells[0] != prev_id:
- words.append(CoNLLWord(*cells))
- else:
- if isinstance(words[-1].head, list):
- words[-1].head.append(cells[6])
- words[-1].deprel.append(cells[7])
- else:
- words[-1].head = [words[-1].head] + [cells[6]]
- words[-1].deprel = [words[-1].deprel] + [cells[7]]
- prev_id = cells[0]
- return CoNLLSentence(words)
+def collapse_enhanced_empty_nodes(sent: list):
+ collapsed = []
+ for cells in sent:
+ if isinstance(cells[0], float):
+ id = cells[0]
+ head, deprel = cells[8].split(':', 1)
+ for x in sent:
+ arrows = [s.split(':', 1) for s in x[8].split('|')]
+ arrows = [(head, f'{head}:{deprel}>{r}') if h == str(id) else (h, r) for h, r in arrows]
+ arrows = sorted(arrows)
+ x[8] = '|'.join(f'{h}:{r}' for h, r in arrows)
+ sent[head][7] += f'>{cells[7]}'
+ else:
+ collapsed.append(cells)
+ return collapsed
-def read_conll(filepath):
+def read_conll(filepath: Union[str, TimingFileIterator], underline_to_none=False, enhanced_collapse_empty_nodes=False):
sent = []
- filepath = get_resource(filepath)
- with open(filepath, encoding='utf-8') as src:
- for line in src:
- if line.startswith('#'):
- continue
- cells = line.strip().split()
- if cells:
+ if isinstance(filepath, str):
+ filepath: str = get_resource(filepath)
+ if filepath.endswith('.conllu') and enhanced_collapse_empty_nodes is None:
+ enhanced_collapse_empty_nodes = True
+ src = open(filepath, encoding='utf-8')
+ else:
+ src = filepath
+ for idx, line in enumerate(src):
+ if line.startswith('#'):
+ continue
+ line = line.strip()
+ cells = line.split('\t')
+ if line and cells:
+ if enhanced_collapse_empty_nodes and '.' in cells[0]:
+ cells[0] = float(cells[0])
+ cells[6] = None
+ else:
+ if '-' in cells[0] or '.' in cells[0]:
+ # sent[-1][1] += cells[1]
+ continue
cells[0] = int(cells[0])
- cells[6] = int(cells[6])
+ if cells[6] != '_':
+ try:
+ cells[6] = int(cells[6])
+ except ValueError:
+ cells[6] = 0
+ logger.exception(f'Wrong CoNLL format {filepath}:{idx + 1}\n{line}')
+ if underline_to_none:
for i, x in enumerate(cells):
if x == '_':
cells[i] = None
- sent.append(cells)
- else:
- yield sent
- sent = []
- if sent:
- yield sent
-
-
-class CoNLLTransform(Transform):
-
- def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32,
- n_tokens_per_batch=5000, min_freq=2,
- **kwargs) -> None:
- super().__init__(**merge_locals_kwargs(locals(), kwargs))
- self.form_vocab: Vocab = None
- self.cpos_vocab: Vocab = None
- self.rel_vocab: Vocab = None
- self.puncts: tf.Tensor = None
-
- def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
- form, cpos = x
- return self.form_vocab.token_to_idx_table.lookup(form), self.cpos_vocab.token_to_idx_table.lookup(cpos)
-
- def y_to_idx(self, y):
- head, rel = y
- return head, self.rel_vocab.token_to_idx_table.lookup(rel)
-
- def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
- if len(X) == 2:
- form_batch, cposes_batch = X
- mask = tf.not_equal(form_batch, 0)
- elif len(X) == 3:
- form_batch, cposes_batch, mask = X
+ sent.append(cells)
else:
- raise ValueError(f'Expect X to be 2 or 3 elements but got {repr(X)}')
- sents = []
-
- for form_sent, cposes_sent, length in zip(form_batch, cposes_batch,
- tf.math.count_nonzero(mask, axis=-1)):
- forms = tolist(form_sent)[1:length + 1]
- cposes = tolist(cposes_sent)[1:length + 1]
- sents.append([(self.form_vocab.idx_to_token[f],
- self.cpos_vocab.idx_to_token[c]) for f, c in zip(forms, cposes)])
-
- return sents
-
- def lock_vocabs(self):
- super().lock_vocabs()
- self.puncts = tf.constant([i for s, i in self.form_vocab.token_to_idx.items()
- if ispunct(s)], dtype=tf.int64)
-
- def file_to_inputs(self, filepath: str, gold=True):
- assert gold, 'only support gold file for now'
- for sent in read_conll(filepath):
- for i, cell in enumerate(sent):
- form = cell[1]
- cpos = cell[3]
- head = cell[6]
- deprel = cell[7]
- sent[i] = [form, cpos, head, deprel]
+ if enhanced_collapse_empty_nodes:
+ sent = collapse_enhanced_empty_nodes(sent)
yield sent
-
- @property
- def bos(self):
- if self.form_vocab.idx_to_token is None:
- return ROOT
- return self.form_vocab.idx_to_token[2]
-
- def file_to_dataset(self, filepath: str, gold=True, map_x=None, map_y=None, batch_size=5000, shuffle=None,
- repeat=None, drop_remainder=False, prefetch=1, cache=True, **kwargs) -> tf.data.Dataset:
- return super().file_to_dataset(filepath, gold, map_x, map_y, batch_size, shuffle, repeat, drop_remainder,
- prefetch, cache, **kwargs)
-
- def input_is_single_sample(self, input: Any) -> bool:
- return isinstance(input[0][0], str) if len(input[0]) else False
-
- def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_size=5000, shuffle=None, repeat=None,
- drop_remainder=False, prefetch=1, cache=True) -> tf.data.Dataset:
- if shuffle:
- def generator():
- # custom bucketing, load corpus into memory
- corpus = list(x for x in (samples() if callable(samples) else samples))
- lengths = [self.len_of_sent(i) for i in corpus]
- if len(corpus) < 32:
- n_buckets = 1
- else:
- n_buckets = min(self.config.n_buckets, len(corpus))
- buckets = dict(zip(*kmeans(lengths, n_buckets)))
- sizes, buckets = zip(*[
- (size, bucket) for size, bucket in buckets.items()
- ])
- # the number of chunks in each bucket, which is clipped by
- # range [1, len(bucket)]
- chunks = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1)) for size, bucket in
- zip(sizes, buckets)]
- range_fn = randperm if shuffle else arange
- max_samples_per_batch = self.config.get('max_samples_per_batch', None)
- for i in tolist(range_fn(len(buckets))):
- split_sizes = [(len(buckets[i]) - j - 1) // chunks[i] + 1
- for j in range(chunks[i])] # how many sentences in each batch
- for batch_indices in tf.split(range_fn(len(buckets[i])), split_sizes):
- indices = [buckets[i][j] for j in tolist(batch_indices)]
- if max_samples_per_batch:
- for j in range(0, len(indices), max_samples_per_batch):
- yield from self.batched_inputs_to_batches(corpus, indices[j:j + max_samples_per_batch],
- shuffle)
- else:
- yield from self.batched_inputs_to_batches(corpus, indices, shuffle)
-
- else:
- def generator():
- # custom bucketing, load corpus into memory
- corpus = list(x for x in (samples() if callable(samples) else samples))
- n_tokens = 0
- batch = []
- for idx, sent in enumerate(corpus):
- sent_len = self.len_of_sent(sent)
- if n_tokens + sent_len > batch_size and batch:
- yield from self.batched_inputs_to_batches(corpus, batch, shuffle)
- n_tokens = 0
- batch = []
- n_tokens += sent_len
- batch.append(idx)
- if batch:
- yield from self.batched_inputs_to_batches(corpus, batch, shuffle)
-
- # next(generator())
- return Transform.samples_to_dataset(self, generator, False, False, 0, False, repeat, drop_remainder, prefetch,
- cache)
-
- def len_of_sent(self, sent):
- return 1 + len(sent) # take ROOT into account
-
- @abstractmethod
- def batched_inputs_to_batches(self, corpus, indices, shuffle):
- """
- Convert batched inputs to batches of samples
-
- Parameters
- ----------
- corpus : list
- A list of inputs
- indices : list
- A list of indices, each list belongs to a batch
-
- Returns
- -------
- None
-
- Yields
- -------
- tuple
- tuple of tf.Tensor
- """
- pass
-
-
-class CoNLL_DEP_Transform(CoNLLTransform):
-
- def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32,
- n_tokens_per_batch=5000, min_freq=2, **kwargs) -> None:
- super().__init__(config, map_x, map_y, lower, n_buckets, n_tokens_per_batch, min_freq, **kwargs)
-
- def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
- types = (tf.int64, tf.int64), (tf.int64, tf.int64)
- shapes = ([None, None], [None, None]), ([None, None], [None, None])
- values = (self.form_vocab.safe_pad_token_idx, self.cpos_vocab.safe_pad_token_idx), (
- 0, self.rel_vocab.safe_pad_token_idx)
- return types, shapes, values
-
- def batched_inputs_to_batches(self, corpus, indices, shuffle):
- """
- Convert batched inputs to batches of samples
-
- Parameters
- ----------
- corpus : list
- A list of inputs
- indices : list
- A list of indices, each list belongs to a batch
-
- Returns
- -------
- None
-
- Yields
- -------
- tuple
- tuple of tf.Tensor
- """
- raw_batch = [[], [], [], []]
- for idx in indices:
- for b in raw_batch:
- b.append([])
- for cells in corpus[idx]:
- for b, c, v in zip(raw_batch, cells,
- [self.form_vocab, self.cpos_vocab, None, self.rel_vocab]):
- b[-1].append(v.get_idx_without_add(c) if v else c)
- batch = []
- for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab, None, self.rel_vocab]):
- b = tf.keras.preprocessing.sequence.pad_sequences(b, padding='post',
- value=v.safe_pad_token_idx if v else 0,
- dtype='int64')
- batch.append(b)
- assert len(batch) == 4
- yield (batch[0], batch[1]), (batch[2], batch[3])
-
- def inputs_to_samples(self, inputs, gold=False):
- for sent in inputs:
- sample = []
- if self.config['lower']:
- for i, cell in enumerate(sent):
- cell = list(sent[i])
- cell[0] = cell[0].lower()
- if not gold:
- cell += [0, self.rel_vocab.safe_pad_token]
- sample.append(cell)
- # insert root word with arbitrary fields, anyway it will be masked
- # form, cpos, head, deprel = sample[0]
- sample.insert(0, [self.bos, self.bos, 0, self.bos])
- yield sample
-
- def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]], Y: Union[tf.Tensor, Tuple[tf.Tensor]],
- gold=False, inputs=None, conll=True) -> Iterable:
- (words, feats, mask), (arc_preds, rel_preds) = X, Y
- if inputs is None:
- inputs = self.X_to_inputs(X)
- ys = self.Y_to_outputs((arc_preds, rel_preds, mask), inputs=inputs)
- sents = []
- for x, y in zip(inputs, ys):
- sent = CoNLLSentence()
- for idx, ((form, cpos), (head, deprel)) in enumerate(zip(x, y)):
- if conll:
- sent.append(CoNLLWord(id=idx + 1, form=form, cpos=cpos, head=head, deprel=deprel))
- else:
- sent.append([head, deprel])
- sents.append(sent)
- return sents
-
- def fit(self, trn_path: str, **kwargs) -> int:
- self.form_vocab = Vocab()
- self.form_vocab.add(ROOT) # make root the 2ed elements while 0th is pad, 1st is unk
- self.cpos_vocab = Vocab(pad_token=None, unk_token=None)
- self.rel_vocab = Vocab(pad_token=None, unk_token=None)
- num_samples = 0
- counter = Counter()
- for sent in self.file_to_samples(trn_path, gold=True):
- num_samples += 1
- for idx, (form, cpos, head, deprel) in enumerate(sent):
- if idx == 0:
- root = form
- else:
- counter[form] += 1
- self.cpos_vocab.add(cpos)
- self.rel_vocab.add(deprel)
-
- for token in [token for token, freq in counter.items() if freq >= self.config.min_freq]:
- self.form_vocab.add(token)
- return num_samples
-
- def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
- arc_preds, rel_preds, mask = Y
- sents = []
-
- for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
- tf.math.count_nonzero(mask, axis=-1)):
- arcs = tolist(arc_sent)[1:length + 1]
- rels = tolist(rel_sent)[1:length + 1]
- sents.append([(a, self.rel_vocab.idx_to_token[r]) for a, r in zip(arcs, rels)])
-
- return sents
-
-
-class CoNLL_SDP_Transform(CoNLLTransform):
- def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32,
- n_tokens_per_batch=5000, min_freq=2, **kwargs) -> None:
- super().__init__(config, map_x, map_y, lower, n_buckets, n_tokens_per_batch, min_freq, **kwargs)
- self.orphan_relation = ROOT
-
- def lock_vocabs(self):
- super().lock_vocabs()
- # heuristic to find the orphan relation
- for rel in self.rel_vocab.idx_to_token:
- if 'root' in rel.lower():
- self.orphan_relation = rel
- break
-
- def file_to_inputs(self, filepath: str, gold=True):
- assert gold, 'only support gold file for now'
- for i, sent in enumerate(read_conll(filepath)):
- prev_cells = None
- parsed_sent = []
- heads = []
- rels = []
- for j, cell in enumerate(sent):
- ID = cell[0]
- form = cell[1]
- cpos = cell[3]
- head = cell[6]
- deprel = cell[7]
- if prev_cells and ID != prev_cells[0]: # found end of token
- parsed_sent.append([prev_cells[1], prev_cells[2], heads, rels])
- heads = []
- rels = []
- heads.append(head)
- rels.append(deprel)
- prev_cells = [ID, form, cpos, head, deprel]
- parsed_sent.append([prev_cells[1], prev_cells[2], heads, rels])
- yield parsed_sent
-
- def fit(self, trn_path: str, **kwargs) -> int:
- self.form_vocab = Vocab()
- self.form_vocab.add(ROOT) # make root the 2ed elements while 0th is pad, 1st is unk
- self.cpos_vocab = Vocab(pad_token=None, unk_token=None)
- self.rel_vocab = Vocab(pad_token=None, unk_token=None)
- num_samples = 0
- counter = Counter()
- for sent in self.file_to_samples(trn_path, gold=True):
- num_samples += 1
- for idx, (form, cpos, head, deprel) in enumerate(sent):
- if idx == 0:
- root = form
- else:
- counter[form] += 1
- self.cpos_vocab.add(cpos)
- self.rel_vocab.update(deprel)
-
- for token in [token for token, freq in counter.items() if freq >= self.config.min_freq]:
- self.form_vocab.add(token)
- return num_samples
-
- def inputs_to_samples(self, inputs, gold=False):
- for sent in inputs:
- sample = []
- if self.config['lower']:
- for i, cell in enumerate(sent):
- cell = list(sent[i])
- cell[0] = cell[0].lower()
- if not gold:
- cell += [[0], [self.rel_vocab.safe_pad_token]]
- sample.append(cell)
- # insert root word with arbitrary fields, anyway it will be masked
- form, cpos, head, deprel = sample[0]
- sample.insert(0, [self.bos, self.bos, [0], deprel])
- yield sample
-
- def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
- types = (tf.int64, tf.int64), (tf.bool, tf.int64)
- shapes = ([None, None], [None, None]), ([None, None, None], [None, None, None])
- values = (self.form_vocab.safe_pad_token_idx, self.cpos_vocab.safe_pad_token_idx), (
- False, self.rel_vocab.safe_pad_token_idx)
- return types, shapes, values
-
- def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
- arc_preds, rel_preds, mask = Y
- sents = []
-
- for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
- tf.math.count_nonzero(mask, axis=-1)):
sent = []
- for arc, rel in zip(tolist(arc_sent[1:, 1:]), tolist(rel_sent[1:, 1:])):
- ar = []
- for idx, (a, r) in enumerate(zip(arc, rel)):
- if a:
- ar.append((idx + 1, self.rel_vocab.idx_to_token[r]))
- if not ar:
- # orphan
- ar.append((0, self.orphan_relation))
- sent.append(ar)
- sents.append(sent)
-
- return sents
- def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]], Y: Union[tf.Tensor, Tuple[tf.Tensor]],
- gold=False, inputs=None, conll=True) -> Iterable:
- (words, feats, mask), (arc_preds, rel_preds) = X, Y
- xs = inputs
- ys = self.Y_to_outputs((arc_preds, rel_preds, mask))
- sents = []
- for x, y in zip(xs, ys):
- sent = CoNLLSentence()
- for idx, ((form, cpos), pred) in enumerate(zip(x, y)):
- head = [p[0] for p in pred]
- deprel = [p[1] for p in pred]
- if conll:
- sent.append(CoNLLWord(id=idx + 1, form=form, cpos=cpos, head=head, deprel=deprel))
- else:
- sent.append([head, deprel])
- sents.append(sent)
- return sents
-
- def batched_inputs_to_batches(self, corpus, indices, shuffle=False):
- """
- Convert batched inputs to batches of samples
-
- Parameters
- ----------
- corpus : list
- A list of inputs
- indices : list
- A list of indices, each list belongs to a batch
+ if sent:
+ if enhanced_collapse_empty_nodes:
+ sent = collapse_enhanced_empty_nodes(sent)
+ yield sent
- Returns
- -------
- None
+ src.close()
- Yields
- -------
- tuple
- tuple of tf.Tensor
- """
- raw_batch = [[], [], [], []]
- max_len = len(max([corpus[i] for i in indices], key=len))
- for idx in indices:
- arc = np.zeros((max_len, max_len), dtype=np.bool)
- rel = np.zeros((max_len, max_len), dtype=np.int64)
- for b in raw_batch[:2]:
- b.append([])
- for m, cells in enumerate(corpus[idx]):
- for b, c, v in zip(raw_batch, cells,
- [self.form_vocab, self.cpos_vocab]):
- b[-1].append(v.get_idx_without_add(c))
- for n, r in zip(cells[2], cells[3]):
- arc[m, n] = True
- rid = self.rel_vocab.get_idx_without_add(r)
- if rid is None:
- logger.warning(f'Relation OOV: {r} not exists in train')
- continue
- rel[m, n] = rid
- raw_batch[-2].append(arc)
- raw_batch[-1].append(rel)
- batch = []
- for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab]):
- b = tf.keras.preprocessing.sequence.pad_sequences(b, padding='post',
- value=v.safe_pad_token_idx,
- dtype='int64')
- batch.append(b)
- batch += raw_batch[2:]
- assert len(batch) == 4
- yield (batch[0], batch[1]), (batch[2], batch[3])
diff --git a/hanlp/components/parsers/constituency/__init__.py b/hanlp/components/parsers/constituency/__init__.py
new file mode 100644
index 000000000..9f5f607f3
--- /dev/null
+++ b/hanlp/components/parsers/constituency/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-28 19:26
diff --git a/hanlp/components/parsers/constituency/constituency_dataset.py b/hanlp/components/parsers/constituency/constituency_dataset.py
new file mode 100644
index 000000000..1f93f7c1c
--- /dev/null
+++ b/hanlp/components/parsers/constituency/constituency_dataset.py
@@ -0,0 +1,209 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-28 19:27
+from typing import List
+
+from phrasetree.tree import Tree
+
+from hanlp_common.constant import EOS, BOS
+from hanlp.common.dataset import TransformableDataset
+
+
+class ConstituencyDataset(TransformableDataset):
+ def load_file(self, filepath: str):
+ with open(filepath) as src:
+ for line in src:
+ line = line.strip()
+ if not line:
+ continue
+ yield {'constituency': Tree.fromstring(line)}
+
+
+def unpack_tree_to_features(sample: dict):
+ tree = sample.get('constituency', None)
+ if tree:
+ words, tags = zip(*tree.pos())
+ chart = [[None] * (len(words) + 1) for _ in range(len(words) + 1)]
+ for i, j, label in factorize(binarize(tree)[0]):
+ # if no_subcategory:
+ # label = label.split('-')[0]
+ chart[i][j] = label
+ sample['token'] = [BOS] + list(words) + [EOS]
+ sample['chart'] = chart
+ return sample
+
+
+def append_bos_eos(sample: dict):
+ if '_con_token' not in sample:
+ sample['_con_token'] = sample['token']
+ sample['token'] = [BOS] + sample['token'] + [EOS]
+ return sample
+
+
+def remove_subcategory(sample: dict):
+ tree: Tree = sample.get('constituency', None)
+ if tree:
+ for subtree in tree.subtrees():
+ label = subtree.label()
+ subtree.set_label(label.split('-')[0])
+ return sample
+
+
+def binarize(tree: Tree):
+ r"""
+ Conducts binarization over the tree.
+
+ First, the tree is transformed to satisfy `Chomsky Normal Form (CNF)`_.
+ Here we call :meth:`~tree.Tree.chomsky_normal_form` to conduct left-binarization.
+ Second, all unary productions in the tree are collapsed.
+
+ Args:
+ tree (tree.Tree):
+ The tree to be binarized.
+
+ Returns:
+ The binarized tree.
+
+ Examples:
+ >>> tree = Tree.fromstring('''
+ (TOP
+ (S
+ (NP (_ She))
+ (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis)))))
+ (_ .)))
+ ''')
+ >>> print(Tree.binarize(tree))
+ (TOP
+ (S
+ (S|<>
+ (NP (_ She))
+ (VP
+ (VP|<> (_ enjoys))
+ (S+VP (VP|<> (_ playing)) (NP (_ tennis)))))
+ (S|<> (_ .))))
+
+ .. _Chomsky Normal Form (CNF):
+ https://en.wikipedia.org/wiki/Chomsky_normal_form
+ """
+
+ tree: Tree = tree.copy(True)
+ nodes = [tree]
+ while nodes:
+ node = nodes.pop()
+ if isinstance(node, Tree):
+ nodes.extend([child for child in node])
+ if len(node) > 1:
+ for i, child in enumerate(node):
+ if not isinstance(child[0], Tree):
+ node[i] = Tree(f"{node.label()}|<>", [child])
+ tree.chomsky_normal_form('left', 0, 0)
+ tree.collapse_unary()
+
+ return tree
+
+
+def factorize(tree, delete_labels=None, equal_labels=None):
+ r"""
+ Factorizes the tree into a sequence.
+ The tree is traversed in pre-order.
+
+ Args:
+ tree (tree.Tree):
+ The tree to be factorized.
+ delete_labels (set[str]):
+ A set of labels to be ignored. This is used for evaluation.
+ If it is a pre-terminal label, delete the word along with the brackets.
+ If it is a non-terminal label, just delete the brackets (don't delete childrens).
+ In `EVALB`_, the default set is:
+ {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}
+ Default: ``None``.
+ equal_labels (dict[str, str]):
+ The key-val pairs in the dict are considered equivalent (non-directional). This is used for evaluation.
+ The default dict defined in `EVALB`_ is: {'ADVP': 'PRT'}
+ Default: ``None``.
+
+ Returns:
+ The sequence of the factorized tree.
+
+ Examples:
+ >>> tree = Tree.fromstring('' (TOP
+ (S
+ (NP (_ She))
+ (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis)))))
+ (_ .)))
+ '')
+ >>> Tree.factorize(tree)
+ [(0, 5, 'TOP'), (0, 5, 'S'), (0, 1, 'NP'), (1, 4, 'VP'), (2, 4, 'S'), (2, 4, 'VP'), (3, 4, 'NP')]
+ >>> Tree.factorize(tree, delete_labels={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''})
+ [(0, 5, 'S'), (0, 1, 'NP'), (1, 4, 'VP'), (2, 4, 'S'), (2, 4, 'VP'), (3, 4, 'NP')]
+
+ .. _EVALB:
+ https://nlp.cs.nyu.edu/evalb/
+ """
+
+ def track(tree, i):
+ label = tree.label()
+ if delete_labels is not None and label in delete_labels:
+ label = None
+ if equal_labels is not None:
+ label = equal_labels.get(label, label)
+ if len(tree) == 1 and not isinstance(tree[0], Tree):
+ return (i + 1 if label is not None else i), []
+ j, spans = i, []
+ for child in tree:
+ if isinstance(child, Tree):
+ j, s = track(child, j)
+ spans += s
+ if label is not None and j > i:
+ spans = [(i, j, label)] + spans
+ return j, spans
+
+ return track(tree, 0)[1]
+
+
+def build_tree(tokens: List[str], sequence):
+ r"""
+ Builds a constituency tree from the sequence. The sequence is generated in pre-order.
+ During building the tree, the sequence is de-binarized to the original format (i.e.,
+ the suffixes ``|<>`` are ignored, the collapsed labels are recovered).
+
+ Args:
+ tokens :
+ All tokens in a sentence.
+ sequence (list[tuple]):
+ A list of tuples used for generating a tree.
+ Each tuple consits of the indices of left/right span boundaries and label of the span.
+
+ Returns:
+ A result constituency tree.
+
+ Examples:
+ >>> tree = Tree.totree(['She', 'enjoys', 'playing', 'tennis', '.'], 'TOP')
+ >>> sequence = [(0, 5, 'S'), (0, 4, 'S|<>'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP|<>'),
+ (2, 4, 'S+VP'), (2, 3, 'VP|<>'), (3, 4, 'NP'), (4, 5, 'S|<>')]
+ >>> print(Tree.build_tree(root, sequence))
+ (TOP
+ (S
+ (NP (_ She))
+ (VP (_ enjoys) (S (VP (_ playing) (NP (_ tennis)))))
+ (_ .)))
+ """
+ tree = Tree('TOP', [Tree('_', [t]) for t in tokens])
+ root = tree.label()
+ leaves = [subtree for subtree in tree.subtrees() if not isinstance(subtree[0], Tree)]
+
+ def track(node):
+ i, j, label = next(node)
+ if j == i + 1:
+ children = [leaves[i]]
+ else:
+ children = track(node) + track(node)
+ if label.endswith('|<>'):
+ return children
+ labels = label.split('+')
+ tree = Tree(labels[-1], children)
+ for label in reversed(labels[:-1]):
+ tree = Tree(label, [tree])
+ return [tree]
+
+ return Tree(root, track(iter(sequence)))
diff --git a/hanlp/components/parsers/constituency/crf_constituency_model.py b/hanlp/components/parsers/constituency/crf_constituency_model.py
new file mode 100644
index 000000000..0d2392c71
--- /dev/null
+++ b/hanlp/components/parsers/constituency/crf_constituency_model.py
@@ -0,0 +1,214 @@
+# -*- coding:utf-8 -*-
+# Adopted from https://github.com/yzhangcs/parser
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import torch
+from torch import nn
+from hanlp.components.parsers.constituency.treecrf import CRFConstituency
+from hanlp.components.parsers.alg import cky
+from hanlp.components.parsers.biaffine.biaffine import Biaffine
+from hanlp.components.parsers.biaffine.mlp import MLP
+
+
+class CRFConstituencyDecoder(nn.Module):
+ r"""
+ The implementation of CRF Constituency Parser,
+ also called FANCY (abbr. of Fast and Accurate Neural Crf constituencY) Parser.
+
+ References:
+ - Yu Zhang, Houquan Zhou and Zhenghua Li. 2020.
+ `Fast and Accurate Neural CRF Constituency Parsing`_.
+
+ Args:
+ n_words (int):
+ The size of the word vocabulary.
+ n_feats (int):
+ The size of the feat vocabulary.
+ n_labels (int):
+ The number of labels.
+ feat (str):
+ Specifies which type of additional feature to use: ``'char'`` | ``'bert'`` | ``'tag'``.
+ ``'char'``: Character-level representations extracted by CharLSTM.
+ ``'bert'``: BERT representations, other pretrained langugae models like XLNet are also feasible.
+ ``'tag'``: POS tag embeddings.
+ Default: 'char'.
+ n_embed (int):
+ The size of word embeddings. Default: 100.
+ n_feat_embed (int):
+ The size of feature representations. Default: 100.
+ n_char_embed (int):
+ The size of character embeddings serving as inputs of CharLSTM, required if ``feat='char'``. Default: 50.
+ bert (str):
+ Specifies which kind of language model to use, e.g., ``'bert-base-cased'`` and ``'xlnet-base-cased'``.
+ This is required if ``feat='bert'``. The full list can be found in `transformers`.
+ Default: ``None``.
+ n_bert_layers (int):
+ Specifies how many last layers to use. Required if ``feat='bert'``.
+ The final outputs would be the weight sum of the hidden states of these layers.
+ Default: 4.
+ mix_dropout (float):
+ The dropout ratio of BERT layers. Required if ``feat='bert'``. Default: .0.
+ embed_dropout (float):
+ The dropout ratio of input embeddings. Default: .33.
+ n_hidden (int):
+ The size of LSTM hidden states. Default: 400.
+ n_lstm_layers (int):
+ The number of LSTM layers. Default: 3.
+ lstm_dropout (float):
+ The dropout ratio of LSTM. Default: .33.
+ n_mlp_span (int):
+ Span MLP size. Default: 500.
+ n_mlp_label (int):
+ Label MLP size. Default: 100.
+ mlp_dropout (float):
+ The dropout ratio of MLP layers. Default: .33.
+ feat_pad_index (int):
+ The index of the padding token in the feat vocabulary. Default: 0.
+ pad_index (int):
+ The index of the padding token in the word vocabulary. Default: 0.
+ unk_index (int):
+ The index of the unknown token in the word vocabulary. Default: 1.
+
+ .. _Fast and Accurate Neural CRF Constituency Parsing:
+ https://www.ijcai.org/Proceedings/2020/560/
+ .. _transformers:
+ https://github.com/huggingface/transformers
+ """
+
+ def __init__(self,
+ n_labels,
+ n_hidden=400,
+ n_mlp_span=500,
+ n_mlp_label=100,
+ mlp_dropout=.33,
+ **kwargs
+ ):
+ super().__init__()
+
+ # the MLP layers
+ self.mlp_span_l = MLP(n_in=n_hidden, n_out=n_mlp_span, dropout=mlp_dropout)
+ self.mlp_span_r = MLP(n_in=n_hidden, n_out=n_mlp_span, dropout=mlp_dropout)
+ self.mlp_label_l = MLP(n_in=n_hidden, n_out=n_mlp_label, dropout=mlp_dropout)
+ self.mlp_label_r = MLP(n_in=n_hidden, n_out=n_mlp_label, dropout=mlp_dropout)
+
+ # the Biaffine layers
+ self.span_attn = Biaffine(n_in=n_mlp_span, bias_x=True, bias_y=False)
+ self.label_attn = Biaffine(n_in=n_mlp_label, n_out=n_labels, bias_x=True, bias_y=True)
+ self.crf = CRFConstituency()
+ self.criterion = nn.CrossEntropyLoss()
+
+ def forward(self, x, **kwargs):
+ r"""
+ Args:
+ x (~torch.FloatTensor): ``[batch_size, seq_len, hidden_dim]``.
+ Hidden states from encoder.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible spans.
+ The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
+ scores of all possible labels on each span.
+ """
+
+ x_f, x_b = x.chunk(2, -1)
+ x = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1)
+ # apply MLPs to the BiLSTM output states
+ span_l = self.mlp_span_l(x)
+ span_r = self.mlp_span_r(x)
+ label_l = self.mlp_label_l(x)
+ label_r = self.mlp_label_r(x)
+
+ # [batch_size, seq_len, seq_len]
+ s_span = self.span_attn(span_l, span_r)
+ # [batch_size, seq_len, seq_len, n_labels]
+ s_label = self.label_attn(label_l, label_r).permute(0, 2, 3, 1)
+
+ return s_span, s_label
+
+ def loss(self, s_span, s_label, charts, mask, mbr=True):
+ r"""
+ Args:
+ s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all spans
+ s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+ Scores of all labels on each span.
+ charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
+ The tensor of gold-standard labels, in which positions without labels are filled with -1.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+ The mask for covering the unpadded tokens in each chart.
+ mbr (bool):
+ If ``True``, returns marginals for MBR decoding. Default: ``True``.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ The training loss and
+ original span scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
+ """
+
+ span_mask = charts.ge(0) & mask
+ span_loss, span_probs = self.crf(s_span, mask, span_mask, mbr)
+ label_loss = self.criterion(s_label[span_mask], charts[span_mask])
+ loss = span_loss + label_loss
+
+ return loss, span_probs
+
+ def decode(self, s_span, s_label, mask):
+ r"""
+ Args:
+ s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all spans.
+ s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+ Scores of all labels on each span.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+ The mask for covering the unpadded tokens in each chart.
+
+ Returns:
+ list[list[tuple]]:
+ Sequences of factorized labeled trees traversed in pre-order.
+ """
+
+ span_preds = cky(s_span, mask)
+ label_preds = s_label.argmax(-1).tolist()
+ return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)]
+
+
+class CRFConstituencyModel(nn.Module):
+
+ def __init__(self, encoder, decoder: CRFConstituencyDecoder) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+
+ def forward(self, batch):
+ r"""
+ Args:
+ batch (~dict):
+ Batch of input data.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ The first tensor of shape ``[batch_size, seq_len, seq_len]`` holds scores of all possible spans.
+ The second of shape ``[batch_size, seq_len, seq_len, n_labels]`` holds
+ scores of all possible labels on each span.
+ """
+ x = self.encoder(batch)
+ return self.decoder(x)
diff --git a/hanlp/components/parsers/constituency/crf_constituency_parser.py b/hanlp/components/parsers/constituency/crf_constituency_parser.py
new file mode 100644
index 000000000..3104d7cc7
--- /dev/null
+++ b/hanlp/components/parsers/constituency/crf_constituency_parser.py
@@ -0,0 +1,317 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-28 21:24
+import logging
+from typing import Union, List
+
+import torch
+from phrasetree.tree import Tree
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import BOS, EOS, IDX
+from hanlp.common.dataset import TransformableDataset, SamplerBuilder, PadSequenceDataLoader
+from hanlp.common.structure import History
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength, TransformList
+from hanlp.common.vocab import VocabWithNone
+from hanlp.components.classifiers.transformer_classifier import TransformerComponent
+from hanlp.components.parsers.constituency.constituency_dataset import ConstituencyDataset, unpack_tree_to_features, \
+ build_tree, factorize, remove_subcategory
+from hanlp.components.parsers.constituency.crf_constituency_model import CRFConstituencyDecoder, CRFConstituencyModel
+from hanlp.metrics.parsing.span import SpanMetric
+from hanlp.utils.time_util import CountdownTimer
+from hanlp.utils.torch_util import clip_grad_norm
+from hanlp_common.util import merge_locals_kwargs, merge_dict, reorder
+
+
+class CRFConstituencyParser(TorchComponent):
+ def __init__(self, **kwargs) -> None:
+ """Two-stage CRF Parsing (:cite:`ijcai2020-560`).
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+ self.model: CRFConstituencyModel = self.model
+
+ def build_optimizer(self, trn, **kwargs):
+ # noinspection PyCallByClass,PyTypeChecker
+ return TransformerComponent.build_optimizer(self, trn, **kwargs)
+
+ def build_criterion(self, decoder=None, **kwargs):
+ return decoder
+
+ def build_metric(self, **kwargs):
+ return SpanMetric()
+
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
+ logger: logging.Logger, devices, ratio_width=None, patience=0.5, eval_trn=True, **kwargs):
+ if isinstance(patience, float):
+ patience = int(patience * epochs)
+ best_epoch, best_metric = 0, -1
+ timer = CountdownTimer(epochs)
+ history = History()
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, ratio_width=ratio_width,
+ eval_trn=eval_trn, **self.config)
+ loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger=logger, ratio_width=ratio_width)
+ timer.update()
+ report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
+ if dev_metric > best_metric:
+ best_epoch, best_metric = epoch, dev_metric
+ self.save_weights(save_dir)
+ report += ' [red](saved)[/red]'
+ else:
+ report += f' ({epoch - best_epoch})'
+ if epoch - best_epoch >= patience:
+ report += ' early stop'
+ logger.info(report)
+ if epoch - best_epoch >= patience:
+ break
+ if not best_epoch:
+ self.save_weights(save_dir)
+ elif best_epoch != epoch:
+ self.load_weights(save_dir)
+ logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}")
+ logger.info(f"Average time of each epoch is {timer.elapsed_average_human}")
+ logger.info(f"{timer.elapsed_human} elapsed")
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric: SpanMetric,
+ logger: logging.Logger,
+ history: History,
+ gradient_accumulation=1,
+ grad_norm=None,
+ ratio_width=None,
+ eval_trn=True,
+ **kwargs):
+ optimizer, scheduler = optimizer
+ metric.reset()
+ self.model.train()
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
+ total_loss = 0
+ for idx, batch in enumerate(trn):
+ out, mask = self.feed_batch(batch)
+ y = batch['chart_id']
+ loss, span_probs = self.compute_loss(out, y, mask)
+ if gradient_accumulation and gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += loss.item()
+ if eval_trn:
+ prediction = self.decode_output(out, mask, batch, span_probs)
+ self.update_metrics(metric, batch, prediction)
+ if history.step(gradient_accumulation):
+ self._step(optimizer, scheduler, grad_norm)
+ report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn \
+ else f'loss: {total_loss / (idx + 1):.4f}'
+ timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
+ del loss
+ del out
+ del mask
+
+ def decode_output(self, out, mask, batch, span_probs=None, decoder=None, tokens=None):
+ s_span, s_label = out
+ if not decoder:
+ decoder = self.model.decoder
+ if span_probs is None:
+ if self.config.mbr:
+ s_span = decoder.crf(s_span, mask, mbr=True)
+ else:
+ s_span = span_probs
+ chart_preds = decoder.decode(s_span, s_label, mask)
+ idx_to_token = self.vocabs.chart.idx_to_token
+ if tokens is None:
+ tokens = [x[1:-1] for x in batch['token']]
+ trees = [build_tree(token, [(i, j, idx_to_token[label]) for i, j, label in chart]) for token, chart in
+ zip(tokens, chart_preds)]
+ # probs = [prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span.unbind())]
+ return trees
+
+ def update_metrics(self, metric, batch, prediction):
+ # Add pre-terminals (pos tags) back to prediction for safe factorization (deletion based on pos)
+ for pred, gold in zip(prediction, batch['constituency']):
+ pred: Tree = pred
+ gold: Tree = gold
+ for p, g in zip(pred.subtrees(lambda t: t.height() == 2), gold.pos()):
+ token, pos = g
+ p: Tree = p
+ assert p.label() == '_'
+ p.set_label(pos)
+ metric([factorize(tree, self.config.delete, self.config.equal) for tree in prediction],
+ [factorize(tree, self.config.delete, self.config.equal) for tree in batch['constituency']])
+ return metric
+
+ def feed_batch(self, batch: dict):
+ mask = self.compute_mask(batch)
+ s_span, s_label = self.model(batch)
+ return (s_span, s_label), mask
+
+ def compute_mask(self, batch, offset=1):
+ lens = batch['token_length'] - offset
+ seq_len = lens.max()
+ mask = lens.new_tensor(range(seq_len)) < lens.view(-1, 1, 1)
+ mask = mask & mask.new_ones(seq_len, seq_len).triu_(1)
+ return mask
+
+ def compute_loss(self, out, y, mask, crf_decoder=None):
+ if not crf_decoder:
+ crf_decoder = self.model.decoder
+ loss, span_probs = crf_decoder.loss(out[0], out[1], y, mask, self.config.mbr)
+ if loss < 0: # wired negative loss
+ loss *= 0
+ return loss, span_probs
+
+ def _step(self, optimizer, scheduler, grad_norm):
+ clip_grad_norm(self.model, grad_norm)
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+
+ @torch.no_grad()
+ def evaluate_dataloader(self, data, criterion, logger=None, ratio_width=None, metric=None, output=None, **kwargs):
+ self.model.eval()
+ total_loss = 0
+ if not metric:
+ metric = self.build_metric()
+ else:
+ metric.reset()
+ timer = CountdownTimer(len(data))
+ for idx, batch in enumerate(data):
+ out, mask = self.feed_batch(batch)
+ y = batch['chart_id']
+ loss, span_probs = self.compute_loss(out, y, mask)
+ total_loss += loss.item()
+ prediction = self.decode_output(out, mask, batch, span_probs)
+ self.update_metrics(metric, batch, prediction)
+ timer.log(f'loss: {total_loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger,
+ ratio_width=ratio_width)
+ total_loss /= len(data)
+ if output:
+ output.close()
+ return total_loss, metric
+
+ # noinspection PyMethodOverriding
+ def build_model(self, encoder, training=True, **kwargs) -> torch.nn.Module:
+ decoder = CRFConstituencyDecoder(n_labels=len(self.vocabs.chart), n_hidden=encoder.get_output_dim(), **kwargs)
+ encoder = encoder.module(vocabs=self.vocabs, training=training)
+ return CRFConstituencyModel(encoder, decoder)
+
+ def build_dataloader(self,
+ data,
+ batch_size,
+ sampler_builder: SamplerBuilder = None,
+ gradient_accumulation=1,
+ shuffle=False,
+ device=None,
+ logger: logging.Logger = None,
+ **kwargs) -> DataLoader:
+ if isinstance(data, TransformableDataset):
+ dataset = data
+ else:
+ transform = self.config.encoder.transform()
+ if self.config.get('transform', None):
+ transform = TransformList(self.config.transform, transform)
+ dataset = self.build_dataset(data, transform, logger)
+ if self.vocabs.mutable:
+ # noinspection PyTypeChecker
+ self.build_vocabs(dataset, logger)
+ lens = [len(x['token_input_ids']) for x in dataset]
+ if sampler_builder:
+ sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
+ else:
+ sampler = None
+ return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
+ if not data:
+ return []
+ flat = self.input_is_flat(data)
+ if flat:
+ data = [data]
+ samples = self.build_samples(data)
+ dataloader = self.build_dataloader(samples, device=self.device,
+ **merge_dict(self.config, batch_size=batch_size, overwrite=True))
+ outputs = []
+ orders = []
+ for idx, batch in enumerate(dataloader):
+ out, mask = self.feed_batch(batch)
+ prediction = self.decode_output(out, mask, batch, span_probs=None)
+ # prediction = [x[0] for x in prediction]
+ outputs.extend(prediction)
+ orders.extend(batch[IDX])
+ outputs = reorder(outputs, orders)
+ if flat:
+ return outputs[0]
+ return outputs
+
+ def input_is_flat(self, data):
+ return isinstance(data[0], str)
+
+ def build_samples(self, data):
+ return [{'token': [BOS] + token + [EOS]} for token in data]
+
+ # noinspection PyMethodOverriding
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ encoder,
+ lr=5e-5,
+ transformer_lr=None,
+ adam_epsilon=1e-8,
+ weight_decay=0,
+ warmup_steps=0.1,
+ grad_norm=1.0,
+ n_mlp_span=500,
+ n_mlp_label=100,
+ mlp_dropout=.33,
+ batch_size=None,
+ batch_max_tokens=5000,
+ gradient_accumulation=1,
+ epochs=30,
+ patience=0.5,
+ mbr=True,
+ sampler_builder=None,
+ delete=('', ':', '``', "''", '.', '?', '!', '-NONE-', 'TOP', ',', 'S1'),
+ equal=(('ADVP', 'PRT'),),
+ no_subcategory=True,
+ eval_trn=True,
+ transform=None,
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs):
+ if isinstance(equal, tuple):
+ equal = dict(equal)
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_dataset(self, data, transform, logger=None):
+ _transform = [
+ unpack_tree_to_features,
+ self.vocabs,
+ FieldLength('token'),
+ transform
+ ]
+ if self.config.get('no_subcategory', True):
+ _transform.insert(0, remove_subcategory)
+ dataset = ConstituencyDataset(data,
+ transform=_transform,
+ cache=isinstance(data, str))
+ return dataset
+
+ def build_vocabs(self, trn, logger, **kwargs):
+ self.vocabs.chart = VocabWithNone(pad_token=None, unk_token=None)
+ timer = CountdownTimer(len(trn))
+ max_seq_len = 0
+ for each in trn:
+ max_seq_len = max(max_seq_len, len(each['token_input_ids']))
+ timer.log(f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})')
+ self.vocabs.chart.set_unk_as_safe_unk()
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
diff --git a/hanlp/components/parsers/constituency/treecrf.py b/hanlp/components/parsers/constituency/treecrf.py
new file mode 100644
index 000000000..910f84bf6
--- /dev/null
+++ b/hanlp/components/parsers/constituency/treecrf.py
@@ -0,0 +1,360 @@
+# -*- coding:utf-8 -*-
+# Adopted from https://github.com/yzhangcs/parser
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import torch
+import torch.autograd as autograd
+import torch.nn as nn
+
+from hanlp.components.parsers.alg import stripe, istree, eisner, mst, eisner2o
+
+
+class CRFConstituency(nn.Module):
+ r"""
+ TreeCRF for calculating partition functions and marginals in :math:`O(n^3)` for constituency trees.
+
+ References:
+ - Yu Zhang, houquan Zhou and Zhenghua Li. 2020.
+ `Fast and Accurate Neural CRF Constituency Parsing`_.
+
+ .. _Fast and Accurate Neural CRF Constituency Parsing:
+ https://www.ijcai.org/Proceedings/2020/560/
+ """
+
+ @torch.enable_grad()
+ def forward(self, scores, mask, target=None, mbr=False):
+ r"""
+ Args:
+ scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all possible constituents.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+ The mask to avoid parsing over padding tokens.
+ For each square matrix in a batch, the positions except upper triangular part should be masked out.
+ target (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
+ The tensor of gold-standard constituents. ``True`` if a constituent exists. Default: ``None``.
+ mbr (bool):
+ If ``True``, marginals will be returned to perform minimum Bayes-risk (MBR) decoding. Default: ``False``.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+ The second is a tensor of shape ``[batch_size, seq_len, seq_len]``, in which are marginals if ``mbr=True``,
+ or original scores otherwise.
+ """
+
+ training = scores.requires_grad
+ # always enable the gradient computation of scores in order for the computation of marginals
+ logZ = self.inside(scores.requires_grad_(), mask)
+ # marginals are used for decoding, and can be computed by combining the inside pass and autograd mechanism
+ probs = scores
+ if mbr:
+ probs, = autograd.grad(logZ, scores, retain_graph=training)
+ if target is None:
+ return probs
+ loss = (logZ - scores[mask & target].sum()) / mask[:, 0].sum()
+
+ return loss, probs
+
+ def inside(self, scores, mask):
+ lens = mask[:, 0].sum(-1)
+ batch_size, seq_len, _ = scores.shape
+ # [seq_len, seq_len, batch_size]
+ scores, mask = scores.permute(1, 2, 0), mask.permute(1, 2, 0)
+ s = torch.full_like(scores, float('-inf'))
+
+ for w in range(1, seq_len):
+ # n denotes the number of spans to iterate,
+ # from span (0, w) to span (n, n+w) given width w
+ n = seq_len - w
+
+ if w == 1:
+ s.diagonal(w).copy_(scores.diagonal(w))
+ continue
+ # [n, w, batch_size]
+ s_s = stripe(s, n, w - 1, (0, 1)) + stripe(s, n, w - 1, (1, w), 0)
+ # [batch_size, n, w]
+ s_s = s_s.permute(2, 0, 1)
+ if s_s.requires_grad:
+ s_s.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
+ s_s = s_s.logsumexp(-1)
+ s.diagonal(w).copy_(s_s + scores.diagonal(w))
+
+ return s[0].gather(0, lens.unsqueeze(0)).sum()
+
+
+class CRF2oDependency(nn.Module):
+ r"""
+ Second-order TreeCRF for calculating partition functions and marginals in :math:`O(n^3)` for projective dependency trees.
+
+ References:
+ - Yu Zhang, Zhenghua Li and Min Zhang. 2020.
+ `Efficient Second-Order TreeCRF for Neural Dependency Parsing`_.
+
+ .. _Efficient Second-Order TreeCRF for Neural Dependency Parsing:
+ https://www.aclweb.org/anthology/2020.acl-main.302/
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.criterion = nn.CrossEntropyLoss()
+
+ @torch.enable_grad()
+ def forward(self, scores, mask, target=None, mbr=True, partial=False):
+ r"""
+ Args:
+ scores (~torch.Tensor, ~torch.Tensor):
+ Tuple of two tensors `s_arc` and `s_sib`.
+ `s_arc` (``[batch_size, seq_len, seq_len]``) holds Scores of all possible dependent-head pairs.
+ `s_sib` (``[batch_size, seq_len, seq_len, seq_len]``) holds the scores of dependent-head-sibling triples.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ The mask to avoid aggregation on padding tokens.
+ The first column serving as pseudo words for roots should be ``False``.
+ target (~torch.LongTensor): ``[batch_size, seq_len]``.
+ Tensors of gold-standard dependent-head pairs and dependent-head-sibling triples.
+ If partially annotated, the unannotated positions should be filled with -1.
+ Default: ``None``.
+ mbr (bool):
+ If ``True``, marginals will be returned to perform minimum Bayes-risk (MBR) decoding. Default: ``False``.
+ partial (bool):
+ ``True`` indicates that the trees are partially annotated. Default: ``False``.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ The first is the training loss averaged by the number of tokens, which won't be returned if ``target=None``.
+ The second is a tensor of shape ``[batch_size, seq_len, seq_len]``, in which are marginals if ``mbr=True``,
+ or original scores otherwise.
+ """
+
+ s_arc, s_sib = scores
+ training = s_arc.requires_grad
+ batch_size, seq_len, _ = s_arc.shape
+ # always enable the gradient computation of scores in order for the computation of marginals
+ logZ = self.inside((s.requires_grad_() for s in scores), mask)
+ # marginals are used for decoding, and can be computed by combining the inside pass and autograd mechanism
+ probs = s_arc
+ if mbr:
+ probs, = autograd.grad(logZ, s_arc, retain_graph=training)
+
+ if target is None:
+ return probs
+ arcs, sibs = target
+ # the second inside process is needed if use partial annotation
+ if partial:
+ score = self.inside(scores, mask, arcs)
+ else:
+ arc_seq, sib_seq = arcs[mask], sibs[mask]
+ arc_mask, sib_mask = mask, sib_seq.gt(0)
+ sib_seq = sib_seq[sib_mask]
+ s_sib = s_sib[mask][torch.arange(len(arc_seq)), arc_seq]
+ s_arc = s_arc[arc_mask].gather(-1, arc_seq.unsqueeze(-1))
+ s_sib = s_sib[sib_mask].gather(-1, sib_seq.unsqueeze(-1))
+ score = s_arc.sum() + s_sib.sum()
+ loss = (logZ - score) / mask.sum()
+
+ return loss, probs
+
+ def inside(self, scores, mask, cands=None):
+ # the end position of each sentence in a batch
+ lens = mask.sum(1)
+ s_arc, s_sib = scores
+ batch_size, seq_len, _ = s_arc.shape
+ # [seq_len, seq_len, batch_size]
+ s_arc = s_arc.permute(2, 1, 0)
+ # [seq_len, seq_len, seq_len, batch_size]
+ s_sib = s_sib.permute(2, 1, 3, 0)
+ s_i = torch.full_like(s_arc, float('-inf'))
+ s_s = torch.full_like(s_arc, float('-inf'))
+ s_c = torch.full_like(s_arc, float('-inf'))
+ s_c.diagonal().fill_(0)
+
+ # set the scores of arcs excluded by cands to -inf
+ if cands is not None:
+ mask = mask.index_fill(1, lens.new_tensor(0), 1)
+ mask = (mask.unsqueeze(1) & mask.unsqueeze(-1)).permute(2, 1, 0)
+ cands = cands.unsqueeze(-1).index_fill(1, lens.new_tensor(0), -1)
+ cands = cands.eq(lens.new_tensor(range(seq_len))) | cands.lt(0)
+ cands = cands.permute(2, 1, 0) & mask
+ s_arc = s_arc.masked_fill(~cands, float('-inf'))
+
+ for w in range(1, seq_len):
+ # n denotes the number of spans to iterate,
+ # from span (0, w) to span (n, n+w) given width w
+ n = seq_len - w
+ # I(j->i) = logsum(exp(I(j->r) + S(j->r, i)) +, i < r < j
+ # exp(C(j->j) + C(i->j-1)))
+ # + s(j->i)
+ # [n, w, batch_size]
+ il = stripe(s_i, n, w, (w, 1)) + stripe(s_s, n, w, (1, 0), 0)
+ il += stripe(s_sib[range(w, n + w), range(n)], n, w, (0, 1))
+ # [n, 1, batch_size]
+ il0 = stripe(s_c, n, 1, (w, w)) + stripe(s_c, n, 1, (0, w - 1))
+ # il0[0] are set to zeros since the scores of the complete spans starting from 0 are always -inf
+ il[:, -1] = il0.index_fill_(0, lens.new_tensor(0), 0).squeeze(1)
+ if il.requires_grad:
+ il.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
+ il = il.permute(2, 0, 1).logsumexp(-1)
+ s_i.diagonal(-w).copy_(il + s_arc.diagonal(-w))
+ # I(i->j) = logsum(exp(I(i->r) + S(i->r, j)) +, i < r < j
+ # exp(C(i->i) + C(j->i+1)))
+ # + s(i->j)
+ # [n, w, batch_size]
+ ir = stripe(s_i, n, w) + stripe(s_s, n, w, (0, w), 0)
+ ir += stripe(s_sib[range(n), range(w, n + w)], n, w)
+ ir[0] = float('-inf')
+ # [n, 1, batch_size]
+ ir0 = stripe(s_c, n, 1) + stripe(s_c, n, 1, (w, 1))
+ ir[:, 0] = ir0.squeeze(1)
+ if ir.requires_grad:
+ ir.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
+ ir = ir.permute(2, 0, 1).logsumexp(-1)
+ s_i.diagonal(w).copy_(ir + s_arc.diagonal(w))
+
+ # [n, w, batch_size]
+ slr = stripe(s_c, n, w) + stripe(s_c, n, w, (w, 1))
+ if slr.requires_grad:
+ slr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
+ slr = slr.permute(2, 0, 1).logsumexp(-1)
+ # S(j, i) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j
+ s_s.diagonal(-w).copy_(slr)
+ # S(i, j) = logsumexp(C(i->r) + C(j->r+1)), i <= r < j
+ s_s.diagonal(w).copy_(slr)
+
+ # C(j->i) = logsumexp(C(r->i) + I(j->r)), i <= r < j
+ cl = stripe(s_c, n, w, (0, 0), 0) + stripe(s_i, n, w, (w, 0))
+ cl.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
+ s_c.diagonal(-w).copy_(cl.permute(2, 0, 1).logsumexp(-1))
+ # C(i->j) = logsumexp(I(i->r) + C(r->j)), i < r <= j
+ cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
+ cr.register_hook(lambda x: x.masked_fill_(torch.isnan(x), 0))
+ s_c.diagonal(w).copy_(cr.permute(2, 0, 1).logsumexp(-1))
+ # disable multi words to modify the root
+ s_c[0, w][lens.ne(w)] = float('-inf')
+
+ return s_c[0].gather(0, lens.unsqueeze(0)).sum()
+
+ def loss(self, s_arc, s_sib, s_rel, arcs, sibs, rels, mask, mbr=True, partial=False):
+ r"""
+ Args:
+ s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all possible arcs.
+ s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+ Scores of all possible dependent-head-sibling triples.
+ s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+ Scores of all possible labels on each arc.
+ arcs (~torch.LongTensor): ``[batch_size, seq_len]``.
+ The tensor of gold-standard arcs.
+ sibs (~torch.LongTensor): ``[batch_size, seq_len]``.
+ The tensor of gold-standard siblings.
+ rels (~torch.LongTensor): ``[batch_size, seq_len]``.
+ The tensor of gold-standard labels.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ The mask for covering the unpadded tokens.
+ mbr (bool):
+ If ``True``, returns marginals for MBR decoding. Default: ``True``.
+ partial (bool):
+ ``True`` denotes the trees are partially annotated. Default: ``False``.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ The training loss and
+ original arc scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
+ """
+
+ scores, target = (s_arc, s_sib), (arcs, sibs)
+ arc_loss, arc_probs = self.forward(scores, mask, target, mbr, partial)
+ # -1 denotes un-annotated arcs
+ if partial:
+ mask = mask & arcs.ge(0)
+ s_rel, rels = s_rel[mask], rels[mask]
+ s_rel = s_rel[torch.arange(len(rels)), arcs[mask]]
+ rel_loss = self.criterion(s_rel, rels)
+ loss = arc_loss + rel_loss
+ return loss, arc_probs
+
+ # def decode(self, s_arc, s_rel, mask, tree=False, proj=False, alg=None):
+ # r"""
+ # Args:
+ # s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ # Scores of all possible arcs.
+ # s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+ # Scores of all possible labels on each arc.
+ # mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ # The mask for covering the unpadded tokens.
+ # tree (bool):
+ # If ``True``, ensures to output well-formed trees. Default: ``False``.
+ # proj (bool):
+ # If ``True``, ensures to output projective trees. Default: ``False``.
+ #
+ # Returns:
+ # ~torch.Tensor, ~torch.Tensor:
+ # Predicted arcs and labels of shape ``[batch_size, seq_len]``.
+ # """
+ #
+ # lens = mask.sum(1)
+ # arc_preds = s_arc.argmax(-1)
+ # if tree and not alg:
+ # bad = [not istree(seq[1:i + 1], proj)
+ # for i, seq in zip(lens.tolist(), arc_preds.tolist())]
+ # if any(bad):
+ # alg = eisner if proj else mst
+ # arc_preds[bad] = alg(s_arc[bad], mask[bad])
+ # rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
+ #
+ # return arc_preds, rel_preds
+ def decode(self, s_arc, s_sib, s_rel, mask, tree=False, mbr=True, proj=False):
+ r"""
+ Args:
+ s_arc (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
+ Scores of all possible arcs.
+ s_sib (~torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
+ Scores of all possible dependent-head-sibling triples.
+ s_rel (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
+ Scores of all possible labels on each arc.
+ mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
+ The mask for covering the unpadded tokens.
+ tree (bool):
+ If ``True``, ensures to output well-formed trees. Default: ``False``.
+ mbr (bool):
+ If ``True``, performs MBR decoding. Default: ``True``.
+ proj (bool):
+ If ``True``, ensures to output projective trees. Default: ``False``.
+
+ Returns:
+ ~torch.Tensor, ~torch.Tensor:
+ Predicted arcs and labels of shape ``[batch_size, seq_len]``.
+ """
+
+ lens = mask.sum(1)
+ arc_preds = s_arc.argmax(-1)
+ if tree:
+ bad = [not istree(seq[1:i + 1], proj)
+ for i, seq in zip(lens.tolist(), arc_preds.tolist())]
+ if any(bad):
+ if proj and not mbr:
+ arc_preds = eisner2o((s_arc, s_sib), mask)
+ else:
+ alg = eisner if proj else mst
+ arc_preds[bad] = alg(s_arc[bad], mask[bad])
+ rel_preds = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
+
+ return arc_preds, rel_preds
diff --git a/hanlp/components/parsers/hpsg/__init__.py b/hanlp/components/parsers/hpsg/__init__.py
new file mode 100644
index 000000000..bbd1b5016
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-22 21:35
\ No newline at end of file
diff --git a/hanlp/components/parsers/hpsg/bracket_eval.py b/hanlp/components/parsers/hpsg/bracket_eval.py
new file mode 100755
index 000000000..b09fb9f93
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/bracket_eval.py
@@ -0,0 +1,151 @@
+import math
+import os.path
+import re
+import subprocess
+import tempfile
+
+from hanlp.components.parsers.hpsg import trees
+from hanlp.datasets.parsing.ptb import _PTB_HOME
+from hanlp.metrics.metric import Metric
+from hanlp.utils.io_util import get_resource, run_cmd, pushd
+from hanlp.utils.log_util import flash
+from hanlp.utils.string_util import ispunct
+
+
+class FScore(Metric):
+
+ def __init__(self, recall, precision, fscore):
+ self.recall = recall
+ self.precision = precision
+ self.fscore = fscore
+
+ def __str__(self):
+ return f"P: {self.precision:.2%} R: {self.recall:.2%} F1: {self.fscore:.2%}"
+
+ @property
+ def score(self):
+ return self.fscore
+
+ def __call__(self, pred, gold):
+ pass
+
+ def reset(self):
+ self.recall = 0
+ self.precision = 0
+ self.fscore = 0
+
+
+def get_evalb_dir():
+ home = os.path.realpath(os.path.join(get_resource(_PTB_HOME), '../EVALB'))
+ evalb_path = os.path.join(home, 'evalb')
+ if not os.path.isfile(evalb_path):
+ flash(f'Compiling evalb to {home}')
+ with pushd(home):
+ run_cmd(f'make')
+ flash('')
+ if not os.path.isfile(evalb_path):
+ raise RuntimeError(f'Failed to compile evalb at {home}')
+ return home
+
+
+def evalb(gold_trees, predicted_trees, ref_gold_path=None, evalb_dir=None):
+ if not evalb_dir:
+ evalb_dir = get_evalb_dir()
+ assert os.path.exists(evalb_dir)
+ evalb_program_path = os.path.join(evalb_dir, "evalb")
+ evalb_spmrl_program_path = os.path.join(evalb_dir, "evalb_spmrl")
+ assert os.path.exists(evalb_program_path) or os.path.exists(evalb_spmrl_program_path)
+
+ if os.path.exists(evalb_program_path):
+ # evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
+ evalb_param_path = os.path.join(evalb_dir, "nk.prm")
+ else:
+ evalb_program_path = evalb_spmrl_program_path
+ evalb_param_path = os.path.join(evalb_dir, "spmrl.prm")
+
+ assert os.path.exists(evalb_program_path)
+ assert os.path.exists(evalb_param_path)
+
+ assert len(gold_trees) == len(predicted_trees)
+ for gold_tree, predicted_tree in zip(gold_trees, predicted_trees):
+ assert isinstance(gold_tree, trees.TreebankNode)
+ assert isinstance(predicted_tree, trees.TreebankNode)
+ gold_leaves = list(gold_tree.leaves())
+ predicted_leaves = list(predicted_tree.leaves())
+ assert len(gold_leaves) == len(predicted_leaves)
+ for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves):
+ if gold_leaf.word != predicted_leaf.word:
+ # Maybe -LRB- => (
+ if ispunct(predicted_leaf.word):
+ gold_leaf.word = predicted_leaf.word
+ else:
+ print(f'Predicted word {predicted_leaf.word} does not match gold word {gold_leaf.word}')
+ # assert all(
+ # gold_leaf.word == predicted_leaf.word
+ # for gold_leaf, predicted_leaf in zip(gold_leaves, predicted_leaves))
+
+ temp_dir = tempfile.TemporaryDirectory(prefix="evalb-")
+ gold_path = os.path.join(temp_dir.name, "gold.txt")
+ predicted_path = os.path.join(temp_dir.name, "predicted.txt")
+ output_path = os.path.join(temp_dir.name, "output.txt")
+
+ # DELETE
+ # predicted_path = 'tmp_predictions.txt'
+ # output_path = 'tmp_output.txt'
+ # gold_path = 'tmp_gold.txt'
+
+ with open(gold_path, "w") as outfile:
+ if ref_gold_path is None:
+ for tree in gold_trees:
+ outfile.write("{}\n".format(tree.linearize()))
+ else:
+ with open(ref_gold_path) as goldfile:
+ outfile.write(goldfile.read())
+
+ with open(predicted_path, "w") as outfile:
+ for tree in predicted_trees:
+ outfile.write("{}\n".format(tree.linearize()))
+
+ command = "{} -p {} {} {} > {}".format(
+ evalb_program_path,
+ evalb_param_path,
+ gold_path,
+ predicted_path,
+ output_path,
+ )
+ # print(command)
+ subprocess.run(command, shell=True)
+
+ fscore = FScore(math.nan, math.nan, math.nan)
+ with open(output_path) as infile:
+ for line in infile:
+ match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
+ if match:
+ fscore.recall = float(match.group(1)) / 100
+ match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
+ if match:
+ fscore.precision = float(match.group(1)) / 100
+ match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
+ if match:
+ fscore.fscore = float(match.group(1)) / 100
+ break
+
+ success = (
+ not math.isnan(fscore.fscore) or
+ fscore.recall == 0.0 or
+ fscore.precision == 0.0)
+
+ if success:
+ temp_dir.cleanup()
+ else:
+ # print("Error reading EVALB results.")
+ # print("Gold path: {}".format(gold_path))
+ # print("Predicted path: {}".format(predicted_path))
+ # print("Output path: {}".format(output_path))
+ pass
+
+ return fscore
+
+
+if __name__ == '__main__':
+ print(get_evalb_dir())
diff --git a/hanlp/components/parsers/hpsg/const_decoder.pyx b/hanlp/components/parsers/hpsg/const_decoder.pyx
new file mode 100755
index 000000000..6d6e7a002
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/const_decoder.pyx
@@ -0,0 +1,153 @@
+import numpy as np
+cimport numpy as np
+from numpy cimport ndarray
+cimport cython
+
+ctypedef np.float32_t DTYPE_t
+
+ORACLE_PRECOMPUTED_TABLE = {}
+
+@cython.boundscheck(False)
+def decode(int force_gold, int sentence_len, np.ndarray[DTYPE_t, ndim=3] label_scores_chart, int is_train, gold, label_vocab):
+ cdef DTYPE_t NEG_INF = -np.inf
+
+ # Label scores chart is copied so we can modify it in-place for augmentated decode
+ cdef np.ndarray[DTYPE_t, ndim=3] label_scores_chart_copy = label_scores_chart.copy()
+ cdef np.ndarray[DTYPE_t, ndim=2] value_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.float32)
+ cdef np.ndarray[int, ndim=2] split_idx_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
+ cdef np.ndarray[int, ndim=2] best_label_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
+
+ cdef int length
+ cdef int left
+ cdef int right
+
+ cdef np.ndarray[DTYPE_t, ndim=1] label_scores_for_span
+
+ cdef int oracle_label_index
+ cdef DTYPE_t label_score
+ cdef int argmax_label_index
+ cdef DTYPE_t left_score
+ cdef DTYPE_t right_score
+
+ cdef int best_split
+ cdef int split_idx # Loop variable for splitting
+ cdef DTYPE_t split_val # best so far
+ cdef DTYPE_t max_split_val
+
+ cdef int label_index_iter
+
+ cdef np.ndarray[int, ndim=2] oracle_label_chart
+ cdef np.ndarray[int, ndim=2] oracle_split_chart
+ if is_train or force_gold:
+ if gold not in ORACLE_PRECOMPUTED_TABLE:
+ oracle_label_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
+ oracle_split_chart = np.zeros((sentence_len+1, sentence_len+1), dtype=np.int32)
+ for length in range(1, sentence_len + 1):
+ for left in range(0, sentence_len + 1 - length):
+ right = left + length
+ oracle_label_chart[left, right] = label_vocab('\t'.join(gold.oracle_label(left, right)))
+ if length == 1:
+ continue
+ oracle_splits = gold.oracle_splits(left, right)
+ oracle_split_chart[left, right] = min(oracle_splits)
+ if not gold.nocache:
+ ORACLE_PRECOMPUTED_TABLE[gold] = oracle_label_chart, oracle_split_chart
+ else:
+ oracle_label_chart, oracle_split_chart = ORACLE_PRECOMPUTED_TABLE[gold]
+
+ for length in range(1, sentence_len + 1):
+ for left in range(0, sentence_len + 1 - length):
+ right = left + length
+
+ if is_train or force_gold:
+ oracle_label_index = oracle_label_chart[left, right]
+
+ if force_gold:
+ label_score = label_scores_chart_copy[left, right, oracle_label_index]
+ best_label_chart[left, right] = oracle_label_index
+
+ else:
+ if is_train:
+ # augment: here we subtract 1 from the oracle label
+ label_scores_chart_copy[left, right, oracle_label_index] -= 1
+
+ # We do argmax ourselves to make sure it compiles to pure C
+ if length < sentence_len:
+ argmax_label_index = 0
+ else:
+ # Not-a-span label is not allowed at the root of the tree
+ argmax_label_index = 1
+
+ label_score = label_scores_chart_copy[left, right, argmax_label_index]
+ for label_index_iter in range(1, label_scores_chart_copy.shape[2]):
+ if label_scores_chart_copy[left, right, label_index_iter] > label_score:
+ argmax_label_index = label_index_iter
+ label_score = label_scores_chart_copy[left, right, label_index_iter]
+ best_label_chart[left, right] = argmax_label_index
+
+ if is_train:
+ # augment: here we add 1 to all label scores
+ label_score += 1
+
+ if length == 1:
+ value_chart[left, right] = label_score
+ continue
+
+ if force_gold:
+ best_split = oracle_split_chart[left, right]
+ else:
+ best_split = left + 1
+ split_val = NEG_INF
+ for split_idx in range(left + 1, right):
+ max_split_val = value_chart[left, split_idx] + value_chart[split_idx, right]
+ if max_split_val > split_val:
+ split_val = max_split_val
+ best_split = split_idx
+
+ value_chart[left, right] = label_score + value_chart[left, best_split] + value_chart[best_split, right]
+ split_idx_chart[left, right] = best_split
+
+ # Now we need to recover the tree by traversing the chart starting at the
+ # root. This iterative implementation is faster than any of my attempts to
+ # use helper functions and recursion
+
+ # All fully binarized trees have the same number of nodes
+ cdef int num_tree_nodes = 2 * sentence_len - 1
+ cdef np.ndarray[int, ndim=1] included_i = np.empty(num_tree_nodes, dtype=np.int32)
+ cdef np.ndarray[int, ndim=1] included_j = np.empty(num_tree_nodes, dtype=np.int32)
+ cdef np.ndarray[int, ndim=1] included_label = np.empty(num_tree_nodes, dtype=np.int32)
+
+ cdef int idx = 0
+ cdef int stack_idx = 1
+ # technically, the maximum stack depth is smaller than this
+ cdef np.ndarray[int, ndim=1] stack_i = np.empty(num_tree_nodes + 5, dtype=np.int32)
+ cdef np.ndarray[int, ndim=1] stack_j = np.empty(num_tree_nodes + 5, dtype=np.int32)
+ stack_i[1] = 0
+ stack_j[1] = sentence_len
+
+ cdef int i, j, k
+ while stack_idx > 0:
+ i = stack_i[stack_idx]
+ j = stack_j[stack_idx]
+ stack_idx -= 1
+ included_i[idx] = i
+ included_j[idx] = j
+ included_label[idx] = best_label_chart[i, j]
+ idx += 1
+ if i + 1 < j:
+ k = split_idx_chart[i, j]
+ stack_idx += 1
+ stack_i[stack_idx] = k
+ stack_j[stack_idx] = j
+ stack_idx += 1
+ stack_i[stack_idx] = i
+ stack_j[stack_idx] = k
+
+ cdef DTYPE_t running_total = 0.0
+ for idx in range(num_tree_nodes):
+ running_total += label_scores_chart[included_i[idx], included_j[idx], included_label[idx]]
+
+ cdef DTYPE_t score = value_chart[0, sentence_len]
+ cdef DTYPE_t augment_amount = round(score - running_total)
+
+ return score, included_i.astype(int), included_j.astype(int), included_label.astype(int), augment_amount
diff --git a/hanlp/components/parsers/hpsg/dep_eval.py b/hanlp/components/parsers/hpsg/dep_eval.py
new file mode 100755
index 000000000..b3b6d2b7d
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/dep_eval.py
@@ -0,0 +1,106 @@
+__author__ = 'max'
+
+import re
+import numpy as np
+
+from hanlp.metrics.metric import Metric
+
+
+def is_uni_punctuation(word):
+ match = re.match("^[^\w\s]+$]", word, flags=re.UNICODE)
+ return match is not None
+
+
+def is_punctuation(word, pos, punct_set=None):
+ if punct_set is None:
+ # Maybe use ispunct
+ return is_uni_punctuation(word)
+ else:
+ return pos in punct_set or pos == 'PU' # for chinese
+
+
+def eval(batch_size, words, postags, heads_pred, types_pred, heads, types, lengths,
+ punct_set=None, symbolic_root=False, symbolic_end=False):
+ ucorr = 0.
+ lcorr = 0.
+ total = 0.
+ ucomplete_match = 0.
+ lcomplete_match = 0.
+
+ ucorr_nopunc = 0.
+ lcorr_nopunc = 0.
+ total_nopunc = 0.
+ ucomplete_match_nopunc = 0.
+ lcomplete_match_nopunc = 0.
+
+ corr_root = 0.
+ total_root = 0.
+ start = 1 if symbolic_root else 0
+ end = 1 if symbolic_end else 0
+ for i in range(batch_size):
+ ucm = 1.
+ lcm = 1.
+ ucm_nopunc = 1.
+ lcm_nopunc = 1.
+ # assert len(heads[i]) == len(heads_pred[i])
+ for j in range(start, lengths[i] - end):
+ word = words[i][j]
+
+ pos = postags[i][j]
+
+ total += 1
+ if heads[i][j] == heads_pred[i][j]:
+ ucorr += 1
+ if types[i][j] == types_pred[i][j]:
+ lcorr += 1
+ else:
+ lcm = 0
+ else:
+ ucm = 0
+ lcm = 0
+
+ if not is_punctuation(word, pos, punct_set):
+ total_nopunc += 1
+ if heads[i][j] == heads_pred[i][j]:
+ ucorr_nopunc += 1
+ if types[i][j] == types_pred[i][j]:
+ lcorr_nopunc += 1
+ else:
+ lcm_nopunc = 0
+ else:
+ ucm_nopunc = 0
+ lcm_nopunc = 0
+
+ if heads_pred[i][j] == 0:
+ total_root += 1
+ corr_root += 1 if int(heads[i][j]) == 0 else 0
+
+ ucomplete_match += ucm
+ lcomplete_match += lcm
+ ucomplete_match_nopunc += ucm_nopunc
+ lcomplete_match_nopunc += lcm_nopunc
+
+ return (ucorr, lcorr, total, ucomplete_match, lcomplete_match), \
+ (ucorr_nopunc, lcorr_nopunc, total_nopunc, ucomplete_match_nopunc, lcomplete_match_nopunc), \
+ (corr_root, total_root), batch_size
+
+
+class SimpleAttachmentScore(Metric):
+
+ def __init__(self, uas, las) -> None:
+ super().__init__()
+ self.las = las
+ self.uas = uas
+
+ @property
+ def score(self):
+ return self.las
+
+ def __call__(self, pred, gold):
+ raise NotImplementedError()
+
+ def reset(self):
+ raise NotImplementedError()
+
+ def __repr__(self):
+ return f"UAS: {self.uas:.2%} LAS: {self.las:.2%}"
diff --git a/hanlp/components/parsers/hpsg/hpsg_dataset.py b/hanlp/components/parsers/hpsg/hpsg_dataset.py
new file mode 100644
index 000000000..5db62b932
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/hpsg_dataset.py
@@ -0,0 +1,49 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-22 21:36
+import os
+from typing import Union, List, Callable, Tuple
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp.components.parsers.hpsg.trees import load_trees_from_str
+from hanlp.utils.io_util import read_tsv_as_sents, TimingFileIterator, get_resource
+
+
+class HeadDrivenPhraseStructureDataset(TransformableDataset):
+
+ def __init__(self, data: Union[List, Tuple] = None,
+ transform: Union[Callable, List] = None, cache=None) -> None:
+ super().__init__(data, transform, cache)
+
+ def load_data(self, data, generate_idx=False):
+ if isinstance(data, tuple):
+ data = list(self.load_file(data))
+ return data
+
+ def load_file(self, filepath: tuple):
+ phrase_tree_path = get_resource(filepath[0])
+ dep_tree_path = get_resource(filepath[1])
+ pf = TimingFileIterator(phrase_tree_path)
+ message_prefix = f'Loading {os.path.basename(phrase_tree_path)} and {os.path.basename(dep_tree_path)}'
+ for i, (dep_sent, phrase_sent) in enumerate(zip(read_tsv_as_sents(dep_tree_path), pf)):
+ # Somehow the file contains escaped literals
+ phrase_sent = phrase_sent.replace('\\/', '/')
+
+ token = [x[1] for x in dep_sent]
+ pos = [x[3] for x in dep_sent]
+ head = [int(x[6]) for x in dep_sent]
+ rel = [x[7] for x in dep_sent]
+ phrase_tree = load_trees_from_str(phrase_sent, [head], [rel], [token])
+ assert len(phrase_tree) == 1, f'{phrase_tree_path} must have on tree per line.'
+ phrase_tree = phrase_tree[0]
+
+ yield {
+ 'FORM': token,
+ 'CPOS': pos,
+ 'HEAD': head,
+ 'DEPREL': rel,
+ 'tree': phrase_tree,
+ 'hpsg': phrase_tree.convert()
+ }
+ pf.log(f'{message_prefix} {i + 1} samples [blink][yellow]...[/yellow][/blink]')
+ pf.erase()
diff --git a/hanlp/components/parsers/hpsg/hpsg_decoder.pyx b/hanlp/components/parsers/hpsg/hpsg_decoder.pyx
new file mode 100755
index 000000000..8d7cce532
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/hpsg_decoder.pyx
@@ -0,0 +1,260 @@
+import numpy as np
+cimport numpy as np
+from numpy cimport ndarray
+cimport cython
+
+ctypedef np.float32_t DTYPE_t
+
+ORACLE_PRECOMPUTED_TABLE = {}
+
+@cython.boundscheck(False)
+def decode(int force_gold, int sentence_len, np.ndarray[DTYPE_t, ndim=3] label_scores_chart, np.ndarray[DTYPE_t, ndim=2] type_scores_chart, int is_train, gold, label_vocab, type_vocab):
+ cdef DTYPE_t NEG_INF = -np.inf
+
+ # Label scores chart is copied so we can modify it in-place for augmentated decode
+ cdef np.ndarray[DTYPE_t, ndim=3] label_scores_chart_copy = label_scores_chart.copy()
+ cdef np.ndarray[DTYPE_t, ndim=2] type_scores_chart_copy = type_scores_chart.copy()
+
+ cdef np.ndarray[DTYPE_t, ndim=3] value_one_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.float32)
+ cdef np.ndarray[DTYPE_t, ndim=3] value_muti_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.float32)
+
+ cdef np.ndarray[int, ndim=3] split_idx_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.int32)
+ cdef np.ndarray[int, ndim=3] best_label_chart = np.zeros((sentence_len+1, sentence_len+1, 2), dtype=np.int32)
+ cdef np.ndarray[int, ndim=3] head_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.int32)
+ cdef np.ndarray[int, ndim=3] father_chart = np.zeros((sentence_len+1, sentence_len+1, sentence_len+1), dtype=np.int32)
+
+
+ cdef int length
+ cdef int left
+ cdef int right
+
+ cdef int child_l
+ cdef int child_r
+ cdef int child_head
+ cdef int child_type
+ cdef int type_id
+
+
+ cdef np.ndarray[DTYPE_t, ndim=1] label_scores_for_span
+
+ cdef int oracle_label_index
+ cdef int oracle_type_index
+ cdef DTYPE_t label_score_one
+ cdef DTYPE_t label_score_empty
+ cdef DTYPE_t dep_score
+ cdef int argmax_label_index
+ cdef int argmax_type_index
+ cdef DTYPE_t left_score
+ cdef DTYPE_t right_score
+ cdef DTYPE_t type_max_score
+
+ cdef int best_split
+ cdef int split_idx # Loop variable for splitting
+ cdef DTYPE_t split_val # best so far
+ cdef DTYPE_t max_split_val
+
+ cdef int label_index_iter, head, father
+
+ if not force_gold:
+
+ for length in range(1, sentence_len + 1):
+ for left in range(0, sentence_len + 1 - length):
+ right = left + length
+
+ if is_train :
+ oracle_label_index = label_vocab.index(gold.oracle_label(left, right))
+
+ # augment: here we subtract 1 from the oracle label
+ label_scores_chart_copy[left, right, oracle_label_index] -= 1
+
+ # We do argmax ourselves to make sure it compiles to pure C
+ #no empty label
+ argmax_label_index = 1
+ if length == 1 or length == sentence_len:
+ argmax_label_index = 2 #sub_head label can not be leaf
+
+ label_score_one = label_scores_chart_copy[left, right, argmax_label_index]
+ for label_index_iter in range(argmax_label_index, label_scores_chart_copy.shape[2]):
+ if label_scores_chart_copy[left, right, label_index_iter] > label_score_one:
+ argmax_label_index = label_index_iter
+ label_score_one = label_scores_chart_copy[left, right, label_index_iter]
+ best_label_chart[left, right, 1] = argmax_label_index
+
+ label_score_empty = label_scores_chart_copy[left, right,0]
+
+ if is_train:
+ # augment: here we add 1 to all label scores
+ label_score_one +=1
+ label_score_empty += 1
+
+ if length == 1:
+ #head is right, index from 1
+ value_one_chart[left, right, right] = label_score_one
+ value_muti_chart[left, right, right] = label_score_empty
+ if value_one_chart[left, right, right] > value_muti_chart[left, right, right]:
+ value_muti_chart[left, right, right] = value_one_chart[left, right, right]
+ best_label_chart[left, right,0] = best_label_chart[left, right,1]
+ else:
+ best_label_chart[left, right,0] = 0 #empty label
+ head_chart[left, right, right] = -1
+
+ continue
+
+ #head also in the empty part
+ for head_l in range(left + 1, right + 1):
+ value_one_chart[left, right, head_l] = NEG_INF
+
+ for split_idx in range(left + 1, right):
+ for head_l in range(left + 1, split_idx + 1):
+ for head_r in range(split_idx + 1, right + 1):
+
+ #head in the right empty part, left father is right
+ #left is one, right is multi
+ dep_score = type_scores_chart_copy[head_l, head_r]
+ if split_idx - left == 1:#leaf can be empty
+ split_val = value_muti_chart[left, split_idx, head_l] + value_muti_chart[split_idx, right, head_r] + dep_score
+ else :
+ split_val = value_one_chart[left, split_idx, head_l] + value_muti_chart[split_idx, right, head_r] + dep_score
+ if split_val > value_one_chart[left, right, head_r]:
+ value_one_chart[left, right, head_r] = split_val
+ split_idx_chart[left, right, head_r] = split_idx
+ head_chart[left, right, head_r] = head_l
+
+ #head in the left empty part, right father is left
+ #left is multi, right is one
+ dep_score = type_scores_chart_copy[head_r, head_l]
+ if right - split_idx == 1:#leaf can be empty
+ split_val = value_muti_chart[split_idx, right, head_r] + value_muti_chart[left, split_idx, head_l] + dep_score
+ else:
+ split_val = value_one_chart[split_idx, right, head_r] + value_muti_chart[left, split_idx, head_l] + dep_score
+ if split_val > value_one_chart[left, right, head_l]:
+ value_one_chart[left, right, head_l] = split_val
+ split_idx_chart[left, right, head_l] = split_idx
+ head_chart[left, right, head_l] = head_r
+
+ for head_l in range(left + 1, right + 1):
+ if label_score_one > label_score_empty:
+ value_muti_chart[left, right, head_l] = value_one_chart[left, right, head_l] + label_score_one
+ else :
+ value_muti_chart[left, right, head_l] = value_one_chart[left, right, head_l] + label_score_empty
+ value_one_chart[left, right, head_l] = value_one_chart[left, right, head_l] + label_score_one
+
+ if label_score_one < label_score_empty:
+ best_label_chart[left, right, 0] = 0
+ else:
+ best_label_chart[left, right,0] = best_label_chart[left, right,1]
+ #add mergein
+
+ # Now we need to recover the tree by traversing the chart starting at the
+ # root. This iterative implementation is faster than any of my attempts to
+ # use helper functions and recursion
+
+ # All fully binarized trees have the same number of nodes
+ cdef int num_tree_nodes = 2 * sentence_len - 1
+ cdef np.ndarray[int, ndim=1] included_i = np.empty(num_tree_nodes, dtype=np.int32)
+ cdef np.ndarray[int, ndim=1] included_j = np.empty(num_tree_nodes, dtype=np.int32)
+
+ cdef np.ndarray[int, ndim=1] included_label = np.empty(num_tree_nodes, dtype=np.int32)
+
+ cdef np.ndarray[int, ndim=1] included_type = np.empty(sentence_len, dtype=np.int32)
+ cdef np.ndarray[int, ndim=1] included_father = np.zeros(sentence_len, dtype=np.int32)# 0 is root
+
+ cdef int idx = 0
+ cdef int stack_idx = 1
+ # technically, the maximum stack depth is smaller than this
+ cdef np.ndarray[int, ndim=1] stack_i = np.empty(num_tree_nodes + 5, dtype=np.int32)
+ cdef np.ndarray[int, ndim=1] stack_j = np.empty(num_tree_nodes + 5, dtype=np.int32)
+ cdef np.ndarray[int, ndim=1] stack_head = np.empty(num_tree_nodes + 5, dtype=np.int32)
+
+ cdef np.ndarray[int, ndim=1] stack_type = np.empty(num_tree_nodes + 5, dtype=np.int32)
+
+ cdef int i, j, k, root_head, nodetype, sub_head
+ if not force_gold:
+ max_split_val = NEG_INF
+ for idxx in range(sentence_len):
+ split_val = value_one_chart[0, sentence_len, idxx + 1] + type_scores_chart[idxx + 1, 0]
+ if split_val > max_split_val:
+ max_split_val = split_val
+ root_head = idxx + 1
+ else:
+ root_head = gold.oracle_head(0, sentence_len)
+ stack_i[1] = 0
+ stack_j[1] = sentence_len
+ stack_head[1] = root_head
+ stack_type[1] = 1
+
+ while stack_idx > 0:
+
+ i = stack_i[stack_idx]
+ j = stack_j[stack_idx]
+ head = stack_head[stack_idx]
+ nodetype = stack_type[stack_idx]
+ stack_idx -= 1
+
+ included_i[idx] = i
+ included_j[idx] = j
+ if force_gold:
+ included_label[idx] = label_vocab.index(gold.oracle_label(i,j))
+ else :
+ if i + 1 == j:
+ nodetype = 0
+ included_label[idx] = best_label_chart[i, j, nodetype]
+
+ idx += 1
+ if i + 1 < j:
+
+ if force_gold:
+ oracle_splits = gold.oracle_splits(i, j)
+ if head > min(oracle_splits): #head index from 1
+ #h in most right, so most left is noempty
+ k = min(oracle_splits)
+ sub_head = gold.oracle_head(i, k)
+ included_father[sub_head - 1] = head
+ else:
+ k = max(oracle_splits)
+ sub_head = gold.oracle_head(k, j)
+ included_father[sub_head - 1] = head
+ else:
+ k = split_idx_chart[i, j, head]
+ sub_head = head_chart[i,j, head]
+ included_father[sub_head - 1] = head
+
+ stack_idx += 1
+ stack_i[stack_idx] = k
+ stack_j[stack_idx] = j
+ if head > k:
+ stack_head[stack_idx] = head
+ stack_type[stack_idx] = 0
+ else :
+ stack_head[stack_idx] = sub_head
+ stack_type[stack_idx] = 1
+ stack_idx += 1
+ stack_i[stack_idx] = i
+ stack_j[stack_idx] = k
+ if head > k:
+ stack_head[stack_idx] = sub_head
+ stack_type[stack_idx] = 1
+ else :
+ stack_head[stack_idx] = head
+ stack_type[stack_idx] = 0
+
+ cdef DTYPE_t running_total = 0.0
+ for idx in range(num_tree_nodes):
+ running_total += label_scores_chart[included_i[idx], included_j[idx], included_label[idx]]
+
+ for idx in range(sentence_len):
+ #root_head father is 0
+ if force_gold:
+ argmax_type_index = type_vocab.index(gold.oracle_type(idx, idx + 1))
+ else :
+ argmax_type_index = 0
+ #root_head father is 0
+ running_total += type_scores_chart[idx + 1, included_father[idx]]
+ included_type[idx] = argmax_type_index
+
+ cdef DTYPE_t score = value_one_chart[0, sentence_len, root_head] + type_scores_chart[root_head, 0]
+ if force_gold:
+ score = running_total
+ cdef DTYPE_t augment_amount = round(score - running_total)
+
+ return score, included_i.astype(int), included_j.astype(int), included_label.astype(int), included_father.astype(int), included_type.astype(int), augment_amount
diff --git a/hanlp/components/parsers/hpsg/hpsg_parser.py b/hanlp/components/parsers/hpsg/hpsg_parser.py
new file mode 100644
index 000000000..895d84b70
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/hpsg_parser.py
@@ -0,0 +1,349 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-22 23:40
+import logging
+from typing import Union, List, Callable, Any, Dict
+
+import torch
+from torch.utils.data import DataLoader
+
+from hanlp.common.dataset import PadSequenceDataLoader
+from hanlp.common.structure import History
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength
+from hanlp.common.vocab import Vocab
+from hanlp.components.ner.biaffine_ner.biaffine_ner import BiaffineNamedEntityRecognizer
+from hanlp.components.parsers.hpsg import trees, bracket_eval, dep_eval
+from hanlp.components.parsers.hpsg.bracket_eval import FScore
+from hanlp.components.parsers.hpsg.dep_eval import SimpleAttachmentScore
+from hanlp.components.parsers.hpsg.hpsg_dataset import HeadDrivenPhraseStructureDataset
+from hanlp.components.parsers.hpsg.hpsg_parser_model import ChartParser
+from hanlp.datasets.parsing.conll_dataset import append_bos_eos
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.metrics.parsing.attachmentscore import AttachmentScore
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs
+
+
+class HeadDrivenPhraseStructureParser(TorchComponent):
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.model: ChartParser = None
+
+ # noinspection PyCallByClass
+ def build_optimizer(self,
+ trn,
+ epochs,
+ lr,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ transformer_lr,
+ **kwargs):
+ return BiaffineNamedEntityRecognizer.build_optimizer(self,
+ trn,
+ epochs,
+ lr,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ transformer_lr)
+
+ def build_criterion(self, **kwargs):
+ pass
+
+ def build_metric(self, **kwargs):
+ return AttachmentScore()
+
+ def build_model(self, training=True, **kwargs) -> torch.nn.Module:
+ model = ChartParser(self.config.embed.module(vocabs=self.vocabs),
+ self.vocabs.pos, self.vocabs.label, self.vocabs.rel, self.config)
+ return model
+
+ # noinspection PyMethodOverriding
+ def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger, sampler_builder,
+ gradient_accumulation,
+ **kwargs) -> DataLoader:
+ # shuffle = False # We need to find the smallest grad_acc
+ dataset = HeadDrivenPhraseStructureDataset(data, transform=[append_bos_eos])
+ if self.config.get('transform', None):
+ dataset.append_transform(self.config.transform)
+ dataset.append_transform(self.vocabs)
+ if isinstance(self.config.embed, Embedding):
+ transform = self.config.embed.transform(vocabs=self.vocabs)
+ if transform:
+ dataset.append_transform(transform)
+ dataset.append_transform(self.vocabs)
+ field_length = FieldLength('token')
+ dataset.append_transform(field_length)
+ if isinstance(data, str):
+ dataset.purge_cache() # Enable cache
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger)
+ if 'token' in self.vocabs:
+ lens = [x[field_length.dst] for x in dataset]
+ else:
+ lens = [len(x['token_input_ids']) for x in dataset]
+ if sampler_builder:
+ sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
+ else:
+ sampler = None
+ return PadSequenceDataLoader(batch_sampler=sampler,
+ batch_size=batch_size,
+ device=device,
+ dataset=dataset)
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
+ pass
+
+ def build_vocabs(self, dataset, logger, **kwargs):
+ self.vocabs.rel = Vocab(pad_token=None, unk_token=None)
+ self.vocabs.pos = Vocab(pad_token=None, unk_token=None)
+ self.vocabs.label = label_vocab = Vocab(pad_token='', unk_token=None)
+ label_vocab.add(trees.Sub_Head)
+ for each in dataset:
+ tree = each['hpsg']
+ nodes = [tree]
+ while nodes:
+ node = nodes.pop()
+ if isinstance(node, trees.InternalParseNode):
+ label_vocab.add('\t'.join(node.label))
+ nodes.extend(reversed(node.children))
+ self.vocabs['rel'].set_unk_as_safe_unk()
+ label_vocab.set_unk_as_safe_unk()
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
+
+ def fit(self, trn_data, dev_data, save_dir,
+ embed: Embedding,
+ batch_size=100,
+ epochs=100,
+ sampler='sorting',
+ n_buckets=32,
+ batch_max_tokens=None,
+ sampler_builder=None,
+ attention_dropout=0.2,
+ bert_do_lower_case=True,
+ bert_model='bert-large-uncased',
+ bert_transliterate='',
+ char_lstm_input_dropout=0.2,
+ clip_grad_norm=0.0,
+ const_lada=0.5,
+ d_biaffine=1024,
+ d_char_emb=64,
+ d_ff=2048,
+ d_kv=64,
+ d_label_hidden=250,
+ d_model=1024,
+ dataset='ptb',
+ elmo_dropout=0.5,
+ embedding_dropout=0.2,
+ embedding_path='data/glove.gz',
+ embedding_type='random',
+ lal_combine_as_self=False,
+ lal_d_kv=128,
+ lal_d_proj=128,
+ lal_partitioned=True,
+ lal_pwff=True,
+ lal_q_as_matrix=False,
+ lal_resdrop=False,
+ max_len_dev=0,
+ max_len_train=0,
+ morpho_emb_dropout=0.2,
+ num_heads=8,
+ num_layers=3,
+ pad_left=False,
+ partitioned=True,
+ relu_dropout=0.2,
+ residual_dropout=0.2,
+ sentence_max_len=300,
+ step_decay=True,
+ step_decay_factor=0.5,
+ step_decay_patience=5,
+ tag_emb_dropout=0.2,
+ timing_dropout=0.0,
+ dont_use_encoder=False,
+ use_cat=False,
+ use_chars_lstm=False,
+ use_elmo=False,
+ use_lal=True,
+ use_tags=True,
+ use_words=False,
+ word_emb_dropout=0.4,
+ lr=1e-3,
+ transformer_lr=5e-5,
+ adam_epsilon=1e-6,
+ weight_decay=0.01,
+ warmup_steps=0.1,
+ grad_norm=5.0,
+ gradient_accumulation=1,
+ devices=None, logger=None, seed=None, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ history: History,
+ linear_scheduler=None,
+ gradient_accumulation=1,
+ **kwargs):
+ self.model.train()
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
+ total_loss = 0
+ self.reset_metrics(metric)
+ for idx, batch in enumerate(trn):
+ output_dict = self.feed_batch(batch)
+ self.update_metrics(batch, output_dict, metric)
+ loss = output_dict['loss']
+ if gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += loss.item()
+ if history.step(gradient_accumulation):
+ self._step(optimizer, linear_scheduler)
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger)
+ del loss
+ return total_loss / timer.total
+
+ def _step(self, optimizer, linear_scheduler):
+ if self.config.grad_norm:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
+ optimizer.step()
+ optimizer.zero_grad()
+ if linear_scheduler:
+ linear_scheduler.step()
+
+ # noinspection PyMethodOverriding
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric,
+ logger,
+ ratio_width=None,
+ output=False,
+ **kwargs):
+ self.model.eval()
+ self.reset_metrics(metric)
+ timer = CountdownTimer(len(data))
+ gold_tree = []
+ pred_tree = []
+ pred_head = []
+ pred_type = []
+ gold_type = []
+ gold_word = []
+ gold_pos = []
+ gold_head = []
+ for batch in data:
+ output_dict = self.feed_batch(batch)
+ gold_tree += batch['tree']
+ pred_tree += output_dict['predicted_tree']
+
+ pred_head += output_dict['pred_head']
+ pred_type += output_dict['pred_type']
+
+ gold_type += batch['DEPREL']
+ gold_head += batch['HEAD']
+ gold_pos += batch['CPOS']
+ gold_word += batch['FORM']
+ assert len(pred_head) == len(gold_head)
+ self.update_metrics(batch, output_dict, metric)
+ timer.log('', ratio_percentage=None, ratio_width=ratio_width)
+
+ tree_score: FScore = bracket_eval.evalb(gold_tree, pred_tree)
+ assert len(pred_head) == len(pred_type)
+ assert len(pred_type) == len(gold_type)
+ lens = [len(x) for x in gold_word]
+ stats, stats_nopunc, stats_root, test_total_inst = dep_eval.eval(len(pred_head), gold_word, gold_pos,
+ pred_head,
+ pred_type, gold_head, gold_type,
+ lens, punct_set=None,
+ symbolic_root=False)
+
+ test_ucorrect, test_lcorrect, test_total, test_ucomlpete_match, test_lcomplete_match = stats
+ test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucomlpete_match_nopunc, test_lcomplete_match_nopunc = stats_nopunc
+ test_root_correct, test_total_root = stats_root
+ dep_score = SimpleAttachmentScore(test_ucorrect_nopunc / test_total_nopunc,
+ test_lcorrect_nopunc / test_total_nopunc)
+ timer.log(f'{tree_score} {dep_score}', ratio_percentage=None, ratio_width=ratio_width, logger=logger)
+ return tree_score, dep_score
+
+ def reset_metrics(self, metrics):
+ pass
+ # for m in metrics:
+ # m.reset()
+
+ def report_metrics(self, loss, metrics):
+ return f'loss: {loss:.4f}'
+
+ def feed_batch(self, batch) -> Dict[str, Any]:
+ predicted_tree, loss_or_score = self.model(batch)
+ outputs = {}
+ if isinstance(loss_or_score, torch.Tensor):
+ loss_or_score /= len(batch['hpsg'])
+ loss = loss_or_score
+ outputs['loss'] = loss
+ else:
+ score = loss_or_score
+ outputs['score'] = score
+ if predicted_tree:
+ predicted_tree = [p.convert() for p in predicted_tree]
+ pred_head = [[leaf.father for leaf in tree.leaves()] for tree in predicted_tree]
+ pred_type = [[leaf.type for leaf in tree.leaves()] for tree in predicted_tree]
+ outputs.update({
+ 'predicted_tree': predicted_tree,
+ 'pred_head': pred_head,
+ 'pred_type': pred_type
+ }),
+ return outputs
+
+ def update_metrics(self, batch: dict, output_dict: dict, metrics):
+ pass
+ # assert len(output_dict['prediction']) == len(batch['ner'])
+ # for pred, gold in zip(output_dict['prediction'], batch['ner']):
+ # metrics(set(pred), set(gold))
+
+ def execute_training_loop(self,
+ trn: DataLoader,
+ dev: DataLoader,
+ epochs,
+ criterion,
+ optimizer,
+ metric,
+ save_dir,
+ logger: logging.Logger,
+ devices,
+ **kwargs):
+ best_epoch, best_score = 0, -1
+ optimizer, scheduler = optimizer
+ timer = CountdownTimer(epochs)
+ _len_trn = len(trn) // self.config.gradient_accumulation
+ ratio_width = len(f'{_len_trn}/{_len_trn}')
+ history = History()
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, history,
+ linear_scheduler=scheduler if self.use_transformer else None, **kwargs)
+ if dev:
+ metric = self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
+ report = f'{timer.elapsed_human}/{timer.total_time_human}'
+ dev_score = sum(x.score for x in metric) / len(metric)
+ if not self.use_transformer:
+ scheduler.step(dev_score)
+ if dev_score > best_score:
+ self.save_weights(save_dir)
+ best_score = dev_score
+ report += ' [red]saved[/red]'
+ timer.log(report, ratio_percentage=False, newline=True, ratio=False)
+
+ @property
+ def use_transformer(self):
+ return 'token' not in self.vocabs
+
+ def _get_transformer(self):
+ return getattr(self.model.embed, 'transformer', None)
diff --git a/hanlp/components/parsers/hpsg/hpsg_parser_model.py b/hanlp/components/parsers/hpsg/hpsg_parser_model.py
new file mode 100644
index 000000000..ce2142f11
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/hpsg_parser_model.py
@@ -0,0 +1,1450 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-22 23:41
+from typing import List
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from torch.nn.utils.rnn import pad_sequence
+import pyximport
+from hanlp.components.parsers.hpsg.trees import InternalTreebankNode, InternalParseNode
+
+pyximport.install(setup_args={"include_dirs": np.get_include()})
+from hanlp.components.parsers.hpsg import hpsg_decoder
+from hanlp.components.parsers.hpsg import const_decoder
+from hanlp.components.parsers.hpsg import trees
+from alnlp.modules import util
+
+START = ""
+STOP = ""
+UNK = ""
+ROOT = ""
+Sub_Head = ""
+No_Head = ""
+
+DTYPE = torch.bool
+
+TAG_UNK = "UNK"
+
+ROOT_TYPE = ""
+
+# Assumes that these control characters are not present in treebank text
+CHAR_UNK = "\0"
+CHAR_START_SENTENCE = "\1"
+CHAR_START_WORD = "\2"
+CHAR_STOP_WORD = "\3"
+CHAR_STOP_SENTENCE = "\4"
+CHAR_PAD = "\5"
+
+
+def from_numpy(ndarray):
+ return torch.from_numpy(ndarray)
+
+
+class BatchIndices:
+ """Batch indices container class (used to implement packed batches)"""
+
+ def __init__(self, batch_idxs_torch):
+ self.batch_idxs_torch = batch_idxs_torch
+ self.batch_size = int(1 + batch_idxs_torch.max())
+ batch_idxs_np_extra = np.concatenate([[-1], batch_idxs_torch.cpu().numpy(), [-1]])
+ self.boundaries_np = np.nonzero(batch_idxs_np_extra[1:] != batch_idxs_np_extra[:-1])[0]
+ self.seq_lens_np = self.boundaries_np[1:] - self.boundaries_np[:-1]
+ assert len(self.seq_lens_np) == self.batch_size
+ self.max_len = int(np.max(self.boundaries_np[1:] - self.boundaries_np[:-1]))
+
+
+#
+class LockedDropoutFunction(torch.autograd.function.InplaceFunction):
+ @classmethod
+ def forward(cls, ctx, input, batch_idxs=None, p=0.5, train=False, inplace=False):
+ """Tokens in the same batch share the same dropout mask
+
+ Args:
+ ctx:
+ input:
+ batch_idxs: (Default value = None)
+ p: (Default value = 0.5)
+ train: (Default value = False)
+ inplace: (Default value = False)
+
+ Returns:
+
+
+ """
+ if p < 0 or p > 1:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+
+ ctx.p = p
+ ctx.train = train
+ ctx.inplace = inplace
+
+ if ctx.inplace:
+ ctx.mark_dirty(input)
+ output = input
+ else:
+ output = input.clone()
+
+ if ctx.p > 0 and ctx.train:
+ if batch_idxs:
+ ctx.noise = input.new().resize_(batch_idxs.batch_size, input.size(1))
+ ctx.noise = ctx.noise[batch_idxs.batch_idxs_torch, :]
+ else:
+ ctx.noise = input.new(input.size(0), 1, input.size(2))
+ if ctx.p == 1:
+ ctx.noise.fill_(0)
+ else:
+ ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
+ output.mul_(ctx.noise)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.p > 0 and ctx.train:
+ return grad_output.mul(ctx.noise), None, None, None, None
+ else:
+ return grad_output, None, None, None, None
+
+
+#
+class FeatureDropout(nn.Module):
+ """Feature-level dropout: takes an input of size len x num_features and drops
+ each feature with probabibility p. A feature is dropped across the full
+ portion of the input that corresponds to a single batch element.
+
+ Args:
+
+ Returns:
+
+ """
+
+ def __init__(self, p=0.5, inplace=False):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ self.p = p
+ self.inplace = inplace
+
+ def forward(self, input, batch_idxs=None):
+ return LockedDropoutFunction.apply(input, batch_idxs, self.p, self.training, self.inplace)
+
+
+#
+class LayerNormalization(nn.Module):
+ def __init__(self, d_hid, eps=1e-3, affine=True):
+ super(LayerNormalization, self).__init__()
+
+ self.eps = eps
+ self.affine = affine
+ if self.affine:
+ self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
+ self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
+
+ def forward(self, z):
+ if z.size(-1) == 1:
+ return z
+
+ mu = torch.mean(z, keepdim=True, dim=-1)
+ sigma = torch.std(z, keepdim=True, dim=-1)
+ ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
+ if self.affine:
+ ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
+
+ return ln_out
+
+
+#
+class ScaledAttention(nn.Module):
+ def __init__(self, hparams, attention_dropout=0.1):
+ super(ScaledAttention, self).__init__()
+ self.hparams = hparams
+ self.temper = hparams.d_model ** 0.5
+ self.dropout = nn.Dropout(attention_dropout)
+ self.softmax = nn.Softmax(dim=1)
+
+ def forward(self, q, k, v, attn_mask=None):
+ # q: [batch, slot, feat]
+ # k: [batch, slot, feat]
+ # v: [batch, slot, feat]
+
+ attn = torch.bmm(q, k.transpose(1, 2)) / self.temper
+
+ if attn_mask is not None:
+ assert attn_mask.size() == attn.size(), \
+ 'Attention mask shape {} mismatch ' \
+ 'with Attention logit tensor shape ' \
+ '{}.'.format(attn_mask.size(), attn.size())
+
+ attn.data.masked_fill_(attn_mask, -float('inf'))
+
+ attn = self.softmax(attn.transpose(1, 2)).transpose(1, 2)
+ attn = self.dropout(attn)
+ output = torch.bmm(attn, v)
+
+ return output, attn
+
+
+# %%
+
+class ScaledDotProductAttention(nn.Module):
+ def __init__(self, d_model, attention_dropout=0.1):
+ super(ScaledDotProductAttention, self).__init__()
+ self.temper = d_model ** 0.5
+ self.dropout = nn.Dropout(attention_dropout)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, q, k, v, attn_mask=None):
+ # q: [batch, slot, feat] or (batch * d_l) x max_len x d_k
+ # k: [batch, slot, feat] or (batch * d_l) x max_len x d_k
+ # v: [batch, slot, feat] or (batch * d_l) x max_len x d_v
+ # q in LAL is (batch * d_l) x 1 x d_k
+
+ attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # (batch * d_l) x max_len x max_len
+ # in LAL, gives: (batch * d_l) x 1 x max_len
+ # attention weights from each word to each word, for each label
+ # in best model (repeated q): attention weights from label (as vector weights) to each word
+
+ if attn_mask is not None:
+ assert attn_mask.size() == attn.size(), \
+ 'Attention mask shape {} mismatch ' \
+ 'with Attention logit tensor shape ' \
+ '{}.'.format(attn_mask.size(), attn.size())
+
+ attn.data.masked_fill_(attn_mask, -float('inf'))
+
+ attn = self.softmax(attn)
+ # Note that this makes the distribution not sum to 1. At some point it
+ # may be worth researching whether this is the right way to apply
+ # dropout to the attention.
+ # Note that the t2t code also applies dropout in this manner
+ attn = self.dropout(attn)
+ output = torch.bmm(attn, v) # (batch * d_l) x max_len x d_v
+ # in LAL, gives: (batch * d_l) x 1 x d_v
+
+ return output, attn
+
+
+#
+class MultiHeadAttention(nn.Module):
+ """Multi-head attention module"""
+
+ def __init__(self, hparams, n_head, d_model, d_k, d_v, residual_dropout=0.1, attention_dropout=0.1,
+ d_positional=None):
+ super(MultiHeadAttention, self).__init__()
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+ self.hparams = hparams
+
+ if d_positional is None:
+ self.partitioned = False
+ else:
+ self.partitioned = True
+
+ if self.partitioned:
+ self.d_content = d_model - d_positional
+ self.d_positional = d_positional
+
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
+ self.w_ks1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
+ self.w_vs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_v // 2))
+
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
+ self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
+ self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_v // 2))
+
+ init.xavier_normal_(self.w_qs1)
+ init.xavier_normal_(self.w_ks1)
+ init.xavier_normal_(self.w_vs1)
+
+ init.xavier_normal_(self.w_qs2)
+ init.xavier_normal_(self.w_ks2)
+ init.xavier_normal_(self.w_vs2)
+ else:
+ self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
+ self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
+ self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))
+
+ init.xavier_normal_(self.w_qs)
+ init.xavier_normal_(self.w_ks)
+ init.xavier_normal_(self.w_vs)
+
+ self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
+ self.layer_norm = LayerNormalization(d_model)
+
+ if not self.partitioned:
+ # The lack of a bias term here is consistent with the t2t code, though
+ # in my experiments I have never observed this making a difference.
+ self.proj = nn.Linear(n_head * d_v, d_model, bias=False)
+ else:
+ self.proj1 = nn.Linear(n_head * (d_v // 2), self.d_content, bias=False)
+ self.proj2 = nn.Linear(n_head * (d_v // 2), self.d_positional, bias=False)
+
+ self.residual_dropout = FeatureDropout(residual_dropout)
+
+ def split_qkv_packed(self, inp, qk_inp=None):
+ v_inp_repeated = inp.repeat(self.n_head, 1).view(self.n_head, -1, inp.size(-1)) # n_head x len_inp x d_model
+ if qk_inp is None:
+ qk_inp_repeated = v_inp_repeated
+ else:
+ qk_inp_repeated = qk_inp.repeat(self.n_head, 1).view(self.n_head, -1, qk_inp.size(-1))
+
+ if not self.partitioned:
+ q_s = torch.bmm(qk_inp_repeated, self.w_qs) # n_head x len_inp x d_k
+ k_s = torch.bmm(qk_inp_repeated, self.w_ks) # n_head x len_inp x d_k
+ v_s = torch.bmm(v_inp_repeated, self.w_vs) # n_head x len_inp x d_v
+ else:
+ q_s = torch.cat([
+ torch.bmm(qk_inp_repeated[:, :, :self.d_content], self.w_qs1),
+ torch.bmm(qk_inp_repeated[:, :, self.d_content:], self.w_qs2),
+ ], -1)
+ k_s = torch.cat([
+ torch.bmm(qk_inp_repeated[:, :, :self.d_content], self.w_ks1),
+ torch.bmm(qk_inp_repeated[:, :, self.d_content:], self.w_ks2),
+ ], -1)
+ v_s = torch.cat([
+ torch.bmm(v_inp_repeated[:, :, :self.d_content], self.w_vs1),
+ torch.bmm(v_inp_repeated[:, :, self.d_content:], self.w_vs2),
+ ], -1)
+ return q_s, k_s, v_s
+
+ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs, mask):
+ # Input is padded representation: n_head x len_inp x d
+ # Output is packed representation: (n_head * B) x T x d
+ # (along with masks for the attention and output)
+ n_head = self.n_head
+ d_k, d_v = self.d_k, self.d_v
+
+ T = batch_idxs.max_len
+ B = batch_idxs.batch_size
+ q_padded = self.pad_seuence(q_s, batch_idxs)
+ k_padded = self.pad_seuence(k_s, batch_idxs)
+ v_padded = self.pad_seuence(v_s, batch_idxs)
+
+ return (
+ q_padded.view(-1, T, d_k),
+ k_padded.view(-1, T, d_k),
+ v_padded.view(-1, T, d_v),
+ ~mask.unsqueeze(1).expand(B, T, T).repeat(n_head, 1, 1),
+ mask.repeat(n_head, 1),
+ )
+
+ @staticmethod
+ def pad_seuence(q_s, batch_idxs):
+ q_padded = pad_sequence(torch.split(q_s.transpose(0, 1), batch_idxs.seq_lens_np.tolist()), True).transpose(0,
+ 2).contiguous()
+ return q_padded
+
+ def combine_v(self, outputs):
+ # Combine attention information from the different heads
+ n_head = self.n_head
+ outputs = outputs.view(n_head, -1, self.d_v) # n_head x len_inp x d_kv
+
+ if not self.partitioned:
+ # Switch from n_head x len_inp x d_v to len_inp x (n_head * d_v)
+ outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, n_head * self.d_v)
+
+ # Project back to residual size
+ outputs = self.proj(outputs)
+ else:
+ d_v1 = self.d_v // 2
+ outputs1 = outputs[:, :, :d_v1]
+ outputs2 = outputs[:, :, d_v1:]
+ outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, n_head * d_v1)
+ outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, n_head * d_v1)
+ outputs = torch.cat([
+ self.proj1(outputs1),
+ self.proj2(outputs2),
+ ], -1)
+
+ return outputs
+
+ def forward(self, inp, batch_idxs, qk_inp=None, batch=None, batched_inp=None, **kwargs):
+ residual = inp
+ mask = batch['mask']
+ B, T = mask.size()
+
+ # While still using a packed representation, project to obtain the
+ # query/key/value for each head
+ q_s, k_s, v_s = self.split_qkv_packed(inp, qk_inp=qk_inp)
+ # n_head x len_inp x d_kv
+
+ # Switch to padded representation, perform attention, then switch back
+ q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs, mask)
+ # (n_head * batch) x len_padded x d_kv
+ outputs_padded, attns_padded = self.attention(
+ q_padded, k_padded, v_padded,
+ attn_mask=attn_mask,
+ )
+ outputs = outputs_padded[output_mask]
+ # (n_head * len_inp) x d_kv
+ outputs = self.combine_v(outputs)
+ # len_inp x d_model
+
+ outputs = self.residual_dropout(outputs, batch_idxs)
+
+ return self.layer_norm(outputs + residual), attns_padded
+
+
+#
+class PositionwiseFeedForward(nn.Module):
+ """A position-wise feed forward module.
+
+ Projects to a higher-dimensional space before applying ReLU, then projects
+ back.
+
+ Args:
+
+ Returns:
+
+ """
+
+ def __init__(self, d_hid, d_ff, relu_dropout=0.1, residual_dropout=0.1):
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = nn.Linear(d_hid, d_ff)
+ self.w_2 = nn.Linear(d_ff, d_hid)
+
+ self.layer_norm = LayerNormalization(d_hid)
+ self.relu_dropout = FeatureDropout(relu_dropout)
+ self.residual_dropout = FeatureDropout(residual_dropout)
+ self.relu = nn.ReLU()
+
+ def forward(self, x, batch_idxs):
+ residual = x
+
+ output = self.w_1(x)
+ output = self.relu_dropout(self.relu(output), batch_idxs)
+ output = self.w_2(output)
+
+ output = self.residual_dropout(output, batch_idxs)
+ return self.layer_norm(output + residual)
+
+
+#
+class PartitionedPositionwiseFeedForward(nn.Module):
+ def __init__(self, d_hid, d_ff, d_positional, relu_dropout=0.1, residual_dropout=0.1):
+ super().__init__()
+ self.d_content = d_hid - d_positional
+ self.w_1c = nn.Linear(self.d_content, d_ff // 2)
+ self.w_1p = nn.Linear(d_positional, d_ff // 2)
+ self.w_2c = nn.Linear(d_ff // 2, self.d_content)
+ self.w_2p = nn.Linear(d_ff // 2, d_positional)
+ self.layer_norm = LayerNormalization(d_hid)
+ self.relu_dropout = FeatureDropout(relu_dropout)
+ self.residual_dropout = FeatureDropout(residual_dropout)
+ self.relu = nn.ReLU()
+
+ def forward(self, x, batch_idxs):
+ residual = x
+ xc = x[:, :self.d_content]
+ xp = x[:, self.d_content:]
+
+ outputc = self.w_1c(xc)
+ outputc = self.relu_dropout(self.relu(outputc), batch_idxs)
+ outputc = self.w_2c(outputc)
+
+ outputp = self.w_1p(xp)
+ outputp = self.relu_dropout(self.relu(outputp), batch_idxs)
+ outputp = self.w_2p(outputp)
+
+ output = torch.cat([outputc, outputp], -1)
+
+ output = self.residual_dropout(output, batch_idxs)
+ return self.layer_norm(output + residual)
+
+
+#
+class MultiLevelEmbedding(nn.Module):
+ def __init__(self,
+ num_embeddings_list,
+ d_embedding,
+ hparams,
+ d_positional=None,
+ max_len=300,
+ normalize=True,
+ dropout=0.1,
+ timing_dropout=0.0,
+ emb_dropouts_list=None,
+ extra_content_dropout=None,
+ word_table_np=None,
+ **kwargs):
+ super().__init__()
+
+ self.d_embedding = d_embedding
+ self.partitioned = d_positional is not None
+ self.hparams = hparams
+
+ if self.partitioned:
+ self.d_positional = d_positional
+ self.d_content = self.d_embedding - self.d_positional
+ else:
+ self.d_positional = self.d_embedding
+ self.d_content = self.d_embedding
+
+ if emb_dropouts_list is None:
+ emb_dropouts_list = [0.0] * len(num_embeddings_list)
+ assert len(emb_dropouts_list) == len(num_embeddings_list)
+
+ if word_table_np is not None:
+ self.pretrain_dim = word_table_np.shape[1]
+ else:
+ self.pretrain_dim = 0
+
+ embs = []
+ emb_dropouts = []
+ cun = len(num_embeddings_list) * 2
+ for i, (num_embeddings, emb_dropout) in enumerate(zip(num_embeddings_list, emb_dropouts_list)):
+ if hparams.use_cat:
+ if i == len(num_embeddings_list) - 1:
+ # last is word
+ emb = nn.Embedding(num_embeddings, self.d_content // cun - self.pretrain_dim, **kwargs)
+ else:
+ emb = nn.Embedding(num_embeddings, self.d_content // cun, **kwargs)
+ else:
+ emb = nn.Embedding(num_embeddings, self.d_content - self.pretrain_dim, **kwargs)
+ embs.append(emb)
+ emb_dropout = FeatureDropout(emb_dropout)
+ emb_dropouts.append(emb_dropout)
+
+ if word_table_np is not None:
+ self.pretrain_emb = nn.Embedding(word_table_np.shape[0], self.pretrain_dim)
+ self.pretrain_emb.weight.data.copy_(torch.from_numpy(word_table_np))
+ self.pretrain_emb.weight.requires_grad_(False)
+ self.pretrain_emb_dropout = FeatureDropout(0.33)
+
+ self.embs = nn.ModuleList(embs)
+ self.emb_dropouts = nn.ModuleList(emb_dropouts)
+
+ if extra_content_dropout is not None:
+ self.extra_content_dropout = FeatureDropout(extra_content_dropout)
+ else:
+ self.extra_content_dropout = None
+
+ if normalize:
+ self.layer_norm = LayerNormalization(d_embedding)
+ else:
+ self.layer_norm = lambda x: x
+
+ self.dropout = FeatureDropout(dropout)
+ self.timing_dropout = FeatureDropout(timing_dropout)
+
+ # Learned embeddings
+ self.max_len = max_len
+ self.position_table = nn.Parameter(torch.FloatTensor(max_len, self.d_positional))
+ init.normal_(self.position_table)
+
+ def forward(self, xs, pre_words_idxs, batch_idxs, extra_content_annotations=None, batch=None, batched_inp=None,
+ **kwargs):
+ B, T, C = batched_inp.size()
+ # extra_content_annotations = batched_inp
+ content_annotations = [
+ emb_dropout(emb(x), batch_idxs)
+ for x, emb, emb_dropout in zip(xs, self.embs, self.emb_dropouts)
+ ]
+ if self.hparams.use_cat:
+ content_annotations = torch.cat(content_annotations, dim=-1)
+ else:
+ content_annotations = sum(content_annotations)
+ if self.pretrain_dim != 0:
+ content_annotations = torch.cat(
+ [content_annotations, self.pretrain_emb_dropout(self.pretrain_emb(pre_words_idxs))], dim=1)
+
+ if extra_content_annotations is not None:
+ if self.extra_content_dropout is not None:
+ extra_content_annotations = self.extra_content_dropout(extra_content_annotations)
+
+ if self.hparams.use_cat:
+ content_annotations = torch.cat(
+ [content_annotations, extra_content_annotations], dim=-1)
+ else:
+ content_annotations += extra_content_annotations
+
+ mask = batch['mask']
+ timing_signal = self.position_table[:T, :].unsqueeze(0).expand_as(batched_inp)[mask]
+ timing_signal = self.timing_dropout(timing_signal, batch_idxs)
+
+ # Combine the content and timing signals
+ if self.partitioned:
+ annotations = torch.cat([content_annotations, timing_signal], 1)
+ else:
+ annotations = content_annotations + timing_signal
+
+ # print(annotations.shape)
+ annotations = self.layer_norm(self.dropout(annotations, batch_idxs))
+ content_annotations = self.dropout(content_annotations, batch_idxs)
+
+ return annotations, content_annotations, timing_signal, batch_idxs
+
+
+#
+class BiLinear(nn.Module):
+ """Bi-linear layer"""
+
+ def __init__(self, left_features, right_features, out_features, bias=True):
+ '''
+
+ Args:
+ left_features: size of left input
+ right_features: size of right input
+ out_features: size of output
+ bias: If set to False, the layer will not learn an additive bias.
+ Default: True
+ '''
+ super(BiLinear, self).__init__()
+ self.left_features = left_features
+ self.right_features = right_features
+ self.out_features = out_features
+
+ self.U = nn.Parameter(torch.Tensor(self.out_features, self.left_features, self.right_features))
+ self.W_l = nn.Parameter(torch.Tensor(self.out_features, self.left_features))
+ self.W_r = nn.Parameter(torch.Tensor(self.out_features, self.left_features))
+
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_features))
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.xavier_uniform_(self.W_l)
+ nn.init.xavier_uniform_(self.W_r)
+ nn.init.constant_(self.bias, 0.)
+ nn.init.xavier_uniform_(self.U)
+
+ def forward(self, input_left, input_right):
+ """
+
+ Args:
+ input_left: Tensor
+ the left input tensor with shape = [batch1, batch2, ..., left_features]
+ input_right: Tensor
+ the right input tensor with shape = [batch1, batch2, ..., right_features]
+
+ Returns:
+
+ """
+ # convert left and right input to matrices [batch, left_features], [batch, right_features]
+ input_left = input_left.view(-1, self.left_features)
+ input_right = input_right.view(-1, self.right_features)
+
+ # output [batch, out_features]
+ output = nn.functional.bilinear(input_left, input_right, self.U, self.bias)
+ output = output + nn.functional.linear(input_left, self.W_l, None) + nn.functional.linear(input_right, self.W_r,
+ None)
+ # convert back to [batch1, batch2, ..., out_features]
+ return output
+
+
+#
+class BiAAttention(nn.Module):
+ """Bi-Affine attention layer."""
+
+ def __init__(self, hparams):
+ super(BiAAttention, self).__init__()
+ self.hparams = hparams
+
+ self.dep_weight = nn.Parameter(torch.FloatTensor(hparams.d_biaffine + 1, hparams.d_biaffine + 1))
+ nn.init.xavier_uniform_(self.dep_weight)
+
+ def forward(self, input_d, input_e, input_s=None):
+ device = input_d.device
+ score = torch.matmul(torch.cat(
+ [input_d, torch.FloatTensor(input_d.size(0), 1).to(device).fill_(1).requires_grad_(False)],
+ dim=1), self.dep_weight)
+ score1 = torch.matmul(score, torch.transpose(torch.cat(
+ [input_e, torch.FloatTensor(input_e.size(0), 1).to(device).fill_(1).requires_grad_(False)],
+ dim=1), 0, 1))
+
+ return score1
+
+
+class Dep_score(nn.Module):
+ def __init__(self, hparams, num_labels):
+ super(Dep_score, self).__init__()
+
+ self.dropout_out = nn.Dropout2d(p=0.33)
+ self.hparams = hparams
+ out_dim = hparams.d_biaffine # d_biaffine
+ self.arc_h = nn.Linear(hparams.annotation_dim, hparams.d_biaffine)
+ self.arc_c = nn.Linear(hparams.annotation_dim, hparams.d_biaffine)
+
+ self.attention = BiAAttention(hparams)
+
+ self.type_h = nn.Linear(hparams.annotation_dim, hparams.d_label_hidden)
+ self.type_c = nn.Linear(hparams.annotation_dim, hparams.d_label_hidden)
+ self.bilinear = BiLinear(hparams.d_label_hidden, hparams.d_label_hidden, num_labels)
+
+ def forward(self, outputs, outpute):
+ # output from rnn [batch, length, hidden_size]
+
+ # apply dropout for output
+ # [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size]
+ outpute = self.dropout_out(outpute.transpose(1, 0)).transpose(1, 0)
+ outputs = self.dropout_out(outputs.transpose(1, 0)).transpose(1, 0)
+
+ # output size [batch, length, arc_space]
+ arc_h = nn.functional.relu(self.arc_h(outputs))
+ arc_c = nn.functional.relu(self.arc_c(outpute))
+
+ # output size [batch, length, type_space]
+ type_h = nn.functional.relu(self.type_h(outputs))
+ type_c = nn.functional.relu(self.type_c(outpute))
+
+ # apply dropout
+ # [batch, length, dim] --> [batch, 2 * length, dim]
+ arc = torch.cat([arc_h, arc_c], dim=0)
+ type = torch.cat([type_h, type_c], dim=0)
+
+ arc = self.dropout_out(arc.transpose(1, 0)).transpose(1, 0)
+ arc_h, arc_c = arc.chunk(2, 0)
+
+ type = self.dropout_out(type.transpose(1, 0)).transpose(1, 0)
+ type_h, type_c = type.chunk(2, 0)
+ type_h = type_h.contiguous()
+ type_c = type_c.contiguous()
+
+ out_arc = self.attention(arc_h, arc_c)
+ out_type = self.bilinear(type_h, type_c)
+
+ return out_arc, out_type
+
+
+class LabelAttention(nn.Module):
+ """Single-head Attention layer for label-specific representations"""
+
+ def __init__(self, hparams, d_model, d_k, d_v, d_l, d_proj, use_resdrop=True, q_as_matrix=False,
+ residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
+ super(LabelAttention, self).__init__()
+ self.hparams = hparams
+ self.d_k = d_k
+ self.d_v = d_v
+ self.d_l = d_l # Number of Labels
+ self.d_model = d_model # Model Dimensionality
+ self.d_proj = d_proj # Projection dimension of each label output
+ self.use_resdrop = use_resdrop # Using Residual Dropout?
+ self.q_as_matrix = q_as_matrix # Using a Matrix of Q to be multiplied with input instead of learned q vectors
+ self.combine_as_self = hparams.lal_combine_as_self # Using the Combination Method of Self-Attention
+
+ if d_positional is None:
+ self.partitioned = False
+ else:
+ self.partitioned = True
+
+ if self.partitioned:
+ self.d_content = d_model - d_positional
+ self.d_positional = d_positional
+
+ if self.q_as_matrix:
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
+ else:
+ self.w_qs1 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
+ self.w_ks1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_k // 2), requires_grad=True)
+ self.w_vs1 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_content, d_v // 2), requires_grad=True)
+
+ if self.q_as_matrix:
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2),
+ requires_grad=True)
+ else:
+ self.w_qs2 = nn.Parameter(torch.FloatTensor(self.d_l, d_k // 2), requires_grad=True)
+ self.w_ks2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_k // 2), requires_grad=True)
+ self.w_vs2 = nn.Parameter(torch.FloatTensor(self.d_l, self.d_positional, d_v // 2), requires_grad=True)
+
+ init.xavier_normal_(self.w_qs1)
+ init.xavier_normal_(self.w_ks1)
+ init.xavier_normal_(self.w_vs1)
+
+ init.xavier_normal_(self.w_qs2)
+ init.xavier_normal_(self.w_ks2)
+ init.xavier_normal_(self.w_vs2)
+ else:
+ if self.q_as_matrix:
+ self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
+ else:
+ self.w_qs = nn.Parameter(torch.FloatTensor(self.d_l, d_k), requires_grad=True)
+ self.w_ks = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_k), requires_grad=True)
+ self.w_vs = nn.Parameter(torch.FloatTensor(self.d_l, d_model, d_v), requires_grad=True)
+
+ init.xavier_normal_(self.w_qs)
+ init.xavier_normal_(self.w_ks)
+ init.xavier_normal_(self.w_vs)
+
+ self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
+ if self.combine_as_self:
+ self.layer_norm = LayerNormalization(d_model)
+ else:
+ self.layer_norm = LayerNormalization(self.d_proj)
+
+ if not self.partitioned:
+ # The lack of a bias term here is consistent with the t2t code, though
+ # in my experiments I have never observed this making a difference.
+ if self.combine_as_self:
+ self.proj = nn.Linear(self.d_l * d_v, d_model, bias=False)
+ else:
+ self.proj = nn.Linear(d_v, d_model, bias=False) # input dimension does not match, should be d_l * d_v
+ else:
+ if self.combine_as_self:
+ self.proj1 = nn.Linear(self.d_l * (d_v // 2), self.d_content, bias=False)
+ self.proj2 = nn.Linear(self.d_l * (d_v // 2), self.d_positional, bias=False)
+ else:
+ self.proj1 = nn.Linear(d_v // 2, self.d_content, bias=False)
+ self.proj2 = nn.Linear(d_v // 2, self.d_positional, bias=False)
+ if not self.combine_as_self:
+ self.reduce_proj = nn.Linear(d_model, self.d_proj, bias=False)
+
+ self.residual_dropout = FeatureDropout(residual_dropout)
+
+ def split_qkv_packed(self, inp, k_inp=None):
+ len_inp = inp.size(0)
+ v_inp_repeated = inp.repeat(self.d_l, 1).view(self.d_l, -1, inp.size(-1)) # d_l x len_inp x d_model
+ if k_inp is None:
+ k_inp_repeated = v_inp_repeated
+ else:
+ k_inp_repeated = k_inp.repeat(self.d_l, 1).view(self.d_l, -1, k_inp.size(-1)) # d_l x len_inp x d_model
+
+ if not self.partitioned:
+ if self.q_as_matrix:
+ q_s = torch.bmm(k_inp_repeated, self.w_qs) # d_l x len_inp x d_k
+ else:
+ q_s = self.w_qs.unsqueeze(1) # d_l x 1 x d_k
+ k_s = torch.bmm(k_inp_repeated, self.w_ks) # d_l x len_inp x d_k
+ v_s = torch.bmm(v_inp_repeated, self.w_vs) # d_l x len_inp x d_v
+ else:
+ if self.q_as_matrix:
+ q_s = torch.cat([
+ torch.bmm(k_inp_repeated[:, :, :self.d_content], self.w_qs1),
+ torch.bmm(k_inp_repeated[:, :, self.d_content:], self.w_qs2),
+ ], -1)
+ else:
+ q_s = torch.cat([
+ self.w_qs1.unsqueeze(1),
+ self.w_qs2.unsqueeze(1),
+ ], -1)
+ k_s = torch.cat([
+ torch.bmm(k_inp_repeated[:, :, :self.d_content], self.w_ks1),
+ torch.bmm(k_inp_repeated[:, :, self.d_content:], self.w_ks2),
+ ], -1)
+ v_s = torch.cat([
+ torch.bmm(v_inp_repeated[:, :, :self.d_content], self.w_vs1),
+ torch.bmm(v_inp_repeated[:, :, self.d_content:], self.w_vs2),
+ ], -1)
+ return q_s, k_s, v_s
+
+ def pad_and_rearrange(self, q_s, k_s, v_s, batch_idxs, mask):
+ # Input is padded representation: n_head x len_inp x d
+ # Output is packed representation: (n_head * B) x T x d
+ # (along with masks for the attention and output)
+ n_head = self.d_l
+ d_k, d_v = self.d_k, self.d_v
+
+ T = batch_idxs.max_len
+ B = batch_idxs.batch_size
+ if self.q_as_matrix:
+ q_padded = q_s.new_zeros((n_head, B, T, d_k))
+ for i, (start, end) in enumerate(zip(batch_idxs.boundaries_np[:-1], batch_idxs.boundaries_np[1:])):
+ q_padded[:, i, :end - start, :] = q_s[:, start:end, :]
+ else:
+ q_padded = q_s.repeat(B, 1, 1) # (d_l * B) x 1 x d_k
+ k_padded = MultiHeadAttention.pad_seuence(k_s, batch_idxs)
+ v_padded = MultiHeadAttention.pad_seuence(v_s, batch_idxs)
+
+ if self.q_as_matrix:
+ q_padded = q_padded.view(-1, T, d_k)
+ attn_mask = ~mask.unsqueeze(1).expand(B, T, T).repeat(n_head, 1, 1)
+ else:
+ attn_mask = ~mask.unsqueeze(1).repeat(n_head, 1, 1)
+
+ output_mask = mask.repeat(n_head, 1)
+
+ return (
+ q_padded,
+ k_padded.view(-1, T, d_k),
+ v_padded.view(-1, T, d_v),
+ attn_mask,
+ output_mask,
+ )
+
+ def combine_v(self, outputs):
+ # Combine attention information from the different labels
+ d_l = self.d_l
+ outputs = outputs.view(d_l, -1, self.d_v) # d_l x len_inp x d_v
+
+ if not self.partitioned:
+ # Switch from d_l x len_inp x d_v to len_inp x d_l x d_v
+ if self.combine_as_self:
+ outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, d_l * self.d_v)
+ else:
+ outputs = torch.transpose(outputs, 0, 1) # .contiguous() #.view(-1, d_l * self.d_v)
+ # Project back to residual size
+ outputs = self.proj(outputs) # Becomes len_inp x d_l x d_model
+ else:
+ d_v1 = self.d_v // 2
+ outputs1 = outputs[:, :, :d_v1]
+ outputs2 = outputs[:, :, d_v1:]
+ if self.combine_as_self:
+ outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, d_l * d_v1)
+ outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, d_l * d_v1)
+ else:
+ outputs1 = torch.transpose(outputs1, 0, 1) # .contiguous() #.view(-1, d_l * d_v1)
+ outputs2 = torch.transpose(outputs2, 0, 1) # .contiguous() #.view(-1, d_l * d_v1)
+ outputs = torch.cat([
+ self.proj1(outputs1),
+ self.proj2(outputs2),
+ ], -1) # .contiguous()
+
+ return outputs
+
+ def forward(self, inp, batch_idxs, k_inp=None, batch=None, batched_inp=None, **kwargs):
+ mask = batch['mask']
+ residual = inp # len_inp x d_model
+ len_inp = inp.size(0)
+
+ # While still using a packed representation, project to obtain the
+ # query/key/value for each head
+ q_s, k_s, v_s = self.split_qkv_packed(inp, k_inp=k_inp)
+ # d_l x len_inp x d_k
+ # q_s is d_l x 1 x d_k
+
+ # Switch to padded representation, perform attention, then switch back
+ q_padded, k_padded, v_padded, attn_mask, output_mask = self.pad_and_rearrange(q_s, k_s, v_s, batch_idxs, mask)
+ # q_padded, k_padded, v_padded: (d_l * batch_size) x max_len x d_kv
+ # q_s is (d_l * batch_size) x 1 x d_kv
+
+ outputs_padded, attns_padded = self.attention(
+ q_padded, k_padded, v_padded,
+ attn_mask=attn_mask,
+ )
+ # outputs_padded: (d_l * batch_size) x max_len x d_kv
+ # in LAL: (d_l * batch_size) x 1 x d_kv
+ # on the best model, this is one value vector per label that is repeated max_len times
+ if not self.q_as_matrix:
+ outputs_padded = outputs_padded.repeat(1, output_mask.size(-1), 1)
+ outputs = outputs_padded[output_mask]
+ # outputs: (d_l * len_inp) x d_kv or LAL: (d_l * len_inp) x d_kv
+ # output_mask: (d_l * batch_size) x max_len
+ # torch.cuda.empty_cache()
+ outputs = self.combine_v(outputs)
+ # outputs: len_inp x d_l x d_model, whereas a normal self-attention layer gets len_inp x d_model
+ if self.use_resdrop:
+ if self.combine_as_self:
+ outputs = self.residual_dropout(outputs, batch_idxs)
+ else:
+ outputs = torch.cat(
+ [self.residual_dropout(outputs[:, i, :], batch_idxs).unsqueeze(1) for i in range(self.d_l)], 1)
+ if self.combine_as_self:
+ outputs = self.layer_norm(outputs + inp)
+ else:
+ outputs = outputs + inp.unsqueeze(1)
+ outputs = self.reduce_proj(outputs) # len_inp x d_l x d_proj
+ outputs = self.layer_norm(outputs) # len_inp x d_l x d_proj
+ outputs = outputs.view(len_inp, -1).contiguous() # len_inp x (d_l * d_proj)
+
+ return outputs, attns_padded
+
+
+class Encoder(nn.Module):
+ def __init__(self, hparams, embedding,
+ num_layers=1, num_heads=2, d_kv=32, d_ff=1024, d_l=112,
+ d_positional=None,
+ num_layers_position_only=0,
+ relu_dropout=0.1, residual_dropout=0.1, attention_dropout=0.1,
+ use_lal=True,
+ lal_d_kv=128,
+ lal_d_proj=128,
+ lal_resdrop=True,
+ lal_pwff=True,
+ lal_q_as_matrix=False,
+ lal_partitioned=True):
+ super().__init__()
+ self.embedding_container = [embedding]
+ d_model = embedding.d_embedding
+ self.hparams = hparams
+
+ d_k = d_v = d_kv
+
+ self.stacks = []
+
+ for i in range(hparams.num_layers):
+ attn = MultiHeadAttention(hparams, num_heads, d_model, d_k, d_v, residual_dropout=residual_dropout,
+ attention_dropout=attention_dropout, d_positional=d_positional)
+ if d_positional is None:
+ ff = PositionwiseFeedForward(d_model, d_ff, relu_dropout=relu_dropout,
+ residual_dropout=residual_dropout)
+ else:
+ ff = PartitionedPositionwiseFeedForward(d_model, d_ff, d_positional, relu_dropout=relu_dropout,
+ residual_dropout=residual_dropout)
+
+ self.add_module(f"attn_{i}", attn)
+ self.add_module(f"ff_{i}", ff)
+
+ self.stacks.append((attn, ff))
+
+ if use_lal:
+ lal_d_positional = d_positional if lal_partitioned else None
+ attn = LabelAttention(hparams, d_model, lal_d_kv, lal_d_kv, d_l, lal_d_proj, use_resdrop=lal_resdrop,
+ q_as_matrix=lal_q_as_matrix,
+ residual_dropout=residual_dropout, attention_dropout=attention_dropout,
+ d_positional=lal_d_positional)
+ ff_dim = lal_d_proj * d_l
+ if hparams.lal_combine_as_self:
+ ff_dim = d_model
+ if lal_pwff:
+ if d_positional is None or not lal_partitioned:
+ ff = PositionwiseFeedForward(ff_dim, d_ff, relu_dropout=relu_dropout,
+ residual_dropout=residual_dropout)
+ else:
+ ff = PartitionedPositionwiseFeedForward(ff_dim, d_ff, d_positional, relu_dropout=relu_dropout,
+ residual_dropout=residual_dropout)
+ else:
+ ff = None
+
+ self.add_module(f"attn_{num_layers}", attn)
+ self.add_module(f"ff_{num_layers}", ff)
+ self.stacks.append((attn, ff))
+
+ self.num_layers_position_only = num_layers_position_only
+ if self.num_layers_position_only > 0:
+ assert d_positional is None, "num_layers_position_only and partitioned are incompatible"
+
+ def forward(self, xs, pre_words_idxs, batch_idxs, extra_content_annotations=None, batch=None, **kwargs):
+ emb = self.embedding_container[0]
+ res, res_c, timing_signal, batch_idxs = emb(xs, pre_words_idxs, batch_idxs,
+ extra_content_annotations=extra_content_annotations,
+ batch=batch, **kwargs)
+
+ for i, (attn, ff) in enumerate(self.stacks):
+ res, current_attns = attn(res, batch_idxs, batch=batch, **kwargs)
+ if ff is not None:
+ res = ff(res, batch_idxs)
+
+ return res, current_attns # batch_idxs
+
+
+class ChartParser(nn.Module):
+ def __init__(
+ self,
+ embed: nn.Module,
+ tag_vocab,
+ label_vocab,
+ type_vocab,
+ config,
+ ):
+ super().__init__()
+ self.embed = embed
+ self.tag_vocab = tag_vocab
+ self.label_vocab = label_vocab
+ self.label_vocab_size = len(label_vocab)
+ self.type_vocab = type_vocab
+
+ self.hparams = config
+ self.d_model = config.d_model
+ self.partitioned = config.partitioned
+ self.d_content = (self.d_model // 2) if self.partitioned else self.d_model
+ self.d_positional = (config.d_model // 2) if self.partitioned else None
+
+ self.use_lal = config.use_lal
+ if self.use_lal:
+ self.lal_d_kv = config.lal_d_kv
+ self.lal_d_proj = config.lal_d_proj
+ self.lal_resdrop = config.lal_resdrop
+ self.lal_pwff = config.lal_pwff
+ self.lal_q_as_matrix = config.lal_q_as_matrix
+ self.lal_partitioned = config.lal_partitioned
+ self.lal_combine_as_self = config.lal_combine_as_self
+
+ self.contributions = False
+
+ num_embeddings_map = {
+ 'tags': len(tag_vocab),
+ }
+ emb_dropouts_map = {
+ 'tags': config.tag_emb_dropout,
+ }
+
+ self.emb_types = []
+ if config.use_tags:
+ self.emb_types.append('tags')
+ if config.use_words:
+ self.emb_types.append('words')
+
+ self.use_tags = config.use_tags
+
+ self.morpho_emb_dropout = None
+
+ self.char_encoder = None
+ self.elmo = None
+ self.bert = None
+ self.xlnet = None
+ self.pad_left = config.pad_left
+ self.roberta = None
+ ex_dim = self.d_content
+ if self.hparams.use_cat:
+ cun = 0
+ if config.use_words or config.use_tags:
+ ex_dim = ex_dim // 2 # word dim = self.d_content/2
+ if config.use_chars_lstm:
+ cun = cun + 1
+ if config.use_elmo or config.use_bert or config.use_xlnet:
+ cun = cun + 1
+ if cun > 0:
+ ex_dim = ex_dim // cun
+
+ self.project_xlnet = nn.Linear(embed.get_output_dim(), ex_dim, bias=False)
+
+ if not config.dont_use_encoder:
+ word_table_np = None
+
+ self.embedding = MultiLevelEmbedding(
+ [num_embeddings_map[emb_type] for emb_type in self.emb_types],
+ config.d_model,
+ hparams=config,
+ d_positional=self.d_positional,
+ dropout=config.embedding_dropout,
+ timing_dropout=config.timing_dropout,
+ emb_dropouts_list=[emb_dropouts_map[emb_type] for emb_type in self.emb_types],
+ extra_content_dropout=self.morpho_emb_dropout,
+ max_len=config.sentence_max_len,
+ word_table_np=word_table_np,
+ )
+
+ self.encoder = Encoder(
+ config,
+ self.embedding,
+ d_l=len(label_vocab) - 1,
+ num_layers=config.num_layers,
+ num_heads=config.num_heads,
+ d_kv=config.d_kv,
+ d_ff=config.d_ff,
+ d_positional=self.d_positional,
+ relu_dropout=config.relu_dropout,
+ residual_dropout=config.residual_dropout,
+ attention_dropout=config.attention_dropout,
+ use_lal=config.use_lal,
+ lal_d_kv=config.lal_d_kv,
+ lal_d_proj=config.lal_d_proj,
+ lal_resdrop=config.lal_resdrop,
+ lal_pwff=config.lal_pwff,
+ lal_q_as_matrix=config.lal_q_as_matrix,
+ lal_partitioned=config.lal_partitioned,
+ )
+ else:
+ self.embedding = None
+ self.encoder = None
+
+ label_vocab_size = len(label_vocab)
+ annotation_dim = ((label_vocab_size - 1) * self.lal_d_proj) if (
+ self.use_lal and not self.lal_combine_as_self) else config.d_model
+ # annotation_dim = self.encoder.stacks[-1][1].w_2c.out_features + self.encoder.stacks[-1][1].w_2p.out_features
+ # annotation_dim = min((self.label_vocab_size - 1) * self.lal_d_proj, self.encoder.stacks[-1][1].w_2c.out_features + self.encoder.stacks[-1][1].w_2p.out_features)
+ config.annotation_dim = annotation_dim
+
+ self.f_label = nn.Sequential(
+ nn.Linear(annotation_dim, config.d_label_hidden),
+ LayerNormalization(config.d_label_hidden),
+ nn.ReLU(),
+ nn.Linear(config.d_label_hidden, label_vocab_size - 1),
+ )
+ self.dep_score = Dep_score(config, len(type_vocab))
+ self.loss_func = torch.nn.CrossEntropyLoss(reduction='sum')
+ self.loss_funt = torch.nn.CrossEntropyLoss(reduction='sum')
+
+ if not config.use_tags and hasattr(config, 'd_tag_hidden'):
+ self.f_tag = nn.Sequential(
+ nn.Linear(annotation_dim, config.d_tag_hidden),
+ LayerNormalization(config.d_tag_hidden),
+ nn.ReLU(),
+ nn.Linear(config.d_tag_hidden, tag_vocab.size),
+ )
+ self.tag_loss_scale = config.tag_loss_scale
+ else:
+ self.f_tag = None
+
+ def forward(self, batch: dict):
+ # sentences = batch['token']
+ token_length: torch.LongTensor = batch['token_length']
+ batch['mask'] = mask = util.lengths_to_mask(token_length)
+ B, T = mask.size()
+ golds: List[InternalParseNode] = batch.get('hpsg', None) if self.training else None
+ if golds:
+ sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in golds]
+ else:
+ sentences = [list(zip(t, w)) for w, t in zip(batch['FORM'], batch['CPOS'])]
+ is_train = golds is not None
+
+ packed_len = sum(token_length)
+ i = 0
+ batch_idxs = torch.arange(B).unsqueeze(1).expand_as(mask)[mask]
+ batch_idxs = BatchIndices(batch_idxs)
+
+ self.train(is_train)
+ torch.set_grad_enabled(is_train)
+ self.current_attns = None
+
+ if golds is None:
+ golds = [None] * len(sentences)
+
+ extra_content_annotations_list = []
+
+ features_packed = self.embed(batch)
+ # For now, just project the features from the last word piece in each word
+ extra_content_annotations = self.project_xlnet(features_packed)
+
+ if self.encoder is not None:
+ if len(extra_content_annotations_list) > 1:
+ if self.hparams.use_cat:
+ extra_content_annotations = torch.cat(extra_content_annotations_list, dim=-1)
+ else:
+ extra_content_annotations = sum(extra_content_annotations_list)
+ elif len(extra_content_annotations_list) == 1:
+ extra_content_annotations = extra_content_annotations_list[0]
+
+ annotations, self.current_attns = self.encoder([batch['pos_id'][mask]], None, batch_idxs,
+ extra_content_annotations=extra_content_annotations[mask],
+ batch=batch,
+ batched_inp=extra_content_annotations)
+
+ if self.partitioned and not self.use_lal:
+ annotations = torch.cat([
+ annotations[:, 0::2],
+ annotations[:, 1::2],
+ ], 1)
+
+ if self.use_lal and not self.lal_combine_as_self:
+ half_dim = self.lal_d_proj // 2
+ annotations_3d = annotations.view(annotations.size(0), -1, half_dim)
+ fencepost_annotations = torch.cat(
+ [annotations_3d[:-1, 0::2, :].flatten(1), annotations_3d[1:, 1::2, :].flatten(1)], dim=-1)
+ else:
+ fencepost_annotations = torch.cat([
+ annotations[:-1, :self.d_model // 2],
+ annotations[1:, self.d_model // 2:],
+ ], 1)
+
+ fencepost_annotations_start = fencepost_annotations
+ fencepost_annotations_end = fencepost_annotations
+
+ else:
+ raise NotImplementedError()
+
+ fp_startpoints = batch_idxs.boundaries_np[:-1]
+ fp_endpoints = batch_idxs.boundaries_np[1:] - 1
+
+ if not is_train:
+ trees = []
+ scores = []
+ for i, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
+ tree, score = self.parse_from_annotations(fencepost_annotations_start[start:end, :],
+ fencepost_annotations_end[start:end, :], sentences[i], i)
+ trees.append(tree)
+ scores.append(score)
+
+ return trees, scores
+
+ pis = []
+ pjs = []
+ plabels = []
+ paugment_total = 0.0
+ cun = 0
+ num_p = 0
+ gis = []
+ gjs = []
+ glabels = []
+ with torch.no_grad():
+ for i, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
+ p_i, p_j, p_label, p_augment, g_i, g_j, g_label \
+ = self.parse_from_annotations(fencepost_annotations_start[start:end, :],
+ fencepost_annotations_end[start:end, :], sentences[i], i,
+ gold=golds[i])
+
+ paugment_total += p_augment
+ num_p += p_i.shape[0]
+ pis.append(p_i + start)
+ pjs.append(p_j + start)
+ gis.append(g_i + start)
+ gjs.append(g_j + start)
+ plabels.append(p_label)
+ glabels.append(g_label)
+
+ device = annotations.device
+ cells_i = torch.tensor(np.concatenate(pis + gis), device=device)
+ cells_j = torch.tensor(np.concatenate(pjs + gjs), device=device)
+ cells_label = torch.tensor(np.concatenate(plabels + glabels), device=device)
+
+ cells_label_scores = self.f_label(fencepost_annotations_end[cells_j] - fencepost_annotations_start[cells_i])
+ cells_label_scores = torch.cat([
+ cells_label_scores.new_zeros((cells_label_scores.size(0), 1)),
+ cells_label_scores
+ ], 1)
+ cells_label_scores = torch.gather(cells_label_scores, 1, cells_label[:, None])
+ loss = cells_label_scores[:num_p].sum() - cells_label_scores[num_p:].sum() + paugment_total
+
+ cun = 0
+ for snum, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
+ # [start,....,end-1]->[,1, 2,...,n]
+ leng = end - start
+ arc_score, type_score = self.dep_score(fencepost_annotations_start[start:end, :],
+ fencepost_annotations_end[start:end, :])
+ # arc_gather = gfather[cun] - start
+ arc_gather = [leaf.father for leaf in golds[snum].leaves()]
+ type_gather = [self.type_vocab.get_idx(leaf.type) for leaf in golds[snum].leaves()]
+ cun += 1
+ assert len(arc_gather) == leng - 1
+ arc_score = torch.transpose(arc_score, 0, 1)
+ loss = loss + 0.5 * self.loss_func(arc_score[1:, :], torch.tensor(arc_gather, device=device)) \
+ + 0.5 * self.loss_funt(type_score[1:, :], torch.tensor(type_gather, device=device))
+
+ return None, loss
+
+ def label_scores_from_annotations(self, fencepost_annotations_start, fencepost_annotations_end):
+
+ span_features = (torch.unsqueeze(fencepost_annotations_end, 0)
+ - torch.unsqueeze(fencepost_annotations_start, 1))
+
+ if self.contributions and self.use_lal:
+ contributions = np.zeros(
+ (span_features.shape[0], span_features.shape[1], span_features.shape[2] // self.lal_d_proj))
+ half_vector = span_features.shape[-1] // 2
+ half_dim = self.lal_d_proj // 2
+ for i in range(contributions.shape[0]):
+ for j in range(contributions.shape[1]):
+ for l in range(contributions.shape[-1]):
+ contributions[i, j, l] = span_features[i, j,
+ l * half_dim:(l + 1) * half_dim].sum() + span_features[i, j,
+ half_vector + l * half_dim:half_vector + (
+ l + 1) * half_dim].sum()
+ contributions[i, j, :] = (contributions[i, j, :] - np.min(contributions[i, j, :]))
+ contributions[i, j, :] = (contributions[i, j, :]) / (
+ np.max(contributions[i, j, :]) - np.min(contributions[i, j, :]))
+ # contributions[i,j,:] = contributions[i,j,:]/np.sum(contributions[i,j,:])
+ contributions = torch.softmax(torch.Tensor(contributions), -1)
+
+ label_scores_chart = self.f_label(span_features)
+ label_scores_chart = torch.cat([
+ label_scores_chart.new_zeros((label_scores_chart.size(0), label_scores_chart.size(1), 1)),
+ label_scores_chart
+ ], 2)
+ if self.contributions and self.use_lal:
+ return label_scores_chart, contributions
+ return label_scores_chart
+
+ def parse_from_annotations(self, fencepost_annotations_start, fencepost_annotations_end, sentence, sentence_idx,
+ gold=None):
+ is_train = gold is not None
+ contributions = None
+ if self.contributions and self.use_lal:
+ label_scores_chart, contributions = self.label_scores_from_annotations(fencepost_annotations_start,
+ fencepost_annotations_end)
+ else:
+ label_scores_chart = self.label_scores_from_annotations(fencepost_annotations_start,
+ fencepost_annotations_end)
+ label_scores_chart_np = label_scores_chart.cpu().data.numpy()
+
+ if is_train:
+ decoder_args = dict(
+ sentence_len=len(sentence),
+ label_scores_chart=label_scores_chart_np,
+ gold=gold,
+ label_vocab=self.label_vocab,
+ is_train=is_train)
+
+ p_score, p_i, p_j, p_label, p_augment = const_decoder.decode(False, **decoder_args)
+ g_score, g_i, g_j, g_label, g_augment = const_decoder.decode(True, **decoder_args)
+ return p_i, p_j, p_label, p_augment, g_i, g_j, g_label
+ else:
+ arc_score, type_score = self.dep_score(fencepost_annotations_start, fencepost_annotations_end)
+
+ arc_score_dc = torch.transpose(arc_score, 0, 1)
+ arc_dc_np = arc_score_dc.cpu().data.numpy()
+
+ type_np = type_score.cpu().data.numpy()
+ type_np = type_np[1:, :] # remove root
+ type = type_np.argmax(axis=1)
+ return self.decode_from_chart(sentence, label_scores_chart_np, arc_dc_np, type, sentence_idx=sentence_idx,
+ contributions=contributions)
+
+ def decode_from_chart_batch(self, sentences, charts_np, golds=None):
+ trees = []
+ scores = []
+ if golds is None:
+ golds = [None] * len(sentences)
+ for sentence, chart_np, gold in zip(sentences, charts_np, golds):
+ tree, score = self.decode_from_chart(sentence, chart_np, gold)
+ trees.append(tree)
+ scores.append(score)
+ return trees, scores
+
+ def decode_from_chart(self, sentence, label_scores_chart_np, arc_dc_np, type, sentence_idx=None, gold=None,
+ contributions=None):
+
+ decoder_args = dict(
+ sentence_len=len(sentence),
+ label_scores_chart=label_scores_chart_np * self.hparams.const_lada,
+ type_scores_chart=arc_dc_np * (1.0 - self.hparams.const_lada),
+ gold=gold,
+ label_vocab=self.label_vocab,
+ type_vocab=self.type_vocab,
+ is_train=False)
+
+ force_gold = (gold is not None)
+
+ # The optimized cython decoder implementation doesn't actually
+ # generate trees, only scores and span indices. When converting to a
+ # tree, we assume that the indices follow a preorder traversal.
+
+ score, p_i, p_j, p_label, p_father, p_type, _ = hpsg_decoder.decode(force_gold, **decoder_args)
+ if contributions is not None:
+ d_l = (self.label_vocab_size - 2)
+ mb_size = (self.current_attns.shape[0] // d_l)
+ print('SENTENCE', sentence)
+
+ idx = -1
+ type_idx_to_token = self.type_vocab.idx_to_token
+ label_idx_to_token = self.label_vocab.idx_to_token
+
+ def get_label(index):
+ label = label_idx_to_token[index]
+ if not label:
+ return ()
+ return tuple(label.split('\t'))
+
+ def make_tree():
+ nonlocal idx
+ idx += 1
+ i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
+ label = get_label(label_idx)
+ if contributions is not None:
+ if label_idx > 0:
+ print(i, sentence[i], j, sentence[j - 1], label, label_idx, contributions[i, j, label_idx - 1])
+ print("CONTRIBUTIONS")
+ print(list(enumerate(contributions[i, j])))
+ print("ATTENTION DIST")
+ print(torch.softmax(self.current_attns[sentence_idx::mb_size, 0, i:j + 1], -1))
+ if (i + 1) >= j:
+ tag, word = sentence[i]
+ if type is not None:
+ tree = trees.LeafParseNode(int(i), tag, word, p_father[i], type_idx_to_token[type[i]])
+ else:
+ tree = trees.LeafParseNode(int(i), tag, word, p_father[i], type_idx_to_token[p_type[i]])
+ if label:
+ assert label[0] != Sub_Head
+ tree = trees.InternalParseNode(label, [tree])
+ return [tree]
+ else:
+ left_trees = make_tree()
+ right_trees = make_tree()
+ children = left_trees + right_trees
+ if label and label[0] != Sub_Head:
+ return [trees.InternalParseNode(label, children)]
+ else:
+ return children
+
+ tree_list = make_tree()
+ assert len(tree_list) == 1
+ tree = tree_list[0]
+ return tree, score
diff --git a/hanlp/components/parsers/hpsg/trees.py b/hanlp/components/parsers/hpsg/trees.py
new file mode 100755
index 000000000..0850a94e4
--- /dev/null
+++ b/hanlp/components/parsers/hpsg/trees.py
@@ -0,0 +1,330 @@
+import collections.abc
+
+Sub_Head = ""
+No_Head = ""
+Htype = 1
+Ntype = 0
+
+
+class TreebankNode(object):
+ pass
+
+
+class InternalTreebankNode(TreebankNode):
+ def __init__(self, label, children):
+ assert isinstance(label, str)
+ self.label = label
+ assert isinstance(children, collections.abc.Sequence)
+ assert all(isinstance(child, TreebankNode) for child in children)
+ assert children
+ self.children = tuple(children)
+ self.father = self.children[0].father
+ self.type = self.children[0].type
+ self.head = self.children[0].head
+ self.left = self.children[0].left
+ self.right = self.children[-1].right
+ self.cun = 0
+ flag = 0
+ for child in self.children:
+ if child.father < self.left + 1 or child.father > self.right:
+ self.father = child.father
+ self.type = child.type
+ self.head = child.head
+ flag = 1
+
+ for child in self.children:
+ if child.head != self.head:
+ if child.father != self.head:
+ self.cun += 1
+
+ def linearize(self):
+ return "({} {})".format(
+ self.label, " ".join(child.linearize() for child in self.children))
+
+ def leaves(self):
+ for child in self.children:
+ yield from child.leaves()
+
+ def convert(self, index=0, nocache=False):
+ tree = self
+ sublabels = [self.label]
+
+ while len(tree.children) == 1 and isinstance(
+ tree.children[0], InternalTreebankNode):
+ tree = tree.children[0]
+ sublabels.append(tree.label)
+
+ pre_children = []
+ children = []
+ sub_father = set()
+ sub_head = set()
+ al_make = set()
+
+ for child in tree.children:
+ sub_head |= set([child.head])
+ sub_father |= set([child.father])
+
+ for child in tree.children:
+ # not in sub tree
+ if (child.father in sub_head and child.father != self.head) or (
+ child.head in sub_father and child.head != self.head):
+ sub_r = child.father
+ if child.head in sub_father:
+ sub_r = child.head
+ if sub_r not in al_make:
+ al_make |= set([sub_r])
+ else:
+ continue
+ sub_children = []
+ for sub_child in tree.children:
+ if sub_child.father == sub_r or sub_child.head == sub_r:
+ if len(sub_children) > 0:
+ assert sub_children[-1].right == sub_child.left # contiune span
+ sub_children.append(sub_child.convert(index=index))
+ index = sub_children[-1].right
+
+ assert len(sub_children) > 1
+
+ sub_node = InternalParseNode(tuple([Sub_Head]), sub_children, nocache=nocache)
+ if len(children) > 0:
+ assert children[-1].right == sub_node.left # contiune span
+ children.append(sub_node)
+ else:
+ if len(children) > 0:
+ assert children[-1].right == child.left # contiune span
+ children.append(child.convert(index=index))
+ index = children[-1].right
+
+ return InternalParseNode(tuple(sublabels), children, nocache=nocache)
+
+
+class LeafTreebankNode(TreebankNode):
+ def __init__(self, tag, word, head, father, type):
+ assert isinstance(tag, str)
+ self.tag = tag
+ self.father = father
+ self.type = type
+ self.head = head
+ assert isinstance(word, str)
+ self.word = word
+ self.left = self.head - 1
+ self.right = self.head
+
+ def linearize(self):
+ return "({} {})".format(self.tag, self.word)
+
+ def leaves(self):
+ yield self
+
+ def convert(self, index=0):
+ return LeafParseNode(index, self.tag, self.word, self.father, self.type)
+
+
+class ParseNode(object):
+ pass
+
+
+class InternalParseNode(ParseNode):
+ def __init__(self, label, children, nocache=False):
+ assert isinstance(label, tuple)
+ assert all(isinstance(sublabel, str) for sublabel in label)
+ assert label
+ self.label = label
+
+ assert isinstance(children, collections.abc.Sequence)
+ assert all(isinstance(child, ParseNode) for child in children)
+ assert children
+ assert len(children) > 1 or isinstance(children[0], LeafParseNode)
+ assert all(
+ left.right == right.left
+ for left, right in zip(children, children[1:]))
+ self.children = tuple(children)
+
+ self.left = children[0].left
+ self.right = children[-1].right
+
+ self.father = self.children[0].father
+ self.type = self.children[0].type
+ self.head = self.children[0].head
+ flag = 0
+ for child in self.children:
+ if child.father - 1 < self.left or child.father > self.right:
+ self.father = child.father
+ self.type = child.type
+ self.head = child.head
+ flag = 1
+
+ self.cun_w = 0
+ for child in self.children:
+ if self.head != child.head:
+ if child.father != self.head:
+ # child.father = self.head
+ self.cun_w += 1
+
+ self.nocache = nocache
+
+ def leaves(self):
+ for child in self.children:
+ yield from child.leaves()
+
+ def convert(self):
+ children = [child.convert() for child in self.children]
+ tree = InternalTreebankNode(self.label[-1], children)
+ for sublabel in reversed(self.label[:-1]):
+ tree = InternalTreebankNode(sublabel, [tree])
+ return tree
+
+ def enclosing(self, left, right):
+ assert self.left <= left < right <= self.right
+ for child in self.children:
+ if isinstance(child, LeafParseNode):
+ continue
+ if child.left <= left < right <= child.right:
+ return child.enclosing(left, right)
+ return self
+
+ def chil_enclosing(self, left, right):
+ assert self.left <= left < right <= self.right
+ for child in self.children:
+ if child.left <= left < right <= child.right:
+ return child.chil_enclosing(left, right)
+ return self
+
+ def oracle_label(self, left, right):
+ enclosing = self.enclosing(left, right)
+ if enclosing.left == left and enclosing.right == right:
+ return enclosing.label
+ return ()
+
+ def oracle_type(self, left, right):
+ enclosing = self.chil_enclosing(left, right)
+ return enclosing.type
+
+ def oracle_head(self, left, right):
+ enclosing = self.chil_enclosing(left, right)
+ return enclosing.head
+
+ def oracle_splits(self, left, right):
+ return [
+ child.left
+ for child in self.enclosing(left, right).children
+ if left < child.left < right
+ ]
+
+
+class LeafParseNode(ParseNode):
+ def __init__(self, index, tag, word, father, type):
+ assert isinstance(index, int)
+ assert index >= 0
+ self.left = index
+ self.right = index + 1
+
+ assert isinstance(tag, str)
+ self.tag = tag
+ self.head = index + 1
+ self.father = father
+ self.type = type
+
+ assert isinstance(word, str)
+ self.word = word
+
+ def leaves(self):
+ yield self
+
+ def chil_enclosing(self, left, right):
+ assert self.left <= left < right <= self.right
+ return self
+
+ def convert(self):
+ return LeafTreebankNode(self.tag, self.word, self.head, self.father, self.type)
+
+
+def load_trees(path, heads=None, types=None, wordss=None, strip_top=True):
+ with open(path) as infile:
+ treebank = infile.read()
+
+ return load_trees_from_str(treebank, heads, types, wordss, strip_top)
+
+
+def load_trees_from_str(treebank, heads=None, types=None, wordss=None, strip_top=True):
+ tokens = treebank.replace("(", " ( ").replace(")", " ) ").split()
+ cun_word = 0 # without root
+ cun_sent = 0
+
+ def helper(index, flag_sent):
+ nonlocal cun_sent
+ nonlocal cun_word
+ trees = []
+
+ while index < len(tokens) and tokens[index] == "(":
+ paren_count = 0
+ while tokens[index] == "(":
+ index += 1
+ paren_count += 1
+
+ label = tokens[index]
+
+ index += 1
+
+ if tokens[index] == "(":
+ children, index = helper(index, flag_sent=0)
+ if len(children) > 0:
+ tr = InternalTreebankNode(label, children)
+ trees.append(tr)
+ else:
+ word = tokens[index]
+ index += 1
+ if label != '-NONE-':
+ trees.append(LeafTreebankNode(label, word, head=cun_word + 1, father=heads[cun_sent][cun_word],
+ type=types[cun_sent][cun_word]))
+ if cun_sent < 0:
+ print(cun_sent, cun_word + 1, word, heads[cun_sent][cun_word], types[cun_sent][cun_word])
+ cun_word += 1
+
+ while paren_count > 0:
+ assert tokens[index] == ")"
+ index += 1
+ paren_count -= 1
+
+ if flag_sent == 1:
+ cun_sent += 1
+ cun_word = 0
+
+ return trees, index
+
+ trees, index = helper(0, flag_sent=1)
+ assert index == len(tokens)
+ assert len(trees) == cun_sent
+ if strip_top:
+ for i, tree in enumerate(trees):
+ if tree.label in ("TOP", "ROOT"):
+ assert len(tree.children) == 1
+ trees[i] = tree.children[0]
+
+ def process_NONE(tree):
+
+ if isinstance(tree, LeafTreebankNode):
+ label = tree.tag
+ if label == '-NONE-':
+ return None
+ else:
+ return tree
+
+ tr = []
+ label = tree.label
+ if label == '-NONE-':
+ return None
+ for node in tree.children:
+ new_node = process_NONE(node)
+ if new_node is not None:
+ tr.append(new_node)
+ if tr == []:
+ return None
+ else:
+ return InternalTreebankNode(label, tr)
+
+ new_trees = []
+ for i, tree in enumerate(trees):
+ new_tree = process_NONE(tree)
+ new_trees.append(new_tree)
+ return new_trees
diff --git a/hanlp/components/parsers/parse_alg.py b/hanlp/components/parsers/parse_alg.py
new file mode 100644
index 000000000..ef7d6eb2e
--- /dev/null
+++ b/hanlp/components/parsers/parse_alg.py
@@ -0,0 +1,310 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-04-02 23:20
+from collections import defaultdict
+from hanlp.components.parsers.chu_liu_edmonds import decode_mst
+import numpy as np
+
+
+class Tarjan:
+ """Computes Tarjan's algorithm for finding strongly connected components (cycles) of a graph"""
+
+ def __init__(self, prediction, tokens):
+ """
+
+ Parameters
+ ----------
+ prediction : numpy.ndarray
+ a predicted dependency tree where prediction[dep_idx] = head_idx
+ tokens : numpy.ndarray
+ the tokens we care about (i.e. exclude _GO, _EOS, and _PAD)
+ """
+ self._edges = defaultdict(set)
+ self._vertices = set((0,))
+ for dep, head in enumerate(prediction[tokens]):
+ self._vertices.add(dep + 1)
+ self._edges[head].add(dep + 1)
+ self._indices = {}
+ self._lowlinks = {}
+ self._onstack = defaultdict(lambda: False)
+ self._SCCs = []
+
+ index = 0
+ stack = []
+ for v in self.vertices:
+ if v not in self.indices:
+ self.strongconnect(v, index, stack)
+
+ # =============================================================
+ def strongconnect(self, v, index, stack):
+ """
+
+ Args:
+ v:
+ index:
+ stack:
+
+ Returns:
+
+ """
+
+ self._indices[v] = index
+ self._lowlinks[v] = index
+ index += 1
+ stack.append(v)
+ self._onstack[v] = True
+ for w in self.edges[v]:
+ if w not in self.indices:
+ self.strongconnect(w, index, stack)
+ self._lowlinks[v] = min(self._lowlinks[v], self._lowlinks[w])
+ elif self._onstack[w]:
+ self._lowlinks[v] = min(self._lowlinks[v], self._indices[w])
+
+ if self._lowlinks[v] == self._indices[v]:
+ self._SCCs.append(set())
+ while stack[-1] != v:
+ w = stack.pop()
+ self._onstack[w] = False
+ self._SCCs[-1].add(w)
+ w = stack.pop()
+ self._onstack[w] = False
+ self._SCCs[-1].add(w)
+ return
+
+ # ======================
+ @property
+ def edges(self):
+ return self._edges
+
+ @property
+ def vertices(self):
+ return self._vertices
+
+ @property
+ def indices(self):
+ return self._indices
+
+ @property
+ def SCCs(self):
+ return self._SCCs
+
+
+class UnionFind(object):
+
+ def __init__(self, n) -> None:
+ super().__init__()
+ self.parent = [x for x in range(n)]
+ self.height = [0] * n
+
+ def find(self, x):
+ if self.parent[x] == x:
+ return x
+ self.parent[x] = self.find(self.parent[x])
+ return self.parent[x]
+
+ def unite(self, x, y):
+ x = self.find(x)
+ y = self.find(y)
+ if x == y:
+ return
+ if self.height[x] < self.height[y]:
+ self.parent[x] = y
+ else:
+ self.parent[y] = x
+ if self.height[x] == self.height[y]:
+ self.height[x] += 1
+
+ def same(self, x, y):
+ return self.find(x) == self.find(y)
+
+
+def tarjan(parse_probs, length, tokens_to_keep, ensure_tree=True):
+ """Adopted from Timothy Dozat https://github.com/tdozat/Parser/blob/master/lib/models/nn.py
+
+ Args:
+ parse_probs(NDArray): seq_len x seq_len, the probability of arcs
+ length(NDArray): sentence length including ROOT
+ tokens_to_keep(NDArray): mask matrix
+ ensure_tree: (Default value = True)
+
+ Returns:
+
+
+ """
+ if ensure_tree:
+ parse_preds, parse_probs, tokens = unique_root(parse_probs, tokens_to_keep, length)
+ # remove cycles
+ tarjan = Tarjan(parse_preds, tokens)
+ for SCC in tarjan.SCCs:
+ if len(SCC) > 1:
+ dependents = set()
+ to_visit = set(SCC)
+ while len(to_visit) > 0:
+ node = to_visit.pop()
+ if not node in dependents:
+ dependents.add(node)
+ to_visit.update(tarjan.edges[node])
+ # The indices of the nodes that participate in the cycle
+ cycle = np.array(list(SCC))
+ # The probabilities of the current heads
+ old_heads = parse_preds[cycle]
+ old_head_probs = parse_probs[cycle, old_heads]
+ # Set the probability of depending on a non-head to zero
+ non_heads = np.array(list(dependents))
+ parse_probs[np.repeat(cycle, len(non_heads)), np.repeat([non_heads], len(cycle), axis=0).flatten()] = 0
+ # Get new potential heads and their probabilities
+ new_heads = np.argmax(parse_probs[cycle][:, tokens], axis=1) + 1
+ new_head_probs = parse_probs[cycle, new_heads] / old_head_probs
+ # Select the most probable change
+ change = np.argmax(new_head_probs)
+ changed_cycle = cycle[change]
+ old_head = old_heads[change]
+ new_head = new_heads[change]
+ # Make the change
+ parse_preds[changed_cycle] = new_head
+ tarjan.edges[new_head].add(changed_cycle)
+ tarjan.edges[old_head].remove(changed_cycle)
+ return parse_preds
+ else:
+ # block and pad heads
+ parse_probs = parse_probs * tokens_to_keep
+ parse_preds = np.argmax(parse_probs, axis=1)
+ return parse_preds
+
+
+def chu_liu_edmonds(parse_probs, length):
+ tree = decode_mst(parse_probs.T, length, False)[0]
+ tree[0] = 0
+ return tree
+
+
+def unique_root(parse_probs, tokens_to_keep: np.ndarray, length):
+ I = np.eye(len(tokens_to_keep))
+ # block loops and pad heads
+ if tokens_to_keep.ndim == 1:
+ tokens_to_keep = np.expand_dims(tokens_to_keep, -1)
+ parse_probs = parse_probs * tokens_to_keep * (1 - I)
+ parse_preds = np.argmax(parse_probs, axis=1)
+ tokens = np.arange(1, length)
+ roots = np.where(parse_preds[tokens] == 0)[0] + 1
+ # ensure at least one root
+ if len(roots) < 1:
+ # The current root probabilities
+ root_probs = parse_probs[tokens, 0]
+ # The current head probabilities
+ old_head_probs = parse_probs[tokens, parse_preds[tokens]]
+ # Get new potential root probabilities
+ new_root_probs = root_probs / old_head_probs
+ # Select the most probable root
+ new_root = tokens[np.argmax(new_root_probs)]
+ # Make the change
+ parse_preds[new_root] = 0
+ # ensure at most one root
+ elif len(roots) > 1:
+ # The probabilities of the current heads
+ root_probs = parse_probs[roots, 0]
+ # Set the probability of depending on the root zero
+ parse_probs[roots, 0] = 0
+ # Get new potential heads and their probabilities
+ new_heads = np.argmax(parse_probs[roots][:, tokens], axis=1) + 1
+ new_head_probs = parse_probs[roots, new_heads] / root_probs
+ # Select the most probable root
+ new_root = roots[np.argmin(new_head_probs)]
+ # Make the change
+ parse_preds[roots] = new_heads
+ parse_preds[new_root] = 0
+ return parse_preds, parse_probs, tokens
+
+
+def dfs(graph, start, end):
+ fringe = [(start, [])]
+ while fringe:
+ state, path = fringe.pop()
+ if path and state == end:
+ yield path
+ continue
+ for next_state in graph[state]:
+ if next_state in path:
+ continue
+ fringe.append((next_state, path + [next_state]))
+
+
+def mst_then_greedy(arc_scores, rel_scores, mask, root_rel_idx, rel_idx=None):
+ from scipy.special import softmax
+ from scipy.special import expit as sigmoid
+ length = sum(mask) + 1
+ mask = mask[:length]
+ arc_scores = arc_scores[:length, :length]
+ arc_pred = arc_scores > 0
+ arc_probs = sigmoid(arc_scores)
+ rel_scores = rel_scores[:length, :length, :]
+ rel_probs = softmax(rel_scores, -1)
+ if not any(arc_pred[:, 0][1:]): # no root
+ root = np.argmax(rel_probs[1:, 0, root_rel_idx]) + 1
+ arc_probs[root, 0] = 1
+ parse_preds, parse_probs, tokens = unique_root(arc_probs, mask, length)
+ root = adjust_root_score(arc_scores, parse_preds, root_rel_idx, rel_scores)
+ tree = chu_liu_edmonds(arc_scores, length)
+ if rel_idx is not None: # Unknown DEPREL label: 'ref'
+ rel_scores[np.arange(len(tree)), tree, rel_idx] = -float('inf')
+ return tree, add_secondary_arcs_by_scores(arc_scores, rel_scores, tree, root_rel_idx)
+
+
+def adjust_root_score(arc_scores, parse_preds, root_rel_idx, rel_scores=None):
+ root = np.where(parse_preds[1:] == 0)[0] + 1
+ arc_scores[:, 0] = min(np.min(arc_scores), -1000)
+ arc_scores[root, 0] = max(np.max(arc_scores), 1000)
+ if rel_scores is not None:
+ rel_scores[:, :, root_rel_idx] = -float('inf')
+ rel_scores[root, 0, root_rel_idx] = float('inf')
+ return root
+
+
+def add_secondary_arcs_by_scores(arc_scores, rel_scores, tree, root_rel_idx, arc_preds=None):
+ if not isinstance(tree, np.ndarray):
+ tree = np.array(tree)
+ if arc_preds is None:
+ arc_preds = arc_scores > 0
+ rel_pred = np.argmax(rel_scores, axis=-1)
+
+ return add_secondary_arcs_by_preds(arc_scores, arc_preds, rel_pred, tree, root_rel_idx)
+
+
+def add_secondary_arcs_by_preds(arc_scores, arc_preds, rel_preds, tree, root_rel_idx=None):
+ dh = np.argwhere(arc_preds)
+ sdh = sorted([(arc_scores[x[0], x[1]], list(x)) for x in dh], reverse=True)
+ graph = [[] for _ in range(len(tree))]
+ for d, h in enumerate(tree):
+ if d:
+ graph[h].append(d)
+ for s, (d, h) in sdh:
+ if not d or not h or d in graph[h]:
+ continue
+ try:
+ path = next(dfs(graph, d, h))
+ except StopIteration:
+ # no path from d to h
+ graph[h].append(d)
+ parse_graph = [[] for _ in range(len(tree))]
+ num_root = 0
+ for h in range(len(tree)):
+ for d in graph[h]:
+ rel = rel_preds[d, h]
+ if h == 0 and root_rel_idx is not None:
+ rel = root_rel_idx
+ assert num_root == 0
+ num_root += 1
+ parse_graph[d].append((h, rel))
+ parse_graph[d] = sorted(parse_graph[d])
+ return parse_graph
+
+
+def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree, root_rel_idx):
+ if len(arc_scores) != tree:
+ arc_scores = arc_scores[:len(tree), :len(tree)]
+ rel_scores = rel_scores[:len(tree), :len(tree), :]
+ parse_preds = arc_scores > 0
+ # adjust_root_score(arc_scores, parse_preds, rel_scores)
+ parse_preds[:, 0] = False # set heads to False
+ rel_scores[:, :, root_rel_idx] = -float('inf')
+ return add_secondary_arcs_by_scores(arc_scores, rel_scores, tree, root_rel_idx, parse_preds)
diff --git a/hanlp/components/parsers/second_order/__init__.py b/hanlp/components/parsers/second_order/__init__.py
new file mode 100644
index 000000000..ca29b2ab2
--- /dev/null
+++ b/hanlp/components/parsers/second_order/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-01 13:44
diff --git a/hanlp/components/parsers/second_order/affine.py b/hanlp/components/parsers/second_order/affine.py
new file mode 100644
index 000000000..e58fe1480
--- /dev/null
+++ b/hanlp/components/parsers/second_order/affine.py
@@ -0,0 +1,171 @@
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import torch
+import torch.nn as nn
+
+
+class Biaffine(nn.Module):
+ r"""
+ Biaffine layer for first-order scoring.
+
+ This function has a tensor of weights :math:`W` and bias terms if needed.
+ The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y`,
+ in which :math:`x` and :math:`y` can be concatenated with bias terms.
+
+ References:
+ - Timothy Dozat and Christopher D. Manning. 2017.
+ `Deep Biaffine Attention for Neural Dependency Parsing`_.
+
+ Args:
+ n_in (int):
+ The size of the input feature.
+ n_out (int):
+ The number of output channels.
+ bias_x (bool):
+ If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``.
+ bias_y (bool):
+ If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``.
+
+ .. _Deep Biaffine Attention for Neural Dependency Parsing:
+ https://openreview.net/forum?id=Hk95PK9le
+ """
+
+ def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True):
+ super().__init__()
+
+ self.n_in = n_in
+ self.n_out = n_out
+ self.bias_x = bias_x
+ self.bias_y = bias_y
+ self.weight = nn.Parameter(torch.Tensor(n_out, n_in+bias_x, n_in+bias_y))
+
+ self.reset_parameters()
+
+ def __repr__(self):
+ s = f"n_in={self.n_in}, n_out={self.n_out}"
+ if self.bias_x:
+ s += f", bias_x={self.bias_x}"
+ if self.bias_y:
+ s += f", bias_y={self.bias_y}"
+
+ return f"{self.__class__.__name__}({s})"
+
+ def reset_parameters(self):
+ nn.init.zeros_(self.weight)
+
+ def forward(self, x, y):
+ r"""
+ Args:
+ x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+ y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+
+ Returns:
+ ~torch.Tensor:
+ A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``.
+ If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
+ """
+
+ if self.bias_x:
+ x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
+ if self.bias_y:
+ y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
+ # [batch_size, n_out, seq_len, seq_len]
+ s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y)
+ # remove dim 1 if n_out == 1
+ s = s.squeeze(1)
+
+ return s
+
+
+class Triaffine(nn.Module):
+ r"""
+ Triaffine layer for second-order scoring.
+
+ This function has a tensor of weights :math:`W` and bias terms if needed.
+ The score :math:`s(x, y, z)` of the vector triple :math:`(x, y, z)` is computed as :math:`x^T z^T W y`.
+ Usually, :math:`x` and :math:`y` can be concatenated with bias terms.
+
+ References:
+ - Yu Zhang, Zhenghua Li and Min Zhang. 2020.
+ `Efficient Second-Order TreeCRF for Neural Dependency Parsing`_.
+ - Xinyu Wang, Jingxian Huang, and Kewei Tu. 2019.
+ `Second-Order Semantic Dependency Parsing with End-to-End Neural Networks`_.
+
+ Args:
+ n_in (int):
+ The size of the input feature.
+ bias_x (bool):
+ If ``True``, adds a bias term for tensor :math:`x`. Default: ``False``.
+ bias_y (bool):
+ If ``True``, adds a bias term for tensor :math:`y`. Default: ``False``.
+
+ .. _Efficient Second-Order TreeCRF for Neural Dependency Parsing:
+ https://www.aclweb.org/anthology/2020.acl-main.302/
+ .. _Second-Order Semantic Dependency Parsing with End-to-End Neural Networks:
+ https://www.aclweb.org/anthology/P19-1454/
+ """
+
+ def __init__(self, n_in, bias_x=False, bias_y=False):
+ super().__init__()
+
+ self.n_in = n_in
+ self.bias_x = bias_x
+ self.bias_y = bias_y
+ self.weight = nn.Parameter(torch.Tensor(n_in+bias_x, n_in, n_in+bias_y))
+
+ self.reset_parameters()
+
+ def __repr__(self):
+ s = f"n_in={self.n_in}"
+ if self.bias_x:
+ s += f", bias_x={self.bias_x}"
+ if self.bias_y:
+ s += f", bias_y={self.bias_y}"
+
+ return f"{self.__class__.__name__}({s})"
+
+ def reset_parameters(self):
+ nn.init.zeros_(self.weight)
+
+ def forward(self, x, y, z):
+ r"""
+ Args:
+ x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+ y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+ z (torch.Tensor): ``[batch_size, seq_len, n_in]``.
+
+ Returns:
+ ~torch.Tensor:
+ A scoring tensor of shape ``[batch_size, seq_len, seq_len, seq_len]``.
+ """
+
+ if self.bias_x:
+ x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
+ if self.bias_y:
+ y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
+ w = torch.einsum('bzk,ikj->bzij', z, self.weight)
+ # [batch_size, seq_len, seq_len, seq_len]
+ s = torch.einsum('bxi,bzij,byj->bzxy', x, w, y)
+
+ return s
diff --git a/hanlp/components/parsers/second_order/model.py b/hanlp/components/parsers/second_order/model.py
new file mode 100644
index 000000000..d47a51406
--- /dev/null
+++ b/hanlp/components/parsers/second_order/model.py
@@ -0,0 +1,18 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-01 15:28
+from torch import nn
+
+
+# noinspection PyAbstractClass
+class DependencyModel(nn.Module):
+ def __init__(self, embed: nn.Module, encoder: nn.Module, decoder: nn.Module):
+ super().__init__()
+ self.embed = embed
+ self.encoder = encoder
+ self.decoder = decoder
+
+ def forward(self, batch, mask):
+ x = self.embed(batch, mask=mask)
+ x = self.encoder(x, mask)
+ return self.decoder(x, mask=mask)
diff --git a/hanlp/components/parsers/second_order/tree_crf_dependency_parser.py b/hanlp/components/parsers/second_order/tree_crf_dependency_parser.py
new file mode 100644
index 000000000..0882b349e
--- /dev/null
+++ b/hanlp/components/parsers/second_order/tree_crf_dependency_parser.py
@@ -0,0 +1,541 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-08 20:51
+import functools
+import os
+from typing import Union, Any, List
+
+import torch
+from alnlp.modules.util import lengths_to_mask
+from torch import nn
+from torch.optim import Adam
+from torch.optim.lr_scheduler import ExponentialLR
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import UNK, IDX
+from hanlp.common.dataset import PadSequenceDataLoader
+from hanlp.common.structure import History
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import LowerCase, FieldLength, PunctuationMask, TransformList
+from hanlp.common.vocab import Vocab, VocabCounter
+from hanlp_common.conll import CoNLLWord, CoNLLSentence
+from hanlp.components.parsers.constituency.treecrf import CRF2oDependency
+from hanlp.components.parsers.second_order.model import DependencyModel
+from hanlp.components.parsers.second_order.treecrf_decoder import TreeCRFDecoder
+from hanlp.datasets.parsing.conll_dataset import CoNLLParsingDataset, append_bos, get_sibs
+from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding, ContextualWordEmbeddingModule
+from hanlp.layers.embeddings.embedding import Embedding, EmbeddingList, ConcatModuleList
+from hanlp.layers.embeddings.util import index_word2vec_with_vocab
+from hanlp.layers.transformers.pt_imports import AutoModel_
+from hanlp.layers.transformers.utils import build_optimizer_scheduler_with_transformer
+from hanlp.metrics.parsing.attachmentscore import AttachmentScore
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs, merge_dict, reorder
+
+
+class TreeConditionalRandomFieldDependencyParser(TorchComponent):
+ def __init__(self) -> None:
+ super().__init__()
+ self.model: DependencyModel = self.model
+ self._transformer_transform = None
+
+ def predict(self, data: Any, batch_size=None, batch_max_tokens=None, output_format='conllx', **kwargs):
+ if not data:
+ return []
+ use_pos = self.use_pos
+ flat = self.input_is_flat(data, use_pos)
+ if flat:
+ data = [data]
+ samples = self.build_samples(data, use_pos)
+ if not batch_max_tokens:
+ batch_max_tokens = self.config.batch_max_tokens
+ if not batch_size:
+ batch_size = self.config.batch_size
+ dataloader = self.build_dataloader(samples,
+ device=self.devices[0], shuffle=False,
+ **merge_dict(self.config,
+ batch_size=batch_size,
+ batch_max_tokens=batch_max_tokens,
+ overwrite=True,
+ **kwargs))
+ predictions, build_data, data, order = self.before_outputs(data)
+ for batch in dataloader:
+ arc_scores, rel_scores, mask, puncts = self.feed_batch(batch)
+ self.collect_outputs(arc_scores, rel_scores, mask, batch, predictions, order, data, use_pos,
+ build_data)
+ outputs = self.post_outputs(predictions, data, order, use_pos, build_data)
+ if flat:
+ return outputs[0]
+ return outputs
+
+ def build_samples(self, data, use_pos=None):
+ samples = []
+ for idx, each in enumerate(data):
+ sample = {IDX: idx}
+ if use_pos:
+ token, pos = zip(*each)
+ sample.update({'FORM': list(token), 'CPOS': list(pos)})
+ else:
+ token = each
+ sample.update({'FORM': list(token)})
+ samples.append(sample)
+ return samples
+
+ def input_is_flat(self, data, use_pos=None):
+ if use_pos:
+ flat = isinstance(data[0], (list, tuple)) and isinstance(data[0][0], str)
+ else:
+ flat = isinstance(data[0], str)
+ return flat
+
+ def before_outputs(self, data):
+ predictions, order = [], []
+ build_data = data is None
+ if build_data:
+ data = []
+ return predictions, build_data, data, order
+
+ def post_outputs(self, predictions, data, order, use_pos, build_data):
+ predictions = reorder(predictions, order)
+ if build_data:
+ data = reorder(data, order)
+ outputs = []
+ self.predictions_to_human(predictions, outputs, data, use_pos)
+ return outputs
+
+ def predictions_to_human(self, predictions, outputs, data, use_pos):
+ for d, (arcs, rels) in zip(data, predictions):
+ sent = CoNLLSentence()
+ for idx, (cell, a, r) in enumerate(zip(d, arcs, rels)):
+ if use_pos:
+ token, pos = cell
+ else:
+ token, pos = cell, None
+ sent.append(CoNLLWord(idx + 1, token, cpos=pos, head=a, deprel=self.vocabs['rel'][r]))
+ outputs.append(sent)
+
+ def collect_outputs(self, arc_scores, rel_scores, mask, batch, predictions, order, data, use_pos,
+ build_data):
+ lens = [len(token) - 1 for token in batch['token']]
+ arc_preds, rel_preds = self.decode(arc_scores, rel_scores, mask, batch)
+ self.collect_outputs_extend(predictions, arc_preds, rel_preds, lens, mask)
+ order.extend(batch[IDX])
+ if build_data:
+ if use_pos:
+ data.extend(zip(batch['FORM'], batch['CPOS']))
+ else:
+ data.extend(batch['FORM'])
+
+ def collect_outputs_extend(self, predictions: list, arc_preds, rel_preds, lens, mask):
+ predictions.extend(zip([seq.tolist() for seq in arc_preds[mask].split(lens)],
+ [seq.tolist() for seq in rel_preds[mask].split(lens)]))
+
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ embed,
+ n_mlp_arc=500,
+ n_mlp_rel=100,
+ n_mlp_sib=100,
+ mlp_dropout=.33,
+ lr=2e-3,
+ transformer_lr=5e-5,
+ mu=.9,
+ nu=.9,
+ epsilon=1e-12,
+ grad_norm=5.0,
+ decay=.75,
+ decay_steps=5000,
+ weight_decay=0,
+ warmup_steps=0.1,
+ separate_optimizer=True,
+ patience=100,
+ lowercase=False,
+ epochs=50000,
+ tree=False,
+ proj=True,
+ mbr=True,
+ partial=False,
+ punct=False,
+ min_freq=2,
+ logger=None,
+ verbose=True,
+ unk=UNK,
+ max_sequence_length=512,
+ batch_size=None,
+ sampler_builder=None,
+ gradient_accumulation=1,
+ devices: Union[float, int, List[int]] = None,
+ transform=None,
+ eval_trn=False,
+ bos='\0',
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def execute_training_loop(self, trn, dev, devices, epochs, logger, patience, save_dir, optimizer,
+ gradient_accumulation, **kwargs):
+ optimizer, scheduler, transformer_optimizer, transformer_scheduler = optimizer
+ criterion = self.build_criterion()
+ best_e, best_metric = 0, self.build_metric()
+ timer = CountdownTimer(epochs)
+ history = History()
+ ratio_width = len(f'{len(trn) // gradient_accumulation}/{len(trn) // gradient_accumulation}')
+ for epoch in range(1, epochs + 1):
+ # train one epoch and update the parameters
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, optimizer, scheduler, criterion, epoch, logger, history,
+ transformer_optimizer, transformer_scheduler,
+ gradient_accumulation=gradient_accumulation, eval_trn=self.config.eval_trn)
+ loss, dev_metric = self.evaluate_dataloader(dev, criterion, ratio_width=ratio_width, logger=logger)
+ timer.update()
+ # logger.info(f"{'Dev' + ' ' * ratio_width} loss: {loss:.4f} {dev_metric}")
+ # save the model if it is the best so far
+ report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
+ if dev_metric > best_metric:
+ best_e, best_metric = epoch, dev_metric
+ self.save_weights(save_dir)
+ report += ' ([red]saved[/red])'
+ else:
+ if patience != epochs:
+ report += f' ({epoch - best_e}/{patience})'
+ else:
+ report += f' ({epoch - best_e})'
+ logger.info(report)
+ if patience is not None and epoch - best_e >= patience:
+ logger.info(f'LAS has stopped improving for {patience} epochs, early stop.')
+ break
+ timer.stop()
+ if not best_e:
+ self.save_weights(save_dir)
+ elif best_e != epoch:
+ self.load_weights(save_dir)
+ logger.info(f"Max score of dev is {best_metric.score:.2%} at epoch {best_e}")
+ logger.info(f"Average time of each epoch is {timer.elapsed_average_human}")
+ logger.info(f"{timer.elapsed_human} elapsed")
+
+ def build_optimizer(self, epochs, trn, gradient_accumulation, **kwargs):
+ config = self.config
+ model = self.model
+ if isinstance(model, nn.DataParallel):
+ model = model.module
+ transformer = self._get_transformer_builder()
+ if transformer and transformer.trainable:
+ transformer = self._get_transformer()
+ optimizer = Adam(set(model.parameters()) - set(transformer.parameters()),
+ config.lr,
+ (config.mu, config.nu),
+ config.epsilon)
+ if self.config.transformer_lr:
+ num_training_steps = len(trn) * epochs // gradient_accumulation
+ if not self.config.separate_optimizer:
+ optimizer, scheduler = build_optimizer_scheduler_with_transformer(model,
+ transformer,
+ config.lr,
+ config.transformer_lr,
+ num_training_steps,
+ config.warmup_steps,
+ config.weight_decay,
+ config.epsilon)
+ transformer_optimizer, transformer_scheduler = None, None
+ else:
+ transformer_optimizer, transformer_scheduler = \
+ build_optimizer_scheduler_with_transformer(transformer,
+ transformer,
+ config.lr,
+ config.transformer_lr,
+ num_training_steps,
+ config.warmup_steps,
+ config.weight_decay,
+ config.epsilon)
+ else:
+ transformer.requires_grad_(False)
+ transformer_optimizer, transformer_scheduler = None, None
+ else:
+ optimizer = Adam(model.parameters(),
+ config.lr,
+ (config.mu, config.nu),
+ config.epsilon)
+ transformer_optimizer, transformer_scheduler = None, None
+ if self.config.separate_optimizer:
+ scheduler = ExponentialLR(optimizer, config.decay ** (1 / config.decay_steps))
+ # noinspection PyUnboundLocalVariable
+ optimizer = Adam(model.parameters(), **{'lr': 0.002, 'betas': (0.9, 0.9), 'eps': 1e-12})
+ scheduler = ExponentialLR(optimizer, **{'gamma': 0.9999424652406974})
+ return optimizer, scheduler, transformer_optimizer, transformer_scheduler
+
+ # noinspection PyMethodOverriding
+ def build_dataloader(self,
+ data,
+ shuffle,
+ device,
+ embed: Embedding,
+ training=False,
+ logger=None,
+ gradient_accumulation=1,
+ sampler_builder=None,
+ batch_size=None,
+ bos='\0',
+ **kwargs) -> DataLoader:
+ first_transform = TransformList(functools.partial(append_bos, bos=bos))
+ embed_transform = embed.transform(vocabs=self.vocabs)
+ transformer_transform = self._get_transformer_transform_from_transforms(embed_transform)
+ if embed_transform:
+ if transformer_transform and isinstance(embed_transform, TransformList):
+ embed_transform.remove(transformer_transform)
+
+ first_transform.append(embed_transform)
+ dataset = self.build_dataset(data, first_transform=first_transform)
+ if self.config.get('transform', None):
+ dataset.append_transform(self.config.transform)
+
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger, self._transformer_trainable())
+ if transformer_transform and isinstance(embed_transform, TransformList):
+ embed_transform.append(transformer_transform)
+
+ dataset.append_transform(FieldLength('token', 'sent_length'))
+ if isinstance(data, str):
+ dataset.purge_cache()
+ if len(dataset) > 1000 and isinstance(data, str):
+ timer = CountdownTimer(len(dataset))
+ self.cache_dataset(dataset, timer, training, logger)
+ if sampler_builder:
+ lens = [sample['sent_length'] for sample in dataset]
+ sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
+ else:
+ sampler = None
+ loader = PadSequenceDataLoader(dataset=dataset,
+ batch_sampler=sampler,
+ batch_size=batch_size,
+ pad=self.get_pad_dict(),
+ device=device,
+ vocabs=self.vocabs)
+ return loader
+
+ def cache_dataset(self, dataset, timer, training=False, logger=None):
+ for each in dataset:
+ timer.log('Preprocessing and caching samples [blink][yellow]...[/yellow][/blink]')
+
+ def get_pad_dict(self):
+ return {'arc': 0}
+
+ def build_dataset(self, data, first_transform=None):
+ if not first_transform:
+ first_transform = append_bos
+ transform = [first_transform, get_sibs]
+ if self.config.get('lowercase', False):
+ transform.append(LowerCase('token'))
+ transform.append(self.vocabs)
+ if not self.config.punct:
+ transform.append(PunctuationMask('token', 'punct_mask'))
+ return CoNLLParsingDataset(data, transform=transform)
+
+ def build_tokenizer_transform(self):
+ return TransformerSequenceTokenizer(self.transformer_tokenizer, 'token', '',
+ ret_token_span=True, cls_is_bos=True,
+ max_seq_length=self.config.get('max_sequence_length',
+ 512),
+ truncate_long_sequences=False)
+
+ def build_vocabs(self, dataset, logger=None, transformer=False):
+ rel_vocab = self.vocabs.get('rel', None)
+ if rel_vocab is None:
+ rel_vocab = Vocab(unk_token=None, pad_token=self.config.get('pad_rel', None))
+ self.vocabs.put(rel=rel_vocab)
+
+ timer = CountdownTimer(len(dataset))
+ if transformer:
+ token_vocab = None
+ else:
+ self.vocabs.token = token_vocab = VocabCounter(unk_token=self.config.get('unk', UNK))
+ for i, sample in enumerate(dataset):
+ timer.log('Building vocab [blink][yellow]...[/yellow][/blink]', ratio_percentage=True)
+ min_freq = self.config.get('min_freq', None)
+ if min_freq:
+ token_vocab.trim(min_freq)
+ rel_vocab.set_unk_as_safe_unk() # Some relation in dev set is OOV
+ self.vocabs.lock()
+ self.vocabs.summary(logger=logger)
+ if token_vocab:
+ self.config.n_words = len(self.vocabs['token'])
+ self.config.n_rels = len(self.vocabs['rel'])
+ if token_vocab:
+ self.config.pad_index = self.vocabs['token'].pad_idx
+ self.config.unk_index = self.vocabs['token'].unk_idx
+
+ # noinspection PyMethodOverriding
+ def build_model(self, embed: Embedding, encoder, n_mlp_arc, n_mlp_rel, mlp_dropout, n_mlp_sib, training=True,
+ **kwargs) -> torch.nn.Module:
+ model = DependencyModel(
+ embed=embed.module(vocabs=self.vocabs),
+ encoder=encoder,
+ decoder=TreeCRFDecoder(encoder.get_output_dim(), n_mlp_arc, n_mlp_sib, n_mlp_rel, mlp_dropout,
+ len(self.vocabs['rel']))
+ )
+ return model
+
+ def build_embeddings(self, training=True):
+ pretrained_embed = None
+ if self.config.get('pretrained_embed', None):
+ pretrained_embed = index_word2vec_with_vocab(self.config.pretrained_embed, self.vocabs['token'],
+ init='zeros', normalize=True)
+ transformer = self.config.transformer
+ if transformer:
+ transformer = AutoModel_.from_pretrained(transformer, training=training)
+ return pretrained_embed, transformer
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn,
+ optimizer,
+ scheduler,
+ criterion,
+ epoch,
+ logger,
+ history: History,
+ transformer_optimizer=None,
+ transformer_scheduler=None,
+ gradient_accumulation=1,
+ eval_trn=False,
+ **kwargs):
+ self.model.train()
+
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation))
+ metric = self.build_metric(training=True)
+ total_loss = 0
+ for idx, batch in enumerate(trn):
+ optimizer.zero_grad()
+ (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
+ arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']
+
+ loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
+ if gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += loss.item()
+ if eval_trn:
+ arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
+ self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
+ if history.step(gradient_accumulation):
+ self._step(optimizer, scheduler, transformer_optimizer, transformer_scheduler)
+ report = self._report(total_loss / (timer.current + 1), metric if eval_trn else None)
+ lr = scheduler.get_last_lr()[0]
+ report += f' lr: {lr:.4e}'
+ timer.log(report, ratio_percentage=False, logger=logger)
+ del loss
+
+ def _step(self, optimizer, scheduler, transformer_optimizer, transformer_scheduler):
+ if self.config.get('grad_norm', None):
+ nn.utils.clip_grad_norm_(self.model.parameters(),
+ self.config.grad_norm)
+ optimizer.step()
+ scheduler.step()
+ if self._transformer_transform and self.config.transformer_lr and transformer_optimizer:
+ transformer_optimizer.step()
+ transformer_optimizer.zero_grad()
+ transformer_scheduler.step()
+
+ def feed_batch(self, batch):
+ words, feats, lens, puncts = batch.get('token_id', None), batch.get('pos_id', None), batch['sent_length'], \
+ batch.get('punct_mask', None)
+ mask = lengths_to_mask(lens)
+ logits = self.model(batch, mask)
+ if self.model.training:
+ mask = mask.clone()
+ # ignore the first token of each sentence
+ mask[:, 0] = 0
+ return logits, mask, puncts
+
+ def _report(self, loss, metric: AttachmentScore = None):
+ return f'loss: {loss:.4f} {metric}' if metric else f'loss: {loss:.4f}'
+
+ def compute_loss(self, s_arc, s_sib, s_rel, arcs, sibs, rels, mask):
+ crf: CRF2oDependency = self.model.decoder.crf
+ return crf.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.config.mbr, self.config.partial)
+
+ # noinspection PyUnboundLocalVariable
+ @torch.no_grad()
+ def evaluate_dataloader(self, loader: PadSequenceDataLoader, criterion, logger=None, filename=None, output=False,
+ ratio_width=None,
+ metric=None,
+ **kwargs):
+ self.model.eval()
+
+ total_loss = 0
+ if not metric:
+ metric = self.build_metric()
+
+ timer = CountdownTimer(len(loader))
+ for batch in loader:
+ (s_arc, s_sib, s_rel), mask, puncts = self.feed_batch(batch)
+ arcs, sibs, rels = batch['arc'], batch['sib_id'], batch['rel_id']
+ loss, s_arc = self.compute_loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask)
+ total_loss += float(loss)
+ arc_preds, rel_preds = self.decode(s_arc, s_sib, s_rel, mask)
+ self.update_metric(arc_preds, rel_preds, arcs, rels, mask, puncts, metric)
+ report = self._report(total_loss / (timer.current + 1), metric)
+ if filename:
+ report = f'{os.path.basename(filename)} ' + report
+ timer.log(report, ratio_percentage=False, logger=logger, ratio_width=ratio_width)
+ total_loss /= len(loader)
+
+ return total_loss, metric
+
+ def update_metric(self, arc_preds, rel_preds, arcs, rels, mask, puncts, metric):
+ # ignore all punctuation if not specified
+ if not self.config.punct:
+ mask &= puncts
+ metric(arc_preds, rel_preds, arcs, rels, mask)
+
+ def decode(self, s_arc, s_sib, s_rel, mask):
+ crf: CRF2oDependency = self.model.decoder.crf
+ return crf.decode(s_arc, s_sib, s_rel, mask, self.config.tree and not self.model.training, self.config.mbr,
+ self.config.proj)
+
+ def build_criterion(self, **kwargs):
+ return None
+
+ def build_metric(self, **kwargs):
+ return AttachmentScore()
+
+ def _get_transformer_transform_from_transforms(self, transform: Union[
+ TransformList, TransformerSequenceTokenizer]) -> TransformerSequenceTokenizer:
+ def _get():
+ if isinstance(transform, TransformerSequenceTokenizer):
+ # noinspection PyTypeChecker
+ return transform
+ elif isinstance(transform, TransformList):
+ # noinspection PyTypeChecker,PyArgumentList
+ for each in transform:
+ if isinstance(each, TransformerSequenceTokenizer):
+ return each
+
+ if self._transformer_transform is None:
+ self._transformer_transform = _get()
+ return self._transformer_transform
+
+ def _get_transformer(self):
+ embed = self.model.embed
+ if isinstance(embed, ContextualWordEmbeddingModule):
+ return embed
+ if isinstance(embed, ConcatModuleList):
+ for each in embed:
+ if isinstance(each, ContextualWordEmbeddingModule):
+ return each
+
+ def _get_transformer_builder(self):
+ embed: Embedding = self.config.embed
+ if isinstance(embed, ContextualWordEmbedding):
+ return embed
+ if isinstance(embed, EmbeddingList):
+ for each in embed.to_list():
+ if isinstance(embed, ContextualWordEmbedding):
+ return each
+
+ def _transformer_trainable(self):
+ builder = self._get_transformer_builder()
+ if not builder:
+ return False
+ return builder.trainable
diff --git a/hanlp/components/parsers/second_order/treecrf_decoder.py b/hanlp/components/parsers/second_order/treecrf_decoder.py
new file mode 100644
index 000000000..f36a99c0b
--- /dev/null
+++ b/hanlp/components/parsers/second_order/treecrf_decoder.py
@@ -0,0 +1,31 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-01 16:51
+from typing import Any, Tuple
+
+import torch
+
+from hanlp.components.parsers.biaffine.biaffine_model import BiaffineDecoder
+from hanlp.components.parsers.biaffine.mlp import MLP
+from hanlp.components.parsers.constituency.treecrf import CRF2oDependency
+from hanlp.components.parsers.second_order.affine import Triaffine
+
+
+class TreeCRFDecoder(BiaffineDecoder):
+ def __init__(self, hidden_size, n_mlp_arc, n_mlp_sib, n_mlp_rel, mlp_dropout, n_rels) -> None:
+ super().__init__(hidden_size, n_mlp_arc, n_mlp_rel, mlp_dropout, n_rels)
+ self.mlp_sib_s = MLP(hidden_size, n_mlp_sib, dropout=mlp_dropout)
+ self.mlp_sib_d = MLP(hidden_size, n_mlp_sib, dropout=mlp_dropout)
+ self.mlp_sib_h = MLP(hidden_size, n_mlp_sib, dropout=mlp_dropout)
+
+ self.sib_attn = Triaffine(n_in=n_mlp_sib, bias_x=True, bias_y=True)
+ self.crf = CRF2oDependency()
+
+ def forward(self, x, mask=None, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ s_arc, s_rel = super(TreeCRFDecoder, self).forward(x, mask)
+ sib_s = self.mlp_sib_s(x)
+ sib_d = self.mlp_sib_d(x)
+ sib_h = self.mlp_sib_h(x)
+ # [batch_size, seq_len, seq_len, seq_len]
+ s_sib = self.sib_attn(sib_s, sib_d, sib_h).permute(0, 3, 1, 2)
+ return s_arc, s_sib, s_rel
diff --git a/hanlp/components/parsers/ud/__init__.py b/hanlp/components/parsers/ud/__init__.py
new file mode 100644
index 000000000..231b31fc3
--- /dev/null
+++ b/hanlp/components/parsers/ud/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-14 20:34
diff --git a/hanlp/components/parsers/ud/lemma_edit.py b/hanlp/components/parsers/ud/lemma_edit.py
new file mode 100644
index 000000000..3967d907d
--- /dev/null
+++ b/hanlp/components/parsers/ud/lemma_edit.py
@@ -0,0 +1,131 @@
+"""
+Utilities for processing lemmas
+
+Adopted from UDPipe Future
+https://github.com/CoNLL-UD-2018/UDPipe-Future
+"""
+
+
+def min_edit_script(source, target, allow_copy=False):
+ """Finds the minimum edit script to transform the source to the target
+
+ Args:
+ source:
+ target:
+ allow_copy: (Default value = False)
+
+ Returns:
+
+ """
+ a = [[(len(source) + len(target) + 1, None)] * (len(target) + 1) for _ in range(len(source) + 1)]
+ for i in range(0, len(source) + 1):
+ for j in range(0, len(target) + 1):
+ if i == 0 and j == 0:
+ a[i][j] = (0, "")
+ else:
+ if allow_copy and i and j and source[i - 1] == target[j - 1] and a[i-1][j-1][0] < a[i][j][0]:
+ a[i][j] = (a[i-1][j-1][0], a[i-1][j-1][1] + "→")
+ if i and a[i-1][j][0] < a[i][j][0]:
+ a[i][j] = (a[i-1][j][0] + 1, a[i-1][j][1] + "-")
+ if j and a[i][j-1][0] < a[i][j][0]:
+ a[i][j] = (a[i][j-1][0] + 1, a[i][j-1][1] + "+" + target[j - 1])
+ return a[-1][-1][1]
+
+
+def gen_lemma_rule(form, lemma, allow_copy=False):
+ """Generates a lemma rule to transform the source to the target
+
+ Args:
+ form:
+ lemma:
+ allow_copy: (Default value = False)
+
+ Returns:
+
+ """
+ form = form.lower()
+
+ previous_case = -1
+ lemma_casing = ""
+ for i, c in enumerate(lemma):
+ case = "↑" if c.lower() != c else "↓"
+ if case != previous_case:
+ lemma_casing += "{}{}{}".format("¦" if lemma_casing else "", case, i if i <= len(lemma) // 2 else i - len(lemma))
+ previous_case = case
+ lemma = lemma.lower()
+
+ best, best_form, best_lemma = 0, 0, 0
+ for l in range(len(lemma)):
+ for f in range(len(form)):
+ cpl = 0
+ while f + cpl < len(form) and l + cpl < len(lemma) and form[f + cpl] == lemma[l + cpl]: cpl += 1
+ if cpl > best:
+ best = cpl
+ best_form = f
+ best_lemma = l
+
+ rule = lemma_casing + ";"
+ if not best:
+ rule += "a" + lemma
+ else:
+ rule += "d{}¦{}".format(
+ min_edit_script(form[:best_form], lemma[:best_lemma], allow_copy),
+ min_edit_script(form[best_form + best:], lemma[best_lemma + best:], allow_copy),
+ )
+ return rule
+
+
+def apply_lemma_rule(form, lemma_rule):
+ """Applies the lemma rule to the form to generate the lemma
+
+ Args:
+ form:
+ lemma_rule:
+
+ Returns:
+
+ """
+ casing, rule = lemma_rule.split(";", 1)
+ if rule.startswith("a"):
+ lemma = rule[1:]
+ else:
+ form = form.lower()
+ rules, rule_sources = rule[1:].split("¦"), []
+ assert len(rules) == 2
+ for rule in rules:
+ source, i = 0, 0
+ while i < len(rule):
+ if rule[i] == "→" or rule[i] == "-":
+ source += 1
+ else:
+ assert rule[i] == "+"
+ i += 1
+ i += 1
+ rule_sources.append(source)
+
+ try:
+ lemma, form_offset = "", 0
+ for i in range(2):
+ j, offset = 0, (0 if i == 0 else len(form) - rule_sources[1])
+ while j < len(rules[i]):
+ if rules[i][j] == "→":
+ lemma += form[offset]
+ offset += 1
+ elif rules[i][j] == "-":
+ offset += 1
+ else:
+ assert(rules[i][j] == "+")
+ lemma += rules[i][j + 1]
+ j += 1
+ j += 1
+ if i == 0:
+ lemma += form[rule_sources[0]: len(form) - rule_sources[1]]
+ except:
+ lemma = form
+
+ for rule in casing.split("¦"):
+ if rule == "↓0": continue # The lemma is lowercased initially
+ case, offset = rule[0], int(rule[1:])
+ lemma = lemma[:offset] + (lemma[offset:].upper() if case == "↑" else lemma[offset:].lower())
+
+ return lemma
diff --git a/hanlp/components/parsers/ud/tag_decoder.py b/hanlp/components/parsers/ud/tag_decoder.py
new file mode 100644
index 000000000..d9fc8100f
--- /dev/null
+++ b/hanlp/components/parsers/ud/tag_decoder.py
@@ -0,0 +1,137 @@
+# This file is modified from udify, which is licensed under the MIT license:
+# MIT License
+#
+# Copyright (c) 2019 Dan Kondratyuk
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+"""
+Decodes sequences of tags, e.g., POS tags, given a list of contextualized word embeddings
+"""
+
+from typing import Dict
+
+import numpy
+import torch
+import torch.nn.functional as F
+from alnlp.metrics import CategoricalAccuracy
+from torch.nn.modules.adaptive import AdaptiveLogSoftmaxWithLoss
+from torch.nn.modules.linear import Linear
+
+from hanlp.components.parsers.ud.lemma_edit import apply_lemma_rule
+from hanlp.components.parsers.ud.udify_util import sequence_cross_entropy, sequence_cross_entropy_with_logits
+
+
+class TagDecoder(torch.nn.Module):
+ """A basic sequence tagger that decodes from inputs of word embeddings"""
+
+ def __init__(self,
+ input_dim,
+ num_classes,
+ label_smoothing: float = 0.03,
+ adaptive: bool = False) -> None:
+ super(TagDecoder, self).__init__()
+
+ self.label_smoothing = label_smoothing
+ self.num_classes = num_classes
+ self.adaptive = adaptive
+
+ if self.adaptive:
+ adaptive_cutoffs = [round(self.num_classes / 15), 3 * round(self.num_classes / 15)]
+ self.task_output = AdaptiveLogSoftmaxWithLoss(input_dim,
+ self.num_classes,
+ cutoffs=adaptive_cutoffs,
+ div_value=4.0)
+ else:
+ self.task_output = Linear(self.output_dim, self.num_classes)
+
+ def forward(self,
+ encoded_text: torch.FloatTensor,
+ mask: torch.LongTensor,
+ gold_tags: torch.LongTensor,
+ ) -> Dict[str, torch.Tensor]:
+ hidden = encoded_text
+
+ batch_size, sequence_length, _ = hidden.size()
+ output_dim = [batch_size, sequence_length, self.num_classes]
+
+ loss_fn = self._adaptive_loss if self.adaptive else self._loss
+
+ output_dict = loss_fn(hidden, mask, gold_tags, output_dim)
+
+ return output_dict
+
+ def _adaptive_loss(self, hidden, mask, gold_tags, output_dim):
+ logits = hidden
+ reshaped_log_probs = logits.reshape(-1, logits.size(2))
+
+ class_probabilities = self.task_output.log_prob(reshaped_log_probs).view(output_dim)
+
+ output_dict = {"logits": logits, "class_probabilities": class_probabilities}
+
+ if gold_tags is not None:
+ output_dict["loss"] = sequence_cross_entropy(class_probabilities,
+ gold_tags,
+ mask,
+ label_smoothing=self.label_smoothing)
+
+ return output_dict
+
+ def _loss(self, hidden, mask, gold_tags, output_dim):
+ logits = self.task_output(hidden)
+ reshaped_log_probs = logits.view(-1, self.num_classes)
+ class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(output_dim)
+
+ output_dict = {"logits": logits, "class_probabilities": class_probabilities}
+
+ if gold_tags is not None:
+ output_dict["loss"] = sequence_cross_entropy_with_logits(logits,
+ gold_tags,
+ mask,
+ label_smoothing=self.label_smoothing)
+ return output_dict
+
+ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ all_words = output_dict["words"]
+
+ all_predictions = output_dict["class_probabilities"][self.task].cpu().data.numpy()
+ if all_predictions.ndim == 3:
+ predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
+ else:
+ predictions_list = [all_predictions]
+ all_tags = []
+ for predictions, words in zip(predictions_list, all_words):
+ argmax_indices = numpy.argmax(predictions, axis=-1)
+ tags = [self.vocab.get_token_from_index(x, namespace=self.task)
+ for x in argmax_indices]
+
+ if self.task == "lemmas":
+ def decode_lemma(word, rule):
+ if rule == "_":
+ return "_"
+ if rule == "@@UNKNOWN@@":
+ return word
+ return apply_lemma_rule(word, rule)
+
+ tags = [decode_lemma(word, rule) for word, rule in zip(words, tags)]
+
+ all_tags.append(tags)
+ output_dict[self.task] = all_tags
+
+ return output_dict
diff --git a/hanlp/components/parsers/ud/ud_model.py b/hanlp/components/parsers/ud/ud_model.py
new file mode 100644
index 000000000..729f49c1a
--- /dev/null
+++ b/hanlp/components/parsers/ud/ud_model.py
@@ -0,0 +1,139 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-15 14:21
+
+from typing import Dict, Any
+
+import torch
+
+from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
+from hanlp.components.parsers.biaffine.biaffine_model import BiaffineDecoder
+from hanlp.components.parsers.ud.tag_decoder import TagDecoder
+from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbeddingModule
+from hanlp.layers.scalar_mix import ScalarMixWithDropout
+
+
+class UniversalDependenciesModel(torch.nn.Module):
+ def __init__(self,
+ encoder: ContextualWordEmbeddingModule,
+ n_mlp_arc,
+ n_mlp_rel,
+ mlp_dropout,
+ num_rels,
+ num_lemmas,
+ num_upos,
+ num_feats,
+ mix_embedding: int = 13,
+ layer_dropout: int = 0.0):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = UniversalDependenciesDecoder(
+ encoder.get_output_dim(),
+ n_mlp_arc,
+ n_mlp_rel,
+ mlp_dropout,
+ num_rels,
+ num_lemmas,
+ num_upos,
+ num_feats,
+ mix_embedding,
+ layer_dropout
+ )
+
+ def forward(self,
+ batch: Dict[str, torch.Tensor],
+ mask,
+ ):
+ hidden = self.encoder(batch)
+ return self.decoder(hidden, batch=batch, mask=mask)
+
+
+class UniversalDependenciesDecoder(torch.nn.Module):
+ def __init__(self,
+ hidden_size,
+ n_mlp_arc,
+ n_mlp_rel,
+ mlp_dropout,
+ num_rels,
+ num_lemmas,
+ num_upos,
+ num_feats,
+ mix_embedding: int = 13,
+ layer_dropout: int = 0.0,
+ ) -> None:
+ super(UniversalDependenciesDecoder, self).__init__()
+
+ # decoders
+ self.decoders = torch.nn.ModuleDict({
+ 'lemmas': TagDecoder(hidden_size, num_lemmas, label_smoothing=0.03, adaptive=True),
+ 'upos': TagDecoder(hidden_size, num_upos, label_smoothing=0.03, adaptive=True),
+ 'deps': BiaffineDecoder(hidden_size, n_mlp_arc, n_mlp_rel, mlp_dropout, num_rels),
+ 'feats': TagDecoder(hidden_size, num_feats, label_smoothing=0.03, adaptive=True),
+ })
+ self.gold_keys = {
+ 'lemmas': 'lemma_id',
+ 'upos': 'pos_id',
+ 'feats': 'feat_id',
+ }
+
+ if mix_embedding:
+ self.scalar_mix = torch.nn.ModuleDict({
+ task: ScalarMixWithDropout((1, mix_embedding),
+ do_layer_norm=False,
+ dropout=layer_dropout)
+ for task in self.decoders
+ })
+ else:
+ self.scalar_mix = None
+
+ def forward(self,
+ hidden,
+ batch: Dict[str, torch.Tensor],
+ mask) -> Dict[str, Any]:
+ mask_without_root = mask.clone()
+ mask_without_root[:, 0] = False
+
+ logits = {}
+ class_probabilities = {}
+ output_dict = {"logits": logits,
+ "class_probabilities": class_probabilities}
+ loss = 0
+
+ arc = batch.get('arc', None)
+ # Run through each of the tasks on the shared encoder and save predictions
+ for task in self.decoders:
+ if self.scalar_mix:
+ decoder_input = self.scalar_mix[task](hidden, mask)
+ else:
+ decoder_input = hidden
+
+ if task == "deps":
+ s_arc, s_rel = self.decoders[task](decoder_input, mask)
+ pred_output = {'class_probabilities': {'s_arc': s_arc, 's_rel': s_rel}}
+ if arc is not None:
+ # noinspection PyTypeChecker
+ pred_output['loss'] = BiaffineDependencyParser.compute_loss(None, s_arc, s_rel, arc,
+ batch['rel_id'],
+ mask_without_root,
+ torch.nn.functional.cross_entropy)
+ else:
+ pred_output = self.decoders[task](decoder_input, mask_without_root,
+ batch.get(self.gold_keys[task], None))
+ if 'logits' in pred_output:
+ logits[task] = pred_output["logits"]
+ if 'class_probabilities' in pred_output:
+ class_probabilities[task] = pred_output["class_probabilities"]
+ if 'loss' in pred_output:
+ # Keep track of the loss if we have the gold tags available
+ loss += pred_output["loss"]
+
+ if arc is not None:
+ output_dict["loss"] = loss
+
+ return output_dict
+
+ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ for task in self.tasks:
+ self.decoders[task].decode(output_dict)
+
+ return output_dict
diff --git a/hanlp/components/parsers/ud/ud_parser.py b/hanlp/components/parsers/ud/ud_parser.py
new file mode 100644
index 000000000..074dd6a96
--- /dev/null
+++ b/hanlp/components/parsers/ud/ud_parser.py
@@ -0,0 +1,347 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-14 20:34
+import logging
+from copy import deepcopy
+from typing import Union, List, Callable
+
+import torch
+from alnlp.modules.util import lengths_to_mask
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import IDX
+from hanlp.common.dataset import PadSequenceDataLoader, SortingSamplerBuilder
+from hanlp.common.structure import History
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength, PunctuationMask
+from hanlp.common.vocab import Vocab
+from hanlp.components.classifiers.transformer_classifier import TransformerComponent
+from hanlp.components.parsers.biaffine.biaffine_dep import BiaffineDependencyParser
+from hanlp_common.conll import CoNLLUWord, CoNLLSentence
+from hanlp.components.parsers.ud.ud_model import UniversalDependenciesModel
+from hanlp.components.parsers.ud.util import generate_lemma_rule, append_bos, sample_form_missing
+from hanlp.components.parsers.ud.lemma_edit import apply_lemma_rule
+from hanlp.datasets.parsing.conll_dataset import CoNLLParsingDataset
+from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding
+from hanlp.metrics.accuracy import CategoricalAccuracy
+from hanlp.metrics.metric import Metric
+from hanlp.metrics.mtl import MetricDict
+from hanlp.metrics.parsing.attachmentscore import AttachmentScore
+from hanlp.utils.time_util import CountdownTimer
+from hanlp.utils.torch_util import clip_grad_norm
+from hanlp_common.util import merge_locals_kwargs, merge_dict, reorder
+
+
+class UniversalDependenciesParser(TorchComponent):
+
+ def __init__(self, **kwargs) -> None:
+ """Universal Dependencies Parsing (lemmatization, features, PoS tagging and dependency parsing) implementation
+ of "75 Languages, 1 Model: Parsing Universal Dependencies Universally" (:cite:`kondratyuk-straka-2019-75`).
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+ self.model: UniversalDependenciesModel = self.model
+
+ def build_dataloader(self,
+ data,
+ batch_size,
+ shuffle=False,
+ device=None,
+ logger: logging.Logger = None,
+ sampler_builder=None,
+ gradient_accumulation=1,
+ transformer: ContextualWordEmbedding = None,
+ **kwargs) -> DataLoader:
+ transform = [generate_lemma_rule, append_bos, self.vocabs, transformer.transform(), FieldLength('token')]
+ if not self.config.punct:
+ transform.append(PunctuationMask('token', 'punct_mask'))
+ dataset = self.build_dataset(data, transform)
+ if self.vocabs.mutable:
+ # noinspection PyTypeChecker
+ self.build_vocabs(dataset, logger)
+ lens = [len(x['token_input_ids']) for x in dataset]
+ if sampler_builder:
+ sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
+ else:
+ sampler = SortingSamplerBuilder(batch_size).build(lens, shuffle, gradient_accumulation)
+ return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler,
+ pad={'arc': 0}, )
+
+ def build_vocabs(self, trn, logger, **kwargs):
+ self.vocabs.pos = Vocab(unk_token=None, pad_token=None)
+ self.vocabs.rel = Vocab(unk_token=None, pad_token=None)
+ self.vocabs.lemma = Vocab(unk_token=None, pad_token=None)
+ self.vocabs.feat = Vocab(unk_token=None, pad_token=None)
+ timer = CountdownTimer(len(trn))
+ max_seq_len = 0
+ for each in trn:
+ max_seq_len = max(max_seq_len, len(each['token']))
+ timer.log(f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})')
+ for v in self.vocabs.values():
+ v.set_unk_as_safe_unk()
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
+
+ def build_dataset(self, data, transform):
+ dataset = CoNLLParsingDataset(data, transform=transform, prune=sample_form_missing, cache=isinstance(data, str))
+ return dataset
+
+ def build_optimizer(self, trn, **kwargs):
+ # noinspection PyCallByClass,PyTypeChecker
+ return TransformerComponent.build_optimizer(self, trn, **kwargs)
+
+ def build_criterion(self, **kwargs):
+ pass
+
+ def build_metric(self, **kwargs):
+ return MetricDict({
+ 'lemmas': CategoricalAccuracy(),
+ 'upos': CategoricalAccuracy(),
+ 'deps': AttachmentScore(),
+ 'feats': CategoricalAccuracy(),
+ })
+
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric: MetricDict = None,
+ output=False,
+ logger=None,
+ ratio_width=None,
+ **kwargs):
+
+ metric.reset()
+ self.model.eval()
+ timer = CountdownTimer(len(data))
+ total_loss = 0
+ for idx, batch in enumerate(data):
+ out, mask = self.feed_batch(batch)
+ loss = out['loss']
+ total_loss += loss.item()
+ self.decode_output(out, mask, batch)
+ self.update_metrics(metric, batch, out, mask)
+ report = f'loss: {total_loss / (idx + 1):.4f} {metric.cstr()}'
+ timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
+ del loss
+ del out
+ del mask
+ return total_loss / len(data), metric
+
+ # noinspection PyMethodOverriding
+ def build_model(self,
+ transformer: ContextualWordEmbedding,
+ n_mlp_arc,
+ n_mlp_rel,
+ mlp_dropout,
+ mix_embedding,
+ layer_dropout,
+ training=True,
+ **kwargs) -> torch.nn.Module:
+ assert bool(transformer.scalar_mix) == bool(mix_embedding), 'transformer.scalar_mix has to be 1 ' \
+ 'when mix_embedding is non-zero.'
+ # noinspection PyTypeChecker
+ return UniversalDependenciesModel(transformer.module(training=training),
+ n_mlp_arc,
+ n_mlp_rel,
+ mlp_dropout,
+ len(self.vocabs.rel),
+ len(self.vocabs.lemma),
+ len(self.vocabs.pos),
+ len(self.vocabs.feat),
+ mix_embedding,
+ layer_dropout)
+
+ def predict(self, data: Union[List[str], List[List[str]]], batch_size: int = None, **kwargs):
+ if not data:
+ return []
+ flat = self.input_is_flat(data)
+ if flat:
+ data = [data]
+ samples = self.build_samples(data)
+ if not batch_size:
+ batch_size = self.config.batch_size
+ dataloader = self.build_dataloader(samples,
+ device=self.devices[0], shuffle=False,
+ **merge_dict(self.config,
+ batch_size=batch_size,
+ overwrite=True,
+ **kwargs))
+ order = []
+ outputs = []
+ for batch in dataloader:
+ out, mask = self.feed_batch(batch)
+ self.decode_output(out, mask, batch)
+ outputs.extend(self.prediction_to_human(out, batch))
+ order.extend(batch[IDX])
+ outputs = reorder(outputs, order)
+ if flat:
+ return outputs[0]
+ return outputs
+
+ def build_samples(self, data: List[List[str]]):
+ return [{'FORM': x} for x in data]
+
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ transformer: ContextualWordEmbedding,
+ sampler_builder=None,
+ mix_embedding: int = 13,
+ layer_dropout: int = 0.1,
+ n_mlp_arc=768,
+ n_mlp_rel=256,
+ mlp_dropout=.33,
+ lr=1e-3,
+ transformer_lr=2.5e-5,
+ patience=0.1,
+ batch_size=32,
+ epochs=30,
+ gradient_accumulation=1,
+ adam_epsilon=1e-8,
+ weight_decay=0,
+ warmup_steps=0.1,
+ grad_norm=1.0,
+ tree=False,
+ proj=False,
+ punct=False,
+ logger=None,
+ verbose=True,
+ devices: Union[float, int, List[int]] = None, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
+ logger: logging.Logger, devices, ratio_width=None, patience=0.5, eval_trn=True, **kwargs):
+ if isinstance(patience, float):
+ patience = int(patience * epochs)
+ best_epoch, best_metric = 0, -1
+ timer = CountdownTimer(epochs)
+ history = History()
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, ratio_width=ratio_width,
+ eval_trn=eval_trn, **self.config)
+ loss, dev_metric = self.evaluate_dataloader(dev, criterion, metric, logger=logger, ratio_width=ratio_width)
+ timer.update()
+ report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
+ if dev_metric > best_metric:
+ best_epoch, best_metric = epoch, deepcopy(dev_metric)
+ self.save_weights(save_dir)
+ report += ' [red](saved)[/red]'
+ else:
+ report += f' ({epoch - best_epoch})'
+ if epoch - best_epoch >= patience:
+ report += ' early stop'
+ logger.info(report)
+ if epoch - best_epoch >= patience:
+ break
+ if not best_epoch:
+ self.save_weights(save_dir)
+ elif best_epoch != epoch:
+ self.load_weights(save_dir)
+ logger.info(f"Max score of dev is {best_metric.cstr()} at epoch {best_epoch}")
+ logger.info(f"Average time of each epoch is {timer.elapsed_average_human}")
+ logger.info(f"{timer.elapsed_human} elapsed")
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric: MetricDict,
+ logger: logging.Logger,
+ history: History,
+ gradient_accumulation=1,
+ grad_norm=None,
+ ratio_width=None,
+ eval_trn=True,
+ **kwargs):
+ optimizer, scheduler = optimizer
+ metric.reset()
+ self.model.train()
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
+ total_loss = 0
+ for idx, batch in enumerate(trn):
+ out, mask = self.feed_batch(batch)
+ loss = out['loss']
+ if gradient_accumulation and gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += loss.item()
+ if eval_trn:
+ self.decode_output(out, mask, batch)
+ self.update_metrics(metric, batch, out, mask)
+ if history.step(gradient_accumulation):
+ self._step(optimizer, scheduler, grad_norm)
+ report = f'loss: {total_loss / (idx + 1):.4f} {metric.cstr()}' if eval_trn \
+ else f'loss: {total_loss / (idx + 1):.4f}'
+ timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
+ del loss
+ del out
+ del mask
+
+ def decode_output(self, outputs, mask, batch):
+ arc_scores, rel_scores = outputs['class_probabilities']['deps']['s_arc'], \
+ outputs['class_probabilities']['deps']['s_rel']
+ arc_preds, rel_preds = BiaffineDependencyParser.decode(self, arc_scores, rel_scores, mask, batch)
+ outputs['arc_preds'], outputs['rel_preds'] = arc_preds, rel_preds
+ return outputs
+
+ def update_metrics(self, metrics, batch, outputs, mask):
+ arc_preds, rel_preds, puncts = outputs['arc_preds'], outputs['rel_preds'], batch.get('punct_mask', None)
+ BiaffineDependencyParser.update_metric(self, arc_preds, rel_preds, batch['arc'], batch['rel_id'], mask, puncts,
+ metrics['deps'], batch)
+ for task, key in zip(['lemmas', 'upos', 'feats'], ['lemma_id', 'pos_id', 'feat_id']):
+ metric: Metric = metrics[task]
+ pred = outputs['class_probabilities'][task]
+ gold = batch[key]
+ metric(pred.detach(), gold, mask=mask)
+ return metrics
+
+ def feed_batch(self, batch: dict):
+ mask = self.compute_mask(batch)
+ output_dict = self.model(batch, mask)
+ if self.model.training:
+ mask = mask.clone()
+ mask[:, 0] = 0
+ return output_dict, mask
+
+ def compute_mask(self, batch):
+ lens = batch['token_length']
+ mask = lengths_to_mask(lens)
+ return mask
+
+ def _step(self, optimizer, scheduler, grad_norm):
+ clip_grad_norm(self.model, grad_norm)
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+
+ def input_is_flat(self, data):
+ # noinspection PyCallByClass,PyTypeChecker
+ return BiaffineDependencyParser.input_is_flat(self, data, False)
+
+ def prediction_to_human(self, outputs: dict, batch):
+ arcs, rels = outputs['arc_preds'], outputs['rel_preds']
+ upos = outputs['class_probabilities']['upos'][:, 1:, :].argmax(-1).tolist()
+ feats = outputs['class_probabilities']['feats'][:, 1:, :].argmax(-1).tolist()
+ lemmas = outputs['class_probabilities']['lemmas'][:, 1:, :].argmax(-1).tolist()
+ lem_vocab = self.vocabs['lemma'].idx_to_token
+ pos_vocab = self.vocabs['pos'].idx_to_token
+ feat_vocab = self.vocabs['feat'].idx_to_token
+ # noinspection PyCallByClass,PyTypeChecker
+ for tree, form, lemma, pos, feat in zip(BiaffineDependencyParser.prediction_to_head_rel(
+ self, arcs, rels, batch), batch['token'], lemmas, upos, feats):
+ form = form[1:]
+ assert len(form) == len(tree)
+ lemma = [apply_lemma_rule(t, lem_vocab[r]) for t, r in zip(form, lemma)]
+ pos = [pos_vocab[x] for x in pos]
+ feat = [feat_vocab[x] for x in feat]
+ yield CoNLLSentence(
+ [CoNLLUWord(id=i + 1, form=fo, lemma=l, upos=p, feats=fe, head=a, deprel=r) for
+ i, (fo, (a, r), l, p, fe) in enumerate(zip(form, tree, lemma, pos, feat))])
+
+ def __call__(self, data, batch_size=None, **kwargs) -> Union[CoNLLSentence, List[CoNLLSentence]]:
+ return super().__call__(data, batch_size, **kwargs)
diff --git a/hanlp/components/parsers/ud/udify_util.py b/hanlp/components/parsers/ud/udify_util.py
new file mode 100644
index 000000000..9dbbb7a7c
--- /dev/null
+++ b/hanlp/components/parsers/ud/udify_util.py
@@ -0,0 +1,369 @@
+# This file is modified from udify and allennlp, which are licensed under the MIT license:
+# MIT License
+#
+# Copyright (c) 2019 Dan Kondratyuk and allennlp
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import os
+from typing import List, Dict, Tuple, Union
+
+import numpy
+import torch
+
+
+def get_ud_treebank_files(dataset_dir: str, treebanks: List[str] = None) -> Dict[str, Tuple[str, str, str]]:
+ """Retrieves all treebank data paths in the given directory.
+ Adopted from https://github.com/Hyperparticle/udify
+ MIT Licence
+
+ Args:
+ dataset_dir:
+ treebanks:
+ dataset_dir: str:
+ treebanks: List[str]: (Default value = None)
+
+ Returns:
+
+
+ """
+ datasets = {}
+ treebanks = os.listdir(dataset_dir) if not treebanks else treebanks
+ for treebank in treebanks:
+ treebank_path = os.path.join(dataset_dir, treebank)
+ conllu_files = [file for file in sorted(os.listdir(treebank_path)) if file.endswith(".conllu")]
+
+ train_file = [file for file in conllu_files if file.endswith("train.conllu")]
+ dev_file = [file for file in conllu_files if file.endswith("dev.conllu")]
+ test_file = [file for file in conllu_files if file.endswith("test.conllu")]
+
+ train_file = os.path.join(treebank_path, train_file[0]) if train_file else None
+ dev_file = os.path.join(treebank_path, dev_file[0]) if dev_file else None
+ test_file = os.path.join(treebank_path, test_file[0]) if test_file else None
+
+ datasets[treebank] = (train_file, dev_file, test_file)
+ return datasets
+
+
+def sequence_cross_entropy(log_probs: torch.FloatTensor,
+ targets: torch.LongTensor,
+ weights: torch.FloatTensor,
+ average: str = "batch",
+ label_smoothing: float = None) -> torch.FloatTensor:
+ if average not in {None, "token", "batch"}:
+ raise ValueError("Got average f{average}, expected one of "
+ "None, 'token', or 'batch'")
+ # shape : (batch * sequence_length, num_classes)
+ log_probs_flat = log_probs.view(-1, log_probs.size(2))
+ # shape : (batch * max_len, 1)
+ targets_flat = targets.view(-1, 1).long()
+
+ if label_smoothing is not None and label_smoothing > 0.0:
+ num_classes = log_probs.size(-1)
+ smoothing_value = label_smoothing / num_classes
+ # Fill all the correct indices with 1 - smoothing value.
+ one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing)
+ smoothed_targets = one_hot_targets + smoothing_value
+ negative_log_likelihood_flat = - log_probs_flat * smoothed_targets
+ negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
+ else:
+ # Contribution to the negative log likelihood only comes from the exact indices
+ # of the targets, as the target distributions are one-hot. Here we use torch.gather
+ # to extract the indices of the num_classes dimension which contribute to the loss.
+ # shape : (batch * sequence_length, 1)
+ negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
+ # shape : (batch, sequence_length)
+ negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
+ # shape : (batch, sequence_length)
+ negative_log_likelihood = negative_log_likelihood * weights.float()
+
+ if average == "batch":
+ # shape : (batch_size,)
+ per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)
+ num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
+ return per_batch_loss.sum() / num_non_empty_sequences
+ elif average == "token":
+ return negative_log_likelihood.sum() / (weights.sum().float() + 1e-13)
+ else:
+ # shape : (batch_size,)
+ per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)
+ return per_batch_loss
+
+
+def sequence_cross_entropy_with_logits(
+ logits: torch.FloatTensor,
+ targets: torch.LongTensor,
+ weights: Union[torch.FloatTensor, torch.BoolTensor],
+ average: str = "batch",
+ label_smoothing: float = None,
+ gamma: float = None,
+ alpha: Union[float, List[float], torch.FloatTensor] = None,
+) -> torch.FloatTensor:
+ """Computes the cross entropy loss of a sequence, weighted with respect to
+ some user provided weights. Note that the weighting here is not the same as
+ in the `torch.nn.CrossEntropyLoss()` criterion, which is weighting
+ classes; here we are weighting the loss contribution from particular elements
+ in the sequence. This allows loss computations for models which use padding.
+
+ # Parameters
+
+ logits : `torch.FloatTensor`, required.
+ A `torch.FloatTensor` of size (batch_size, sequence_length, num_classes)
+ which contains the unnormalized probability for each class.
+ targets : `torch.LongTensor`, required.
+ A `torch.LongTensor` of size (batch, sequence_length) which contains the
+ index of the true class for each corresponding step.
+ weights : `Union[torch.FloatTensor, torch.BoolTensor]`, required.
+ A `torch.FloatTensor` of size (batch, sequence_length)
+ average: `str`, optional (default = `"batch"`)
+ If "batch", average the loss across the batches. If "token", average
+ the loss across each item in the input. If `None`, return a vector
+ of losses per batch element.
+ label_smoothing : `float`, optional (default = `None`)
+ Whether or not to apply label smoothing to the cross-entropy loss.
+ For example, with a label smoothing value of 0.2, a 4 class classification
+ target would look like `[0.05, 0.05, 0.85, 0.05]` if the 3rd class was
+ the correct label.
+ gamma : `float`, optional (default = `None`)
+ Focal loss[*] focusing parameter `gamma` to reduces the relative loss for
+ well-classified examples and put more focus on hard. The greater value
+ `gamma` is, the more focus on hard examples.
+ alpha : `Union[float, List[float]]`, optional (default = `None`)
+ Focal loss[*] weighting factor `alpha` to balance between classes. Can be
+ used independently with `gamma`. If a single `float` is provided, it
+ is assumed binary case using `alpha` and `1 - alpha` for positive and
+ negative respectively. If a list of `float` is provided, with the same
+ length as the number of classes, the weights will match the classes.
+ [*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for
+ Dense Object Detection," 2017 IEEE International Conference on Computer
+ Vision (ICCV), Venice, 2017, pp. 2999-3007.
+
+ # Returns
+
+ `torch.FloatTensor`
+ A torch.FloatTensor representing the cross entropy loss.
+ If `average=="batch"` or `average=="token"`, the returned loss is a scalar.
+ If `average is None`, the returned loss is a vector of shape (batch_size,).
+
+ Args:
+ logits: torch.FloatTensor:
+ targets: torch.LongTensor:
+ weights: Union[torch.FloatTensor:
+ torch.BoolTensor]:
+ average: str: (Default value = "batch")
+ label_smoothing: float: (Default value = None)
+ gamma: float: (Default value = None)
+ alpha: Union[float:
+ List[float]:
+ torch.FloatTensor]: (Default value = None)
+
+ Returns:
+
+ """
+ if average not in {None, "token", "batch"}:
+ raise ValueError("Got average f{average}, expected one of None, 'token', or 'batch'")
+
+ # make sure weights are float
+ weights = weights.to(logits.dtype)
+ # sum all dim except batch
+ non_batch_dims = tuple(range(1, len(weights.shape)))
+ # shape : (batch_size,)
+ weights_batch_sum = weights.sum(dim=non_batch_dims)
+ # shape : (batch * sequence_length, num_classes)
+ logits_flat = logits.view(-1, logits.size(-1))
+ # shape : (batch * sequence_length, num_classes)
+ log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
+ # shape : (batch * max_len, 1)
+ targets_flat = targets.view(-1, 1).long()
+ # focal loss coefficient
+ if gamma:
+ # shape : (batch * sequence_length, num_classes)
+ probs_flat = log_probs_flat.exp()
+ # shape : (batch * sequence_length,)
+ probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat)
+ # shape : (batch * sequence_length,)
+ focal_factor = (1.0 - probs_flat) ** gamma
+ # shape : (batch, sequence_length)
+ focal_factor = focal_factor.view(*targets.size())
+ weights = weights * focal_factor
+
+ if alpha is not None:
+ # shape : () / (num_classes,)
+ if isinstance(alpha, (float, int)):
+
+ # shape : (2,)
+ alpha_factor = torch.tensor(
+ [1.0 - float(alpha), float(alpha)], dtype=weights.dtype, device=weights.device
+ )
+
+ elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)):
+
+ # shape : (c,)
+ alpha_factor = torch.tensor(alpha, dtype=weights.dtype, device=weights.device)
+
+ if not alpha_factor.size():
+ # shape : (1,)
+ alpha_factor = alpha_factor.view(1)
+ # shape : (2,)
+ alpha_factor = torch.cat([1 - alpha_factor, alpha_factor])
+ else:
+ raise TypeError(
+ ("alpha must be float, list of float, or torch.FloatTensor, {} provided.").format(
+ type(alpha)
+ )
+ )
+ # shape : (batch, max_len)
+ alpha_factor = torch.gather(alpha_factor, dim=0, index=targets_flat.view(-1)).view(
+ *targets.size()
+ )
+ weights = weights * alpha_factor
+
+ if label_smoothing is not None and label_smoothing > 0.0:
+ num_classes = logits.size(-1)
+ smoothing_value = label_smoothing / num_classes
+ # Fill all the correct indices with 1 - smoothing value.
+ one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(
+ -1, targets_flat, 1.0 - label_smoothing
+ )
+ smoothed_targets = one_hot_targets + smoothing_value
+ negative_log_likelihood_flat = -log_probs_flat * smoothed_targets
+ negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
+ else:
+ # Contribution to the negative log likelihood only comes from the exact indices
+ # of the targets, as the target distributions are one-hot. Here we use torch.gather
+ # to extract the indices of the num_classes dimension which contribute to the loss.
+ # shape : (batch * sequence_length, 1)
+ negative_log_likelihood_flat = -torch.gather(log_probs_flat, dim=1, index=targets_flat)
+ # shape : (batch, sequence_length)
+ negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
+ # shape : (batch, sequence_length)
+ negative_log_likelihood = negative_log_likelihood * weights
+
+ if average == "batch":
+ # shape : (batch_size,)
+ per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (
+ weights_batch_sum + tiny_value_of_dtype(negative_log_likelihood.dtype)
+ )
+ num_non_empty_sequences = (weights_batch_sum > 0).sum() + tiny_value_of_dtype(
+ negative_log_likelihood.dtype
+ )
+ return per_batch_loss.sum() / num_non_empty_sequences
+ elif average == "token":
+ return negative_log_likelihood.sum() / (
+ weights_batch_sum.sum() + tiny_value_of_dtype(negative_log_likelihood.dtype)
+ )
+ else:
+ # shape : (batch_size,)
+ per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (
+ weights_batch_sum + tiny_value_of_dtype(negative_log_likelihood.dtype)
+ )
+ return per_batch_loss
+
+
+def tiny_value_of_dtype(dtype: torch.dtype):
+ """Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical
+ issues such as division by zero.
+ This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs.
+ Only supports floating point dtypes.
+
+ Args:
+ dtype: torch.dtype:
+
+ Returns:
+
+ """
+ if not dtype.is_floating_point:
+ raise TypeError("Only supports floating point dtypes.")
+ if dtype == torch.float or dtype == torch.double:
+ return 1e-13
+ elif dtype == torch.half:
+ return 1e-4
+ else:
+ raise TypeError("Does not support dtype " + str(dtype))
+
+
+def combine_initial_dims_to_1d_or_2d(tensor: torch.Tensor) -> torch.Tensor:
+ """Given a (possibly higher order) tensor of ids with shape
+ (d1, ..., dn, sequence_length)
+
+ Args:
+ tensor: torch.Tensor:
+
+ Returns:
+ If original tensor is 1-d or 2-d, return it as is.
+
+ """
+ if tensor.dim() <= 2:
+ return tensor
+ else:
+ return tensor.view(-1, tensor.size(-1))
+
+
+def uncombine_initial_dims(tensor: torch.Tensor, original_size: torch.Size) -> torch.Tensor:
+ """Given a tensor of embeddings with shape
+ (d1 * ... * dn, sequence_length, embedding_dim)
+ and the original shape
+ (d1, ..., dn, sequence_length),
+
+ Args:
+ tensor: torch.Tensor:
+ original_size: torch.Size:
+
+ Returns:
+ (d1, ..., dn, sequence_length, embedding_dim).
+ If original size is 1-d or 2-d, return it as is.
+
+ """
+ if len(original_size) <= 2:
+ return tensor
+ else:
+ view_args = list(original_size) + [tensor.size(-1)]
+ return tensor.view(*view_args)
+
+
+def get_range_vector(size: int, device: int) -> torch.Tensor:
+ """Returns a range vector with the desired size, starting at 0. The CUDA implementation
+ is meant to avoid copy data from CPU to GPU.
+
+ Args:
+ size: int:
+ device: int:
+
+ Returns:
+
+ """
+ if device > -1:
+ return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1
+ else:
+ return torch.arange(0, size, dtype=torch.long)
+
+
+def get_device_of(tensor: torch.Tensor) -> int:
+ """Returns the device of the tensor.
+
+ Args:
+ tensor: torch.Tensor:
+
+ Returns:
+
+ """
+ if not tensor.is_cuda:
+ return -1
+ else:
+ return tensor.get_device()
diff --git a/hanlp/components/parsers/ud/util.py b/hanlp/components/parsers/ud/util.py
new file mode 100644
index 000000000..d7420f234
--- /dev/null
+++ b/hanlp/components/parsers/ud/util.py
@@ -0,0 +1,28 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-14 20:44
+from hanlp_common.constant import ROOT
+from hanlp.components.parsers.ud.lemma_edit import gen_lemma_rule
+
+
+def generate_lemma_rule(sample: dict):
+ if 'LEMMA' in sample:
+ sample['lemma'] = [gen_lemma_rule(word, lemma) if lemma != "_" else "_" for word, lemma in
+ zip(sample['FORM'], sample['LEMMA'])]
+ return sample
+
+
+def append_bos(sample: dict):
+ if 'FORM' in sample:
+ sample['token'] = [ROOT] + sample['FORM']
+ if 'UPOS' in sample:
+ sample['pos'] = sample['UPOS'][:1] + sample['UPOS']
+ sample['arc'] = [0] + sample['HEAD']
+ sample['rel'] = sample['DEPREL'][:1] + sample['DEPREL']
+ sample['lemma'] = sample['lemma'][:1] + sample['lemma']
+ sample['feat'] = sample['FEATS'][:1] + sample['FEATS']
+ return sample
+
+
+def sample_form_missing(sample: dict):
+ return all(t == '_' for t in sample['FORM'])
diff --git a/hanlp/components/pipeline.py b/hanlp/components/pipeline.py
index a836f0fb3..a919ddfd3 100644
--- a/hanlp/components/pipeline.py
+++ b/hanlp/components/pipeline.py
@@ -2,13 +2,13 @@
# Author: hankcs
# Date: 2019-12-31 00:22
import types
-from typing import Callable, List, Generator, Union, Any, Tuple, Iterable
+from typing import Callable, Union, Iterable
from hanlp.components.lambda_wrapper import LambdaComponent
from hanlp.common.component import Component
-from hanlp.common.document import Document
+from hanlp_common.document import Document
from hanlp.utils.component_util import load_from_meta
-from hanlp.utils.io_util import save_json, load_json
-from hanlp.utils.reflection import module_path_of, str_to_type, class_path_of
+from hanlp_common.io import save_json, load_json
+from hanlp_common.reflection import str_to_type, classpath_of
import hanlp
@@ -16,12 +16,14 @@ class Pipe(Component):
def __init__(self, component: Component, input_key: str = None, output_key: str = None, **kwargs) -> None:
super().__init__()
+ if not hasattr(self, 'config'):
+ self.config = {'classpath': classpath_of(self)}
self.output_key = output_key
self.input_key = input_key
self.component = component
self.kwargs = kwargs
- self.meta.update({
- 'component': component.meta,
+ self.config.update({
+ 'component': component.config,
'input_key': self.input_key,
'output_key': self.output_key,
'kwargs': self.kwargs
@@ -65,8 +67,8 @@ def __repr__(self):
return f'{self.input_key}->{self.component.__class__.__name__}->{self.output_key}'
@staticmethod
- def from_meta(meta: dict, **kwargs):
- cls = str_to_type(meta['class_path'])
+ def from_config(meta: dict, **kwargs):
+ cls = str_to_type(meta['classpath'])
component = load_from_meta(meta['component'])
return cls(component, meta['input_key'], meta['output_key'], **meta['kwargs'])
@@ -74,6 +76,8 @@ def from_meta(meta: dict, **kwargs):
class Pipeline(Component, list):
def __init__(self, *pipes: Pipe) -> None:
super().__init__()
+ if not hasattr(self, 'config'):
+ self.config = {'classpath': classpath_of(self)}
if pipes:
self.extend(pipes)
@@ -100,9 +104,9 @@ def __call__(self, doc: Document, **kwargs) -> Document:
@property
def meta(self):
return {
- 'class_path': class_path_of(self),
+ 'classpath': classpath_of(self),
'hanlp_version': hanlp.version.__version__,
- 'pipes': [pipe.meta for pipe in self]
+ 'pipes': [pipe.config for pipe in self]
}
@meta.setter
@@ -115,10 +119,10 @@ def save(self, filepath):
def load(self, filepath):
meta = load_json(filepath)
self.clear()
- self.extend(Pipeline.from_meta(meta))
+ self.extend(Pipeline.from_config(meta))
@staticmethod
- def from_meta(meta: Union[dict, str], **kwargs):
+ def from_config(meta: Union[dict, str], **kwargs):
if isinstance(meta, str):
meta = load_json(meta)
return Pipeline(*[load_from_meta(pipe) for pipe in meta['pipes']])
diff --git a/hanlp/components/pos.py b/hanlp/components/pos.py
deleted file mode 100644
index b6791e0bd..000000000
--- a/hanlp/components/pos.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-05 23:05
-from hanlp.components.taggers.cnn_tagger import CNNTagger
-from hanlp.components.taggers.rnn_tagger import RNNTagger
-
-
-class CNNPartOfSpeechTagger(CNNTagger):
- pass
-
-
-class RNNPartOfSpeechTagger(RNNTagger):
- pass
diff --git a/hanlp/components/pos_tf.py b/hanlp/components/pos_tf.py
new file mode 100644
index 000000000..ab9f17117
--- /dev/null
+++ b/hanlp/components/pos_tf.py
@@ -0,0 +1,13 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-05 23:05
+from hanlp.components.taggers.cnn_tagger_tf import CNNTaggerTF
+from hanlp.components.taggers.rnn_tagger_tf import RNNTaggerTF
+
+
+class CNNPartOfSpeechTaggerTF(CNNTaggerTF):
+ pass
+
+
+class RNNPartOfSpeechTaggerTF(RNNTaggerTF):
+ pass
diff --git a/hanlp/components/rnn_language_model.py b/hanlp/components/rnn_language_model.py
index 6bb165098..e626bb003 100644
--- a/hanlp/components/rnn_language_model.py
+++ b/hanlp/components/rnn_language_model.py
@@ -5,7 +5,7 @@
import tensorflow as tf
-from hanlp.common.component import KerasComponent
+from hanlp.common.keras_component import KerasComponent
from hanlp.transform.text import TextTransform
diff --git a/hanlp/components/srl/__init__.py b/hanlp/components/srl/__init__.py
new file mode 100644
index 000000000..9aa83b12b
--- /dev/null
+++ b/hanlp/components/srl/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-22 20:50
\ No newline at end of file
diff --git a/hanlp/components/srl/span_bio/__init__.py b/hanlp/components/srl/span_bio/__init__.py
new file mode 100644
index 000000000..347fc1872
--- /dev/null
+++ b/hanlp/components/srl/span_bio/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-04 13:59
diff --git a/hanlp/components/srl/span_bio/baffine_tagging.py b/hanlp/components/srl/span_bio/baffine_tagging.py
new file mode 100644
index 000000000..9d2a88770
--- /dev/null
+++ b/hanlp/components/srl/span_bio/baffine_tagging.py
@@ -0,0 +1,75 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-04 13:59
+import math
+
+import torch
+from torch import nn
+
+from hanlp.components.parsers.biaffine.biaffine import Biaffine
+from hanlp.components.parsers.biaffine.mlp import MLP
+from hanlp.layers.crf.crf import CRF
+
+
+class BiaffineTaggingDecoder(nn.Module):
+
+ def __init__(self,
+ n_rels,
+ hidden_size,
+ n_mlp_rel=300,
+ mlp_dropout=0.2,
+ crf=False) -> None:
+ super().__init__()
+ self.mlp_rel_h = MLP(n_in=hidden_size,
+ n_out=n_mlp_rel,
+ dropout=mlp_dropout)
+ self.mlp_rel_d = MLP(n_in=hidden_size,
+ n_out=n_mlp_rel,
+ dropout=mlp_dropout)
+ self.rel_attn = Biaffine(n_in=n_mlp_rel,
+ n_out=n_rels,
+ bias_x=True,
+ bias_y=True)
+ bias = 1 / math.sqrt(self.rel_attn.weight.size(1))
+ nn.init.uniform_(self.rel_attn.weight, -bias, bias)
+ self.crf = CRF(n_rels) if crf else None
+
+ # noinspection PyUnusedLocal
+ def forward(self, x: torch.Tensor, **kwargs):
+ rel_h = self.mlp_rel_h(x)
+ rel_d = self.mlp_rel_d(x)
+
+ # get arc and rel scores from the bilinear attention
+ # [batch_size, seq_len, seq_len, n_rels]
+ s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1)
+ return s_rel
+
+
+class SpanBIOSemanticRoleLabelingModel(nn.Module):
+
+ def __init__(self,
+ embed,
+ encoder,
+ num_labels: int,
+ n_mlp_rel,
+ mlp_dropout,
+ crf=False,
+ ) -> None:
+ super().__init__()
+ self.embed = embed
+ self.encoder = encoder
+ hidden_size = encoder.get_output_dim() if encoder else embed.get_output_dim()
+ self.decoder = BiaffineTaggingDecoder(
+ num_labels,
+ hidden_size,
+ n_mlp_rel,
+ mlp_dropout,
+ crf,
+ )
+
+ def forward(self, batch, mask):
+ x = self.embed(batch)
+ if self.encoder:
+ x = self.encoder(x, mask=mask)
+ x = self.decoder(x)
+ return x
diff --git a/hanlp/components/srl/span_bio/span_bio.py b/hanlp/components/srl/span_bio/span_bio.py
new file mode 100644
index 000000000..c6a4dc25e
--- /dev/null
+++ b/hanlp/components/srl/span_bio/span_bio.py
@@ -0,0 +1,380 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-22 20:54
+import logging
+from copy import copy
+from typing import Union, List, Callable, Dict, Any
+from bisect import bisect
+import torch
+import torch.nn.functional as F
+from alnlp.modules.util import lengths_to_mask
+from torch import nn
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import IDX, PRED
+from hanlp.common.dataset import PadSequenceDataLoader, SamplerBuilder, TransformableDataset
+from hanlp.common.structure import History
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength
+from hanlp.common.vocab import Vocab
+from hanlp.components.srl.span_bio.baffine_tagging import SpanBIOSemanticRoleLabelingModel
+from hanlp.datasets.srl.conll2012 import CoNLL2012SRLBIODataset
+from hanlp.layers.crf.crf import CRF
+from hanlp.layers.embeddings.contextual_word_embedding import find_transformer
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.layers.transformers.utils import build_optimizer_scheduler_with_transformer
+from hanlp.metrics.chunking.sequence_labeling import get_entities
+from hanlp.metrics.f1 import F1
+from hanlp.utils.string_util import guess_delimiter
+from hanlp.utils.time_util import CountdownTimer
+from hanlp.utils.torch_util import clip_grad_norm
+from hanlp_common.util import merge_locals_kwargs, reorder
+
+
+class SpanBIOSemanticRoleLabeler(TorchComponent):
+
+ def __init__(self, **kwargs) -> None:
+ """A span based Semantic Role Labeling task using BIO scheme for tagging the role of each token. Given a
+ predicate and a token, it uses biaffine (:cite:`dozat:17a`) to predict their relations as one of BIO-ROLE.
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+ self.model: SpanBIOSemanticRoleLabelingModel = None
+
+ def build_optimizer(self,
+ trn,
+ epochs,
+ lr,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ transformer_lr=None,
+ gradient_accumulation=1,
+ **kwargs):
+ num_training_steps = len(trn) * epochs // gradient_accumulation
+ if transformer_lr is None:
+ transformer_lr = lr
+ transformer = find_transformer(self.model.embed)
+ optimizer, scheduler = build_optimizer_scheduler_with_transformer(self.model, transformer,
+ lr, transformer_lr,
+ num_training_steps, warmup_steps,
+ weight_decay, adam_epsilon)
+ return optimizer, scheduler
+
+ def build_criterion(self, decoder=None, **kwargs):
+ if self.config.crf:
+ if not decoder:
+ decoder = self.model.decoder
+ if isinstance(decoder, torch.nn.DataParallel):
+ decoder = decoder.module
+ return decoder.crf
+ else:
+ return nn.CrossEntropyLoss(reduction=self.config.loss_reduction)
+
+ def build_metric(self, **kwargs):
+ return F1()
+
+ def execute_training_loop(self,
+ trn: DataLoader,
+ dev: DataLoader,
+ epochs,
+ criterion,
+ optimizer,
+ metric,
+ save_dir,
+ logger: logging.Logger,
+ devices,
+ ratio_width=None,
+ patience=0.5,
+ **kwargs):
+ if isinstance(patience, float):
+ patience = int(patience * epochs)
+ best_epoch, best_metric = 0, -1
+ timer = CountdownTimer(epochs)
+ history = History()
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, ratio_width=ratio_width,
+ **self.config)
+ loss, dev_metric = self.evaluate_dataloader(dev, criterion, metric, logger=logger, ratio_width=ratio_width)
+ timer.update()
+ report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
+ if dev_metric > best_metric:
+ best_epoch, best_metric = epoch, copy(dev_metric)
+ self.save_weights(save_dir)
+ report += ' [red](saved)[/red]'
+ else:
+ report += f' ({epoch - best_epoch})'
+ if epoch - best_epoch >= patience:
+ report += ' early stop'
+ logger.info(report)
+ if epoch - best_epoch >= patience:
+ break
+ if not best_epoch:
+ self.save_weights(save_dir)
+ elif best_epoch != epoch:
+ self.load_weights(save_dir)
+ logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}")
+ logger.info(f"Average time of each epoch is {timer.elapsed_average_human}")
+ logger.info(f"{timer.elapsed_human} elapsed")
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ history: History,
+ gradient_accumulation=1,
+ grad_norm=None,
+ ratio_width=None,
+ eval_trn=False,
+ **kwargs):
+ optimizer, scheduler = optimizer
+ self.model.train()
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
+ total_loss = 0
+ for idx, batch in enumerate(trn):
+ pred, mask = self.feed_batch(batch)
+ loss = self.compute_loss(criterion, pred, batch['srl_id'], mask)
+ if gradient_accumulation and gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ loss.backward()
+ total_loss += loss.item()
+ if eval_trn:
+ prediction = self.decode_output(pred, mask, batch)
+ self.update_metrics(metric, prediction, batch)
+ if history.step(gradient_accumulation):
+ self._step(optimizer, scheduler, grad_norm)
+ report = f'loss: {total_loss / (idx + 1):.4f} {metric}' if eval_trn else f'loss: {total_loss / (idx + 1):.4f}'
+ timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
+ del loss
+ del pred
+ del mask
+
+ def naive_decode(self, pred, mask, batch, decoder=None):
+ vocab = self.vocabs['srl'].idx_to_token
+ results = []
+ for sent, matrix in zip(batch['token'], pred.argmax(-1).tolist()):
+ results.append([])
+ for token, tags_per_token in zip(sent, matrix):
+ tags_per_token = [vocab[x] for x in tags_per_token][:len(sent)]
+ srl_per_token = get_entities(tags_per_token)
+ results[-1].append(srl_per_token)
+ return results
+
+ def decode_output(self, pred, mask, batch, decoder=None):
+ # naive = self.naive_decode(pred, mask, batch, decoder)
+ vocab = self.vocabs['srl'].idx_to_token
+ if self.config.crf:
+ if not decoder:
+ decoder = self.model.decoder
+ crf: CRF = decoder.crf
+ token_index, mask = mask
+ pred = crf.decode(pred, mask)
+ pred = sum(pred, [])
+ else:
+ pred = pred[mask].argmax(-1)
+ pred = pred.tolist()
+ pred = [vocab[x] for x in pred]
+ results = []
+ offset = 0
+ for sent in batch['token']:
+ results.append([])
+ for token in sent:
+ tags_per_token = pred[offset:offset + len(sent)]
+ srl_per_token = get_entities(tags_per_token)
+ results[-1].append(srl_per_token)
+ offset += len(sent)
+ assert offset == len(pred)
+ # assert results == naive
+ return results
+
+ def update_metrics(self, metric, prediction, batch):
+ for p, g in zip(prediction, batch['srl_set']):
+ srl = set()
+ for i, args in enumerate(p):
+ srl.update((i, start, end, label) for (label, start, end) in args)
+ metric(srl, g)
+ return metric
+
+ def feed_batch(self, batch: dict):
+ lens = batch['token_length']
+ mask2d = lengths_to_mask(lens)
+ pred = self.model(batch, mask=mask2d)
+ mask3d = self.compute_mask(mask2d)
+ if self.config.crf:
+ token_index = mask3d[0]
+ pred = pred.flatten(end_dim=1)[token_index]
+ pred = F.log_softmax(pred, dim=-1)
+ return pred, mask3d
+
+ def compute_mask(self, mask2d):
+ mask3d = mask2d.unsqueeze_(-1).expand(-1, -1, mask2d.size(1))
+ mask3d = mask3d & mask3d.transpose(1, 2)
+ if self.config.crf:
+ mask3d = mask3d.flatten(end_dim=1)
+ token_index = mask3d[:, 0]
+ mask3d = mask3d[token_index]
+ return token_index, mask3d
+ else:
+ return mask3d
+
+ def _step(self, optimizer, scheduler, grad_norm):
+ clip_grad_norm(self.model, grad_norm)
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+
+ # noinspection PyMethodOverriding
+ def build_model(self, embed: Embedding, encoder, training, **kwargs) -> torch.nn.Module:
+ # noinspection PyCallByClass
+ model = SpanBIOSemanticRoleLabelingModel(
+ embed.module(training=training, vocabs=self.vocabs),
+ encoder,
+ len(self.vocabs.srl),
+ self.config.n_mlp_rel,
+ self.config.mlp_dropout,
+ self.config.crf,
+ )
+ return model
+
+ # noinspection PyMethodOverriding
+ def build_dataloader(self, data, batch_size,
+ sampler_builder: SamplerBuilder = None,
+ gradient_accumulation=1,
+ shuffle=False, device=None, logger: logging.Logger = None,
+ **kwargs) -> DataLoader:
+ if isinstance(data, TransformableDataset):
+ dataset = data
+ else:
+ dataset = self.build_dataset(data, [self.config.embed.transform(vocabs=self.vocabs), self.vocabs,
+ FieldLength('token')])
+ if self.vocabs.mutable:
+ # noinspection PyTypeChecker
+ self.build_vocabs(dataset, logger)
+ lens = [len(x['token_input_ids']) for x in dataset]
+ if sampler_builder:
+ sampler = sampler_builder.build(lens, shuffle, gradient_accumulation)
+ else:
+ sampler = None
+ return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
+
+ def build_dataset(self, data, transform):
+ dataset = CoNLL2012SRLBIODataset(data,
+ transform=transform,
+ doc_level_offset=self.config.get('doc_level_offset', True),
+ cache=isinstance(data, str))
+ return dataset
+
+ def build_vocabs(self, dataset, logger, **kwargs):
+ self.vocabs.srl = Vocab(pad_token=None, unk_token=None)
+ timer = CountdownTimer(len(dataset))
+ max_seq_len = 0
+ for sample in dataset:
+ max_seq_len = max(max_seq_len, len(sample['token_input_ids']))
+ timer.log(f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})')
+ self.vocabs['srl'].set_unk_as_safe_unk() # C-ARGM-FRQ appears only in test set
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
+ if self.config.get('delimiter') is None:
+ tokens = dataset[0]['token']
+ self.config.delimiter = guess_delimiter(tokens)
+ logger.info(f'Guess the delimiter between tokens could be [blue]"{self.config.delimiter}"[/blue]. '
+ f'If not, specify `delimiter` in `fit()`')
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
+ if not data:
+ return []
+ flat = self.input_is_flat(data)
+ if flat:
+ data = [data]
+ dataloader = self.build_dataloader(self.build_samples(data), batch_size, device=self.device, **kwargs)
+ results = []
+ order = []
+ for batch in dataloader:
+ pred, mask = self.feed_batch(batch)
+ prediction = self.decode_output(pred, mask, batch)
+ results.extend(self.prediction_to_result(prediction, batch))
+ order.extend(batch[IDX])
+ results = reorder(results, order)
+ if flat:
+ return results[0]
+ return results
+
+ def build_samples(self, data):
+ return [{'token': token} for token in data]
+
+ # noinspection PyMethodOverriding
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ embed,
+ encoder=None,
+ lr=1e-3,
+ transformer_lr=1e-4,
+ adam_epsilon=1e-8,
+ warmup_steps=0.1,
+ weight_decay=0,
+ crf=False,
+ n_mlp_rel=300,
+ mlp_dropout=0.2,
+ batch_size=32,
+ gradient_accumulation=1,
+ grad_norm=1,
+ loss_reduction='mean',
+ epochs=30,
+ delimiter=None,
+ doc_level_offset=True,
+ eval_trn=False,
+ logger=None,
+ devices: Union[float, int, List[int]] = None,
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def compute_loss(self, criterion, pred, srl, mask):
+ if self.config.crf:
+ token_index, mask = mask
+ criterion: CRF = criterion
+ loss = -criterion.forward(pred, srl.flatten(end_dim=1)[token_index], mask,
+ reduction=self.config.loss_reduction)
+ else:
+ loss = criterion(pred[mask], srl[mask])
+ return loss
+
+ # noinspection PyMethodOverriding
+ @torch.no_grad()
+ def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric, logger, ratio_width=None,
+ filename=None, **kwargs):
+ self.model.eval()
+ timer = CountdownTimer(len(data))
+ total_loss = 0
+ metric.reset()
+ for idx, batch in enumerate(data):
+ pred, mask = self.feed_batch(batch)
+ loss = self.compute_loss(criterion, pred, batch['srl_id'], mask)
+ total_loss += loss.item()
+ prediction = self.decode_output(pred, mask, batch)
+ self.update_metrics(metric, prediction, batch)
+ report = f'loss: {total_loss / (idx + 1):.4f} {metric}'
+ timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
+ return total_loss / timer.total, metric
+
+ def input_is_flat(self, data) -> bool:
+ return isinstance(data[0], str)
+
+ def prediction_to_result(self, prediction: List, batch: Dict[str, Any], delimiter=None) -> List:
+ if delimiter is None:
+ delimiter = self.config.delimiter
+ for matrix, tokens in zip(prediction, batch['token']):
+ result = []
+ for i, arguments in enumerate(matrix):
+ if arguments:
+ pas = [(delimiter.join(tokens[x[1]:x[2]]),) + x for x in arguments]
+ pas.insert(bisect([a[1] for a in arguments], i), (tokens[i], PRED, i, i + 1))
+ result.append(pas)
+ yield result
diff --git a/hanlp/components/srl/span_rank/__init__.py b/hanlp/components/srl/span_rank/__init__.py
new file mode 100644
index 000000000..2f00eebb0
--- /dev/null
+++ b/hanlp/components/srl/span_rank/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-19 22:22
\ No newline at end of file
diff --git a/hanlp/components/srl/span_rank/highway_variational_lstm.py b/hanlp/components/srl/span_rank/highway_variational_lstm.py
new file mode 100644
index 000000000..0e6d41326
--- /dev/null
+++ b/hanlp/components/srl/span_rank/highway_variational_lstm.py
@@ -0,0 +1,250 @@
+# Adopted from https://github.com/KiroSummer/A_Syntax-aware_MTL_Framework_for_Chinese_SRL
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from torch.autograd import Variable
+
+from .layer import DropoutLayer, HighwayLSTMCell, VariationalLSTMCell
+
+
+def initializer_1d(input_tensor, initializer):
+ assert len(input_tensor.size()) == 1
+ input_tensor = input_tensor.view(-1, 1)
+ input_tensor = initializer(input_tensor)
+ return input_tensor.view(-1)
+
+
+class HighwayBiLSTM(nn.Module):
+ """A module that runs multiple steps of HighwayBiLSTM."""
+
+ def __init__(self, input_size, hidden_size, num_layers=1, batch_first=False, bidirectional=False, dropout_in=0,
+ dropout_out=0):
+ super(HighwayBiLSTM, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.batch_first = batch_first
+ self.bidirectional = bidirectional
+ self.dropout_in = dropout_in
+ self.dropout_out = dropout_out
+ self.num_directions = 2 if bidirectional else 1
+
+ self.fcells, self.f_dropout, self.f_hidden_dropout = [], [], []
+ self.bcells, self.b_dropout, self.b_hidden_dropout = [], [], []
+ for layer in range(num_layers):
+ layer_input_size = input_size if layer == 0 else hidden_size
+ self.fcells.append(HighwayLSTMCell(input_size=layer_input_size, hidden_size=hidden_size))
+ self.f_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ self.f_hidden_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ if self.bidirectional:
+ self.bcells.append(HighwayLSTMCell(input_size=hidden_size, hidden_size=hidden_size))
+ self.b_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ self.b_hidden_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ self.fcells, self.bcells = nn.ModuleList(self.fcells), nn.ModuleList(self.bcells)
+ self.f_dropout, self.b_dropout = nn.ModuleList(self.f_dropout), nn.ModuleList(self.b_dropout)
+
+ def reset_dropout_layer(self, batch_size):
+ for layer in range(self.num_layers):
+ self.f_dropout[layer].reset_dropout_mask(batch_size)
+ if self.bidirectional:
+ self.b_dropout[layer].reset_dropout_mask(batch_size)
+
+ @staticmethod
+ def _forward_rnn(cell, gate, input, masks, initial, drop_masks=None, hidden_drop=None):
+ max_time = input.size(0)
+ output = []
+ hx = initial
+ for time in range(max_time):
+ h_next, c_next = cell(input[time], mask=masks[time], hx=hx, dropout=drop_masks)
+ hx = (h_next, c_next)
+ output.append(h_next)
+ output = torch.stack(output, 0)
+ return output, hx
+
+ @staticmethod
+ def _forward_brnn(cell, gate, input, masks, initial, drop_masks=None, hidden_drop=None):
+ max_time = input.size(0)
+ output = []
+ hx = initial
+ for time in reversed(list(range(max_time))):
+ h_next, c_next = cell(input[time], mask=masks[time], hx=hx, dropout=drop_masks)
+ hx = (h_next, c_next)
+ output.append(h_next)
+ output.reverse()
+ output = torch.stack(output, 0)
+ return output, hx
+
+ def forward(self, input, masks, initial=None):
+ if self.batch_first:
+ input = input.transpose(0, 1) # transpose: return the transpose matrix
+ masks = torch.unsqueeze(masks.transpose(0, 1), dim=2)
+ max_time, batch_size, _ = input.size()
+
+ self.reset_dropout_layer(batch_size) # reset the dropout each batch forward
+
+ masks = masks.expand(-1, -1, self.hidden_size) # expand: -1 means not expand that dimension
+ if initial is None:
+ initial = Variable(input.data.new(batch_size, self.hidden_size).zero_())
+ initial = (initial, initial) # h0, c0
+
+ h_n, c_n = [], []
+ for layer in range(self.num_layers):
+ # hidden_mask, hidden_drop = None, None
+ hidden_mask, hidden_drop = self.f_dropout[layer], self.f_hidden_dropout[layer]
+ layer_output, (layer_h_n, layer_c_n) = HighwayBiLSTM._forward_rnn(cell=self.fcells[layer], \
+ gate=None, input=input, masks=masks,
+ initial=initial, \
+ drop_masks=hidden_mask,
+ hidden_drop=hidden_drop)
+ h_n.append(layer_h_n)
+ c_n.append(layer_c_n)
+ if self.bidirectional:
+ hidden_mask, hidden_drop = self.b_dropout[layer], self.b_hidden_dropout[layer]
+ blayer_output, (blayer_h_n, blayer_c_n) = HighwayBiLSTM._forward_brnn(cell=self.bcells[layer], \
+ gate=None, input=layer_output,
+ masks=masks, initial=initial, \
+ drop_masks=hidden_mask,
+ hidden_drop=hidden_drop)
+ h_n.append(blayer_h_n)
+ c_n.append(blayer_c_n)
+
+ input = blayer_output if self.bidirectional else layer_output
+
+ h_n, c_n = torch.stack(h_n, 0), torch.stack(c_n, 0)
+ if self.batch_first:
+ input = input.transpose(1, 0) # transpose: return the transpose matrix
+ return input, (h_n, c_n)
+
+
+class StackedHighwayBiLSTM(nn.Module):
+ """A module that runs multiple steps of HighwayBiLSTM."""
+
+ def __init__(self, input_size, hidden_size, num_layers=1, batch_first=False, \
+ bidirectional=False, dropout_in=0, dropout_out=0):
+ super(StackedHighwayBiLSTM, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.batch_first = batch_first
+ self.bidirectional = bidirectional
+ self.dropout_in = dropout_in
+ self.dropout_out = dropout_out
+ self.num_directions = 2 if bidirectional else 1
+
+ self.fcells, self.f_dropout, self.f_hidden_dropout = [], [], []
+ self.bcells, self.b_dropout, self.b_hidden_dropout = [], [], []
+ self.f_initial, self.b_initial = [], []
+ for layer in range(num_layers):
+ layer_input_size = input_size if layer == 0 else 2 * hidden_size if self.bidirectional else hidden_size
+ self.fcells.append(VariationalLSTMCell(input_size=layer_input_size, hidden_size=hidden_size))
+ self.f_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ self.f_hidden_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ self.f_initial.append(nn.Parameter(torch.Tensor(2, self.hidden_size)))
+ assert self.bidirectional is True
+ self.bcells.append(VariationalLSTMCell(input_size=layer_input_size, hidden_size=hidden_size))
+ self.b_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ self.b_hidden_dropout.append(DropoutLayer(hidden_size, self.dropout_out))
+ self.b_initial.append(nn.Parameter(torch.Tensor(2, self.hidden_size)))
+ self.lstm_project_layer = nn.ModuleList([nn.Linear(2 * self.hidden_size, 2 * self.hidden_size)
+ for _ in range(num_layers - 1)])
+ self.fcells, self.bcells = nn.ModuleList(self.fcells), nn.ModuleList(self.bcells)
+ self.f_dropout, self.b_dropout = nn.ModuleList(self.f_dropout), nn.ModuleList(self.b_dropout)
+ self.f_hidden_dropout, self.b_hidden_dropout = \
+ nn.ModuleList(self.f_hidden_dropout), nn.ModuleList(self.b_hidden_dropout)
+ self.f_initial, self.b_initial = nn.ParameterList(self.f_initial), nn.ParameterList(self.b_initial)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for layer_initial in [self.f_initial, self.b_initial]:
+ for initial in layer_initial:
+ init.xavier_uniform_(initial)
+ for layer in self.lstm_project_layer:
+ init.xavier_uniform_(layer.weight)
+ initializer_1d(layer.bias, init.xavier_uniform_)
+
+ def reset_dropout_layer(self, batch_size):
+ for layer in range(self.num_layers):
+ self.f_dropout[layer].reset_dropout_mask(batch_size)
+ self.f_hidden_dropout[layer].reset_dropout_mask(batch_size)
+ if self.bidirectional:
+ self.b_dropout[layer].reset_dropout_mask(batch_size)
+ self.b_hidden_dropout[layer].reset_dropout_mask(batch_size)
+
+ def reset_state(self, batch_size):
+ f_states, b_states = [], []
+ for f_layer_initial, b_layer_initial in zip(self.f_initial, self.b_initial):
+ f_states.append([f_layer_initial[0].expand(batch_size, -1), f_layer_initial[1].expand(batch_size, -1)])
+ b_states.append([b_layer_initial[0].expand(batch_size, -1), b_layer_initial[1].expand(batch_size, -1)])
+ return f_states, b_states
+
+ @staticmethod
+ def _forward_rnn(cell, gate, input, masks, initial, drop_masks=None, hidden_drop=None):
+ max_time = input.size(0)
+ output = []
+ hx = initial
+ for time in range(max_time):
+ h_next, c_next = cell(input[time], mask=masks[time], hx=hx, dropout=drop_masks)
+ hx = (h_next, c_next)
+ output.append(h_next)
+ output = torch.stack(output, 0)
+ return output, hx
+
+ @staticmethod
+ def _forward_brnn(cell, gate, input, masks, initial, drop_masks=None, hidden_drop=None):
+ max_time = input.size(0)
+ output = []
+ hx = initial
+ for time in reversed(list(range(max_time))):
+ h_next, c_next = cell(input[time], mask=masks[time], hx=hx, dropout=drop_masks)
+ hx = (h_next, c_next)
+ output.append(h_next)
+ output.reverse()
+ output = torch.stack(output, 0)
+ return output, hx
+
+ def forward(self, input, masks, initial=None):
+ if self.batch_first:
+ input = input.transpose(0, 1) # transpose: return the transpose matrix
+ masks = torch.unsqueeze(masks.transpose(0, 1), dim=2)
+ max_time, batch_size, _ = input.size()
+
+ self.reset_dropout_layer(batch_size) # reset the dropout each batch forward
+ f_states, b_states = self.reset_state(batch_size)
+
+ masks = masks.expand(-1, -1, self.hidden_size) # expand: -1 means not expand that dimension
+
+ h_n, c_n = [], []
+ outputs = []
+ for layer in range(self.num_layers):
+ hidden_mask, hidden_drop = self.f_dropout[layer], self.f_hidden_dropout[layer]
+ layer_output, (layer_h_n, layer_c_n) = \
+ StackedHighwayBiLSTM._forward_rnn(cell=self.fcells[layer],
+ gate=None, input=input, masks=masks, initial=f_states[layer],
+ drop_masks=hidden_mask, hidden_drop=hidden_drop)
+ h_n.append(layer_h_n)
+ c_n.append(layer_c_n)
+ assert self.bidirectional is True
+ hidden_mask, hidden_drop = self.b_dropout[layer], self.b_hidden_dropout[layer]
+ blayer_output, (blayer_h_n, blayer_c_n) = \
+ StackedHighwayBiLSTM._forward_brnn(cell=self.bcells[layer],
+ gate=None, input=input, masks=masks, initial=b_states[layer],
+ drop_masks=hidden_mask, hidden_drop=hidden_drop)
+ h_n.append(blayer_h_n)
+ c_n.append(blayer_c_n)
+
+ output = torch.cat([layer_output, blayer_output], 2) if self.bidirectional else layer_output
+ output = F.dropout(output, self.dropout_out, self.training)
+ if layer > 0: # Highway
+ highway_gates = torch.sigmoid(self.lstm_project_layer[layer - 1].forward(output))
+ output = highway_gates * output + (1 - highway_gates) * input
+ if self.batch_first:
+ outputs.append(output.transpose(1, 0))
+ else:
+ outputs.append(output)
+ input = output
+
+ h_n, c_n = torch.stack(h_n, 0), torch.stack(c_n, 0)
+ if self.batch_first:
+ output = output.transpose(1, 0) # transpose: return the transpose matrix
+ return output, (h_n, c_n), outputs
diff --git a/hanlp/components/srl/span_rank/inference_utils.py b/hanlp/components/srl/span_rank/inference_utils.py
new file mode 100644
index 000000000..4a3734673
--- /dev/null
+++ b/hanlp/components/srl/span_rank/inference_utils.py
@@ -0,0 +1,243 @@
+# Adopted from https://github.com/KiroSummer/A_Syntax-aware_MTL_Framework_for_Chinese_SRL
+
+# Inference functions for the SRL model.
+import numpy as np
+
+
+def decode_spans(span_starts, span_ends, span_scores, labels_inv):
+ """
+
+ Args:
+ span_starts: [num_candidates,]
+ span_scores: [num_candidates, num_labels]
+ span_ends:
+ labels_inv:
+
+ Returns:
+
+
+ """
+ pred_spans = []
+ span_labels = np.argmax(span_scores, axis=1) # [num_candidates]
+ spans_list = list(zip(span_starts, span_ends, span_labels, span_scores))
+ spans_list = sorted(spans_list, key=lambda x: x[3][x[2]], reverse=True)
+ predicted_spans = {}
+ for start, end, label, _ in spans_list:
+ # Skip invalid span.
+ if label == 0 or (start, end) in predicted_spans:
+ continue
+ pred_spans.append((start, end, labels_inv[label]))
+ predicted_spans[(start, end)] = label
+ return pred_spans
+
+
+def greedy_decode(predict_dict, srl_labels_inv):
+ """Greedy decoding for SRL predicate-argument structures.
+
+ Args:
+ predict_dict: Dictionary of name to numpy arrays.
+ srl_labels_inv: SRL label id to string name.
+ suppress_overlap: Whether to greedily suppress overlapping arguments for the same predicate.
+
+ Returns:
+
+
+ """
+ arg_starts = predict_dict["arg_starts"]
+ arg_ends = predict_dict["arg_ends"]
+ predicates = predict_dict["predicates"]
+ arg_labels = predict_dict["arg_labels"]
+ scores = predict_dict["srl_scores"]
+
+ num_suppressed_args = 0
+
+ # Map from predicates to a list of labeled spans.
+ pred_to_args = {}
+ if len(arg_ends) > 0 and len(predicates) > 0:
+ max_len = max(np.max(arg_ends), np.max(predicates)) + 1
+ else:
+ max_len = 1
+
+ for j, pred_id in enumerate(predicates):
+ args_list = []
+ for i, (arg_start, arg_end) in enumerate(zip(arg_starts, arg_ends)):
+ # If label is not null.
+ if arg_labels[i][j] == 0:
+ continue
+ label = srl_labels_inv[arg_labels[i][j]]
+ # if label not in ["V", "C-V"]:
+ args_list.append((arg_start, arg_end, label, scores[i][j][arg_labels[i][j]]))
+
+ # Sort arguments by highest score first.
+ args_list = sorted(args_list, key=lambda x: x[3], reverse=True)
+ new_args_list = []
+
+ flags = [False for _ in range(max_len)]
+ # Predicate will not overlap with arguments either.
+ flags[pred_id] = True
+
+ for (arg_start, arg_end, label, score) in args_list:
+ # If none of the tokens has been covered:
+ if not max(flags[arg_start:arg_end + 1]):
+ new_args_list.append((arg_start, arg_end, label))
+ for k in range(arg_start, arg_end + 1):
+ flags[k] = True
+
+ # Only add predicate if it has any argument.
+ if new_args_list:
+ pred_to_args[pred_id] = new_args_list
+
+ num_suppressed_args += len(args_list) - len(new_args_list)
+
+ return pred_to_args, num_suppressed_args
+
+
+_CORE_ARGS = {"ARG0": 1, "ARG1": 2, "ARG2": 4, "ARG3": 8, "ARG4": 16, "ARG5": 32, "ARGA": 64,
+ "A0": 1, "A1": 2, "A2": 4, "A3": 8, "A4": 16, "A5": 32, "AA": 64}
+
+
+def get_predicted_clusters(top_span_starts, top_span_ends, predicted_antecedents):
+ mention_to_predicted = {}
+ predicted_clusters = []
+ for i, predicted_index in enumerate(predicted_antecedents):
+ if predicted_index < 0:
+ continue
+ assert i > predicted_index
+ predicted_antecedent = (int(top_span_starts[predicted_index]), int(top_span_ends[predicted_index]))
+ if predicted_antecedent in mention_to_predicted:
+ predicted_cluster = mention_to_predicted[predicted_antecedent]
+ else:
+ predicted_cluster = len(predicted_clusters)
+ predicted_clusters.append([predicted_antecedent])
+ mention_to_predicted[predicted_antecedent] = predicted_cluster
+
+ mention = (int(top_span_starts[i]), int(top_span_ends[i]))
+ predicted_clusters[predicted_cluster].append(mention)
+ mention_to_predicted[mention] = predicted_cluster
+
+ predicted_clusters = [tuple(pc) for pc in predicted_clusters]
+ mention_to_predicted = {m: predicted_clusters[i] for m, i in list(mention_to_predicted.items())}
+
+ return predicted_clusters, mention_to_predicted
+
+
+def _decode_non_overlapping_spans(starts, ends, scores, max_len, labels_inv, pred_id):
+ labels = np.argmax(scores, axis=1)
+ spans = []
+ for i, (start, end, label) in enumerate(zip(starts, ends, labels)):
+ if label <= 0:
+ continue
+ label_str = labels_inv[label]
+ if pred_id is not None and label_str == "V":
+ continue
+ spans.append((start, end, label_str, scores[i][label]))
+ spans = sorted(spans, key=lambda x: x[3], reverse=True)
+ flags = np.zeros([max_len], dtype=bool)
+ if pred_id is not None:
+ flags[pred_id] = True
+ new_spans = []
+ for start, end, label_str, score in spans:
+ if not max(flags[start:end + 1]):
+ new_spans.append((start, end, label_str)) # , score))
+ for k in range(start, end + 1):
+ flags[k] = True
+ return new_spans
+
+
+def _dp_decode_non_overlapping_spans(starts, ends, scores, max_len, labels_inv, pred_id, u_constraint=False):
+ num_roles = scores.shape[1] # [num_arg, num_roles]
+ labels = np.argmax(scores, axis=1).astype(np.int64)
+ spans = list(zip(starts, ends, list(range(len(starts)))))
+ spans = sorted(spans, key=lambda x: (x[0], x[1])) # sort according to the span start index
+
+ if u_constraint:
+ f = np.zeros([max_len + 1, 128], dtype=float) - 0.1
+ else: # This one
+ f = np.zeros([max_len + 1, 1], dtype=float) - 0.1
+
+ f[0, 0] = 0
+ states = {0: set([0])} # A dictionary from id to list of binary core-arg states.
+ pointers = {} # A dictionary from states to (arg_id, role, prev_t, prev_rs)
+ best_state = [(0, 0)]
+
+ def _update_state(t0, rs0, t1, rs1, delta, arg_id, role):
+ if f[t0][rs0] + delta > f[t1][rs1]:
+ f[t1][rs1] = f[t0][rs0] + delta
+ if t1 not in states:
+ states[t1] = set()
+ states[t1].update([rs1])
+ pointers[(t1, rs1)] = (arg_id, role, t0, rs0) # the pointers store
+ if f[t1][rs1] > f[best_state[0][0]][best_state[0][1]]:
+ best_state[0] = (t1, rs1)
+
+ for start, end, i in spans: # [arg_start, arg_end, arg_span_id]
+ assert scores[i][0] == 0 # dummy score
+ # The extra dummy score should be same for all states, so we can safely skip arguments overlap
+ # with the predicate.
+ if pred_id is not None and start <= pred_id and pred_id <= end: # skip the span contains the predicate
+ continue
+ r0 = labels[i] # Locally best role assignment.
+ # Strictly better to incorporate a dummy span if it has the highest local score.
+ if r0 == 0: # labels_inv[r0] == "O"
+ continue
+ r0_str = labels_inv[r0]
+ # Enumerate explored states.
+ t_states = [t for t in list(states.keys()) if t <= start] # collect the state which is before the current span
+ for t in t_states: # for each state
+ role_states = states[t]
+ # Update states if best role is not a core arg.
+ if not u_constraint or r0_str not in _CORE_ARGS: # True; this one
+ for rs in role_states: # the set type in the value in the state dict
+ _update_state(t, rs, end + 1, rs, scores[i][r0], i, r0) # update the state
+ else:
+ for rs in role_states:
+ for r in range(1, num_roles):
+ if scores[i][r] > 0:
+ r_str = labels_inv[r]
+ core_state = _CORE_ARGS.get(r_str, 0)
+ # print start, end, i, r_str, core_state, rs
+ if core_state & rs == 0:
+ _update_state(t, rs, end + 1, rs | core_state, scores[i][r], i, r)
+ # Backtrack to decode.
+ new_spans = []
+ t, rs = best_state[0]
+ while (t, rs) in pointers:
+ i, r, t0, rs0 = pointers[(t, rs)]
+ new_spans.append((int(starts[i]), int(ends[i]), labels_inv[r]))
+ t = t0
+ rs = rs0
+ return new_spans[::-1]
+
+
+def srl_decode(sentence_lengths, predict_dict, srl_labels_inv, config): # decode the predictions.
+ # Decode sentence-level tasks.
+ num_sentences = len(sentence_lengths)
+ predictions = [{} for _ in range(num_sentences)]
+ # Sentence-level predictions.
+ for i in range(num_sentences): # for each sentences
+ # if predict_dict["No_arg"] is True:
+ # predictions["srl"][i][predict_dict["predicates"][i]] = []
+ # continue
+ predict_dict_num_args_ = predict_dict["num_args"].cpu().numpy()
+ predict_dict_num_preds_ = predict_dict["num_preds"].cpu().numpy()
+ predict_dict_predicates_ = predict_dict["predicates"].cpu().numpy()
+ predict_dict_arg_starts_ = predict_dict["arg_starts"].cpu().numpy()
+ predict_dict_arg_ends_ = predict_dict["arg_ends"].cpu().numpy()
+ predict_dict_srl_scores_ = predict_dict["srl_scores"].detach().cpu().numpy()
+ num_args = predict_dict_num_args_[i] # the number of the candidate argument spans
+ num_preds = predict_dict_num_preds_[i] # the number of the candidate predicates
+ # for each predicate id, exec the decode process
+ for j, pred_id in enumerate(predict_dict_predicates_[i][:num_preds]):
+ # sorted arg_starts and arg_ends and srl_scores ? should be??? enforce_srl_constraint = False
+ arg_spans = _dp_decode_non_overlapping_spans(
+ predict_dict_arg_starts_[i][:num_args],
+ predict_dict_arg_ends_[i][:num_args],
+ predict_dict_srl_scores_[i, :num_args, j, :],
+ sentence_lengths[i], srl_labels_inv, pred_id, config.enforce_srl_constraint)
+ # To avoid warnings in the eval script.
+ if config.use_gold_predicates: # false
+ arg_spans.append((pred_id, pred_id, "V"))
+ if arg_spans:
+ predictions[i][int(pred_id)] = sorted(arg_spans, key=lambda x: (x[0], x[1]))
+
+ return predictions
diff --git a/hanlp/components/srl/span_rank/layer.py b/hanlp/components/srl/span_rank/layer.py
new file mode 100644
index 000000000..3250b6312
--- /dev/null
+++ b/hanlp/components/srl/span_rank/layer.py
@@ -0,0 +1,388 @@
+# Adopted from https://github.com/KiroSummer/A_Syntax-aware_MTL_Framework_for_Chinese_SRL
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import numpy as np
+import torch.nn.functional as F
+
+from hanlp.components.srl.span_rank.util import block_orth_normal_initializer
+
+
+def get_tensor_np(t):
+ return t.data.cpu().numpy()
+
+
+def orthonormal_initializer(output_size, input_size):
+ """adopted from Timothy Dozat https://github.com/tdozat/Parser/blob/master/lib/linalg.py
+
+ Args:
+ output_size:
+ input_size:
+
+ Returns:
+
+
+ """
+ print((output_size, input_size))
+ I = np.eye(output_size)
+ lr = .1
+ eps = .05 / (output_size + input_size)
+ success = False
+ tries = 0
+ while not success and tries < 10:
+ Q = np.random.randn(input_size, output_size) / np.sqrt(output_size)
+ for i in range(100):
+ QTQmI = Q.T.dot(Q) - I
+ loss = np.sum(QTQmI ** 2 / 2)
+ Q2 = Q ** 2
+ Q -= lr * Q.dot(QTQmI) / (
+ np.abs(Q2 + Q2.sum(axis=0, keepdims=True) + Q2.sum(axis=1, keepdims=True) - 1) + eps)
+ if np.max(Q) > 1e6 or loss > 1e6 or not np.isfinite(loss):
+ tries += 1
+ lr /= 2
+ break
+ success = True
+ if success:
+ print(('Orthogonal pretrainer loss: %.2e' % loss))
+ else:
+ print('Orthogonal pretrainer failed, using non-orthogonal random matrix')
+ Q = np.random.randn(input_size, output_size) / np.sqrt(output_size)
+ return np.transpose(Q.astype(np.float32))
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, features, eps=1e-8):
+ super(LayerNorm, self).__init__()
+ self.gamma = nn.Parameter(torch.ones(features))
+ self.beta = nn.Parameter(torch.zeros(features))
+ self.eps = eps
+
+ def forward(self, x):
+ mean = x.mean(-1, keepdim=True)
+ std = x.std(-1, keepdim=True)
+ return self.gamma * (x - mean) / (std + self.eps) + self.beta
+
+
+class DropoutLayer3D(nn.Module):
+ def __init__(self, input_size, dropout_rate=0.0):
+ super(DropoutLayer3D, self).__init__()
+ self.dropout_rate = dropout_rate
+ self.input_size = input_size
+ self.drop_mask = torch.FloatTensor(self.input_size).fill_(1 - self.dropout_rate)
+ self.drop_mask = Variable(torch.bernoulli(self.drop_mask), requires_grad=False)
+ if torch.cuda.is_available():
+ self.drop_mask = self.drop_mask.cuda()
+
+ def reset_dropout_mask(self, batch_size, length):
+ self.drop_mask = torch.FloatTensor(batch_size, length, self.input_size).fill_(1 - self.dropout_rate)
+ self.drop_mask = Variable(torch.bernoulli(self.drop_mask), requires_grad=False)
+ if torch.cuda.is_available():
+ self.drop_mask = self.drop_mask.cuda()
+
+ def forward(self, x):
+ if self.training:
+ return torch.mul(x, self.drop_mask)
+ else: # eval
+ return x * (1.0 - self.dropout_rate)
+
+
+class DropoutLayer(nn.Module):
+ def __init__(self, input_size, dropout_rate=0.0):
+ super(DropoutLayer, self).__init__()
+ self.dropout_rate = dropout_rate
+ self.input_size = input_size
+ self.drop_mask = torch.Tensor(self.input_size).fill_(1 - self.dropout_rate)
+ self.drop_mask = torch.bernoulli(self.drop_mask)
+
+ def reset_dropout_mask(self, batch_size):
+ self.drop_mask = torch.Tensor(batch_size, self.input_size).fill_(1 - self.dropout_rate)
+ self.drop_mask = torch.bernoulli(self.drop_mask)
+
+ def forward(self, x):
+ if self.training:
+ return torch.mul(x, self.drop_mask.to(x.device))
+ else: # eval
+ return x * (1.0 - self.dropout_rate)
+
+
+class NonLinear(nn.Module):
+ def __init__(self, input_size, hidden_size, activation=None):
+ super(NonLinear, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.linear = nn.Linear(in_features=input_size, out_features=hidden_size)
+ if activation is None:
+ self._activate = lambda x: x
+ else:
+ if not callable(activation):
+ raise ValueError("activation must be callable: type={}".format(type(activation)))
+ self._activate = activation
+
+ self.reset_parameters()
+
+ def forward(self, x):
+ y = self.linear(x)
+ return self._activate(y)
+
+ def reset_parameters(self):
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+
+class Biaffine(nn.Module):
+ def __init__(self, in1_features, in2_features, out_features,
+ bias=(True, True)):
+ super(Biaffine, self).__init__()
+ self.in1_features = in1_features
+ self.in2_features = in2_features
+ self.out_features = out_features
+ self.bias = bias
+ self.linear_input_size = in1_features + int(bias[0])
+ self.linear_output_size = out_features * (in2_features + int(bias[1]))
+ self.linear = nn.Linear(in_features=self.linear_input_size,
+ out_features=self.linear_output_size,
+ bias=False)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ torch.nn.init.xavier_uniform_(self.linear.weight)
+
+ def forward(self, input1, input2):
+ batch_size, len1, dim1 = input1.size()
+ batch_size, len2, dim2 = input2.size()
+ if self.bias[0]:
+ ones = input1.data.new(batch_size, len1, 1).zero_().fill_(1) # this kind of implementation is too tedious
+ input1 = torch.cat((input1, Variable(ones)), dim=2)
+ dim1 += 1
+ if self.bias[1]:
+ ones = input2.data.new(batch_size, len2, 1).zero_().fill_(1)
+ input2 = torch.cat((input2, Variable(ones)), dim=2)
+ dim2 += 1
+
+ affine = self.linear(input1)
+
+ affine = affine.view(batch_size, len1 * self.out_features, dim2)
+ input2 = torch.transpose(input2, 1, 2)
+ # torch.bmm: Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.
+ biaffine = torch.transpose(torch.bmm(affine, input2), 1, 2)
+ # view: Returns a new tensor with the same data as the self tensor but of a different size.
+ biaffine = biaffine.contiguous().view(batch_size, len2, len1, self.out_features)
+
+ return biaffine
+
+ def __repr__(self):
+ return self.__class__.__name__ + ' (' \
+ + 'in1_features=' + str(self.in1_features) \
+ + ', in2_features=' + str(self.in2_features) \
+ + ', out_features=' + str(self.out_features) + ')'
+
+
+class HighwayLSTMCell(nn.Module):
+ def __init__(self, input_size, hidden_size):
+ super(HighwayLSTMCell, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.linear_ih = nn.Linear(in_features=input_size,
+ out_features=6 * hidden_size)
+ self.linear_hh = nn.Linear(in_features=hidden_size,
+ out_features=5 * hidden_size,
+ bias=False)
+ self.reset_parameters() # reset all the param in the MyLSTMCell
+
+ def reset_parameters(self):
+ weight_ih = block_orth_normal_initializer([self.input_size, ], [self.hidden_size] * 6)
+ self.linear_ih.weight.data.copy_(weight_ih)
+
+ weight_hh = block_orth_normal_initializer([self.hidden_size, ], [self.hidden_size] * 5)
+ self.linear_hh.weight.data.copy_(weight_hh)
+ # nn.init.constant(self.linear_hh.weight, 1.0)
+ # nn.init.constant(self.linear_ih.weight, 1.0)
+
+ nn.init.constant(self.linear_ih.bias, 0.0)
+
+ def forward(self, x, mask=None, hx=None, dropout=None):
+ assert mask is not None and hx is not None
+ _h, _c = hx
+ _x = self.linear_ih(x) # compute the x
+ preact = self.linear_hh(_h) + _x[:, :self.hidden_size * 5]
+
+ i, f, o, t, j = preact.chunk(chunks=5, dim=1)
+ i, f, o, t, j = F.sigmoid(i), F.sigmoid(f + 1.0), F.sigmoid(o), F.sigmoid(t), F.tanh(j)
+ k = _x[:, self.hidden_size * 5:]
+
+ c = f * _c + i * j
+ c = mask * c + (1.0 - mask) * _c
+
+ h = t * o * F.tanh(c) + (1.0 - t) * k
+ if dropout is not None:
+ h = dropout(h)
+ h = mask * h + (1.0 - mask) * _h
+ return h, c
+
+
+class VariationalLSTMCell(nn.Module):
+ def __init__(self, input_size, hidden_size):
+ super(VariationalLSTMCell, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.linear = nn.Linear(in_features=input_size + self.hidden_size, out_features=3 * hidden_size)
+ self.reset_parameters() # reset all the param in the MyLSTMCell
+
+ def reset_parameters(self):
+ weight = block_orth_normal_initializer([self.input_size + self.hidden_size, ], [self.hidden_size] * 3)
+ self.linear.weight.data.copy_(weight)
+ nn.init.constant_(self.linear.bias, 0.0)
+
+ def forward(self, x, mask=None, hx=None, dropout=None):
+ assert mask is not None and hx is not None
+ _h, _c = hx
+ _h = dropout(_h)
+ _x = self.linear(torch.cat([x, _h], 1)) # compute the x
+ i, j, o = _x.chunk(3, dim=1)
+ i = torch.sigmoid(i)
+ c = (1.0 - i) * _c + i * torch.tanh(j)
+ c = mask * c # + (1.0 - mask) * _c
+ h = torch.tanh(c) * torch.sigmoid(o)
+ h = mask * h # + (1.0 - mask) * _h
+
+ return h, c
+
+
+class VariationalLSTM(nn.Module):
+ """A module that runs multiple steps of LSTM."""
+
+ def __init__(self, input_size, hidden_size, num_layers=1, batch_first=False, \
+ bidirectional=False, dropout_in=0, dropout_out=0):
+ super(VariationalLSTM, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.batch_first = batch_first
+ self.bidirectional = bidirectional
+ self.dropout_in = dropout_in
+ self.dropout_out = dropout_out
+ self.num_directions = 2 if bidirectional else 1
+
+ self.fcells = []
+ self.bcells = []
+ for layer in range(num_layers):
+ layer_input_size = input_size if layer == 0 else hidden_size * self.num_directions
+ self.fcells.append(nn.LSTMCell(input_size=layer_input_size, hidden_size=hidden_size))
+ if self.bidirectional:
+ self.bcells.append(nn.LSTMCell(input_size=layer_input_size, hidden_size=hidden_size))
+
+ self._all_weights = []
+ for layer in range(num_layers):
+ layer_params = (self.fcells[layer].weight_ih, self.fcells[layer].weight_hh, \
+ self.fcells[layer].bias_ih, self.fcells[layer].bias_hh)
+ suffix = ''
+ param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
+ param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
+ param_names = [x.format(layer, suffix) for x in param_names]
+ for name, param in zip(param_names, layer_params):
+ setattr(self, name, param)
+ self._all_weights.append(param_names)
+
+ if self.bidirectional:
+ layer_params = (self.bcells[layer].weight_ih, self.bcells[layer].weight_hh, \
+ self.bcells[layer].bias_ih, self.bcells[layer].bias_hh)
+ suffix = '_reverse'
+ param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
+ param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
+ param_names = [x.format(layer, suffix) for x in param_names]
+ for name, param in zip(param_names, layer_params):
+ setattr(self, name, param)
+ self._all_weights.append(param_names)
+
+ self.reset_parameters()
+
+ def reset_parameters(self): # modified by kiro
+ for name, param in self.named_parameters():
+ print(name)
+ if "weight" in name:
+ # for i in range(4):
+ # nn.init.orthogonal(self.__getattr__(name)[self.hidden_size*i:self.hidden_size*(i+1),:])
+ nn.init.orthogonal(self.__getattr__(name))
+ if "bias" in name:
+ nn.init.normal(self.__getattr__(name), 0.0, 0.01)
+ # nn.init.constant(self.__getattr__(name), 1.0) # different from zhang's 0
+
+ @staticmethod
+ def _forward_rnn(cell, input, masks, initial, drop_masks):
+ max_time = input.size(0)
+ output = []
+ hx = initial
+ for time in range(max_time):
+ h_next, c_next = cell(input=input[time], hx=hx)
+ h_next = h_next * masks[time] + initial[0] * (1 - masks[time])
+ c_next = c_next * masks[time] + initial[1] * (1 - masks[time])
+ output.append(h_next)
+ if drop_masks is not None: h_next = h_next * drop_masks
+ hx = (h_next, c_next)
+ output = torch.stack(output, 0)
+ return output, hx
+
+ @staticmethod
+ def _forward_brnn(cell, input, masks, initial, drop_masks):
+ max_time = input.size(0)
+ output = []
+ hx = initial
+ for time in reversed(list(range(max_time))):
+ h_next, c_next = cell(input=input[time], hx=hx)
+ h_next = h_next * masks[time] + initial[0] * (1 - masks[time])
+ c_next = c_next * masks[time] + initial[1] * (1 - masks[time])
+ output.append(h_next)
+ if drop_masks is not None: h_next = h_next * drop_masks
+ hx = (h_next, c_next)
+ output.reverse()
+ output = torch.stack(output, 0)
+ return output, hx
+
+ def forward(self, input, masks, initial=None):
+ if self.batch_first:
+ input = input.transpose(0, 1) # transpose: return the transpose matrix
+ masks = torch.unsqueeze(masks.transpose(0, 1), dim=2)
+ max_time, batch_size, _ = input.size()
+ masks = masks.expand(-1, -1, self.hidden_size) # expand: -1 means not expand that dimension
+ if initial is None:
+ initial = Variable(input.data.new(batch_size, self.hidden_size).zero_())
+ initial = (initial, initial) # h0, c0
+ h_n = []
+ c_n = []
+
+ for layer in range(self.num_layers):
+ max_time, batch_size, input_size = input.size()
+ input_mask, hidden_mask = None, None
+ if self.training: # when training, use the dropout
+ input_mask = input.data.new(batch_size, input_size).fill_(1 - self.dropout_in)
+ input_mask = Variable(torch.bernoulli(input_mask), requires_grad=False)
+ input_mask = input_mask / (1 - self.dropout_in)
+ # permute: exchange the dimension
+ input_mask = torch.unsqueeze(input_mask, dim=2).expand(-1, -1, max_time).permute(2, 0, 1)
+ input = input * input_mask
+
+ hidden_mask = input.data.new(batch_size, self.hidden_size).fill_(1 - self.dropout_out)
+ hidden_mask = Variable(torch.bernoulli(hidden_mask), requires_grad=False)
+ hidden_mask = hidden_mask / (1 - self.dropout_out)
+
+ layer_output, (layer_h_n, layer_c_n) = VariationalLSTM._forward_rnn(cell=self.fcells[layer], \
+ input=input, masks=masks,
+ initial=initial,
+ drop_masks=hidden_mask)
+ if self.bidirectional:
+ blayer_output, (blayer_h_n, blayer_c_n) = VariationalLSTM._forward_brnn(cell=self.bcells[layer], \
+ input=input, masks=masks,
+ initial=initial,
+ drop_masks=hidden_mask)
+
+ h_n.append(torch.cat([layer_h_n, blayer_h_n], 1) if self.bidirectional else layer_h_n)
+ c_n.append(torch.cat([layer_c_n, blayer_c_n], 1) if self.bidirectional else layer_c_n)
+ input = torch.cat([layer_output, blayer_output], 2) if self.bidirectional else layer_output
+
+ h_n = torch.stack(h_n, 0)
+ c_n = torch.stack(c_n, 0)
+ if self.batch_first:
+ input = input.transpose(1, 0) # transpose: return the transpose matrix
+ return input, (h_n, c_n)
diff --git a/hanlp/components/srl/span_rank/span_rank.py b/hanlp/components/srl/span_rank/span_rank.py
new file mode 100644
index 000000000..7ced536c7
--- /dev/null
+++ b/hanlp/components/srl/span_rank/span_rank.py
@@ -0,0 +1,411 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-09 18:13
+import logging
+from bisect import bisect
+from typing import Union, List, Callable, Tuple, Dict, Any
+
+from hanlp_common.constant import IDX
+from hanlp.layers.transformers.utils import build_optimizer_scheduler_with_transformer
+import torch
+from torch.utils.data import DataLoader
+from hanlp.common.dataset import PadSequenceDataLoader, SortingSampler
+from hanlp.common.torch_component import TorchComponent
+from hanlp.common.transform import FieldLength
+from hanlp.common.vocab import Vocab
+from hanlp.components.srl.span_rank.inference_utils import srl_decode
+from hanlp.components.srl.span_rank.span_ranking_srl_model import SpanRankingSRLModel
+from hanlp.components.srl.span_rank.srl_eval_utils import compute_srl_f1
+from hanlp.datasets.srl.conll2012 import CoNLL2012SRLDataset, filter_v_args, unpack_srl, \
+ group_pa_by_p
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.metrics.f1 import F1
+from hanlp_common.visualization import markdown_table
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs, reorder
+
+
+class SpanRankingSemanticRoleLabeler(TorchComponent):
+ def __init__(self, **kwargs) -> None:
+ """An implementation of "Jointly Predicting Predicates and Arguments in Neural Semantic Role Labeling"
+ (:cite:`he-etal-2018-jointly`). It generates candidates triples of (predicate, arg_start, arg_end) and rank them.
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+ self.model: SpanRankingSRLModel = None
+
+ def build_optimizer(self,
+ trn,
+ epochs,
+ lr,
+ adam_epsilon,
+ weight_decay,
+ warmup_steps,
+ transformer_lr,
+ **kwargs):
+ # noinspection PyProtectedMember
+ transformer = self._get_transformer()
+ if transformer:
+ num_training_steps = len(trn) * epochs // self.config.get('gradient_accumulation', 1)
+ optimizer, scheduler = build_optimizer_scheduler_with_transformer(self.model,
+ transformer,
+ lr, transformer_lr,
+ num_training_steps, warmup_steps,
+ weight_decay, adam_epsilon)
+ else:
+ optimizer = torch.optim.Adam(self.model.parameters(), self.config.lr)
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ optimizer=optimizer,
+ mode='max',
+ factor=0.5,
+ patience=2,
+ verbose=True,
+ )
+ return optimizer, scheduler
+
+ def _get_transformer(self):
+ return getattr(self.model_.embed, 'transformer', None)
+
+ def build_criterion(self, **kwargs):
+ pass
+
+ # noinspection PyProtectedMember
+ def build_metric(self, **kwargs) -> Tuple[F1, F1]:
+ predicate_f1 = F1()
+ end_to_end_f1 = F1()
+ return predicate_f1, end_to_end_f1
+
+ def execute_training_loop(self,
+ trn: DataLoader,
+ dev: DataLoader,
+ epochs,
+ criterion,
+ optimizer,
+ metric,
+ save_dir,
+ logger: logging.Logger,
+ devices,
+ **kwargs):
+ best_epoch, best_metric = 0, -1
+ predicate, end_to_end = metric
+ optimizer, scheduler = optimizer
+ timer = CountdownTimer(epochs)
+ ratio_width = len(f'{len(trn)}/{len(trn)}')
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger,
+ linear_scheduler=scheduler if self._get_transformer() else None)
+ if dev:
+ self.evaluate_dataloader(dev, criterion, metric, logger, ratio_width=ratio_width)
+ report = f'{timer.elapsed_human}/{timer.total_time_human}'
+ dev_score = end_to_end.score
+ if not self._get_transformer():
+ scheduler.step(dev_score)
+ if dev_score > best_metric:
+ self.save_weights(save_dir)
+ best_metric = dev_score
+ report += ' [red]saved[/red]'
+ timer.log(report, ratio_percentage=False, newline=True, ratio=False)
+
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ linear_scheduler=None,
+ gradient_accumulation=1,
+ **kwargs):
+ self.model.train()
+ timer = CountdownTimer(len(trn) // gradient_accumulation)
+ total_loss = 0
+ self.reset_metrics(metric)
+ for idx, batch in enumerate(trn):
+ output_dict = self.feed_batch(batch)
+ self.update_metrics(batch, output_dict, metric)
+ loss = output_dict['loss']
+ loss = loss.sum() # For data parallel
+ loss.backward()
+ if gradient_accumulation and gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ if self.config.grad_norm:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm)
+ if (idx + 1) % gradient_accumulation == 0:
+ self._step(optimizer, linear_scheduler)
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger)
+ total_loss += loss.item()
+ del loss
+ if len(trn) % gradient_accumulation:
+ self._step(optimizer, linear_scheduler)
+ return total_loss / timer.total
+
+ def _step(self, optimizer, linear_scheduler):
+ optimizer.step()
+ optimizer.zero_grad()
+ if linear_scheduler:
+ linear_scheduler.step()
+
+ # noinspection PyMethodOverriding
+ @torch.no_grad()
+ def evaluate_dataloader(self,
+ data: DataLoader,
+ criterion: Callable,
+ metric,
+ logger,
+ ratio_width=None,
+ output=False,
+ official=False,
+ confusion_matrix=False,
+ **kwargs):
+ self.model.eval()
+ self.reset_metrics(metric)
+ timer = CountdownTimer(len(data))
+ total_loss = 0
+ if official:
+ sentences = []
+ gold = []
+ pred = []
+ for batch in data:
+ output_dict = self.feed_batch(batch)
+ if official:
+ sentences += batch['token']
+ gold += batch['srl']
+ pred += output_dict['prediction']
+ self.update_metrics(batch, output_dict, metric)
+ loss = output_dict['loss']
+ total_loss += loss.item()
+ timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
+ logger=logger,
+ ratio_width=ratio_width)
+ del loss
+ if official:
+ scores = compute_srl_f1(sentences, gold, pred)
+ if logger:
+ if confusion_matrix:
+ labels = sorted(set(y for x in scores.label_confusions.keys() for y in x))
+ headings = ['GOLD↓PRED→'] + labels
+ matrix = []
+ for i, gold in enumerate(labels):
+ row = [gold]
+ matrix.append(row)
+ for j, pred in enumerate(labels):
+ row.append(scores.label_confusions.get((gold, pred), 0))
+ matrix = markdown_table(headings, matrix)
+ logger.info(f'{"Confusion Matrix": ^{len(matrix.splitlines()[0])}}')
+ logger.info(matrix)
+ headings = ['Settings', 'Precision', 'Recall', 'F1']
+ data = []
+ for h, (p, r, f) in zip(['Unlabeled', 'Labeled', 'Official'], [
+ [scores.unlabeled_precision, scores.unlabeled_recall, scores.unlabeled_f1],
+ [scores.precision, scores.recall, scores.f1],
+ [scores.conll_precision, scores.conll_recall, scores.conll_f1],
+ ]):
+ data.append([h] + [f'{x:.2%}' for x in [p, r, f]])
+ table = markdown_table(headings, data)
+ logger.info(f'{"Scores": ^{len(table.splitlines()[0])}}')
+ logger.info(table)
+ else:
+ scores = metric
+ return total_loss / timer.total, scores
+
+ def build_model(self,
+ training=True,
+ **kwargs) -> torch.nn.Module:
+ # noinspection PyTypeChecker
+ # embed: torch.nn.Embedding = self.config.embed.module(vocabs=self.vocabs)[0].embed
+ model = SpanRankingSRLModel(self.config,
+ self.config.embed.module(vocabs=self.vocabs, training=training),
+ self.config.context_layer,
+ len(self.vocabs.srl_label))
+ return model
+
+ # noinspection PyMethodOverriding
+ def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger,
+ generate_idx=False, **kwargs) -> DataLoader:
+ batch_max_tokens = self.config.batch_max_tokens
+ gradient_accumulation = self.config.get('gradient_accumulation', 1)
+ if batch_size:
+ batch_size //= gradient_accumulation
+ if batch_max_tokens:
+ batch_max_tokens //= gradient_accumulation
+ dataset = self.build_dataset(data, generate_idx, logger)
+
+ sampler = SortingSampler([x['token_length'] for x in dataset],
+ batch_size=batch_size,
+ batch_max_tokens=batch_max_tokens,
+ shuffle=shuffle)
+ return PadSequenceDataLoader(batch_sampler=sampler,
+ device=device,
+ dataset=dataset)
+
+ def build_dataset(self, data, generate_idx, logger, transform=None):
+ dataset = CoNLL2012SRLDataset(data, transform=[filter_v_args, unpack_srl, group_pa_by_p],
+ doc_level_offset=self.config.doc_level_offset, generate_idx=generate_idx)
+ if transform:
+ dataset.append_transform(transform)
+ if isinstance(self.config.get('embed', None), Embedding):
+ transform = self.config.embed.transform(vocabs=self.vocabs)
+ if transform:
+ dataset.append_transform(transform)
+ dataset.append_transform(self.vocabs)
+ dataset.append_transform(FieldLength('token'))
+ if isinstance(data, str):
+ dataset.purge_cache() # Enable cache
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger)
+ return dataset
+
+ def predict(self, data: Union[str, List[str]], batch_size: int = None, fmt='dict', **kwargs):
+ if not data:
+ return []
+ flat = self.input_is_flat(data)
+ if flat:
+ data = [data]
+ samples = []
+ for token in data:
+ sample = dict()
+ sample['token'] = token
+ samples.append(sample)
+ batch_size = batch_size or self.config.batch_size
+ dataloader = self.build_dataloader(samples, batch_size, False, self.device, None, generate_idx=True)
+ outputs = []
+ order = []
+ for batch in dataloader:
+ output_dict = self.feed_batch(batch)
+ outputs.extend(output_dict['prediction'])
+ order.extend(batch[IDX])
+ outputs = reorder(outputs, order)
+ if fmt == 'list':
+ outputs = self.format_dict_to_results(data, outputs)
+ if flat:
+ return outputs[0]
+ return outputs
+
+ @staticmethod
+ def format_dict_to_results(data, outputs, exclusive_offset=False, with_predicate=False, with_argument=False,
+ label_first=False):
+ results = []
+ for i in range(len(outputs)):
+ tokens = data[i]
+ output = []
+ for p, a in outputs[i].items():
+ # a: [(0, 0, 'ARG0')]
+ if with_predicate:
+ a.insert(bisect([x[0] for x in a], p), (p, p, 'PRED'))
+ if with_argument is not False:
+ a = [x + (tokens[x[0]:x[1] + 1],) for x in a]
+ if isinstance(with_argument, str):
+ a = [x[:-1] + (with_argument.join(x[-1]),) for x in a]
+ if exclusive_offset:
+ a = [(x[0], x[1] + 1) + x[2:] for x in a]
+ if label_first:
+ a = [tuple(reversed(x[2:])) + x[:2] for x in a]
+ output.append(a)
+ results.append(output)
+ return results
+
+ def input_is_flat(self, data):
+ return isinstance(data[0], str)
+
+ # noinspection PyMethodOverriding
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
+ embed,
+ context_layer,
+ batch_size=40,
+ batch_max_tokens=700,
+ lexical_dropout=0.5,
+ dropout=0.2,
+ span_width_feature_size=20,
+ ffnn_size=150,
+ ffnn_depth=2,
+ argument_ratio=0.8,
+ predicate_ratio=0.4,
+ max_arg_width=30,
+ mlp_label_size=100,
+ enforce_srl_constraint=False,
+ use_gold_predicates=False,
+ doc_level_offset=True,
+ use_biaffine=False,
+ lr=1e-3,
+ transformer_lr=1e-5,
+ adam_epsilon=1e-6,
+ weight_decay=0.01,
+ warmup_steps=0.1,
+ grad_norm=5.0,
+ gradient_accumulation=1,
+ loss_reduction='sum',
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs
+ ):
+
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_vocabs(self, dataset, logger, **kwargs):
+ self.vocabs.srl_label = Vocab(pad_token=None, unk_token=None)
+ # Use null to indicate no relationship
+ self.vocabs.srl_label.add('')
+ timer = CountdownTimer(len(dataset))
+ max_seq_len = 0
+ for each in dataset:
+ max_seq_len = max(max_seq_len, len(each['token_input_ids']))
+ timer.log(f'Building vocabs (max sequence length {max_seq_len}) [blink][yellow]...[/yellow][/blink]')
+ pass
+ timer.stop()
+ timer.erase()
+ self.vocabs['srl_label'].set_unk_as_safe_unk()
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
+
+ def reset_metrics(self, metrics):
+ for each in metrics:
+ each.reset()
+
+ def report_metrics(self, loss, metrics):
+ predicate, end_to_end = metrics
+ return f'loss: {loss:.4f} predicate: {predicate.score:.2%} end_to_end: {end_to_end.score:.2%}'
+
+ def feed_batch(self, batch) -> Dict[str, Any]:
+ output_dict = self.model(batch)
+ prediction = self.decode_output(output_dict, batch, self.model.training)
+ output_dict['prediction'] = prediction
+ return output_dict
+
+ def decode_output(self, output_dict, batch, training=False):
+ idx_to_label = self.vocabs['srl_label'].idx_to_token
+ if training:
+ # Use fast decoding during training,
+ prediction = []
+ top_predicate_indices = output_dict['predicates'].tolist()
+ top_spans = torch.stack([output_dict['arg_starts'], output_dict['arg_ends']], dim=-1).tolist()
+ srl_mask = output_dict['srl_mask'].tolist()
+ for n, (pal, predicate_indices, argument_spans) in enumerate(
+ zip(output_dict['srl_scores'].argmax(-1).tolist(), top_predicate_indices, top_spans)):
+ srl_per_sentence = {}
+ for p, (al, predicate_index) in enumerate(zip(pal, predicate_indices)):
+ for a, (l, argument_span) in enumerate(zip(al, argument_spans)):
+ if l and srl_mask[n][p][a]:
+ args = srl_per_sentence.get(p, None)
+ if args is None:
+ args = srl_per_sentence[p] = []
+ args.append((*argument_span, idx_to_label[l]))
+ prediction.append(srl_per_sentence)
+ else:
+ prediction = srl_decode(batch['token_length'], output_dict, idx_to_label, self.config)
+ return prediction
+
+ def update_metrics(self, batch: dict, output_dict: dict, metrics):
+ def unpack(y: dict):
+ return set((p, bel) for p, a in y.items() for bel in a)
+
+ predicate, end_to_end = metrics
+ for pred, gold in zip(output_dict['prediction'], batch['srl']):
+ predicate(pred.keys(), gold.keys())
+ end_to_end(unpack(pred), unpack(gold))
diff --git a/hanlp/components/srl/span_rank/span_ranking_srl_model.py b/hanlp/components/srl/span_rank/span_ranking_srl_model.py
new file mode 100644
index 000000000..256655774
--- /dev/null
+++ b/hanlp/components/srl/span_rank/span_ranking_srl_model.py
@@ -0,0 +1,500 @@
+from typing import Dict
+
+from alnlp.modules.feedforward import FeedForward
+from alnlp.modules.time_distributed import TimeDistributed
+
+from .highway_variational_lstm import *
+import torch
+from alnlp.modules import util
+
+from ...parsers.biaffine.biaffine import Biaffine
+
+
+def initializer_1d(input_tensor, initializer):
+ assert len(input_tensor.size()) == 1
+ input_tensor = input_tensor.view(-1, 1)
+ input_tensor = initializer(input_tensor)
+ return input_tensor.view(-1)
+
+
+class SpanRankingSRLDecoder(nn.Module):
+
+ def __init__(self, context_layer_output_dim, label_space_size, config) -> None:
+ super().__init__()
+ self.config = config
+ self.label_space_size = label_space_size
+ self.dropout = float(config.dropout)
+ self.use_gold_predicates = config.use_gold_predicates
+ # span width feature embedding
+ self.span_width_embedding = nn.Embedding(self.config.max_arg_width, self.config.span_width_feature_size)
+ # self.context_projective_layer = nn.Linear(2 * self.lstm_hidden_size, self.config.num_attention_heads)
+ # span scores
+ self.span_emb_size = 3 * context_layer_output_dim + self.config.span_width_feature_size
+ self.arg_unary_score_layers = nn.ModuleList([nn.Linear(self.span_emb_size, self.config.ffnn_size) if i == 0
+ else nn.Linear(self.config.ffnn_size, self.config.ffnn_size) for i
+ in range(self.config.ffnn_depth)]) # [,150]
+ self.arg_dropout_layers = nn.ModuleList([nn.Dropout(self.dropout) for _ in range(self.config.ffnn_depth)])
+ self.arg_unary_score_projection = nn.Linear(self.config.ffnn_size, 1)
+ # predicate scores
+ self.pred_unary_score_layers = nn.ModuleList(
+ [nn.Linear(context_layer_output_dim, self.config.ffnn_size) if i == 0
+ else nn.Linear(self.config.ffnn_size, self.config.ffnn_size) for i
+ in range(self.config.ffnn_depth)]) # [,150]
+ self.pred_dropout_layers = nn.ModuleList([nn.Dropout(self.dropout) for _ in range(self.config.ffnn_depth)])
+ self.pred_unary_score_projection = nn.Linear(self.config.ffnn_size, 1)
+ # srl scores
+ self.srl_unary_score_input_size = self.span_emb_size + context_layer_output_dim
+ self.srl_unary_score_layers = nn.ModuleList([nn.Linear(self.srl_unary_score_input_size, self.config.ffnn_size)
+ if i == 0 else nn.Linear(self.config.ffnn_size,
+ self.config.ffnn_size)
+ for i in range(self.config.ffnn_depth)])
+ self.srl_dropout_layers = nn.ModuleList([nn.Dropout(self.dropout) for _ in range(self.config.ffnn_depth)])
+ self.srl_unary_score_projection = nn.Linear(self.config.ffnn_size, self.label_space_size - 1)
+ if config.use_biaffine:
+ self.predicate_scale = TimeDistributed(FeedForward(context_layer_output_dim, 1, self.span_emb_size, 'ReLU'))
+ self.biaffine = Biaffine(self.span_emb_size, self.label_space_size - 1)
+ self.loss_reduction = config.loss_reduction
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ init.xavier_uniform_(self.span_width_embedding.weight)
+ # init.xavier_uniform_(self.context_projective_layer.weight)
+ # initializer_1d(self.context_projective_layer.bias, init.xavier_uniform_)
+
+ for layer in self.arg_unary_score_layers:
+ init.xavier_uniform_(layer.weight)
+ initializer_1d(layer.bias, init.xavier_uniform_)
+ init.xavier_uniform_(self.arg_unary_score_projection.weight)
+ initializer_1d(self.arg_unary_score_projection.bias, init.xavier_uniform_)
+
+ for layer in self.pred_unary_score_layers:
+ init.xavier_uniform_(layer.weight)
+ initializer_1d(layer.bias, init.xavier_uniform_)
+ init.xavier_uniform_(self.pred_unary_score_projection.weight)
+ initializer_1d(self.pred_unary_score_projection.bias, init.xavier_uniform_)
+
+ for layer in self.srl_unary_score_layers:
+ init.xavier_uniform_(layer.weight)
+ initializer_1d(layer.bias, init.xavier_uniform_)
+ init.xavier_uniform_(self.srl_unary_score_projection.weight)
+ initializer_1d(self.srl_unary_score_projection.bias, init.xavier_uniform_)
+ return None
+
+ def forward(self, hidden_states, batch, mask=None):
+ gold_arg_ends, gold_arg_labels, gold_arg_starts, gold_predicates, masks, sent_lengths = SpanRankingSRLModel.unpack(
+ batch, mask=mask, training=self.training)
+ return self.decode(hidden_states, sent_lengths, masks, gold_arg_starts, gold_arg_ends, gold_arg_labels,
+ gold_predicates)
+
+ @staticmethod
+ def get_candidate_spans(sent_lengths: torch.Tensor, max_sent_length, max_arg_width):
+ num_sentences = len(sent_lengths)
+ device = sent_lengths.device
+ candidate_starts = torch.arange(0, max_sent_length, device=device).expand(num_sentences, max_arg_width, -1)
+ candidate_width = torch.arange(0, max_arg_width, device=device).view(1, -1, 1)
+ candidate_ends = candidate_starts + candidate_width
+
+ candidate_starts = candidate_starts.contiguous().view(num_sentences, max_sent_length * max_arg_width)
+ candidate_ends = candidate_ends.contiguous().view(num_sentences, max_sent_length * max_arg_width)
+ actual_sent_lengths = sent_lengths.view(-1, 1).expand(-1, max_sent_length * max_arg_width)
+ candidate_mask = candidate_ends < actual_sent_lengths
+
+ candidate_starts = candidate_starts * candidate_mask
+ candidate_ends = candidate_ends * candidate_mask
+ return candidate_starts, candidate_ends, candidate_mask
+
+ @staticmethod
+ def exclusive_cumsum(input: torch.Tensor, exclusive=True):
+ """
+
+ Args:
+ input: input is the sentence lengths tensor.
+ exclusive: exclude the last sentence length (Default value = True)
+ input(torch.Tensor :):
+ input: torch.Tensor:
+
+ Returns:
+
+
+ """
+ assert exclusive is True
+ if exclusive is True:
+ exclusive_sent_lengths = input.new_zeros(1, dtype=torch.long)
+ result = torch.cumsum(torch.cat([exclusive_sent_lengths, input], 0)[:-1], 0).view(-1, 1)
+ else:
+ result = torch.cumsum(input, 0).view(-1, 1)
+ return result
+
+ def flatten_emb(self, emb):
+ num_sentences, max_sentence_length = emb.size()[0], emb.size()[1]
+ assert len(emb.size()) == 3
+ flatted_emb = emb.contiguous().view(num_sentences * max_sentence_length, -1)
+ return flatted_emb
+
+ def flatten_emb_in_sentence(self, emb, batch_sentences_mask):
+ num_sentences, max_sentence_length = emb.size()[0], emb.size()[1]
+ flatted_emb = self.flatten_emb(emb)
+ return flatted_emb[batch_sentences_mask.reshape(num_sentences * max_sentence_length)]
+
+ def get_span_emb(self, flatted_context_emb, flatted_candidate_starts, flatted_candidate_ends,
+ config, dropout=0.0):
+ batch_word_num = flatted_context_emb.size()[0]
+ # gather slices from embeddings according to indices
+ span_start_emb = flatted_context_emb[flatted_candidate_starts]
+ span_end_emb = flatted_context_emb[flatted_candidate_ends]
+ span_emb_feature_list = [span_start_emb, span_end_emb] # store the span vector representations for span rep.
+
+ span_width = 1 + flatted_candidate_ends - flatted_candidate_starts # [num_spans], generate the span width
+ max_arg_width = config.max_arg_width
+
+ # get the span width feature emb
+ span_width_index = span_width - 1
+ span_width_emb = self.span_width_embedding(span_width_index)
+ span_width_emb = F.dropout(span_width_emb, dropout, self.training)
+ span_emb_feature_list.append(span_width_emb)
+
+ """head features"""
+ cpu_flatted_candidte_starts = flatted_candidate_starts
+ span_indices = torch.arange(0, max_arg_width, device=flatted_context_emb.device).view(1, -1) + \
+ cpu_flatted_candidte_starts.view(-1, 1) # For all the i, where i in [begin, ..i, end] for span
+ # reset the position index to the batch_word_num index with index - 1
+ span_indices = torch.clamp(span_indices, max=batch_word_num - 1)
+ num_spans, spans_width = span_indices.size()[0], span_indices.size()[1]
+ flatted_span_indices = span_indices.view(-1) # so Huge!!!, column is the span?
+ # if torch.cuda.is_available():
+ flatted_span_indices = flatted_span_indices
+ span_text_emb = flatted_context_emb.index_select(0, flatted_span_indices).view(num_spans, spans_width, -1)
+ span_indices_mask = util.lengths_to_mask(span_width, max_len=max_arg_width)
+ # project context output to num head
+ # head_scores = self.context_projective_layer.forward(flatted_context_emb)
+ # get span attention
+ # span_attention = head_scores.index_select(0, flatted_span_indices).view(num_spans, spans_width)
+ # span_attention = torch.add(span_attention, expanded_span_indices_log_mask).unsqueeze(2) # control the span len
+ # span_attention = F.softmax(span_attention, dim=1)
+ span_text_emb = span_text_emb * span_indices_mask.unsqueeze(2).expand(-1, -1, span_text_emb.size()[-1])
+ span_head_emb = torch.mean(span_text_emb, 1)
+ span_emb_feature_list.append(span_head_emb)
+
+ span_emb = torch.cat(span_emb_feature_list, 1)
+ return span_emb, None, span_text_emb, span_indices, span_indices_mask
+
+ def get_arg_unary_scores(self, span_emb):
+ """Compute span score with FFNN(span embedding)
+
+ Args:
+ span_emb: tensor of [num_sentences, num_spans, emb_size]
+ config: param dropout:
+ num_labels: param name:
+
+ Returns:
+
+
+ """
+ input = span_emb
+ for i, ffnn in enumerate(self.arg_unary_score_layers):
+ input = F.relu(ffnn.forward(input))
+ input = self.arg_dropout_layers[i].forward(input)
+ output = self.arg_unary_score_projection.forward(input)
+ return output
+
+ def get_pred_unary_scores(self, span_emb):
+ input = span_emb
+ for i, ffnn in enumerate(self.pred_unary_score_layers):
+ input = F.relu(ffnn.forward(input))
+ input = self.pred_dropout_layers[i].forward(input)
+ output = self.pred_unary_score_projection.forward(input)
+ return output
+
+ def extract_spans(self, candidate_scores, candidate_starts, candidate_ends, topk, max_sentence_length,
+ sort_spans, enforce_non_crossing):
+ """extract the topk span indices
+
+ Args:
+ candidate_scores: param candidate_starts:
+ candidate_ends: param topk: [num_sentences]
+ max_sentence_length: param sort_spans:
+ enforce_non_crossing: return: indices [num_sentences, max_num_predictions]
+ candidate_starts:
+ topk:
+ sort_spans:
+
+ Returns:
+
+
+ """
+ # num_sentences = candidate_scores.size()[0]
+ # num_input_spans = candidate_scores.size()[1]
+ max_num_output_spans = int(torch.max(topk))
+ indices = [score.topk(k)[1] for score, k in zip(candidate_scores, topk)]
+ output_span_indices_tensor = [F.pad(item, [0, max_num_output_spans - item.size()[0]], value=item[-1])
+ for item in indices]
+ output_span_indices_tensor = torch.stack(output_span_indices_tensor)
+ return output_span_indices_tensor
+
+ def batch_index_select(self, emb, indices):
+ num_sentences = emb.size()[0]
+ max_sent_length = emb.size()[1]
+ flatten_emb = self.flatten_emb(emb)
+ offset = (torch.arange(0, num_sentences, device=emb.device) * max_sent_length).unsqueeze(1)
+ return torch.index_select(flatten_emb, 0, (indices + offset).view(-1)) \
+ .view(indices.size()[0], indices.size()[1], -1)
+
+ def get_batch_topk(self, candidate_starts: torch.Tensor, candidate_ends, candidate_scores, topk_ratio, text_len,
+ max_sentence_length, sort_spans=False, enforce_non_crossing=True):
+ num_sentences = candidate_starts.size()[0]
+ max_sentence_length = candidate_starts.size()[1]
+
+ topk = torch.floor(text_len.to(torch.float) * topk_ratio).to(torch.long)
+ topk = torch.max(topk, torch.ones(num_sentences, device=candidate_starts.device, dtype=torch.long))
+
+ # this part should be implemented with C++
+ predicted_indices = self.extract_spans(candidate_scores, candidate_starts, candidate_ends, topk,
+ max_sentence_length, sort_spans, enforce_non_crossing)
+ predicted_starts = torch.gather(candidate_starts, 1, predicted_indices)
+ predicted_ends = torch.gather(candidate_ends, 1, predicted_indices)
+ predicted_scores = torch.gather(candidate_scores, 1, predicted_indices)
+ return predicted_starts, predicted_ends, predicted_scores, topk, predicted_indices
+
+ def get_dense_span_labels(self, span_starts, span_ends, span_labels, max_sentence_length,
+ span_parents=None):
+ num_sentences = span_starts.size()[0]
+ max_spans_num = span_starts.size()[1]
+
+ # span_starts = span_starts + 1 - (span_labels > 0).to(torch.long)
+ span_starts[(span_labels == 0) & (span_starts < max_sentence_length - 1)] += 1 # make start > end
+ sentence_indices = torch.arange(0, num_sentences, device=span_starts.device).unsqueeze(1).expand(-1,
+ max_spans_num)
+
+ sparse_indices = torch.cat([sentence_indices.unsqueeze(2), span_starts.unsqueeze(2), span_ends.unsqueeze(2)],
+ dim=2)
+ if span_parents is not None: # semantic span predicate offset
+ sparse_indices = torch.cat([sparse_indices, span_parents.unsqueeze(2)], 2)
+
+ rank = 3 if span_parents is None else 4
+ dense_labels = torch.sparse.LongTensor(sparse_indices.view(num_sentences * max_spans_num, rank).t(),
+ span_labels.view(-1),
+ torch.Size([num_sentences] + [max_sentence_length] * (rank - 1))) \
+ .to_dense()
+ return dense_labels
+
+ @staticmethod
+ def gather_4d(params, indices):
+ assert len(params.size()) == 4 and len(indices) == 4
+ indices_a, indices_b, indices_c, indices_d = indices
+ result = params[indices_a, indices_b, indices_c, indices_d]
+ return result
+
+ def get_srl_labels(self,
+ arg_starts,
+ arg_ends,
+ predicates,
+ gold_predicates,
+ gold_arg_starts,
+ gold_arg_ends,
+ gold_arg_labels,
+ max_sentence_length
+ ):
+ num_sentences = arg_starts.size()[0]
+ max_arg_num = arg_starts.size()[1]
+ max_pred_num = predicates.size()[1]
+
+ sentence_indices_2d = torch.arange(0, num_sentences, device=arg_starts.device).unsqueeze(1).unsqueeze(2).expand(
+ -1, max_arg_num, max_pred_num)
+ expanded_arg_starts = arg_starts.unsqueeze(2).expand(-1, -1, max_pred_num)
+ expanded_arg_ends = arg_ends.unsqueeze(2).expand(-1, -1, max_pred_num)
+ expanded_predicates = predicates.unsqueeze(1).expand(-1, max_arg_num, -1)
+
+ dense_srl_labels = self.get_dense_span_labels(gold_arg_starts,
+ gold_arg_ends,
+ gold_arg_labels,
+ max_sentence_length, span_parents=gold_predicates) # ans
+ srl_labels = self.gather_4d(dense_srl_labels,
+ [sentence_indices_2d, expanded_arg_starts, expanded_arg_ends, expanded_predicates])
+ return srl_labels
+
+ def get_srl_unary_scores(self, span_emb):
+ input = span_emb
+ for i, ffnn in enumerate(self.srl_unary_score_layers):
+ input = F.relu(ffnn.forward(input))
+ input = self.srl_dropout_layers[i].forward(input)
+ output = self.srl_unary_score_projection.forward(input)
+ return output
+
+ def get_srl_scores(self, arg_emb, pred_emb, arg_scores, pred_scores, num_labels, config, dropout):
+ num_sentences = arg_emb.size()[0]
+ num_args = arg_emb.size()[1] # [batch_size, max_arg_num, arg_emb_size]
+ num_preds = pred_emb.size()[1] # [batch_size, max_pred_num, pred_emb_size]
+
+ unsqueezed_arg_emb = arg_emb.unsqueeze(2)
+ unsqueezed_pred_emb = pred_emb.unsqueeze(1)
+ expanded_arg_emb = unsqueezed_arg_emb.expand(-1, -1, num_preds, -1)
+ expanded_pred_emb = unsqueezed_pred_emb.expand(-1, num_args, -1, -1)
+ pair_emb_list = [expanded_arg_emb, expanded_pred_emb]
+ pair_emb = torch.cat(pair_emb_list, 3) # concatenate the argument emb and pre emb
+ pair_emb_size = pair_emb.size()[3]
+ flat_pair_emb = pair_emb.view(num_sentences * num_args * num_preds, pair_emb_size)
+ # get unary scores
+ flat_srl_scores = self.get_srl_unary_scores(flat_pair_emb)
+ srl_scores = flat_srl_scores.view(num_sentences, num_args, num_preds, -1)
+ if self.config.use_biaffine:
+ srl_scores += self.biaffine(arg_emb, self.predicate_scale(pred_emb)).permute([0, 2, 3, 1])
+ unsqueezed_arg_scores, unsqueezed_pred_scores = \
+ arg_scores.unsqueeze(2).unsqueeze(3), pred_scores.unsqueeze(1).unsqueeze(3)
+ srl_scores = srl_scores + unsqueezed_arg_scores + unsqueezed_pred_scores
+ dummy_scores = torch.zeros([num_sentences, num_args, num_preds, 1], device=arg_emb.device)
+ srl_scores = torch.cat([dummy_scores, srl_scores], 3)
+ return srl_scores
+
+ def get_srl_softmax_loss(self, srl_scores, srl_labels, num_predicted_args, num_predicted_preds):
+ srl_loss_mask = self.get_srl_loss_mask(srl_scores, num_predicted_args, num_predicted_preds)
+
+ loss = torch.nn.functional.cross_entropy(srl_scores[srl_loss_mask], srl_labels[srl_loss_mask],
+ reduction=self.loss_reduction)
+ return loss, srl_loss_mask
+
+ def get_srl_loss_mask(self, srl_scores, num_predicted_args, num_predicted_preds):
+ max_num_arg = srl_scores.size()[1]
+ max_num_pred = srl_scores.size()[2]
+ # num_predicted_args, 1D tensor; max_num_arg: a int variable means the gold ans's max arg number
+ args_mask = util.lengths_to_mask(num_predicted_args, max_num_arg)
+ pred_mask = util.lengths_to_mask(num_predicted_preds, max_num_pred)
+ srl_loss_mask = args_mask.unsqueeze(2) & pred_mask.unsqueeze(1)
+ return srl_loss_mask
+
+ def decode(self, contextualized_embeddings, sent_lengths, masks, gold_arg_starts, gold_arg_ends, gold_arg_labels,
+ gold_predicates):
+ num_sentences, max_sent_length = masks.size()
+ device = sent_lengths.device
+ """generate candidate spans with argument pruning"""
+ # candidate_starts [num_sentences, max_sent_length * max_arg_width]
+ candidate_starts, candidate_ends, candidate_mask = self.get_candidate_spans(
+ sent_lengths, max_sent_length, self.config.max_arg_width)
+ flatted_candidate_mask = candidate_mask.view(-1)
+ batch_word_offset = self.exclusive_cumsum(sent_lengths) # get the word offset in a batch
+ # choose the flatted_candidate_starts with the actual existing positions, i.e. exclude the illegal starts
+ flatted_candidate_starts = candidate_starts + batch_word_offset
+ flatted_candidate_starts = flatted_candidate_starts.view(-1)[flatted_candidate_mask].to(torch.long)
+ flatted_candidate_ends = candidate_ends + batch_word_offset
+ flatted_candidate_ends = flatted_candidate_ends.view(-1)[flatted_candidate_mask].to(torch.long)
+ # flatten the lstm output according to the sentence mask, i.e. exclude the illegal (padding) lstm output
+ flatted_context_output = self.flatten_emb_in_sentence(contextualized_embeddings, masks)
+ """generate the span embedding"""
+ candidate_span_emb, head_scores, span_head_emb, head_indices, head_indices_log_mask = self.get_span_emb(
+ flatted_context_output, flatted_candidate_starts, flatted_candidate_ends,
+ self.config, dropout=self.dropout)
+ """Get the span ids"""
+ candidate_span_number = candidate_span_emb.size()[0]
+ max_candidate_spans_num_per_sentence = candidate_mask.size()[1]
+ sparse_indices = candidate_mask.nonzero(as_tuple=False)
+ sparse_values = torch.arange(0, candidate_span_number, device=device)
+ candidate_span_ids = torch.sparse.FloatTensor(sparse_indices.t(), sparse_values,
+ torch.Size([num_sentences,
+ max_candidate_spans_num_per_sentence])).to_dense()
+ spans_log_mask = torch.log(candidate_mask.to(torch.float))
+ predict_dict = {"candidate_starts": candidate_starts, "candidate_ends": candidate_ends,
+ "head_scores": head_scores}
+ """Get unary scores and topk of candidate argument spans."""
+ flatted_candidate_arg_scores = self.get_arg_unary_scores(candidate_span_emb)
+ candidate_arg_scores = flatted_candidate_arg_scores.index_select(0, candidate_span_ids.view(-1)) \
+ .view(candidate_span_ids.size()[0], candidate_span_ids.size()[1])
+ candidate_arg_scores = candidate_arg_scores + spans_log_mask
+ arg_starts, arg_ends, arg_scores, num_args, top_arg_indices = \
+ self.get_batch_topk(candidate_starts, candidate_ends, candidate_arg_scores,
+ self.config.argument_ratio, sent_lengths, max_sent_length,
+ sort_spans=False, enforce_non_crossing=False)
+ """Get the candidate predicate"""
+ candidate_pred_ids = torch.arange(0, max_sent_length, device=device).unsqueeze(0).expand(num_sentences, -1)
+ candidate_pred_emb = contextualized_embeddings
+ candidate_pred_scores = self.get_pred_unary_scores(candidate_pred_emb)
+ candidate_pred_scores = candidate_pred_scores + torch.log(masks.to(torch.float).unsqueeze(2))
+ candidate_pred_scores = candidate_pred_scores.squeeze(2)
+ if self.use_gold_predicates is True:
+ predicates = gold_predicates[0]
+ num_preds = gold_predicates[1]
+ pred_scores = torch.zeros_like(predicates)
+ top_pred_indices = predicates
+ else:
+ predicates, _, pred_scores, num_preds, top_pred_indices = self.get_batch_topk(
+ candidate_pred_ids, candidate_pred_ids, candidate_pred_scores, self.config.predicate_ratio,
+ sent_lengths, max_sent_length,
+ sort_spans=False, enforce_non_crossing=False)
+ """Get top arg embeddings"""
+ arg_span_indices = torch.gather(candidate_span_ids, 1, top_arg_indices) # [num_sentences, max_num_args]
+ arg_emb = candidate_span_emb.index_select(0, arg_span_indices.view(-1)).view(
+ arg_span_indices.size()[0], arg_span_indices.size()[1], -1
+ ) # [num_sentences, max_num_args, emb]
+ """Get top predicate embeddings"""
+ pred_emb = self.batch_index_select(candidate_pred_emb,
+ top_pred_indices) # [num_sentences, max_num_preds, emb]
+ """Get the srl scores according to the arg emb and pre emb."""
+ srl_scores = self.get_srl_scores(arg_emb, pred_emb, arg_scores, pred_scores, self.label_space_size, self.config,
+ self.dropout) # [num_sentences, max_num_args, max_num_preds, num_labels]
+ if gold_arg_labels is not None:
+ """Get the answers according to the labels"""
+ srl_labels = self.get_srl_labels(arg_starts, arg_ends, predicates, gold_predicates, gold_arg_starts,
+ gold_arg_ends, gold_arg_labels, max_sent_length)
+
+ """Compute the srl loss"""
+ srl_loss, srl_mask = self.get_srl_softmax_loss(srl_scores, srl_labels, num_args, num_preds)
+ predict_dict.update({
+ 'srl_mask': srl_mask,
+ 'loss': srl_loss
+ })
+ else:
+ predict_dict['srl_mask'] = self.get_srl_loss_mask(srl_scores, num_args, num_preds)
+ predict_dict.update({
+ "candidate_arg_scores": candidate_arg_scores,
+ "candidate_pred_scores": candidate_pred_scores,
+ "predicates": predicates,
+ "arg_starts": arg_starts,
+ "arg_ends": arg_ends,
+ "arg_scores": arg_scores,
+ "pred_scores": pred_scores,
+ "num_args": num_args,
+ "num_preds": num_preds,
+ "arg_labels": torch.max(srl_scores, 1)[1], # [num_sentences, num_args, num_preds]
+ "srl_scores": srl_scores,
+ })
+ return predict_dict
+
+
+class SpanRankingSRLModel(nn.Module):
+
+ def __init__(self, config, embed: torch.nn.Module, context_layer: torch.nn.Module, label_space_size):
+ super(SpanRankingSRLModel, self).__init__()
+ self.config = config
+ self.dropout = float(config.dropout)
+ self.lexical_dropout = float(self.config.lexical_dropout)
+ self.label_space_size = label_space_size
+
+ # Initialize layers and parameters
+ self.word_embedding_dim = embed.get_output_dim() # get the embedding dim
+ self.embed = embed
+ # Initialize context layer
+ self.context_layer = context_layer
+ context_layer_output_dim = context_layer.get_output_dim()
+ self.decoder = SpanRankingSRLDecoder(context_layer_output_dim, label_space_size, config)
+
+ def forward(self,
+ batch: Dict[str, torch.Tensor]
+ ):
+ gold_arg_ends, gold_arg_labels, gold_arg_starts, gold_predicates, masks, sent_lengths = \
+ self.unpack(batch, training=self.training)
+
+ context_embeddings = self.embed(batch)
+ context_embeddings = F.dropout(context_embeddings, self.lexical_dropout, self.training)
+ contextualized_embeddings = self.context_layer(context_embeddings, masks)
+
+ return self.decoder.decode(contextualized_embeddings, sent_lengths, masks, gold_arg_starts, gold_arg_ends,
+ gold_arg_labels, gold_predicates)
+
+ @staticmethod
+ def unpack(batch, mask=None, training=False):
+ keys = 'token_length', 'predicate_offset', 'argument_begin_offset', 'argument_end_offset', 'srl_label_id'
+ sent_lengths, gold_predicates, gold_arg_starts, gold_arg_ends, gold_arg_labels = [batch.get(k, None) for k in
+ keys]
+ if mask is None:
+ mask = util.lengths_to_mask(sent_lengths)
+ # elif not training:
+ # sent_lengths = mask.sum(dim=1)
+ return gold_arg_ends, gold_arg_labels, gold_arg_starts, gold_predicates, mask, sent_lengths
diff --git a/hanlp/components/srl/span_rank/srl_eval_utils.py b/hanlp/components/srl/span_rank/srl_eval_utils.py
new file mode 100644
index 000000000..5f663460c
--- /dev/null
+++ b/hanlp/components/srl/span_rank/srl_eval_utils.py
@@ -0,0 +1,287 @@
+# Evaluation util functions for PropBank SRL.
+
+import codecs
+import collections
+import operator
+import tempfile
+from collections import Counter
+
+from hanlp.metrics.srl.srlconll import official_conll_05_evaluate
+
+_SRL_CONLL_EVAL_SCRIPT = "../run_eval.sh"
+
+
+def split_example_for_eval(example):
+ """Split document-based samples into sentence-based samples for evaluation.
+
+ Args:
+ example:
+
+ Returns:
+
+
+ """
+ sentences = example["sentences"]
+ num_words = sum(len(s) for s in sentences)
+ word_offset = 0
+ samples = []
+ # assert len(sentences) == 1
+ for i, sentence in enumerate(sentences):
+ # assert i == 0 # For CoNLL-2005, there are always document == sentence.
+ srl_rels = {}
+ ner_spans = [] # Unused.
+ for r in example["srl"][i]:
+ pred_id = r[0] - word_offset
+ if pred_id not in srl_rels:
+ srl_rels[pred_id] = []
+ srl_rels[pred_id].append((r[1] - word_offset, r[2] - word_offset, r[3]))
+ samples.append((sentence, srl_rels, ner_spans))
+ word_offset += len(sentence)
+ return samples
+
+
+def evaluate_retrieval(span_starts, span_ends, span_scores, pred_starts, pred_ends, gold_spans,
+ text_length, evaluators, debugging=False):
+ """Evaluation for unlabeled retrieval.
+
+ Args:
+ gold_spans: Set of tuples of (start, end).
+ span_starts:
+ span_ends:
+ span_scores:
+ pred_starts:
+ pred_ends:
+ text_length:
+ evaluators:
+ debugging: (Default value = False)
+
+ Returns:
+
+
+ """
+ if len(span_starts) > 0:
+ sorted_starts, sorted_ends, sorted_scores = list(zip(*sorted(
+ zip(span_starts, span_ends, span_scores),
+ key=operator.itemgetter(2), reverse=True)))
+ else:
+ sorted_starts = []
+ sorted_ends = []
+ for k, evaluator in list(evaluators.items()):
+ if k == -3:
+ predicted_spans = set(zip(span_starts, span_ends)) & gold_spans
+ else:
+ if k == -2:
+ predicted_starts = pred_starts
+ predicted_ends = pred_ends
+ if debugging:
+ print("Predicted", list(zip(sorted_starts, sorted_ends, sorted_scores))[:len(gold_spans)])
+ print("Gold", gold_spans)
+ # FIXME: scalar index error
+ elif k == 0:
+ is_predicted = span_scores > 0
+ predicted_starts = span_starts[is_predicted]
+ predicted_ends = span_ends[is_predicted]
+ else:
+ if k == -1:
+ num_predictions = len(gold_spans)
+ else:
+ num_predictions = (k * text_length) / 100
+ predicted_starts = sorted_starts[:num_predictions]
+ predicted_ends = sorted_ends[:num_predictions]
+ predicted_spans = set(zip(predicted_starts, predicted_ends))
+ evaluator.update(gold_set=gold_spans, predicted_set=predicted_spans)
+
+
+def _calc_f1(total_gold, total_predicted, total_matched, message=None):
+ precision = total_matched / total_predicted if total_predicted > 0 else 0
+ recall = total_matched / total_gold if total_gold > 0 else 0
+ f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
+ if message:
+ print(("{}: Precision: {:.2%} Recall: {:.2%} F1: {:.2%}".format(message, precision, recall, f1)))
+ return precision, recall, f1
+
+
+def compute_span_f1(gold_data, predictions, task_name):
+ assert len(gold_data) == len(predictions)
+ total_gold = 0
+ total_predicted = 0
+ total_matched = 0
+ total_unlabeled_matched = 0
+ label_confusions = Counter() # Counter of (gold, pred) label pairs.
+
+ for i in range(len(gold_data)):
+ gold = gold_data[i]
+ pred = predictions[i]
+ total_gold += len(gold)
+ total_predicted += len(pred)
+ for a0 in gold:
+ for a1 in pred:
+ if a0[0] == a1[0] and a0[1] == a1[1]:
+ total_unlabeled_matched += 1
+ label_confusions.update([(a0[2], a1[2]), ])
+ if a0[2] == a1[2]:
+ total_matched += 1
+ prec, recall, f1 = _calc_f1(total_gold, total_predicted, total_matched, task_name)
+ ul_prec, ul_recall, ul_f1 = _calc_f1(total_gold, total_predicted, total_unlabeled_matched,
+ "Unlabeled " + task_name)
+ return prec, recall, f1, ul_prec, ul_recall, ul_f1, label_confusions
+
+
+def compute_unlabeled_span_f1(gold_data, predictions, task_name):
+ assert len(gold_data) == len(predictions)
+ total_gold = 0
+ total_predicted = 0
+ total_matched = 0
+ total_unlabeled_matched = 0
+ label_confusions = Counter() # Counter of (gold, pred) label pairs.
+
+ for i in range(len(gold_data)):
+ gold = gold_data[i]
+ pred = predictions[i]
+ total_gold += len(gold)
+ total_predicted += len(pred)
+ for a0 in gold:
+ for a1 in pred:
+ if a0[0] == a1[0] and a0[1] == a1[1]:
+ total_unlabeled_matched += 1
+ label_confusions.update([(a0[2], a1[2]), ])
+ if a0[2] == a1[2]:
+ total_matched += 1
+ prec, recall, f1 = _calc_f1(total_gold, total_predicted, total_matched, task_name)
+ ul_prec, ul_recall, ul_f1 = _calc_f1(total_gold, total_predicted, total_unlabeled_matched,
+ "Unlabeled " + task_name)
+ return prec, recall, f1, ul_prec, ul_recall, ul_f1, label_confusions
+
+
+SRLScores = collections.namedtuple('SRLScores',
+ ['unlabeled_precision', 'unlabeled_recall', 'unlabeled_f1', 'precision', 'recall',
+ 'f1', 'conll_precision', 'conll_recall', 'conll_f1', 'label_confusions',
+ 'num_sents'])
+
+
+def compute_srl_f1(sentences, gold_srl, predictions, gold_path=None) -> SRLScores:
+ assert len(gold_srl) == len(predictions)
+ total_gold = 0
+ total_predicted = 0
+ total_matched = 0
+ total_unlabeled_matched = 0
+ num_sents = 0
+ label_confusions = Counter()
+
+ # Compute unofficial F1 of SRL relations.
+ for gold, prediction in zip(gold_srl, predictions):
+ gold_rels = 0
+ pred_rels = 0
+ matched = 0
+ for pred_id, gold_args in gold.items():
+ filtered_gold_args = [a for a in gold_args if a[2] not in ["V", "C-V"]]
+ total_gold += len(filtered_gold_args)
+ gold_rels += len(filtered_gold_args)
+ if pred_id not in prediction:
+ continue
+ for a0 in filtered_gold_args:
+ for a1 in prediction[pred_id]:
+ if a0[0] == a1[0] and a0[1] == a1[1]:
+ total_unlabeled_matched += 1
+ label_confusions.update([(a0[2], a1[2]), ])
+ if a0[2] == a1[2]:
+ total_matched += 1
+ matched += 1
+ for pred_id, args in prediction.items():
+ filtered_args = [a for a in args if a[2] not in ["V"]] # "C-V"]]
+ total_predicted += len(filtered_args)
+ pred_rels += len(filtered_args)
+
+ if gold_rels == matched and pred_rels == matched:
+ num_sents += 1
+
+ precision, recall, f1 = _calc_f1(total_gold, total_predicted, total_matched,
+ # "SRL (unofficial)"
+ )
+ unlabeled_precision, unlabeled_recall, unlabeled_f1 = _calc_f1(total_gold, total_predicted,
+ total_unlabeled_matched,
+ # "Unlabeled SRL (unofficial)"
+ )
+
+ # Prepare to compute official F1.
+ if not gold_path:
+ # print("No gold conll_eval data provided. Recreating ...")
+ gold_path = tempfile.NamedTemporaryFile().name
+ print_to_conll(sentences, gold_srl, gold_path, None)
+ gold_predicates = None
+ else:
+ gold_predicates = read_gold_predicates(gold_path)
+
+ temp_output = tempfile.NamedTemporaryFile().name
+ # print(("Output temp outoput {}".format(temp_output)))
+ print_to_conll(sentences, predictions, temp_output, gold_predicates)
+
+ # Evaluate twice with official script.
+ conll_recall, conll_precision, conll_f1 = official_conll_05_evaluate(temp_output, gold_path)
+ return SRLScores(unlabeled_precision, unlabeled_recall, unlabeled_f1, precision, recall, f1, conll_precision,
+ conll_recall, conll_f1, label_confusions, num_sents)
+
+
+def print_sentence_to_conll(fout, tokens, labels):
+ """Print a labeled sentence into CoNLL format.
+
+ Args:
+ fout:
+ tokens:
+ labels:
+
+ Returns:
+
+
+ """
+ for label_column in labels:
+ assert len(label_column) == len(tokens)
+ for i in range(len(tokens)):
+ fout.write(tokens[i].ljust(15))
+ for label_column in labels:
+ fout.write(label_column[i].rjust(15))
+ fout.write("\n")
+ fout.write("\n")
+
+
+def read_gold_predicates(gold_path):
+ print("gold path", gold_path)
+ fin = codecs.open(gold_path, "r", "utf-8")
+ gold_predicates = [[], ]
+ for line in fin:
+ line = line.strip()
+ if not line:
+ gold_predicates.append([])
+ else:
+ info = line.split()
+ gold_predicates[-1].append(info[0])
+ fin.close()
+ return gold_predicates
+
+
+def print_to_conll(sentences, srl_labels, output_filename, gold_predicates=None):
+ fout = codecs.open(output_filename, "w", "utf-8")
+ for sent_id, words in enumerate(sentences):
+ if gold_predicates:
+ assert len(gold_predicates[sent_id]) == len(words)
+ pred_to_args = srl_labels[sent_id]
+ props = ["-" for _ in words]
+ col_labels = [["*" for _ in words] for _ in range(len(pred_to_args))]
+ for i, pred_id in enumerate(sorted(pred_to_args.keys())):
+ # To make sure CoNLL-eval script count matching predicates as correct.
+ if gold_predicates and gold_predicates[sent_id][pred_id] != "-":
+ props[pred_id] = gold_predicates[sent_id][pred_id]
+ else:
+ props[pred_id] = "P" + words[pred_id]
+ flags = [False for _ in words]
+ for start, end, label in pred_to_args[pred_id]:
+ if not max(flags[start:end + 1]):
+ col_labels[i][start] = "(" + label + col_labels[i][start]
+ col_labels[i][end] = col_labels[i][end] + ")"
+ for j in range(start, end + 1):
+ flags[j] = True
+ # Add unpredicted verb (for predicted SRL).
+ if not flags[pred_id]: # if the predicate id is False
+ col_labels[i][pred_id] = "(V*)"
+ print_sentence_to_conll(fout, props, col_labels)
+ fout.close()
diff --git a/hanlp/components/srl/span_rank/util.py b/hanlp/components/srl/span_rank/util.py
new file mode 100644
index 000000000..2443932ab
--- /dev/null
+++ b/hanlp/components/srl/span_rank/util.py
@@ -0,0 +1,12 @@
+# Adopted from https://github.com/KiroSummer/A_Syntax-aware_MTL_Framework_for_Chinese_SRL
+import torch
+
+
+def block_orth_normal_initializer(input_size, output_size):
+ weight = []
+ for o in output_size:
+ for i in input_size:
+ param = torch.FloatTensor(o, i)
+ torch.nn.init.orthogonal_(param)
+ weight.append(param)
+ return torch.cat(weight)
diff --git a/hanlp/components/taggers/cnn_tagger.py b/hanlp/components/taggers/cnn_tagger_tf.py
similarity index 93%
rename from hanlp/components/taggers/cnn_tagger.py
rename to hanlp/components/taggers/cnn_tagger_tf.py
index 0be1901a4..e5e6422c0 100644
--- a/hanlp/components/taggers/cnn_tagger.py
+++ b/hanlp/components/taggers/cnn_tagger_tf.py
@@ -6,17 +6,17 @@
import tensorflow as tf
-from hanlp.components.taggers.tagger import TaggerComponent
+from hanlp.components.taggers.tagger_tf import TaggerComponent
from hanlp.transform.tsv import TSVTaggingTransform
-from hanlp.common.vocab import Vocab
-from hanlp.layers.embeddings import build_embedding
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.layers.embeddings.util_tf import build_embedding
class WindowTokenTransform(TSVTaggingTransform):
def fit(self, trn_path: str, **kwargs):
- self.word_vocab = Vocab()
- self.tag_vocab = Vocab(pad_token=None, unk_token=None)
+ self.word_vocab = VocabTF()
+ self.tag_vocab = VocabTF(pad_token=None, unk_token=None)
for ngrams, tags in self.file_to_samples(trn_path):
for words in ngrams:
self.word_vocab.update(words)
@@ -91,7 +91,7 @@ def call(self, inputs, **kwargs):
return o
-class CNNTagger(TaggerComponent, ABC):
+class CNNTaggerTF(TaggerComponent, ABC):
def __init__(self, transform: WindowTokenTransform = None) -> None:
if not transform:
transform = WindowTokenTransform()
diff --git a/hanlp/components/taggers/ngram_conv/ngram_conv_tagger.py b/hanlp/components/taggers/ngram_conv/ngram_conv_tagger.py
index 64ea4d447..e26bec3d4 100644
--- a/hanlp/components/taggers/ngram_conv/ngram_conv_tagger.py
+++ b/hanlp/components/taggers/ngram_conv/ngram_conv_tagger.py
@@ -6,22 +6,22 @@
import tensorflow as tf
-from hanlp.common.structure import SerializableDict
-from hanlp.components.taggers.tagger import TaggerComponent
+from hanlp_common.structure import SerializableDict
+from hanlp.components.taggers.tagger_tf import TaggerComponent
from hanlp.transform.tsv import TSVTaggingTransform
-from hanlp.transform.txt import extract_ngram_features, bmes_to_words
-from hanlp.common.vocab import Vocab
-from hanlp.layers.embeddings import build_embedding
+from hanlp.transform.txt import bmes_to_words, extract_ngram_features
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.layers.embeddings.util_tf import build_embedding
from hanlp.layers.weight_normalization import WeightNormalization
-from hanlp.utils.util import merge_locals_kwargs
+from hanlp_common.util import merge_locals_kwargs
class NgramTransform(TSVTaggingTransform):
def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None:
super().__init__(config, map_x, map_y, **kwargs)
- self.ngram_vocab: Optional[Vocab] = None
- self.tag_vocab: Optional[Vocab] = None
+ self.ngram_vocab: Optional[VocabTF] = None
+ self.tag_vocab: Optional[VocabTF] = None
def inputs_to_samples(self, inputs, gold=False):
for data in inputs:
@@ -57,7 +57,7 @@ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
return types, shapes, defaults
def fit(self, trn_path: str, **kwargs):
- word_vocab, ngram_vocab, tag_vocab = Vocab(), Vocab(), Vocab(pad_token=None, unk_token=None)
+ word_vocab, ngram_vocab, tag_vocab = VocabTF(), VocabTF(), VocabTF(pad_token=None, unk_token=None)
num_samples = 0
for X, Y in self.file_to_samples(trn_path, gold=True):
num_samples += 1
@@ -138,7 +138,7 @@ def call(self, inputs, **kwargs):
return logits
-class NgramConvTagger(TaggerComponent):
+class NgramConvTaggerTF(TaggerComponent):
def __init__(self, transform: NgramTransform = None) -> None:
if not transform:
diff --git a/hanlp/components/taggers/rnn/__init__.py b/hanlp/components/taggers/rnn/__init__.py
new file mode 100644
index 000000000..05278e816
--- /dev/null
+++ b/hanlp/components/taggers/rnn/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-19 15:41
\ No newline at end of file
diff --git a/hanlp/components/taggers/rnn/rnntaggingmodel.py b/hanlp/components/taggers/rnn/rnntaggingmodel.py
new file mode 100644
index 000000000..dd7d53d03
--- /dev/null
+++ b/hanlp/components/taggers/rnn/rnntaggingmodel.py
@@ -0,0 +1,100 @@
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from typing import Union
+
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
+from hanlp.layers.crf.crf import CRF
+
+
+class RNNTaggingModel(nn.Module):
+
+ def __init__(self,
+ embed: Union[nn.Embedding, int],
+ rnn_input,
+ rnn_hidden,
+ n_out,
+ drop=0.5,
+ crf=True,
+ crf_constraints=None):
+ super(RNNTaggingModel, self).__init__()
+
+ # the embedding layer
+ if isinstance(embed, nn.Module):
+ self.embed = embed
+ n_embed = embed.embedding_dim
+ else:
+ self.embed = None
+ n_embed = embed
+
+ if rnn_input:
+ self.embed_to_rnn = nn.Linear(n_embed, rnn_input)
+ else:
+ self.embed_to_rnn = None
+ rnn_input = n_embed
+
+ # the word-lstm layer
+ self.word_lstm = nn.LSTM(input_size=rnn_input,
+ hidden_size=rnn_hidden,
+ batch_first=True,
+ bidirectional=True)
+
+ # the output layer
+ self.out = nn.Linear(rnn_hidden * 2, n_out)
+ # the CRF layer
+ self.crf = CRF(n_out, crf_constraints) if crf else None
+
+ self.drop = nn.Dropout(drop)
+ # self.drop = SharedDropout(drop)
+ # self.drop = LockedDropout(drop)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # init Linear
+ nn.init.xavier_uniform_(self.out.weight)
+
+ def forward(self,
+ x: torch.Tensor,
+ batch=None,
+ **kwargs):
+ # get the mask and lengths of given batch
+ mask = x.gt(0)
+ lens = mask.sum(dim=1)
+ # get outputs from embedding layers
+ if isinstance(self.embed, nn.Embedding):
+ x = self.embed(x[mask])
+ else:
+ x = self.embed(batch, mask=mask)
+ if x.dim() == 3:
+ x = x[mask]
+ x = self.drop(x)
+ if self.embed_to_rnn:
+ x = self.embed_to_rnn(x)
+ x = pack_sequence(torch.split(x, lens.tolist()), True)
+ x, _ = self.word_lstm(x)
+ x, _ = pad_packed_sequence(x, True)
+ x = self.drop(x)
+
+ return self.out(x), mask
diff --git a/hanlp/components/taggers/rnn_tagger.py b/hanlp/components/taggers/rnn_tagger.py
index 8ed8a3c8f..12dbbf9cd 100644
--- a/hanlp/components/taggers/rnn_tagger.py
+++ b/hanlp/components/taggers/rnn_tagger.py
@@ -1,88 +1,195 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-09-14 20:30
-from typing import Union, List
+# Date: 2020-05-20 13:12
+import logging
-import tensorflow as tf
+import torch
+from torch import nn
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.utils.data import DataLoader
-from hanlp.common.transform import Transform
-from hanlp.components.taggers.tagger import TaggerComponent
-from hanlp.transform.tsv import TSVTaggingTransform
+from hanlp.common.dataset import PadSequenceDataLoader, SortingSampler, TransformableDataset
+from hanlp_common.configurable import Configurable
+from hanlp.common.transform import EmbeddingNamedTransform
from hanlp.common.vocab import Vocab
-from hanlp.layers.embeddings import build_embedding, embeddings_require_string_input, embeddings_require_char_input
-from hanlp.utils.util import merge_locals_kwargs
+from hanlp.components.taggers.rnn.rnntaggingmodel import RNNTaggingModel
+from hanlp.components.taggers.tagger import Tagger
+from hanlp.datasets.ner.tsv import TSVTaggingDataset
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.layers.embeddings.util import build_word2vec_with_vocab
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import merge_locals_kwargs, merge_dict
-class RNNTagger(TaggerComponent):
+class RNNTagger(Tagger):
- def __init__(self, transform: Transform = None) -> None:
- if not transform:
- self.transform = transform = TSVTaggingTransform()
- super().__init__(transform)
+ def __init__(self, **kwargs) -> None:
+ """An old-school tagger using non-contextualized embeddings and RNNs as context layer.
- def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
- rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, lower=False, logger=None,
- loss: Union[tf.keras.losses.Loss, str] = None,
- optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='accuracy',
- batch_size=32, dev_batch_size=32, lr_decay_per_epoch=None, verbose=True, **kwargs):
- return super().fit(**merge_locals_kwargs(locals(), kwargs))
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+ self.model: RNNTaggingModel = None
- def build_model(self, embeddings, embedding_trainable, rnn_input_dropout, rnn_output_dropout, rnn_units,
- loss,
- **kwargs) -> tf.keras.Model:
- model = tf.keras.Sequential()
- embeddings = build_embedding(embeddings, self.transform.word_vocab, self.transform)
- model.add(embeddings)
- if rnn_input_dropout:
- model.add(tf.keras.layers.Dropout(rnn_input_dropout, name='rnn_input_dropout'))
- model.add(
- tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=rnn_units, return_sequences=True), name='bilstm'))
- if rnn_output_dropout:
- model.add(tf.keras.layers.Dropout(rnn_output_dropout, name='rnn_output_dropout'))
- model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(len(self.transform.tag_vocab)), name='dense'))
- return model
+ # noinspection PyMethodOverriding
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion,
+ optimizer,
+ metric,
+ save_dir,
+ logger,
+ patience,
+ **kwargs):
+ max_e, max_metric = 0, -1
+
+ criterion = self.build_criterion()
+ timer = CountdownTimer(epochs)
+ ratio_width = len(f'{len(trn)}/{len(trn)}')
+ scheduler = self.build_scheduler(**merge_dict(self.config, optimizer=optimizer, overwrite=True))
+ if not patience:
+ patience = epochs
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, ratio_width=ratio_width)
+ loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger)
+ if scheduler:
+ if isinstance(scheduler, ReduceLROnPlateau):
+ scheduler.step(dev_metric.score)
+ else:
+ scheduler.step(epoch)
+ report_patience = f'Patience: {epoch - max_e}/{patience}'
+ # save the model if it is the best so far
+ if dev_metric > max_metric:
+ self.save_weights(save_dir)
+ max_e, max_metric = epoch, dev_metric
+ report_patience = '[red]Saved[/red] '
+ stop = epoch - max_e >= patience
+ if stop:
+ timer.stop()
+ timer.log(f'{report_patience} lr: {optimizer.param_groups[0]["lr"]:.4f}',
+ ratio_percentage=False, newline=True, ratio=False)
+ if stop:
+ break
+ timer.stop()
+ if max_e != epoch:
+ self.load_weights(save_dir)
+ logger.info(f"Max score of dev is {max_metric.score:.2%} at epoch {max_e}")
+ logger.info(f"{timer.elapsed_human} elapsed, average time of each epoch is {timer.elapsed_average_human}")
+
+ def build_scheduler(self, optimizer, anneal_factor, anneal_patience, **kwargs):
+ scheduler: ReduceLROnPlateau = ReduceLROnPlateau(optimizer,
+ factor=anneal_factor,
+ patience=anneal_patience,
+ mode='max') if anneal_factor and anneal_patience else None
+ return scheduler
- def predict(self, sents: Union[List[str], List[List[str]]], batch_size=32, **kwargs) -> Union[
- List[str], List[List[str]]]:
- return super().predict(sents, batch_size)
+ def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, ratio_width=None,
+ **kwargs):
+ self.model.train()
+ timer = CountdownTimer(len(trn))
+ total_loss = 0
+ for idx, batch in enumerate(trn):
+ optimizer.zero_grad()
+ out, mask = self.feed_batch(batch)
+ y = batch['tag_id']
+ loss = self.compute_loss(criterion, out, y, mask)
+ loss.backward()
+ nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
+ optimizer.step()
+ total_loss += loss.item()
+ prediction = self.decode_output(out, mask, batch)
+ self.update_metrics(metric, out, y, mask, batch, prediction)
+ timer.log(f'loss: {loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger,
+ ratio_width=ratio_width)
+ del loss
+ del out
+ del mask
- def save_weights(self, save_dir, filename='model.h5'):
- # remove the pre-trained embedding
- embedding_layer: tf.keras.layers.Embedding = self.model.get_layer(index=0)
- if embedding_layer.trainable:
- super().save_weights(save_dir, filename)
+ def feed_batch(self, batch):
+ x = batch[f'{self.config.token_key}_id']
+ out, mask = self.model(x, **batch, batch=batch)
+ return out, mask
+
+ # noinspection PyMethodOverriding
+ def build_model(self, rnn_input, rnn_hidden, drop, crf, **kwargs) -> torch.nn.Module:
+ vocabs = self.vocabs
+ token_embed = self._convert_embed()
+ if isinstance(token_embed, EmbeddingNamedTransform):
+ token_embed = token_embed.output_dim
+ elif isinstance(token_embed, Embedding):
+ token_embed = token_embed.module(vocabs=vocabs)
else:
- truncated_model = tf.keras.Sequential(layers=self.model.layers[1:])
- truncated_model.build(input_shape=embedding_layer.output_shape)
- truncated_model.save_weights(save_dir)
-
- def build_loss(self, loss, **kwargs):
- if not loss:
- loss = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM,
- from_logits=True)
- return loss
- return super().build_loss(loss, **kwargs)
-
- @property
- def tag_vocab(self) -> Vocab:
- return self.transform.tag_vocab
-
- def build_transform(self, embeddings, **kwargs):
- if embeddings_require_string_input(embeddings):
- self.transform.map_x = False
- if embeddings_require_char_input(embeddings):
- self.transform.char_vocab = Vocab()
- return super().build_transform(**kwargs)
-
- @property
- def sample_data(self):
- if self.transform.char_vocab:
- # You cannot build your model by calling `build` if your layers do not support float type inputs.
- # Instead, in order to instantiate and build your model, `call` your model on real tensor data (of the
- # correct dtype).
- sample = tf.constant([
- ['hello', 'world', self.transform.word_vocab.pad_token],
- ['hello', 'this', 'world'],
- ])
- sample._keras_mask = tf.not_equal(sample, self.transform.word_vocab.pad_token)
- return sample
+ token_embed = build_word2vec_with_vocab(token_embed, vocabs[self.config.token_key])
+ model = RNNTaggingModel(token_embed, rnn_input, rnn_hidden, len(vocabs['tag']), drop, crf)
+ return model
+
+ def _convert_embed(self):
+ embed = self.config['embed']
+ if isinstance(embed, dict):
+ self.config['embed'] = embed = Configurable.from_config(embed)
+ return embed
+
+ def build_dataloader(self, data, batch_size, shuffle, device, logger=None, **kwargs) -> DataLoader:
+ vocabs = self.vocabs
+ token_embed = self._convert_embed()
+ dataset = data if isinstance(data, TransformableDataset) else self.build_dataset(data, transform=[vocabs])
+ if vocabs.mutable:
+ # Before building vocabs, let embeddings submit their vocabs, some embeddings will possibly opt out as their
+ # transforms are not relevant to vocabs
+ if isinstance(token_embed, Embedding):
+ transform = token_embed.transform(vocabs=vocabs)
+ if transform:
+ dataset.transform.insert(-1, transform)
+ self.build_vocabs(dataset, logger)
+ if isinstance(token_embed, Embedding):
+ # Vocabs built, now add all transforms to the pipeline. Be careful about redundant ones.
+ transform = token_embed.transform(vocabs=vocabs)
+ if transform and transform not in dataset.transform:
+ dataset.transform.insert(-1, transform)
+ sampler = SortingSampler([len(sample[self.config.token_key]) for sample in dataset], batch_size,
+ shuffle=shuffle)
+ return PadSequenceDataLoader(dataset,
+ device=device,
+ batch_sampler=sampler,
+ vocabs=vocabs)
+
+ def build_dataset(self, data, transform):
+ return TSVTaggingDataset(data, transform)
+
+ def build_vocabs(self, dataset, logger):
+ self.vocabs.tag = Vocab(unk_token=None, pad_token=None)
+ self.vocabs[self.config.token_key] = Vocab()
+ for each in dataset:
+ pass
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
+
+ def fit(self, trn_data, dev_data, save_dir,
+ batch_size=50,
+ epochs=100,
+ embed=100,
+ rnn_input=None,
+ rnn_hidden=256,
+ drop=0.5,
+ lr=0.001,
+ patience=10,
+ crf=True,
+ optimizer='adam',
+ token_key='token',
+ tagging_scheme=None,
+ anneal_factor: float = 0.5,
+ anneal_patience=2,
+ devices=None, logger=None, verbose=True, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def _id_to_tags(self, ids):
+ batch = []
+ vocab = self.vocabs['tag'].idx_to_token
+ for b in ids:
+ batch.append([])
+ for i in b:
+ batch[-1].append(vocab[i])
+ return batch
+
+ def write_output(self, yhat, y, mask, batch, prediction, output):
+ pass
diff --git a/hanlp/components/taggers/rnn_tagger_tf.py b/hanlp/components/taggers/rnn_tagger_tf.py
new file mode 100644
index 000000000..8f992877a
--- /dev/null
+++ b/hanlp/components/taggers/rnn_tagger_tf.py
@@ -0,0 +1,89 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-09-14 20:30
+from typing import Union, List
+
+import tensorflow as tf
+
+from hanlp.common.transform_tf import Transform
+from hanlp.components.taggers.tagger_tf import TaggerComponent
+from hanlp.transform.tsv import TSVTaggingTransform
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.layers.embeddings.util_tf import build_embedding, embeddings_require_string_input, \
+ embeddings_require_char_input
+from hanlp_common.util import merge_locals_kwargs
+
+
+class RNNTaggerTF(TaggerComponent):
+
+ def __init__(self, transform: Transform = None) -> None:
+ if not transform:
+ self.transform = transform = TSVTaggingTransform()
+ super().__init__(transform)
+
+ def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
+ rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, lower=False, logger=None,
+ loss: Union[tf.keras.losses.Loss, str] = None,
+ optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='accuracy',
+ batch_size=32, dev_batch_size=32, lr_decay_per_epoch=None, verbose=True, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def build_model(self, embeddings, embedding_trainable, rnn_input_dropout, rnn_output_dropout, rnn_units,
+ loss,
+ **kwargs) -> tf.keras.Model:
+ model = tf.keras.Sequential()
+ embeddings = build_embedding(embeddings, self.transform.word_vocab, self.transform)
+ model.add(embeddings)
+ if rnn_input_dropout:
+ model.add(tf.keras.layers.Dropout(rnn_input_dropout, name='rnn_input_dropout'))
+ model.add(
+ tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=rnn_units, return_sequences=True), name='bilstm'))
+ if rnn_output_dropout:
+ model.add(tf.keras.layers.Dropout(rnn_output_dropout, name='rnn_output_dropout'))
+ model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(len(self.transform.tag_vocab)), name='dense'))
+ return model
+
+ def predict(self, sents: Union[List[str], List[List[str]]], batch_size=32, **kwargs) -> Union[
+ List[str], List[List[str]]]:
+ return super().predict(sents, batch_size)
+
+ def save_weights(self, save_dir, filename='model.h5'):
+ # remove the pre-trained embedding
+ embedding_layer: tf.keras.layers.Embedding = self.model.get_layer(index=0)
+ if embedding_layer.trainable:
+ super().save_weights(save_dir, filename)
+ else:
+ truncated_model = tf.keras.Sequential(layers=self.model.layers[1:])
+ truncated_model.build(input_shape=embedding_layer.output_shape)
+ truncated_model.save_weights(save_dir)
+
+ def build_loss(self, loss, **kwargs):
+ if not loss:
+ loss = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM,
+ from_logits=True)
+ return loss
+ return super().build_loss(loss, **kwargs)
+
+ @property
+ def tag_vocab(self) -> VocabTF:
+ return self.transform.tag_vocab
+
+ def build_transform(self, embeddings, **kwargs):
+ if embeddings_require_string_input(embeddings):
+ self.transform.map_x = False
+ if embeddings_require_char_input(embeddings):
+ self.transform.char_vocab = VocabTF()
+ return super().build_transform(**kwargs)
+
+ @property
+ def sample_data(self):
+ if self.transform.char_vocab:
+ # You cannot build your model by calling `build` if your layers do not support float type inputs.
+ # Instead, in order to instantiate and build your model, `call` your model on real tensor data (of the
+ # correct dtype).
+ sample = tf.constant([
+ ['hello', 'world', self.transform.word_vocab.pad_token],
+ ['hello', 'this', 'world'],
+ ])
+ sample._keras_mask = tf.not_equal(sample, self.transform.word_vocab.pad_token)
+ return sample
diff --git a/hanlp/components/taggers/tagger.py b/hanlp/components/taggers/tagger.py
index 7dc996c50..9508aa85b 100644
--- a/hanlp/components/taggers/tagger.py
+++ b/hanlp/components/taggers/tagger.py
@@ -1,38 +1,197 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-10-25 21:49
+# Date: 2020-08-11 12:19
import logging
-from abc import ABC
-
-import tensorflow as tf
-
-from hanlp.common.component import KerasComponent
-from hanlp.layers.crf.crf_layer import CRFWrapper, CRFLoss, CRF
-from hanlp.metrics.chunking.iobes import IOBES_F1
-
-
-class TaggerComponent(KerasComponent, ABC):
-
- def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
- if metrics == 'f1':
- assert hasattr(self.transform, 'tag_vocab'), 'Name your tag vocab tag_vocab in your transform ' \
- 'or override build_metrics'
- if not self.config.get('run_eagerly', None):
- logger.debug('ChunkingF1 runs only under eager mode, '
- 'set run_eagerly=True to remove this warning')
- self.config.run_eagerly = True
- return IOBES_F1(self.transform.tag_vocab)
- return super().build_metrics(metrics, logger, **kwargs)
-
- def build_loss(self, loss, **kwargs):
- assert self.model is not None, 'should create model before build loss'
- if loss == 'crf':
- if isinstance(self.model, tf.keras.models.Sequential):
- crf = CRF(len(self.transform.tag_vocab))
- self.model.add(crf)
- loss = CRFLoss(crf, self.model.dtype)
+import warnings
+from abc import ABC, abstractmethod
+from typing import List, TextIO, Any
+
+import torch
+from torch import optim, nn
+from torch.utils.data import DataLoader
+
+from hanlp_common.constant import IDX
+from hanlp.common.structure import History
+from hanlp.components.distillation.distillable_component import DistillableComponent
+from hanlp.components.taggers.util import guess_tagging_scheme
+from hanlp.layers.crf.crf import CRF
+from hanlp.metrics.accuracy import CategoricalAccuracy
+from hanlp.utils.time_util import CountdownTimer
+from hanlp_common.util import reorder
+
+
+class Tagger(DistillableComponent, ABC):
+ def build_optimizer(self, optimizer, lr, **kwargs):
+ if optimizer == 'adam':
+ return optim.Adam(params=self.model.parameters(), lr=lr)
+ elif optimizer == 'sgd':
+ return torch.optim.SGD(self.model.parameters(), lr=lr)
+
+ def build_criterion(self, model=None, reduction='mean', **kwargs):
+ if self.config.get('crf', False):
+ if not model:
+ model = self.model
+ if isinstance(model, nn.DataParallel):
+ raise ValueError('DataParallel not supported when CRF is used')
+ return self.model_from_config.module.crf
+ return model.crf
+ else:
+ return nn.CrossEntropyLoss(reduction=reduction)
+
+ def build_metric(self, **kwargs):
+ return CategoricalAccuracy()
+
+ @abstractmethod
+ def feed_batch(self, batch):
+ pass
+
+ def compute_loss(self, criterion, out, y, mask):
+ if self.config.get('crf', False):
+ criterion: CRF = criterion
+ loss = -criterion.forward(out, y, mask)
+ else:
+ loss = criterion(out[mask], y[mask])
+ return loss
+
+ def decode_output(self, logits, mask, batch, model=None):
+ if self.config.get('crf', False):
+ if model is None:
+ model = self.model
+ crf: CRF = model.crf
+ outputs = crf.decode(logits, mask)
+ return [y[0] for y in outputs]
+ else:
+ return logits.argmax(-1)
+
+ def execute_training_loop(self, trn: DataLoader, dev: DataLoader, epochs, criterion, optimizer, metric, save_dir,
+ logger: logging.Logger, devices, ratio_width=None, patience=5, teacher=None,
+ kd_criterion=None,
+ **kwargs):
+ best_epoch, best_metric = 0, -1
+ timer = CountdownTimer(epochs)
+ history = History()
+ for epoch in range(1, epochs + 1):
+ logger.info(f"[yellow]Epoch {epoch} / {epochs}:[/yellow]")
+ self.fit_dataloader(trn, criterion, optimizer, metric, logger, history=history, ratio_width=ratio_width,
+ **self.config)
+ loss, dev_metric = self.evaluate_dataloader(dev, criterion, logger=logger, ratio_width=ratio_width)
+ timer.update()
+ report = f"{timer.elapsed_human} / {timer.total_time_human} ETA: {timer.eta_human}"
+ if dev_metric > best_metric:
+ best_epoch, best_metric = epoch, dev_metric
+ self.save_weights(save_dir)
+ report += ' [red](saved)[/red]'
else:
- self.model = CRFWrapper(self.model, len(self.transform.tag_vocab))
- loss = CRFLoss(self.model.crf, self.model.dtype)
- return loss
- return super().build_loss(loss, **kwargs)
+ report += f' ({epoch - best_epoch})'
+ if epoch - best_epoch >= patience:
+ report += ' early stop'
+ logger.info(report)
+ if epoch - best_epoch >= patience:
+ break
+ if not best_epoch:
+ self.save_weights(save_dir)
+ elif best_epoch != epoch:
+ self.load_weights(save_dir)
+ logger.info(f"Max score of dev is {best_metric} at epoch {best_epoch}")
+ logger.info(f"Average time of each epoch is {timer.elapsed_average_human}")
+ logger.info(f"{timer.elapsed_human} elapsed")
+ return best_metric
+
+ def id_to_tags(self, ids: torch.LongTensor, lens: List[int]):
+ batch = []
+ vocab = self.vocabs['tag'].idx_to_token
+ for b, l in zip(ids, lens):
+ batch.append([])
+ for i in b[:l]:
+ batch[-1].append(vocab[i])
+ return batch
+
+ def update_metrics(self, metric, logits, y, mask, batch=None, prediction=None):
+ metric(logits, y, mask)
+
+ @torch.no_grad()
+ def evaluate_dataloader(self, data, criterion, logger=None, ratio_width=None, metric=None, output=None, **kwargs):
+ self.model.eval()
+ if isinstance(output, str):
+ output = open(output, 'w')
+
+ loss = 0
+ if not metric:
+ metric = self.build_metric()
+ else:
+ metric.reset()
+ timer = CountdownTimer(len(data))
+ for idx, batch in enumerate(data):
+ logits, mask = self.feed_batch(batch)
+ y = batch['tag_id']
+ loss += self.compute_loss(criterion, logits, y, mask).item()
+ prediction = self.decode_output(logits, mask, batch)
+ self.update_metrics(metric, logits, y, mask, batch, prediction)
+ if output:
+ self.write_prediction(prediction, batch, output)
+ timer.log(f'loss: {loss / (idx + 1):.4f} {metric}', ratio_percentage=False, logger=logger,
+ ratio_width=ratio_width)
+ loss /= len(data)
+ if output:
+ output.close()
+ return float(loss), metric
+
+ def write_prediction(self, prediction, batch, output: TextIO):
+ for tokens, ps, gs in zip(batch[self.config.token_key], prediction, batch['tag']):
+ output.write('\n'.join('\t'.join([t, p, g]) for t, p, g in zip(tokens, ps, gs)))
+ output.write('\n')
+
+ def predict(self, tokens: Any, batch_size: int = None, **kwargs):
+ if not tokens:
+ return []
+ flat = self.input_is_flat(tokens)
+ if flat:
+ tokens = [tokens]
+ outputs = self.predict_data(tokens, batch_size, **kwargs)
+ if flat:
+ return outputs[0]
+ return outputs
+
+ def input_is_flat(self, tokens):
+ return isinstance(tokens, list) and isinstance(tokens[0], str)
+
+ def predict_data(self, data, batch_size, **kwargs):
+ samples = self.build_samples(data, **kwargs)
+ if not batch_size:
+ batch_size = self.config.get('batch_size', 32)
+ dataloader = self.build_dataloader(samples, batch_size, False, self.device)
+ outputs = []
+ orders = []
+ vocab = self.vocabs['tag'].idx_to_token
+ for batch in dataloader:
+ out, mask = self.feed_batch(batch)
+ pred = self.decode_output(out, mask, batch)
+ if isinstance(pred, torch.Tensor):
+ pred = pred.tolist()
+ outputs.extend(self.prediction_to_human(pred, vocab, batch))
+ orders.extend(batch[IDX])
+ outputs = reorder(outputs, orders)
+ return outputs
+
+ def build_samples(self, data: List[str], **kwargs):
+ return [{self.config.token_key: sent} for sent in data]
+
+ def prediction_to_human(self, pred, vocab: List[str], batch):
+ lengths = batch.get(f'{self.config.token_key}_length', None)
+ if lengths is None:
+ lengths = torch.sum(batch['mask'], dim=1)
+ if isinstance(lengths, torch.Tensor):
+ lengths = lengths.tolist()
+ for each, l in zip(pred, lengths):
+ yield [vocab[id] for id in each[:l]]
+
+ @property
+ def tagging_scheme(self):
+ tagging_scheme = self.config.tagging_scheme
+ if not tagging_scheme:
+ self.config.tagging_scheme = tagging_scheme = guess_tagging_scheme(self.vocabs.tag.idx_to_token)
+ if tagging_scheme == 'BIO':
+ warnings.warn(f'The tag scheme for {self.vocabs.tag.idx_to_token} might be IOB1 or IOB2 '
+ f'but we are using IOB2 by default. Please set tagging_scheme="IOB1" or tagging_scheme="BIO" '
+ f'to get rid of this warning.')
+ return tagging_scheme
diff --git a/hanlp/components/taggers/tagger_tf.py b/hanlp/components/taggers/tagger_tf.py
new file mode 100644
index 000000000..e92b97347
--- /dev/null
+++ b/hanlp/components/taggers/tagger_tf.py
@@ -0,0 +1,38 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-10-25 21:49
+import logging
+from abc import ABC
+
+import tensorflow as tf
+
+from hanlp.common.keras_component import KerasComponent
+from hanlp.layers.crf.crf_layer_tf import CRF, CRFLoss, CRFWrapper
+from hanlp.metrics.chunking.iobes_tf import IOBES_F1_TF
+
+
+class TaggerComponent(KerasComponent, ABC):
+
+ def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
+ if metrics == 'f1':
+ assert hasattr(self.transform, 'tag_vocab'), 'Name your tag vocab tag_vocab in your transform ' \
+ 'or override build_metrics'
+ if not self.config.get('run_eagerly', None):
+ logger.debug('ChunkingF1 runs only under eager mode, '
+ 'set run_eagerly=True to remove this warning')
+ self.config.run_eagerly = True
+ return IOBES_F1_TF(self.transform.tag_vocab)
+ return super().build_metrics(metrics, logger, **kwargs)
+
+ def build_loss(self, loss, **kwargs):
+ assert self.model is not None, 'should create model before build loss'
+ if loss == 'crf':
+ if isinstance(self.model, tf.keras.models.Sequential):
+ crf = CRF(len(self.transform.tag_vocab))
+ self.model.add(crf)
+ loss = CRFLoss(crf, self.model.dtype)
+ else:
+ self.model = CRFWrapper(self.model, len(self.transform.tag_vocab))
+ loss = CRFLoss(self.model.crf, self.model.dtype)
+ return loss
+ return super().build_loss(loss, **kwargs)
diff --git a/hanlp/components/taggers/transformers/metrics.py b/hanlp/components/taggers/transformers/metrics_tf.py
similarity index 100%
rename from hanlp/components/taggers/transformers/metrics.py
rename to hanlp/components/taggers/transformers/metrics_tf.py
diff --git a/hanlp/components/taggers/transformers/transformer_tagger.py b/hanlp/components/taggers/transformers/transformer_tagger.py
index 4ba04a775..3d7dd0c9c 100644
--- a/hanlp/components/taggers/transformers/transformer_tagger.py
+++ b/hanlp/components/taggers/transformers/transformer_tagger.py
@@ -1,112 +1,255 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-12-29 13:55
+# Date: 2020-06-15 20:55
import logging
-import math
+from typing import Union, List
-import tensorflow as tf
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
-from hanlp.common.transform import Transform
-from hanlp.components.taggers.tagger import TaggerComponent
-from hanlp.components.taggers.transformers.metrics import Accuracy
-from hanlp.components.taggers.transformers.transformer_transform import TransformerTransform
-from hanlp.layers.transformers.loader import build_transformer
-from hanlp.losses.sparse_categorical_crossentropy import MaskedSparseCategoricalCrossentropyOverBatchFirstDim
-from hanlp.optimizers.adamw import create_optimizer
-from hanlp.utils.util import merge_locals_kwargs
+from hanlp.common.dataset import PadSequenceDataLoader, SamplerBuilder, TransformableDataset
+from hanlp.common.structure import History
+from hanlp.common.transform import FieldLength, TransformList
+from hanlp.common.vocab import Vocab
+from hanlp.components.classifiers.transformer_classifier import TransformerComponent
+from hanlp.components.taggers.tagger import Tagger
+from hanlp.datasets.ner.tsv import TSVTaggingDataset
+from hanlp.layers.crf.crf import CRF
+from hanlp.layers.transformers.encoder import TransformerEncoder
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp.utils.time_util import CountdownTimer
+from hanlp.utils.torch_util import clip_grad_norm
+from hanlp_common.util import merge_locals_kwargs
+from alnlp.modules.util import lengths_to_mask
-class TransformerTaggingModel(tf.keras.Model):
- def __init__(self, transformer: tf.keras.Model, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.transformer = transformer
+# noinspection PyAbstractClass
+class TransformerTaggingModel(nn.Module):
+ def __init__(self,
+ encoder: TransformerEncoder,
+ num_labels,
+ crf=False,
+ secondary_encoder=None) -> None:
+ """
+ A shallow tagging model use transformer as decoder.
+ Args:
+ encoder: A pretrained transformer.
+ num_labels: Size of tagset.
+ crf: True to enable CRF.
+ crf_constraints: The allowed transitions (from_label_id, to_label_id).
+ """
+ super().__init__()
+ self.encoder = encoder
+ self.secondary_encoder = secondary_encoder
+ # noinspection PyUnresolvedReferences
+ self.classifier = nn.Linear(encoder.transformer.config.hidden_size, num_labels)
+ self.crf = CRF(num_labels) if crf else None
- def call(self, inputs, training=None, mask=None):
- return super().call(inputs, training, mask)
+ def forward(self, lens: torch.LongTensor, input_ids, token_span, token_type_ids=None):
+ mask = lengths_to_mask(lens)
+ x = self.encoder(input_ids, token_span=token_span, token_type_ids=token_type_ids)
+ if self.secondary_encoder:
+ x = self.secondary_encoder(x, mask=mask)
+ x = self.classifier(x)
+ return x, mask
-class TransformerTagger(TaggerComponent):
- def __init__(self, transform: TransformerTransform = None) -> None:
- if transform is None:
- transform = TransformerTransform()
- super().__init__(transform)
- self.transform: TransformerTransform = transform
+class TransformerTagger(TransformerComponent, Tagger):
- def build_model(self, transformer, max_seq_length, **kwargs) -> tf.keras.Model:
- model, tokenizer = build_transformer(transformer, max_seq_length, len(self.transform.tag_vocab), tagging=True)
- self.transform.tokenizer = tokenizer
+ def __init__(self, **kwargs) -> None:
+ """A simple tagger using a linear layer with an optional CRF (:cite:`lafferty2001conditional`) layer for
+ any tagging tasks including PoS tagging and many others.
+
+ Args:
+ **kwargs: Not used.
+ """
+ super().__init__(**kwargs)
+ self._tokenizer_transform = None
+ self.model: TransformerTaggingModel = None
+
+ # noinspection PyMethodOverriding
+ def fit_dataloader(self,
+ trn: DataLoader,
+ criterion,
+ optimizer,
+ metric,
+ logger: logging.Logger,
+ history: History,
+ gradient_accumulation=1,
+ grad_norm=None,
+ transformer_grad_norm=None,
+ teacher: Tagger = None,
+ kd_criterion=None,
+ temperature_scheduler=None,
+ ratio_width=None,
+ **kwargs):
+ optimizer, scheduler = optimizer
+ if teacher:
+ scheduler, lambda_scheduler = scheduler
+ else:
+ lambda_scheduler = None
+ self.model.train()
+ timer = CountdownTimer(history.num_training_steps(len(trn), gradient_accumulation=gradient_accumulation))
+ total_loss = 0
+ for idx, batch in enumerate(trn):
+ out, mask = self.feed_batch(batch)
+ y = batch['tag_id']
+ loss = self.compute_loss(criterion, out, y, mask)
+ if gradient_accumulation and gradient_accumulation > 1:
+ loss /= gradient_accumulation
+ if teacher:
+ with torch.no_grad():
+ out_T, _ = teacher.feed_batch(batch)
+ # noinspection PyNoneFunctionAssignment
+ kd_loss = self.compute_distill_loss(kd_criterion, out, out_T, mask, temperature_scheduler)
+ _lambda = float(lambda_scheduler)
+ loss = _lambda * loss + (1 - _lambda) * kd_loss
+ loss.backward()
+ total_loss += loss.item()
+ prediction = self.decode_output(out, mask, batch)
+ self.update_metrics(metric, out, y, mask, batch, prediction)
+ if history.step(gradient_accumulation):
+ self._step(optimizer, scheduler, grad_norm, transformer_grad_norm, lambda_scheduler)
+ report = f'loss: {total_loss / (idx + 1):.4f} {metric}'
+ timer.log(report, logger=logger, ratio_percentage=False, ratio_width=ratio_width)
+ del loss
+ del out
+ del mask
+
+ def _step(self, optimizer, scheduler, grad_norm, transformer_grad_norm, lambda_scheduler):
+ clip_grad_norm(self.model, grad_norm, self.model.encoder.transformer, transformer_grad_norm)
+ optimizer.step()
+ scheduler.step()
+ if lambda_scheduler:
+ lambda_scheduler.step()
+ optimizer.zero_grad()
+
+ def compute_distill_loss(self, kd_criterion, out_S, out_T, mask, temperature_scheduler):
+ logits_S = out_S[mask]
+ logits_T = out_T[mask]
+ temperature = temperature_scheduler(logits_S, logits_T)
+ return kd_criterion(logits_S, logits_T, temperature)
+
+ def build_model(self, **kwargs) -> torch.nn.Module:
+ model = TransformerTaggingModel(self.build_transformer(),
+ len(self.vocabs.tag),
+ self.config.crf,
+ self.config.get('secondary_encoder', None),
+ )
return model
- def fit(self, trn_data, dev_data, save_dir,
+ # noinspection PyMethodOverriding
+ def build_dataloader(self, data, batch_size, shuffle, device, logger: logging.Logger = None,
+ sampler_builder: SamplerBuilder = None, gradient_accumulation=1, **kwargs) -> DataLoader:
+ if isinstance(data, TransformableDataset):
+ dataset = data
+ else:
+ args = dict((k, self.config.get(k, None)) for k in
+ ['delimiter', 'max_seq_len', 'sent_delimiter', 'char_level', 'hard_constraint'])
+ dataset = self.build_dataset(data, **args)
+ if self.config.token_key is None:
+ self.config.token_key = next(iter(dataset[0]))
+ logger.info(
+ f'Guess [bold][blue]token_key={self.config.token_key}[/blue][/bold] according to the '
+ f'training dataset: [blue]{dataset}[/blue]')
+ dataset.append_transform(self.tokenizer_transform)
+ dataset.append_transform(self.last_transform())
+ if not isinstance(data, list):
+ dataset.purge_cache()
+ if self.vocabs.mutable:
+ self.build_vocabs(dataset, logger)
+ if sampler_builder is not None:
+ sampler = sampler_builder.build([len(x[f'{self.config.token_key}_input_ids']) for x in dataset], shuffle,
+ gradient_accumulation=gradient_accumulation if shuffle else 1)
+ else:
+ sampler = None
+ return PadSequenceDataLoader(dataset, batch_size, shuffle, device=device, batch_sampler=sampler)
+
+ def build_dataset(self, data, transform=None, **kwargs):
+ return TSVTaggingDataset(data, transform=transform, **kwargs)
+
+ def last_transform(self):
+ return TransformList(self.vocabs, FieldLength(self.config.token_key))
+
+ @property
+ def tokenizer_transform(self) -> TransformerSequenceTokenizer:
+ if not self._tokenizer_transform:
+ self._tokenizer_transform = TransformerSequenceTokenizer(self.transformer_tokenizer,
+ self.config.token_key,
+ ret_token_span=True)
+ return self._tokenizer_transform
+
+ def build_vocabs(self, trn, logger, **kwargs):
+ self.vocabs.tag = Vocab(pad_token=None, unk_token=None)
+ timer = CountdownTimer(len(trn))
+ max_seq_len = 0
+ token_key = self.config.token_key
+ for each in trn:
+ max_seq_len = max(max_seq_len, len(each[token_key]))
+ timer.log(f'Building vocab [blink][yellow]...[/yellow][/blink] (longest sequence: {max_seq_len})')
+ self.vocabs.tag.set_unk_as_safe_unk()
+ self.vocabs.lock()
+ self.vocabs.summary(logger)
+
+ # noinspection PyMethodOverriding
+ def fit(self,
+ trn_data,
+ dev_data,
+ save_dir,
transformer,
- optimizer='adamw',
- learning_rate=5e-5,
- weight_decay_rate=0,
- epsilon=1e-8,
- clipnorm=1.0,
- warmup_steps_ratio=0,
- use_amp=False,
- max_seq_length=128,
+ average_subwords=False,
+ word_dropout: float = 0.2,
+ hidden_dropout=None,
+ layer_dropout=0,
+ scalar_mix=None,
+ mix_embedding: int = 0,
+ grad_norm=5.0,
+ transformer_grad_norm=None,
+ lr=5e-5,
+ transformer_lr=None,
+ transformer_layers=None,
+ gradient_accumulation=1,
+ adam_epsilon=1e-6,
+ weight_decay=0,
+ warmup_steps=0.1,
+ secondary_encoder=None,
+ crf=False,
+ reduction='sum',
batch_size=32,
+ sampler_builder: SamplerBuilder = None,
epochs=3,
- metrics='accuracy',
- run_eagerly=False,
+ patience=5,
+ token_key=None,
+ max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False,
+ transform=None,
logger=None,
- verbose=True,
+ devices: Union[float, int, List[int]] = None,
**kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
- # noinspection PyMethodOverriding
- def build_optimizer(self, optimizer, learning_rate, epsilon, weight_decay_rate, clipnorm, use_amp, train_steps,
- warmup_steps, **kwargs):
- if optimizer == 'adamw':
- opt = create_optimizer(init_lr=learning_rate,
- epsilon=epsilon,
- weight_decay_rate=weight_decay_rate,
- clipnorm=clipnorm,
- num_train_steps=train_steps, num_warmup_steps=warmup_steps)
- # opt = tfa.optimizers.AdamW(learning_rate=3e-5, epsilon=1e-08, weight_decay=0.01)
- # opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
- self.config.optimizer = tf.keras.utils.serialize_keras_object(opt)
- lr_config = self.config.optimizer['config']['learning_rate']['config']
- if 'decay_schedule_fn' in lr_config:
- lr_config['decay_schedule_fn'] = dict(
- (k, v) for k, v in lr_config['decay_schedule_fn'].items() if not k.startswith('_'))
+ def feed_batch(self, batch: dict):
+ features = [batch[k] for k in self.tokenizer_transform.output_key]
+ if len(features) == 2:
+ input_ids, token_span = features
else:
- opt = super().build_optimizer(optimizer)
- if use_amp:
- # loss scaling is currently required when using mixed precision
- opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
- return opt
-
- def build_vocab(self, trn_data, logger):
- train_examples = super().build_vocab(trn_data, logger)
- warmup_steps_per_epoch = math.ceil(train_examples * self.config.warmup_steps_ratio / self.config.batch_size)
- self.config.warmup_steps = warmup_steps_per_epoch * self.config.epochs
- return train_examples
-
- def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer,
- loss, metrics, callbacks, logger, **kwargs):
- history = self.model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch,
- validation_data=dev_data,
- callbacks=callbacks,
- validation_steps=dev_steps,
- # mask out padding labels
- # class_weight=dict(
- # (i, 0 if i == 0 else 1) for i in range(len(self.transform.tag_vocab)))
- ) # type:tf.keras.callbacks.History
- return history
-
- def build_loss(self, loss, **kwargs):
- if not loss:
- return MaskedSparseCategoricalCrossentropyOverBatchFirstDim(self.transform.tag_vocab.pad_idx)
- return super().build_loss(loss, **kwargs)
-
- def load_transform(self, save_dir) -> Transform:
- super().load_transform(save_dir)
- self.transform.tokenizer = build_transformer(self.config.transformer, self.config.max_seq_length,
- len(self.transform.tag_vocab), tagging=True, tokenizer_only=True)
- return self.transform
-
- def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
- return Accuracy(self.transform.tag_vocab.pad_idx)
+ input_ids, token_span = features[0], None
+ lens = batch[f'{self.config.token_key}_length']
+ x, mask = self.model(lens, input_ids, token_span, batch.get(f'{self.config.token_key}_token_type_ids'))
+ return x, mask
+
+ # noinspection PyMethodOverriding
+ def distill(self,
+ teacher: str,
+ trn_data,
+ dev_data,
+ save_dir,
+ transformer: str,
+ batch_size=None,
+ temperature_scheduler='flsw',
+ epochs=None,
+ devices=None,
+ logger=None,
+ seed=None,
+ **kwargs):
+ return super().distill(**merge_locals_kwargs(locals(), kwargs))
diff --git a/hanlp/components/taggers/transformers/transformer_tagger_tf.py b/hanlp/components/taggers/transformers/transformer_tagger_tf.py
new file mode 100644
index 000000000..8742425a6
--- /dev/null
+++ b/hanlp/components/taggers/transformers/transformer_tagger_tf.py
@@ -0,0 +1,94 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-29 13:55
+import math
+
+import tensorflow as tf
+
+from hanlp.common.transform_tf import Transform
+from hanlp.components.taggers.tagger_tf import TaggerComponent
+from hanlp.components.taggers.transformers.transformer_transform_tf import TransformerTransform
+from hanlp.layers.transformers.loader_tf import build_transformer
+from hanlp.layers.transformers.utils_tf import build_adamw_optimizer
+from hanlp.losses.sparse_categorical_crossentropy import SparseCategoricalCrossentropyOverBatchFirstDim
+from hanlp_common.util import merge_locals_kwargs
+
+
+class TransformerTaggingModel(tf.keras.Model):
+ def __init__(self, transformer: tf.keras.Model, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.transformer = transformer
+
+ def call(self, inputs, training=None, mask=None):
+ return super().call(inputs, training, mask)
+
+
+class TransformerTaggerTF(TaggerComponent):
+ def __init__(self, transform: TransformerTransform = None) -> None:
+ if transform is None:
+ transform = TransformerTransform()
+ super().__init__(transform)
+ self.transform: TransformerTransform = transform
+
+ def build_model(self, transformer, max_seq_length, **kwargs) -> tf.keras.Model:
+ model, tokenizer = build_transformer(transformer, max_seq_length, len(self.transform.tag_vocab), tagging=True)
+ self.transform.tokenizer = tokenizer
+ return model
+
+ def fit(self, trn_data, dev_data, save_dir,
+ transformer,
+ optimizer='adamw',
+ learning_rate=5e-5,
+ weight_decay_rate=0,
+ epsilon=1e-8,
+ clipnorm=1.0,
+ warmup_steps_ratio=0,
+ use_amp=False,
+ max_seq_length=128,
+ batch_size=32,
+ epochs=3,
+ metrics='accuracy',
+ run_eagerly=False,
+ logger=None,
+ verbose=True,
+ **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ # noinspection PyMethodOverriding
+ def build_optimizer(self, optimizer, learning_rate, epsilon, weight_decay_rate, clipnorm, use_amp, train_steps,
+ warmup_steps, **kwargs):
+ if optimizer == 'adamw':
+ opt = build_adamw_optimizer(self.config, learning_rate, epsilon, clipnorm, train_steps, use_amp,
+ warmup_steps, weight_decay_rate)
+ else:
+ opt = super().build_optimizer(optimizer)
+ return opt
+
+ def build_vocab(self, trn_data, logger):
+ train_examples = super().build_vocab(trn_data, logger)
+ warmup_steps_per_epoch = math.ceil(train_examples * self.config.warmup_steps_ratio / self.config.batch_size)
+ self.config.warmup_steps = warmup_steps_per_epoch * self.config.epochs
+ return train_examples
+
+ def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_epoch, dev_steps, model, optimizer,
+ loss, metrics, callbacks, logger, **kwargs):
+ history = self.model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch,
+ validation_data=dev_data,
+ callbacks=callbacks,
+ validation_steps=dev_steps,
+ # mask out padding labels
+ # class_weight=dict(
+ # (i, 0 if i == 0 else 1) for i in range(len(self.transform.tag_vocab)))
+ ) # type:tf.keras.callbacks.History
+ return history
+
+ def build_loss(self, loss, **kwargs):
+ if not loss:
+ return SparseCategoricalCrossentropyOverBatchFirstDim()
+ return super().build_loss(loss, **kwargs)
+
+ def load_transform(self, save_dir) -> Transform:
+ super().load_transform(save_dir)
+ self.transform.tokenizer = build_transformer(self.config.transformer, self.config.max_seq_length,
+ len(self.transform.tag_vocab), tagging=True, tokenizer_only=True)
+ return self.transform
diff --git a/hanlp/components/taggers/transformers/transformer_transform.py b/hanlp/components/taggers/transformers/transformer_transform_tf.py
similarity index 88%
rename from hanlp/components/taggers/transformers/transformer_transform.py
rename to hanlp/components/taggers/transformers/transformer_transform_tf.py
index 0aaad2f1b..09ba3d0d3 100644
--- a/hanlp/components/taggers/transformers/transformer_transform.py
+++ b/hanlp/components/taggers/transformers/transformer_transform_tf.py
@@ -5,10 +5,10 @@
import tensorflow as tf
-from hanlp.common.structure import SerializableDict
-from hanlp.common.transform import Transform
-from hanlp.common.vocab import Vocab
-from hanlp.components.taggers.transformers.utils import convert_examples_to_features, config_is
+from hanlp_common.structure import SerializableDict
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.layers.transformers.utils_tf import convert_examples_to_features
from hanlp.transform.tsv import TsvTaggingFormat
@@ -19,7 +19,7 @@ def __init__(self,
map_x=False, map_y=False, **kwargs) -> None:
super().__init__(config, map_x, map_y, **kwargs)
self._tokenizer = tokenizer
- self.tag_vocab: Vocab = None
+ self.tag_vocab: VocabTF = None
self.special_token_ids = None
self.pad = '[PAD]'
self.unk = '[UNK]'
@@ -36,16 +36,17 @@ def tokenizer(self):
@tokenizer.setter
def tokenizer(self, tokenizer):
self._tokenizer = tokenizer
- if self.pad not in tokenizer.vocab:
+ vocab = tokenizer._vocab if hasattr(tokenizer, '_vocab') else tokenizer.vocab
+ if self.pad not in vocab:
# English albert use instead of [PAD]
self.pad = ''
- if self.unk not in tokenizer.vocab:
+ if self.unk not in vocab:
self.unk = ''
- self.special_token_ids = tf.constant([tokenizer.vocab[token] for token in [self.pad, '[CLS]', '[SEP]']],
+ self.special_token_ids = tf.constant([vocab[token] for token in [self.pad, '[CLS]', '[SEP]']],
dtype=tf.int32)
def fit(self, trn_path: str, **kwargs) -> int:
- self.tag_vocab = Vocab(unk_token=None)
+ self.tag_vocab = VocabTF(unk_token=None)
num_samples = 0
for words, tags in self.file_to_inputs(trn_path, gold=True):
num_samples += 1
@@ -81,9 +82,10 @@ def inputs_to_samples(self, inputs, gold=False):
else:
words, tags = sample, [self.tag_vocab.idx_to_token[1]] * len(sample)
- input_ids, input_mask, segment_ids, label_ids = convert_examples_to_features(words, tags,
- self.tag_vocab.token_to_idx,
+ input_ids, input_mask, segment_ids, label_ids = convert_examples_to_features(words,
max_seq_length, tokenizer,
+ tags,
+ self.tag_vocab.token_to_idx,
cls_token_at_end=xlnet,
# xlnet has a cls token at the end
cls_token=cls_token,
@@ -93,7 +95,7 @@ def inputs_to_samples(self, inputs, gold=False):
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=xlnet,
# pad on the left for xlnet
- pad_token=pad_token,
+ pad_token_id=pad_token,
pad_token_segment_id=4 if xlnet else 0,
pad_token_label_id=pad_label_idx,
unk_token=unk_token)
diff --git a/hanlp/components/taggers/transformers/utils.py b/hanlp/components/taggers/transformers/utils.py
deleted file mode 100644
index cc70985c0..000000000
--- a/hanlp/components/taggers/transformers/utils.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-29 15:32
-from hanlp.utils.log_util import logger
-
-
-def config_is(config, model='bert'):
- return model in type(config).__name__.lower()
-
-
-def convert_examples_to_features(
- words,
- labels,
- label_map,
- max_seq_length,
- tokenizer,
- cls_token_at_end=False,
- cls_token="[CLS]",
- cls_token_segment_id=1,
- sep_token="[SEP]",
- sep_token_extra=False,
- pad_on_left=False,
- pad_token=0,
- pad_token_segment_id=0,
- pad_token_label_id=0,
- sequence_a_segment_id=0,
- mask_padding_with_zero=True,
- unk_token='[UNK]'
-):
- """ Loads a data file into a list of `InputBatch`s
- `cls_token_at_end` define the location of the CLS token:
- - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
- `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
- """
- args = locals()
- assert label_map, 'label_map is required'
-
- tokens = []
- label_ids = []
- for word, label in zip(words, labels):
- word_tokens = tokenizer.tokenize(word)
- if not word_tokens:
- # some wired chars cause the tagger to return empty list
- word_tokens = [unk_token] * len(word)
- tokens.extend(word_tokens)
- # Use the real label id for the first token of the word, and padding ids for the remaining tokens
- label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))
-
- # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
- special_tokens_count = 3 if sep_token_extra else 2
- if len(tokens) > max_seq_length - special_tokens_count:
- logger.warning(
- f'Input tokens {words} exceed the max sequence length of {max_seq_length - special_tokens_count}. '
- f'The exceeded part will be truncated and ignored. '
- f'You are recommended to split your long text into several sentences within '
- f'{max_seq_length - special_tokens_count} tokens beforehand.')
- tokens = tokens[: (max_seq_length - special_tokens_count)]
- label_ids = label_ids[: (max_seq_length - special_tokens_count)]
-
- # The convention in BERT is:
- # (a) For sequence pairs:
- # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
- # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
- # (b) For single sequences:
- # tokens: [CLS] the dog is hairy . [SEP]
- # type_ids: 0 0 0 0 0 0 0
- #
- # Where "type_ids" are used to indicate whether this is the first
- # sequence or the second sequence. The embedding vectors for `type=0` and
- # `type=1` were learned during pre-training and are added to the wordpiece
- # embedding vector (and position vector). This is not *strictly* necessary
- # since the [SEP] token unambiguously separates the sequences, but it makes
- # it easier for the model to learn the concept of sequences.
- #
- # For classification tasks, the first vector (corresponding to [CLS]) is
- # used as as the "sentence vector". Note that this only makes sense because
- # the entire model is fine-tuned.
- tokens += [sep_token]
- label_ids += [pad_token_label_id]
- if sep_token_extra:
- # roberta uses an extra separator b/w pairs of sentences
- tokens += [sep_token]
- label_ids += [pad_token_label_id]
- segment_ids = [sequence_a_segment_id] * len(tokens)
-
- if cls_token_at_end:
- tokens += [cls_token]
- label_ids += [pad_token_label_id]
- segment_ids += [cls_token_segment_id]
- else:
- tokens = [cls_token] + tokens
- label_ids = [pad_token_label_id] + label_ids
- segment_ids = [cls_token_segment_id] + segment_ids
-
- input_ids = tokenizer.convert_tokens_to_ids(tokens)
-
- # The mask has 1 for real tokens and 0 for padding tokens. Only real
- # tokens are attended to.
- input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
-
- # Zero-pad up to the sequence length.
- padding_length = max_seq_length - len(input_ids)
- if pad_on_left:
- input_ids = ([pad_token] * padding_length) + input_ids
- input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
- segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
- label_ids = ([pad_token_label_id] * padding_length) + label_ids
- else:
- input_ids += [pad_token] * padding_length
- input_mask += [0 if mask_padding_with_zero else 1] * padding_length
- segment_ids += [pad_token_segment_id] * padding_length
- label_ids += [pad_token_label_id] * padding_length
-
- assert len(input_ids) == max_seq_length
- assert len(input_mask) == max_seq_length
- assert len(segment_ids) == max_seq_length
- assert len(label_ids) == max_seq_length, f'failed for:\n {args}'
- return input_ids, input_mask, segment_ids, label_ids
diff --git a/hanlp/components/taggers/util.py b/hanlp/components/taggers/util.py
new file mode 100644
index 000000000..7efa27525
--- /dev/null
+++ b/hanlp/components/taggers/util.py
@@ -0,0 +1,22 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-01 00:31
+from typing import List, Tuple
+from alnlp.modules.conditional_random_field import allowed_transitions
+
+
+def guess_tagging_scheme(labels: List[str]) -> str:
+ tagset = set(y.split('-')[0] for y in labels)
+ for scheme in "BIO", "BIOUL", "BMES", 'IOBES':
+ if tagset == set(list(scheme)):
+ return scheme
+
+
+def guess_allowed_transitions(labels) -> List[Tuple[int, int]]:
+ scheme = guess_tagging_scheme(labels)
+ if not scheme:
+ return None
+ if scheme == 'IOBES':
+ scheme = 'BIOUL'
+ labels = [y.replace('E-', 'L-').replace('S-', 'U-') for y in labels]
+ return allowed_transitions(scheme, dict(enumerate(labels)))
diff --git a/hanlp/components/tok.py b/hanlp/components/tok.py
index f55c561a0..2f6449684 100644
--- a/hanlp/components/tok.py
+++ b/hanlp/components/tok.py
@@ -1,110 +1,54 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-10-27 14:30
-import logging
-from typing import Union, Any, List, Tuple, Iterable
+# Date: 2020-06-12 13:08
+from typing import Any, Callable
-import tensorflow as tf
-
-from hanlp.common.component import KerasComponent
-from hanlp.components.taggers.ngram_conv.ngram_conv_tagger import NgramTransform, NgramConvTagger
from hanlp.components.taggers.rnn_tagger import RNNTagger
-from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
-from hanlp.components.taggers.transformers.transformer_transform import TransformerTransform
-from hanlp.losses.sparse_categorical_crossentropy import SparseCategoricalCrossentropyOverBatchFirstDim
-from hanlp.metrics.chunking.bmes import BMES_F1
-from hanlp.transform.tsv import TSVTaggingTransform
-from hanlp.transform.txt import extract_ngram_features_and_tags, bmes_to_words, TxtFormat, TxtBMESFormat
-from hanlp.utils.util import merge_locals_kwargs
-
-
-class BMESTokenizer(KerasComponent):
-
- def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
- if metrics == 'f1':
- self.config.run_eagerly = True
- return BMES_F1(self.transform.tag_vocab)
- return super().build_metrics(metrics, logger, **kwargs)
-
-
-class NgramConvTokenizerTransform(TxtFormat, NgramTransform):
-
- def inputs_to_samples(self, inputs, gold=False):
- if self.input_is_single_sample(inputs):
- inputs = [inputs]
- for sent in inputs:
- # bigram_only = false
- yield extract_ngram_features_and_tags(sent, False, self.config.window_size, gold)
-
- def input_is_single_sample(self, input: Union[List[str], List[List[str]]]) -> bool:
- if not input:
- return True
- return isinstance(input, str)
-
- def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None,
- **kwargs) -> Iterable:
- yield from TxtBMESFormat.Y_to_tokens(self, self.tag_vocab, Y, gold, inputs)
-
-
-class NgramConvTokenizer(BMESTokenizer, NgramConvTagger):
-
- def __init__(self) -> None:
- super().__init__(NgramConvTokenizerTransform())
-
- def fit(self, trn_data: Any, dev_data: Any, save_dir: str, word_embed: Union[str, int, dict] = 200,
- ngram_embed: Union[str, int, dict] = 50, embedding_trainable=True, window_size=4, kernel_size=3,
- filters=(200, 200, 200, 200, 200), dropout_embed=0.2, dropout_hidden=0.2, weight_norm=True,
- loss: Union[tf.keras.losses.Loss, str] = None,
- optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='f1', batch_size=100,
- epochs=100, logger=None, verbose=True, **kwargs):
+from hanlp.datasets.cws.chunking_dataset import ChunkingDataset
+from hanlp.metrics.chunking.chunking_f1 import ChunkingF1
+from hanlp.utils.span_util import bmes_to_words
+from hanlp_common.util import merge_locals_kwargs
+
+
+class RNNTokenizer(RNNTagger):
+
+ def predict(self, sentence: Any, batch_size: int = None, **kwargs):
+ flat = isinstance(sentence, str)
+ if flat:
+ sentence = [sentence]
+ for i, s in enumerate(sentence):
+ sentence[i] = list(s)
+ outputs = RNNTagger.predict(self, sentence, batch_size, **kwargs)
+ if flat:
+ return outputs[0]
+ return outputs
+
+ def predict_data(self, data, batch_size, **kwargs):
+ tags = RNNTagger.predict_data(self, data, batch_size, **kwargs)
+ words = [bmes_to_words(c, t) for c, t in zip(data, tags)]
+ return words
+
+ def build_dataset(self, data, transform=None):
+ dataset = ChunkingDataset(data)
+ if 'transform' in self.config:
+ dataset.append_transform(self.config.transform)
+ if transform:
+ dataset.append_transform(transform)
+ return dataset
+
+ def build_metric(self, **kwargs):
+ return ChunkingF1()
+
+ def update_metrics(self, metric, logits, y, mask, batch):
+ pred = self.decode_output(logits, mask, batch)
+ pred = self._id_to_tags(pred)
+ gold = batch['tag']
+ metric(pred, gold)
+
+ def fit(self, trn_data, dev_data, save_dir, batch_size=50, epochs=100, embed=100, rnn_input=None, rnn_hidden=256,
+ drop=0.5, lr=0.001, patience=10, crf=True, optimizer='adam', token_key='char', tagging_scheme=None,
+ anneal_factor: float = 0.5, anneal_patience=2, devices=None, logger=None,
+ verbose=True, transform: Callable = None, **kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
- def evaluate_output_to_file(self, batch, outputs, out):
- for x, y_pred in zip(self.transform.X_to_inputs(batch[0]),
- self.transform.Y_to_outputs(outputs, gold=False)):
- out.write(self.transform.input_truth_output_to_str(x, None, y_pred))
- out.write('\n')
- def build_loss(self, loss, **kwargs):
- if loss is None:
- return SparseCategoricalCrossentropyOverBatchFirstDim()
- return super().build_loss(loss, **kwargs)
-
-
-class TransformerTokenizerTransform(TxtBMESFormat, TransformerTransform):
-
- def inputs_to_samples(self, inputs, gold=False):
- yield from TransformerTransform.inputs_to_samples(self, TxtBMESFormat.inputs_to_samples(self, inputs, gold),
- True)
-
- def Y_to_tokens(self, tag_vocab, Y, gold, inputs):
- if not gold:
- Y = tf.argmax(Y, axis=2)
- for text, ys in zip(inputs, Y):
- tags = [tag_vocab.idx_to_token[int(y)] for y in ys[1:len(text) + 1]]
- yield bmes_to_words(list(text), tags)
-
-
-class TransformerTokenizer(BMESTokenizer, TransformerTagger):
- def __init__(self, transform: TransformerTokenizerTransform = None) -> None:
- if transform is None:
- transform = TransformerTokenizerTransform()
- super().__init__(transform)
-
-
-class RNNTokenizerTransform(TxtBMESFormat, TSVTaggingTransform):
- pass
-
-
-class RNNTokenizer(BMESTokenizer, RNNTagger):
- def __init__(self, transform: RNNTokenizerTransform = None) -> None:
- if not transform:
- transform = RNNTokenizerTransform()
- super().__init__(transform)
-
- def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
- rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, lower=False, max_seq_len=50,
- logger=None, loss: Union[tf.keras.losses.Loss, str] = None,
- optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='f1', batch_size=32,
- dev_batch_size=32, lr_decay_per_epoch=None, verbose=True, **kwargs):
- return super().fit(**merge_locals_kwargs(locals(), kwargs))
diff --git a/hanlp/components/tok_tf.py b/hanlp/components/tok_tf.py
new file mode 100644
index 000000000..32a4332aa
--- /dev/null
+++ b/hanlp/components/tok_tf.py
@@ -0,0 +1,110 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-10-27 14:30
+import logging
+from typing import Union, Any, List, Tuple, Iterable
+
+import tensorflow as tf
+
+from hanlp.common.keras_component import KerasComponent
+from hanlp.components.taggers.ngram_conv.ngram_conv_tagger import NgramTransform, NgramConvTaggerTF
+from hanlp.components.taggers.rnn_tagger_tf import RNNTaggerTF
+from hanlp.components.taggers.transformers.transformer_tagger_tf import TransformerTaggerTF
+from hanlp.components.taggers.transformers.transformer_transform_tf import TransformerTransform
+from hanlp.losses.sparse_categorical_crossentropy import SparseCategoricalCrossentropyOverBatchFirstDim
+from hanlp.metrics.chunking.bmes import BMES_F1_TF
+from hanlp.transform.tsv import TSVTaggingTransform
+from hanlp.transform.txt import TxtFormat, TxtBMESFormat, extract_ngram_features_and_tags, bmes_to_words
+from hanlp_common.util import merge_locals_kwargs
+
+
+class BMESTokenizerTF(KerasComponent):
+
+ def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
+ if metrics == 'f1':
+ self.config.run_eagerly = True
+ return BMES_F1_TF(self.transform.tag_vocab)
+ return super().build_metrics(metrics, logger, **kwargs)
+
+
+class NgramConvTokenizerTransform(TxtFormat, NgramTransform):
+
+ def inputs_to_samples(self, inputs, gold=False):
+ if self.input_is_single_sample(inputs):
+ inputs = [inputs]
+ for sent in inputs:
+ # bigram_only = false
+ yield extract_ngram_features_and_tags(sent, False, self.config.window_size, gold)
+
+ def input_is_single_sample(self, input: Union[List[str], List[List[str]]]) -> bool:
+ if not input:
+ return True
+ return isinstance(input, str)
+
+ def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None,
+ **kwargs) -> Iterable:
+ yield from TxtBMESFormat.Y_to_tokens(self, self.tag_vocab, Y, gold, inputs)
+
+
+class NgramConvTokenizerTF(BMESTokenizerTF, NgramConvTaggerTF):
+
+ def __init__(self) -> None:
+ super().__init__(NgramConvTokenizerTransform())
+
+ def fit(self, trn_data: Any, dev_data: Any, save_dir: str, word_embed: Union[str, int, dict] = 200,
+ ngram_embed: Union[str, int, dict] = 50, embedding_trainable=True, window_size=4, kernel_size=3,
+ filters=(200, 200, 200, 200, 200), dropout_embed=0.2, dropout_hidden=0.2, weight_norm=True,
+ loss: Union[tf.keras.losses.Loss, str] = None,
+ optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='f1', batch_size=100,
+ epochs=100, logger=None, verbose=True, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def evaluate_output_to_file(self, batch, outputs, out):
+ for x, y_pred in zip(self.transform.X_to_inputs(batch[0]),
+ self.transform.Y_to_outputs(outputs, gold=False)):
+ out.write(self.transform.input_truth_output_to_str(x, None, y_pred))
+ out.write('\n')
+
+ def build_loss(self, loss, **kwargs):
+ if loss is None:
+ return SparseCategoricalCrossentropyOverBatchFirstDim()
+ return super().build_loss(loss, **kwargs)
+
+
+class TransformerTokenizerTransform(TxtBMESFormat, TransformerTransform):
+
+ def inputs_to_samples(self, inputs, gold=False):
+ yield from TransformerTransform.inputs_to_samples(self, TxtBMESFormat.inputs_to_samples(self, inputs, gold),
+ True)
+
+ def Y_to_tokens(self, tag_vocab, Y, gold, inputs):
+ if not gold:
+ Y = tf.argmax(Y, axis=2)
+ for text, ys in zip(inputs, Y):
+ tags = [tag_vocab.idx_to_token[int(y)] for y in ys[1:len(text) + 1]]
+ yield bmes_to_words(list(text), tags)
+
+
+class TransformerTokenizerTF(BMESTokenizerTF, TransformerTaggerTF):
+ def __init__(self, transform: TransformerTokenizerTransform = None) -> None:
+ if transform is None:
+ transform = TransformerTokenizerTransform()
+ super().__init__(transform)
+
+
+class RNNTokenizerTransform(TxtBMESFormat, TSVTaggingTransform):
+ pass
+
+
+class RNNTokenizerTF(BMESTokenizerTF, RNNTaggerTF):
+ def __init__(self, transform: RNNTokenizerTransform = None) -> None:
+ if not transform:
+ transform = RNNTokenizerTransform()
+ super().__init__(transform)
+
+ def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
+ rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, lower=False, max_seq_len=50,
+ logger=None, loss: Union[tf.keras.losses.Loss, str] = None,
+ optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='f1', batch_size=32,
+ dev_batch_size=32, lr_decay_per_epoch=None, verbose=True, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
diff --git a/hanlp/components/tokenizers/__init__.py b/hanlp/components/tokenizers/__init__.py
new file mode 100644
index 000000000..f407a32d2
--- /dev/null
+++ b/hanlp/components/tokenizers/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-11 02:48
\ No newline at end of file
diff --git a/hanlp/components/tokenizers/multi_criteria_cws_transformer.py b/hanlp/components/tokenizers/multi_criteria_cws_transformer.py
new file mode 100644
index 000000000..d1f2688a5
--- /dev/null
+++ b/hanlp/components/tokenizers/multi_criteria_cws_transformer.py
@@ -0,0 +1,89 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-21 19:55
+from typing import List, Union
+
+from hanlp.common.dataset import SamplerBuilder
+from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
+from hanlp.components.tokenizers.transformer import TransformerTaggingTokenizer
+from hanlp.datasets.cws.multi_criteria_cws.mcws_dataset import MultiCriteriaTextTokenizingDataset, append_criteria_token
+import functools
+
+from hanlp.metrics.f1 import F1
+from hanlp.metrics.mtl import MetricDict
+from hanlp_common.util import merge_locals_kwargs
+
+
+class MultiCriteriaTransformerTaggingTokenizer(TransformerTaggingTokenizer):
+ def __init__(self, **kwargs) -> None:
+ r"""Transformer based implementation of "Effective Neural Solution for Multi-Criteria Word Segmentation"
+ (:cite:`he2019effective`). It uses an artificial token ``[unused_i]`` instead of ``[SEP]`` in the input_ids to
+ mark the i-th segmentation criteria.
+
+ Args:
+ **kwargs: Not used.
+ """
+ super().__init__(**kwargs)
+
+ def build_dataset(self, data, **kwargs):
+ return MultiCriteriaTextTokenizingDataset(data, **kwargs)
+
+ def on_config_ready(self, **kwargs):
+ super().on_config_ready(**kwargs)
+ # noinspection PyAttributeOutsideInit
+ if 'criteria_token_map' not in self.config:
+ unused_tokens = [f'[unused{i}]' for i in range(1, 100)]
+ ids = self.transformer_tokenizer.convert_tokens_to_ids(unused_tokens)
+ self.config.unused_tokens = dict((x, ids[i]) for i, x in enumerate(unused_tokens) if
+ ids[i] != self.transformer_tokenizer.unk_token_id)
+ self.config.criteria_token_map = dict()
+
+ def last_transform(self):
+ transforms = super().last_transform()
+ transforms.append(functools.partial(append_criteria_token,
+ criteria_tokens=self.config.unused_tokens,
+ criteria_token_map=self.config.criteria_token_map))
+ return transforms
+
+ def build_vocabs(self, trn, logger, **kwargs):
+ super().build_vocabs(trn, logger, **kwargs)
+ logger.info(f'criteria[{len(self.config.criteria_token_map)}] = {list(self.config.criteria_token_map)}')
+
+ def feed_batch(self, batch: dict):
+ x, mask = TransformerTagger.feed_batch(self, batch)
+ # strip [CLS], [SEP] and [unused_i]
+ return x[:, 1:-2, :], mask
+
+ def build_samples(self, data: List[str], criteria=None, **kwargs):
+ if not criteria:
+ criteria = next(iter(self.config.criteria_token_map.keys()))
+ else:
+ assert criteria in self.config.criteria_token_map, \
+ f'Unsupported criteria {criteria}. Choose one from {list(self.config.criteria_token_map.keys())}'
+ samples = super().build_samples(data, **kwargs)
+ for sample in samples:
+ sample['criteria'] = criteria
+ return samples
+
+ def build_metric(self, **kwargs):
+ metrics = MetricDict()
+ for criteria in self.config.criteria_token_map:
+ metrics[criteria] = F1()
+ return metrics
+
+ def update_metrics(self, metric, logits, y, mask, batch, prediction):
+ for p, g, c in zip(prediction, self.tag_to_span(batch['tag']), batch['criteria']):
+ pred = set(p)
+ gold = set(g)
+ metric[c](pred, gold)
+
+ def fit(self, trn_data, dev_data, save_dir, transformer, average_subwords=False, word_dropout: float = 0.2,
+ hidden_dropout=None, layer_dropout=0, scalar_mix=None, mix_embedding: int = 0, grad_norm=5.0,
+ transformer_grad_norm=None, lr=5e-5,
+ transformer_lr=None, transformer_layers=None, gradient_accumulation=1,
+ adam_epsilon=1e-8, weight_decay=0, warmup_steps=0.1, crf=False, reduction='sum',
+ batch_size=32, sampler_builder: SamplerBuilder = None, epochs=30, patience=5, token_key=None,
+ tagging_scheme='BMES', delimiter=None,
+ max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False, transform=None, logger=None,
+ devices: Union[float, int, List[int]] = None, **kwargs):
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
diff --git a/hanlp/components/tokenizers/transformer.py b/hanlp/components/tokenizers/transformer.py
new file mode 100644
index 000000000..c1fd83be0
--- /dev/null
+++ b/hanlp/components/tokenizers/transformer.py
@@ -0,0 +1,256 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-11 02:48
+import functools
+from typing import TextIO, Union, List, Dict, Any, Set
+from bisect import bisect
+
+import torch
+
+from hanlp.common.dataset import SamplerBuilder
+from hanlp.common.transform import TransformList
+from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
+from hanlp.datasets.tokenization.txt import TextTokenizingDataset, generate_tags_for_subtokens
+from hanlp.metrics.f1 import F1
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp.utils.span_util import bmes_to_spans
+from hanlp_common.util import merge_locals_kwargs
+from hanlp_trie import DictInterface, TrieDict
+
+
+class TransformerTaggingTokenizer(TransformerTagger):
+
+ def __init__(self, **kwargs) -> None:
+ """ A tokenizer using transformer tagger for span prediction. It features with 2 high performance dictionaries
+ to handle edge cases in real application.
+
+ - High priority dictionary: Perform longest-prefix-matching on input text which take higher priority over model predictions.
+ - Low priority dictionary: Perform longest-prefix-matching on model predictions and combing them.
+
+ .. Note:: For algorithm beginners, longest-prefix-matching is the prerequisite to understand what dictionary can
+ do and what it can't do. The tutorial in `this book `_ can be very helpful.
+
+ Args:
+ **kwargs: Predefined config.
+ """
+ super().__init__(**kwargs)
+
+ @property
+ def dict_force(self) -> DictInterface:
+ r""" The high priority dictionary which perform longest-prefix-matching on inputs to split them into two subsets:
+
+ 1. spans containing no keywords
+ 2. keywords
+
+ These spans are then fed into tokenizer for further tokenization.
+
+ .. Caution::
+ Longest-prefix-matching **NEVER** guarantee the presence of any keywords. Abuse of
+ ``dict_force`` can lead to low quality results. For more details, refer to
+ `this book `_.
+
+ Examples:
+ >>> tok.dict_force = {'和服', '服务行业'} # Force '和服' and '服务行业' by longest-prefix-matching
+ >>> tok("商品和服务行业")
+ ['商品', '和服', '务行业']
+ >>> tok.dict_force = {'和服务': ['和', '服务']} # Force '和服务' to be tokenized as ['和', '服务']
+ >>> tok("商品和服务行业")
+ ['商品', '和', '服务', '行业']
+ """
+ return self.config.get('dict_force', None)
+
+ @dict_force.setter
+ def dict_force(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
+ if not isinstance(dictionary, DictInterface):
+ dictionary = TrieDict(dictionary)
+ self.config.dict_force = dictionary
+ self.tokenizer_transform.dict = dictionary
+
+ @property
+ def dict_combine(self) -> DictInterface:
+ """ The low priority dictionary which perform longest-prefix-matching on model predictions and combing them.
+
+ Examples:
+ >>> tok.dict_combine = {'和服', '服务行业'}
+ >>> tok("商品和服务行业")
+ ['商品', '和', '服务行业']
+
+ """
+ return self.config.get('dict_combine', None)
+
+ @dict_combine.setter
+ def dict_combine(self, dictionary: Union[DictInterface, Union[Dict[str, Any], Set[str]]]):
+ if not isinstance(dictionary, DictInterface):
+ dictionary = TrieDict(dictionary)
+ self.config.dict_combine = dictionary
+
+ def build_metric(self, **kwargs):
+ return F1()
+
+ # noinspection PyMethodOverriding
+ def update_metrics(self, metric, logits, y, mask, batch, prediction):
+ for p, g in zip(prediction, self.tag_to_span(batch['tag'], batch)):
+ pred = set(p)
+ gold = set(g)
+ metric(pred, gold)
+
+ def decode_output(self, logits, mask, batch, model=None):
+ output = super().decode_output(logits, mask, batch, model)
+ if isinstance(output, torch.Tensor):
+ output = output.tolist()
+ prediction = self.id_to_tags(output, [len(x) for x in batch['token']])
+ return self.tag_to_span(prediction, batch)
+
+ def tag_to_span(self, batch_tags, batch: dict):
+ spans = []
+ if 'custom_words' in batch:
+ if self.config.tagging_scheme == 'BMES':
+ S = 'S'
+ M = 'M'
+ E = 'E'
+ else:
+ S = 'B'
+ M = 'I'
+ E = 'I'
+ for tags, subwords, custom_words in zip(batch_tags, batch['token_subtoken_offsets'], batch['custom_words']):
+ assert len(tags) == len(subwords)
+ # [batch['raw_token'][0][x[0]:x[1]] for x in subwords]
+ if custom_words:
+ for start, end, label in custom_words:
+ if end - start == 1:
+ tags[start] = S
+ else:
+ tags[start] = 'B'
+ tags[end - 1] = E
+ for i in range(start + 1, end - 1):
+ tags[i] = M
+ if end < len(tags):
+ tags[end] = 'B'
+ spans.append(bmes_to_spans(tags))
+ else:
+ for tags in batch_tags:
+ spans.append(bmes_to_spans(tags))
+ return spans
+
+ def write_prediction(self, prediction, batch, output: TextIO):
+ batch_tokens = self.spans_to_tokens(prediction, batch)
+ for tokens in batch_tokens:
+ output.write(' '.join(tokens))
+ output.write('\n')
+
+ @property
+ def tokenizer_transform(self):
+ if not self._tokenizer_transform:
+ self._tokenizer_transform = TransformerSequenceTokenizer(self.transformer_tokenizer,
+ self.config.token_key,
+ ret_subtokens=True,
+ ret_subtokens_group=True,
+ ret_token_span=False)
+ return self._tokenizer_transform
+
+ def spans_to_tokens(self, spans, batch, rebuild_span=False):
+ batch_tokens = []
+ dict_combine = self.dict_combine
+ for spans_per_sent, sub_tokens in zip(spans, batch[self.config.token_key]):
+ tokens = [''.join(sub_tokens[span[0]:span[1]]) for span in spans_per_sent]
+ if dict_combine:
+ if rebuild_span:
+ char_to_span = []
+ offset = 0
+ for start, end in spans_per_sent:
+ char_to_span.append(offset)
+ offset += sum(len(x) for x in sub_tokens[start:end])
+ buffer = []
+ offset = 0
+ for start, end, label in dict_combine.tokenize(tokens):
+ # batch['raw_token'][0][start:end]
+ if offset < start:
+ buffer.extend(tokens[offset:start])
+ buffer.append(''.join(tokens[start:end]))
+ offset = end
+ if rebuild_span:
+ combined_span = (spans_per_sent[start][0], spans_per_sent[end - 1][1])
+ del spans_per_sent[start:end]
+ spans_per_sent.insert(start, combined_span)
+ if offset < len(tokens):
+ buffer.extend(tokens[offset:])
+ tokens = buffer
+ batch_tokens.append(tokens)
+ return batch_tokens
+
+ def generate_prediction_filename(self, tst_data, save_dir):
+ return super().generate_prediction_filename(tst_data.replace('.tsv', '.txt'), save_dir)
+
+ def prediction_to_human(self, pred, vocab, batch, rebuild_span=False):
+ return self.spans_to_tokens(pred, batch, rebuild_span)
+
+ def input_is_flat(self, tokens):
+ return isinstance(tokens, str)
+
+ def build_dataset(self, data, **kwargs):
+ return TextTokenizingDataset(data, **kwargs)
+
+ def last_transform(self):
+ return TransformList(functools.partial(generate_tags_for_subtokens, tagging_scheme=self.config.tagging_scheme),
+ super().last_transform())
+
+ def fit(self, trn_data, dev_data, save_dir, transformer, average_subwords=False, word_dropout: float = 0.2,
+ hidden_dropout=None, layer_dropout=0, scalar_mix=None, grad_norm=5.0,
+ transformer_grad_norm=None, lr=5e-5,
+ transformer_lr=None, transformer_layers=None, gradient_accumulation=1,
+ adam_epsilon=1e-8, weight_decay=0, warmup_steps=0.1, crf=False, reduction='sum',
+ batch_size=32, sampler_builder: SamplerBuilder = None, epochs=30, patience=5, token_key=None,
+ tagging_scheme='BMES', delimiter=None,
+ max_seq_len=None, sent_delimiter=None, char_level=False, hard_constraint=False, transform=None, logger=None,
+ devices: Union[float, int, List[int]] = None, **kwargs):
+ """
+
+ Args:
+ trn_data: Training set.
+ dev_data: Development set.
+ save_dir: The directory to save trained component.
+ transformer: An identifier of a pre-trained transformer.
+ average_subwords: ``True`` to average subword representations.
+ word_dropout: Dropout rate to randomly replace a subword with MASK.
+ hidden_dropout: Dropout rate applied to hidden states.
+ layer_dropout: Randomly zero out hidden states of a transformer layer.
+ scalar_mix: Layer attention.
+ grad_norm: Gradient norm for clipping.
+ transformer_grad_norm: Gradient norm for clipping transformer gradient.
+ lr: Learning rate for decoder.
+ transformer_lr: Learning for encoder.
+ transformer_layers: The number of bottom layers to use.
+ gradient_accumulation: Number of batches per update.
+ adam_epsilon: The epsilon to use in Adam.
+ weight_decay: The weight decay to use.
+ warmup_steps: The number of warmup steps.
+ crf: ``True`` to enable CRF (:cite:`lafferty2001conditional`).
+ reduction: The loss reduction used in aggregating losses.
+ batch_size: The number of samples in a batch.
+ sampler_builder: The builder to build sampler, which will override batch_size.
+ epochs: The number of epochs to train.
+ patience: The number of patience epochs before early stopping.
+ token_key: The key to tokens in dataset.
+ tagging_scheme: Either ``BMES`` or ``BI``.
+ delimiter: Delimiter between tokens used to split a line in the corpus.
+ max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ transform: An optional transform to be applied to samples. Usually a character normalization transform is
+ passed in.
+ devices: Devices this component will live on.
+ logger: Any :class:`logging.Logger` instance.
+ seed: Random seed to reproduce this training.
+ **kwargs: Not used.
+
+ Returns:
+ Best metrics on dev set.
+ """
+ return super().fit(**merge_locals_kwargs(locals(), kwargs))
+
+ def feed_batch(self, batch: dict):
+ x, mask = super().feed_batch(batch)
+ return x[:, 1:-1, :], mask
diff --git a/hanlp/datasets/__init__.py b/hanlp/datasets/__init__.py
index 24ac73095..24338737f 100644
--- a/hanlp/datasets/__init__.py
+++ b/hanlp/datasets/__init__.py
@@ -1,8 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-06-13 18:15
-from . import classification
-from . import cws
-from . import parsing
-from . import pos
-from . import glue
diff --git a/hanlp/datasets/classification/sentiment.py b/hanlp/datasets/classification/sentiment.py
index b08bb0171..084f10469 100644
--- a/hanlp/datasets/classification/sentiment.py
+++ b/hanlp/datasets/classification/sentiment.py
@@ -4,5 +4,5 @@
_ERNIE_TASK_DATA = 'https://ernie.bj.bcebos.com/task_data_zh.tgz#'
CHNSENTICORP_ERNIE_TRAIN = _ERNIE_TASK_DATA + 'chnsenticorp/train.tsv'
-CHNSENTICORP_ERNIE_VALID = _ERNIE_TASK_DATA + 'chnsenticorp/dev.tsv'
+CHNSENTICORP_ERNIE_DEV = _ERNIE_TASK_DATA + 'chnsenticorp/dev.tsv'
CHNSENTICORP_ERNIE_TEST = _ERNIE_TASK_DATA + 'chnsenticorp/test.tsv'
diff --git a/hanlp/datasets/coref/__init__.py b/hanlp/datasets/coref/__init__.py
new file mode 100644
index 000000000..07956f7d4
--- /dev/null
+++ b/hanlp/datasets/coref/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-04 13:39
\ No newline at end of file
diff --git a/hanlp/datasets/coref/conll12coref.py b/hanlp/datasets/coref/conll12coref.py
new file mode 100644
index 000000000..7e14c3436
--- /dev/null
+++ b/hanlp/datasets/coref/conll12coref.py
@@ -0,0 +1,88 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-04 15:33
+import collections
+import os
+from typing import Union, List, Callable, DefaultDict, Tuple, Optional, Iterator
+
+from alnlp.data.ontonotes import Ontonotes as _Ontonotes, OntonotesSentence
+from alnlp.data.util import make_coref_instance
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp.utils.io_util import TimingFileIterator
+
+
+class Ontonotes(_Ontonotes):
+ def dataset_document_iterator(self, file_path: str) -> Iterator[List[OntonotesSentence]]:
+ """An iterator over CONLL formatted files which yields documents, regardless
+ of the number of document annotations in a particular file. This is useful
+ for conll data which has been preprocessed, such as the preprocessing which
+ takes place for the 2012 CONLL Coreference Resolution task.
+
+ Args:
+ file_path: str:
+
+ Returns:
+
+ """
+ open_file = TimingFileIterator(file_path)
+ conll_rows = []
+ document: List[OntonotesSentence] = []
+ for line in open_file:
+ open_file.log(f'Loading {os.path.basename(file_path)}')
+ line = line.strip()
+ if line != "" and not line.startswith("#"):
+ # Non-empty line. Collect the annotation.
+ conll_rows.append(line)
+ else:
+ if conll_rows:
+ document.append(self._conll_rows_to_sentence(conll_rows))
+ conll_rows = []
+ if line.startswith("#end document"):
+ yield document
+ document = []
+ open_file.erase()
+ if document:
+ # Collect any stragglers or files which might not
+ # have the '#end document' format for the end of the file.
+ yield document
+
+
+class CONLL12CorefDataset(TransformableDataset):
+
+ def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None,
+ max_span_width=10, max_sentences=None, remove_singleton_clusters=False) -> None:
+ self.remove_singleton_clusters = remove_singleton_clusters
+ self.max_sentences = max_sentences
+ self.max_span_width = max_span_width
+ super().__init__(data, transform, cache)
+
+ def load_file(self, filepath: str):
+ ontonotes_reader = Ontonotes()
+ for sentences in ontonotes_reader.dataset_document_iterator(filepath):
+ clusters: DefaultDict[int, List[Tuple[int, int]]] = collections.defaultdict(list)
+
+ total_tokens = 0
+ for sentence in sentences:
+ for typed_span in sentence.coref_spans:
+ # Coref annotations are on a _per sentence_
+ # basis, so we need to adjust them to be relative
+ # to the length of the document.
+ span_id, (start, end) = typed_span
+ clusters[span_id].append((start + total_tokens, end + total_tokens))
+ total_tokens += len(sentence.words)
+
+ yield self.text_to_instance([s.words for s in sentences], list(clusters.values()))
+
+ def text_to_instance(
+ self, # type: ignore
+ sentences: List[List[str]],
+ gold_clusters: Optional[List[List[Tuple[int, int]]]] = None,
+ ) -> dict:
+ return make_coref_instance(
+ sentences,
+ self.max_span_width,
+ gold_clusters,
+ self.max_sentences,
+ self.remove_singleton_clusters,
+ )
diff --git a/hanlp/datasets/cws/chunking_dataset.py b/hanlp/datasets/cws/chunking_dataset.py
new file mode 100644
index 000000000..dd497bd5a
--- /dev/null
+++ b/hanlp/datasets/cws/chunking_dataset.py
@@ -0,0 +1,48 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-03 18:50
+from typing import Union, List, Callable
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp.utils.io_util import get_resource
+from hanlp.utils.span_util import bmes_of
+from hanlp.utils.string_util import ispunct
+
+
+class ChunkingDataset(TransformableDataset):
+
+ def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None,
+ generate_idx=None, max_seq_len=None, sent_delimiter=None) -> None:
+ if not sent_delimiter:
+ sent_delimiter = lambda x: ispunct(x)
+ elif isinstance(sent_delimiter, str):
+ sent_delimiter = set(list(sent_delimiter))
+ sent_delimiter = lambda x: x in sent_delimiter
+ self.sent_delimiter = sent_delimiter
+ self.max_seq_len = max_seq_len
+ super().__init__(data, transform, cache, generate_idx)
+
+ def load_file(self, filepath):
+ max_seq_len = self.max_seq_len
+ delimiter = self.sent_delimiter
+ for chars, tags in self._generate_chars_tags(filepath, delimiter, max_seq_len):
+ yield {'char': chars, 'tag': tags}
+
+ @staticmethod
+ def _generate_chars_tags(filepath, delimiter, max_seq_len):
+ filepath = get_resource(filepath)
+ with open(filepath, encoding='utf8') as src:
+ for text in src:
+ chars, tags = bmes_of(text, True)
+ if max_seq_len and delimiter and len(chars) > max_seq_len:
+ short_chars, short_tags = [], []
+ for idx, (char, tag) in enumerate(zip(chars, tags)):
+ short_chars.append(char)
+ short_tags.append(tag)
+ if len(short_chars) >= max_seq_len and delimiter(char):
+ yield short_chars, short_tags
+ short_chars, short_tags = [], []
+ if short_chars:
+ yield short_chars, short_tags
+ else:
+ yield chars, tags
diff --git a/hanlp/datasets/cws/ctb.py b/hanlp/datasets/cws/ctb.py
deleted file mode 100644
index e4dd1a95b..000000000
--- a/hanlp/datasets/cws/ctb.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 22:19
-
-CTB6_CWS_HOME = 'http://file.hankcs.com/corpus/ctb6_cws.zip'
-
-CTB6_CWS_TRAIN = 'http://file.hankcs.com/corpus/ctb6_cws.zip#train.txt'
-CTB6_CWS_VALID = 'http://file.hankcs.com/corpus/ctb6_cws.zip#dev.txt'
-CTB6_CWS_TEST = 'http://file.hankcs.com/corpus/ctb6_cws.zip#test.txt'
diff --git a/hanlp/datasets/cws/ctb6.py b/hanlp/datasets/cws/ctb6.py
new file mode 100644
index 000000000..52f0a46a6
--- /dev/null
+++ b/hanlp/datasets/cws/ctb6.py
@@ -0,0 +1,12 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 22:19
+
+_CTB6_CWS_HOME = 'http://file.hankcs.com/corpus/ctb6_cws.zip'
+
+CTB6_CWS_TRAIN = _CTB6_CWS_HOME + '#train.txt'
+'''CTB6 training set.'''
+CTB6_CWS_DEV = _CTB6_CWS_HOME + '#dev.txt'
+'''CTB6 dev set.'''
+CTB6_CWS_TEST = _CTB6_CWS_HOME + '#test.txt'
+'''CTB6 test set.'''
diff --git a/hanlp/datasets/cws/multi_criteria_cws/__init__.py b/hanlp/datasets/cws/multi_criteria_cws/__init__.py
new file mode 100644
index 000000000..b5972df42
--- /dev/null
+++ b/hanlp/datasets/cws/multi_criteria_cws/__init__.py
@@ -0,0 +1,35 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-11 20:35
+
+_HOME = 'https://github.com/hankcs/multi-criteria-cws/archive/naive-mix.zip#data/raw/'
+
+CNC_TRAIN_ALL = _HOME + 'cnc/train-all.txt'
+CNC_TRAIN = _HOME + 'cnc/train.txt'
+CNC_DEV = _HOME + 'cnc/dev.txt'
+CNC_TEST = _HOME + 'cnc/test.txt'
+
+CTB_TRAIN_ALL = _HOME + 'ctb/train-all.txt'
+CTB_TRAIN = _HOME + 'ctb/train.txt'
+CTB_DEV = _HOME + 'ctb/dev.txt'
+CTB_TEST = _HOME + 'ctb/test.txt'
+
+SXU_TRAIN_ALL = _HOME + 'sxu/train-all.txt'
+SXU_TRAIN = _HOME + 'sxu/train.txt'
+SXU_DEV = _HOME + 'sxu/dev.txt'
+SXU_TEST = _HOME + 'sxu/test.txt'
+
+UDC_TRAIN_ALL = _HOME + 'udc/train-all.txt'
+UDC_TRAIN = _HOME + 'udc/train.txt'
+UDC_DEV = _HOME + 'udc/dev.txt'
+UDC_TEST = _HOME + 'udc/test.txt'
+
+WTB_TRAIN_ALL = _HOME + 'wtb/train-all.txt'
+WTB_TRAIN = _HOME + 'wtb/train.txt'
+WTB_DEV = _HOME + 'wtb/dev.txt'
+WTB_TEST = _HOME + 'wtb/test.txt'
+
+ZX_TRAIN_ALL = _HOME + 'zx/train-all.txt'
+ZX_TRAIN = _HOME + 'zx/train.txt'
+ZX_DEV = _HOME + 'zx/dev.txt'
+ZX_TEST = _HOME + 'zx/test.txt'
diff --git a/hanlp/datasets/cws/multi_criteria_cws/mcws_dataset.py b/hanlp/datasets/cws/multi_criteria_cws/mcws_dataset.py
new file mode 100644
index 000000000..cf0f48e48
--- /dev/null
+++ b/hanlp/datasets/cws/multi_criteria_cws/mcws_dataset.py
@@ -0,0 +1,98 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-21 19:11
+import os
+from typing import Union, List, Callable, Dict, Iterable
+
+from hanlp.datasets.tokenization.txt import TextTokenizingDataset
+from hanlp.utils.io_util import get_resource
+
+
+class MultiCriteriaTextTokenizingDataset(TextTokenizingDataset):
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ generate_idx=None,
+ delimiter=None,
+ max_seq_len=None,
+ sent_delimiter=None,
+ char_level=False,
+ hard_constraint=False) -> None:
+ super().__init__(data, transform, cache, generate_idx, delimiter, max_seq_len, sent_delimiter, char_level,
+ hard_constraint)
+
+ def should_load_file(self, data) -> bool:
+ return isinstance(data, (tuple, dict))
+
+ def load_file(self, filepath: Union[Iterable[str], Dict[str, str]]):
+ """Load multi-criteria corpora specified in filepath.
+
+ Args:
+ filepath: A list of files where filename is its criterion. Or a dict of filename-criterion pairs.
+
+ .. highlight:: bash
+ .. code-block:: bash
+
+ $ tree -L 2 .
+ .
+ ├── cnc
+ │ ├── dev.txt
+ │ ├── test.txt
+ │ ├── train-all.txt
+ │ └── train.txt
+ ├── ctb
+ │ ├── dev.txt
+ │ ├── test.txt
+ │ ├── train-all.txt
+ │ └── train.txt
+ ├── sxu
+ │ ├── dev.txt
+ │ ├── test.txt
+ │ ├── train-all.txt
+ │ └── train.txt
+ ├── udc
+ │ ├── dev.txt
+ │ ├── test.txt
+ │ ├── train-all.txt
+ │ └── train.txt
+ ├── wtb
+ │ ├── dev.txt
+ │ ├── test.txt
+ │ ├── train-all.txt
+ │ └── train.txt
+ └── zx
+ ├── dev.txt
+ ├── test.txt
+ ├── train-all.txt
+ └── train.txt
+
+ $ head -n 2 ctb/dev.txt
+ 上海 浦东 开发 与 法制 建设 同步
+ 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )
+
+ """
+ for eachpath in (filepath.items() if isinstance(filepath, dict) else filepath):
+ if isinstance(eachpath, tuple):
+ criteria, eachpath = eachpath
+ eachpath = get_resource(eachpath)
+ else:
+ eachpath = get_resource(eachpath)
+ criteria = os.path.basename(os.path.dirname(eachpath))
+ for sample in super().load_file(eachpath):
+ sample['criteria'] = criteria
+ yield sample
+
+
+def append_criteria_token(sample: dict, criteria_tokens: Dict[str, int], criteria_token_map: dict) -> dict:
+ criteria = sample['criteria']
+ token = criteria_token_map.get(criteria, None)
+ if not token:
+ unused_tokens = list(criteria_tokens.keys())
+ size = len(criteria_token_map)
+ assert size + 1 < len(unused_tokens), f'No unused token available for criteria {criteria}. ' \
+ f'Current criteria_token_map = {criteria_token_map}'
+ token = criteria_token_map[criteria] = unused_tokens[size]
+ sample['token_token_type_ids'] = [0] * len(sample['token_input_ids']) + [1]
+ sample['token_input_ids'] = sample['token_input_ids'] + [criteria_tokens[token]]
+ return sample
diff --git a/hanlp/datasets/cws/sighan2005/__init__.py b/hanlp/datasets/cws/sighan2005/__init__.py
index 8af5e264a..396805809 100644
--- a/hanlp/datasets/cws/sighan2005/__init__.py
+++ b/hanlp/datasets/cws/sighan2005/__init__.py
@@ -16,7 +16,7 @@ def make(train):
full = train.replace('_90.txt', '.utf8')
logger.info(f'Splitting {full} into training set and valid set with 9:1 proportion')
valid = train.replace('90.txt', '10.txt')
- split_file(full, train=0.9, valid=0.1, test=0, names={'train': train, 'valid': valid})
+ split_file(full, train=0.9, dev=0.1, test=0, names={'train': train, 'dev': valid})
assert os.path.isfile(train), f'Failed to make {train}'
assert os.path.isfile(valid), f'Failed to make {valid}'
logger.info(f'Successfully made {train} {valid}')
diff --git a/hanlp/datasets/cws/sighan2005/as_.py b/hanlp/datasets/cws/sighan2005/as_.py
new file mode 100644
index 000000000..f1a683c97
--- /dev/null
+++ b/hanlp/datasets/cws/sighan2005/as_.py
@@ -0,0 +1,19 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-21 15:42
+from hanlp.datasets.cws.sighan2005 import SIGHAN2005, make
+
+SIGHAN2005_AS_DICT = SIGHAN2005 + "#" + "gold/as_training_words.utf8"
+'''Dictionary built on trainings set.'''
+SIGHAN2005_AS_TRAIN_ALL = SIGHAN2005 + "#" + "training/as_training.utf8"
+'''Full training set.'''
+SIGHAN2005_AS_TRAIN = SIGHAN2005 + "#" + "training/as_training_90.txt"
+'''Training set (first 90% of the full official training set).'''
+SIGHAN2005_AS_DEV = SIGHAN2005 + "#" + "training/as_training_10.txt"
+'''Dev set (last 10% of full official training set).'''
+SIGHAN2005_AS_TEST_INPUT = SIGHAN2005 + "#" + "testing/as_testing.utf8"
+'''Test input.'''
+SIGHAN2005_AS_TEST = SIGHAN2005 + "#" + "gold/as_testing_gold.utf8"
+'''Test set.'''
+
+make(SIGHAN2005_AS_TRAIN)
diff --git a/hanlp/datasets/cws/sighan2005/cityu.py b/hanlp/datasets/cws/sighan2005/cityu.py
new file mode 100644
index 000000000..d5429d2f8
--- /dev/null
+++ b/hanlp/datasets/cws/sighan2005/cityu.py
@@ -0,0 +1,19 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-21 15:42
+from hanlp.datasets.cws.sighan2005 import SIGHAN2005, make
+
+SIGHAN2005_CITYU_DICT = SIGHAN2005 + "#" + "gold/cityu_training_words.utf8"
+'''Dictionary built on trainings set.'''
+SIGHAN2005_CITYU_TRAIN_ALL = SIGHAN2005 + "#" + "training/cityu_training.utf8"
+'''Full training set.'''
+SIGHAN2005_CITYU_TRAIN = SIGHAN2005 + "#" + "training/cityu_training_90.txt"
+'''Training set (first 90% of the full official training set).'''
+SIGHAN2005_CITYU_DEV = SIGHAN2005 + "#" + "training/cityu_training_10.txt"
+'''Dev set (last 10% of full official training set).'''
+SIGHAN2005_CITYU_TEST_INPUT = SIGHAN2005 + "#" + "testing/cityu_test.utf8"
+'''Test input.'''
+SIGHAN2005_CITYU_TEST = SIGHAN2005 + "#" + "gold/cityu_test_gold.utf8"
+'''Test set.'''
+
+make(SIGHAN2005_CITYU_TRAIN)
diff --git a/hanlp/datasets/cws/sighan2005/msr.py b/hanlp/datasets/cws/sighan2005/msr.py
index 43d69b949..f3db54a72 100644
--- a/hanlp/datasets/cws/sighan2005/msr.py
+++ b/hanlp/datasets/cws/sighan2005/msr.py
@@ -4,10 +4,16 @@
from hanlp.datasets.cws.sighan2005 import SIGHAN2005, make
SIGHAN2005_MSR_DICT = SIGHAN2005 + "#" + "gold/msr_training_words.utf8"
-SIGHAN2005_MSR_TRAIN_FULL = SIGHAN2005 + "#" + "training/msr_training.utf8"
+'''Dictionary built on trainings set.'''
+SIGHAN2005_MSR_TRAIN_ALL = SIGHAN2005 + "#" + "training/msr_training.utf8"
+'''Full training set.'''
SIGHAN2005_MSR_TRAIN = SIGHAN2005 + "#" + "training/msr_training_90.txt"
-SIGHAN2005_MSR_VALID = SIGHAN2005 + "#" + "training/msr_training_10.txt"
+'''Training set (first 90% of the full official training set).'''
+SIGHAN2005_MSR_DEV = SIGHAN2005 + "#" + "training/msr_training_10.txt"
+'''Dev set (last 10% of full official training set).'''
SIGHAN2005_MSR_TEST_INPUT = SIGHAN2005 + "#" + "testing/msr_test.utf8"
+'''Test input.'''
SIGHAN2005_MSR_TEST = SIGHAN2005 + "#" + "gold/msr_test_gold.utf8"
+'''Test set.'''
make(SIGHAN2005_MSR_TRAIN)
diff --git a/hanlp/datasets/cws/sighan2005/pku.py b/hanlp/datasets/cws/sighan2005/pku.py
index 59fd8f34e..0ecfc55cb 100644
--- a/hanlp/datasets/cws/sighan2005/pku.py
+++ b/hanlp/datasets/cws/sighan2005/pku.py
@@ -4,10 +4,16 @@
from hanlp.datasets.cws.sighan2005 import SIGHAN2005, make
SIGHAN2005_PKU_DICT = SIGHAN2005 + "#" + "gold/pku_training_words.utf8"
-SIGHAN2005_PKU_TRAIN_FULL = SIGHAN2005 + "#" + "training/pku_training.utf8"
+'''Dictionary built on trainings set.'''
+SIGHAN2005_PKU_TRAIN_ALL = SIGHAN2005 + "#" + "training/pku_training.utf8"
+'''Full training set.'''
SIGHAN2005_PKU_TRAIN = SIGHAN2005 + "#" + "training/pku_training_90.txt"
-SIGHAN2005_PKU_VALID = SIGHAN2005 + "#" + "training/pku_training_10.txt"
+'''Training set (first 90% of the full official training set).'''
+SIGHAN2005_PKU_DEV = SIGHAN2005 + "#" + "training/pku_training_10.txt"
+'''Dev set (last 10% of full official training set).'''
SIGHAN2005_PKU_TEST_INPUT = SIGHAN2005 + "#" + "testing/pku_test.utf8"
+'''Test input.'''
SIGHAN2005_PKU_TEST = SIGHAN2005 + "#" + "gold/pku_test_gold.utf8"
+'''Test set.'''
make(SIGHAN2005_PKU_TRAIN)
diff --git a/hanlp/datasets/eos/__init__.py b/hanlp/datasets/eos/__init__.py
new file mode 100644
index 000000000..108e6aa85
--- /dev/null
+++ b/hanlp/datasets/eos/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-26 18:11
\ No newline at end of file
diff --git a/hanlp/datasets/eos/eos.py b/hanlp/datasets/eos/eos.py
new file mode 100644
index 000000000..28de02e81
--- /dev/null
+++ b/hanlp/datasets/eos/eos.py
@@ -0,0 +1,101 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-26 18:12
+import itertools
+from collections import Counter
+from typing import Union, List, Callable
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp.utils.io_util import TimingFileIterator
+from hanlp.utils.log_util import cprint
+from hanlp.utils.string_util import ispunct
+
+
+class SentenceBoundaryDetectionDataset(TransformableDataset):
+
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ append_after_sentence=None,
+ eos_chars=None,
+ eos_char_min_freq=200,
+ eos_char_is_punct=True,
+ window_size=5,
+ **kwargs,
+ ) -> None:
+ """Dataset for sentence boundary detection (eos).
+
+ Args:
+ data: The local or remote path to a dataset, or a list of samples where each sample is a dict.
+ transform: Predefined transform(s).
+ cache: ``True`` to enable caching, so that transforms won't be called twice.
+ append_after_sentence: A :class:`str` to insert at the tail of each sentence. For example, English always
+ have a space between sentences.
+ eos_chars: Punctuations at the tail of sentences. If ``None``, then it will built from training samples.
+ eos_char_min_freq: Minimal frequency to keep a eos char.
+ eos_char_is_punct: Limit eos chars to punctuations.
+ window_size: Window size to extract ngram features.
+ kwargs: Not used.
+ """
+ self.eos_char_is_punct = eos_char_is_punct
+ self.append_after_sentence = append_after_sentence
+ self.window_size = window_size
+ self.eos_chars = eos_chars
+ self.eos_char_min_freq = eos_char_min_freq
+ super().__init__(data, transform, cache)
+
+ def load_file(self, filepath: str):
+ """Load eos corpus.
+
+ Args:
+ filepath: Path to the corpus.
+
+ .. highlight:: bash
+ .. code-block:: bash
+
+ $ head -n 2 ctb8.txt
+ 中国经济简讯
+ 新华社北京十月二十九日电中国经济简讯
+
+ """
+ f = TimingFileIterator(filepath)
+ sents = []
+ eos_offsets = []
+ offset = 0
+ for line in f:
+ if not line.strip():
+ continue
+ line = line.rstrip('\n')
+ eos_offsets.append(offset + len(line.rstrip()) - 1)
+ offset += len(line)
+ if self.append_after_sentence:
+ line += self.append_after_sentence
+ offset += len(self.append_after_sentence)
+ f.log(line)
+ sents.append(line)
+ f.erase()
+ corpus = list(itertools.chain.from_iterable(sents))
+
+ if self.eos_chars:
+ if not isinstance(self.eos_chars, set):
+ self.eos_chars = set(self.eos_chars)
+ else:
+ eos_chars = Counter()
+ for i in eos_offsets:
+ eos_chars[corpus[i]] += 1
+ self.eos_chars = set(k for (k, v) in eos_chars.most_common() if
+ v >= self.eos_char_min_freq and (not self.eos_char_is_punct or ispunct(k)))
+ cprint(f'eos_chars = [yellow]{self.eos_chars}[/yellow]')
+
+ eos_index = 0
+ eos_offsets = [i for i in eos_offsets if corpus[i] in self.eos_chars]
+ window_size = self.window_size
+ for i, c in enumerate(corpus):
+ if c in self.eos_chars:
+ window = corpus[i - window_size: i + window_size + 1]
+ label_id = 1. if eos_offsets[eos_index] == i else 0.
+ if label_id > 0:
+ eos_index += 1
+ yield {'char': window, 'label_id': label_id}
+ assert eos_index == len(eos_offsets), f'{eos_index} != {len(eos_offsets)}'
diff --git a/hanlp/datasets/eos/nn_eos.py b/hanlp/datasets/eos/nn_eos.py
new file mode 100644
index 000000000..1ca43cc59
--- /dev/null
+++ b/hanlp/datasets/eos/nn_eos.py
@@ -0,0 +1,17 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-24 22:51
+_SETIMES2_EN_HR_SENTENCES_HOME = 'https://schweter.eu/cloud/nn_eos/SETIMES2.en-hr.sentences.tar.xz'
+SETIMES2_EN_HR_HR_SENTENCES_TRAIN = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.train'
+'''Training set of SETimes corpus.'''
+SETIMES2_EN_HR_HR_SENTENCES_DEV = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.dev'
+'''Dev set of SETimes corpus.'''
+SETIMES2_EN_HR_HR_SENTENCES_TEST = _SETIMES2_EN_HR_SENTENCES_HOME + '#SETIMES2.en-hr.hr.sentences.test'
+'''Test set of SETimes corpus.'''
+_EUROPARL_V7_DE_EN_EN_SENTENCES_HOME = 'http://schweter.eu/cloud/nn_eos/europarl-v7.de-en.en.sentences.tar.xz'
+EUROPARL_V7_DE_EN_EN_SENTENCES_TRAIN = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.train'
+'''Training set of Europarl corpus (:cite:`koehn2005europarl`).'''
+EUROPARL_V7_DE_EN_EN_SENTENCES_DEV = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.dev'
+'''Dev set of Europarl corpus (:cite:`koehn2005europarl`).'''
+EUROPARL_V7_DE_EN_EN_SENTENCES_TEST = _EUROPARL_V7_DE_EN_EN_SENTENCES_HOME + '#europarl-v7.de-en.en.sentences.test'
+'''Test set of Europarl corpus (:cite:`koehn2005europarl`).'''
diff --git a/hanlp/datasets/glue.py b/hanlp/datasets/glue.py
index e53060f10..0a6d7c160 100644
--- a/hanlp/datasets/glue.py
+++ b/hanlp/datasets/glue.py
@@ -1,53 +1,25 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-11-10 11:47
-from hanlp.common.structure import SerializableDict
-from hanlp.transform.table import TableTransform
+from hanlp.common.dataset import TableDataset
STANFORD_SENTIMENT_TREEBANK_2_TRAIN = 'http://file.hankcs.com/corpus/SST2.zip#train.tsv'
-STANFORD_SENTIMENT_TREEBANK_2_VALID = 'http://file.hankcs.com/corpus/SST2.zip#dev.tsv'
+STANFORD_SENTIMENT_TREEBANK_2_DEV = 'http://file.hankcs.com/corpus/SST2.zip#dev.tsv'
STANFORD_SENTIMENT_TREEBANK_2_TEST = 'http://file.hankcs.com/corpus/SST2.zip#test.tsv'
MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_TRAIN = 'http://file.hankcs.com/corpus/mrpc.zip#train.tsv'
-MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_VALID = 'http://file.hankcs.com/corpus/mrpc.zip#dev.tsv'
+MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV = 'http://file.hankcs.com/corpus/mrpc.zip#dev.tsv'
MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_TEST = 'http://file.hankcs.com/corpus/mrpc.zip#test.tsv'
-class StanfordSentimentTreebank2Transorm(TableTransform):
+class SST2Dataset(TableDataset):
pass
-class MicrosoftResearchParaphraseCorpus(TableTransform):
-
- def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=(3, 4),
- y_column=0, skip_header=True, delimiter='auto', **kwargs) -> None:
- super().__init__(config, map_x, map_y, x_columns, y_column, skip_header, delimiter, **kwargs)
-
-
def main():
- # _test_sst2()
- _test_mrpc()
-
-
-def _test_sst2():
- transform = StanfordSentimentTreebank2Transorm()
- transform.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN)
- transform.lock_vocabs()
- transform.label_vocab.summary()
- transform.build_config()
- dataset = transform.file_to_dataset(STANFORD_SENTIMENT_TREEBANK_2_TRAIN)
- for batch in dataset.take(1):
- print(batch)
-
-def _test_mrpc():
- transform = MicrosoftResearchParaphraseCorpus()
- transform.fit(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_VALID)
- transform.lock_vocabs()
- transform.label_vocab.summary()
- transform.build_config()
- dataset = transform.file_to_dataset(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_VALID)
- for batch in dataset.take(1):
- print(batch)
+ dataset = SST2Dataset(STANFORD_SENTIMENT_TREEBANK_2_TEST)
+ print(dataset)
+
if __name__ == '__main__':
main()
diff --git a/hanlp/datasets/lm/__init__.py b/hanlp/datasets/lm/__init__.py
new file mode 100644
index 000000000..db1bd80df
--- /dev/null
+++ b/hanlp/datasets/lm/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-05 21:41
+
+_PTB_HOME = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz#'
+PTB_TOKEN_TRAIN = _PTB_HOME + 'data/ptb.train.txt'
+PTB_TOKEN_DEV = _PTB_HOME + 'data/ptb.valid.txt'
+PTB_TOKEN_TEST = _PTB_HOME + 'data/ptb.test.txt'
+
+PTB_CHAR_TRAIN = _PTB_HOME + 'data/ptb.char.train.txt'
+PTB_CHAR_DEV = _PTB_HOME + 'data/ptb.char.valid.txt'
+PTB_CHAR_TEST = _PTB_HOME + 'data/ptb.char.test.txt'
diff --git a/hanlp/datasets/lm/lm_dataset.py b/hanlp/datasets/lm/lm_dataset.py
new file mode 100644
index 000000000..0a325d5ac
--- /dev/null
+++ b/hanlp/datasets/lm/lm_dataset.py
@@ -0,0 +1,143 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-05 21:42
+import os
+from typing import Union, Callable, List
+
+import torch
+
+import hanlp_common.io
+from hanlp.common.dataset import TransformSequentialDataset
+from hanlp.common.transform import ToChar, WhitespaceTokenizer, AppendEOS, FieldToIndex
+from hanlp.common.vocab import Vocab
+from hanlp.utils.io_util import file_cache, get_resource, TimingFileIterator
+from hanlp.utils.log_util import flash, ErasablePrinter
+
+
+class LanguageModelDataset(TransformSequentialDataset):
+
+ def __init__(self,
+ data: str,
+ batch_size,
+ seq_len,
+ tokenizer='char',
+ eos='\n',
+ strip=True,
+ vocab=None,
+ cache=False,
+ transform: Union[Callable, List] = None) -> None:
+ self.cache = cache
+ self.eos = eos
+ self.strip = strip
+ super().__init__(transform)
+ if isinstance(tokenizer, str):
+ available_tokenizers = {
+ 'char': ToChar('text', 'token'),
+ 'whitespace': WhitespaceTokenizer('text', 'token')
+ }
+ assert tokenizer in available_tokenizers, f'{tokenizer} not supported, available options: {available_tokenizers.keys()} '
+ self.append_transform(available_tokenizers[tokenizer])
+
+ if vocab is None:
+ vocab = Vocab()
+ self.training = True
+ else:
+ self.training = vocab.mutable
+ self.append_transform(AppendEOS('token', eos=eos))
+ self.append_transform(FieldToIndex('token', vocab))
+ self.batch_size = batch_size
+ data = get_resource(data)
+ self.data = data
+ self.num_tokens = None
+ self.load_file(data)
+ self._fp = None
+ if isinstance(seq_len, int):
+ self.seq_len = lambda: seq_len
+ else:
+ self.seq_len = seq_len
+
+ @property
+ def vocab(self):
+ return self.transform[-1].vocab
+
+ @property
+ def vocab_path(self):
+ return os.path.splitext(self.data)[0] + '.vocab.json'
+
+ def load_file(self, filepath):
+ cache, valid = file_cache(filepath, not self.cache)
+ if not valid or (self.vocab.mutable and not os.path.isfile(self.vocab_path)):
+ with open(cache, 'wb') as out:
+ tokens, lines = 0, 0
+ f = TimingFileIterator(filepath)
+ for line in f:
+ if self.strip:
+ line = line.strip()
+ if not line:
+ continue
+ sample = {'text': line}
+ sample = self.transform_sample(sample, inplace=True)
+ for id in sample['token_id']:
+ out.write((id).to_bytes(4, 'little'))
+ tokens += len(sample['token_id'])
+ lines += 1
+ f.log(f'{tokens // 1000000}M tokens, {lines // 1000000}M lines\n'
+ f'{sample["token"][:10]}')
+ f.erase()
+ if self.vocab.mutable:
+ self.vocab.lock()
+ hanlp_common.io.save_json(self.vocab_path)
+ self.num_tokens = tokens
+ else:
+ self.num_tokens = int(os.path.getsize(self.filecache) / 4)
+ if self.vocab.mutable:
+ hanlp_common.io.load_json(self.vocab_path)
+
+ def __iter__(self):
+ batch_size = self.batch_size
+ max_seq_len = self.max_seq_len
+ i = 0
+ safety = 2 if self.training else 1
+ with open(self.filecache, 'rb') as fp:
+ while i < max_seq_len - safety:
+ seq_len = self.seq_len()
+ seq_len = min(seq_len, max_seq_len - 1 - i)
+ data = []
+ for j in range(batch_size):
+ data.append(self._read_chunk(fp, max_seq_len * j + i, seq_len + 1))
+ data = torch.LongTensor(data)
+ data.transpose_(0, 1)
+ data, targets = data[:seq_len, :], data[1:, :]
+ yield data, targets.contiguous().view(-1)
+ i += seq_len
+
+ def estimate_num_batches(self, seq_len=None):
+ if not seq_len:
+ seq_len = self.seq_len()
+ return self.max_seq_len // seq_len
+
+ @property
+ def max_seq_len(self):
+ max_seq_len = self.num_tokens // self.batch_size
+ return max_seq_len
+
+ @staticmethod
+ def _read_chunk(fp, offset, length):
+ data = []
+ fp.seek(offset * 4)
+ for i in range(length):
+ id = int.from_bytes(fp.read(4), 'little')
+ data.append(id)
+ return data
+
+ def _debug_load_cache(self):
+ with open(self.filecache, 'rb') as src:
+ ids = []
+ for i in range(self.num_tokens):
+ id = int.from_bytes(src.read(4), 'little')
+ ids.append(id)
+ return torch.LongTensor(ids)
+
+ @property
+ def filecache(self):
+ return file_cache(self.data)[0]
diff --git a/hanlp/datasets/ner/conll03.py b/hanlp/datasets/ner/conll03.py
index 2927c3f85..0e0eef536 100644
--- a/hanlp/datasets/ner/conll03.py
+++ b/hanlp/datasets/ner/conll03.py
@@ -2,6 +2,10 @@
# Author: hankcs
# Date: 2019-12-06 15:31
+
CONLL03_EN_TRAIN = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.train.tsv'
-CONLL03_EN_VALID = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.dev.tsv'
+'''Training set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)'''
+CONLL03_EN_DEV = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.dev.tsv'
+'''Dev set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)'''
CONLL03_EN_TEST = 'https://file.hankcs.com/corpus/conll03_en_iobes.zip#eng.test.tsv'
+'''Test set of CoNLL03 (:cite:`tjong-kim-sang-de-meulder-2003-introduction`)'''
diff --git a/hanlp/datasets/ner/json_ner.py b/hanlp/datasets/ner/json_ner.py
new file mode 100644
index 000000000..9428bec3f
--- /dev/null
+++ b/hanlp/datasets/ner/json_ner.py
@@ -0,0 +1,151 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-21 16:26
+import os
+from typing import Union, List, Callable, Dict
+
+from hanlp_common.constant import NULL
+from hanlp.common.dataset import TransformableDataset
+import json
+from alnlp.metrics import span_utils
+from hanlp.utils.io_util import TimingFileIterator, read_tsv_as_sents
+
+
+class JsonNERDataset(TransformableDataset):
+
+ def __init__(self, data: Union[str, List], transform: Union[Callable, List] = None, cache=None,
+ generate_idx=None, doc_level_offset=True, tagset=None) -> None:
+ """A dataset for ``.jsonlines`` format NER corpora.
+
+ Args:
+ data: The local or remote path to a dataset, or a list of samples where each sample is a dict.
+ transform: Predefined transform(s).
+ cache: ``True`` to enable caching, so that transforms won't be called twice.
+ generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when
+ samples are re-ordered by a sampler.
+ doc_level_offset: ``True`` to indicate the offsets in ``jsonlines`` are of document level.
+ tagset: Optional tagset to prune entities outside of this tagset from datasets.
+ """
+ self.tagset = tagset
+ self.doc_level_offset = doc_level_offset
+ super().__init__(data, transform, cache, generate_idx)
+
+ def load_file(self, filepath: str):
+ """Load ``.jsonlines`` NER corpus. Samples of this corpus can be found using the following scripts.
+
+ .. highlight:: python
+ .. code-block:: python
+
+ import json
+ from hanlp_common.document import Document
+ from hanlp.datasets.srl.ontonotes5.chinese import ONTONOTES5_CONLL12_CHINESE_DEV
+ from hanlp.utils.io_util import get_resource
+
+ with open(get_resource(ONTONOTES5_CONLL12_CHINESE_DEV)) as src:
+ for line in src:
+ doc = json.loads(line)
+ print(Document(doc))
+ break
+
+ Args:
+ filepath: ``.jsonlines`` NER corpus.
+ """
+ filename = os.path.basename(filepath)
+ reader = TimingFileIterator(filepath)
+ num_docs, num_sentences = 0, 0
+ for line in reader:
+ doc = json.loads(line)
+ num_docs += 1
+ num_tokens_in_doc = 0
+ for sentence, ner in zip(doc['sentences'], doc['ner']):
+ if self.doc_level_offset:
+ ner = [(x[0] - num_tokens_in_doc, x[1] - num_tokens_in_doc, x[2]) for x in ner]
+ else:
+ ner = [(x[0], x[1], x[2]) for x in ner]
+ if self.tagset:
+ ner = [x for x in ner if x[2] in self.tagset]
+ if isinstance(self.tagset, dict):
+ ner = [(x[0], x[1], self.tagset[x[2]]) for x in ner]
+ deduplicated_srl = []
+ be_set = set()
+ for b, e, l in ner:
+ be = (b, e)
+ if be in be_set:
+ continue
+ be_set.add(be)
+ deduplicated_srl.append((b, e, l))
+ yield {
+ 'token': sentence,
+ 'ner': deduplicated_srl
+ }
+ num_sentences += 1
+ num_tokens_in_doc += len(sentence)
+ reader.log(
+ f'{filename} {num_docs} documents, {num_sentences} sentences [blink][yellow]...[/yellow][/blink]')
+ reader.erase()
+
+
+def convert_conll03_to_json(file_path):
+ dataset = []
+ num_docs = [0]
+
+ def new_doc():
+ doc_key = num_docs[0]
+ num_docs[0] += 1
+ return {
+ 'doc_key': doc_key,
+ 'sentences': [],
+ 'ner': [],
+ }
+
+ doc = new_doc()
+ offset = 0
+ for cells in read_tsv_as_sents(file_path):
+ if cells[0][0] == '-DOCSTART-' and doc['ner']:
+ dataset.append(doc)
+ doc = new_doc()
+ offset = 0
+ sentence = [x[0] for x in cells]
+ ner = [x[-1] for x in cells]
+ ner = span_utils.iobes_tags_to_spans(ner)
+ adjusted_ner = []
+ for label, (span_start, span_end) in ner:
+ adjusted_ner.append([span_start + offset, span_end + offset, label])
+ doc['sentences'].append(sentence)
+ doc['ner'].append(adjusted_ner)
+ offset += len(sentence)
+ if doc['ner']:
+ dataset.append(doc)
+ output_path = os.path.splitext(file_path)[0] + '.json'
+ with open(output_path, 'w') as out:
+ for each in dataset:
+ json.dump(each, out)
+ out.write('\n')
+
+
+def unpack_ner(sample: dict) -> dict:
+ ner: list = sample.get('ner', None)
+ if ner is not None:
+ if ner:
+ sample['begin_offset'], sample['end_offset'], sample['label'] = zip(*ner)
+ else:
+ # It's necessary to create a null label when there is no NER in the sentence for the sake of padding.
+ sample['begin_offset'], sample['end_offset'], sample['label'] = [0], [0], [NULL]
+ return sample
+
+
+def prune_ner_tagset(sample: dict, tagset: Union[set, Dict[str, str]]):
+ if 'tag' in sample:
+ pruned_tag = []
+ for tag in sample['tag']:
+ cells = tag.split('-', 1)
+ if len(cells) == 2:
+ role, ner_type = cells
+ if ner_type in tagset:
+ if isinstance(tagset, dict):
+ tag = role + '-' + tagset[ner_type]
+ else:
+ tag = 'O'
+ pruned_tag.append(tag)
+ sample['tag'] = pruned_tag
+ return sample
\ No newline at end of file
diff --git a/hanlp/datasets/ner/msra.py b/hanlp/datasets/ner/msra.py
index b5b7fe317..6019b2ba9 100644
--- a/hanlp/datasets/ner/msra.py
+++ b/hanlp/datasets/ner/msra.py
@@ -2,8 +2,33 @@
# Author: hankcs
# Date: 2019-12-28 23:13
-MSRA_NER_HOME = 'http://file.hankcs.com/corpus/msra_ner.zip'
+_MSRA_NER_HOME = 'http://file.hankcs.com/corpus/msra_ner.zip'
+_MSRA_NER_TOKEN_LEVEL_HOME = 'http://file.hankcs.com/corpus/msra_ner_token_level.zip'
-MSRA_NER_TRAIN = f'{MSRA_NER_HOME}#train.tsv'
-MSRA_NER_VALID = f'{MSRA_NER_HOME}#dev.tsv'
-MSRA_NER_TEST = f'{MSRA_NER_HOME}#test.tsv'
+MSRA_NER_CHAR_LEVEL_TRAIN = f'{_MSRA_NER_HOME}#train.tsv'
+'''Training set of MSRA (:cite:`levow-2006-third`) in character level.'''
+MSRA_NER_CHAR_LEVEL_DEV = f'{_MSRA_NER_HOME}#dev.tsv'
+'''Dev set of MSRA (:cite:`levow-2006-third`) in character level.'''
+MSRA_NER_CHAR_LEVEL_TEST = f'{_MSRA_NER_HOME}#test.tsv'
+'''Test set of MSRA (:cite:`levow-2006-third`) in character level.'''
+
+MSRA_NER_TOKEN_LEVEL_IOBES_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.tsv'
+'''Training set of MSRA (:cite:`levow-2006-third`) in token level.'''
+MSRA_NER_TOKEN_LEVEL_IOBES_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.tsv'
+'''Dev set of MSRA (:cite:`levow-2006-third`) in token level.'''
+MSRA_NER_TOKEN_LEVEL_IOBES_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.tsv'
+'''Test set of MSRA (:cite:`levow-2006-third`) in token level.'''
+
+MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.short.tsv'
+'''Training set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.'''
+MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.short.tsv'
+'''Dev set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.'''
+MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.short.tsv'
+'''Test set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level.'''
+
+MSRA_NER_TOKEN_LEVEL_SHORT_JSON_TRAIN = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.train.short.jsonlines'
+'''Training set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.'''
+MSRA_NER_TOKEN_LEVEL_SHORT_JSON_DEV = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.dev.short.jsonlines'
+'''Dev set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.'''
+MSRA_NER_TOKEN_LEVEL_SHORT_JSON_TEST = f'{_MSRA_NER_TOKEN_LEVEL_HOME}#word_level.test.short.jsonlines'
+'''Test set of shorten (<= 128 tokens) MSRA (:cite:`levow-2006-third`) in token level and jsonlines format.'''
diff --git a/hanlp/datasets/ner/resume.py b/hanlp/datasets/ner/resume.py
new file mode 100644
index 000000000..ac7a6c6f2
--- /dev/null
+++ b/hanlp/datasets/ner/resume.py
@@ -0,0 +1,16 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-08 12:10
+from hanlp.common.dataset import TransformableDataset
+
+from hanlp.utils.io_util import get_resource, generate_words_tags_from_tsv
+
+_RESUME_NER_HOME = 'https://github.com/jiesutd/LatticeLSTM/archive/master.zip#'
+
+RESUME_NER_TRAIN = _RESUME_NER_HOME + 'ResumeNER/train.char.bmes'
+'''Training set of Resume in char level.'''
+RESUME_NER_DEV = _RESUME_NER_HOME + 'ResumeNER/dev.char.bmes'
+'''Dev set of Resume in char level.'''
+RESUME_NER_TEST = _RESUME_NER_HOME + 'ResumeNER/test.char.bmes'
+'''Test set of Resume in char level.'''
+
diff --git a/hanlp/datasets/ner/tsv.py b/hanlp/datasets/ner/tsv.py
new file mode 100644
index 000000000..398c8affd
--- /dev/null
+++ b/hanlp/datasets/ner/tsv.py
@@ -0,0 +1,86 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-24 23:09
+from typing import Union, List, Callable
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp.utils.io_util import get_resource, generate_words_tags_from_tsv
+from hanlp.utils.string_util import split_long_sentence_into
+
+
+class TSVTaggingDataset(TransformableDataset):
+
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ generate_idx=None,
+ max_seq_len=None,
+ sent_delimiter=None,
+ char_level=False,
+ hard_constraint=False,
+ **kwargs
+ ) -> None:
+ """
+
+ Args:
+ data: The local or remote path to a dataset, or a list of samples where each sample is a dict.
+ transform: Predefined transform(s).
+ cache: ``True`` to enable caching, so that transforms won't be called twice.
+ generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when
+ samples are re-ordered by a sampler.
+ max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level, which is never the case for
+ lemmatization.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ kwargs: Not used.
+ """
+ self.char_level = char_level
+ self.hard_constraint = hard_constraint
+ self.sent_delimiter = sent_delimiter
+ self.max_seq_len = max_seq_len
+ super().__init__(data, transform, cache, generate_idx)
+
+ def load_file(self, filepath):
+ """Load a ``.tsv`` file. A ``.tsv`` file for tagging is defined as a tab separated text file, where non-empty
+ lines have two columns for token and tag respectively, empty lines mark the end of sentences.
+
+ Args:
+ filepath: Path to a ``.tsv`` tagging file.
+
+ .. highlight:: bash
+ .. code-block:: bash
+
+ $ head eng.train.tsv
+ -DOCSTART- O
+
+ EU S-ORG
+ rejects O
+ German S-MISC
+ call O
+ to O
+ boycott O
+ British S-MISC
+ lamb O
+
+ """
+ filepath = get_resource(filepath)
+ # idx = 0
+ for words, tags in generate_words_tags_from_tsv(filepath, lower=False):
+ # idx += 1
+ # if idx % 1000 == 0:
+ # print(f'\rRead instances {idx // 1000}k', end='')
+ if self.max_seq_len:
+ start = 0
+ for short_sents in split_long_sentence_into(words, self.max_seq_len, self.sent_delimiter,
+ char_level=self.char_level,
+ hard_constraint=self.hard_constraint):
+ end = start + len(short_sents)
+ yield {'token': short_sents, 'tag': tags[start:end]}
+ start = end
+ else:
+ yield {'token': words, 'tag': tags}
+ # print('\r', end='')
diff --git a/hanlp/datasets/ner/weibo.py b/hanlp/datasets/ner/weibo.py
new file mode 100644
index 000000000..5c18e62b8
--- /dev/null
+++ b/hanlp/datasets/ner/weibo.py
@@ -0,0 +1,15 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-03 23:33
+from hanlp.common.dataset import TransformableDataset
+
+from hanlp.utils.io_util import get_resource, generate_words_tags_from_tsv
+
+_WEIBO_NER_HOME = 'https://github.com/hltcoe/golden-horse/archive/master.zip#data/'
+
+WEIBO_NER_TRAIN = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.train'
+'''Training set of Weibo in char level.'''
+WEIBO_NER_DEV = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.dev'
+'''Dev set of Weibo in char level.'''
+WEIBO_NER_TEST = _WEIBO_NER_HOME + 'weiboNER_2nd_conll.test'
+'''Test set of Weibo in char level.'''
diff --git a/hanlp/datasets/parsing/__init__.py b/hanlp/datasets/parsing/__init__.py
index 9fc22973f..49d49520b 100644
--- a/hanlp/datasets/parsing/__init__.py
+++ b/hanlp/datasets/parsing/__init__.py
@@ -1,5 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 00:51
-from . import ctb
-from . import semeval2016
diff --git a/hanlp/datasets/parsing/_ctb_utils.py b/hanlp/datasets/parsing/_ctb_utils.py
new file mode 100644
index 000000000..dc783f3ed
--- /dev/null
+++ b/hanlp/datasets/parsing/_ctb_utils.py
@@ -0,0 +1,306 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-25 16:14
+import os
+import shutil
+import sys
+from collections import defaultdict
+from os import listdir
+from os.path import join, isfile
+from typing import List
+
+from phrasetree.tree import Tree
+
+from hanlp.components.parsers.conll import read_conll
+from hanlp.utils.io_util import get_resource, get_exitcode_stdout_stderr, read_tsv_as_sents, run_cmd, pushd
+from hanlp.utils.log_util import cprint
+from hanlp.utils.time_util import CountdownTimer
+
+
+# See Shao et al., 2017
+# CTB9_ACADEMIA_SPLITS = {
+# 'train': '''
+# 0044-0143, 0170-0270, 0400-0899,
+# 1001-1017, 1019, 1021-1035, 1037-
+# 1043, 1045-1059, 1062-1071, 1073-
+# 1117, 1120-1131, 1133-1140, 1143-
+# 1147, 1149-1151, 2000-2915, 4051-
+# 4099, 4112-4180, 4198-4368, 5000-
+# 5446, 6000-6560, 7000-7013
+# ''',
+# 'dev': '''
+# 0301-0326, 2916-3030, 4100-4106,
+# 4181-4189, 4369-4390, 5447-5492,
+# 6561-6630, 7013-7014
+# ''',
+# 'test': '''
+# 0001-0043, 0144-0169, 0271-0301,
+# 0900-0931, 1018, 1020, 1036, 1044,
+# 1060, 1061, 1072, 1118, 1119, 1132,
+# 1141, 1142, 1148, 3031-3145, 4107-
+# 4111, 4190-4197, 4391-4411, 5493-
+# 5558, 6631-6700, 7015-7017
+# '''
+# }
+#
+#
+# def _make_splits(splits: Dict[str, str]):
+# total = set()
+# for part, text in list(splits.items()):
+# if not isinstance(text, str):
+# continue
+# lines = text.replace('\n', '').split()
+# cids = set()
+# for line in lines:
+# for each in line.split(','):
+# each = each.strip()
+# if not each:
+# continue
+# if '-' in each:
+# start, end = each.split('-')
+# start, end = map(lambda x: int(x), [start, end])
+# cids.update(range(start, end + 1))
+# # cids.update(map(lambda x: f'{x:04d}', range(start, end)))
+# else:
+# cids.add(int(each))
+# cids = set(f'{x:04d}' for x in cids)
+# assert len(cids & total) == 0, f'Overlap found in {part}'
+# splits[part] = cids
+#
+# return splits
+#
+#
+# _make_splits(CTB9_ACADEMIA_SPLITS)
+
+
+def convert_to_stanford_dependency_330(src, dst):
+ cprint(f'Converting {os.path.basename(src)} to {os.path.basename(dst)} using Stanford Parser Version 3.3.0. '
+ f'It might take a while [blink][yellow]...[/yellow][/blink]')
+ sp_home = 'https://nlp.stanford.edu/software/stanford-parser-full-2013-11-12.zip'
+ sp_home = get_resource(sp_home)
+ # jar_path = get_resource(f'{sp_home}#stanford-parser.jar')
+ code, out, err = get_exitcode_stdout_stderr(
+ f'java -cp {sp_home}/* edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalStructure '
+ f'-basic -keepPunct -conllx '
+ f'-treeFile {src}')
+ with open(dst, 'w') as f:
+ f.write(out)
+ if code:
+ raise RuntimeError(f'Conversion failed with code {code} for {src}. The err message is:\n {err}\n'
+ f'Do you have java installed? Do you have enough memory?')
+
+
+def clean_ctb_bracketed(ctb_root, out_root):
+ os.makedirs(out_root, exist_ok=True)
+ ctb_root = join(ctb_root, 'bracketed')
+ chtbs = _list_treebank_root(ctb_root)
+ timer = CountdownTimer(len(chtbs))
+ for f in chtbs:
+ with open(join(ctb_root, f), encoding='utf-8') as src, open(join(out_root, f + '.txt'), 'w',
+ encoding='utf-8') as out:
+ for line in src:
+ if not line.strip().startswith('<'):
+ out.write(line)
+ timer.log('Cleaning up CTB [blink][yellow]...[/yellow][/blink]', erase=False)
+
+
+def _list_treebank_root(ctb_root):
+ chtbs = [f for f in listdir(ctb_root) if isfile(join(ctb_root, f)) and f.startswith('chtb')]
+ return sorted(chtbs)
+
+
+def list_treebank(ctb_home):
+ ctb_home = get_resource(ctb_home)
+ cleaned_root = join(ctb_home, 'cleaned_bracket')
+ return _list_treebank_root(cleaned_root)
+
+
+def load_bracketed_trees(chtbs) -> List[Tree]:
+ trees = []
+ for f in chtbs:
+ with open(f, encoding='utf-8') as src:
+ content = src.read()
+ trees = [x for x in content.split('\n\n') if x.strip()]
+ for tree in trees:
+ tree = Tree.fromstring(tree)
+ trees.append(tree)
+ return trees
+
+
+def split_str_to_trees(text: str):
+ trees = []
+ buffer = []
+ for line in text.split('\n'):
+ if not line.strip():
+ continue
+ if line.startswith('('):
+ if buffer:
+ trees.append('\n'.join(buffer).strip())
+ buffer = []
+ buffer.append(line)
+ if buffer:
+ trees.append('\n'.join(buffer).strip())
+ return trees
+
+
+def make_ctb_tasks(chtbs, out_root, part):
+ for task in ['cws', 'pos', 'par', 'dep']:
+ os.makedirs(join(out_root, task), exist_ok=True)
+ timer = CountdownTimer(len(chtbs))
+ par_path = join(out_root, 'par', f'{part}.txt')
+ with open(join(out_root, 'cws', f'{part}.txt'), 'w', encoding='utf-8') as cws, \
+ open(join(out_root, 'pos', f'{part}.tsv'), 'w', encoding='utf-8') as pos, \
+ open(par_path, 'w', encoding='utf-8') as par:
+ for f in chtbs:
+ with open(f, encoding='utf-8') as src:
+ content = src.read()
+ trees = split_str_to_trees(content)
+ for tree in trees:
+ try:
+ tree = Tree.fromstring(tree)
+ except ValueError:
+ print(tree)
+ exit(1)
+ words = []
+ for word, tag in tree.pos():
+ if tag == '-NONE-' or not tag:
+ continue
+ tag = tag.split('-')[0]
+ if tag == 'X': # 铜_NN 30_CD x_X 25_CD x_X 14_CD cm_NT 1999_NT
+ tag = 'FW'
+ pos.write('{}\t{}\n'.format(word, tag))
+ words.append(word)
+ cws.write(' '.join(words))
+ par.write(tree.pformat(margin=sys.maxsize))
+ for fp in cws, pos, par:
+ fp.write('\n')
+ timer.log(f'Preprocesing the [blue]{part}[/blue] set of CTB [blink][yellow]...[/yellow][/blink]',
+ erase=False)
+ remove_all_ec(par_path)
+ dep_path = join(out_root, 'dep', f'{part}.conllx')
+ convert_to_stanford_dependency_330(par_path, dep_path)
+ sents = list(read_conll(dep_path))
+ with open(dep_path, 'w') as out:
+ for sent in sents:
+ for i, cells in enumerate(sent):
+ tag = cells[3]
+ tag = tag.split('-')[0] # NT-SHORT ---> NT
+ if tag == 'X': # 铜_NN 30_CD x_X 25_CD x_X 14_CD cm_NT 1999_NT
+ tag = 'FW'
+ cells[3] = cells[4] = tag
+ out.write('\t'.join(str(x) for x in cells))
+ out.write('\n')
+ out.write('\n')
+
+
+def reverse_splits(splits):
+ cid_domain = dict()
+ for domain, cids in splits.items():
+ for each in cids:
+ cid_domain[each] = domain
+ return cid_domain
+
+
+def split_chtb(chtbs: List[str], splits=None):
+ train, dev, test = [], [], []
+ unused = []
+ for each in chtbs:
+ name, domain, ext = each.split('.', 2)
+ _, cid = name.split('_')
+ if splits:
+ if cid in splits['train']:
+ bin = train
+ elif cid in splits['dev']:
+ bin = dev
+ elif cid in splits['test']:
+ bin = test
+ else:
+ bin = unused
+ # raise IOError(f'{name} not in any splits')
+ else:
+ bin = train
+ if name.endswith('8'):
+ bin = dev
+ elif name.endswith('9'):
+ bin = test
+ bin.append(each)
+ return train, dev, test
+
+
+def id_of_chtb(each: str):
+ return int(each.split('.')[0].split('_')[-1])
+
+
+def make_ctb(ctb_home):
+ ctb_home = get_resource(ctb_home)
+ cleaned_root = join(ctb_home, 'cleaned_bracket')
+ if not os.path.isdir(cleaned_root):
+ clean_ctb_bracketed(ctb_home, cleaned_root)
+ tasks_root = join(ctb_home, 'tasks')
+ if not os.path.isdir(tasks_root):
+ try:
+ chtbs = _list_treebank_root(cleaned_root)
+ print(f'For the {len(chtbs)} files in CTB, we apply the following splits:')
+ train, dev, test = split_chtb(chtbs)
+ for part, name in zip([train, dev, test], ['train', 'dev', 'test']):
+ print(f'{name} = {[id_of_chtb(x) for x in part]}')
+ cprint('[yellow]Each file id ending with 8/9 is put into '
+ 'dev/test respectively, the rest are put into train. '
+ 'Our splits ensure files are evenly split across each genre, which is recommended '
+ 'for production systems.[/yellow]')
+ for part, name in zip([train, dev, test], ['train', 'dev', 'test']):
+ make_ctb_tasks([join(cleaned_root, x) for x in part], tasks_root, name)
+ cprint('Done pre-processing CTB. Enjoy your research with [blue]HanLP[/blue]!')
+ except Exception as e:
+ shutil.rmtree(tasks_root, ignore_errors=True)
+ raise e
+
+
+def load_domains(ctb_home):
+ """
+ Load file ids from a Chinese treebank grouped by domains.
+
+ Args:
+ ctb_home: Root path to CTB.
+
+ Returns:
+ A dict of sets, each represents a domain.
+ """
+ ctb_home = get_resource(ctb_home)
+ ctb_root = join(ctb_home, 'bracketed')
+ chtbs = _list_treebank_root(ctb_root)
+ domains = defaultdict(set)
+ for each in chtbs:
+ name, domain = each.split('.')
+ _, fid = name.split('_')
+ domains[domain].add(fid)
+ return domains
+
+
+def ctb_pos_to_text_format(path, delimiter='_'):
+ """
+ Convert ctb pos tagging corpus from tsv format to text format, where each word is followed by
+ its pos tag.
+ Args:
+ path: File to be converted.
+ delimiter: Delimiter between word and tag.
+ """
+ path = get_resource(path)
+ name, ext = os.path.splitext(path)
+ with open(f'{name}.txt', 'w', encoding='utf-8') as out:
+ for sent in read_tsv_as_sents(path):
+ out.write(' '.join([delimiter.join(x) for x in sent]))
+ out.write('\n')
+
+
+def remove_all_ec(path):
+ """
+ Remove empty categories for all trees in this file and save them into a "noempty" file.
+
+ Args:
+ path: File path.
+ """
+ script = get_resource('https://file.hankcs.com/bin/remove_ec.zip')
+ with pushd(script):
+ run_cmd(f'java -cp elit-ddr-0.0.5-SNAPSHOT.jar:elit-sdk-0.0.5-SNAPSHOT.jar:hanlp-1.7.8.jar:'
+ f'fastutil-8.1.1.jar:. demo.RemoveEmptyCategoriesTreebank {path}')
diff --git a/hanlp/datasets/parsing/amr.py b/hanlp/datasets/parsing/amr.py
new file mode 100644
index 000000000..8110c4a2a
--- /dev/null
+++ b/hanlp/datasets/parsing/amr.py
@@ -0,0 +1,390 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-18 17:47
+from collections import defaultdict
+from copy import copy
+from typing import List
+
+import numpy as np
+import torch
+
+
+from hanlp_common.constant import CLS
+from hanlp.common.dataset import TransformableDataset, PadSequenceDataLoader
+from hanlp.common.transform import VocabDict
+from hanlp.common.vocab import VocabWithFrequency
+from hanlp.components.amr.amr_parser.amrio import AMRIO
+from hanlp.components.amr.amr_parser.data import END, DUM, list_to_tensor, lists_of_string_to_tensor, NIL, REL
+from hanlp.components.amr.amr_parser.transformer import SelfAttentionMask
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+from hanlp_common.util import merge_list_of_dict
+
+
+class AbstractMeaningRepresentationDataset(TransformableDataset):
+ def load_file(self, filepath: str):
+ for tok, lem, pos, ner, amr in AMRIO.read(filepath):
+ yield {'token': tok, 'lemma': lem, 'pos': pos, 'ner': ner, 'amr': amr}
+
+
+def generate_oracle(sample: dict):
+ amr = sample.get('amr', None)
+ if amr:
+ concept, edge, _ = amr.root_centered_sort()
+ sample['concept'] = concept
+ sample['edge'] = edge
+ return sample
+
+
+def chars_for_tok(sample: dict, max_string_len=20):
+ token = sample['token']
+ chars = []
+ for each in token:
+ each = each[:max_string_len]
+ chars.append([CLS] + list(each) + [END])
+ sample['word_char'] = chars
+ return sample
+
+
+def append_bos(sample: dict):
+ for key in ['token', 'lemma', 'pos', 'ner']:
+ if key in sample:
+ sample[key] = [CLS] + sample[key]
+ return sample
+
+
+def get_concepts(sample: dict, vocab: VocabWithFrequency = None, rel_vocab: VocabWithFrequency = None):
+ lem, tok = sample['lemma'], sample['token']
+ cp_seq, mp_seq = [], []
+ new_tokens = set()
+ for le, to in zip(lem, tok):
+ cp_seq.append(le + '_')
+ mp_seq.append(le)
+
+ for cp, mp in zip(cp_seq, mp_seq):
+ if vocab.get_idx(cp) == vocab.unk_idx:
+ new_tokens.add(cp)
+ if vocab.get_idx(mp) == vocab.unk_idx:
+ new_tokens.add(mp)
+ nxt = len(vocab)
+ token2idx, idx2token = dict(), dict()
+ if rel_vocab:
+ new_tokens = rel_vocab.idx_to_token + sorted(new_tokens)
+ else:
+ new_tokens = sorted(new_tokens)
+ for x in new_tokens:
+ token2idx[x] = nxt
+ idx2token[nxt] = x
+ nxt += 1
+ for k, v in zip(['cp_seq', 'mp_seq', 'token2idx', 'idx2token'], [cp_seq, mp_seq, token2idx, idx2token]):
+ sample[k] = v
+ return sample
+
+
+def batchify(data, vocabs: VocabDict, unk_rate=0., device=None, squeeze=False,
+ tokenizer: TransformerSequenceTokenizer = None, shuffle_sibling=True,
+ levi_graph=False, extra_arc=False, bart=False):
+ rel_vocab: VocabWithFrequency = vocabs.rel
+ _tok = list_to_tensor(data['token'], vocabs['token'], unk_rate=unk_rate) if 'token' in vocabs else None
+ _lem = list_to_tensor(data['lemma'], vocabs['lemma'], unk_rate=unk_rate)
+ _pos = list_to_tensor(data['pos'], vocabs['pos'], unk_rate=unk_rate) if 'pos' in vocabs else None
+ _ner = list_to_tensor(data['ner'], vocabs['ner'], unk_rate=unk_rate) if 'ner' in vocabs else None
+ _word_char = lists_of_string_to_tensor(data['token'], vocabs['word_char']) if 'word_char' in vocabs else None
+
+ local_token2idx = data['token2idx']
+ local_idx2token = data['idx2token']
+ _cp_seq = list_to_tensor(data['cp_seq'], vocabs['predictable_concept'], local_token2idx)
+ _mp_seq = list_to_tensor(data['mp_seq'], vocabs['predictable_concept'], local_token2idx)
+
+ ret = copy(data)
+ if 'amr' in data:
+ concept, edge = [], []
+ for amr in data['amr']:
+ if levi_graph == 'kahn':
+ concept_i, edge_i = amr.to_levi(rel_vocab.get_frequency, shuffle=shuffle_sibling)
+ else:
+ concept_i, edge_i, _ = amr.root_centered_sort(rel_vocab.get_frequency, shuffle=shuffle_sibling)
+ concept.append(concept_i)
+ edge.append(edge_i)
+ if levi_graph is True:
+ concept_with_rel, edge_with_rel = levi_amr(concept, edge, extra_arc=extra_arc)
+ concept = concept_with_rel
+ edge = edge_with_rel
+
+ augmented_concept = [[DUM] + x + [END] for x in concept]
+
+ _concept_in = list_to_tensor(augmented_concept, vocabs.get('concept_and_rel', vocabs['concept']),
+ unk_rate=unk_rate)[:-1]
+ _concept_char_in = lists_of_string_to_tensor(augmented_concept, vocabs['concept_char'])[:-1]
+ _concept_out = list_to_tensor(augmented_concept, vocabs['predictable_concept'], local_token2idx)[1:]
+
+ out_conc_len, bsz = _concept_out.shape
+ _rel = np.full((1 + out_conc_len, bsz, out_conc_len), rel_vocab.pad_idx)
+ # v: [, concept_0, ..., concept_l, ..., concept_{n-1}, ] u: [, concept_0, ..., concept_l, ..., concept_{n-1}]
+
+ for bidx, (x, y) in enumerate(zip(edge, concept)):
+ for l, _ in enumerate(y):
+ if l > 0:
+ # l=1 => pos=l+1=2
+ _rel[l + 1, bidx, 1:l + 1] = rel_vocab.get_idx(NIL)
+ for v, u, r in x:
+ if levi_graph:
+ r = 1
+ else:
+ r = rel_vocab.get_idx(r)
+ assert v > u, 'Invalid typological order'
+ _rel[v + 1, bidx, u + 1] = r
+ ret.update(
+ {'concept_in': _concept_in, 'concept_char_in': _concept_char_in, 'concept_out': _concept_out, 'rel': _rel})
+ else:
+ augmented_concept = None
+
+ token_length = ret.get('token_length', None)
+ if token_length is not None and not isinstance(token_length, torch.Tensor):
+ ret['token_length'] = torch.tensor(token_length, dtype=torch.long, device=device if (
+ isinstance(device, torch.device) or device >= 0) else 'cpu:0')
+ ret.update({'lem': _lem, 'tok': _tok, 'pos': _pos, 'ner': _ner, 'word_char': _word_char,
+ 'copy_seq': np.stack([_cp_seq, _mp_seq], -1), 'local_token2idx': local_token2idx,
+ 'local_idx2token': local_idx2token})
+ if squeeze:
+ token_field = make_batch_for_squeeze(data, augmented_concept, tokenizer, device, ret)
+ else:
+ token_field = 'token'
+ subtoken_to_tensor(token_field, ret)
+ if bart:
+ make_batch_for_bart(augmented_concept, ret, tokenizer, device)
+ move_dict_to_device(ret, device)
+
+ return ret
+
+
+def make_batch_for_bart(augmented_concept, ret, tokenizer, device, training=True):
+ token_field = 'concept'
+ tokenizer = TransformerSequenceTokenizer(tokenizer.tokenizer, token_field, cls_is_bos=True, sep_is_eos=None)
+ encodings = [tokenizer({token_field: x[:-1] if training else x}) for x in augmented_concept]
+ ret.update(merge_list_of_dict(encodings))
+ decoder_mask = []
+ max_seq_len = len(max(ret['concept_input_ids'], key=len))
+ last_concept_offset = []
+ for spans, concepts in zip(ret['concept_token_span'], augmented_concept):
+ mask = ~SelfAttentionMask.get_mask(max_seq_len, device, ret_parameter=False)
+ for group in spans:
+ for i in range(len(group)):
+ for j in range(i + 1, len(group)):
+ mask[group[i], group[j]] = True
+ decoder_mask.append(mask)
+ last_concept_offset.append(len(concepts) - 1)
+ ret['decoder_mask'] = torch.stack(decoder_mask)
+ if not training:
+ ret['last_concept_offset'] = torch.tensor(last_concept_offset, device=device, dtype=torch.long)
+ subtoken_to_tensor(token_field, ret)
+
+
+def levi_amr(concept, edge, extra_arc=False):
+ concept_with_rel = []
+ edge_with_rel = []
+ for bidx, (edge_i, concept_i) in enumerate(zip(edge, concept)):
+ concept_i, edge_i = linearize(concept_i, edge_i, NIL, prefix=REL, extra_arc=extra_arc)
+ # This is a undirectional graph, so we can safely reverse edge
+ edge_i = [tuple(reversed(sorted(x[:2]))) + x[2:] for x in edge_i]
+ concept_with_rel.append(concept_i)
+ edge_with_rel.append(edge_i)
+ return concept_with_rel, edge_with_rel
+
+
+def move_dict_to_device(ret, device):
+ if device == -1:
+ device = 'cpu:0'
+ for k, v in ret.items():
+ if isinstance(v, np.ndarray):
+ ret[k] = torch.tensor(v, device=device).contiguous()
+ elif isinstance(v, torch.Tensor):
+ ret[k] = v.to(device).contiguous()
+
+
+def subtoken_to_tensor(token_field, ret):
+ token_input_ids = PadSequenceDataLoader.pad_data(ret[f'{token_field}_input_ids'], 0, torch.long)
+ token_token_span = PadSequenceDataLoader.pad_data(ret[f'{token_field}_token_span'], 0, torch.long)
+ ret.update({f'{token_field}_token_span': token_token_span, f'{token_field}_input_ids': token_input_ids})
+
+
+def make_batch_for_squeeze(data, augmented_concept, tokenizer, device, ret):
+ token_field = 'token_and_concept'
+ attention_mask = []
+ token_and_concept = [t + [tokenizer.sep_token] + c for t, c in zip(data['token'], augmented_concept)]
+ encodings = [tokenizer({token_field: x}) for x in token_and_concept]
+ ret.update(merge_list_of_dict(encodings))
+ max_input_len = len(max(ret[f'{token_field}_input_ids'], key=len))
+ concept_mask = []
+ token_mask = []
+ token_type_ids = []
+ snt_len = []
+ last_concept_offset = []
+ for tokens, concepts, input_ids, spans in zip(data['token'], augmented_concept,
+ ret['token_and_concept_input_ids'],
+ ret['token_and_concept_token_span']):
+ raw_sent_len = len(tokens) + 1 # for [SEP]
+ raw_concept_len = len(concepts)
+ if concepts[-1] == END:
+ concept_mask.append([False] * raw_sent_len + [True] * (raw_concept_len - 1) + [False]) # skip END concept
+ else:
+ concept_mask.append([False] * raw_sent_len + [True] * raw_concept_len)
+ token_mask.append([False] + [True] * (raw_sent_len - 2) + [False] * (raw_concept_len + 1))
+ assert len(concept_mask) == len(token_mask)
+ snt_len.append(raw_sent_len - 2) # skip [CLS] and [SEP]
+ sent_len = input_ids.index(tokenizer.tokenizer.sep_token_id) + 1
+ concept_len = len(input_ids) - sent_len
+ mask = torch.zeros((max_input_len, max_input_len), dtype=torch.bool)
+ mask[:sent_len + concept_len, :sent_len] = True
+ bottom_right = ~SelfAttentionMask.get_mask(concept_len, device, ret_parameter=False)
+ mask[sent_len:sent_len + concept_len, sent_len:sent_len + concept_len] = bottom_right
+ for group in spans:
+ if group[0] >= sent_len:
+ for i in range(len(group)):
+ for j in range(i + 1, len(group)):
+ mask[group[i], group[j]] = True
+ attention_mask.append(mask)
+ _token_type_ids = [0] * sent_len + [1] * concept_len
+ token_type_ids.append(_token_type_ids)
+ assert len(input_ids) == len(_token_type_ids)
+ last_concept_offset.append(raw_concept_len - 1)
+ ret['attention_mask'] = torch.stack(attention_mask)
+ ret['concept_mask'] = PadSequenceDataLoader.pad_data(concept_mask, 0, torch.bool)
+ ret['token_mask'] = PadSequenceDataLoader.pad_data(token_mask, 0, torch.bool)
+ ret['token_type_ids'] = PadSequenceDataLoader.pad_data(token_type_ids, 0, torch.long)
+ ret['snt_len'] = PadSequenceDataLoader.pad_data(snt_len, 0, torch.long)
+ ret['last_concept_offset'] = PadSequenceDataLoader.pad_data(last_concept_offset, 0, torch.long)
+ return token_field
+
+
+def linearize(concept: List, edge: List, label='', prefix=REL, extra_arc=False):
+ vur = defaultdict(dict)
+ for v, u, r in edge:
+ vur[v][u] = r
+ concept_with_rel = []
+ edge_with_rel = []
+ reorder = dict()
+ for v, c in enumerate(concept):
+ reorder[v] = len(concept_with_rel)
+ concept_with_rel.append(c)
+ ur = vur[v]
+ for u, r in ur.items():
+ if u < v:
+ concept_with_rel.append(prefix + r)
+ for k, v in reorder.items():
+ assert concept[k] == concept_with_rel[v]
+ for v, c in enumerate(concept):
+ ur = vur[v]
+ for i, (u, r) in enumerate(ur.items()):
+ if u < v:
+ _v = reorder[v]
+ _u = reorder[u]
+ _m = _v + i + 1
+ edge_with_rel.append((_v, _m, label))
+ edge_with_rel.append((_m, _u, label))
+ if extra_arc:
+ edge_with_rel.append((_v, _u, label))
+ return concept_with_rel, edge_with_rel
+
+
+def unlinearize(concept: List, edge: List, prefix=REL, extra_arc=False):
+ real_concept, reorder = separate_concept_rel(concept, prefix)
+ if extra_arc:
+ edge = [x for x in edge if concept[x[0]].startswith(REL) or concept[x[1]].startswith(REL)]
+ real_edge = []
+ for f, b in zip(edge[::2], edge[1::2]):
+ if b[1] not in reorder:
+ continue
+ u = reorder[b[1]]
+ if f[0] not in reorder:
+ continue
+ v = reorder[f[0]]
+ r = concept[f[1]][len(prefix):]
+ real_edge.append((v, u, r))
+ return real_concept, real_edge
+
+
+def separate_concept_rel(concept, prefix=REL):
+ reorder = dict()
+ real_concept = []
+ for i, c in enumerate(concept):
+ if not c.startswith(prefix):
+ reorder[i] = len(real_concept)
+ real_concept.append(c)
+ return real_concept, reorder
+
+
+def remove_unconnected_components(concept: List, edge: List):
+ from scipy.sparse import csr_matrix
+ from scipy.sparse.csgraph._traversal import connected_components
+ row = np.array([x[0] for x in edge], dtype=np.int)
+ col = np.array([x[1] for x in edge], dtype=np.int)
+ data = np.ones(len(row), dtype=np.int)
+ graph = csr_matrix((data, (row, col)), shape=(len(concept), len(concept)))
+ n_components, labels = connected_components(csgraph=graph, directed=True, return_labels=True)
+ if n_components > 1:
+ unique, counts = np.unique(labels, return_counts=True)
+ largest_component = max(zip(counts, unique))[-1]
+ connected_nodes = set(np.where(labels == largest_component)[0])
+ reorder = dict()
+ good_concept = []
+ good_edge = []
+ for i, c in enumerate(concept):
+ if i in connected_nodes:
+ reorder[i] = len(good_concept)
+ good_concept.append(c)
+ for v, u, r in edge:
+ if v in connected_nodes and u in connected_nodes:
+ good_edge.append((reorder[v], reorder[u], r))
+ concept, edge = good_concept, good_edge
+ return concept, edge
+
+
+def largest_connected_component(triples: List):
+ node_to_id = dict()
+ concept = []
+ edge = []
+ for u, r, v in triples:
+ if u not in node_to_id:
+ node_to_id[u] = len(node_to_id)
+ concept.append(u)
+ if v not in node_to_id:
+ node_to_id[v] = len(node_to_id)
+ concept.append(v)
+ edge.append((node_to_id[u], node_to_id[v], r))
+ concept, edge = remove_unconnected_components(concept, edge)
+ return concept, edge
+
+
+def to_triples(concept: List, edge: List):
+ return [(concept[u], r, concept[v]) for u, v, r in edge]
+
+
+def reverse_edge_for_levi_bfs(concept, edge):
+ for v, u, r in edge:
+ if r == '_reverse_':
+ for x in v, u:
+ if concept[x].startswith(REL) and not concept[x].endswith('_reverse_'):
+ concept[x] += '_reverse_'
+
+
+def un_kahn(concept, edge):
+ # (['want', 'rel=ARG1', 'rel=ARG0', 'believe', 'rel=ARG1', 'rel=ARG0', 'boy', 'girl'],
+ # [(0, 1, 0.9999417066574097), (0, 2, 0.9999995231628418), (1, 3, 0.9999992847442627), (3, 4, 1.0), (3, 5, 0.9999996423721313), (2, 6, 0.9996106624603271), (4, 6, 0.9999767541885376), (5, 7, 0.9999860525131226)])
+ real_concept, reorder = separate_concept_rel(concept)
+ tri_edge = dict()
+ for m, (a, b, p1) in enumerate(edge):
+ if concept[a].startswith(REL):
+ continue
+ for n, (c, d, p2) in enumerate(edge[m + 1:]):
+ if b == c:
+ key = (a, d)
+ _, p = tri_edge.get(key, (None, 0))
+ if p1 * p2 > p:
+ tri_edge[key] = (b, p1 * p2)
+ real_edge = []
+ for (a, d), (r, p) in tri_edge.items():
+ u = reorder[a]
+ r = concept[r][len(REL):]
+ v = reorder[d]
+ real_edge.append((v, u, r))
+ return real_concept, real_edge
diff --git a/hanlp/datasets/parsing/conll_dataset.py b/hanlp/datasets/parsing/conll_dataset.py
new file mode 100644
index 000000000..a7ece3be5
--- /dev/null
+++ b/hanlp/datasets/parsing/conll_dataset.py
@@ -0,0 +1,106 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-08 16:10
+from typing import Union, List, Callable, Dict
+
+from hanlp_common.constant import ROOT, EOS, BOS
+from hanlp.common.dataset import TransformableDataset
+from hanlp.components.parsers.conll import read_conll
+from hanlp.utils.io_util import TimingFileIterator
+
+
+class CoNLLParsingDataset(TransformableDataset):
+
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ generate_idx=None,
+ prune: Callable[[Dict[str, List[str]]], bool] = None) -> None:
+ """General class for CoNLL style dependency parsing datasets.
+
+ Args:
+ data: The local or remote path to a dataset, or a list of samples where each sample is a dict.
+ transform: Predefined transform(s).
+ cache: ``True`` to enable caching, so that transforms won't be called twice.
+ generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when
+ samples are re-ordered by a sampler.
+ prune: A filter to prune unwanted samples.
+ """
+ self._prune = prune
+ super().__init__(data, transform, cache, generate_idx)
+
+ def load_file(self, filepath):
+ """Both ``.conllx`` and ``.conllu`` are supported. Their descriptions can be found in
+ :class:`hanlp_common.conll.CoNLLWord` and :class:`hanlp_common.conll.CoNLLUWord` respectively.
+
+ Args:
+ filepath: ``.conllx`` or ``.conllu`` file path.
+ """
+ if filepath.endswith('.conllu'):
+ # See https://universaldependencies.org/format.html
+ field_names = ['ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS',
+ 'FEATS', 'HEAD', 'DEPREL', 'DEPS', 'MISC']
+ else:
+ field_names = ['ID', 'FORM', 'LEMMA', 'CPOS', 'POS',
+ 'FEATS', 'HEAD', 'DEPREL', 'PHEAD', 'PDEPREL']
+ fp = TimingFileIterator(filepath)
+ for idx, sent in enumerate(read_conll(fp)):
+ sample = {}
+ for i, field in enumerate(field_names):
+ sample[field] = [cell[i] for cell in sent]
+ if not self._prune or not self._prune(sample):
+ yield sample
+ fp.log(f'{idx + 1} samples [blink][yellow]...[/yellow][/blink]')
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+
+def append_bos(sample: dict, pos_key='CPOS', bos=ROOT) -> dict:
+ """
+
+ Args:
+ sample:
+ pos_key:
+ bos: A special token inserted to the head of tokens.
+
+ Returns:
+
+ """
+ sample['token'] = [bos] + sample['FORM']
+ if pos_key in sample:
+ sample['pos'] = [ROOT] + sample[pos_key]
+ if 'HEAD' in sample:
+ sample['arc'] = [0] + sample['HEAD']
+ sample['rel'] = sample['DEPREL'][:1] + sample['DEPREL']
+ return sample
+
+
+def append_bos_eos(sample: dict) -> dict:
+ sample['token'] = [BOS] + sample['FORM'] + [EOS]
+ if 'CPOS' in sample:
+ sample['pos'] = [BOS] + sample['CPOS'] + [EOS]
+ if 'HEAD' in sample:
+ sample['arc'] = [0] + sample['HEAD'] + [0]
+ sample['rel'] = sample['DEPREL'][:1] + sample['DEPREL'] + sample['DEPREL'][:1]
+ return sample
+
+
+def get_sibs(sample: dict) -> dict:
+ heads = sample.get('arc', None)
+ if heads:
+ sibs = [-1] * len(heads)
+ for i in range(1, len(heads)):
+ hi = heads[i]
+ for j in range(i + 1, len(heads)):
+ hj = heads[j]
+ di, dj = hi - i, hj - j
+ if hi >= 0 and hj >= 0 and hi == hj and di * dj > 0:
+ if abs(di) > abs(dj):
+ sibs[i] = j
+ else:
+ sibs[j] = i
+ break
+ sample['sib_id'] = [0] + sibs[1:]
+ return sample
diff --git a/hanlp/datasets/parsing/ctb.py b/hanlp/datasets/parsing/ctb.py
deleted file mode 100644
index 2f6adcc4b..000000000
--- a/hanlp/datasets/parsing/ctb.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 18:44
-from hanlp.common.constant import HANLP_URL
-
-CTB_HOME = HANLP_URL + 'embeddings/SUDA-LA-CIP_20200109_021624.zip#'
-
-CTB5_DEP_HOME = CTB_HOME + 'BPNN/data/ctb5/'
-
-CTB5_DEP_TRAIN = CTB5_DEP_HOME + 'train.conll'
-CTB5_DEP_VALID = CTB5_DEP_HOME + 'dev.conll'
-CTB5_DEP_TEST = CTB5_DEP_HOME + 'test.conll'
-
-CTB7_HOME = CTB_HOME + 'BPNN/data/ctb7/'
-
-CTB7_DEP_TRAIN = CTB7_HOME + 'train.conll'
-CTB7_DEP_VALID = CTB7_HOME + 'dev.conll'
-CTB7_DEP_TEST = CTB7_HOME + 'test.conll'
-
-CIP_W2V_100_CN = CTB_HOME + 'BPNN/data/embed.txt'
diff --git a/hanlp/datasets/parsing/ctb5.py b/hanlp/datasets/parsing/ctb5.py
new file mode 100644
index 000000000..f7fa89a81
--- /dev/null
+++ b/hanlp/datasets/parsing/ctb5.py
@@ -0,0 +1,17 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 18:44
+from hanlp_common.constant import HANLP_URL
+
+_CTB_HOME = HANLP_URL + 'embeddings/SUDA-LA-CIP_20200109_021624.zip#'
+
+_CTB5_DEP_HOME = _CTB_HOME + 'BPNN/data/ctb5/'
+
+CTB5_DEP_TRAIN = _CTB5_DEP_HOME + 'train.conll'
+'''Training set for ctb5 dependency parsing.'''
+CTB5_DEP_DEV = _CTB5_DEP_HOME + 'dev.conll'
+'''Dev set for ctb5 dependency parsing.'''
+CTB5_DEP_TEST = _CTB5_DEP_HOME + 'test.conll'
+'''Test set for ctb5 dependency parsing.'''
+
+CIP_W2V_100_CN = _CTB_HOME + 'BPNN/data/embed.txt'
diff --git a/hanlp/datasets/parsing/ctb7.py b/hanlp/datasets/parsing/ctb7.py
new file mode 100644
index 000000000..09345b182
--- /dev/null
+++ b/hanlp/datasets/parsing/ctb7.py
@@ -0,0 +1,13 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 18:44
+from hanlp.datasets.parsing.ctb5 import _CTB_HOME
+
+_CTB7_HOME = _CTB_HOME + 'BPNN/data/ctb7/'
+
+CTB7_DEP_TRAIN = _CTB7_HOME + 'train.conll'
+'''Training set for ctb7 dependency parsing.'''
+CTB7_DEP_DEV = _CTB7_HOME + 'dev.conll'
+'''Dev set for ctb7 dependency parsing.'''
+CTB7_DEP_TEST = _CTB7_HOME + 'test.conll'
+'''Test set for ctb7 dependency parsing.'''
diff --git a/hanlp/datasets/parsing/ctb8.py b/hanlp/datasets/parsing/ctb8.py
new file mode 100644
index 000000000..2868021cc
--- /dev/null
+++ b/hanlp/datasets/parsing/ctb8.py
@@ -0,0 +1,44 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-14 20:54
+
+from hanlp.datasets.parsing._ctb_utils import make_ctb
+
+_CTB8_HOME = 'https://wakespace.lib.wfu.edu/bitstream/handle/10339/39379/LDC2013T21.tgz#data/'
+
+CTB8_CWS_TRAIN = _CTB8_HOME + 'tasks/cws/train.txt'
+'''Training set for ctb8 Chinese word segmentation.'''
+CTB8_CWS_DEV = _CTB8_HOME + 'tasks/cws/dev.txt'
+'''Dev set for ctb8 Chinese word segmentation.'''
+CTB8_CWS_TEST = _CTB8_HOME + 'tasks/cws/test.txt'
+'''Test set for ctb8 Chinese word segmentation.'''
+
+CTB8_POS_TRAIN = _CTB8_HOME + 'tasks/pos/train.tsv'
+'''Training set for ctb8 PoS tagging.'''
+CTB8_POS_DEV = _CTB8_HOME + 'tasks/pos/dev.tsv'
+'''Dev set for ctb8 PoS tagging.'''
+CTB8_POS_TEST = _CTB8_HOME + 'tasks/pos/test.tsv'
+'''Test set for ctb8 PoS tagging.'''
+
+CTB8_BRACKET_LINE_TRAIN = _CTB8_HOME + 'tasks/par/train.txt'
+'''Training set for ctb8 constituency parsing with empty categories.'''
+CTB8_BRACKET_LINE_DEV = _CTB8_HOME + 'tasks/par/dev.txt'
+'''Dev set for ctb8 constituency parsing with empty categories.'''
+CTB8_BRACKET_LINE_TEST = _CTB8_HOME + 'tasks/par/test.txt'
+'''Test set for ctb8 constituency parsing with empty categories.'''
+
+CTB8_BRACKET_LINE_NOEC_TRAIN = _CTB8_HOME + 'tasks/par/train.noempty.txt'
+'''Training set for ctb8 constituency parsing without empty categories.'''
+CTB8_BRACKET_LINE_NOEC_DEV = _CTB8_HOME + 'tasks/par/dev.noempty.txt'
+'''Dev set for ctb8 constituency parsing without empty categories.'''
+CTB8_BRACKET_LINE_NOEC_TEST = _CTB8_HOME + 'tasks/par/test.noempty.txt'
+'''Test set for ctb8 constituency parsing without empty categories.'''
+
+CTB8_SD330_TRAIN = _CTB8_HOME + 'tasks/dep/train.conllx'
+'''Training set for ctb8 in Stanford Dependencies 3.3.0 standard.'''
+CTB8_SD330_DEV = _CTB8_HOME + 'tasks/dep/dev.conllx'
+'''Dev set for ctb8 in Stanford Dependencies 3.3.0 standard.'''
+CTB8_SD330_TEST = _CTB8_HOME + 'tasks/dep/test.conllx'
+'''Test set for ctb8 in Stanford Dependencies 3.3.0 standard.'''
+
+make_ctb(_CTB8_HOME)
diff --git a/hanlp/datasets/parsing/ctb9.py b/hanlp/datasets/parsing/ctb9.py
new file mode 100644
index 000000000..15c8c36a3
--- /dev/null
+++ b/hanlp/datasets/parsing/ctb9.py
@@ -0,0 +1,55 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-10-14 20:54
+from urllib.error import HTTPError
+
+from hanlp.datasets.parsing._ctb_utils import make_ctb
+from hanlp.utils.io_util import get_resource, path_from_url
+
+_CTB9_HOME = 'https://catalog.ldc.upenn.edu/LDC2016T13/ctb9.0_LDC2016T13.tgz#data/'
+
+CTB9_CWS_TRAIN = _CTB9_HOME + 'tasks/cws/train.txt'
+'''Training set for ctb9 Chinese word segmentation.'''
+CTB9_CWS_DEV = _CTB9_HOME + 'tasks/cws/dev.txt'
+'''Dev set for ctb9 Chinese word segmentation.'''
+CTB9_CWS_TEST = _CTB9_HOME + 'tasks/cws/test.txt'
+'''Test set for ctb9 Chinese word segmentation.'''
+
+CTB9_POS_TRAIN = _CTB9_HOME + 'tasks/pos/train.tsv'
+'''Training set for ctb9 PoS tagging.'''
+CTB9_POS_DEV = _CTB9_HOME + 'tasks/pos/dev.tsv'
+'''Dev set for ctb9 PoS tagging.'''
+CTB9_POS_TEST = _CTB9_HOME + 'tasks/pos/test.tsv'
+'''Test set for ctb9 PoS tagging.'''
+
+CTB9_BRACKET_LINE_TRAIN = _CTB9_HOME + 'tasks/par/train.txt'
+'''Training set for ctb9 constituency parsing with empty categories.'''
+CTB9_BRACKET_LINE_DEV = _CTB9_HOME + 'tasks/par/dev.txt'
+'''Dev set for ctb9 constituency parsing with empty categories.'''
+CTB9_BRACKET_LINE_TEST = _CTB9_HOME + 'tasks/par/test.txt'
+'''Test set for ctb9 constituency parsing with empty categories.'''
+
+CTB9_BRACKET_LINE_NOEC_TRAIN = _CTB9_HOME + 'tasks/par/train.noempty.txt'
+'''Training set for ctb9 constituency parsing without empty categories.'''
+CTB9_BRACKET_LINE_NOEC_DEV = _CTB9_HOME + 'tasks/par/dev.noempty.txt'
+'''Dev set for ctb9 constituency parsing without empty categories.'''
+CTB9_BRACKET_LINE_NOEC_TEST = _CTB9_HOME + 'tasks/par/test.noempty.txt'
+'''Test set for ctb9 constituency parsing without empty categories.'''
+
+CTB9_SD330_TRAIN = _CTB9_HOME + 'tasks/dep/train.conllx'
+'''Training set for ctb9 in Stanford Dependencies 3.3.0 standard.'''
+CTB9_SD330_DEV = _CTB9_HOME + 'tasks/dep/dev.conllx'
+'''Dev set for ctb9 in Stanford Dependencies 3.3.0 standard.'''
+CTB9_SD330_TEST = _CTB9_HOME + 'tasks/dep/test.conllx'
+'''Test set for ctb9 in Stanford Dependencies 3.3.0 standard.'''
+
+try:
+ get_resource(_CTB9_HOME)
+except HTTPError:
+ raise FileNotFoundError(
+ 'Chinese Treebank 9.0 is a copyright dataset owned by LDC which we cannot re-distribute. '
+ f'Please apply for a licence from LDC (https://catalog.ldc.upenn.edu/LDC2016T13) '
+ f'then download it to {path_from_url(_CTB9_HOME)}'
+ )
+
+make_ctb(_CTB9_HOME)
diff --git a/hanlp/datasets/parsing/ptb.py b/hanlp/datasets/parsing/ptb.py
new file mode 100644
index 000000000..2f7636926
--- /dev/null
+++ b/hanlp/datasets/parsing/ptb.py
@@ -0,0 +1,48 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-02-17 15:46
+
+_PTB_HOME = 'https://github.com/KhalilMrini/LAL-Parser/archive/master.zip#data/'
+
+PTB_TRAIN = _PTB_HOME + '02-21.10way.clean'
+'''Training set of PTB without empty categories. PoS tags are automatically predicted using 10-fold
+jackknifing (:cite:`collins-koo-2005-discriminative`).'''
+PTB_DEV = _PTB_HOME + '22.auto.clean'
+'''Dev set of PTB without empty categories. PoS tags are automatically predicted using 10-fold
+jackknifing (:cite:`collins-koo-2005-discriminative`).'''
+PTB_TEST = _PTB_HOME + '23.auto.clean'
+'''Test set of PTB without empty categories. PoS tags are automatically predicted using 10-fold
+jackknifing (:cite:`collins-koo-2005-discriminative`).'''
+
+PTB_SD330_TRAIN = _PTB_HOME + 'ptb_train_3.3.0.sd.clean'
+'''Training set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold
+jackknifing (:cite:`collins-koo-2005-discriminative`).'''
+PTB_SD330_DEV = _PTB_HOME + 'ptb_dev_3.3.0.sd.clean'
+'''Dev set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold
+jackknifing (:cite:`collins-koo-2005-discriminative`).'''
+PTB_SD330_TEST = _PTB_HOME + 'ptb_test_3.3.0.sd.clean'
+'''Test set of PTB in Stanford Dependencies 3.3.0 format. PoS tags are automatically predicted using 10-fold
+jackknifing (:cite:`collins-koo-2005-discriminative`).'''
+
+PTB_TOKEN_MAPPING = {
+ "-LRB-": "(",
+ "-RRB-": ")",
+ "-LCB-": "{",
+ "-RCB-": "}",
+ "-LSB-": "[",
+ "-RSB-": "]",
+ "``": '"',
+ "''": '"',
+ "`": "'",
+ '«': '"',
+ '»': '"',
+ '‘': "'",
+ '’': "'",
+ '“': '"',
+ '”': '"',
+ '„': '"',
+ '‹': "'",
+ '›': "'",
+ "\u2013": "--", # en dash
+ "\u2014": "--", # em dash
+}
diff --git a/hanlp/datasets/parsing/semeval15.py b/hanlp/datasets/parsing/semeval15.py
new file mode 100644
index 000000000..f8e74551e
--- /dev/null
+++ b/hanlp/datasets/parsing/semeval15.py
@@ -0,0 +1,61 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-28 14:40
+# from hanlp.datasets.parsing.conll_dataset import CoNLLParsingDataset
+#
+#
+# class SemEval15Dataset(CoNLLParsingDataset):
+# def load_file(self, filepath: str):
+# pass
+import warnings
+
+from hanlp_common.constant import ROOT, PAD
+from hanlp_common.conll import CoNLLSentence
+
+
+def unpack_deps_to_head_deprel(sample: dict, pad_rel=None, arc_key='arc', rel_key='rel'):
+ if 'DEPS' in sample:
+ deps = ['_'] + sample['DEPS']
+ sample[arc_key] = arc = []
+ sample[rel_key] = rel = []
+ for each in deps:
+ arc_per_token = [False] * len(deps)
+ rel_per_token = [None] * len(deps)
+ if each != '_':
+ for ar in each.split('|'):
+ a, r = ar.split(':')
+ a = int(a)
+ arc_per_token[a] = True
+ rel_per_token[a] = r
+ if not pad_rel:
+ pad_rel = r
+ arc.append(arc_per_token)
+ rel.append(rel_per_token)
+ if not pad_rel:
+ pad_rel = PAD
+ for i in range(len(rel)):
+ rel[i] = [r if r else pad_rel for r in rel[i]]
+ return sample
+
+
+def append_bos_to_form_pos(sample, pos_key='CPOS'):
+ sample['token'] = [ROOT] + sample['FORM']
+ if pos_key in sample:
+ sample['pos'] = [ROOT] + sample[pos_key]
+ return sample
+
+
+def merge_head_deprel_with_2nd(sample: dict):
+ if 'arc' in sample:
+ arc_2nd = sample['arc_2nd']
+ rel_2nd = sample['rel_2nd']
+ for i, (arc, rel) in enumerate(zip(sample['arc'], sample['rel'])):
+ if i:
+ if arc_2nd[i][arc] and rel_2nd[i][arc] != rel:
+ sample_str = CoNLLSentence.from_dict(sample, conllu=True).to_markdown()
+ warnings.warn(f'The main dependency conflicts with 2nd dependency at ID={i}, ' \
+ 'which means joint mode might not be suitable. ' \
+ f'The sample is\n{sample_str}')
+ arc_2nd[i][arc] = True
+ rel_2nd[i][arc] = rel
+ return sample
diff --git a/hanlp/datasets/parsing/semeval16.py b/hanlp/datasets/parsing/semeval16.py
new file mode 100644
index 000000000..a7d9b046b
--- /dev/null
+++ b/hanlp/datasets/parsing/semeval16.py
@@ -0,0 +1,64 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 00:51
+from hanlp_common.conll import CoNLLSentence
+import os
+
+from hanlp.utils.io_util import get_resource, merge_files
+from hanlp_common.io import eprint
+
+_SEMEVAL2016_HOME = 'https://github.com/HIT-SCIR/SemEval-2016/archive/master.zip'
+
+SEMEVAL2016_NEWS_TRAIN = _SEMEVAL2016_HOME + '#train/news.train.conll'
+SEMEVAL2016_NEWS_DEV = _SEMEVAL2016_HOME + '#validation/news.valid.conll'
+SEMEVAL2016_NEWS_TEST = _SEMEVAL2016_HOME + '#test/news.test.conll'
+
+SEMEVAL2016_NEWS_TRAIN_CONLLU = _SEMEVAL2016_HOME + '#train/news.train.conllu'
+SEMEVAL2016_NEWS_DEV_CONLLU = _SEMEVAL2016_HOME + '#validation/news.valid.conllu'
+SEMEVAL2016_NEWS_TEST_CONLLU = _SEMEVAL2016_HOME + '#test/news.test.conllu'
+
+SEMEVAL2016_TEXT_TRAIN = _SEMEVAL2016_HOME + '#train/text.train.conll'
+SEMEVAL2016_TEXT_DEV = _SEMEVAL2016_HOME + '#validation/text.valid.conll'
+SEMEVAL2016_TEXT_TEST = _SEMEVAL2016_HOME + '#test/text.test.conll'
+
+SEMEVAL2016_TEXT_TRAIN_CONLLU = _SEMEVAL2016_HOME + '#train/text.train.conllu'
+SEMEVAL2016_TEXT_DEV_CONLLU = _SEMEVAL2016_HOME + '#validation/text.valid.conllu'
+SEMEVAL2016_TEXT_TEST_CONLLU = _SEMEVAL2016_HOME + '#test/text.test.conllu'
+
+SEMEVAL2016_FULL_TRAIN_CONLLU = _SEMEVAL2016_HOME + '#train/full.train.conllu'
+SEMEVAL2016_FULL_DEV_CONLLU = _SEMEVAL2016_HOME + '#validation/full.valid.conllu'
+SEMEVAL2016_FULL_TEST_CONLLU = _SEMEVAL2016_HOME + '#test/full.test.conllu'
+
+
+def convert_conll_to_conllu(path):
+ sents = CoNLLSentence.from_file(path, conllu=True)
+ with open(os.path.splitext(path)[0] + '.conllu', 'w') as out:
+ for sent in sents:
+ for word in sent:
+ if not word.deps:
+ word.deps = [(word.head, word.deprel)]
+ word.head = None
+ word.deprel = None
+ out.write(str(sent))
+ out.write('\n\n')
+
+
+for file in [SEMEVAL2016_NEWS_TRAIN, SEMEVAL2016_NEWS_DEV, SEMEVAL2016_NEWS_TEST,
+ SEMEVAL2016_TEXT_TRAIN, SEMEVAL2016_TEXT_DEV, SEMEVAL2016_TEXT_TEST]:
+ file = get_resource(file)
+ conllu = os.path.splitext(file)[0] + '.conllu'
+ if not os.path.isfile(conllu):
+ eprint(f'Converting {os.path.basename(file)} to {os.path.basename(conllu)} ...')
+ convert_conll_to_conllu(file)
+
+for group, part in zip([[SEMEVAL2016_NEWS_TRAIN_CONLLU, SEMEVAL2016_TEXT_TRAIN_CONLLU],
+ [SEMEVAL2016_NEWS_DEV_CONLLU, SEMEVAL2016_TEXT_DEV_CONLLU],
+ [SEMEVAL2016_NEWS_TEST_CONLLU, SEMEVAL2016_TEXT_TEST_CONLLU]],
+ ['train', 'valid', 'test']):
+ root = get_resource(_SEMEVAL2016_HOME)
+ dst = f'{root}/train/full.{part}.conllu'
+ if not os.path.isfile(dst):
+ group = [get_resource(x) for x in group]
+ eprint(f'Concatenating {os.path.basename(group[0])} and {os.path.basename(group[1])} '
+ f'into full dataset {os.path.basename(dst)} ...')
+ merge_files(group, dst)
diff --git a/hanlp/datasets/parsing/semeval2016.py b/hanlp/datasets/parsing/semeval2016.py
deleted file mode 100644
index d3e4c2835..000000000
--- a/hanlp/datasets/parsing/semeval2016.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 00:51
-
-SEMEVAL2016 = 'https://github.com/HIT-SCIR/SemEval-2016/archive/master.zip'
-
-SEMEVAL2016_NEWS_TRAIN = SEMEVAL2016 + '#train/news.train.conll'
-SEMEVAL2016_NEWS_VALID = SEMEVAL2016 + '#validation/news.valid.conll'
-SEMEVAL2016_NEWS_TEST = SEMEVAL2016 + '#test/news.test.conll'
-
-SEMEVAL2016_TEXT_TRAIN = SEMEVAL2016 + '#train/text.train.conll'
-SEMEVAL2016_TEXT_VALID = SEMEVAL2016 + '#validation/text.valid.conll'
-SEMEVAL2016_TEXT_TEST = SEMEVAL2016 + '#test/text.test.conll'
diff --git a/hanlp/datasets/parsing/ud/__init__.py b/hanlp/datasets/parsing/ud/__init__.py
new file mode 100644
index 000000000..bf6f057cc
--- /dev/null
+++ b/hanlp/datasets/parsing/ud/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-07 21:45
+import os
+import shutil
+
+from hanlp.components.parsers.ud.udify_util import get_ud_treebank_files
+from hanlp.utils.io_util import get_resource
+from hanlp.utils.log_util import flash
+
+
+def concat_treebanks(home, version):
+ ud_home = get_resource(home)
+ treebanks = get_ud_treebank_files(ud_home)
+ output_dir = os.path.abspath(os.path.join(ud_home, os.path.pardir, os.path.pardir, f'ud-multilingual-v{version}'))
+ if os.path.isdir(output_dir):
+ return output_dir
+ os.makedirs(output_dir)
+ train, dev, test = list(zip(*[treebanks[k] for k in treebanks]))
+
+ for treebank, name in zip([train, dev, test], ["train.conllu", "dev.conllu", "test.conllu"]):
+ flash(f'Concatenating {len(train)} treebanks into {name} [blink][yellow]...[/yellow][/blink]')
+ with open(os.path.join(output_dir, name), 'w') as write:
+ for t in treebank:
+ if not t:
+ continue
+ with open(t, 'r') as read:
+ shutil.copyfileobj(read, write)
+ flash('')
+ return output_dir
diff --git a/hanlp/datasets/parsing/ud/ud23.py b/hanlp/datasets/parsing/ud/ud23.py
new file mode 100644
index 000000000..dbfbde9a1
--- /dev/null
+++ b/hanlp/datasets/parsing/ud/ud23.py
@@ -0,0 +1,341 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-21 20:26
+
+_UD_23_HOME = "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-2895/ud-treebanks-v2.3.tgz?sequence=1&isAllowed=y"
+_UD_24_HOME = "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-2988/ud-treebanks-v2.4.tgz?sequence=4&isAllowed=y"
+
+
+def _list_dir(path, home):
+ prefix = home.lstrip('_').replace('_HOME', '')
+
+ from hanlp.utils.io_util import get_resource
+ import glob
+ import os
+ path = get_resource(path)
+ with open('ud23.py', 'a') as out:
+ for f in sorted(glob.glob(path + '/UD_*')):
+ basename = os.path.basename(f)
+ name = basename[len('UD_'):]
+ name = name.upper().replace('-', '_')
+ for split in 'train', 'dev', 'test':
+ sp = glob.glob(f + f'/*{split}.conllu')
+ if not sp:
+ continue
+ sp = os.path.basename(sp[0])
+ out.write(f'{prefix}_{name}_{split.upper()} = {home} + "#{basename}/{sp}"\n')
+
+
+def main():
+ _list_dir(_UD_23_HOME, '_UD_23_HOME')
+ pass
+
+
+if __name__ == '__main__':
+ main()
+
+UD_23_AFRIKAANS_AFRIBOOMS_TRAIN = _UD_23_HOME + "#UD_Afrikaans-AfriBooms/af_afribooms-ud-train.conllu"
+UD_23_AFRIKAANS_AFRIBOOMS_DEV = _UD_23_HOME + "#UD_Afrikaans-AfriBooms/af_afribooms-ud-dev.conllu"
+UD_23_AFRIKAANS_AFRIBOOMS_TEST = _UD_23_HOME + "#UD_Afrikaans-AfriBooms/af_afribooms-ud-test.conllu"
+UD_23_AKKADIAN_PISANDUB_TEST = _UD_23_HOME + "#UD_Akkadian-PISANDUB/akk_pisandub-ud-test.conllu"
+UD_23_AMHARIC_ATT_TEST = _UD_23_HOME + "#UD_Amharic-ATT/am_att-ud-test.conllu"
+UD_23_ANCIENT_GREEK_PROIEL_TRAIN = _UD_23_HOME + "#UD_Ancient_Greek-PROIEL/grc_proiel-ud-train.conllu"
+UD_23_ANCIENT_GREEK_PROIEL_DEV = _UD_23_HOME + "#UD_Ancient_Greek-PROIEL/grc_proiel-ud-dev.conllu"
+UD_23_ANCIENT_GREEK_PROIEL_TEST = _UD_23_HOME + "#UD_Ancient_Greek-PROIEL/grc_proiel-ud-test.conllu"
+UD_23_ANCIENT_GREEK_PERSEUS_TRAIN = _UD_23_HOME + "#UD_Ancient_Greek-Perseus/grc_perseus-ud-train.conllu"
+UD_23_ANCIENT_GREEK_PERSEUS_DEV = _UD_23_HOME + "#UD_Ancient_Greek-Perseus/grc_perseus-ud-dev.conllu"
+UD_23_ANCIENT_GREEK_PERSEUS_TEST = _UD_23_HOME + "#UD_Ancient_Greek-Perseus/grc_perseus-ud-test.conllu"
+UD_23_ARABIC_NYUAD_TRAIN = _UD_23_HOME + "#UD_Arabic-NYUAD/ar_nyuad-ud-train.conllu"
+UD_23_ARABIC_NYUAD_DEV = _UD_23_HOME + "#UD_Arabic-NYUAD/ar_nyuad-ud-dev.conllu"
+UD_23_ARABIC_NYUAD_TEST = _UD_23_HOME + "#UD_Arabic-NYUAD/ar_nyuad-ud-test.conllu"
+UD_23_ARABIC_PADT_TRAIN = _UD_23_HOME + "#UD_Arabic-PADT/ar_padt-ud-train.conllu"
+UD_23_ARABIC_PADT_DEV = _UD_23_HOME + "#UD_Arabic-PADT/ar_padt-ud-dev.conllu"
+UD_23_ARABIC_PADT_TEST = _UD_23_HOME + "#UD_Arabic-PADT/ar_padt-ud-test.conllu"
+UD_23_ARABIC_PUD_TEST = _UD_23_HOME + "#UD_Arabic-PUD/ar_pud-ud-test.conllu"
+UD_23_ARMENIAN_ARMTDP_TRAIN = _UD_23_HOME + "#UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu"
+UD_23_ARMENIAN_ARMTDP_TEST = _UD_23_HOME + "#UD_Armenian-ArmTDP/hy_armtdp-ud-test.conllu"
+UD_23_BAMBARA_CRB_TEST = _UD_23_HOME + "#UD_Bambara-CRB/bm_crb-ud-test.conllu"
+UD_23_BASQUE_BDT_TRAIN = _UD_23_HOME + "#UD_Basque-BDT/eu_bdt-ud-train.conllu"
+UD_23_BASQUE_BDT_DEV = _UD_23_HOME + "#UD_Basque-BDT/eu_bdt-ud-dev.conllu"
+UD_23_BASQUE_BDT_TEST = _UD_23_HOME + "#UD_Basque-BDT/eu_bdt-ud-test.conllu"
+UD_23_BELARUSIAN_HSE_TRAIN = _UD_23_HOME + "#UD_Belarusian-HSE/be_hse-ud-train.conllu"
+UD_23_BELARUSIAN_HSE_DEV = _UD_23_HOME + "#UD_Belarusian-HSE/be_hse-ud-dev.conllu"
+UD_23_BELARUSIAN_HSE_TEST = _UD_23_HOME + "#UD_Belarusian-HSE/be_hse-ud-test.conllu"
+UD_23_BRETON_KEB_TEST = _UD_23_HOME + "#UD_Breton-KEB/br_keb-ud-test.conllu"
+UD_23_BULGARIAN_BTB_TRAIN = _UD_23_HOME + "#UD_Bulgarian-BTB/bg_btb-ud-train.conllu"
+UD_23_BULGARIAN_BTB_DEV = _UD_23_HOME + "#UD_Bulgarian-BTB/bg_btb-ud-dev.conllu"
+UD_23_BULGARIAN_BTB_TEST = _UD_23_HOME + "#UD_Bulgarian-BTB/bg_btb-ud-test.conllu"
+UD_23_BURYAT_BDT_TRAIN = _UD_23_HOME + "#UD_Buryat-BDT/bxr_bdt-ud-train.conllu"
+UD_23_BURYAT_BDT_TEST = _UD_23_HOME + "#UD_Buryat-BDT/bxr_bdt-ud-test.conllu"
+UD_23_CANTONESE_HK_TEST = _UD_23_HOME + "#UD_Cantonese-HK/yue_hk-ud-test.conllu"
+UD_23_CATALAN_ANCORA_TRAIN = _UD_23_HOME + "#UD_Catalan-AnCora/ca_ancora-ud-train.conllu"
+UD_23_CATALAN_ANCORA_DEV = _UD_23_HOME + "#UD_Catalan-AnCora/ca_ancora-ud-dev.conllu"
+UD_23_CATALAN_ANCORA_TEST = _UD_23_HOME + "#UD_Catalan-AnCora/ca_ancora-ud-test.conllu"
+UD_23_CHINESE_CFL_TEST = _UD_23_HOME + "#UD_Chinese-CFL/zh_cfl-ud-test.conllu"
+UD_23_CHINESE_GSD_TRAIN = _UD_23_HOME + "#UD_Chinese-GSD/zh_gsd-ud-train.conllu"
+UD_23_CHINESE_GSD_DEV = _UD_23_HOME + "#UD_Chinese-GSD/zh_gsd-ud-dev.conllu"
+UD_23_CHINESE_GSD_TEST = _UD_23_HOME + "#UD_Chinese-GSD/zh_gsd-ud-test.conllu"
+UD_23_CHINESE_HK_TEST = _UD_23_HOME + "#UD_Chinese-HK/zh_hk-ud-test.conllu"
+UD_23_CHINESE_PUD_TEST = _UD_23_HOME + "#UD_Chinese-PUD/zh_pud-ud-test.conllu"
+UD_23_COPTIC_SCRIPTORIUM_TRAIN = _UD_23_HOME + "#UD_Coptic-Scriptorium/cop_scriptorium-ud-train.conllu"
+UD_23_COPTIC_SCRIPTORIUM_DEV = _UD_23_HOME + "#UD_Coptic-Scriptorium/cop_scriptorium-ud-dev.conllu"
+UD_23_COPTIC_SCRIPTORIUM_TEST = _UD_23_HOME + "#UD_Coptic-Scriptorium/cop_scriptorium-ud-test.conllu"
+UD_23_CROATIAN_SET_TRAIN = _UD_23_HOME + "#UD_Croatian-SET/hr_set-ud-train.conllu"
+UD_23_CROATIAN_SET_DEV = _UD_23_HOME + "#UD_Croatian-SET/hr_set-ud-dev.conllu"
+UD_23_CROATIAN_SET_TEST = _UD_23_HOME + "#UD_Croatian-SET/hr_set-ud-test.conllu"
+UD_23_CZECH_CAC_TRAIN = _UD_23_HOME + "#UD_Czech-CAC/cs_cac-ud-train.conllu"
+UD_23_CZECH_CAC_DEV = _UD_23_HOME + "#UD_Czech-CAC/cs_cac-ud-dev.conllu"
+UD_23_CZECH_CAC_TEST = _UD_23_HOME + "#UD_Czech-CAC/cs_cac-ud-test.conllu"
+UD_23_CZECH_CLTT_TRAIN = _UD_23_HOME + "#UD_Czech-CLTT/cs_cltt-ud-train.conllu"
+UD_23_CZECH_CLTT_DEV = _UD_23_HOME + "#UD_Czech-CLTT/cs_cltt-ud-dev.conllu"
+UD_23_CZECH_CLTT_TEST = _UD_23_HOME + "#UD_Czech-CLTT/cs_cltt-ud-test.conllu"
+UD_23_CZECH_FICTREE_TRAIN = _UD_23_HOME + "#UD_Czech-FicTree/cs_fictree-ud-train.conllu"
+UD_23_CZECH_FICTREE_DEV = _UD_23_HOME + "#UD_Czech-FicTree/cs_fictree-ud-dev.conllu"
+UD_23_CZECH_FICTREE_TEST = _UD_23_HOME + "#UD_Czech-FicTree/cs_fictree-ud-test.conllu"
+UD_23_CZECH_PDT_TRAIN = _UD_23_HOME + "#UD_Czech-PDT/cs_pdt-ud-train.conllu"
+UD_23_CZECH_PDT_DEV = _UD_23_HOME + "#UD_Czech-PDT/cs_pdt-ud-dev.conllu"
+UD_23_CZECH_PDT_TEST = _UD_23_HOME + "#UD_Czech-PDT/cs_pdt-ud-test.conllu"
+UD_23_CZECH_PUD_TEST = _UD_23_HOME + "#UD_Czech-PUD/cs_pud-ud-test.conllu"
+UD_23_DANISH_DDT_TRAIN = _UD_23_HOME + "#UD_Danish-DDT/da_ddt-ud-train.conllu"
+UD_23_DANISH_DDT_DEV = _UD_23_HOME + "#UD_Danish-DDT/da_ddt-ud-dev.conllu"
+UD_23_DANISH_DDT_TEST = _UD_23_HOME + "#UD_Danish-DDT/da_ddt-ud-test.conllu"
+UD_23_DUTCH_ALPINO_TRAIN = _UD_23_HOME + "#UD_Dutch-Alpino/nl_alpino-ud-train.conllu"
+UD_23_DUTCH_ALPINO_DEV = _UD_23_HOME + "#UD_Dutch-Alpino/nl_alpino-ud-dev.conllu"
+UD_23_DUTCH_ALPINO_TEST = _UD_23_HOME + "#UD_Dutch-Alpino/nl_alpino-ud-test.conllu"
+UD_23_DUTCH_LASSYSMALL_TRAIN = _UD_23_HOME + "#UD_Dutch-LassySmall/nl_lassysmall-ud-train.conllu"
+UD_23_DUTCH_LASSYSMALL_DEV = _UD_23_HOME + "#UD_Dutch-LassySmall/nl_lassysmall-ud-dev.conllu"
+UD_23_DUTCH_LASSYSMALL_TEST = _UD_23_HOME + "#UD_Dutch-LassySmall/nl_lassysmall-ud-test.conllu"
+UD_23_ENGLISH_ESL_TRAIN = _UD_23_HOME + "#UD_English-ESL/en_esl-ud-train.conllu"
+UD_23_ENGLISH_ESL_DEV = _UD_23_HOME + "#UD_English-ESL/en_esl-ud-dev.conllu"
+UD_23_ENGLISH_ESL_TEST = _UD_23_HOME + "#UD_English-ESL/en_esl-ud-test.conllu"
+UD_23_ENGLISH_EWT_TRAIN = _UD_23_HOME + "#UD_English-EWT/en_ewt-ud-train.conllu"
+UD_23_ENGLISH_EWT_DEV = _UD_23_HOME + "#UD_English-EWT/en_ewt-ud-dev.conllu"
+UD_23_ENGLISH_EWT_TEST = _UD_23_HOME + "#UD_English-EWT/en_ewt-ud-test.conllu"
+UD_23_ENGLISH_GUM_TRAIN = _UD_23_HOME + "#UD_English-GUM/en_gum-ud-train.conllu"
+UD_23_ENGLISH_GUM_DEV = _UD_23_HOME + "#UD_English-GUM/en_gum-ud-dev.conllu"
+UD_23_ENGLISH_GUM_TEST = _UD_23_HOME + "#UD_English-GUM/en_gum-ud-test.conllu"
+UD_23_ENGLISH_LINES_TRAIN = _UD_23_HOME + "#UD_English-LinES/en_lines-ud-train.conllu"
+UD_23_ENGLISH_LINES_DEV = _UD_23_HOME + "#UD_English-LinES/en_lines-ud-dev.conllu"
+UD_23_ENGLISH_LINES_TEST = _UD_23_HOME + "#UD_English-LinES/en_lines-ud-test.conllu"
+UD_23_ENGLISH_PUD_TEST = _UD_23_HOME + "#UD_English-PUD/en_pud-ud-test.conllu"
+UD_23_ENGLISH_PARTUT_TRAIN = _UD_23_HOME + "#UD_English-ParTUT/en_partut-ud-train.conllu"
+UD_23_ENGLISH_PARTUT_DEV = _UD_23_HOME + "#UD_English-ParTUT/en_partut-ud-dev.conllu"
+UD_23_ENGLISH_PARTUT_TEST = _UD_23_HOME + "#UD_English-ParTUT/en_partut-ud-test.conllu"
+UD_23_ERZYA_JR_TEST = _UD_23_HOME + "#UD_Erzya-JR/myv_jr-ud-test.conllu"
+UD_23_ESTONIAN_EDT_TRAIN = _UD_23_HOME + "#UD_Estonian-EDT/et_edt-ud-train.conllu"
+UD_23_ESTONIAN_EDT_DEV = _UD_23_HOME + "#UD_Estonian-EDT/et_edt-ud-dev.conllu"
+UD_23_ESTONIAN_EDT_TEST = _UD_23_HOME + "#UD_Estonian-EDT/et_edt-ud-test.conllu"
+UD_23_FAROESE_OFT_TEST = _UD_23_HOME + "#UD_Faroese-OFT/fo_oft-ud-test.conllu"
+UD_23_FINNISH_FTB_TRAIN = _UD_23_HOME + "#UD_Finnish-FTB/fi_ftb-ud-train.conllu"
+UD_23_FINNISH_FTB_DEV = _UD_23_HOME + "#UD_Finnish-FTB/fi_ftb-ud-dev.conllu"
+UD_23_FINNISH_FTB_TEST = _UD_23_HOME + "#UD_Finnish-FTB/fi_ftb-ud-test.conllu"
+UD_23_FINNISH_PUD_TEST = _UD_23_HOME + "#UD_Finnish-PUD/fi_pud-ud-test.conllu"
+UD_23_FINNISH_TDT_TRAIN = _UD_23_HOME + "#UD_Finnish-TDT/fi_tdt-ud-train.conllu"
+UD_23_FINNISH_TDT_DEV = _UD_23_HOME + "#UD_Finnish-TDT/fi_tdt-ud-dev.conllu"
+UD_23_FINNISH_TDT_TEST = _UD_23_HOME + "#UD_Finnish-TDT/fi_tdt-ud-test.conllu"
+UD_23_FRENCH_FTB_TRAIN = _UD_23_HOME + "#UD_French-FTB/fr_ftb-ud-train.conllu"
+UD_23_FRENCH_FTB_DEV = _UD_23_HOME + "#UD_French-FTB/fr_ftb-ud-dev.conllu"
+UD_23_FRENCH_FTB_TEST = _UD_23_HOME + "#UD_French-FTB/fr_ftb-ud-test.conllu"
+UD_23_FRENCH_GSD_TRAIN = _UD_23_HOME + "#UD_French-GSD/fr_gsd-ud-train.conllu"
+UD_23_FRENCH_GSD_DEV = _UD_23_HOME + "#UD_French-GSD/fr_gsd-ud-dev.conllu"
+UD_23_FRENCH_GSD_TEST = _UD_23_HOME + "#UD_French-GSD/fr_gsd-ud-test.conllu"
+UD_23_FRENCH_PUD_TEST = _UD_23_HOME + "#UD_French-PUD/fr_pud-ud-test.conllu"
+UD_23_FRENCH_PARTUT_TRAIN = _UD_23_HOME + "#UD_French-ParTUT/fr_partut-ud-train.conllu"
+UD_23_FRENCH_PARTUT_DEV = _UD_23_HOME + "#UD_French-ParTUT/fr_partut-ud-dev.conllu"
+UD_23_FRENCH_PARTUT_TEST = _UD_23_HOME + "#UD_French-ParTUT/fr_partut-ud-test.conllu"
+UD_23_FRENCH_SEQUOIA_TRAIN = _UD_23_HOME + "#UD_French-Sequoia/fr_sequoia-ud-train.conllu"
+UD_23_FRENCH_SEQUOIA_DEV = _UD_23_HOME + "#UD_French-Sequoia/fr_sequoia-ud-dev.conllu"
+UD_23_FRENCH_SEQUOIA_TEST = _UD_23_HOME + "#UD_French-Sequoia/fr_sequoia-ud-test.conllu"
+UD_23_FRENCH_SPOKEN_TRAIN = _UD_23_HOME + "#UD_French-Spoken/fr_spoken-ud-train.conllu"
+UD_23_FRENCH_SPOKEN_DEV = _UD_23_HOME + "#UD_French-Spoken/fr_spoken-ud-dev.conllu"
+UD_23_FRENCH_SPOKEN_TEST = _UD_23_HOME + "#UD_French-Spoken/fr_spoken-ud-test.conllu"
+UD_23_GALICIAN_CTG_TRAIN = _UD_23_HOME + "#UD_Galician-CTG/gl_ctg-ud-train.conllu"
+UD_23_GALICIAN_CTG_DEV = _UD_23_HOME + "#UD_Galician-CTG/gl_ctg-ud-dev.conllu"
+UD_23_GALICIAN_CTG_TEST = _UD_23_HOME + "#UD_Galician-CTG/gl_ctg-ud-test.conllu"
+UD_23_GALICIAN_TREEGAL_TRAIN = _UD_23_HOME + "#UD_Galician-TreeGal/gl_treegal-ud-train.conllu"
+UD_23_GALICIAN_TREEGAL_TEST = _UD_23_HOME + "#UD_Galician-TreeGal/gl_treegal-ud-test.conllu"
+UD_23_GERMAN_GSD_TRAIN = _UD_23_HOME + "#UD_German-GSD/de_gsd-ud-train.conllu"
+UD_23_GERMAN_GSD_DEV = _UD_23_HOME + "#UD_German-GSD/de_gsd-ud-dev.conllu"
+UD_23_GERMAN_GSD_TEST = _UD_23_HOME + "#UD_German-GSD/de_gsd-ud-test.conllu"
+UD_23_GERMAN_PUD_TEST = _UD_23_HOME + "#UD_German-PUD/de_pud-ud-test.conllu"
+UD_23_GOTHIC_PROIEL_TRAIN = _UD_23_HOME + "#UD_Gothic-PROIEL/got_proiel-ud-train.conllu"
+UD_23_GOTHIC_PROIEL_DEV = _UD_23_HOME + "#UD_Gothic-PROIEL/got_proiel-ud-dev.conllu"
+UD_23_GOTHIC_PROIEL_TEST = _UD_23_HOME + "#UD_Gothic-PROIEL/got_proiel-ud-test.conllu"
+UD_23_GREEK_GDT_TRAIN = _UD_23_HOME + "#UD_Greek-GDT/el_gdt-ud-train.conllu"
+UD_23_GREEK_GDT_DEV = _UD_23_HOME + "#UD_Greek-GDT/el_gdt-ud-dev.conllu"
+UD_23_GREEK_GDT_TEST = _UD_23_HOME + "#UD_Greek-GDT/el_gdt-ud-test.conllu"
+UD_23_HEBREW_HTB_TRAIN = _UD_23_HOME + "#UD_Hebrew-HTB/he_htb-ud-train.conllu"
+UD_23_HEBREW_HTB_DEV = _UD_23_HOME + "#UD_Hebrew-HTB/he_htb-ud-dev.conllu"
+UD_23_HEBREW_HTB_TEST = _UD_23_HOME + "#UD_Hebrew-HTB/he_htb-ud-test.conllu"
+UD_23_HINDI_HDTB_TRAIN = _UD_23_HOME + "#UD_Hindi-HDTB/hi_hdtb-ud-train.conllu"
+UD_23_HINDI_HDTB_DEV = _UD_23_HOME + "#UD_Hindi-HDTB/hi_hdtb-ud-dev.conllu"
+UD_23_HINDI_HDTB_TEST = _UD_23_HOME + "#UD_Hindi-HDTB/hi_hdtb-ud-test.conllu"
+UD_23_HINDI_PUD_TEST = _UD_23_HOME + "#UD_Hindi-PUD/hi_pud-ud-test.conllu"
+UD_23_HINDI_ENGLISH_HIENCS_TRAIN = _UD_23_HOME + "#UD_Hindi_English-HIENCS/qhe_hiencs-ud-train.conllu"
+UD_23_HINDI_ENGLISH_HIENCS_DEV = _UD_23_HOME + "#UD_Hindi_English-HIENCS/qhe_hiencs-ud-dev.conllu"
+UD_23_HINDI_ENGLISH_HIENCS_TEST = _UD_23_HOME + "#UD_Hindi_English-HIENCS/qhe_hiencs-ud-test.conllu"
+UD_23_HUNGARIAN_SZEGED_TRAIN = _UD_23_HOME + "#UD_Hungarian-Szeged/hu_szeged-ud-train.conllu"
+UD_23_HUNGARIAN_SZEGED_DEV = _UD_23_HOME + "#UD_Hungarian-Szeged/hu_szeged-ud-dev.conllu"
+UD_23_HUNGARIAN_SZEGED_TEST = _UD_23_HOME + "#UD_Hungarian-Szeged/hu_szeged-ud-test.conllu"
+UD_23_INDONESIAN_GSD_TRAIN = _UD_23_HOME + "#UD_Indonesian-GSD/id_gsd-ud-train.conllu"
+UD_23_INDONESIAN_GSD_DEV = _UD_23_HOME + "#UD_Indonesian-GSD/id_gsd-ud-dev.conllu"
+UD_23_INDONESIAN_GSD_TEST = _UD_23_HOME + "#UD_Indonesian-GSD/id_gsd-ud-test.conllu"
+UD_23_INDONESIAN_PUD_TEST = _UD_23_HOME + "#UD_Indonesian-PUD/id_pud-ud-test.conllu"
+UD_23_IRISH_IDT_TRAIN = _UD_23_HOME + "#UD_Irish-IDT/ga_idt-ud-train.conllu"
+UD_23_IRISH_IDT_TEST = _UD_23_HOME + "#UD_Irish-IDT/ga_idt-ud-test.conllu"
+UD_23_ITALIAN_ISDT_TRAIN = _UD_23_HOME + "#UD_Italian-ISDT/it_isdt-ud-train.conllu"
+UD_23_ITALIAN_ISDT_DEV = _UD_23_HOME + "#UD_Italian-ISDT/it_isdt-ud-dev.conllu"
+UD_23_ITALIAN_ISDT_TEST = _UD_23_HOME + "#UD_Italian-ISDT/it_isdt-ud-test.conllu"
+UD_23_ITALIAN_PUD_TEST = _UD_23_HOME + "#UD_Italian-PUD/it_pud-ud-test.conllu"
+UD_23_ITALIAN_PARTUT_TRAIN = _UD_23_HOME + "#UD_Italian-ParTUT/it_partut-ud-train.conllu"
+UD_23_ITALIAN_PARTUT_DEV = _UD_23_HOME + "#UD_Italian-ParTUT/it_partut-ud-dev.conllu"
+UD_23_ITALIAN_PARTUT_TEST = _UD_23_HOME + "#UD_Italian-ParTUT/it_partut-ud-test.conllu"
+UD_23_ITALIAN_POSTWITA_TRAIN = _UD_23_HOME + "#UD_Italian-PoSTWITA/it_postwita-ud-train.conllu"
+UD_23_ITALIAN_POSTWITA_DEV = _UD_23_HOME + "#UD_Italian-PoSTWITA/it_postwita-ud-dev.conllu"
+UD_23_ITALIAN_POSTWITA_TEST = _UD_23_HOME + "#UD_Italian-PoSTWITA/it_postwita-ud-test.conllu"
+UD_23_JAPANESE_BCCWJ_TRAIN = _UD_23_HOME + "#UD_Japanese-BCCWJ/ja_bccwj-ud-train.conllu"
+UD_23_JAPANESE_BCCWJ_DEV = _UD_23_HOME + "#UD_Japanese-BCCWJ/ja_bccwj-ud-dev.conllu"
+UD_23_JAPANESE_BCCWJ_TEST = _UD_23_HOME + "#UD_Japanese-BCCWJ/ja_bccwj-ud-test.conllu"
+UD_23_JAPANESE_GSD_TRAIN = _UD_23_HOME + "#UD_Japanese-GSD/ja_gsd-ud-train.conllu"
+UD_23_JAPANESE_GSD_DEV = _UD_23_HOME + "#UD_Japanese-GSD/ja_gsd-ud-dev.conllu"
+UD_23_JAPANESE_GSD_TEST = _UD_23_HOME + "#UD_Japanese-GSD/ja_gsd-ud-test.conllu"
+UD_23_JAPANESE_MODERN_TEST = _UD_23_HOME + "#UD_Japanese-Modern/ja_modern-ud-test.conllu"
+UD_23_JAPANESE_PUD_TEST = _UD_23_HOME + "#UD_Japanese-PUD/ja_pud-ud-test.conllu"
+UD_23_KAZAKH_KTB_TRAIN = _UD_23_HOME + "#UD_Kazakh-KTB/kk_ktb-ud-train.conllu"
+UD_23_KAZAKH_KTB_TEST = _UD_23_HOME + "#UD_Kazakh-KTB/kk_ktb-ud-test.conllu"
+UD_23_KOMI_ZYRIAN_IKDP_TEST = _UD_23_HOME + "#UD_Komi_Zyrian-IKDP/kpv_ikdp-ud-test.conllu"
+UD_23_KOMI_ZYRIAN_LATTICE_TEST = _UD_23_HOME + "#UD_Komi_Zyrian-Lattice/kpv_lattice-ud-test.conllu"
+UD_23_KOREAN_GSD_TRAIN = _UD_23_HOME + "#UD_Korean-GSD/ko_gsd-ud-train.conllu"
+UD_23_KOREAN_GSD_DEV = _UD_23_HOME + "#UD_Korean-GSD/ko_gsd-ud-dev.conllu"
+UD_23_KOREAN_GSD_TEST = _UD_23_HOME + "#UD_Korean-GSD/ko_gsd-ud-test.conllu"
+UD_23_KOREAN_KAIST_TRAIN = _UD_23_HOME + "#UD_Korean-Kaist/ko_kaist-ud-train.conllu"
+UD_23_KOREAN_KAIST_DEV = _UD_23_HOME + "#UD_Korean-Kaist/ko_kaist-ud-dev.conllu"
+UD_23_KOREAN_KAIST_TEST = _UD_23_HOME + "#UD_Korean-Kaist/ko_kaist-ud-test.conllu"
+UD_23_KOREAN_PUD_TEST = _UD_23_HOME + "#UD_Korean-PUD/ko_pud-ud-test.conllu"
+UD_23_KURMANJI_MG_TRAIN = _UD_23_HOME + "#UD_Kurmanji-MG/kmr_mg-ud-train.conllu"
+UD_23_KURMANJI_MG_TEST = _UD_23_HOME + "#UD_Kurmanji-MG/kmr_mg-ud-test.conllu"
+UD_23_LATIN_ITTB_TRAIN = _UD_23_HOME + "#UD_Latin-ITTB/la_ittb-ud-train.conllu"
+UD_23_LATIN_ITTB_DEV = _UD_23_HOME + "#UD_Latin-ITTB/la_ittb-ud-dev.conllu"
+UD_23_LATIN_ITTB_TEST = _UD_23_HOME + "#UD_Latin-ITTB/la_ittb-ud-test.conllu"
+UD_23_LATIN_PROIEL_TRAIN = _UD_23_HOME + "#UD_Latin-PROIEL/la_proiel-ud-train.conllu"
+UD_23_LATIN_PROIEL_DEV = _UD_23_HOME + "#UD_Latin-PROIEL/la_proiel-ud-dev.conllu"
+UD_23_LATIN_PROIEL_TEST = _UD_23_HOME + "#UD_Latin-PROIEL/la_proiel-ud-test.conllu"
+UD_23_LATIN_PERSEUS_TRAIN = _UD_23_HOME + "#UD_Latin-Perseus/la_perseus-ud-train.conllu"
+UD_23_LATIN_PERSEUS_TEST = _UD_23_HOME + "#UD_Latin-Perseus/la_perseus-ud-test.conllu"
+UD_23_LATVIAN_LVTB_TRAIN = _UD_23_HOME + "#UD_Latvian-LVTB/lv_lvtb-ud-train.conllu"
+UD_23_LATVIAN_LVTB_DEV = _UD_23_HOME + "#UD_Latvian-LVTB/lv_lvtb-ud-dev.conllu"
+UD_23_LATVIAN_LVTB_TEST = _UD_23_HOME + "#UD_Latvian-LVTB/lv_lvtb-ud-test.conllu"
+UD_23_LITHUANIAN_HSE_TRAIN = _UD_23_HOME + "#UD_Lithuanian-HSE/lt_hse-ud-train.conllu"
+UD_23_LITHUANIAN_HSE_DEV = _UD_23_HOME + "#UD_Lithuanian-HSE/lt_hse-ud-dev.conllu"
+UD_23_LITHUANIAN_HSE_TEST = _UD_23_HOME + "#UD_Lithuanian-HSE/lt_hse-ud-test.conllu"
+UD_23_MALTESE_MUDT_TRAIN = _UD_23_HOME + "#UD_Maltese-MUDT/mt_mudt-ud-train.conllu"
+UD_23_MALTESE_MUDT_DEV = _UD_23_HOME + "#UD_Maltese-MUDT/mt_mudt-ud-dev.conllu"
+UD_23_MALTESE_MUDT_TEST = _UD_23_HOME + "#UD_Maltese-MUDT/mt_mudt-ud-test.conllu"
+UD_23_MARATHI_UFAL_TRAIN = _UD_23_HOME + "#UD_Marathi-UFAL/mr_ufal-ud-train.conllu"
+UD_23_MARATHI_UFAL_DEV = _UD_23_HOME + "#UD_Marathi-UFAL/mr_ufal-ud-dev.conllu"
+UD_23_MARATHI_UFAL_TEST = _UD_23_HOME + "#UD_Marathi-UFAL/mr_ufal-ud-test.conllu"
+UD_23_NAIJA_NSC_TEST = _UD_23_HOME + "#UD_Naija-NSC/pcm_nsc-ud-test.conllu"
+UD_23_NORTH_SAMI_GIELLA_TRAIN = _UD_23_HOME + "#UD_North_Sami-Giella/sme_giella-ud-train.conllu"
+UD_23_NORTH_SAMI_GIELLA_TEST = _UD_23_HOME + "#UD_North_Sami-Giella/sme_giella-ud-test.conllu"
+UD_23_NORWEGIAN_BOKMAAL_TRAIN = _UD_23_HOME + "#UD_Norwegian-Bokmaal/no_bokmaal-ud-train.conllu"
+UD_23_NORWEGIAN_BOKMAAL_DEV = _UD_23_HOME + "#UD_Norwegian-Bokmaal/no_bokmaal-ud-dev.conllu"
+UD_23_NORWEGIAN_BOKMAAL_TEST = _UD_23_HOME + "#UD_Norwegian-Bokmaal/no_bokmaal-ud-test.conllu"
+UD_23_NORWEGIAN_NYNORSK_TRAIN = _UD_23_HOME + "#UD_Norwegian-Nynorsk/no_nynorsk-ud-train.conllu"
+UD_23_NORWEGIAN_NYNORSK_DEV = _UD_23_HOME + "#UD_Norwegian-Nynorsk/no_nynorsk-ud-dev.conllu"
+UD_23_NORWEGIAN_NYNORSK_TEST = _UD_23_HOME + "#UD_Norwegian-Nynorsk/no_nynorsk-ud-test.conllu"
+UD_23_NORWEGIAN_NYNORSKLIA_TRAIN = _UD_23_HOME + "#UD_Norwegian-NynorskLIA/no_nynorsklia-ud-train.conllu"
+UD_23_NORWEGIAN_NYNORSKLIA_TEST = _UD_23_HOME + "#UD_Norwegian-NynorskLIA/no_nynorsklia-ud-test.conllu"
+UD_23_OLD_CHURCH_SLAVONIC_PROIEL_TRAIN = _UD_23_HOME + "#UD_Old_Church_Slavonic-PROIEL/cu_proiel-ud-train.conllu"
+UD_23_OLD_CHURCH_SLAVONIC_PROIEL_DEV = _UD_23_HOME + "#UD_Old_Church_Slavonic-PROIEL/cu_proiel-ud-dev.conllu"
+UD_23_OLD_CHURCH_SLAVONIC_PROIEL_TEST = _UD_23_HOME + "#UD_Old_Church_Slavonic-PROIEL/cu_proiel-ud-test.conllu"
+UD_23_OLD_FRENCH_SRCMF_TRAIN = _UD_23_HOME + "#UD_Old_French-SRCMF/fro_srcmf-ud-train.conllu"
+UD_23_OLD_FRENCH_SRCMF_DEV = _UD_23_HOME + "#UD_Old_French-SRCMF/fro_srcmf-ud-dev.conllu"
+UD_23_OLD_FRENCH_SRCMF_TEST = _UD_23_HOME + "#UD_Old_French-SRCMF/fro_srcmf-ud-test.conllu"
+UD_23_PERSIAN_SERAJI_TRAIN = _UD_23_HOME + "#UD_Persian-Seraji/fa_seraji-ud-train.conllu"
+UD_23_PERSIAN_SERAJI_DEV = _UD_23_HOME + "#UD_Persian-Seraji/fa_seraji-ud-dev.conllu"
+UD_23_PERSIAN_SERAJI_TEST = _UD_23_HOME + "#UD_Persian-Seraji/fa_seraji-ud-test.conllu"
+UD_23_POLISH_LFG_TRAIN = _UD_23_HOME + "#UD_Polish-LFG/pl_lfg-ud-train.conllu"
+UD_23_POLISH_LFG_DEV = _UD_23_HOME + "#UD_Polish-LFG/pl_lfg-ud-dev.conllu"
+UD_23_POLISH_LFG_TEST = _UD_23_HOME + "#UD_Polish-LFG/pl_lfg-ud-test.conllu"
+UD_23_POLISH_SZ_TRAIN = _UD_23_HOME + "#UD_Polish-SZ/pl_sz-ud-train.conllu"
+UD_23_POLISH_SZ_DEV = _UD_23_HOME + "#UD_Polish-SZ/pl_sz-ud-dev.conllu"
+UD_23_POLISH_SZ_TEST = _UD_23_HOME + "#UD_Polish-SZ/pl_sz-ud-test.conllu"
+UD_23_PORTUGUESE_BOSQUE_TRAIN = _UD_23_HOME + "#UD_Portuguese-Bosque/pt_bosque-ud-train.conllu"
+UD_23_PORTUGUESE_BOSQUE_DEV = _UD_23_HOME + "#UD_Portuguese-Bosque/pt_bosque-ud-dev.conllu"
+UD_23_PORTUGUESE_BOSQUE_TEST = _UD_23_HOME + "#UD_Portuguese-Bosque/pt_bosque-ud-test.conllu"
+UD_23_PORTUGUESE_GSD_TRAIN = _UD_23_HOME + "#UD_Portuguese-GSD/pt_gsd-ud-train.conllu"
+UD_23_PORTUGUESE_GSD_DEV = _UD_23_HOME + "#UD_Portuguese-GSD/pt_gsd-ud-dev.conllu"
+UD_23_PORTUGUESE_GSD_TEST = _UD_23_HOME + "#UD_Portuguese-GSD/pt_gsd-ud-test.conllu"
+UD_23_PORTUGUESE_PUD_TEST = _UD_23_HOME + "#UD_Portuguese-PUD/pt_pud-ud-test.conllu"
+UD_23_ROMANIAN_NONSTANDARD_TRAIN = _UD_23_HOME + "#UD_Romanian-Nonstandard/ro_nonstandard-ud-train.conllu"
+UD_23_ROMANIAN_NONSTANDARD_DEV = _UD_23_HOME + "#UD_Romanian-Nonstandard/ro_nonstandard-ud-dev.conllu"
+UD_23_ROMANIAN_NONSTANDARD_TEST = _UD_23_HOME + "#UD_Romanian-Nonstandard/ro_nonstandard-ud-test.conllu"
+UD_23_ROMANIAN_RRT_TRAIN = _UD_23_HOME + "#UD_Romanian-RRT/ro_rrt-ud-train.conllu"
+UD_23_ROMANIAN_RRT_DEV = _UD_23_HOME + "#UD_Romanian-RRT/ro_rrt-ud-dev.conllu"
+UD_23_ROMANIAN_RRT_TEST = _UD_23_HOME + "#UD_Romanian-RRT/ro_rrt-ud-test.conllu"
+UD_23_RUSSIAN_GSD_TRAIN = _UD_23_HOME + "#UD_Russian-GSD/ru_gsd-ud-train.conllu"
+UD_23_RUSSIAN_GSD_DEV = _UD_23_HOME + "#UD_Russian-GSD/ru_gsd-ud-dev.conllu"
+UD_23_RUSSIAN_GSD_TEST = _UD_23_HOME + "#UD_Russian-GSD/ru_gsd-ud-test.conllu"
+UD_23_RUSSIAN_PUD_TEST = _UD_23_HOME + "#UD_Russian-PUD/ru_pud-ud-test.conllu"
+UD_23_RUSSIAN_SYNTAGRUS_TRAIN = _UD_23_HOME + "#UD_Russian-SynTagRus/ru_syntagrus-ud-train.conllu"
+UD_23_RUSSIAN_SYNTAGRUS_DEV = _UD_23_HOME + "#UD_Russian-SynTagRus/ru_syntagrus-ud-dev.conllu"
+UD_23_RUSSIAN_SYNTAGRUS_TEST = _UD_23_HOME + "#UD_Russian-SynTagRus/ru_syntagrus-ud-test.conllu"
+UD_23_RUSSIAN_TAIGA_TRAIN = _UD_23_HOME + "#UD_Russian-Taiga/ru_taiga-ud-train.conllu"
+UD_23_RUSSIAN_TAIGA_TEST = _UD_23_HOME + "#UD_Russian-Taiga/ru_taiga-ud-test.conllu"
+UD_23_SANSKRIT_UFAL_TEST = _UD_23_HOME + "#UD_Sanskrit-UFAL/sa_ufal-ud-test.conllu"
+UD_23_SERBIAN_SET_TRAIN = _UD_23_HOME + "#UD_Serbian-SET/sr_set-ud-train.conllu"
+UD_23_SERBIAN_SET_DEV = _UD_23_HOME + "#UD_Serbian-SET/sr_set-ud-dev.conllu"
+UD_23_SERBIAN_SET_TEST = _UD_23_HOME + "#UD_Serbian-SET/sr_set-ud-test.conllu"
+UD_23_SLOVAK_SNK_TRAIN = _UD_23_HOME + "#UD_Slovak-SNK/sk_snk-ud-train.conllu"
+UD_23_SLOVAK_SNK_DEV = _UD_23_HOME + "#UD_Slovak-SNK/sk_snk-ud-dev.conllu"
+UD_23_SLOVAK_SNK_TEST = _UD_23_HOME + "#UD_Slovak-SNK/sk_snk-ud-test.conllu"
+UD_23_SLOVENIAN_SSJ_TRAIN = _UD_23_HOME + "#UD_Slovenian-SSJ/sl_ssj-ud-train.conllu"
+UD_23_SLOVENIAN_SSJ_DEV = _UD_23_HOME + "#UD_Slovenian-SSJ/sl_ssj-ud-dev.conllu"
+UD_23_SLOVENIAN_SSJ_TEST = _UD_23_HOME + "#UD_Slovenian-SSJ/sl_ssj-ud-test.conllu"
+UD_23_SLOVENIAN_SST_TRAIN = _UD_23_HOME + "#UD_Slovenian-SST/sl_sst-ud-train.conllu"
+UD_23_SLOVENIAN_SST_TEST = _UD_23_HOME + "#UD_Slovenian-SST/sl_sst-ud-test.conllu"
+UD_23_SPANISH_ANCORA_TRAIN = _UD_23_HOME + "#UD_Spanish-AnCora/es_ancora-ud-train.conllu"
+UD_23_SPANISH_ANCORA_DEV = _UD_23_HOME + "#UD_Spanish-AnCora/es_ancora-ud-dev.conllu"
+UD_23_SPANISH_ANCORA_TEST = _UD_23_HOME + "#UD_Spanish-AnCora/es_ancora-ud-test.conllu"
+UD_23_SPANISH_GSD_TRAIN = _UD_23_HOME + "#UD_Spanish-GSD/es_gsd-ud-train.conllu"
+UD_23_SPANISH_GSD_DEV = _UD_23_HOME + "#UD_Spanish-GSD/es_gsd-ud-dev.conllu"
+UD_23_SPANISH_GSD_TEST = _UD_23_HOME + "#UD_Spanish-GSD/es_gsd-ud-test.conllu"
+UD_23_SPANISH_PUD_TEST = _UD_23_HOME + "#UD_Spanish-PUD/es_pud-ud-test.conllu"
+UD_23_SWEDISH_LINES_TRAIN = _UD_23_HOME + "#UD_Swedish-LinES/sv_lines-ud-train.conllu"
+UD_23_SWEDISH_LINES_DEV = _UD_23_HOME + "#UD_Swedish-LinES/sv_lines-ud-dev.conllu"
+UD_23_SWEDISH_LINES_TEST = _UD_23_HOME + "#UD_Swedish-LinES/sv_lines-ud-test.conllu"
+UD_23_SWEDISH_PUD_TEST = _UD_23_HOME + "#UD_Swedish-PUD/sv_pud-ud-test.conllu"
+UD_23_SWEDISH_TALBANKEN_TRAIN = _UD_23_HOME + "#UD_Swedish-Talbanken/sv_talbanken-ud-train.conllu"
+UD_23_SWEDISH_TALBANKEN_DEV = _UD_23_HOME + "#UD_Swedish-Talbanken/sv_talbanken-ud-dev.conllu"
+UD_23_SWEDISH_TALBANKEN_TEST = _UD_23_HOME + "#UD_Swedish-Talbanken/sv_talbanken-ud-test.conllu"
+UD_23_SWEDISH_SIGN_LANGUAGE_SSLC_TRAIN = _UD_23_HOME + "#UD_Swedish_Sign_Language-SSLC/swl_sslc-ud-train.conllu"
+UD_23_SWEDISH_SIGN_LANGUAGE_SSLC_DEV = _UD_23_HOME + "#UD_Swedish_Sign_Language-SSLC/swl_sslc-ud-dev.conllu"
+UD_23_SWEDISH_SIGN_LANGUAGE_SSLC_TEST = _UD_23_HOME + "#UD_Swedish_Sign_Language-SSLC/swl_sslc-ud-test.conllu"
+UD_23_TAGALOG_TRG_TEST = _UD_23_HOME + "#UD_Tagalog-TRG/tl_trg-ud-test.conllu"
+UD_23_TAMIL_TTB_TRAIN = _UD_23_HOME + "#UD_Tamil-TTB/ta_ttb-ud-train.conllu"
+UD_23_TAMIL_TTB_DEV = _UD_23_HOME + "#UD_Tamil-TTB/ta_ttb-ud-dev.conllu"
+UD_23_TAMIL_TTB_TEST = _UD_23_HOME + "#UD_Tamil-TTB/ta_ttb-ud-test.conllu"
+UD_23_TELUGU_MTG_TRAIN = _UD_23_HOME + "#UD_Telugu-MTG/te_mtg-ud-train.conllu"
+UD_23_TELUGU_MTG_DEV = _UD_23_HOME + "#UD_Telugu-MTG/te_mtg-ud-dev.conllu"
+UD_23_TELUGU_MTG_TEST = _UD_23_HOME + "#UD_Telugu-MTG/te_mtg-ud-test.conllu"
+UD_23_THAI_PUD_TEST = _UD_23_HOME + "#UD_Thai-PUD/th_pud-ud-test.conllu"
+UD_23_TURKISH_IMST_TRAIN = _UD_23_HOME + "#UD_Turkish-IMST/tr_imst-ud-train.conllu"
+UD_23_TURKISH_IMST_DEV = _UD_23_HOME + "#UD_Turkish-IMST/tr_imst-ud-dev.conllu"
+UD_23_TURKISH_IMST_TEST = _UD_23_HOME + "#UD_Turkish-IMST/tr_imst-ud-test.conllu"
+UD_23_TURKISH_PUD_TEST = _UD_23_HOME + "#UD_Turkish-PUD/tr_pud-ud-test.conllu"
+UD_23_UKRAINIAN_IU_TRAIN = _UD_23_HOME + "#UD_Ukrainian-IU/uk_iu-ud-train.conllu"
+UD_23_UKRAINIAN_IU_DEV = _UD_23_HOME + "#UD_Ukrainian-IU/uk_iu-ud-dev.conllu"
+UD_23_UKRAINIAN_IU_TEST = _UD_23_HOME + "#UD_Ukrainian-IU/uk_iu-ud-test.conllu"
+UD_23_UPPER_SORBIAN_UFAL_TRAIN = _UD_23_HOME + "#UD_Upper_Sorbian-UFAL/hsb_ufal-ud-train.conllu"
+UD_23_UPPER_SORBIAN_UFAL_TEST = _UD_23_HOME + "#UD_Upper_Sorbian-UFAL/hsb_ufal-ud-test.conllu"
+UD_23_URDU_UDTB_TRAIN = _UD_23_HOME + "#UD_Urdu-UDTB/ur_udtb-ud-train.conllu"
+UD_23_URDU_UDTB_DEV = _UD_23_HOME + "#UD_Urdu-UDTB/ur_udtb-ud-dev.conllu"
+UD_23_URDU_UDTB_TEST = _UD_23_HOME + "#UD_Urdu-UDTB/ur_udtb-ud-test.conllu"
+UD_23_UYGHUR_UDT_TRAIN = _UD_23_HOME + "#UD_Uyghur-UDT/ug_udt-ud-train.conllu"
+UD_23_UYGHUR_UDT_DEV = _UD_23_HOME + "#UD_Uyghur-UDT/ug_udt-ud-dev.conllu"
+UD_23_UYGHUR_UDT_TEST = _UD_23_HOME + "#UD_Uyghur-UDT/ug_udt-ud-test.conllu"
+UD_23_VIETNAMESE_VTB_TRAIN = _UD_23_HOME + "#UD_Vietnamese-VTB/vi_vtb-ud-train.conllu"
+UD_23_VIETNAMESE_VTB_DEV = _UD_23_HOME + "#UD_Vietnamese-VTB/vi_vtb-ud-dev.conllu"
+UD_23_VIETNAMESE_VTB_TEST = _UD_23_HOME + "#UD_Vietnamese-VTB/vi_vtb-ud-test.conllu"
+UD_23_WARLPIRI_UFAL_TEST = _UD_23_HOME + "#UD_Warlpiri-UFAL/wbp_ufal-ud-test.conllu"
+UD_23_YORUBA_YTB_TEST = _UD_23_HOME + "#UD_Yoruba-YTB/yo_ytb-ud-test.conllu"
diff --git a/hanlp/datasets/parsing/ud/ud23m.py b/hanlp/datasets/parsing/ud/ud23m.py
new file mode 100644
index 000000000..b5ea067cf
--- /dev/null
+++ b/hanlp/datasets/parsing/ud/ud23m.py
@@ -0,0 +1,12 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-21 20:39
+import os
+
+from hanlp.datasets.parsing.ud import concat_treebanks
+from .ud23 import _UD_23_HOME
+
+_UD_23_MULTILINGUAL_HOME = concat_treebanks(_UD_23_HOME, '2.3')
+UD_23_MULTILINGUAL_TRAIN = os.path.join(_UD_23_MULTILINGUAL_HOME, 'train.conllu')
+UD_23_MULTILINGUAL_DEV = os.path.join(_UD_23_MULTILINGUAL_HOME, 'dev.conllu')
+UD_23_MULTILINGUAL_TEST = os.path.join(_UD_23_MULTILINGUAL_HOME, 'test.conllu')
diff --git a/hanlp/datasets/parsing/ud/ud27.py b/hanlp/datasets/parsing/ud/ud27.py
new file mode 100644
index 000000000..d1e13fd64
--- /dev/null
+++ b/hanlp/datasets/parsing/ud/ud27.py
@@ -0,0 +1,855 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-07 21:03
+import glob
+import os
+
+from hanlp.utils.io_util import uncompress, get_resource
+
+_UD_27_URL = "https://lindat.mff.cuni.cz/repository/xmlui/handle/11234/1-3424/allzip"
+_UD_27_HOME = _UD_27_URL + '#ud-treebanks-v2.7/'
+_path = get_resource(_UD_27_URL)
+if os.path.isfile(_path):
+ os.rename(_path, _path + '.zip')
+ uncompress(_path + '.zip')
+ uncompress(os.path.join(_path, 'ud-treebanks-v2.7.tgz'))
+
+
+# noinspection PyShadowingNames
+def _list_dir(path, home):
+ prefix = home.lstrip('_').replace('_HOME', '')
+
+ path = get_resource(path)
+ with open('ud27.py', 'a') as out:
+ for f in sorted(glob.glob(path + '/ud-treebanks-v2.7/UD_*')):
+ basename = os.path.basename(f)
+ name = basename[len('UD_'):]
+ name = name.upper().replace('-', '_')
+ for split in 'train', 'dev', 'test':
+ sp = glob.glob(f + f'/*{split}.conllu')
+ if not sp:
+ continue
+ sp = os.path.basename(sp[0])
+ out.write(f'{prefix}_{name}_{split.upper()} = {home} + "{basename}/{sp}"\n')
+ out.write(f'"{prefix} {split} set of {name}."\n')
+
+
+def main():
+ _list_dir(_UD_27_URL, '_UD_27_HOME')
+ pass
+
+
+if __name__ == '__main__':
+ main()
+UD_27_AFRIKAANS_AFRIBOOMS_TRAIN = _UD_27_HOME + "UD_Afrikaans-AfriBooms/af_afribooms-ud-train.conllu"
+"UD_27 train set of AFRIKAANS_AFRIBOOMS."
+UD_27_AFRIKAANS_AFRIBOOMS_DEV = _UD_27_HOME + "UD_Afrikaans-AfriBooms/af_afribooms-ud-dev.conllu"
+"UD_27 dev set of AFRIKAANS_AFRIBOOMS."
+UD_27_AFRIKAANS_AFRIBOOMS_TEST = _UD_27_HOME + "UD_Afrikaans-AfriBooms/af_afribooms-ud-test.conllu"
+"UD_27 test set of AFRIKAANS_AFRIBOOMS."
+UD_27_AKKADIAN_PISANDUB_TEST = _UD_27_HOME + "UD_Akkadian-PISANDUB/akk_pisandub-ud-test.conllu"
+"UD_27 test set of AKKADIAN_PISANDUB."
+UD_27_AKKADIAN_RIAO_TEST = _UD_27_HOME + "UD_Akkadian-RIAO/akk_riao-ud-test.conllu"
+"UD_27 test set of AKKADIAN_RIAO."
+UD_27_AKUNTSU_TUDET_TEST = _UD_27_HOME + "UD_Akuntsu-TuDeT/aqz_tudet-ud-test.conllu"
+"UD_27 test set of AKUNTSU_TUDET."
+UD_27_ALBANIAN_TSA_TEST = _UD_27_HOME + "UD_Albanian-TSA/sq_tsa-ud-test.conllu"
+"UD_27 test set of ALBANIAN_TSA."
+UD_27_AMHARIC_ATT_TEST = _UD_27_HOME + "UD_Amharic-ATT/am_att-ud-test.conllu"
+"UD_27 test set of AMHARIC_ATT."
+UD_27_ANCIENT_GREEK_PROIEL_TRAIN = _UD_27_HOME + "UD_Ancient_Greek-PROIEL/grc_proiel-ud-train.conllu"
+"UD_27 train set of ANCIENT_GREEK_PROIEL."
+UD_27_ANCIENT_GREEK_PROIEL_DEV = _UD_27_HOME + "UD_Ancient_Greek-PROIEL/grc_proiel-ud-dev.conllu"
+"UD_27 dev set of ANCIENT_GREEK_PROIEL."
+UD_27_ANCIENT_GREEK_PROIEL_TEST = _UD_27_HOME + "UD_Ancient_Greek-PROIEL/grc_proiel-ud-test.conllu"
+"UD_27 test set of ANCIENT_GREEK_PROIEL."
+UD_27_ANCIENT_GREEK_PERSEUS_TRAIN = _UD_27_HOME + "UD_Ancient_Greek-Perseus/grc_perseus-ud-train.conllu"
+"UD_27 train set of ANCIENT_GREEK_PERSEUS."
+UD_27_ANCIENT_GREEK_PERSEUS_DEV = _UD_27_HOME + "UD_Ancient_Greek-Perseus/grc_perseus-ud-dev.conllu"
+"UD_27 dev set of ANCIENT_GREEK_PERSEUS."
+UD_27_ANCIENT_GREEK_PERSEUS_TEST = _UD_27_HOME + "UD_Ancient_Greek-Perseus/grc_perseus-ud-test.conllu"
+"UD_27 test set of ANCIENT_GREEK_PERSEUS."
+UD_27_APURINA_UFPA_TEST = _UD_27_HOME + "UD_Apurina-UFPA/apu_ufpa-ud-test.conllu"
+"UD_27 test set of APURINA_UFPA."
+UD_27_ARABIC_NYUAD_TRAIN = _UD_27_HOME + "UD_Arabic-NYUAD/ar_nyuad-ud-train.conllu"
+"UD_27 train set of ARABIC_NYUAD."
+UD_27_ARABIC_NYUAD_DEV = _UD_27_HOME + "UD_Arabic-NYUAD/ar_nyuad-ud-dev.conllu"
+"UD_27 dev set of ARABIC_NYUAD."
+UD_27_ARABIC_NYUAD_TEST = _UD_27_HOME + "UD_Arabic-NYUAD/ar_nyuad-ud-test.conllu"
+"UD_27 test set of ARABIC_NYUAD."
+UD_27_ARABIC_PADT_TRAIN = _UD_27_HOME + "UD_Arabic-PADT/ar_padt-ud-train.conllu"
+"UD_27 train set of ARABIC_PADT."
+UD_27_ARABIC_PADT_DEV = _UD_27_HOME + "UD_Arabic-PADT/ar_padt-ud-dev.conllu"
+"UD_27 dev set of ARABIC_PADT."
+UD_27_ARABIC_PADT_TEST = _UD_27_HOME + "UD_Arabic-PADT/ar_padt-ud-test.conllu"
+"UD_27 test set of ARABIC_PADT."
+UD_27_ARABIC_PUD_TEST = _UD_27_HOME + "UD_Arabic-PUD/ar_pud-ud-test.conllu"
+"UD_27 test set of ARABIC_PUD."
+UD_27_ARMENIAN_ARMTDP_TRAIN = _UD_27_HOME + "UD_Armenian-ArmTDP/hy_armtdp-ud-train.conllu"
+"UD_27 train set of ARMENIAN_ARMTDP."
+UD_27_ARMENIAN_ARMTDP_DEV = _UD_27_HOME + "UD_Armenian-ArmTDP/hy_armtdp-ud-dev.conllu"
+"UD_27 dev set of ARMENIAN_ARMTDP."
+UD_27_ARMENIAN_ARMTDP_TEST = _UD_27_HOME + "UD_Armenian-ArmTDP/hy_armtdp-ud-test.conllu"
+"UD_27 test set of ARMENIAN_ARMTDP."
+UD_27_ASSYRIAN_AS_TEST = _UD_27_HOME + "UD_Assyrian-AS/aii_as-ud-test.conllu"
+"UD_27 test set of ASSYRIAN_AS."
+UD_27_BAMBARA_CRB_TEST = _UD_27_HOME + "UD_Bambara-CRB/bm_crb-ud-test.conllu"
+"UD_27 test set of BAMBARA_CRB."
+UD_27_BASQUE_BDT_TRAIN = _UD_27_HOME + "UD_Basque-BDT/eu_bdt-ud-train.conllu"
+"UD_27 train set of BASQUE_BDT."
+UD_27_BASQUE_BDT_DEV = _UD_27_HOME + "UD_Basque-BDT/eu_bdt-ud-dev.conllu"
+"UD_27 dev set of BASQUE_BDT."
+UD_27_BASQUE_BDT_TEST = _UD_27_HOME + "UD_Basque-BDT/eu_bdt-ud-test.conllu"
+"UD_27 test set of BASQUE_BDT."
+UD_27_BELARUSIAN_HSE_TRAIN = _UD_27_HOME + "UD_Belarusian-HSE/be_hse-ud-train.conllu"
+"UD_27 train set of BELARUSIAN_HSE."
+UD_27_BELARUSIAN_HSE_DEV = _UD_27_HOME + "UD_Belarusian-HSE/be_hse-ud-dev.conllu"
+"UD_27 dev set of BELARUSIAN_HSE."
+UD_27_BELARUSIAN_HSE_TEST = _UD_27_HOME + "UD_Belarusian-HSE/be_hse-ud-test.conllu"
+"UD_27 test set of BELARUSIAN_HSE."
+UD_27_BHOJPURI_BHTB_TEST = _UD_27_HOME + "UD_Bhojpuri-BHTB/bho_bhtb-ud-test.conllu"
+"UD_27 test set of BHOJPURI_BHTB."
+UD_27_BRETON_KEB_TEST = _UD_27_HOME + "UD_Breton-KEB/br_keb-ud-test.conllu"
+"UD_27 test set of BRETON_KEB."
+UD_27_BULGARIAN_BTB_TRAIN = _UD_27_HOME + "UD_Bulgarian-BTB/bg_btb-ud-train.conllu"
+"UD_27 train set of BULGARIAN_BTB."
+UD_27_BULGARIAN_BTB_DEV = _UD_27_HOME + "UD_Bulgarian-BTB/bg_btb-ud-dev.conllu"
+"UD_27 dev set of BULGARIAN_BTB."
+UD_27_BULGARIAN_BTB_TEST = _UD_27_HOME + "UD_Bulgarian-BTB/bg_btb-ud-test.conllu"
+"UD_27 test set of BULGARIAN_BTB."
+UD_27_BURYAT_BDT_TRAIN = _UD_27_HOME + "UD_Buryat-BDT/bxr_bdt-ud-train.conllu"
+"UD_27 train set of BURYAT_BDT."
+UD_27_BURYAT_BDT_TEST = _UD_27_HOME + "UD_Buryat-BDT/bxr_bdt-ud-test.conllu"
+"UD_27 test set of BURYAT_BDT."
+UD_27_CANTONESE_HK_TEST = _UD_27_HOME + "UD_Cantonese-HK/yue_hk-ud-test.conllu"
+"UD_27 test set of CANTONESE_HK."
+UD_27_CATALAN_ANCORA_TRAIN = _UD_27_HOME + "UD_Catalan-AnCora/ca_ancora-ud-train.conllu"
+"UD_27 train set of CATALAN_ANCORA."
+UD_27_CATALAN_ANCORA_DEV = _UD_27_HOME + "UD_Catalan-AnCora/ca_ancora-ud-dev.conllu"
+"UD_27 dev set of CATALAN_ANCORA."
+UD_27_CATALAN_ANCORA_TEST = _UD_27_HOME + "UD_Catalan-AnCora/ca_ancora-ud-test.conllu"
+"UD_27 test set of CATALAN_ANCORA."
+UD_27_CHINESE_CFL_TEST = _UD_27_HOME + "UD_Chinese-CFL/zh_cfl-ud-test.conllu"
+"UD_27 test set of CHINESE_CFL."
+UD_27_CHINESE_GSD_TRAIN = _UD_27_HOME + "UD_Chinese-GSD/zh_gsd-ud-train.conllu"
+"UD_27 train set of CHINESE_GSD."
+UD_27_CHINESE_GSD_DEV = _UD_27_HOME + "UD_Chinese-GSD/zh_gsd-ud-dev.conllu"
+"UD_27 dev set of CHINESE_GSD."
+UD_27_CHINESE_GSD_TEST = _UD_27_HOME + "UD_Chinese-GSD/zh_gsd-ud-test.conllu"
+"UD_27 test set of CHINESE_GSD."
+UD_27_CHINESE_GSDSIMP_TRAIN = _UD_27_HOME + "UD_Chinese-GSDSimp/zh_gsdsimp-ud-train.conllu"
+"UD_27 train set of CHINESE_GSDSIMP."
+UD_27_CHINESE_GSDSIMP_DEV = _UD_27_HOME + "UD_Chinese-GSDSimp/zh_gsdsimp-ud-dev.conllu"
+"UD_27 dev set of CHINESE_GSDSIMP."
+UD_27_CHINESE_GSDSIMP_TEST = _UD_27_HOME + "UD_Chinese-GSDSimp/zh_gsdsimp-ud-test.conllu"
+"UD_27 test set of CHINESE_GSDSIMP."
+UD_27_CHINESE_HK_TEST = _UD_27_HOME + "UD_Chinese-HK/zh_hk-ud-test.conllu"
+"UD_27 test set of CHINESE_HK."
+UD_27_CHINESE_PUD_TEST = _UD_27_HOME + "UD_Chinese-PUD/zh_pud-ud-test.conllu"
+"UD_27 test set of CHINESE_PUD."
+UD_27_CHUKCHI_HSE_TEST = _UD_27_HOME + "UD_Chukchi-HSE/ckt_hse-ud-test.conllu"
+"UD_27 test set of CHUKCHI_HSE."
+UD_27_CLASSICAL_CHINESE_KYOTO_TRAIN = _UD_27_HOME + "UD_Classical_Chinese-Kyoto/lzh_kyoto-ud-train.conllu"
+"UD_27 train set of CLASSICAL_CHINESE_KYOTO."
+UD_27_CLASSICAL_CHINESE_KYOTO_DEV = _UD_27_HOME + "UD_Classical_Chinese-Kyoto/lzh_kyoto-ud-dev.conllu"
+"UD_27 dev set of CLASSICAL_CHINESE_KYOTO."
+UD_27_CLASSICAL_CHINESE_KYOTO_TEST = _UD_27_HOME + "UD_Classical_Chinese-Kyoto/lzh_kyoto-ud-test.conllu"
+"UD_27 test set of CLASSICAL_CHINESE_KYOTO."
+UD_27_COPTIC_SCRIPTORIUM_TRAIN = _UD_27_HOME + "UD_Coptic-Scriptorium/cop_scriptorium-ud-train.conllu"
+"UD_27 train set of COPTIC_SCRIPTORIUM."
+UD_27_COPTIC_SCRIPTORIUM_DEV = _UD_27_HOME + "UD_Coptic-Scriptorium/cop_scriptorium-ud-dev.conllu"
+"UD_27 dev set of COPTIC_SCRIPTORIUM."
+UD_27_COPTIC_SCRIPTORIUM_TEST = _UD_27_HOME + "UD_Coptic-Scriptorium/cop_scriptorium-ud-test.conllu"
+"UD_27 test set of COPTIC_SCRIPTORIUM."
+UD_27_CROATIAN_SET_TRAIN = _UD_27_HOME + "UD_Croatian-SET/hr_set-ud-train.conllu"
+"UD_27 train set of CROATIAN_SET."
+UD_27_CROATIAN_SET_DEV = _UD_27_HOME + "UD_Croatian-SET/hr_set-ud-dev.conllu"
+"UD_27 dev set of CROATIAN_SET."
+UD_27_CROATIAN_SET_TEST = _UD_27_HOME + "UD_Croatian-SET/hr_set-ud-test.conllu"
+"UD_27 test set of CROATIAN_SET."
+UD_27_CZECH_CAC_TRAIN = _UD_27_HOME + "UD_Czech-CAC/cs_cac-ud-train.conllu"
+"UD_27 train set of CZECH_CAC."
+UD_27_CZECH_CAC_DEV = _UD_27_HOME + "UD_Czech-CAC/cs_cac-ud-dev.conllu"
+"UD_27 dev set of CZECH_CAC."
+UD_27_CZECH_CAC_TEST = _UD_27_HOME + "UD_Czech-CAC/cs_cac-ud-test.conllu"
+"UD_27 test set of CZECH_CAC."
+UD_27_CZECH_CLTT_TRAIN = _UD_27_HOME + "UD_Czech-CLTT/cs_cltt-ud-train.conllu"
+"UD_27 train set of CZECH_CLTT."
+UD_27_CZECH_CLTT_DEV = _UD_27_HOME + "UD_Czech-CLTT/cs_cltt-ud-dev.conllu"
+"UD_27 dev set of CZECH_CLTT."
+UD_27_CZECH_CLTT_TEST = _UD_27_HOME + "UD_Czech-CLTT/cs_cltt-ud-test.conllu"
+"UD_27 test set of CZECH_CLTT."
+UD_27_CZECH_FICTREE_TRAIN = _UD_27_HOME + "UD_Czech-FicTree/cs_fictree-ud-train.conllu"
+"UD_27 train set of CZECH_FICTREE."
+UD_27_CZECH_FICTREE_DEV = _UD_27_HOME + "UD_Czech-FicTree/cs_fictree-ud-dev.conllu"
+"UD_27 dev set of CZECH_FICTREE."
+UD_27_CZECH_FICTREE_TEST = _UD_27_HOME + "UD_Czech-FicTree/cs_fictree-ud-test.conllu"
+"UD_27 test set of CZECH_FICTREE."
+UD_27_CZECH_PDT_TRAIN = _UD_27_HOME + "UD_Czech-PDT/cs_pdt-ud-train.conllu"
+"UD_27 train set of CZECH_PDT."
+UD_27_CZECH_PDT_DEV = _UD_27_HOME + "UD_Czech-PDT/cs_pdt-ud-dev.conllu"
+"UD_27 dev set of CZECH_PDT."
+UD_27_CZECH_PDT_TEST = _UD_27_HOME + "UD_Czech-PDT/cs_pdt-ud-test.conllu"
+"UD_27 test set of CZECH_PDT."
+UD_27_CZECH_PUD_TEST = _UD_27_HOME + "UD_Czech-PUD/cs_pud-ud-test.conllu"
+"UD_27 test set of CZECH_PUD."
+UD_27_DANISH_DDT_TRAIN = _UD_27_HOME + "UD_Danish-DDT/da_ddt-ud-train.conllu"
+"UD_27 train set of DANISH_DDT."
+UD_27_DANISH_DDT_DEV = _UD_27_HOME + "UD_Danish-DDT/da_ddt-ud-dev.conllu"
+"UD_27 dev set of DANISH_DDT."
+UD_27_DANISH_DDT_TEST = _UD_27_HOME + "UD_Danish-DDT/da_ddt-ud-test.conllu"
+"UD_27 test set of DANISH_DDT."
+UD_27_DUTCH_ALPINO_TRAIN = _UD_27_HOME + "UD_Dutch-Alpino/nl_alpino-ud-train.conllu"
+"UD_27 train set of DUTCH_ALPINO."
+UD_27_DUTCH_ALPINO_DEV = _UD_27_HOME + "UD_Dutch-Alpino/nl_alpino-ud-dev.conllu"
+"UD_27 dev set of DUTCH_ALPINO."
+UD_27_DUTCH_ALPINO_TEST = _UD_27_HOME + "UD_Dutch-Alpino/nl_alpino-ud-test.conllu"
+"UD_27 test set of DUTCH_ALPINO."
+UD_27_DUTCH_LASSYSMALL_TRAIN = _UD_27_HOME + "UD_Dutch-LassySmall/nl_lassysmall-ud-train.conllu"
+"UD_27 train set of DUTCH_LASSYSMALL."
+UD_27_DUTCH_LASSYSMALL_DEV = _UD_27_HOME + "UD_Dutch-LassySmall/nl_lassysmall-ud-dev.conllu"
+"UD_27 dev set of DUTCH_LASSYSMALL."
+UD_27_DUTCH_LASSYSMALL_TEST = _UD_27_HOME + "UD_Dutch-LassySmall/nl_lassysmall-ud-test.conllu"
+"UD_27 test set of DUTCH_LASSYSMALL."
+UD_27_ENGLISH_ESL_TRAIN = _UD_27_HOME + "UD_English-ESL/en_esl-ud-train.conllu"
+"UD_27 train set of ENGLISH_ESL."
+UD_27_ENGLISH_ESL_DEV = _UD_27_HOME + "UD_English-ESL/en_esl-ud-dev.conllu"
+"UD_27 dev set of ENGLISH_ESL."
+UD_27_ENGLISH_ESL_TEST = _UD_27_HOME + "UD_English-ESL/en_esl-ud-test.conllu"
+"UD_27 test set of ENGLISH_ESL."
+UD_27_ENGLISH_EWT_TRAIN = _UD_27_HOME + "UD_English-EWT/en_ewt-ud-train.conllu"
+"UD_27 train set of ENGLISH_EWT."
+UD_27_ENGLISH_EWT_DEV = _UD_27_HOME + "UD_English-EWT/en_ewt-ud-dev.conllu"
+"UD_27 dev set of ENGLISH_EWT."
+UD_27_ENGLISH_EWT_TEST = _UD_27_HOME + "UD_English-EWT/en_ewt-ud-test.conllu"
+"UD_27 test set of ENGLISH_EWT."
+UD_27_ENGLISH_GUM_TRAIN = _UD_27_HOME + "UD_English-GUM/en_gum-ud-train.conllu"
+"UD_27 train set of ENGLISH_GUM."
+UD_27_ENGLISH_GUM_DEV = _UD_27_HOME + "UD_English-GUM/en_gum-ud-dev.conllu"
+"UD_27 dev set of ENGLISH_GUM."
+UD_27_ENGLISH_GUM_TEST = _UD_27_HOME + "UD_English-GUM/en_gum-ud-test.conllu"
+"UD_27 test set of ENGLISH_GUM."
+UD_27_ENGLISH_GUMREDDIT_TRAIN = _UD_27_HOME + "UD_English-GUMReddit/en_gumreddit-ud-train.conllu"
+"UD_27 train set of ENGLISH_GUMREDDIT."
+UD_27_ENGLISH_GUMREDDIT_DEV = _UD_27_HOME + "UD_English-GUMReddit/en_gumreddit-ud-dev.conllu"
+"UD_27 dev set of ENGLISH_GUMREDDIT."
+UD_27_ENGLISH_GUMREDDIT_TEST = _UD_27_HOME + "UD_English-GUMReddit/en_gumreddit-ud-test.conllu"
+"UD_27 test set of ENGLISH_GUMREDDIT."
+UD_27_ENGLISH_LINES_TRAIN = _UD_27_HOME + "UD_English-LinES/en_lines-ud-train.conllu"
+"UD_27 train set of ENGLISH_LINES."
+UD_27_ENGLISH_LINES_DEV = _UD_27_HOME + "UD_English-LinES/en_lines-ud-dev.conllu"
+"UD_27 dev set of ENGLISH_LINES."
+UD_27_ENGLISH_LINES_TEST = _UD_27_HOME + "UD_English-LinES/en_lines-ud-test.conllu"
+"UD_27 test set of ENGLISH_LINES."
+UD_27_ENGLISH_PUD_TEST = _UD_27_HOME + "UD_English-PUD/en_pud-ud-test.conllu"
+"UD_27 test set of ENGLISH_PUD."
+UD_27_ENGLISH_PARTUT_TRAIN = _UD_27_HOME + "UD_English-ParTUT/en_partut-ud-train.conllu"
+"UD_27 train set of ENGLISH_PARTUT."
+UD_27_ENGLISH_PARTUT_DEV = _UD_27_HOME + "UD_English-ParTUT/en_partut-ud-dev.conllu"
+"UD_27 dev set of ENGLISH_PARTUT."
+UD_27_ENGLISH_PARTUT_TEST = _UD_27_HOME + "UD_English-ParTUT/en_partut-ud-test.conllu"
+"UD_27 test set of ENGLISH_PARTUT."
+UD_27_ENGLISH_PRONOUNS_TEST = _UD_27_HOME + "UD_English-Pronouns/en_pronouns-ud-test.conllu"
+"UD_27 test set of ENGLISH_PRONOUNS."
+UD_27_ERZYA_JR_TEST = _UD_27_HOME + "UD_Erzya-JR/myv_jr-ud-test.conllu"
+"UD_27 test set of ERZYA_JR."
+UD_27_ESTONIAN_EDT_TRAIN = _UD_27_HOME + "UD_Estonian-EDT/et_edt-ud-train.conllu"
+"UD_27 train set of ESTONIAN_EDT."
+UD_27_ESTONIAN_EDT_DEV = _UD_27_HOME + "UD_Estonian-EDT/et_edt-ud-dev.conllu"
+"UD_27 dev set of ESTONIAN_EDT."
+UD_27_ESTONIAN_EDT_TEST = _UD_27_HOME + "UD_Estonian-EDT/et_edt-ud-test.conllu"
+"UD_27 test set of ESTONIAN_EDT."
+UD_27_ESTONIAN_EWT_TRAIN = _UD_27_HOME + "UD_Estonian-EWT/et_ewt-ud-train.conllu"
+"UD_27 train set of ESTONIAN_EWT."
+UD_27_ESTONIAN_EWT_DEV = _UD_27_HOME + "UD_Estonian-EWT/et_ewt-ud-dev.conllu"
+"UD_27 dev set of ESTONIAN_EWT."
+UD_27_ESTONIAN_EWT_TEST = _UD_27_HOME + "UD_Estonian-EWT/et_ewt-ud-test.conllu"
+"UD_27 test set of ESTONIAN_EWT."
+UD_27_FAROESE_FARPAHC_TRAIN = _UD_27_HOME + "UD_Faroese-FarPaHC/fo_farpahc-ud-train.conllu"
+"UD_27 train set of FAROESE_FARPAHC."
+UD_27_FAROESE_FARPAHC_DEV = _UD_27_HOME + "UD_Faroese-FarPaHC/fo_farpahc-ud-dev.conllu"
+"UD_27 dev set of FAROESE_FARPAHC."
+UD_27_FAROESE_FARPAHC_TEST = _UD_27_HOME + "UD_Faroese-FarPaHC/fo_farpahc-ud-test.conllu"
+"UD_27 test set of FAROESE_FARPAHC."
+UD_27_FAROESE_OFT_TEST = _UD_27_HOME + "UD_Faroese-OFT/fo_oft-ud-test.conllu"
+"UD_27 test set of FAROESE_OFT."
+UD_27_FINNISH_FTB_TRAIN = _UD_27_HOME + "UD_Finnish-FTB/fi_ftb-ud-train.conllu"
+"UD_27 train set of FINNISH_FTB."
+UD_27_FINNISH_FTB_DEV = _UD_27_HOME + "UD_Finnish-FTB/fi_ftb-ud-dev.conllu"
+"UD_27 dev set of FINNISH_FTB."
+UD_27_FINNISH_FTB_TEST = _UD_27_HOME + "UD_Finnish-FTB/fi_ftb-ud-test.conllu"
+"UD_27 test set of FINNISH_FTB."
+UD_27_FINNISH_OOD_TEST = _UD_27_HOME + "UD_Finnish-OOD/fi_ood-ud-test.conllu"
+"UD_27 test set of FINNISH_OOD."
+UD_27_FINNISH_PUD_TEST = _UD_27_HOME + "UD_Finnish-PUD/fi_pud-ud-test.conllu"
+"UD_27 test set of FINNISH_PUD."
+UD_27_FINNISH_TDT_TRAIN = _UD_27_HOME + "UD_Finnish-TDT/fi_tdt-ud-train.conllu"
+"UD_27 train set of FINNISH_TDT."
+UD_27_FINNISH_TDT_DEV = _UD_27_HOME + "UD_Finnish-TDT/fi_tdt-ud-dev.conllu"
+"UD_27 dev set of FINNISH_TDT."
+UD_27_FINNISH_TDT_TEST = _UD_27_HOME + "UD_Finnish-TDT/fi_tdt-ud-test.conllu"
+"UD_27 test set of FINNISH_TDT."
+UD_27_FRENCH_FQB_TEST = _UD_27_HOME + "UD_French-FQB/fr_fqb-ud-test.conllu"
+"UD_27 test set of FRENCH_FQB."
+UD_27_FRENCH_FTB_TRAIN = _UD_27_HOME + "UD_French-FTB/fr_ftb-ud-train.conllu"
+"UD_27 train set of FRENCH_FTB."
+UD_27_FRENCH_FTB_DEV = _UD_27_HOME + "UD_French-FTB/fr_ftb-ud-dev.conllu"
+"UD_27 dev set of FRENCH_FTB."
+UD_27_FRENCH_FTB_TEST = _UD_27_HOME + "UD_French-FTB/fr_ftb-ud-test.conllu"
+"UD_27 test set of FRENCH_FTB."
+UD_27_FRENCH_GSD_TRAIN = _UD_27_HOME + "UD_French-GSD/fr_gsd-ud-train.conllu"
+"UD_27 train set of FRENCH_GSD."
+UD_27_FRENCH_GSD_DEV = _UD_27_HOME + "UD_French-GSD/fr_gsd-ud-dev.conllu"
+"UD_27 dev set of FRENCH_GSD."
+UD_27_FRENCH_GSD_TEST = _UD_27_HOME + "UD_French-GSD/fr_gsd-ud-test.conllu"
+"UD_27 test set of FRENCH_GSD."
+UD_27_FRENCH_PUD_TEST = _UD_27_HOME + "UD_French-PUD/fr_pud-ud-test.conllu"
+"UD_27 test set of FRENCH_PUD."
+UD_27_FRENCH_PARTUT_TRAIN = _UD_27_HOME + "UD_French-ParTUT/fr_partut-ud-train.conllu"
+"UD_27 train set of FRENCH_PARTUT."
+UD_27_FRENCH_PARTUT_DEV = _UD_27_HOME + "UD_French-ParTUT/fr_partut-ud-dev.conllu"
+"UD_27 dev set of FRENCH_PARTUT."
+UD_27_FRENCH_PARTUT_TEST = _UD_27_HOME + "UD_French-ParTUT/fr_partut-ud-test.conllu"
+"UD_27 test set of FRENCH_PARTUT."
+UD_27_FRENCH_SEQUOIA_TRAIN = _UD_27_HOME + "UD_French-Sequoia/fr_sequoia-ud-train.conllu"
+"UD_27 train set of FRENCH_SEQUOIA."
+UD_27_FRENCH_SEQUOIA_DEV = _UD_27_HOME + "UD_French-Sequoia/fr_sequoia-ud-dev.conllu"
+"UD_27 dev set of FRENCH_SEQUOIA."
+UD_27_FRENCH_SEQUOIA_TEST = _UD_27_HOME + "UD_French-Sequoia/fr_sequoia-ud-test.conllu"
+"UD_27 test set of FRENCH_SEQUOIA."
+UD_27_FRENCH_SPOKEN_TRAIN = _UD_27_HOME + "UD_French-Spoken/fr_spoken-ud-train.conllu"
+"UD_27 train set of FRENCH_SPOKEN."
+UD_27_FRENCH_SPOKEN_DEV = _UD_27_HOME + "UD_French-Spoken/fr_spoken-ud-dev.conllu"
+"UD_27 dev set of FRENCH_SPOKEN."
+UD_27_FRENCH_SPOKEN_TEST = _UD_27_HOME + "UD_French-Spoken/fr_spoken-ud-test.conllu"
+"UD_27 test set of FRENCH_SPOKEN."
+UD_27_GALICIAN_CTG_TRAIN = _UD_27_HOME + "UD_Galician-CTG/gl_ctg-ud-train.conllu"
+"UD_27 train set of GALICIAN_CTG."
+UD_27_GALICIAN_CTG_DEV = _UD_27_HOME + "UD_Galician-CTG/gl_ctg-ud-dev.conllu"
+"UD_27 dev set of GALICIAN_CTG."
+UD_27_GALICIAN_CTG_TEST = _UD_27_HOME + "UD_Galician-CTG/gl_ctg-ud-test.conllu"
+"UD_27 test set of GALICIAN_CTG."
+UD_27_GALICIAN_TREEGAL_TRAIN = _UD_27_HOME + "UD_Galician-TreeGal/gl_treegal-ud-train.conllu"
+"UD_27 train set of GALICIAN_TREEGAL."
+UD_27_GALICIAN_TREEGAL_TEST = _UD_27_HOME + "UD_Galician-TreeGal/gl_treegal-ud-test.conllu"
+"UD_27 test set of GALICIAN_TREEGAL."
+UD_27_GERMAN_GSD_TRAIN = _UD_27_HOME + "UD_German-GSD/de_gsd-ud-train.conllu"
+"UD_27 train set of GERMAN_GSD."
+UD_27_GERMAN_GSD_DEV = _UD_27_HOME + "UD_German-GSD/de_gsd-ud-dev.conllu"
+"UD_27 dev set of GERMAN_GSD."
+UD_27_GERMAN_GSD_TEST = _UD_27_HOME + "UD_German-GSD/de_gsd-ud-test.conllu"
+"UD_27 test set of GERMAN_GSD."
+UD_27_GERMAN_HDT_TRAIN = _UD_27_HOME + "UD_German-HDT/de_hdt-ud-train.conllu"
+"UD_27 train set of GERMAN_HDT."
+UD_27_GERMAN_HDT_DEV = _UD_27_HOME + "UD_German-HDT/de_hdt-ud-dev.conllu"
+"UD_27 dev set of GERMAN_HDT."
+UD_27_GERMAN_HDT_TEST = _UD_27_HOME + "UD_German-HDT/de_hdt-ud-test.conllu"
+"UD_27 test set of GERMAN_HDT."
+UD_27_GERMAN_LIT_TEST = _UD_27_HOME + "UD_German-LIT/de_lit-ud-test.conllu"
+"UD_27 test set of GERMAN_LIT."
+UD_27_GERMAN_PUD_TEST = _UD_27_HOME + "UD_German-PUD/de_pud-ud-test.conllu"
+"UD_27 test set of GERMAN_PUD."
+UD_27_GOTHIC_PROIEL_TRAIN = _UD_27_HOME + "UD_Gothic-PROIEL/got_proiel-ud-train.conllu"
+"UD_27 train set of GOTHIC_PROIEL."
+UD_27_GOTHIC_PROIEL_DEV = _UD_27_HOME + "UD_Gothic-PROIEL/got_proiel-ud-dev.conllu"
+"UD_27 dev set of GOTHIC_PROIEL."
+UD_27_GOTHIC_PROIEL_TEST = _UD_27_HOME + "UD_Gothic-PROIEL/got_proiel-ud-test.conllu"
+"UD_27 test set of GOTHIC_PROIEL."
+UD_27_GREEK_GDT_TRAIN = _UD_27_HOME + "UD_Greek-GDT/el_gdt-ud-train.conllu"
+"UD_27 train set of GREEK_GDT."
+UD_27_GREEK_GDT_DEV = _UD_27_HOME + "UD_Greek-GDT/el_gdt-ud-dev.conllu"
+"UD_27 dev set of GREEK_GDT."
+UD_27_GREEK_GDT_TEST = _UD_27_HOME + "UD_Greek-GDT/el_gdt-ud-test.conllu"
+"UD_27 test set of GREEK_GDT."
+UD_27_HEBREW_HTB_TRAIN = _UD_27_HOME + "UD_Hebrew-HTB/he_htb-ud-train.conllu"
+"UD_27 train set of HEBREW_HTB."
+UD_27_HEBREW_HTB_DEV = _UD_27_HOME + "UD_Hebrew-HTB/he_htb-ud-dev.conllu"
+"UD_27 dev set of HEBREW_HTB."
+UD_27_HEBREW_HTB_TEST = _UD_27_HOME + "UD_Hebrew-HTB/he_htb-ud-test.conllu"
+"UD_27 test set of HEBREW_HTB."
+UD_27_HINDI_HDTB_TRAIN = _UD_27_HOME + "UD_Hindi-HDTB/hi_hdtb-ud-train.conllu"
+"UD_27 train set of HINDI_HDTB."
+UD_27_HINDI_HDTB_DEV = _UD_27_HOME + "UD_Hindi-HDTB/hi_hdtb-ud-dev.conllu"
+"UD_27 dev set of HINDI_HDTB."
+UD_27_HINDI_HDTB_TEST = _UD_27_HOME + "UD_Hindi-HDTB/hi_hdtb-ud-test.conllu"
+"UD_27 test set of HINDI_HDTB."
+UD_27_HINDI_PUD_TEST = _UD_27_HOME + "UD_Hindi-PUD/hi_pud-ud-test.conllu"
+"UD_27 test set of HINDI_PUD."
+UD_27_HINDI_ENGLISH_HIENCS_TRAIN = _UD_27_HOME + "UD_Hindi_English-HIENCS/qhe_hiencs-ud-train.conllu"
+"UD_27 train set of HINDI_ENGLISH_HIENCS."
+UD_27_HINDI_ENGLISH_HIENCS_DEV = _UD_27_HOME + "UD_Hindi_English-HIENCS/qhe_hiencs-ud-dev.conllu"
+"UD_27 dev set of HINDI_ENGLISH_HIENCS."
+UD_27_HINDI_ENGLISH_HIENCS_TEST = _UD_27_HOME + "UD_Hindi_English-HIENCS/qhe_hiencs-ud-test.conllu"
+"UD_27 test set of HINDI_ENGLISH_HIENCS."
+UD_27_HUNGARIAN_SZEGED_TRAIN = _UD_27_HOME + "UD_Hungarian-Szeged/hu_szeged-ud-train.conllu"
+"UD_27 train set of HUNGARIAN_SZEGED."
+UD_27_HUNGARIAN_SZEGED_DEV = _UD_27_HOME + "UD_Hungarian-Szeged/hu_szeged-ud-dev.conllu"
+"UD_27 dev set of HUNGARIAN_SZEGED."
+UD_27_HUNGARIAN_SZEGED_TEST = _UD_27_HOME + "UD_Hungarian-Szeged/hu_szeged-ud-test.conllu"
+"UD_27 test set of HUNGARIAN_SZEGED."
+UD_27_ICELANDIC_ICEPAHC_TRAIN = _UD_27_HOME + "UD_Icelandic-IcePaHC/is_icepahc-ud-train.conllu"
+"UD_27 train set of ICELANDIC_ICEPAHC."
+UD_27_ICELANDIC_ICEPAHC_DEV = _UD_27_HOME + "UD_Icelandic-IcePaHC/is_icepahc-ud-dev.conllu"
+"UD_27 dev set of ICELANDIC_ICEPAHC."
+UD_27_ICELANDIC_ICEPAHC_TEST = _UD_27_HOME + "UD_Icelandic-IcePaHC/is_icepahc-ud-test.conllu"
+"UD_27 test set of ICELANDIC_ICEPAHC."
+UD_27_ICELANDIC_PUD_TEST = _UD_27_HOME + "UD_Icelandic-PUD/is_pud-ud-test.conllu"
+"UD_27 test set of ICELANDIC_PUD."
+UD_27_INDONESIAN_CSUI_TRAIN = _UD_27_HOME + "UD_Indonesian-CSUI/id_csui-ud-train.conllu"
+"UD_27 train set of INDONESIAN_CSUI."
+UD_27_INDONESIAN_CSUI_TEST = _UD_27_HOME + "UD_Indonesian-CSUI/id_csui-ud-test.conllu"
+"UD_27 test set of INDONESIAN_CSUI."
+UD_27_INDONESIAN_GSD_TRAIN = _UD_27_HOME + "UD_Indonesian-GSD/id_gsd-ud-train.conllu"
+"UD_27 train set of INDONESIAN_GSD."
+UD_27_INDONESIAN_GSD_DEV = _UD_27_HOME + "UD_Indonesian-GSD/id_gsd-ud-dev.conllu"
+"UD_27 dev set of INDONESIAN_GSD."
+UD_27_INDONESIAN_GSD_TEST = _UD_27_HOME + "UD_Indonesian-GSD/id_gsd-ud-test.conllu"
+"UD_27 test set of INDONESIAN_GSD."
+UD_27_INDONESIAN_PUD_TEST = _UD_27_HOME + "UD_Indonesian-PUD/id_pud-ud-test.conllu"
+"UD_27 test set of INDONESIAN_PUD."
+UD_27_IRISH_IDT_TRAIN = _UD_27_HOME + "UD_Irish-IDT/ga_idt-ud-train.conllu"
+"UD_27 train set of IRISH_IDT."
+UD_27_IRISH_IDT_DEV = _UD_27_HOME + "UD_Irish-IDT/ga_idt-ud-dev.conllu"
+"UD_27 dev set of IRISH_IDT."
+UD_27_IRISH_IDT_TEST = _UD_27_HOME + "UD_Irish-IDT/ga_idt-ud-test.conllu"
+"UD_27 test set of IRISH_IDT."
+UD_27_ITALIAN_ISDT_TRAIN = _UD_27_HOME + "UD_Italian-ISDT/it_isdt-ud-train.conllu"
+"UD_27 train set of ITALIAN_ISDT."
+UD_27_ITALIAN_ISDT_DEV = _UD_27_HOME + "UD_Italian-ISDT/it_isdt-ud-dev.conllu"
+"UD_27 dev set of ITALIAN_ISDT."
+UD_27_ITALIAN_ISDT_TEST = _UD_27_HOME + "UD_Italian-ISDT/it_isdt-ud-test.conllu"
+"UD_27 test set of ITALIAN_ISDT."
+UD_27_ITALIAN_PUD_TEST = _UD_27_HOME + "UD_Italian-PUD/it_pud-ud-test.conllu"
+"UD_27 test set of ITALIAN_PUD."
+UD_27_ITALIAN_PARTUT_TRAIN = _UD_27_HOME + "UD_Italian-ParTUT/it_partut-ud-train.conllu"
+"UD_27 train set of ITALIAN_PARTUT."
+UD_27_ITALIAN_PARTUT_DEV = _UD_27_HOME + "UD_Italian-ParTUT/it_partut-ud-dev.conllu"
+"UD_27 dev set of ITALIAN_PARTUT."
+UD_27_ITALIAN_PARTUT_TEST = _UD_27_HOME + "UD_Italian-ParTUT/it_partut-ud-test.conllu"
+"UD_27 test set of ITALIAN_PARTUT."
+UD_27_ITALIAN_POSTWITA_TRAIN = _UD_27_HOME + "UD_Italian-PoSTWITA/it_postwita-ud-train.conllu"
+"UD_27 train set of ITALIAN_POSTWITA."
+UD_27_ITALIAN_POSTWITA_DEV = _UD_27_HOME + "UD_Italian-PoSTWITA/it_postwita-ud-dev.conllu"
+"UD_27 dev set of ITALIAN_POSTWITA."
+UD_27_ITALIAN_POSTWITA_TEST = _UD_27_HOME + "UD_Italian-PoSTWITA/it_postwita-ud-test.conllu"
+"UD_27 test set of ITALIAN_POSTWITA."
+UD_27_ITALIAN_TWITTIRO_TRAIN = _UD_27_HOME + "UD_Italian-TWITTIRO/it_twittiro-ud-train.conllu"
+"UD_27 train set of ITALIAN_TWITTIRO."
+UD_27_ITALIAN_TWITTIRO_DEV = _UD_27_HOME + "UD_Italian-TWITTIRO/it_twittiro-ud-dev.conllu"
+"UD_27 dev set of ITALIAN_TWITTIRO."
+UD_27_ITALIAN_TWITTIRO_TEST = _UD_27_HOME + "UD_Italian-TWITTIRO/it_twittiro-ud-test.conllu"
+"UD_27 test set of ITALIAN_TWITTIRO."
+UD_27_ITALIAN_VIT_TRAIN = _UD_27_HOME + "UD_Italian-VIT/it_vit-ud-train.conllu"
+"UD_27 train set of ITALIAN_VIT."
+UD_27_ITALIAN_VIT_DEV = _UD_27_HOME + "UD_Italian-VIT/it_vit-ud-dev.conllu"
+"UD_27 dev set of ITALIAN_VIT."
+UD_27_ITALIAN_VIT_TEST = _UD_27_HOME + "UD_Italian-VIT/it_vit-ud-test.conllu"
+"UD_27 test set of ITALIAN_VIT."
+UD_27_JAPANESE_BCCWJ_TRAIN = _UD_27_HOME + "UD_Japanese-BCCWJ/ja_bccwj-ud-train.conllu"
+"UD_27 train set of JAPANESE_BCCWJ."
+UD_27_JAPANESE_BCCWJ_DEV = _UD_27_HOME + "UD_Japanese-BCCWJ/ja_bccwj-ud-dev.conllu"
+"UD_27 dev set of JAPANESE_BCCWJ."
+UD_27_JAPANESE_BCCWJ_TEST = _UD_27_HOME + "UD_Japanese-BCCWJ/ja_bccwj-ud-test.conllu"
+"UD_27 test set of JAPANESE_BCCWJ."
+UD_27_JAPANESE_GSD_TRAIN = _UD_27_HOME + "UD_Japanese-GSD/ja_gsd-ud-train.conllu"
+"UD_27 train set of JAPANESE_GSD."
+UD_27_JAPANESE_GSD_DEV = _UD_27_HOME + "UD_Japanese-GSD/ja_gsd-ud-dev.conllu"
+"UD_27 dev set of JAPANESE_GSD."
+UD_27_JAPANESE_GSD_TEST = _UD_27_HOME + "UD_Japanese-GSD/ja_gsd-ud-test.conllu"
+"UD_27 test set of JAPANESE_GSD."
+UD_27_JAPANESE_MODERN_TEST = _UD_27_HOME + "UD_Japanese-Modern/ja_modern-ud-test.conllu"
+"UD_27 test set of JAPANESE_MODERN."
+UD_27_JAPANESE_PUD_TEST = _UD_27_HOME + "UD_Japanese-PUD/ja_pud-ud-test.conllu"
+"UD_27 test set of JAPANESE_PUD."
+UD_27_KARELIAN_KKPP_TEST = _UD_27_HOME + "UD_Karelian-KKPP/krl_kkpp-ud-test.conllu"
+"UD_27 test set of KARELIAN_KKPP."
+UD_27_KAZAKH_KTB_TRAIN = _UD_27_HOME + "UD_Kazakh-KTB/kk_ktb-ud-train.conllu"
+"UD_27 train set of KAZAKH_KTB."
+UD_27_KAZAKH_KTB_TEST = _UD_27_HOME + "UD_Kazakh-KTB/kk_ktb-ud-test.conllu"
+"UD_27 test set of KAZAKH_KTB."
+UD_27_KHUNSARI_AHA_TEST = _UD_27_HOME + "UD_Khunsari-AHA/kfm_aha-ud-test.conllu"
+"UD_27 test set of KHUNSARI_AHA."
+UD_27_KOMI_PERMYAK_UH_TEST = _UD_27_HOME + "UD_Komi_Permyak-UH/koi_uh-ud-test.conllu"
+"UD_27 test set of KOMI_PERMYAK_UH."
+UD_27_KOMI_ZYRIAN_IKDP_TEST = _UD_27_HOME + "UD_Komi_Zyrian-IKDP/kpv_ikdp-ud-test.conllu"
+"UD_27 test set of KOMI_ZYRIAN_IKDP."
+UD_27_KOMI_ZYRIAN_LATTICE_TEST = _UD_27_HOME + "UD_Komi_Zyrian-Lattice/kpv_lattice-ud-test.conllu"
+"UD_27 test set of KOMI_ZYRIAN_LATTICE."
+UD_27_KOREAN_GSD_TRAIN = _UD_27_HOME + "UD_Korean-GSD/ko_gsd-ud-train.conllu"
+"UD_27 train set of KOREAN_GSD."
+UD_27_KOREAN_GSD_DEV = _UD_27_HOME + "UD_Korean-GSD/ko_gsd-ud-dev.conllu"
+"UD_27 dev set of KOREAN_GSD."
+UD_27_KOREAN_GSD_TEST = _UD_27_HOME + "UD_Korean-GSD/ko_gsd-ud-test.conllu"
+"UD_27 test set of KOREAN_GSD."
+UD_27_KOREAN_KAIST_TRAIN = _UD_27_HOME + "UD_Korean-Kaist/ko_kaist-ud-train.conllu"
+"UD_27 train set of KOREAN_KAIST."
+UD_27_KOREAN_KAIST_DEV = _UD_27_HOME + "UD_Korean-Kaist/ko_kaist-ud-dev.conllu"
+"UD_27 dev set of KOREAN_KAIST."
+UD_27_KOREAN_KAIST_TEST = _UD_27_HOME + "UD_Korean-Kaist/ko_kaist-ud-test.conllu"
+"UD_27 test set of KOREAN_KAIST."
+UD_27_KOREAN_PUD_TEST = _UD_27_HOME + "UD_Korean-PUD/ko_pud-ud-test.conllu"
+"UD_27 test set of KOREAN_PUD."
+UD_27_KURMANJI_MG_TRAIN = _UD_27_HOME + "UD_Kurmanji-MG/kmr_mg-ud-train.conllu"
+"UD_27 train set of KURMANJI_MG."
+UD_27_KURMANJI_MG_TEST = _UD_27_HOME + "UD_Kurmanji-MG/kmr_mg-ud-test.conllu"
+"UD_27 test set of KURMANJI_MG."
+UD_27_LATIN_ITTB_TRAIN = _UD_27_HOME + "UD_Latin-ITTB/la_ittb-ud-train.conllu"
+"UD_27 train set of LATIN_ITTB."
+UD_27_LATIN_ITTB_DEV = _UD_27_HOME + "UD_Latin-ITTB/la_ittb-ud-dev.conllu"
+"UD_27 dev set of LATIN_ITTB."
+UD_27_LATIN_ITTB_TEST = _UD_27_HOME + "UD_Latin-ITTB/la_ittb-ud-test.conllu"
+"UD_27 test set of LATIN_ITTB."
+UD_27_LATIN_LLCT_TRAIN = _UD_27_HOME + "UD_Latin-LLCT/la_llct-ud-train.conllu"
+"UD_27 train set of LATIN_LLCT."
+UD_27_LATIN_LLCT_DEV = _UD_27_HOME + "UD_Latin-LLCT/la_llct-ud-dev.conllu"
+"UD_27 dev set of LATIN_LLCT."
+UD_27_LATIN_LLCT_TEST = _UD_27_HOME + "UD_Latin-LLCT/la_llct-ud-test.conllu"
+"UD_27 test set of LATIN_LLCT."
+UD_27_LATIN_PROIEL_TRAIN = _UD_27_HOME + "UD_Latin-PROIEL/la_proiel-ud-train.conllu"
+"UD_27 train set of LATIN_PROIEL."
+UD_27_LATIN_PROIEL_DEV = _UD_27_HOME + "UD_Latin-PROIEL/la_proiel-ud-dev.conllu"
+"UD_27 dev set of LATIN_PROIEL."
+UD_27_LATIN_PROIEL_TEST = _UD_27_HOME + "UD_Latin-PROIEL/la_proiel-ud-test.conllu"
+"UD_27 test set of LATIN_PROIEL."
+UD_27_LATIN_PERSEUS_TRAIN = _UD_27_HOME + "UD_Latin-Perseus/la_perseus-ud-train.conllu"
+"UD_27 train set of LATIN_PERSEUS."
+UD_27_LATIN_PERSEUS_TEST = _UD_27_HOME + "UD_Latin-Perseus/la_perseus-ud-test.conllu"
+"UD_27 test set of LATIN_PERSEUS."
+UD_27_LATVIAN_LVTB_TRAIN = _UD_27_HOME + "UD_Latvian-LVTB/lv_lvtb-ud-train.conllu"
+"UD_27 train set of LATVIAN_LVTB."
+UD_27_LATVIAN_LVTB_DEV = _UD_27_HOME + "UD_Latvian-LVTB/lv_lvtb-ud-dev.conllu"
+"UD_27 dev set of LATVIAN_LVTB."
+UD_27_LATVIAN_LVTB_TEST = _UD_27_HOME + "UD_Latvian-LVTB/lv_lvtb-ud-test.conllu"
+"UD_27 test set of LATVIAN_LVTB."
+UD_27_LITHUANIAN_ALKSNIS_TRAIN = _UD_27_HOME + "UD_Lithuanian-ALKSNIS/lt_alksnis-ud-train.conllu"
+"UD_27 train set of LITHUANIAN_ALKSNIS."
+UD_27_LITHUANIAN_ALKSNIS_DEV = _UD_27_HOME + "UD_Lithuanian-ALKSNIS/lt_alksnis-ud-dev.conllu"
+"UD_27 dev set of LITHUANIAN_ALKSNIS."
+UD_27_LITHUANIAN_ALKSNIS_TEST = _UD_27_HOME + "UD_Lithuanian-ALKSNIS/lt_alksnis-ud-test.conllu"
+"UD_27 test set of LITHUANIAN_ALKSNIS."
+UD_27_LITHUANIAN_HSE_TRAIN = _UD_27_HOME + "UD_Lithuanian-HSE/lt_hse-ud-train.conllu"
+"UD_27 train set of LITHUANIAN_HSE."
+UD_27_LITHUANIAN_HSE_DEV = _UD_27_HOME + "UD_Lithuanian-HSE/lt_hse-ud-dev.conllu"
+"UD_27 dev set of LITHUANIAN_HSE."
+UD_27_LITHUANIAN_HSE_TEST = _UD_27_HOME + "UD_Lithuanian-HSE/lt_hse-ud-test.conllu"
+"UD_27 test set of LITHUANIAN_HSE."
+UD_27_LIVVI_KKPP_TRAIN = _UD_27_HOME + "UD_Livvi-KKPP/olo_kkpp-ud-train.conllu"
+"UD_27 train set of LIVVI_KKPP."
+UD_27_LIVVI_KKPP_TEST = _UD_27_HOME + "UD_Livvi-KKPP/olo_kkpp-ud-test.conllu"
+"UD_27 test set of LIVVI_KKPP."
+UD_27_MALTESE_MUDT_TRAIN = _UD_27_HOME + "UD_Maltese-MUDT/mt_mudt-ud-train.conllu"
+"UD_27 train set of MALTESE_MUDT."
+UD_27_MALTESE_MUDT_DEV = _UD_27_HOME + "UD_Maltese-MUDT/mt_mudt-ud-dev.conllu"
+"UD_27 dev set of MALTESE_MUDT."
+UD_27_MALTESE_MUDT_TEST = _UD_27_HOME + "UD_Maltese-MUDT/mt_mudt-ud-test.conllu"
+"UD_27 test set of MALTESE_MUDT."
+UD_27_MANX_CADHAN_TEST = _UD_27_HOME + "UD_Manx-Cadhan/gv_cadhan-ud-test.conllu"
+"UD_27 test set of MANX_CADHAN."
+UD_27_MARATHI_UFAL_TRAIN = _UD_27_HOME + "UD_Marathi-UFAL/mr_ufal-ud-train.conllu"
+"UD_27 train set of MARATHI_UFAL."
+UD_27_MARATHI_UFAL_DEV = _UD_27_HOME + "UD_Marathi-UFAL/mr_ufal-ud-dev.conllu"
+"UD_27 dev set of MARATHI_UFAL."
+UD_27_MARATHI_UFAL_TEST = _UD_27_HOME + "UD_Marathi-UFAL/mr_ufal-ud-test.conllu"
+"UD_27 test set of MARATHI_UFAL."
+UD_27_MBYA_GUARANI_DOOLEY_TEST = _UD_27_HOME + "UD_Mbya_Guarani-Dooley/gun_dooley-ud-test.conllu"
+"UD_27 test set of MBYA_GUARANI_DOOLEY."
+UD_27_MBYA_GUARANI_THOMAS_TEST = _UD_27_HOME + "UD_Mbya_Guarani-Thomas/gun_thomas-ud-test.conllu"
+"UD_27 test set of MBYA_GUARANI_THOMAS."
+UD_27_MOKSHA_JR_TEST = _UD_27_HOME + "UD_Moksha-JR/mdf_jr-ud-test.conllu"
+"UD_27 test set of MOKSHA_JR."
+UD_27_MUNDURUKU_TUDET_TEST = _UD_27_HOME + "UD_Munduruku-TuDeT/myu_tudet-ud-test.conllu"
+"UD_27 test set of MUNDURUKU_TUDET."
+UD_27_NAIJA_NSC_TRAIN = _UD_27_HOME + "UD_Naija-NSC/pcm_nsc-ud-train.conllu"
+"UD_27 train set of NAIJA_NSC."
+UD_27_NAIJA_NSC_DEV = _UD_27_HOME + "UD_Naija-NSC/pcm_nsc-ud-dev.conllu"
+"UD_27 dev set of NAIJA_NSC."
+UD_27_NAIJA_NSC_TEST = _UD_27_HOME + "UD_Naija-NSC/pcm_nsc-ud-test.conllu"
+"UD_27 test set of NAIJA_NSC."
+UD_27_NAYINI_AHA_TEST = _UD_27_HOME + "UD_Nayini-AHA/nyq_aha-ud-test.conllu"
+"UD_27 test set of NAYINI_AHA."
+UD_27_NORTH_SAMI_GIELLA_TRAIN = _UD_27_HOME + "UD_North_Sami-Giella/sme_giella-ud-train.conllu"
+"UD_27 train set of NORTH_SAMI_GIELLA."
+UD_27_NORTH_SAMI_GIELLA_TEST = _UD_27_HOME + "UD_North_Sami-Giella/sme_giella-ud-test.conllu"
+"UD_27 test set of NORTH_SAMI_GIELLA."
+UD_27_NORWEGIAN_BOKMAAL_TRAIN = _UD_27_HOME + "UD_Norwegian-Bokmaal/no_bokmaal-ud-train.conllu"
+"UD_27 train set of NORWEGIAN_BOKMAAL."
+UD_27_NORWEGIAN_BOKMAAL_DEV = _UD_27_HOME + "UD_Norwegian-Bokmaal/no_bokmaal-ud-dev.conllu"
+"UD_27 dev set of NORWEGIAN_BOKMAAL."
+UD_27_NORWEGIAN_BOKMAAL_TEST = _UD_27_HOME + "UD_Norwegian-Bokmaal/no_bokmaal-ud-test.conllu"
+"UD_27 test set of NORWEGIAN_BOKMAAL."
+UD_27_NORWEGIAN_NYNORSK_TRAIN = _UD_27_HOME + "UD_Norwegian-Nynorsk/no_nynorsk-ud-train.conllu"
+"UD_27 train set of NORWEGIAN_NYNORSK."
+UD_27_NORWEGIAN_NYNORSK_DEV = _UD_27_HOME + "UD_Norwegian-Nynorsk/no_nynorsk-ud-dev.conllu"
+"UD_27 dev set of NORWEGIAN_NYNORSK."
+UD_27_NORWEGIAN_NYNORSK_TEST = _UD_27_HOME + "UD_Norwegian-Nynorsk/no_nynorsk-ud-test.conllu"
+"UD_27 test set of NORWEGIAN_NYNORSK."
+UD_27_NORWEGIAN_NYNORSKLIA_TRAIN = _UD_27_HOME + "UD_Norwegian-NynorskLIA/no_nynorsklia-ud-train.conllu"
+"UD_27 train set of NORWEGIAN_NYNORSKLIA."
+UD_27_NORWEGIAN_NYNORSKLIA_DEV = _UD_27_HOME + "UD_Norwegian-NynorskLIA/no_nynorsklia-ud-dev.conllu"
+"UD_27 dev set of NORWEGIAN_NYNORSKLIA."
+UD_27_NORWEGIAN_NYNORSKLIA_TEST = _UD_27_HOME + "UD_Norwegian-NynorskLIA/no_nynorsklia-ud-test.conllu"
+"UD_27 test set of NORWEGIAN_NYNORSKLIA."
+UD_27_OLD_CHURCH_SLAVONIC_PROIEL_TRAIN = _UD_27_HOME + "UD_Old_Church_Slavonic-PROIEL/cu_proiel-ud-train.conllu"
+"UD_27 train set of OLD_CHURCH_SLAVONIC_PROIEL."
+UD_27_OLD_CHURCH_SLAVONIC_PROIEL_DEV = _UD_27_HOME + "UD_Old_Church_Slavonic-PROIEL/cu_proiel-ud-dev.conllu"
+"UD_27 dev set of OLD_CHURCH_SLAVONIC_PROIEL."
+UD_27_OLD_CHURCH_SLAVONIC_PROIEL_TEST = _UD_27_HOME + "UD_Old_Church_Slavonic-PROIEL/cu_proiel-ud-test.conllu"
+"UD_27 test set of OLD_CHURCH_SLAVONIC_PROIEL."
+UD_27_OLD_FRENCH_SRCMF_TRAIN = _UD_27_HOME + "UD_Old_French-SRCMF/fro_srcmf-ud-train.conllu"
+"UD_27 train set of OLD_FRENCH_SRCMF."
+UD_27_OLD_FRENCH_SRCMF_DEV = _UD_27_HOME + "UD_Old_French-SRCMF/fro_srcmf-ud-dev.conllu"
+"UD_27 dev set of OLD_FRENCH_SRCMF."
+UD_27_OLD_FRENCH_SRCMF_TEST = _UD_27_HOME + "UD_Old_French-SRCMF/fro_srcmf-ud-test.conllu"
+"UD_27 test set of OLD_FRENCH_SRCMF."
+UD_27_OLD_RUSSIAN_RNC_TRAIN = _UD_27_HOME + "UD_Old_Russian-RNC/orv_rnc-ud-train.conllu"
+"UD_27 train set of OLD_RUSSIAN_RNC."
+UD_27_OLD_RUSSIAN_RNC_TEST = _UD_27_HOME + "UD_Old_Russian-RNC/orv_rnc-ud-test.conllu"
+"UD_27 test set of OLD_RUSSIAN_RNC."
+UD_27_OLD_RUSSIAN_TOROT_TRAIN = _UD_27_HOME + "UD_Old_Russian-TOROT/orv_torot-ud-train.conllu"
+"UD_27 train set of OLD_RUSSIAN_TOROT."
+UD_27_OLD_RUSSIAN_TOROT_DEV = _UD_27_HOME + "UD_Old_Russian-TOROT/orv_torot-ud-dev.conllu"
+"UD_27 dev set of OLD_RUSSIAN_TOROT."
+UD_27_OLD_RUSSIAN_TOROT_TEST = _UD_27_HOME + "UD_Old_Russian-TOROT/orv_torot-ud-test.conllu"
+"UD_27 test set of OLD_RUSSIAN_TOROT."
+UD_27_OLD_TURKISH_TONQQ_TEST = _UD_27_HOME + "UD_Old_Turkish-Tonqq/otk_tonqq-ud-test.conllu"
+"UD_27 test set of OLD_TURKISH_TONQQ."
+UD_27_PERSIAN_PERDT_TRAIN = _UD_27_HOME + "UD_Persian-PerDT/fa_perdt-ud-train.conllu"
+"UD_27 train set of PERSIAN_PERDT."
+UD_27_PERSIAN_PERDT_DEV = _UD_27_HOME + "UD_Persian-PerDT/fa_perdt-ud-dev.conllu"
+"UD_27 dev set of PERSIAN_PERDT."
+UD_27_PERSIAN_PERDT_TEST = _UD_27_HOME + "UD_Persian-PerDT/fa_perdt-ud-test.conllu"
+"UD_27 test set of PERSIAN_PERDT."
+UD_27_PERSIAN_SERAJI_TRAIN = _UD_27_HOME + "UD_Persian-Seraji/fa_seraji-ud-train.conllu"
+"UD_27 train set of PERSIAN_SERAJI."
+UD_27_PERSIAN_SERAJI_DEV = _UD_27_HOME + "UD_Persian-Seraji/fa_seraji-ud-dev.conllu"
+"UD_27 dev set of PERSIAN_SERAJI."
+UD_27_PERSIAN_SERAJI_TEST = _UD_27_HOME + "UD_Persian-Seraji/fa_seraji-ud-test.conllu"
+"UD_27 test set of PERSIAN_SERAJI."
+UD_27_POLISH_LFG_TRAIN = _UD_27_HOME + "UD_Polish-LFG/pl_lfg-ud-train.conllu"
+"UD_27 train set of POLISH_LFG."
+UD_27_POLISH_LFG_DEV = _UD_27_HOME + "UD_Polish-LFG/pl_lfg-ud-dev.conllu"
+"UD_27 dev set of POLISH_LFG."
+UD_27_POLISH_LFG_TEST = _UD_27_HOME + "UD_Polish-LFG/pl_lfg-ud-test.conllu"
+"UD_27 test set of POLISH_LFG."
+UD_27_POLISH_PDB_TRAIN = _UD_27_HOME + "UD_Polish-PDB/pl_pdb-ud-train.conllu"
+"UD_27 train set of POLISH_PDB."
+UD_27_POLISH_PDB_DEV = _UD_27_HOME + "UD_Polish-PDB/pl_pdb-ud-dev.conllu"
+"UD_27 dev set of POLISH_PDB."
+UD_27_POLISH_PDB_TEST = _UD_27_HOME + "UD_Polish-PDB/pl_pdb-ud-test.conllu"
+"UD_27 test set of POLISH_PDB."
+UD_27_POLISH_PUD_TEST = _UD_27_HOME + "UD_Polish-PUD/pl_pud-ud-test.conllu"
+"UD_27 test set of POLISH_PUD."
+UD_27_PORTUGUESE_BOSQUE_TRAIN = _UD_27_HOME + "UD_Portuguese-Bosque/pt_bosque-ud-train.conllu"
+"UD_27 train set of PORTUGUESE_BOSQUE."
+UD_27_PORTUGUESE_BOSQUE_DEV = _UD_27_HOME + "UD_Portuguese-Bosque/pt_bosque-ud-dev.conllu"
+"UD_27 dev set of PORTUGUESE_BOSQUE."
+UD_27_PORTUGUESE_BOSQUE_TEST = _UD_27_HOME + "UD_Portuguese-Bosque/pt_bosque-ud-test.conllu"
+"UD_27 test set of PORTUGUESE_BOSQUE."
+UD_27_PORTUGUESE_GSD_TRAIN = _UD_27_HOME + "UD_Portuguese-GSD/pt_gsd-ud-train.conllu"
+"UD_27 train set of PORTUGUESE_GSD."
+UD_27_PORTUGUESE_GSD_DEV = _UD_27_HOME + "UD_Portuguese-GSD/pt_gsd-ud-dev.conllu"
+"UD_27 dev set of PORTUGUESE_GSD."
+UD_27_PORTUGUESE_GSD_TEST = _UD_27_HOME + "UD_Portuguese-GSD/pt_gsd-ud-test.conllu"
+"UD_27 test set of PORTUGUESE_GSD."
+UD_27_PORTUGUESE_PUD_TEST = _UD_27_HOME + "UD_Portuguese-PUD/pt_pud-ud-test.conllu"
+"UD_27 test set of PORTUGUESE_PUD."
+UD_27_ROMANIAN_NONSTANDARD_TRAIN = _UD_27_HOME + "UD_Romanian-Nonstandard/ro_nonstandard-ud-train.conllu"
+"UD_27 train set of ROMANIAN_NONSTANDARD."
+UD_27_ROMANIAN_NONSTANDARD_DEV = _UD_27_HOME + "UD_Romanian-Nonstandard/ro_nonstandard-ud-dev.conllu"
+"UD_27 dev set of ROMANIAN_NONSTANDARD."
+UD_27_ROMANIAN_NONSTANDARD_TEST = _UD_27_HOME + "UD_Romanian-Nonstandard/ro_nonstandard-ud-test.conllu"
+"UD_27 test set of ROMANIAN_NONSTANDARD."
+UD_27_ROMANIAN_RRT_TRAIN = _UD_27_HOME + "UD_Romanian-RRT/ro_rrt-ud-train.conllu"
+"UD_27 train set of ROMANIAN_RRT."
+UD_27_ROMANIAN_RRT_DEV = _UD_27_HOME + "UD_Romanian-RRT/ro_rrt-ud-dev.conllu"
+"UD_27 dev set of ROMANIAN_RRT."
+UD_27_ROMANIAN_RRT_TEST = _UD_27_HOME + "UD_Romanian-RRT/ro_rrt-ud-test.conllu"
+"UD_27 test set of ROMANIAN_RRT."
+UD_27_ROMANIAN_SIMONERO_TRAIN = _UD_27_HOME + "UD_Romanian-SiMoNERo/ro_simonero-ud-train.conllu"
+"UD_27 train set of ROMANIAN_SIMONERO."
+UD_27_ROMANIAN_SIMONERO_DEV = _UD_27_HOME + "UD_Romanian-SiMoNERo/ro_simonero-ud-dev.conllu"
+"UD_27 dev set of ROMANIAN_SIMONERO."
+UD_27_ROMANIAN_SIMONERO_TEST = _UD_27_HOME + "UD_Romanian-SiMoNERo/ro_simonero-ud-test.conllu"
+"UD_27 test set of ROMANIAN_SIMONERO."
+UD_27_RUSSIAN_GSD_TRAIN = _UD_27_HOME + "UD_Russian-GSD/ru_gsd-ud-train.conllu"
+"UD_27 train set of RUSSIAN_GSD."
+UD_27_RUSSIAN_GSD_DEV = _UD_27_HOME + "UD_Russian-GSD/ru_gsd-ud-dev.conllu"
+"UD_27 dev set of RUSSIAN_GSD."
+UD_27_RUSSIAN_GSD_TEST = _UD_27_HOME + "UD_Russian-GSD/ru_gsd-ud-test.conllu"
+"UD_27 test set of RUSSIAN_GSD."
+UD_27_RUSSIAN_PUD_TEST = _UD_27_HOME + "UD_Russian-PUD/ru_pud-ud-test.conllu"
+"UD_27 test set of RUSSIAN_PUD."
+UD_27_RUSSIAN_SYNTAGRUS_TRAIN = _UD_27_HOME + "UD_Russian-SynTagRus/ru_syntagrus-ud-train.conllu"
+"UD_27 train set of RUSSIAN_SYNTAGRUS."
+UD_27_RUSSIAN_SYNTAGRUS_DEV = _UD_27_HOME + "UD_Russian-SynTagRus/ru_syntagrus-ud-dev.conllu"
+"UD_27 dev set of RUSSIAN_SYNTAGRUS."
+UD_27_RUSSIAN_SYNTAGRUS_TEST = _UD_27_HOME + "UD_Russian-SynTagRus/ru_syntagrus-ud-test.conllu"
+"UD_27 test set of RUSSIAN_SYNTAGRUS."
+UD_27_RUSSIAN_TAIGA_TRAIN = _UD_27_HOME + "UD_Russian-Taiga/ru_taiga-ud-train.conllu"
+"UD_27 train set of RUSSIAN_TAIGA."
+UD_27_RUSSIAN_TAIGA_DEV = _UD_27_HOME + "UD_Russian-Taiga/ru_taiga-ud-dev.conllu"
+"UD_27 dev set of RUSSIAN_TAIGA."
+UD_27_RUSSIAN_TAIGA_TEST = _UD_27_HOME + "UD_Russian-Taiga/ru_taiga-ud-test.conllu"
+"UD_27 test set of RUSSIAN_TAIGA."
+UD_27_SANSKRIT_UFAL_TEST = _UD_27_HOME + "UD_Sanskrit-UFAL/sa_ufal-ud-test.conllu"
+"UD_27 test set of SANSKRIT_UFAL."
+UD_27_SANSKRIT_VEDIC_TRAIN = _UD_27_HOME + "UD_Sanskrit-Vedic/sa_vedic-ud-train.conllu"
+"UD_27 train set of SANSKRIT_VEDIC."
+UD_27_SANSKRIT_VEDIC_TEST = _UD_27_HOME + "UD_Sanskrit-Vedic/sa_vedic-ud-test.conllu"
+"UD_27 test set of SANSKRIT_VEDIC."
+UD_27_SCOTTISH_GAELIC_ARCOSG_TRAIN = _UD_27_HOME + "UD_Scottish_Gaelic-ARCOSG/gd_arcosg-ud-train.conllu"
+"UD_27 train set of SCOTTISH_GAELIC_ARCOSG."
+UD_27_SCOTTISH_GAELIC_ARCOSG_DEV = _UD_27_HOME + "UD_Scottish_Gaelic-ARCOSG/gd_arcosg-ud-dev.conllu"
+"UD_27 dev set of SCOTTISH_GAELIC_ARCOSG."
+UD_27_SCOTTISH_GAELIC_ARCOSG_TEST = _UD_27_HOME + "UD_Scottish_Gaelic-ARCOSG/gd_arcosg-ud-test.conllu"
+"UD_27 test set of SCOTTISH_GAELIC_ARCOSG."
+UD_27_SERBIAN_SET_TRAIN = _UD_27_HOME + "UD_Serbian-SET/sr_set-ud-train.conllu"
+"UD_27 train set of SERBIAN_SET."
+UD_27_SERBIAN_SET_DEV = _UD_27_HOME + "UD_Serbian-SET/sr_set-ud-dev.conllu"
+"UD_27 dev set of SERBIAN_SET."
+UD_27_SERBIAN_SET_TEST = _UD_27_HOME + "UD_Serbian-SET/sr_set-ud-test.conllu"
+"UD_27 test set of SERBIAN_SET."
+UD_27_SKOLT_SAMI_GIELLAGAS_TEST = _UD_27_HOME + "UD_Skolt_Sami-Giellagas/sms_giellagas-ud-test.conllu"
+"UD_27 test set of SKOLT_SAMI_GIELLAGAS."
+UD_27_SLOVAK_SNK_TRAIN = _UD_27_HOME + "UD_Slovak-SNK/sk_snk-ud-train.conllu"
+"UD_27 train set of SLOVAK_SNK."
+UD_27_SLOVAK_SNK_DEV = _UD_27_HOME + "UD_Slovak-SNK/sk_snk-ud-dev.conllu"
+"UD_27 dev set of SLOVAK_SNK."
+UD_27_SLOVAK_SNK_TEST = _UD_27_HOME + "UD_Slovak-SNK/sk_snk-ud-test.conllu"
+"UD_27 test set of SLOVAK_SNK."
+UD_27_SLOVENIAN_SSJ_TRAIN = _UD_27_HOME + "UD_Slovenian-SSJ/sl_ssj-ud-train.conllu"
+"UD_27 train set of SLOVENIAN_SSJ."
+UD_27_SLOVENIAN_SSJ_DEV = _UD_27_HOME + "UD_Slovenian-SSJ/sl_ssj-ud-dev.conllu"
+"UD_27 dev set of SLOVENIAN_SSJ."
+UD_27_SLOVENIAN_SSJ_TEST = _UD_27_HOME + "UD_Slovenian-SSJ/sl_ssj-ud-test.conllu"
+"UD_27 test set of SLOVENIAN_SSJ."
+UD_27_SLOVENIAN_SST_TRAIN = _UD_27_HOME + "UD_Slovenian-SST/sl_sst-ud-train.conllu"
+"UD_27 train set of SLOVENIAN_SST."
+UD_27_SLOVENIAN_SST_TEST = _UD_27_HOME + "UD_Slovenian-SST/sl_sst-ud-test.conllu"
+"UD_27 test set of SLOVENIAN_SST."
+UD_27_SOI_AHA_TEST = _UD_27_HOME + "UD_Soi-AHA/soj_aha-ud-test.conllu"
+"UD_27 test set of SOI_AHA."
+UD_27_SOUTH_LEVANTINE_ARABIC_MADAR_TEST = _UD_27_HOME + "UD_South_Levantine_Arabic-MADAR/ajp_madar-ud-test.conllu"
+"UD_27 test set of SOUTH_LEVANTINE_ARABIC_MADAR."
+UD_27_SPANISH_ANCORA_TRAIN = _UD_27_HOME + "UD_Spanish-AnCora/es_ancora-ud-train.conllu"
+"UD_27 train set of SPANISH_ANCORA."
+UD_27_SPANISH_ANCORA_DEV = _UD_27_HOME + "UD_Spanish-AnCora/es_ancora-ud-dev.conllu"
+"UD_27 dev set of SPANISH_ANCORA."
+UD_27_SPANISH_ANCORA_TEST = _UD_27_HOME + "UD_Spanish-AnCora/es_ancora-ud-test.conllu"
+"UD_27 test set of SPANISH_ANCORA."
+UD_27_SPANISH_GSD_TRAIN = _UD_27_HOME + "UD_Spanish-GSD/es_gsd-ud-train.conllu"
+"UD_27 train set of SPANISH_GSD."
+UD_27_SPANISH_GSD_DEV = _UD_27_HOME + "UD_Spanish-GSD/es_gsd-ud-dev.conllu"
+"UD_27 dev set of SPANISH_GSD."
+UD_27_SPANISH_GSD_TEST = _UD_27_HOME + "UD_Spanish-GSD/es_gsd-ud-test.conllu"
+"UD_27 test set of SPANISH_GSD."
+UD_27_SPANISH_PUD_TEST = _UD_27_HOME + "UD_Spanish-PUD/es_pud-ud-test.conllu"
+"UD_27 test set of SPANISH_PUD."
+UD_27_SWEDISH_LINES_TRAIN = _UD_27_HOME + "UD_Swedish-LinES/sv_lines-ud-train.conllu"
+"UD_27 train set of SWEDISH_LINES."
+UD_27_SWEDISH_LINES_DEV = _UD_27_HOME + "UD_Swedish-LinES/sv_lines-ud-dev.conllu"
+"UD_27 dev set of SWEDISH_LINES."
+UD_27_SWEDISH_LINES_TEST = _UD_27_HOME + "UD_Swedish-LinES/sv_lines-ud-test.conllu"
+"UD_27 test set of SWEDISH_LINES."
+UD_27_SWEDISH_PUD_TEST = _UD_27_HOME + "UD_Swedish-PUD/sv_pud-ud-test.conllu"
+"UD_27 test set of SWEDISH_PUD."
+UD_27_SWEDISH_TALBANKEN_TRAIN = _UD_27_HOME + "UD_Swedish-Talbanken/sv_talbanken-ud-train.conllu"
+"UD_27 train set of SWEDISH_TALBANKEN."
+UD_27_SWEDISH_TALBANKEN_DEV = _UD_27_HOME + "UD_Swedish-Talbanken/sv_talbanken-ud-dev.conllu"
+"UD_27 dev set of SWEDISH_TALBANKEN."
+UD_27_SWEDISH_TALBANKEN_TEST = _UD_27_HOME + "UD_Swedish-Talbanken/sv_talbanken-ud-test.conllu"
+"UD_27 test set of SWEDISH_TALBANKEN."
+UD_27_SWEDISH_SIGN_LANGUAGE_SSLC_TRAIN = _UD_27_HOME + "UD_Swedish_Sign_Language-SSLC/swl_sslc-ud-train.conllu"
+"UD_27 train set of SWEDISH_SIGN_LANGUAGE_SSLC."
+UD_27_SWEDISH_SIGN_LANGUAGE_SSLC_DEV = _UD_27_HOME + "UD_Swedish_Sign_Language-SSLC/swl_sslc-ud-dev.conllu"
+"UD_27 dev set of SWEDISH_SIGN_LANGUAGE_SSLC."
+UD_27_SWEDISH_SIGN_LANGUAGE_SSLC_TEST = _UD_27_HOME + "UD_Swedish_Sign_Language-SSLC/swl_sslc-ud-test.conllu"
+"UD_27 test set of SWEDISH_SIGN_LANGUAGE_SSLC."
+UD_27_SWISS_GERMAN_UZH_TEST = _UD_27_HOME + "UD_Swiss_German-UZH/gsw_uzh-ud-test.conllu"
+"UD_27 test set of SWISS_GERMAN_UZH."
+UD_27_TAGALOG_TRG_TEST = _UD_27_HOME + "UD_Tagalog-TRG/tl_trg-ud-test.conllu"
+"UD_27 test set of TAGALOG_TRG."
+UD_27_TAGALOG_UGNAYAN_TEST = _UD_27_HOME + "UD_Tagalog-Ugnayan/tl_ugnayan-ud-test.conllu"
+"UD_27 test set of TAGALOG_UGNAYAN."
+UD_27_TAMIL_MWTT_TEST = _UD_27_HOME + "UD_Tamil-MWTT/ta_mwtt-ud-test.conllu"
+"UD_27 test set of TAMIL_MWTT."
+UD_27_TAMIL_TTB_TRAIN = _UD_27_HOME + "UD_Tamil-TTB/ta_ttb-ud-train.conllu"
+"UD_27 train set of TAMIL_TTB."
+UD_27_TAMIL_TTB_DEV = _UD_27_HOME + "UD_Tamil-TTB/ta_ttb-ud-dev.conllu"
+"UD_27 dev set of TAMIL_TTB."
+UD_27_TAMIL_TTB_TEST = _UD_27_HOME + "UD_Tamil-TTB/ta_ttb-ud-test.conllu"
+"UD_27 test set of TAMIL_TTB."
+UD_27_TELUGU_MTG_TRAIN = _UD_27_HOME + "UD_Telugu-MTG/te_mtg-ud-train.conllu"
+"UD_27 train set of TELUGU_MTG."
+UD_27_TELUGU_MTG_DEV = _UD_27_HOME + "UD_Telugu-MTG/te_mtg-ud-dev.conllu"
+"UD_27 dev set of TELUGU_MTG."
+UD_27_TELUGU_MTG_TEST = _UD_27_HOME + "UD_Telugu-MTG/te_mtg-ud-test.conllu"
+"UD_27 test set of TELUGU_MTG."
+UD_27_THAI_PUD_TEST = _UD_27_HOME + "UD_Thai-PUD/th_pud-ud-test.conllu"
+"UD_27 test set of THAI_PUD."
+UD_27_TUPINAMBA_TUDET_TEST = _UD_27_HOME + "UD_Tupinamba-TuDeT/tpn_tudet-ud-test.conllu"
+"UD_27 test set of TUPINAMBA_TUDET."
+UD_27_TURKISH_BOUN_TRAIN = _UD_27_HOME + "UD_Turkish-BOUN/tr_boun-ud-train.conllu"
+"UD_27 train set of TURKISH_BOUN."
+UD_27_TURKISH_BOUN_DEV = _UD_27_HOME + "UD_Turkish-BOUN/tr_boun-ud-dev.conllu"
+"UD_27 dev set of TURKISH_BOUN."
+UD_27_TURKISH_BOUN_TEST = _UD_27_HOME + "UD_Turkish-BOUN/tr_boun-ud-test.conllu"
+"UD_27 test set of TURKISH_BOUN."
+UD_27_TURKISH_GB_TEST = _UD_27_HOME + "UD_Turkish-GB/tr_gb-ud-test.conllu"
+"UD_27 test set of TURKISH_GB."
+UD_27_TURKISH_IMST_TRAIN = _UD_27_HOME + "UD_Turkish-IMST/tr_imst-ud-train.conllu"
+"UD_27 train set of TURKISH_IMST."
+UD_27_TURKISH_IMST_DEV = _UD_27_HOME + "UD_Turkish-IMST/tr_imst-ud-dev.conllu"
+"UD_27 dev set of TURKISH_IMST."
+UD_27_TURKISH_IMST_TEST = _UD_27_HOME + "UD_Turkish-IMST/tr_imst-ud-test.conllu"
+"UD_27 test set of TURKISH_IMST."
+UD_27_TURKISH_PUD_TEST = _UD_27_HOME + "UD_Turkish-PUD/tr_pud-ud-test.conllu"
+"UD_27 test set of TURKISH_PUD."
+UD_27_TURKISH_GERMAN_SAGT_TRAIN = _UD_27_HOME + "UD_Turkish_German-SAGT/qtd_sagt-ud-train.conllu"
+"UD_27 train set of TURKISH_GERMAN_SAGT."
+UD_27_TURKISH_GERMAN_SAGT_DEV = _UD_27_HOME + "UD_Turkish_German-SAGT/qtd_sagt-ud-dev.conllu"
+"UD_27 dev set of TURKISH_GERMAN_SAGT."
+UD_27_TURKISH_GERMAN_SAGT_TEST = _UD_27_HOME + "UD_Turkish_German-SAGT/qtd_sagt-ud-test.conllu"
+"UD_27 test set of TURKISH_GERMAN_SAGT."
+UD_27_UKRAINIAN_IU_TRAIN = _UD_27_HOME + "UD_Ukrainian-IU/uk_iu-ud-train.conllu"
+"UD_27 train set of UKRAINIAN_IU."
+UD_27_UKRAINIAN_IU_DEV = _UD_27_HOME + "UD_Ukrainian-IU/uk_iu-ud-dev.conllu"
+"UD_27 dev set of UKRAINIAN_IU."
+UD_27_UKRAINIAN_IU_TEST = _UD_27_HOME + "UD_Ukrainian-IU/uk_iu-ud-test.conllu"
+"UD_27 test set of UKRAINIAN_IU."
+UD_27_UPPER_SORBIAN_UFAL_TRAIN = _UD_27_HOME + "UD_Upper_Sorbian-UFAL/hsb_ufal-ud-train.conllu"
+"UD_27 train set of UPPER_SORBIAN_UFAL."
+UD_27_UPPER_SORBIAN_UFAL_TEST = _UD_27_HOME + "UD_Upper_Sorbian-UFAL/hsb_ufal-ud-test.conllu"
+"UD_27 test set of UPPER_SORBIAN_UFAL."
+UD_27_URDU_UDTB_TRAIN = _UD_27_HOME + "UD_Urdu-UDTB/ur_udtb-ud-train.conllu"
+"UD_27 train set of URDU_UDTB."
+UD_27_URDU_UDTB_DEV = _UD_27_HOME + "UD_Urdu-UDTB/ur_udtb-ud-dev.conllu"
+"UD_27 dev set of URDU_UDTB."
+UD_27_URDU_UDTB_TEST = _UD_27_HOME + "UD_Urdu-UDTB/ur_udtb-ud-test.conllu"
+"UD_27 test set of URDU_UDTB."
+UD_27_UYGHUR_UDT_TRAIN = _UD_27_HOME + "UD_Uyghur-UDT/ug_udt-ud-train.conllu"
+"UD_27 train set of UYGHUR_UDT."
+UD_27_UYGHUR_UDT_DEV = _UD_27_HOME + "UD_Uyghur-UDT/ug_udt-ud-dev.conllu"
+"UD_27 dev set of UYGHUR_UDT."
+UD_27_UYGHUR_UDT_TEST = _UD_27_HOME + "UD_Uyghur-UDT/ug_udt-ud-test.conllu"
+"UD_27 test set of UYGHUR_UDT."
+UD_27_VIETNAMESE_VTB_TRAIN = _UD_27_HOME + "UD_Vietnamese-VTB/vi_vtb-ud-train.conllu"
+"UD_27 train set of VIETNAMESE_VTB."
+UD_27_VIETNAMESE_VTB_DEV = _UD_27_HOME + "UD_Vietnamese-VTB/vi_vtb-ud-dev.conllu"
+"UD_27 dev set of VIETNAMESE_VTB."
+UD_27_VIETNAMESE_VTB_TEST = _UD_27_HOME + "UD_Vietnamese-VTB/vi_vtb-ud-test.conllu"
+"UD_27 test set of VIETNAMESE_VTB."
+UD_27_WARLPIRI_UFAL_TEST = _UD_27_HOME + "UD_Warlpiri-UFAL/wbp_ufal-ud-test.conllu"
+"UD_27 test set of WARLPIRI_UFAL."
+UD_27_WELSH_CCG_TRAIN = _UD_27_HOME + "UD_Welsh-CCG/cy_ccg-ud-train.conllu"
+"UD_27 train set of WELSH_CCG."
+UD_27_WELSH_CCG_TEST = _UD_27_HOME + "UD_Welsh-CCG/cy_ccg-ud-test.conllu"
+"UD_27 test set of WELSH_CCG."
+UD_27_WOLOF_WTB_TRAIN = _UD_27_HOME + "UD_Wolof-WTB/wo_wtb-ud-train.conllu"
+"UD_27 train set of WOLOF_WTB."
+UD_27_WOLOF_WTB_DEV = _UD_27_HOME + "UD_Wolof-WTB/wo_wtb-ud-dev.conllu"
+"UD_27 dev set of WOLOF_WTB."
+UD_27_WOLOF_WTB_TEST = _UD_27_HOME + "UD_Wolof-WTB/wo_wtb-ud-test.conllu"
+"UD_27 test set of WOLOF_WTB."
+UD_27_YORUBA_YTB_TEST = _UD_27_HOME + "UD_Yoruba-YTB/yo_ytb-ud-test.conllu"
+"UD_27 test set of YORUBA_YTB."
diff --git a/hanlp/datasets/parsing/ud/ud27m.py b/hanlp/datasets/parsing/ud/ud27m.py
new file mode 100644
index 000000000..5b5b984e4
--- /dev/null
+++ b/hanlp/datasets/parsing/ud/ud27m.py
@@ -0,0 +1,15 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-21 20:39
+import os
+
+from hanlp.datasets.parsing.ud import concat_treebanks
+from hanlp.datasets.parsing.ud.ud27 import _UD_27_HOME
+
+_UD_27_MULTILINGUAL_HOME = concat_treebanks(_UD_27_HOME, '2.7')
+UD_27_MULTILINGUAL_TRAIN = os.path.join(_UD_27_MULTILINGUAL_HOME, 'train.conllu')
+"Training set of multilingual UD_27 obtained by concatenating all training sets."
+UD_27_MULTILINGUAL_DEV = os.path.join(_UD_27_MULTILINGUAL_HOME, 'dev.conllu')
+"Dev set of multilingual UD_27 obtained by concatenating all dev sets."
+UD_27_MULTILINGUAL_TEST = os.path.join(_UD_27_MULTILINGUAL_HOME, 'test.conllu')
+"Test set of multilingual UD_27 obtained by concatenating all test sets."
diff --git a/hanlp/datasets/pos/ctb.py b/hanlp/datasets/pos/ctb.py
deleted file mode 100644
index 50a865871..000000000
--- a/hanlp/datasets/pos/ctb.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 22:51
-
-CTB5_POS_HOME = 'http://file.hankcs.com/corpus/ctb5.1-pos.zip'
-
-CTB5_POS_TRAIN = f'{CTB5_POS_HOME}#train.tsv'
-CTB5_POS_VALID = f'{CTB5_POS_HOME}#dev.tsv'
-CTB5_POS_TEST = f'{CTB5_POS_HOME}#test.tsv'
diff --git a/hanlp/datasets/pos/ctb5.py b/hanlp/datasets/pos/ctb5.py
new file mode 100644
index 000000000..8eb1d450d
--- /dev/null
+++ b/hanlp/datasets/pos/ctb5.py
@@ -0,0 +1,12 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 22:51
+
+_CTB5_POS_HOME = 'http://file.hankcs.com/corpus/ctb5.1-pos.zip'
+
+CTB5_POS_TRAIN = f'{_CTB5_POS_HOME}#train.tsv'
+'''PoS training set for CTB5.'''
+CTB5_POS_DEV = f'{_CTB5_POS_HOME}#dev.tsv'
+'''PoS dev set for CTB5.'''
+CTB5_POS_TEST = f'{_CTB5_POS_HOME}#test.tsv'
+'''PoS test set for CTB5.'''
diff --git a/hanlp/datasets/qa/__init__.py b/hanlp/datasets/qa/__init__.py
new file mode 100644
index 000000000..d4de204f5
--- /dev/null
+++ b/hanlp/datasets/qa/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-20 19:17
\ No newline at end of file
diff --git a/hanlp/datasets/qa/hotpotqa.py b/hanlp/datasets/qa/hotpotqa.py
new file mode 100644
index 000000000..9cbeb10f9
--- /dev/null
+++ b/hanlp/datasets/qa/hotpotqa.py
@@ -0,0 +1,194 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-20 19:46
+from enum import Enum, auto
+
+import torch
+import ujson
+from torch.nn.utils.rnn import pad_sequence
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp_common.util import merge_list_of_dict
+
+HOTPOT_QA_TRAIN = 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json'
+HOTPOT_QA_DISTRACTOR_DEV = 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json'
+HOTPOT_QA_FULLWIKI_DEV = 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'
+
+
+class HotpotQADataset(TransformableDataset):
+
+ def load_file(self, filepath):
+ with open(filepath) as fd:
+ return ujson.load(fd)
+
+
+class BuildGraph(object):
+
+ def __init__(self, dst='graph') -> None:
+ super().__init__()
+ self.dst = dst
+
+ def __call__(self, sample: dict):
+ sample[self.dst] = build_graph(sample)
+ return sample
+
+
+def hotpotqa_collate_fn(samples):
+ batch = merge_list_of_dict(samples)
+ max_seq_len = len(max([x['graph'] for x in samples], key=len))
+ arc = torch.zeros([len(samples), max_seq_len, max_seq_len])
+ token_offset = torch.zeros([len(samples), max_seq_len], dtype=torch.long)
+ src_mask = torch.zeros([len(samples), max_seq_len], dtype=torch.bool)
+ sp_candidate_mask = torch.zeros([len(samples), max_seq_len], dtype=torch.bool)
+ sp_label = torch.zeros([len(samples), max_seq_len], dtype=torch.float)
+ # sp = torch.zeros([len(samples), max_seq_len], dtype=torch.bool)
+ tokens = []
+ offset = 0
+ for i, sample in enumerate(samples):
+ graph = sample['graph']
+ for j, u in enumerate(graph):
+ u: Vertex = u
+ for v in u.to:
+ v: Vertex = v
+ arc[i, v.id, u.id] = 1
+ arc[i, u.id, v.id] = 1
+ # record each vertex's token offset
+ token_offset[i, u.id] = offset
+ src_mask[i, u.id] = True
+ sp_candidate_mask[i, u.id] = u.is_sp_root_candidate()
+ sp_label[i, u.id] = u.is_sp_root()
+ offset += 1
+ tokens.extend(sample['token_id'])
+ seq_lengths = torch.LongTensor(list(map(len, tokens)))
+ tokens = [torch.LongTensor(x) for x in tokens]
+ tokens = pad_sequence(tokens, batch_first=True)
+ batch['adj'] = arc
+ batch['tokens'] = tokens
+ batch['src_mask'] = src_mask
+ batch['seq_lengths'] = seq_lengths
+ batch['token_offset'] = token_offset
+ batch['sp_candidate_mask'] = sp_candidate_mask
+ batch['sp_label'] = sp_label
+ return batch
+
+
+def flat_sentence(sample: dict) -> dict:
+ sample['token'] = token = []
+ for sent in sample['parsed_sentences']:
+ token.append(['bos'] + [x.lower() for x in sent[0]])
+ return sample
+
+
+def create_sp_label(sample: dict) -> dict:
+ sample['sp_label'] = sp_label = []
+
+ def label(title_, index_):
+ for t, i in sample['supporting_facts']:
+ if t == title_ and i == index_:
+ return 1
+ return 0
+
+ for context in sample['context']:
+ title, sents = context
+ for idx, sent in enumerate(sents):
+ sp_label.append(label(title, idx))
+ assert len(sample['supporting_facts']) == sum(sp_label)
+ return sample
+
+
+class Type(Enum):
+ Q_ROOT = auto()
+ Q_WORD = auto()
+ SP_ROOT = auto()
+ SP_WORD = auto()
+ NON_SP_ROOT = auto()
+ NON_SP_WORD = auto()
+ DOCUMENT_TITLE = auto()
+
+
+class Vertex(object):
+
+ def __init__(self, id, type: Type, text=None) -> None:
+ super().__init__()
+ self.id = id
+ self.type = type
+ if not text:
+ text = str(type).split('.')[-1]
+ self.text = text
+ self.to = []
+ self.rel = []
+
+ def connect(self, to, rel):
+ self.to.append(to)
+ self.rel.append(rel)
+
+ def __str__(self) -> str:
+ return f'{self.text} {self.id}'
+
+ def __hash__(self) -> int:
+ return self.id
+
+ def is_word(self):
+ return self.type in {Type.SP_WORD, Type.Q_WORD, Type.NON_SP_WORD}
+
+ def is_question(self):
+ return self.type in {Type.Q_ROOT, Type.Q_WORD}
+
+ def is_sp(self):
+ return self.type in {Type.SP_ROOT, Type.SP_WORD}
+
+ def is_sp_root(self):
+ return self.type in {Type.SP_ROOT}
+
+ def is_sp_root_candidate(self):
+ return self.type in {Type.SP_ROOT, Type.NON_SP_ROOT}
+
+
+def build_graph(each: dict, debug=False):
+ raw_sents = []
+ raw_sents.append(each['question'])
+ sp_idx = set()
+ sp_sents = {}
+ for sp in each['supporting_facts']:
+ title, offset = sp
+ ids = sp_sents.get(title, None)
+ if ids is None:
+ sp_sents[title] = ids = set()
+ ids.add(offset)
+ idx = 1
+ for document in each['context']:
+ title, sents = document
+ raw_sents += sents
+ for i, s in enumerate(sents):
+ if title in sp_sents and i in sp_sents[title]:
+ sp_idx.add(idx)
+ idx += 1
+ assert idx == len(raw_sents)
+ parsed_sents = each['parsed_sentences']
+ assert len(raw_sents) == len(parsed_sents)
+ graph = []
+ for idx, (raw, sent) in enumerate(zip(raw_sents, parsed_sents)):
+ if debug:
+ if idx > 1 and idx not in sp_idx:
+ continue
+ offset = len(graph)
+ if idx == 0:
+ if debug:
+ print(f'Question: {raw}')
+ graph.append(Vertex(len(graph), Type.Q_ROOT))
+ else:
+ if debug:
+ if idx in sp_idx:
+ print(f'Supporting Fact: {raw}')
+ graph.append(Vertex(len(graph), Type.SP_ROOT if idx in sp_idx else Type.NON_SP_ROOT))
+ tokens, heads, deprels = sent
+ for t, h, d in zip(tokens, heads, deprels):
+ graph.append(
+ Vertex(len(graph), (Type.SP_WORD if idx in sp_idx else Type.NON_SP_WORD) if idx else Type.Q_WORD, t))
+ for i, (h, d) in enumerate(zip(heads, deprels)):
+ graph[offset + h].connect(graph[offset + i + 1], d)
+ q_root = graph[0]
+ for u in graph:
+ if u.type == Type.SP_ROOT or u.type == Type.NON_SP_ROOT:
+ q_root.connect(u, 'supporting fact?')
+ return graph
diff --git a/hanlp/datasets/srl/__init__.py b/hanlp/datasets/srl/__init__.py
new file mode 100644
index 000000000..107252885
--- /dev/null
+++ b/hanlp/datasets/srl/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-22 19:15
+
+
diff --git a/hanlp/datasets/srl/conll2012.py b/hanlp/datasets/srl/conll2012.py
new file mode 100644
index 000000000..9d7206e83
--- /dev/null
+++ b/hanlp/datasets/srl/conll2012.py
@@ -0,0 +1,227 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-22 19:15
+import glob
+import json
+import os
+from typing import Union, List, Callable
+
+from alnlp.metrics.span_utils import enumerate_spans
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp.common.transform import NamedTransform
+from hanlp.utils.io_util import read_tsv_as_sents, get_resource, TimingFileIterator
+from hanlp.utils.time_util import CountdownTimer
+
+
+class CoNLL2012BIOSRLDataset(TransformableDataset):
+ def load_file(self, filepath: str):
+ filepath = get_resource(filepath)
+ if os.path.isfile(filepath):
+ files = [filepath]
+ else:
+ assert os.path.isdir(filepath), f'{filepath} has to be a directory of CoNLL 2012'
+ files = sorted(glob.glob(f'{filepath}/**/*gold_conll', recursive=True))
+ timer = CountdownTimer(len(files))
+ for fid, f in enumerate(files):
+ timer.log(f'files loading[blink][yellow]...[/yellow][/blink]')
+ # 0:DOCUMENT 1:PART 2:INDEX 3:WORD 4:POS 5:PARSE 6:LEMMA 7:FRAME 8:SENSE 9:SPEAKER 10:NE 11-N:ARGS N:COREF
+ for sent in read_tsv_as_sents(f, ignore_prefix='#'):
+ sense = [cell[7] for cell in sent]
+ props = [cell[11:-1] for cell in sent]
+ props = map(lambda p: p, zip(*props))
+ prd_bio_labels = [self._make_bio_labels(prop) for prop in props]
+ prd_bio_labels = [self._remove_B_V(x) for x in prd_bio_labels]
+ prd_indices = [i for i, x in enumerate(sense) if x != '-']
+ token = [x[3] for x in sent]
+ srl = [None for x in token]
+ for idx, labels in zip(prd_indices, prd_bio_labels):
+ srl[idx] = labels
+ srl = [x if x else ['O'] * len(token) for x in srl]
+ yield {'token': token, 'srl': srl}
+
+ @staticmethod
+ def _make_bio_labels(prop):
+ """Copied from https://github.com/hiroki13/span-based-srl/blob/2c8b677c4e00b6c607e09ef4f9fe3d54961e4f2e/src/utils/sent.py#L42
+
+ Args:
+ prop: 1D: n_words; elem=bracket label
+
+ Returns:
+ 1D: n_words; elem=BIO label
+
+ """
+ labels = []
+ prev = None
+ for arg in prop:
+ if arg.startswith('('):
+ if arg.endswith(')'):
+ prev = arg.split("*")[0][1:]
+ label = 'B-' + prev
+ prev = None
+ else:
+ prev = arg[1:-1]
+ label = 'B-' + prev
+ else:
+ if prev:
+ label = 'I-' + prev
+ if arg.endswith(')'):
+ prev = None
+ else:
+ label = 'O'
+ labels.append(label)
+ return labels
+
+ @staticmethod
+ def _remove_B_V(labels):
+ return ['O' if x == 'B-V' else x for x in labels]
+
+
+class CoNLL2012SRLDataset(TransformableDataset):
+
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ doc_level_offset=True,
+ generate_idx=None) -> None:
+ self.doc_level_offset = doc_level_offset
+ super().__init__(data, transform, cache, generate_idx=generate_idx)
+
+ def load_file(self, filepath: str):
+ """Load ``.jsonlines`` CoNLL12-style corpus. Samples of this corpus can be found using the following scripts.
+
+ .. highlight:: python
+ .. code-block:: python
+
+ import json
+ from hanlp_common.document import Document
+ from hanlp.datasets.srl.ontonotes5.chinese import ONTONOTES5_CONLL12_CHINESE_DEV
+ from hanlp.utils.io_util import get_resource
+
+ with open(get_resource(ONTONOTES5_CONLL12_CHINESE_DEV)) as src:
+ for line in src:
+ doc = json.loads(line)
+ print(Document(doc))
+ break
+
+ Args:
+ filepath: ``.jsonlines`` CoNLL12 corpus.
+ """
+ filename = os.path.basename(filepath)
+ reader = TimingFileIterator(filepath)
+ num_docs, num_sentences = 0, 0
+ for line in reader:
+ doc = json.loads(line)
+ num_docs += 1
+ num_tokens_in_doc = 0
+ for sid, (sentence, srl) in enumerate(zip(doc['sentences'], doc['srl'])):
+ if self.doc_level_offset:
+ srl = [(x[0] - num_tokens_in_doc, x[1] - num_tokens_in_doc, x[2] - num_tokens_in_doc, x[3]) for x in
+ srl]
+ else:
+ srl = [(x[0], x[1], x[2], x[3]) for x in srl]
+ for x in srl:
+ if any([o < 0 for o in x[:3]]):
+ raise ValueError(f'Negative offset occurred, maybe doc_level_offset=False')
+ if any([o >= len(sentence) for o in x[:3]]):
+ raise ValueError('Offset exceeds sentence length, maybe doc_level_offset=True')
+ deduplicated_srl = set()
+ pa_set = set()
+ for p, b, e, l in srl:
+ pa = (p, b, e)
+ if pa in pa_set:
+ continue
+ pa_set.add(pa)
+ deduplicated_srl.add((p, b, e, l))
+ yield self.build_sample(sentence, deduplicated_srl, doc, sid)
+ num_sentences += 1
+ num_tokens_in_doc += len(sentence)
+ reader.log(
+ f'{filename} {num_docs} documents, {num_sentences} sentences [blink][yellow]...[/yellow][/blink]')
+ reader.erase()
+
+ # noinspection PyMethodMayBeStatic
+ def build_sample(self, sentence, deduplicated_srl, doc, sid):
+ return {
+ 'token': sentence,
+ 'srl': deduplicated_srl
+ }
+
+
+def group_pa_by_p(sample: dict) -> dict:
+ if 'srl' in sample:
+ srl: list = sample['srl']
+ grouped_srl = group_pa_by_p_(srl)
+ sample['srl'] = grouped_srl
+ return sample
+
+
+def group_pa_by_p_(srl):
+ grouped_srl = {}
+ for p, b, e, l in srl:
+ bel = grouped_srl.get(p, None)
+ if not bel:
+ bel = grouped_srl[p] = set()
+ bel.add((b, e, l))
+ return grouped_srl
+
+
+def filter_v_args(sample: dict) -> dict:
+ if 'srl' in sample:
+ sample['srl'] = [t for t in sample['srl'] if t[-1] not in ["V", "C-V"]]
+ return sample
+
+
+def unpack_srl(sample: dict) -> dict:
+ if 'srl' in sample:
+ srl = sample['srl']
+ predicate_offset = [x[0] for x in srl]
+ argument_begin_offset = [x[1] for x in srl]
+ argument_end_offset = [x[2] for x in srl]
+ srl_label = [x[-1] for x in srl]
+ sample.update({
+ 'predicate_offset': predicate_offset,
+ 'argument_begin_offset': argument_begin_offset,
+ 'argument_end_offset': argument_end_offset,
+ 'srl_label': srl_label, # We can obtain mask by srl_label > 0
+ # 'srl_mask': len(srl_label),
+ })
+ return sample
+
+
+class SpanCandidatesGenerator(NamedTransform):
+
+ def __init__(self, src: str, dst: str = None, max_span_width=None) -> None:
+ if not dst:
+ dst = f'{src}_span'
+ super().__init__(src, dst)
+ self.max_span_width = max_span_width
+
+ def __call__(self, sample: dict) -> dict:
+ sample[self.dst] = list(enumerate_spans(sample[self.src], max_span_width=self.max_span_width))
+ return sample
+
+
+class CoNLL2012SRLBIODataset(CoNLL2012SRLDataset):
+ def build_sample(self, tokens, deduplicated_srl, doc, sid):
+ # Convert srl to exclusive format
+ deduplicated_srl = set((x[0], x[1], x[2] + 1, x[3]) for x in deduplicated_srl if x[3] != 'V')
+ labels = [['O'] * len(tokens) for _ in range(len(tokens))]
+ srl = group_pa_by_p_(deduplicated_srl)
+ for p, args in sorted(srl.items()):
+ labels_per_p = labels[p]
+ for start, end, label in args:
+ assert end > start
+ assert label != 'V' # We don't predict predicate
+ labels_per_p[start] = 'B-' + label
+ for j in range(start + 1, end):
+ labels_per_p[j] = 'I-' + label
+ sample = {
+ 'token': tokens,
+ 'srl': labels,
+ 'srl_set': deduplicated_srl,
+ }
+ if 'pos' in doc:
+ sample['pos'] = doc['pos'][sid]
+ return sample
diff --git a/hanlp/datasets/srl/ontonotes5/__init__.py b/hanlp/datasets/srl/ontonotes5/__init__.py
new file mode 100644
index 000000000..bb9fb0a49
--- /dev/null
+++ b/hanlp/datasets/srl/ontonotes5/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-26 16:07
+ONTONOTES5_HOME = 'https://catalog.ldc.upenn.edu/LDC2013T19/ontonotes-release-5.0.tgz#data/'
+CONLL12_HOME = ONTONOTES5_HOME + '../conll-2012/'
diff --git a/hanlp/datasets/srl/ontonotes5/_utils.py b/hanlp/datasets/srl/ontonotes5/_utils.py
new file mode 100644
index 000000000..afc9547a5
--- /dev/null
+++ b/hanlp/datasets/srl/ontonotes5/_utils.py
@@ -0,0 +1,450 @@
+#!/usr/bin/env python
+import codecs
+import collections
+import glob
+import json
+import os
+import re
+import sys
+import shutil
+from hanlp.utils.io_util import merge_files, get_resource, pushd, run_cmd
+from hanlp_common.io import eprint
+from pprint import pprint
+
+from hanlp.utils.log_util import flash
+
+BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)")
+
+
+def flatten(l):
+ return [item for sublist in l for item in sublist]
+
+
+def get_doc_key(doc_id, part):
+ return "{}_{}".format(doc_id, int(part))
+
+
+class DocumentState(object):
+ def __init__(self):
+ self.doc_key = None
+ self.text = []
+ self.text_speakers = []
+ self.speakers = []
+ self.sentences = []
+ self.pos = []
+ self.lemma = []
+ self.pos_buffer = []
+ self.lemma_buffer = []
+ self.constituents = [] # {}
+ self.const_stack = []
+ self.const_buffer = []
+ self.ner = []
+ self.ner_stack = []
+ self.ner_buffer = []
+ self.srl = []
+ self.argument_stacks = []
+ self.argument_buffers = []
+ self.predicate_buffer = []
+ self.clusters = collections.defaultdict(list)
+ self.coref_stacks = collections.defaultdict(list)
+
+ def assert_empty(self):
+ assert self.doc_key is None
+ assert len(self.text) == 0
+ assert len(self.text_speakers) == 0
+ assert len(self.speakers) == 0
+ assert len(self.sentences) == 0
+ assert len(self.srl) == 0
+ assert len(self.predicate_buffer) == 0
+ assert len(self.argument_buffers) == 0
+ assert len(self.argument_stacks) == 0
+ assert len(self.constituents) == 0
+ assert len(self.const_stack) == 0
+ assert len(self.const_buffer) == 0
+ assert len(self.ner) == 0
+ assert len(self.lemma_buffer) == 0
+ assert len(self.pos_buffer) == 0
+ assert len(self.ner_stack) == 0
+ assert len(self.ner_buffer) == 0
+ assert len(self.coref_stacks) == 0
+ assert len(self.clusters) == 0
+
+ def assert_finalizable(self):
+ assert self.doc_key is not None
+ assert len(self.text) == 0
+ assert len(self.text_speakers) == 0
+ assert len(self.speakers) > 0
+ assert len(self.sentences) > 0
+ assert len(self.constituents) > 0
+ assert len(self.const_stack) == 0
+ assert len(self.ner_stack) == 0
+ assert len(self.predicate_buffer) == 0
+ assert all(len(s) == 0 for s in list(self.coref_stacks.values()))
+
+ def finalize_sentence(self):
+ self.sentences.append(tuple(self.text))
+ del self.text[:]
+ self.lemma.append(tuple(self.lemma_buffer))
+ del self.lemma_buffer[:]
+ self.pos.append(tuple(self.pos_buffer))
+ del self.pos_buffer[:]
+ self.speakers.append(tuple(self.text_speakers))
+ del self.text_speakers[:]
+
+ assert len(self.predicate_buffer) == len(self.argument_buffers)
+ self.srl.append([])
+ for pred, args in zip(self.predicate_buffer, self.argument_buffers):
+ for start, end, label in args:
+ self.srl[-1].append((pred, start, end, label))
+ self.predicate_buffer = []
+ self.argument_buffers = []
+ self.argument_stacks = []
+ self.constituents.append([c for c in self.const_buffer])
+ self.const_buffer = []
+ self.ner.append([c for c in self.ner_buffer])
+ self.ner_buffer = []
+
+ def finalize(self):
+ merged_clusters = []
+ for c1 in list(self.clusters.values()):
+ existing = None
+ for m in c1:
+ for c2 in merged_clusters:
+ if m in c2:
+ existing = c2
+ break
+ if existing is not None:
+ break
+ if existing is not None:
+ print("Merging clusters (shouldn't happen very often.)")
+ existing.update(c1)
+ else:
+ merged_clusters.append(set(c1))
+ merged_clusters = [list(c) for c in merged_clusters]
+ all_mentions = flatten(merged_clusters)
+ assert len(all_mentions) == len(set(all_mentions))
+ assert len(self.sentences) == len(self.srl)
+ assert len(self.sentences) == len(self.constituents)
+ assert len(self.sentences) == len(self.ner)
+ return {
+ "doc_key": self.doc_key,
+ "sentences": self.sentences,
+ "lemma": self.lemma,
+ "pos": self.pos,
+ "speakers": self.speakers,
+ "srl": self.srl,
+ "constituents": self.constituents,
+ "ner": self.ner,
+ "clusters": merged_clusters
+ }
+
+
+def filter_data(v5_input_file, doc_ids_file, output_file):
+ """Filter OntoNotes5 data based on CoNLL2012 (coref) doc ids.
+ https://github.com/bcmi220/unisrl/blob/master/scripts/filter_conll2012_data.py
+
+ Args:
+ v5_input_file: param doc_ids_file:
+ output_file:
+ doc_ids_file:
+
+ Returns:
+
+ """
+ doc_count = 0
+ sentence_count = 0
+ srl_count = 0
+ ner_count = 0
+ cluster_count = 0
+ word_count = 0
+ doc_ids = []
+ doc_ids_to_keys = {}
+ filtered_examples = {}
+
+ with open(doc_ids_file, "r") as f:
+ for line in f:
+ doc_id = line.strip().split("annotations/")[1]
+ doc_ids.append(doc_id)
+ doc_ids_to_keys[doc_id] = []
+ f.close()
+
+ with codecs.open(v5_input_file, "r", "utf8") as f:
+ for jsonline in f:
+ example = json.loads(jsonline)
+ doc_key = example["doc_key"]
+ dk_prefix = "_".join(doc_key.split("_")[:-1])
+ if dk_prefix not in doc_ids_to_keys:
+ continue
+ doc_ids_to_keys[dk_prefix].append(doc_key)
+ filtered_examples[doc_key] = example
+
+ sentences = example["sentences"]
+ word_count += sum([len(s) for s in sentences])
+ sentence_count += len(sentences)
+ srl_count += sum([len(srl) for srl in example["srl"]])
+ ner_count += sum([len(ner) for ner in example["ner"]])
+ coref = example["clusters"]
+ cluster_count += len(coref)
+ doc_count += 1
+ f.close()
+
+ print(("Documents: {}\nSentences: {}\nWords: {}\nNER: {}, PAS: {}, Clusters: {}".format(
+ doc_count, sentence_count, word_count, ner_count, srl_count, cluster_count)))
+
+ with codecs.open(output_file, "w", "utf8") as f:
+ for doc_id in doc_ids:
+ for key in doc_ids_to_keys[doc_id]:
+ f.write(json.dumps(filtered_examples[key], ensure_ascii=False))
+ f.write("\n")
+ f.close()
+
+
+def normalize_word(word, language):
+ if language == "arabic":
+ word = word[:word.find("#")]
+ if word == "/." or word == "/?":
+ return word[1:]
+ else:
+ return word
+
+
+def handle_bit(word_index, bit, stack, spans, label_set):
+ asterisk_idx = bit.find("*")
+ if asterisk_idx >= 0:
+ open_parens = bit[:asterisk_idx]
+ close_parens = bit[asterisk_idx + 1:]
+ else:
+ open_parens = bit[:-1]
+ close_parens = bit[-1]
+
+ current_idx = open_parens.find("(")
+ while current_idx >= 0:
+ next_idx = open_parens.find("(", current_idx + 1)
+ if next_idx >= 0:
+ label = open_parens[current_idx + 1:next_idx]
+ else:
+ label = open_parens[current_idx + 1:]
+ label_set.add(label)
+ stack.append((word_index, label))
+ current_idx = next_idx
+
+ for c in close_parens:
+ try:
+ assert c == ")"
+ except AssertionError:
+ print(word_index, bit, spans, stack)
+ continue
+ open_index, label = stack.pop()
+ spans.append((open_index, word_index, label))
+ ''' current_span = (open_index, word_index)
+ if current_span in spans:
+ spans[current_span] += "_" + label
+ else:
+ spans[current_span] = label
+ spans[current_span] = label '''
+
+
+def handle_line(line, document_state: DocumentState, language, labels, stats):
+ begin_document_match = re.match(BEGIN_DOCUMENT_REGEX, line)
+ if begin_document_match:
+ document_state.assert_empty()
+ document_state.doc_key = get_doc_key(begin_document_match.group(1), begin_document_match.group(2))
+ return None
+ elif line.startswith("#end document"):
+ document_state.assert_finalizable()
+ finalized_state = document_state.finalize()
+ stats["num_clusters"] += len(finalized_state["clusters"])
+ stats["num_mentions"] += sum(len(c) for c in finalized_state["clusters"])
+ # labels["{}_const_labels".format(language)].update(l for _, _, l in finalized_state["constituents"])
+ # labels["ner"].update(l for _, _, l in finalized_state["ner"])
+ return finalized_state
+ else:
+ row = line.split()
+ # Starting a new sentence.
+ if len(row) == 0:
+ stats["max_sent_len_{}".format(language)] = max(len(document_state.text),
+ stats["max_sent_len_{}".format(language)])
+ stats["num_sents_{}".format(language)] += 1
+ document_state.finalize_sentence()
+ return None
+ assert len(row) >= 12
+
+ doc_key = get_doc_key(row[0], row[1])
+ word = normalize_word(row[3], language)
+ pos = row[4]
+ parse = row[5]
+ lemma = row[6]
+ predicate_sense = row[7]
+ speaker = row[9]
+ ner = row[10]
+ args = row[11:-1]
+ coref = row[-1]
+
+ word_index = len(document_state.text) + sum(len(s) for s in document_state.sentences)
+ document_state.text.append(word)
+ document_state.text_speakers.append(speaker)
+ document_state.pos_buffer.append(pos)
+ document_state.lemma_buffer.append(lemma)
+
+ handle_bit(word_index, parse, document_state.const_stack, document_state.const_buffer, labels["categories"])
+ handle_bit(word_index, ner, document_state.ner_stack, document_state.ner_buffer, labels["ner"])
+
+ if len(document_state.argument_stacks) < len(args):
+ document_state.argument_stacks = [[] for _ in args]
+ document_state.argument_buffers = [[] for _ in args]
+
+ for i, arg in enumerate(args):
+ handle_bit(word_index, arg, document_state.argument_stacks[i], document_state.argument_buffers[i],
+ labels["srl"])
+ if predicate_sense != "-":
+ document_state.predicate_buffer.append(word_index)
+ if coref != "-":
+ for segment in coref.split("|"):
+ if segment[0] == "(":
+ if segment[-1] == ")":
+ cluster_id = int(segment[1:-1])
+ document_state.clusters[cluster_id].append((word_index, word_index))
+ else:
+ cluster_id = int(segment[1:])
+ document_state.coref_stacks[cluster_id].append(word_index)
+ else:
+ cluster_id = int(segment[:-1])
+ start = document_state.coref_stacks[cluster_id].pop()
+ document_state.clusters[cluster_id].append((start, word_index))
+ return None
+
+
+def ontonotes_document_generator(input_path, language, labels, stats):
+ with open(input_path, "r") as input_file:
+ document_state = DocumentState()
+ for line in input_file.readlines():
+ document = handle_line(line, document_state, language, labels, stats)
+ if document is not None:
+ yield document
+ document_state = DocumentState()
+
+
+def convert_to_jsonlines(input_path, output_path, language, labels=None, stats=None):
+ if labels is None:
+ labels = collections.defaultdict(set)
+ if stats is None:
+ stats = collections.defaultdict(int)
+ count = 0
+ with open(output_path, "w") as output_file:
+ for document in ontonotes_document_generator(input_path, language, labels, stats):
+ output_file.write(json.dumps(document, ensure_ascii=False))
+ output_file.write("\n")
+ count += 1
+
+ return labels, stats
+
+
+def make_ontonotes_jsonlines(conll12_ontonotes_path, output_path, languages=None):
+ if languages is None:
+ languages = ['english', 'chinese', 'arabic']
+ for language in languages:
+ make_ontonotes_language_jsonlines(conll12_ontonotes_path, output_path, language)
+
+
+def make_ontonotes_language_jsonlines(conll12_ontonotes_path, output_path=None, language='english'):
+ conll12_ontonotes_path = get_resource(conll12_ontonotes_path)
+ if output_path is None:
+ output_path = os.path.dirname(conll12_ontonotes_path)
+ for split in ['train', 'development', 'test']:
+ pattern = f'{conll12_ontonotes_path}/data/{split}/data/{language}/annotations/*/*/*/*gold_conll'
+ files = sorted(glob.glob(pattern, recursive=True))
+ assert files, f'No gold_conll files found in {pattern}'
+ version = os.path.basename(files[0]).split('.')[-1].split('_')[0]
+ if version.startswith('v'):
+ assert all([version in os.path.basename(f) for f in files])
+ else:
+ version = 'v5'
+ lang_dir = f'{output_path}/{language}'
+ if split == 'conll-2012-test':
+ split = 'test'
+ full_file = f'{lang_dir}/{split}.{language}.{version}_gold_conll'
+ os.makedirs(lang_dir, exist_ok=True)
+ print(f'Merging {len(files)} files to {full_file}')
+ merge_files(files, full_file)
+ v5_json_file = full_file.replace(f'.{version}_gold_conll', f'.{version}.jsonlines')
+ print(f'Converting CoNLL file {full_file} to json file {v5_json_file}')
+ labels, stats = convert_to_jsonlines(full_file, v5_json_file, language)
+ print('Labels:')
+ pprint(labels)
+ print('Statistics:')
+ pprint(stats)
+ conll12_json_file = f'{lang_dir}/{split}.{language}.conll12.jsonlines'
+ print(f'Applying CoNLL 12 official splits on {v5_json_file} to {conll12_json_file}')
+ id_file = get_resource(f'http://conll.cemantix.org/2012/download/ids/'
+ f'{language}/coref/{split}.id')
+ filter_data(v5_json_file, id_file, conll12_json_file)
+
+
+def make_gold_conll(ontonotes_path, language):
+ ontonotes_path = os.path.abspath(get_resource(ontonotes_path))
+ to_conll = get_resource(
+ 'https://gist.githubusercontent.com/hankcs/46b9137016c769e4b6137104daf43a92/raw/66369de6c24b5ec47696ae307591f0d72c6f3f02/ontonotes_to_conll.sh')
+ to_conll = os.path.abspath(to_conll)
+ # shutil.rmtree(os.path.join(ontonotes_path, 'conll-2012'), ignore_errors=True)
+ with pushd(ontonotes_path):
+ try:
+ flash(f'Converting [blue]{language}[/blue] to CoNLL format, '
+ f'this might take half an hour [blink][yellow]...[/yellow][/blink]')
+ run_cmd(f'bash {to_conll} {ontonotes_path} {language}')
+ flash('')
+ except RuntimeError as e:
+ flash(f'[red]Failed[/red] to convert {language} of {ontonotes_path} to CoNLL. See exceptions for detail')
+ raise e
+
+
+def convert_jsonlines_to_IOBES(json_file, output_file=None, doc_level_offset=True):
+ json_file = get_resource(json_file)
+ if not output_file:
+ output_file = os.path.splitext(json_file)[0] + '.ner.tsv'
+ with open(json_file) as src, open(output_file, 'w', encoding='utf-8') as out:
+ for line in src:
+ doc = json.loads(line)
+ offset = 0
+ for sent, ner in zip(doc['sentences'], doc['ner']):
+ tags = ['O'] * len(sent)
+ for start, end, label in ner:
+ if doc_level_offset:
+ start -= offset
+ end -= offset
+ if start == end:
+ tags[start] = 'S-' + label
+ else:
+ tags[start] = 'B-' + label
+ for i in range(start + 1, end + 1):
+ tags[i] = 'I-' + label
+ tags[end] = 'E-' + label
+ offset += len(sent)
+ for token, tag in zip(sent, tags):
+ out.write(f'{token}\t{tag}\n')
+ out.write('\n')
+
+
+def make_ner_tsv_if_necessary(json_file):
+ json_file = get_resource(json_file)
+ output_file = os.path.splitext(json_file)[0] + '.ner.tsv'
+ if not os.path.isfile(output_file):
+ convert_jsonlines_to_IOBES(json_file, output_file)
+ return output_file
+
+
+def batch_make_ner_tsv_if_necessary(json_files):
+ for each in json_files:
+ make_ner_tsv_if_necessary(each)
+
+
+def main():
+ if len(sys.argv) != 3:
+ eprint('2 arguments required: ontonotes_path output_path')
+ exit(1)
+ ontonotes_path = sys.argv[1]
+ output_path = sys.argv[2]
+ make_ontonotes_jsonlines(ontonotes_path, output_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/hanlp/datasets/srl/ontonotes5/chinese.py b/hanlp/datasets/srl/ontonotes5/chinese.py
new file mode 100644
index 000000000..fd0c56d04
--- /dev/null
+++ b/hanlp/datasets/srl/ontonotes5/chinese.py
@@ -0,0 +1,57 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-26 16:07
+import os
+from urllib.error import HTTPError
+import shutil
+
+from hanlp.datasets.srl.ontonotes5 import ONTONOTES5_HOME, CONLL12_HOME
+from hanlp.datasets.srl.ontonotes5._utils import make_gold_conll, make_ontonotes_language_jsonlines, \
+ batch_make_ner_tsv_if_necessary
+from hanlp.utils.io_util import get_resource, path_from_url
+from hanlp.utils.log_util import cprint, flash
+
+_ONTONOTES5_CHINESE_HOME = ONTONOTES5_HOME + 'files/data/chinese/'
+_ONTONOTES5_CONLL12_CHINESE_HOME = CONLL12_HOME + 'chinese/'
+ONTONOTES5_CONLL12_CHINESE_TRAIN = _ONTONOTES5_CONLL12_CHINESE_HOME + 'train.chinese.conll12.jsonlines'
+'''Training set of OntoNotes5 used in CoNLL12 (:cite:`pradhan-etal-2012-conll`).'''
+ONTONOTES5_CONLL12_CHINESE_DEV = _ONTONOTES5_CONLL12_CHINESE_HOME + 'development.chinese.conll12.jsonlines'
+'''Dev set of OntoNotes5 used in CoNLL12 (:cite:`pradhan-etal-2012-conll`).'''
+ONTONOTES5_CONLL12_CHINESE_TEST = _ONTONOTES5_CONLL12_CHINESE_HOME + 'test.chinese.conll12.jsonlines'
+'''Test set of OntoNotes5 used in CoNLL12 (:cite:`pradhan-etal-2012-conll`).'''
+
+ONTONOTES5_CONLL12_NER_CHINESE_TRAIN = _ONTONOTES5_CONLL12_CHINESE_HOME + 'train.chinese.conll12.ner.tsv'
+'''Training set of OntoNotes5 used in CoNLL12 (:cite:`pradhan-etal-2012-conll`).'''
+ONTONOTES5_CONLL12_NER_CHINESE_DEV = _ONTONOTES5_CONLL12_CHINESE_HOME + 'development.chinese.conll12.ner.tsv'
+'''Dev set of OntoNotes5 used in CoNLL12 (:cite:`pradhan-etal-2012-conll`).'''
+ONTONOTES5_CONLL12_NER_CHINESE_TEST = _ONTONOTES5_CONLL12_CHINESE_HOME + 'test.chinese.conll12.ner.tsv'
+'''Test set of OntoNotes5 used in CoNLL12 (:cite:`pradhan-etal-2012-conll`).'''
+
+try:
+ get_resource(ONTONOTES5_HOME, verbose=False)
+except HTTPError:
+ intended_file_path = path_from_url(ONTONOTES5_HOME)
+ cprint('Ontonotes 5.0 is a [red][bold]copyright[/bold][/red] dataset owned by LDC which we cannot re-distribute. '
+ f'Please apply for a licence from LDC (https://catalog.ldc.upenn.edu/LDC2016T13) '
+ f'then download it to {intended_file_path}')
+ cprint('Luckily, an [red]unofficial[/red] Chinese version is provided on GitHub '
+ 'which will be used for demonstration purpose.')
+ unofficial_chinese = get_resource('https://github.com/GuocaiL/Coref_Resolution/archive/master.zip#data/')
+ intended_home, _ = os.path.splitext(intended_file_path)
+ intended_chinese = f'{intended_home}/data/files/data/chinese/'
+ # print(os.path.dirname(intended_chinese))
+ # print(unofficial_chinese)
+ # print(intended_chinese)
+ for folder in ['annotations', 'metadata']:
+ flash(f'Copying {unofficial_chinese}{folder} to {intended_chinese}{folder} [blink][yellow]...[/yellow][/blink]')
+ shutil.copytree(f'{unofficial_chinese}{folder}', f'{intended_chinese}{folder}')
+ flash('')
+
+try:
+ get_resource(ONTONOTES5_CONLL12_CHINESE_TRAIN, verbose=False)
+except HTTPError:
+ make_gold_conll(ONTONOTES5_HOME + '..', 'chinese')
+ make_ontonotes_language_jsonlines(CONLL12_HOME + 'v4', language='chinese')
+
+batch_make_ner_tsv_if_necessary(
+ [ONTONOTES5_CONLL12_CHINESE_TRAIN, ONTONOTES5_CONLL12_CHINESE_DEV, ONTONOTES5_CONLL12_CHINESE_TEST])
diff --git a/hanlp/datasets/srl/ontonotes5/english.py b/hanlp/datasets/srl/ontonotes5/english.py
new file mode 100644
index 000000000..a206cf70a
--- /dev/null
+++ b/hanlp/datasets/srl/ontonotes5/english.py
@@ -0,0 +1,28 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-25 18:48
+import glob
+import os
+
+from hanlp.utils.io_util import get_resource, merge_files
+
+_CONLL2012_EN_HOME = 'https://github.com/yuchenlin/OntoNotes-5.0-NER-BIO/archive/master.zip#conll-formatted-ontonotes-5.0/data'
+# These are v4 of OntoNotes, in .conll format
+CONLL2012_EN_TRAIN = _CONLL2012_EN_HOME + '/train/data/english/annotations'
+CONLL2012_EN_DEV = _CONLL2012_EN_HOME + '/development/data/english/annotations'
+CONLL2012_EN_TEST = _CONLL2012_EN_HOME + '/conll-2012-test/data/english/annotations'
+
+
+def conll_2012_en_combined():
+ home = get_resource(_CONLL2012_EN_HOME)
+ outputs = ['train', 'dev', 'test']
+ for i in range(len(outputs)):
+ outputs[i] = f'{home}/conll12_en/{outputs[i]}.conll'
+ if all(os.path.isfile(x) for x in outputs):
+ return outputs
+ os.makedirs(os.path.dirname(outputs[0]), exist_ok=True)
+ for in_path, out_path in zip([CONLL2012_EN_TRAIN, CONLL2012_EN_DEV, CONLL2012_EN_TEST], outputs):
+ in_path = get_resource(in_path)
+ files = sorted(glob.glob(f'{in_path}/**/*gold_conll', recursive=True))
+ merge_files(files, out_path)
+ return outputs
diff --git a/hanlp/datasets/tokenization/__init__.py b/hanlp/datasets/tokenization/__init__.py
new file mode 100644
index 000000000..35d0150a5
--- /dev/null
+++ b/hanlp/datasets/tokenization/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-01 12:33
\ No newline at end of file
diff --git a/hanlp/datasets/tokenization/txt.py b/hanlp/datasets/tokenization/txt.py
new file mode 100644
index 000000000..885b7ef32
--- /dev/null
+++ b/hanlp/datasets/tokenization/txt.py
@@ -0,0 +1,110 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-01 12:35
+from typing import Union, List, Callable
+
+from hanlp.common.dataset import TransformableDataset
+from hanlp.utils.io_util import TimingFileIterator
+from hanlp.utils.span_util import words_to_bmes, words_to_bi
+from hanlp.utils.string_util import split_long_sentence_into
+
+
+class TextTokenizingDataset(TransformableDataset):
+ def __init__(self,
+ data: Union[str, List],
+ transform: Union[Callable, List] = None,
+ cache=None,
+ generate_idx=None,
+ delimiter=None,
+ max_seq_len=None,
+ sent_delimiter=None,
+ char_level=False,
+ hard_constraint=False,
+ ) -> None:
+ """A dataset for tagging tokenization tasks.
+
+ Args:
+ data: The local or remote path to a dataset, or a list of samples where each sample is a dict.
+ transform: Predefined transform(s).
+ cache: ``True`` to enable caching, so that transforms won't be called twice.
+ generate_idx: Create a :const:`~hanlp_common.constants.IDX` field for each sample to store its order in dataset. Useful for prediction when
+ samples are re-ordered by a sampler.
+ delimiter: Delimiter between tokens used to split a line in the corpus.
+ max_seq_len: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ sent_delimiter: Delimiter between sentences, like period or comma, which indicates a long sentence can
+ be split here.
+ char_level: Whether the sequence length is measured at char level.
+ hard_constraint: Whether to enforce hard length constraint on sentences. If there is no ``sent_delimiter``
+ in a sentence, it will be split at a token anyway.
+ """
+ self.hard_constraint = hard_constraint
+ self.char_level = char_level
+ self.sent_delimiter = sent_delimiter
+ self.max_seq_len = max_seq_len
+ self.delimiter = delimiter
+ super().__init__(data, transform, cache, generate_idx)
+
+ def load_file(self, filepath: str):
+ """Load tokenized corpus. The format is one sentence per line, where each line consisits of tokens seperated
+ by a delimiter (usually space).
+
+ .. highlight:: bash
+ .. code-block:: bash
+
+ $ head train.txt
+ 上海 浦东 开发 与 法制 建设 同步
+ 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )
+
+ Args:
+ filepath: The path to the corpus.
+ """
+ f = TimingFileIterator(filepath)
+ # longest_sent = 0
+ for line in f:
+ line = line.rstrip('\n')
+ tokens = line.split(self.delimiter)
+ if not tokens:
+ continue
+ if self.max_seq_len and sum(len(t) for t in tokens) > self.max_seq_len:
+ # debug = []
+ for short_sents in split_long_sentence_into(tokens, self.max_seq_len, self.sent_delimiter,
+ char_level=self.char_level,
+ hard_constraint=self.hard_constraint):
+ # debug.extend(short_sents)
+ # longest_sent = max(longest_sent, len(''.join(short_sents)))
+ yield {'token': short_sents}
+ # assert debug == tokens
+ else:
+ # longest_sent = max(longest_sent, len(''.join(tokens)))
+ yield {'token': tokens}
+ f.log(line[:20])
+ f.erase()
+ # print(f'Longest sent: {longest_sent} in {filepath}')
+
+
+def generate_tags_for_subtokens(sample: dict, tagging_scheme='BMES'):
+ # We could use token_token_span but we don't want token_token_span in the batch
+ subtokens_group = sample.get('token_subtoken_offsets_group', None)
+ sample['raw_token'] = sample['token']
+ sample['token'] = offsets_to_subtokens(sample.get('token_') or sample['token'], sample['token_subtoken_offsets'],
+ subtokens_group)
+ if subtokens_group:
+ if tagging_scheme == 'BMES':
+ sample['tag'] = words_to_bmes(subtokens_group)
+ elif tagging_scheme == 'BI':
+ sample['tag'] = words_to_bi(subtokens_group)
+ else:
+ raise NotImplementedError(f'Unsupported tagging scheme {tagging_scheme}.')
+ return sample
+
+
+def offsets_to_subtokens(tokens, token_subtoken_offsets, token_input_tokens_group):
+ results = []
+ if token_input_tokens_group:
+ for subtokens, token in zip(token_input_tokens_group, tokens):
+ for b, e in subtokens:
+ results.append(token[b:e])
+ else:
+ for b, e in token_subtoken_offsets:
+ results.append(tokens[b:e])
+ return results
diff --git a/hanlp/layers/context_layer.py b/hanlp/layers/context_layer.py
new file mode 100644
index 000000000..b0c178144
--- /dev/null
+++ b/hanlp/layers/context_layer.py
@@ -0,0 +1,53 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-05 19:34
+from alnlp.modules.pytorch_seq2seq_wrapper import LstmSeq2SeqEncoder
+from torch import nn
+from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+
+from hanlp.common.structure import ConfigTracker
+
+
+class _LSTMSeq2Seq(nn.Module):
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int = 1,
+ bias: bool = True,
+ dropout: float = 0.0,
+ bidirectional: bool = False,
+ ):
+ """
+ Under construction, not ready for production
+ :param input_size:
+ :param hidden_size:
+ :param num_layers:
+ :param bias:
+ :param dropout:
+ :param bidirectional:
+ """
+ self.rnn = nn.LSTM(
+ input_size=input_size,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ bias=bias,
+ batch_first=True,
+ dropout=dropout,
+ bidirectional=bidirectional,
+ )
+
+ def forward(self, embed, lens, max_len):
+ x = pack_padded_sequence(embed, lens, True, False)
+ x, _ = self.rnn(x)
+ x, _ = pad_packed_sequence(x, True, total_length=max_len)
+ return x
+
+
+# We might update this to support yaml based configuration
+class LSTMContextualEncoder(LstmSeq2SeqEncoder, ConfigTracker):
+
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, dropout: float = 0.0,
+ bidirectional: bool = False, stateful: bool = False):
+ super().__init__(input_size, hidden_size, num_layers, bias, dropout, bidirectional, stateful)
+ ConfigTracker.__init__(self, locals())
diff --git a/hanlp/layers/crf/crf.py b/hanlp/layers/crf/crf.py
index 70762c017..ccf63ad15 100644
--- a/hanlp/layers/crf/crf.py
+++ b/hanlp/layers/crf/crf.py
@@ -1,503 +1,353 @@
-# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+# Copied from https://github.com/kmkurn/pytorch-crf
+# Copyright 2017 Kemal Kurniawan
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is furnished to do
+# so, subject to the following conditions:
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+# PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+__version__ = '0.7.2'
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from typing import List, Optional
-import numpy as np
-import tensorflow as tf
+import torch
+import torch.nn as nn
-# TODO: Wrap functions in @tf.function once
-# https://github.com/tensorflow/tensorflow/issues/29075 is resolved
+class CRF(nn.Module):
+ """Conditional random field.
-def crf_sequence_score(inputs, tag_indices, sequence_lengths,
- transition_params):
- """Computes the unnormalized score for a tag sequence.
+ This module implements a conditional random field [LMP01]_. The forward computation
+ of this class computes the log likelihood of the given sequence of tags and
+ emission score tensor. This class also has `~CRF.decode` method which finds
+ the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
Args:
- inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
- to use as input to the CRF layer.
- tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
- we compute the unnormalized score.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
- transition_params: A [num_tags, num_tags] transition matrix.
- Returns:
- sequence_scores: A [batch_size] vector of unnormalized sequence scores.
- """
- tag_indices = tf.cast(tag_indices, dtype=tf.int32)
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
-
- # If max_seq_len is 1, we skip the score calculation and simply gather the
- # unary potentials of the single tag.
- def _single_seq_fn():
- batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
-
- example_inds = tf.reshape(
- tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
- sequence_scores = tf.gather_nd(
- tf.squeeze(inputs, [1]),
- tf.concat([example_inds, tag_indices], axis=1))
- sequence_scores = tf.where(
- tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores),
- sequence_scores)
- return sequence_scores
-
- def _multi_seq_fn():
- # Compute the scores of the given tag sequence.
- unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
- binary_scores = crf_binary_score(tag_indices, sequence_lengths,
- transition_params)
- sequence_scores = unary_scores + binary_scores
- return sequence_scores
-
- if inputs.shape[1] == 1:
- return _single_seq_fn()
- else:
- return _multi_seq_fn()
-
-
-def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
- transition_params):
- """Computes the unnormalized score of all tag sequences matching
- tag_bitmap.
-
- tag_bitmap enables more than one tag to be considered correct at each time
- step. This is useful when an observed output at a given time step is
- consistent with more than one tag, and thus the log likelihood of that
- observation must take into account all possible consistent tags.
-
- Using one-hot vectors in tag_bitmap gives results identical to
- crf_sequence_score.
-
- Args:
- inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
- to use as input to the CRF layer.
- tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
- representing all active tags at each index for which to calculate the
- unnormalized score.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
- transition_params: A [num_tags, num_tags] transition matrix.
- Returns:
- sequence_scores: A [batch_size] vector of unnormalized sequence scores.
- """
- tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool)
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
- filtered_inputs = tf.where(tag_bitmap, inputs,
- tf.fill(tf.shape(inputs), float("-inf")))
-
- # If max_seq_len is 1, we skip the score calculation and simply gather the
- # unary potentials of all active tags.
- def _single_seq_fn():
- return tf.reduce_logsumexp(
- filtered_inputs, axis=[1, 2], keepdims=False)
-
- def _multi_seq_fn():
- # Compute the logsumexp of all scores of sequences matching the given tags.
- return crf_log_norm(
- inputs=filtered_inputs,
- sequence_lengths=sequence_lengths,
- transition_params=transition_params)
-
- if inputs.shape[1] == 1:
- return _single_seq_fn()
- else:
- return _multi_seq_fn()
-
-
-def crf_log_norm(inputs, sequence_lengths, transition_params):
- """Computes the normalization for a CRF.
-
- Args:
- inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
- to use as input to the CRF layer.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
- transition_params: A [num_tags, num_tags] transition matrix.
- Returns:
- log_norm: A [batch_size] vector of normalizers for a CRF.
- """
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
- # Split up the first and rest of the inputs in preparation for the forward
- # algorithm.
- first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
- first_input = tf.squeeze(first_input, [1])
-
- # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
- # the "initial state" (the unary potentials).
- def _single_seq_fn():
- log_norm = tf.reduce_logsumexp(first_input, [1])
- # Mask `log_norm` of the sequences with length <= zero.
- log_norm = tf.where(
- tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm),
- log_norm)
- return log_norm
-
- def _multi_seq_fn():
- """Forward computation of alpha values."""
- rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])
- # Compute the alpha values in the forward algorithm in order to get the
- # partition function.
-
- alphas = crf_forward(rest_of_input, first_input, transition_params,
- sequence_lengths)
- log_norm = tf.reduce_logsumexp(alphas, [1])
- # Mask `log_norm` of the sequences with length <= zero.
- log_norm = tf.where(
- tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm),
- log_norm)
- return log_norm
-
- if inputs.shape[1] == 1:
- return _single_seq_fn()
- else:
- return _multi_seq_fn()
-
-
-def crf_log_likelihood(inputs,
- tag_indices,
- sequence_lengths,
- transition_params=None):
- """Computes the log-likelihood of tag sequences in a CRF.
-
- Args:
- inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
- to use as input to the CRF layer.
- tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
- we compute the log-likelihood.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
- transition_params: A [num_tags, num_tags] transition matrix,
- if available.
- Returns:
- log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
- each example, given the sequence of tag indices.
- transition_params: A [num_tags, num_tags] transition matrix. This is
- either provided by the caller or created in this function.
- """
- num_tags = inputs.shape[2]
-
- # cast type to handle different types
- tag_indices = tf.cast(tag_indices, dtype=tf.int32)
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
-
- if transition_params is None:
- initializer = tf.keras.initializers.GlorotUniform()
- transition_params = tf.Variable(
- initializer([num_tags, num_tags]), "transitions")
-
- sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
- transition_params)
- log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
-
- # Normalize the scores to get the log-likelihood per example.
- log_likelihood = sequence_scores - log_norm
- return log_likelihood, transition_params
-
+ num_tags: Number of tags.
+ batch_first: Whether the first dimension corresponds to the size of a minibatch.
-def crf_unary_score(tag_indices, sequence_lengths, inputs):
- """Computes the unary scores of tag sequences.
+ Attributes:
+ start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
+ ``(num_tags,)``.
+ end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
+ ``(num_tags,)``.
+ transitions (`~torch.nn.Parameter`): Transition score tensor of size
+ ``(num_tags, num_tags)``.
- Args:
- tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
- inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
- Returns:
- unary_scores: A [batch_size] vector of unary scores.
- """
- assert len(tag_indices.shape) == 2, 'tag_indices: A [batch_size, max_seq_len] matrix of tag indices.'
- tag_indices = tf.cast(tag_indices, dtype=tf.int32)
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
-
- batch_size = tf.shape(inputs)[0]
- max_seq_len = tf.shape(inputs)[1]
- num_tags = tf.shape(inputs)[2]
-
- flattened_inputs = tf.reshape(inputs, [-1])
-
- offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1)
- offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0)
- # Use int32 or int64 based on tag_indices' dtype.
- if tag_indices.dtype == tf.int64:
- offsets = tf.cast(offsets, tf.int64)
- flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1])
- unary_scores = tf.reshape(
- tf.gather(flattened_inputs, flattened_tag_indices),
- [batch_size, max_seq_len])
+ .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
+ "Conditional random fields: Probabilistic models for segmenting and
+ labeling sequence data". *Proc. 18th International Conf. on Machine
+ Learning*. Morgan Kaufmann. pp. 282–289.
- masks = tf.sequence_mask(
- sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32)
-
- unary_scores = tf.reduce_sum(unary_scores * masks, 1)
- return unary_scores
-
-
-def crf_binary_score(tag_indices, sequence_lengths, transition_params):
- """Computes the binary scores of tag sequences.
-
- Args:
- tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
- transition_params: A [num_tags, num_tags] matrix of binary potentials.
- Returns:
- binary_scores: A [batch_size] vector of binary scores.
+ .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
"""
- tag_indices = tf.cast(tag_indices, dtype=tf.int32)
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
- num_tags = tf.shape(transition_params)[0]
- num_transitions = tf.shape(tag_indices)[1] - 1
-
- # Truncate by one on each side of the sequence to get the start and end
- # indices of each transition.
- start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions])
- end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions])
-
- # Encode the indices in a flattened representation.
- flattened_transition_indices = start_tag_indices * \
- num_tags + end_tag_indices
- flattened_transition_params = tf.reshape(transition_params, [-1])
-
- # Get the binary scores based on the flattened representation.
- binary_scores = tf.gather(flattened_transition_params,
- flattened_transition_indices)
-
- masks = tf.sequence_mask(
- sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32)
- truncated_masks = tf.slice(masks, [0, 1], [-1, -1])
- binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1)
- return binary_scores
-
-
-def crf_forward(inputs, state, transition_params, sequence_lengths):
- """Computes the alpha values in a linear-chain CRF.
-
- See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
-
- Args:
- inputs: A [batch_size, num_tags] matrix of unary potentials.
- state: A [batch_size, num_tags] matrix containing the previous alpha
- values.
- transition_params: A [num_tags, num_tags] matrix of binary potentials.
- This matrix is expanded into a [1, num_tags, num_tags] in preparation
- for the broadcast summation occurring within the cell.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
-
- Returns:
- new_alphas: A [batch_size, num_tags] matrix containing the
- new alpha values.
- """
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+ def __init__(self, num_tags: int, batch_first: bool = True) -> None:
+ if num_tags <= 0:
+ raise ValueError(f'invalid number of tags: {num_tags}')
+ super().__init__()
+ self.num_tags = num_tags
+ self.batch_first = batch_first
+ self.start_transitions = nn.Parameter(torch.empty(num_tags))
+ self.end_transitions = nn.Parameter(torch.empty(num_tags))
+ self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
- sequence_lengths = tf.maximum(
- tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 2)
- inputs = tf.transpose(inputs, [1, 0, 2])
- transition_params = tf.expand_dims(transition_params, 0)
+ self.reset_parameters()
- def _scan_fn(state, inputs):
- state = tf.expand_dims(state, 2)
- transition_scores = state + transition_params
- new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1])
- return new_alphas
+ def reset_parameters(self) -> None:
+ """Initialize the transition parameters.
- all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
- idxs = tf.stack(
- [tf.range(tf.shape(sequence_lengths)[0]), sequence_lengths], axis=1)
- return tf.gather_nd(all_alphas, idxs)
-
-
-def viterbi_decode(score, transition_params):
- """Decode the highest scoring sequence of tags outside of TensorFlow.
-
- This should only be used at test time.
-
- Args:
- score: A [seq_len, num_tags] matrix of unary potentials.
- transition_params: A [num_tags, num_tags] matrix of binary potentials.
-
- Returns:
- viterbi: A [seq_len] list of integers containing the highest scoring tag
- indices.
- viterbi_score: A float containing the score for the Viterbi sequence.
- """
- trellis = np.zeros_like(score)
- backpointers = np.zeros_like(score, dtype=np.int32)
- trellis[0] = score[0]
-
- for t in range(1, score.shape[0]):
- v = np.expand_dims(trellis[t - 1], 1) + transition_params
- trellis[t] = score[t] + np.max(v, 0)
- backpointers[t] = np.argmax(v, 0)
-
- viterbi = [np.argmax(trellis[-1])]
- for bp in reversed(backpointers[1:]):
- viterbi.append(bp[viterbi[-1]])
- viterbi.reverse()
-
- viterbi_score = np.max(trellis[-1])
- return viterbi, viterbi_score
-
-
-class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
- """Computes the forward decoding in a linear-chain CRF."""
-
- def __init__(self, transition_params, **kwargs):
- """Initialize the CrfDecodeForwardRnnCell.
-
- Args:
- transition_params: A [num_tags, num_tags] matrix of binary
- potentials. This matrix is expanded into a
- [1, num_tags, num_tags] in preparation for the broadcast
- summation occurring within the cell.
+ The parameters will be initialized randomly from a uniform distribution
+ between -0.1 and 0.1.
"""
- super(CrfDecodeForwardRnnCell, self).__init__(**kwargs)
- self._transition_params = tf.expand_dims(transition_params, 0)
- self._num_tags = transition_params.shape[0]
-
- @property
- def state_size(self):
- return self._num_tags
-
- @property
- def output_size(self):
- return self._num_tags
-
- def build(self, input_shape):
- super(CrfDecodeForwardRnnCell, self).build(input_shape)
-
- def call(self, inputs, state):
- """Build the CrfDecodeForwardRnnCell.
+ nn.init.uniform_(self.start_transitions, -0.1, 0.1)
+ nn.init.uniform_(self.end_transitions, -0.1, 0.1)
+ nn.init.uniform_(self.transitions, -0.1, 0.1)
+
+ def __repr__(self) -> str:
+ return f'{self.__class__.__name__}(num_tags={self.num_tags})'
+
+ def forward(
+ self,
+ emissions: torch.Tensor,
+ tags: torch.LongTensor,
+ mask: Optional[torch.ByteTensor] = None,
+ reduction: str = 'sum',
+ ) -> torch.Tensor:
+ """Compute the conditional log likelihood of a sequence of tags given emission scores.
Args:
- inputs: A [batch_size, num_tags] matrix of unary potentials.
- state: A [batch_size, num_tags] matrix containing the previous step's
- score values.
+ emissions (`~torch.Tensor`): Emission score tensor of size
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
+ ``(batch_size, seq_length, num_tags)`` otherwise.
+ tags (`~torch.LongTensor`): Sequence of tags tensor of size
+ ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
+ ``(batch_size, seq_length)`` otherwise.
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
+ reduction: Specifies the reduction to apply to the output:
+ ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
+ ``sum``: the output will be summed over batches. ``mean``: the output will be
+ averaged over batches. ``token_mean``: the output will be averaged over tokens.
Returns:
- backpointers: A [batch_size, num_tags] matrix of backpointers.
- new_state: A [batch_size, num_tags] matrix of new score values.
+ `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
+ reduction is ``none``, ``()`` otherwise.
"""
- state = tf.expand_dims(state[0], 2)
- transition_scores = state + self._transition_params
- new_state = inputs + tf.reduce_max(transition_scores, [1])
- backpointers = tf.argmax(transition_scores, 1)
- backpointers = tf.cast(backpointers, dtype=tf.int32)
- return backpointers, new_state
-
+ self._validate(emissions, tags=tags, mask=mask)
+ if reduction not in ('none', 'sum', 'mean', 'token_mean'):
+ raise ValueError(f'invalid reduction: {reduction}')
+ if mask is None:
+ mask = torch.ones_like(tags, dtype=torch.uint8)
+
+ if self.batch_first:
+ emissions = emissions.transpose(0, 1)
+ tags = tags.transpose(0, 1)
+ mask = mask.transpose(0, 1)
+
+ # shape: (batch_size,)
+ numerator = self._compute_score(emissions, tags, mask)
+ # shape: (batch_size,)
+ denominator = self._compute_normalizer(emissions, mask)
+ # shape: (batch_size,)
+ llh = numerator - denominator
+
+ if reduction == 'none':
+ return llh
+ if reduction == 'sum':
+ return llh.sum()
+ if reduction == 'mean':
+ return llh.mean()
+ assert reduction == 'token_mean'
+ return llh.sum() / mask.type_as(emissions).sum()
+
+ def decode(self, emissions: torch.Tensor,
+ mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
+ """Find the most likely tag sequence using Viterbi algorithm.
-def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
- """Computes forward decoding in a linear-chain CRF.
-
- Args:
- inputs: A [batch_size, num_tags] matrix of unary potentials.
- state: A [batch_size, num_tags] matrix containing the previous step's
- score values.
- transition_params: A [num_tags, num_tags] matrix of binary potentials.
- sequence_lengths: A [batch_size] vector of true sequence lengths.
-
- Returns:
- backpointers: A [batch_size, num_tags] matrix of backpointers.
- new_state: A [batch_size, num_tags] matrix of new score values.
- """
- sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
- mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
- crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
- crf_fwd_layer = tf.keras.layers.RNN(
- crf_fwd_cell, return_sequences=True, return_state=True)
- return crf_fwd_layer(inputs, state, mask=mask)
-
-
-def crf_decode_backward(inputs, state):
- """Computes backward decoding in a linear-chain CRF.
-
- Args:
- inputs: A [batch_size, num_tags] matrix of
- backpointer of next step (in time order).
- state: A [batch_size, 1] matrix of tag index of next step.
-
- Returns:
- new_tags: A [batch_size, num_tags]
- tensor containing the new tag indices.
- """
- inputs = tf.transpose(inputs, [1, 0, 2])
-
- def _scan_fn(state, inputs):
- state = tf.squeeze(state, axis=[1])
- idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1)
- new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1)
- return new_tags
-
- return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
-
-
-def crf_decode(potentials, transition_params, sequence_length):
- """Decode the highest scoring sequence of tags in TensorFlow.
-
- This is a function for tensor.
+ Args:
+ emissions (`~torch.Tensor`): Emission score tensor of size
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
+ ``(batch_size, seq_length, num_tags)`` otherwise.
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
- Args:
- potentials: A [batch_size, max_seq_len, num_tags] tensor of
- unary potentials.
- transition_params: A [num_tags, num_tags] matrix of
- binary potentials.
- sequence_length: A [batch_size] vector of true sequence lengths.
-
- Returns:
- decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
- Contains the highest scoring tag indices.
- best_score: A [batch_size] vector, containing the score of `decode_tags`.
- """
- sequence_length = tf.cast(sequence_length, dtype=tf.int32)
-
- # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
- # and the max activation.
- def _single_seq_fn():
- squeezed_potentials = tf.squeeze(potentials, [1])
- decode_tags = tf.expand_dims(tf.argmax(squeezed_potentials, axis=1), 1)
- best_score = tf.reduce_max(squeezed_potentials, axis=1)
- return tf.cast(decode_tags, dtype=tf.int32), best_score
-
- def _multi_seq_fn():
- """Decoding of highest scoring sequence."""
- # Computes forward decoding. Get last score and backpointers.
- initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1])
- initial_state = tf.squeeze(initial_state, axis=[1])
- inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1])
-
- sequence_length_less_one = tf.maximum(
- tf.constant(0, dtype=sequence_length.dtype), sequence_length - 1)
-
- backpointers, last_score = crf_decode_forward(
- inputs, initial_state, transition_params, sequence_length_less_one)
-
- backpointers = tf.reverse_sequence(
- backpointers, sequence_length_less_one, seq_axis=1)
-
- initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32)
- initial_state = tf.expand_dims(initial_state, axis=-1)
-
- decode_tags = crf_decode_backward(backpointers, initial_state)
- decode_tags = tf.squeeze(decode_tags, axis=[2])
- decode_tags = tf.concat([initial_state, decode_tags], axis=1)
- decode_tags = tf.reverse_sequence(
- decode_tags, sequence_length, seq_axis=1)
-
- best_score = tf.reduce_max(last_score, axis=1)
- return decode_tags, best_score
-
- if potentials.shape[1] == 1:
- return _single_seq_fn()
- else:
- return _multi_seq_fn()
+ Returns:
+ List of list containing the best tag sequence for each batch.
+ """
+ self._validate(emissions, mask=mask)
+ if mask is None:
+ mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
+
+ if self.batch_first:
+ emissions = emissions.transpose(0, 1)
+ mask = mask.transpose(0, 1)
+
+ return self._viterbi_decode(emissions, mask)
+
+ def _validate(
+ self,
+ emissions: torch.Tensor,
+ tags: Optional[torch.LongTensor] = None,
+ mask: Optional[torch.ByteTensor] = None) -> None:
+ if emissions.dim() != 3:
+ raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
+ if emissions.size(2) != self.num_tags:
+ raise ValueError(
+ f'expected last dimension of emissions is {self.num_tags}, '
+ f'got {emissions.size(2)}')
+
+ if tags is not None:
+ if emissions.shape[:2] != tags.shape:
+ raise ValueError(
+ 'the first two dimensions of emissions and tags must match, '
+ f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
+
+ if mask is not None:
+ if emissions.shape[:2] != mask.shape:
+ raise ValueError(
+ 'the first two dimensions of emissions and mask must match, '
+ f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
+ no_empty_seq = not self.batch_first and mask[0].all()
+ no_empty_seq_bf = self.batch_first and mask[:, 0].all()
+ if not no_empty_seq and not no_empty_seq_bf:
+ raise ValueError('mask of the first timestep must all be on')
+
+ def _compute_score(
+ self, emissions: torch.Tensor, tags: torch.LongTensor,
+ mask: torch.ByteTensor) -> torch.Tensor:
+ # emissions: (seq_length, batch_size, num_tags)
+ # tags: (seq_length, batch_size)
+ # mask: (seq_length, batch_size)
+ assert emissions.dim() == 3 and tags.dim() == 2
+ assert emissions.shape[:2] == tags.shape
+ assert emissions.size(2) == self.num_tags
+ assert mask.shape == tags.shape
+ assert mask[0].all()
+
+ seq_length, batch_size = tags.shape
+ mask = mask.type_as(emissions)
+
+ # Start transition score and first emission
+ # shape: (batch_size,)
+ score = self.start_transitions[tags[0]]
+ score += emissions[0, torch.arange(batch_size), tags[0]]
+
+ for i in range(1, seq_length):
+ # Transition score to next tag, only added if next timestep is valid (mask == 1)
+ # shape: (batch_size,)
+ score += self.transitions[tags[i - 1], tags[i]] * mask[i]
+
+ # Emission score for next tag, only added if next timestep is valid (mask == 1)
+ # shape: (batch_size,)
+ score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
+
+ # End transition score
+ # shape: (batch_size,)
+ seq_ends = mask.long().sum(dim=0) - 1
+ # shape: (batch_size,)
+ last_tags = tags[seq_ends, torch.arange(batch_size)]
+ # shape: (batch_size,)
+ score += self.end_transitions[last_tags]
+
+ return score
+
+ def _compute_normalizer(
+ self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
+ # emissions: (seq_length, batch_size, num_tags)
+ # mask: (seq_length, batch_size)
+ assert emissions.dim() == 3 and mask.dim() == 2
+ assert emissions.shape[:2] == mask.shape
+ assert emissions.size(2) == self.num_tags
+ assert mask[0].all()
+
+ seq_length = emissions.size(0)
+
+ # Start transition score and first emission; score has size of
+ # (batch_size, num_tags) where for each batch, the j-th column stores
+ # the score that the first timestep has tag j
+ # shape: (batch_size, num_tags)
+ score = self.start_transitions + emissions[0]
+
+ for i in range(1, seq_length):
+ # Broadcast score for every possible next tag
+ # shape: (batch_size, num_tags, 1)
+ broadcast_score = score.unsqueeze(2)
+
+ # Broadcast emission score for every possible current tag
+ # shape: (batch_size, 1, num_tags)
+ broadcast_emissions = emissions[i].unsqueeze(1)
+
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
+ # for each sample, entry at row i and column j stores the sum of scores of all
+ # possible tag sequences so far that end with transitioning from tag i to tag j
+ # and emitting
+ # shape: (batch_size, num_tags, num_tags)
+ next_score = broadcast_score + self.transitions + broadcast_emissions
+
+ # Sum over all possible current tags, but we're in score space, so a sum
+ # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
+ # all possible tag sequences so far, that end in tag i
+ # shape: (batch_size, num_tags)
+ next_score = torch.logsumexp(next_score, dim=1)
+
+ # Set score to the next score if this timestep is valid (mask == 1)
+ # shape: (batch_size, num_tags)
+ score = torch.where(mask[i].unsqueeze(1), next_score, score)
+
+ # End transition score
+ # shape: (batch_size, num_tags)
+ score += self.end_transitions
+
+ # Sum (log-sum-exp) over all possible tags
+ # shape: (batch_size,)
+ return torch.logsumexp(score, dim=1)
+
+ def _viterbi_decode(self, emissions: torch.FloatTensor,
+ mask: torch.ByteTensor) -> List[List[int]]:
+ # emissions: (seq_length, batch_size, num_tags)
+ # mask: (seq_length, batch_size)
+ assert emissions.dim() == 3 and mask.dim() == 2
+ assert emissions.shape[:2] == mask.shape
+ assert emissions.size(2) == self.num_tags
+ assert mask[0].all()
+
+ seq_length, batch_size = mask.shape
+
+ # Start transition and first emission
+ # shape: (batch_size, num_tags)
+ score = self.start_transitions + emissions[0]
+ history = []
+
+ # score is a tensor of size (batch_size, num_tags) where for every batch,
+ # value at column j stores the score of the best tag sequence so far that ends
+ # with tag j
+ # history saves where the best tags candidate transitioned from; this is used
+ # when we trace back the best tag sequence
+
+ # Viterbi algorithm recursive case: we compute the score of the best tag sequence
+ # for every possible next tag
+ for i in range(1, seq_length):
+ # Broadcast viterbi score for every possible next tag
+ # shape: (batch_size, num_tags, 1)
+ broadcast_score = score.unsqueeze(2)
+
+ # Broadcast emission score for every possible current tag
+ # shape: (batch_size, 1, num_tags)
+ broadcast_emission = emissions[i].unsqueeze(1)
+
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
+ # for each sample, entry at row i and column j stores the score of the best
+ # tag sequence so far that ends with transitioning from tag i to tag j and emitting
+ # shape: (batch_size, num_tags, num_tags)
+ next_score = broadcast_score + self.transitions + broadcast_emission
+
+ # Find the maximum score over all possible current tag
+ # shape: (batch_size, num_tags)
+ next_score, indices = next_score.max(dim=1)
+
+ # Set score to the next score if this timestep is valid (mask == 1)
+ # and save the index that produces the next score
+ # shape: (batch_size, num_tags)
+ score = torch.where(mask[i].unsqueeze(1), next_score, score)
+ history.append(indices)
+
+ # End transition score
+ # shape: (batch_size, num_tags)
+ score += self.end_transitions
+
+ # Now, compute the best path for each sample
+
+ # shape: (batch_size,)
+ seq_ends = mask.long().sum(dim=0) - 1
+ best_tags_list = []
+
+ for idx in range(batch_size):
+ # Find the tag which maximizes the score at the last timestep; this is our best tag
+ # for the last timestep
+ _, best_last_tag = score[idx].max(dim=0)
+ best_tags = [best_last_tag.item()]
+
+ # We trace back where the best last tag comes from, append that to our best tag
+ # sequence, and trace it back again, and so on
+ for hist in reversed(history[:seq_ends[idx]]):
+ best_last_tag = hist[idx][best_tags[-1]]
+ best_tags.append(best_last_tag.item())
+
+ # Reverse the order because we start from the last timestep
+ best_tags.reverse()
+ best_tags_list.append(best_tags)
+
+ return best_tags_list
diff --git a/hanlp/layers/crf/crf_layer.py b/hanlp/layers/crf/crf_layer_tf.py
similarity index 93%
rename from hanlp/layers/crf/crf_layer.py
rename to hanlp/layers/crf/crf_layer_tf.py
index 57507558f..7190b965e 100644
--- a/hanlp/layers/crf/crf_layer.py
+++ b/hanlp/layers/crf/crf_layer_tf.py
@@ -15,30 +15,32 @@
# ******************************************************************************
import tensorflow as tf
-from hanlp.layers.crf.crf import crf_decode, crf_log_likelihood
+from hanlp.layers.crf.crf_tf import crf_decode, crf_log_likelihood
class CRF(tf.keras.layers.Layer):
- """
- Conditional Random Field layer (tf.keras)
+ """Conditional Random Field layer (tf.keras)
`CRF` can be used as the last layer in a network (as a classifier). Input shape (features)
must be equal to the number of classes the CRF can predict (a linear layer is recommended).
-
+
Note: the loss and accuracy functions of networks using `CRF` must
use the provided loss and accuracy functions (denoted as loss and viterbi_accuracy)
as the classification of sequences are used with the layers internal weights.
-
+
Copyright: this is a modified version of
https://github.com/NervanaSystems/nlp-architect/blob/master/nlp_architect/nn/tensorflow/python/keras/layers/crf.py
Args:
- num_labels (int): the number of labels to tag each temporal input.
-
+ num_labels(int): the number of labels to tag each temporal input.
Input shape:
- nD tensor with shape `(batch_size, sentence length, num_classes)`.
-
+ num_labels(int): the number of labels to tag each temporal input.
+ Input shape:
+ nD tensor with shape `(batch_size, sentence length, num_classes)`.
Output shape:
- nD tensor with shape: `(batch_size, sentence length, num_classes)`.
+ nD tensor with shape: `(batch_size, sentence length, num_classes)`.
+
+ Returns:
+
"""
def __init__(self, num_classes, **kwargs):
diff --git a/hanlp/layers/crf/crf_tf.py b/hanlp/layers/crf/crf_tf.py
new file mode 100644
index 000000000..253a099b2
--- /dev/null
+++ b/hanlp/layers/crf/crf_tf.py
@@ -0,0 +1,520 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+# TODO: Wrap functions in @tf.function once
+# https://github.com/tensorflow/tensorflow/issues/29075 is resolved
+
+
+def crf_sequence_score(inputs, tag_indices, sequence_lengths,
+ transition_params):
+ """Computes the unnormalized score for a tag sequence.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
+ we compute the unnormalized score.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params:
+
+ Returns:
+ sequence_scores: A [batch_size] vector of unnormalized sequence scores.
+
+ """
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+
+ # If max_seq_len is 1, we skip the score calculation and simply gather the
+ # unary potentials of the single tag.
+ def _single_seq_fn():
+ batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
+
+ example_inds = tf.reshape(
+ tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
+ sequence_scores = tf.gather_nd(
+ tf.squeeze(inputs, [1]),
+ tf.concat([example_inds, tag_indices], axis=1))
+ sequence_scores = tf.where(
+ tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores),
+ sequence_scores)
+ return sequence_scores
+
+ def _multi_seq_fn():
+ # Compute the scores of the given tag sequence.
+ unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
+ binary_scores = crf_binary_score(tag_indices, sequence_lengths,
+ transition_params)
+ sequence_scores = unary_scores + binary_scores
+ return sequence_scores
+
+ if inputs.shape[1] == 1:
+ return _single_seq_fn()
+ else:
+ return _multi_seq_fn()
+
+
+def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
+ transition_params):
+ """Computes the unnormalized score of all tag sequences matching
+ tag_bitmap.
+
+ tag_bitmap enables more than one tag to be considered correct at each time
+ step. This is useful when an observed output at a given time step is
+ consistent with more than one tag, and thus the log likelihood of that
+ observation must take into account all possible consistent tags.
+
+ Using one-hot vectors in tag_bitmap gives results identical to
+ crf_sequence_score.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
+ representing all active tags at each index for which to calculate the
+ unnormalized score.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params:
+
+ Returns:
+ sequence_scores: A [batch_size] vector of unnormalized sequence scores.
+
+ """
+ tag_bitmap = tf.cast(tag_bitmap, dtype=tf.bool)
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+ filtered_inputs = tf.where(tag_bitmap, inputs,
+ tf.fill(tf.shape(inputs), float("-inf")))
+
+ # If max_seq_len is 1, we skip the score calculation and simply gather the
+ # unary potentials of all active tags.
+ def _single_seq_fn():
+ return tf.reduce_logsumexp(
+ filtered_inputs, axis=[1, 2], keepdims=False)
+
+ def _multi_seq_fn():
+ # Compute the logsumexp of all scores of sequences matching the given tags.
+ return crf_log_norm(
+ inputs=filtered_inputs,
+ sequence_lengths=sequence_lengths,
+ transition_params=transition_params)
+
+ if inputs.shape[1] == 1:
+ return _single_seq_fn()
+ else:
+ return _multi_seq_fn()
+
+
+def crf_log_norm(inputs, sequence_lengths, transition_params):
+ """Computes the normalization for a CRF.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params:
+
+ Returns:
+ log_norm: A [batch_size] vector of normalizers for a CRF.
+
+ """
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+ # Split up the first and rest of the inputs in preparation for the forward
+ # algorithm.
+ first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
+ first_input = tf.squeeze(first_input, [1])
+
+ # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
+ # the "initial state" (the unary potentials).
+ def _single_seq_fn():
+ log_norm = tf.reduce_logsumexp(first_input, [1])
+ # Mask `log_norm` of the sequences with length <= zero.
+ log_norm = tf.where(
+ tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm),
+ log_norm)
+ return log_norm
+
+ def _multi_seq_fn():
+ """Forward computation of alpha values."""
+ rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])
+ # Compute the alpha values in the forward algorithm in order to get the
+ # partition function.
+
+ alphas = crf_forward(rest_of_input, first_input, transition_params,
+ sequence_lengths)
+ log_norm = tf.reduce_logsumexp(alphas, [1])
+ # Mask `log_norm` of the sequences with length <= zero.
+ log_norm = tf.where(
+ tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm),
+ log_norm)
+ return log_norm
+
+ if inputs.shape[1] == 1:
+ return _single_seq_fn()
+ else:
+ return _multi_seq_fn()
+
+
+def crf_log_likelihood(inputs,
+ tag_indices,
+ sequence_lengths,
+ transition_params=None):
+ """Computes the log-likelihood of tag sequences in a CRF.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which
+ we compute the log-likelihood.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params: A [num_tags, num_tags] transition matrix, (Default value = None)
+
+ Returns:
+ log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
+ each example, given the sequence of tag indices.
+ transition_params: A [num_tags, num_tags] transition matrix. This is
+ either provided by the caller or created in this function.
+
+ """
+ num_tags = inputs.shape[2]
+
+ # cast type to handle different types
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+
+ if transition_params is None:
+ initializer = tf.keras.initializers.GlorotUniform()
+ transition_params = tf.Variable(
+ initializer([num_tags, num_tags]), "transitions")
+
+ sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
+ transition_params)
+ log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
+
+ # Normalize the scores to get the log-likelihood per example.
+ log_likelihood = sequence_scores - log_norm
+ return log_likelihood, transition_params
+
+
+def crf_unary_score(tag_indices, sequence_lengths, inputs):
+ """Computes the unary scores of tag sequences.
+
+ Args:
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ inputs:
+
+ Returns:
+ unary_scores: A [batch_size] vector of unary scores.
+
+ """
+ assert len(tag_indices.shape) == 2, 'tag_indices: A [batch_size, max_seq_len] matrix of tag indices.'
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+
+ batch_size = tf.shape(inputs)[0]
+ max_seq_len = tf.shape(inputs)[1]
+ num_tags = tf.shape(inputs)[2]
+
+ flattened_inputs = tf.reshape(inputs, [-1])
+
+ offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1)
+ offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0)
+ # Use int32 or int64 based on tag_indices' dtype.
+ if tag_indices.dtype == tf.int64:
+ offsets = tf.cast(offsets, tf.int64)
+ flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1])
+
+ unary_scores = tf.reshape(
+ tf.gather(flattened_inputs, flattened_tag_indices),
+ [batch_size, max_seq_len])
+
+ masks = tf.sequence_mask(
+ sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32)
+
+ unary_scores = tf.reduce_sum(unary_scores * masks, 1)
+ return unary_scores
+
+
+def crf_binary_score(tag_indices, sequence_lengths, transition_params):
+ """Computes the binary scores of tag sequences.
+
+ Args:
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params:
+
+ Returns:
+ binary_scores: A [batch_size] vector of binary scores.
+
+ """
+ tag_indices = tf.cast(tag_indices, dtype=tf.int32)
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+
+ num_tags = tf.shape(transition_params)[0]
+ num_transitions = tf.shape(tag_indices)[1] - 1
+
+ # Truncate by one on each side of the sequence to get the start and end
+ # indices of each transition.
+ start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions])
+ end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions])
+
+ # Encode the indices in a flattened representation.
+ flattened_transition_indices = start_tag_indices * \
+ num_tags + end_tag_indices
+ flattened_transition_params = tf.reshape(transition_params, [-1])
+
+ # Get the binary scores based on the flattened representation.
+ binary_scores = tf.gather(flattened_transition_params,
+ flattened_transition_indices)
+
+ masks = tf.sequence_mask(
+ sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32)
+ truncated_masks = tf.slice(masks, [0, 1], [-1, -1])
+ binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1)
+ return binary_scores
+
+
+def crf_forward(inputs, state, transition_params, sequence_lengths):
+ """Computes the alpha values in a linear-chain CRF.
+
+ See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
+
+ Args:
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
+ state: A [batch_size, num_tags] matrix containing the previous alpha
+ values.
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
+ This matrix is expanded into a [1, num_tags, num_tags] in preparation
+ for the broadcast summation occurring within the cell.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+
+ Returns:
+ new_alphas: A [batch_size, num_tags] matrix containing the
+ new alpha values.
+
+ """
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+
+ sequence_lengths = tf.maximum(
+ tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 2)
+ inputs = tf.transpose(inputs, [1, 0, 2])
+ transition_params = tf.expand_dims(transition_params, 0)
+
+ def _scan_fn(state, inputs):
+ state = tf.expand_dims(state, 2)
+ transition_scores = state + transition_params
+ new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1])
+ return new_alphas
+
+ all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
+ idxs = tf.stack(
+ [tf.range(tf.shape(sequence_lengths)[0]), sequence_lengths], axis=1)
+ return tf.gather_nd(all_alphas, idxs)
+
+
+def viterbi_decode(score, transition_params):
+ """Decode the highest scoring sequence of tags outside of TensorFlow.
+
+ This should only be used at test time.
+
+ Args:
+ score: A [seq_len, num_tags] matrix of unary potentials.
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
+
+ Returns:
+ viterbi: A [seq_len] list of integers containing the highest scoring tag
+ indices.
+ viterbi_score: A float containing the score for the Viterbi sequence.
+
+ """
+ trellis = np.zeros_like(score)
+ backpointers = np.zeros_like(score, dtype=np.int32)
+ trellis[0] = score[0]
+
+ for t in range(1, score.shape[0]):
+ v = np.expand_dims(trellis[t - 1], 1) + transition_params
+ trellis[t] = score[t] + np.max(v, 0)
+ backpointers[t] = np.argmax(v, 0)
+
+ viterbi = [np.argmax(trellis[-1])]
+ for bp in reversed(backpointers[1:]):
+ viterbi.append(bp[viterbi[-1]])
+ viterbi.reverse()
+
+ viterbi_score = np.max(trellis[-1])
+ return viterbi, viterbi_score
+
+
+class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell):
+ """Computes the forward decoding in a linear-chain CRF."""
+
+ def __init__(self, transition_params, **kwargs):
+ """Initialize the CrfDecodeForwardRnnCell.
+
+ Args:
+ transition_params: A [num_tags, num_tags] matrix of binary
+ potentials. This matrix is expanded into a
+ [1, num_tags, num_tags] in preparation for the broadcast
+ summation occurring within the cell.
+ """
+ super(CrfDecodeForwardRnnCell, self).__init__(**kwargs)
+ self._transition_params = tf.expand_dims(transition_params, 0)
+ self._num_tags = transition_params.shape[0]
+
+ @property
+ def state_size(self):
+ return self._num_tags
+
+ @property
+ def output_size(self):
+ return self._num_tags
+
+ def build(self, input_shape):
+ super(CrfDecodeForwardRnnCell, self).build(input_shape)
+
+ def call(self, inputs, state):
+ """Build the CrfDecodeForwardRnnCell.
+
+ Args:
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
+ state: A [batch_size, num_tags] matrix containing the previous step's
+ score values.
+
+ Returns:
+ backpointers: A [batch_size, num_tags] matrix of backpointers.
+ new_state: A [batch_size, num_tags] matrix of new score values.
+
+ """
+ state = tf.expand_dims(state[0], 2)
+ transition_scores = state + self._transition_params
+ new_state = inputs + tf.reduce_max(transition_scores, [1])
+ backpointers = tf.argmax(transition_scores, 1)
+ backpointers = tf.cast(backpointers, dtype=tf.int32)
+ return backpointers, new_state
+
+
+def crf_decode_forward(inputs, state, transition_params, sequence_lengths):
+ """Computes forward decoding in a linear-chain CRF.
+
+ Args:
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
+ state: A [batch_size, num_tags] matrix containing the previous step's
+ score values.
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+
+ Returns:
+ backpointers: A [batch_size, num_tags] matrix of backpointers.
+ new_state: A [batch_size, num_tags] matrix of new score values.
+
+ """
+ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
+ mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
+ crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
+ crf_fwd_layer = tf.keras.layers.RNN(
+ crf_fwd_cell, return_sequences=True, return_state=True)
+ return crf_fwd_layer(inputs, state, mask=mask)
+
+
+def crf_decode_backward(inputs, state):
+ """Computes backward decoding in a linear-chain CRF.
+
+ Args:
+ inputs: A [batch_size, num_tags] matrix of
+ backpointer of next step (in time order).
+ state: A [batch_size, 1] matrix of tag index of next step.
+
+ Returns:
+ new_tags: A [batch_size, num_tags]
+ tensor containing the new tag indices.
+
+ """
+ inputs = tf.transpose(inputs, [1, 0, 2])
+
+ def _scan_fn(state, inputs):
+ state = tf.squeeze(state, axis=[1])
+ idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1)
+ new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1)
+ return new_tags
+
+ return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
+
+
+def crf_decode(potentials, transition_params, sequence_length):
+ """Decode the highest scoring sequence of tags in TensorFlow.
+
+ This is a function for tensor.
+
+ Args:
+ potentials: A [batch_size, max_seq_len, num_tags] tensor of
+ unary potentials.
+ transition_params: A [num_tags, num_tags] matrix of
+ binary potentials.
+ sequence_length: A [batch_size] vector of true sequence lengths.
+
+ Returns:
+ decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
+ Contains the highest scoring tag indices.
+ best_score: A [batch_size] vector, containing the score of `decode_tags`.
+
+ """
+ sequence_length = tf.cast(sequence_length, dtype=tf.int32)
+
+ # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
+ # and the max activation.
+ def _single_seq_fn():
+ squeezed_potentials = tf.squeeze(potentials, [1])
+ decode_tags = tf.expand_dims(tf.argmax(squeezed_potentials, axis=1), 1)
+ best_score = tf.reduce_max(squeezed_potentials, axis=1)
+ return tf.cast(decode_tags, dtype=tf.int32), best_score
+
+ def _multi_seq_fn():
+ """Decoding of highest scoring sequence."""
+ # Computes forward decoding. Get last score and backpointers.
+ initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1])
+ initial_state = tf.squeeze(initial_state, axis=[1])
+ inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1])
+
+ sequence_length_less_one = tf.maximum(
+ tf.constant(0, dtype=sequence_length.dtype), sequence_length - 1)
+
+ backpointers, last_score = crf_decode_forward(
+ inputs, initial_state, transition_params, sequence_length_less_one)
+
+ backpointers = tf.reverse_sequence(
+ backpointers, sequence_length_less_one, seq_axis=1)
+
+ initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32)
+ initial_state = tf.expand_dims(initial_state, axis=-1)
+
+ decode_tags = crf_decode_backward(backpointers, initial_state)
+ decode_tags = tf.squeeze(decode_tags, axis=[2])
+ decode_tags = tf.concat([initial_state, decode_tags], axis=1)
+ decode_tags = tf.reverse_sequence(
+ decode_tags, sequence_length, seq_axis=1)
+
+ best_score = tf.reduce_max(last_score, axis=1)
+ return decode_tags, best_score
+
+ if potentials.shape[1] == 1:
+ return _single_seq_fn()
+ else:
+ return _multi_seq_fn()
diff --git a/hanlp/layers/dropout.py b/hanlp/layers/dropout.py
new file mode 100644
index 000000000..fdd4d16de
--- /dev/null
+++ b/hanlp/layers/dropout.py
@@ -0,0 +1,159 @@
+# -*- coding:utf-8 -*-
+# Date: 2020-06-05 17:47
+from typing import List
+
+import torch
+import torch.nn as nn
+
+
+class WordDropout(nn.Module):
+ def __init__(self, p: float, oov_token: int, exclude_tokens: List[int] = None) -> None:
+ super().__init__()
+ self.oov_token = oov_token
+ self.p = p
+ if not exclude_tokens:
+ exclude_tokens = [0]
+ self.exclude = exclude_tokens
+
+ @staticmethod
+ def token_dropout(tokens: torch.LongTensor,
+ oov_token: int,
+ exclude_tokens: List[int],
+ p: float = 0.2,
+ training: float = True) -> torch.LongTensor:
+ """During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p``
+
+ Adopted from https://github.com/Hyperparticle/udify
+
+ Args:
+ tokens: The current batch of padded sentences with word ids
+ oov_token: The mask token
+ exclude_tokens: The tokens for padding the input batch
+ p: The probability a word gets mapped to the unknown token
+ training: Applies the dropout if set to ``True``
+ tokens: torch.LongTensor:
+ oov_token: int:
+ exclude_tokens: List[int]:
+ p: float: (Default value = 0.2)
+ training: float: (Default value = True)
+
+ Returns:
+ A copy of the input batch with token dropout applied
+
+ """
+ if training and p > 0:
+ # This creates a mask that only considers unpadded tokens for mapping to oov
+ padding_mask = tokens.new_ones(tokens.size(), dtype=torch.bool)
+ for pad in exclude_tokens:
+ padding_mask &= (tokens != pad)
+
+ # Create a uniformly random mask selecting either the original words or OOV tokens
+ dropout_mask = (tokens.new_empty(tokens.size(), dtype=torch.float).uniform_() < p)
+ oov_mask = dropout_mask & padding_mask
+
+ oov_fill = tokens.new_empty(tokens.size(), dtype=torch.long).fill_(oov_token)
+
+ result = torch.where(oov_mask, oov_fill, tokens)
+
+ return result
+ else:
+ return tokens
+
+ def forward(self, tokens: torch.LongTensor) -> torch.LongTensor:
+ return self.token_dropout(tokens, self.oov_token, self.exclude, self.p, self.training)
+
+
+class SharedDropout(nn.Module):
+
+ def __init__(self, p=0.5, batch_first=True):
+ super(SharedDropout, self).__init__()
+
+ self.p = p
+ self.batch_first = batch_first
+
+ def extra_repr(self):
+ s = f"p={self.p}"
+ if self.batch_first:
+ s += f", batch_first={self.batch_first}"
+
+ return s
+
+ def forward(self, x):
+ if self.training:
+ if self.batch_first:
+ mask = self.get_mask(x[:, 0], self.p)
+ else:
+ mask = self.get_mask(x[0], self.p)
+ x *= mask.unsqueeze(1) if self.batch_first else mask
+
+ return x
+
+ @staticmethod
+ def get_mask(x, p):
+ mask = x.new_empty(x.shape).bernoulli_(1 - p)
+ mask = mask / (1 - p)
+
+ return mask
+
+
+class IndependentDropout(nn.Module):
+
+ def __init__(self, p=0.5):
+ r"""
+ For :math:`N` tensors, they use different dropout masks respectively.
+ When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate,
+ and when all of them are dropped together, zeros are returned.
+ Copied from https://github.com/yzhangcs/parser/master/supar/modules/dropout.py.
+
+ Args:
+ p (float):
+ The probability of an element to be zeroed. Default: 0.5.
+
+ Examples:
+ >>> x, y = torch.ones(1, 3, 5), torch.ones(1, 3, 5)
+ >>> x, y = IndependentDropout()(x, y)
+ >>> x
+ tensor([[[1., 1., 1., 1., 1.],
+ [0., 0., 0., 0., 0.],
+ [2., 2., 2., 2., 2.]]])
+ >>> y
+ tensor([[[1., 1., 1., 1., 1.],
+ [2., 2., 2., 2., 2.],
+ [0., 0., 0., 0., 0.]]])
+ """
+ super(IndependentDropout, self).__init__()
+ self.p = p
+
+ def extra_repr(self):
+ return f"p={self.p}"
+
+ def forward(self, *items):
+ if self.training:
+ masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p)
+ for x in items]
+ total = sum(masks)
+ scale = len(items) / total.max(torch.ones_like(total))
+ masks = [mask * scale for mask in masks]
+ items = [item * mask.unsqueeze(dim=-1)
+ for item, mask in zip(items, masks)]
+
+ return items
+
+
+class LockedDropout(nn.Module):
+ def __init__(self, dropout_rate=0.5):
+ super(LockedDropout, self).__init__()
+ self.dropout_rate = dropout_rate
+
+ def forward(self, x):
+ if not self.training or not self.dropout_rate:
+ return x
+
+ if x.dim() == 3:
+ mask = x.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate) / (1 - self.dropout_rate)
+ mask = mask.expand_as(x)
+ elif x.dim() == 2:
+ mask = torch.empty_like(x).bernoulli_(1 - self.dropout_rate) / (1 - self.dropout_rate)
+ else:
+ raise ValueError(f'Unsupported dim: {x.dim()}. Only 2d (T,C) or 3d (B,T,C) is supported')
+ return mask * x
diff --git a/hanlp/layers/embeddings/__init__.py b/hanlp/layers/embeddings/__init__.py
index d253f6260..0aa659895 100644
--- a/hanlp/layers/embeddings/__init__.py
+++ b/hanlp/layers/embeddings/__init__.py
@@ -1,78 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-08-24 21:48
-from typing import Union
-
-import tensorflow as tf
-
-from hanlp.common.transform import Transform
-from hanlp.common.vocab import Vocab
-from hanlp.layers.embeddings.char_cnn import CharCNNEmbedding
-from hanlp.layers.embeddings.char_rnn import CharRNNEmbedding
-from hanlp.layers.embeddings.concat_embedding import ConcatEmbedding
-from hanlp.layers.embeddings.contextual_string_embedding import ContextualStringEmbedding
-from hanlp.layers.embeddings.word2vec import Word2VecEmbeddingV1, Word2VecEmbedding, StringWord2VecEmbedding
-
-
-def build_embedding(embeddings: Union[str, int, dict], word_vocab: Vocab, transform: Transform):
- config = transform.config
- if isinstance(embeddings, int):
- embeddings = tf.keras.layers.Embedding(input_dim=len(word_vocab), output_dim=embeddings,
- trainable=True, mask_zero=True)
- config.embedding_trainable = True
- elif isinstance(embeddings, dict):
- # Embeddings need vocab
- if embeddings['class_name'].split('>')[-1] in (Word2VecEmbedding.__name__, StringWord2VecEmbedding.__name__):
- # Vocab won't present in the dict
- embeddings['config']['vocab'] = word_vocab
- elif embeddings['class_name'].split('>')[-1] in (CharRNNEmbedding.__name__, CharCNNEmbedding.__name__):
- embeddings['config']['word_vocab'] = word_vocab
- embeddings['config']['char_vocab'] = transform.char_vocab
- transform.map_x = False
- elif embeddings['class_name'].split('>')[-1] == 'FastTextEmbedding':
- from hanlp.layers.embeddings.fast_text import FastTextEmbedding
- layer: tf.keras.layers.Embedding = tf.keras.utils.deserialize_keras_object(embeddings,
- custom_objects=tf.keras.utils.get_custom_objects())
- # Embedding specific configuration
- if layer.__class__.__name__ == 'FastTextEmbedding':
- config.run_eagerly = True # fasttext can only run in eager mode
- config.embedding_trainable = False
- transform.map_x = False # fasttext accept string instead of int
- return layer
- elif isinstance(embeddings, list):
- if embeddings_require_string_input(embeddings):
- # those embeddings require string as input
- transform.map_x = False
- # use the string version of Word2VecEmbedding instead
- for embed in embeddings:
- if embed['class_name'].split('>')[-1] == Word2VecEmbedding.__name__:
- embed['class_name'] = 'HanLP>' + StringWord2VecEmbedding.__name__
- return ConcatEmbedding(*[build_embedding(embed, word_vocab, transform) for embed in embeddings])
- else:
- assert isinstance(embeddings, str), 'embedding should be str or int or dict'
- # word_vocab.unlock()
- embeddings = Word2VecEmbeddingV1(path=embeddings, vocab=word_vocab,
- trainable=config.get('embedding_trainable', False))
- embeddings = embeddings.array_ks
- return embeddings
-
-
-def any_embedding_in(embeddings, *cls):
- names = set(x if isinstance(x, str) else x.__name__ for x in cls)
- for embed in embeddings:
- if isinstance(embed, dict) and embed['class_name'].split('>')[-1] in names:
- return True
- return False
-
-
-def embeddings_require_string_input(embeddings):
- if not isinstance(embeddings, list):
- embeddings = [embeddings]
- return any_embedding_in(embeddings, CharRNNEmbedding, CharCNNEmbedding, 'FastTextEmbedding',
- ContextualStringEmbedding)
-
-
-def embeddings_require_char_input(embeddings):
- if not isinstance(embeddings, list):
- embeddings = [embeddings]
- return any_embedding_in(embeddings, CharRNNEmbedding, CharCNNEmbedding, ContextualStringEmbedding)
diff --git a/hanlp/layers/embeddings/char_cnn.py b/hanlp/layers/embeddings/char_cnn.py
index 05462819a..830595115 100644
--- a/hanlp/layers/embeddings/char_cnn.py
+++ b/hanlp/layers/embeddings/char_cnn.py
@@ -1,109 +1,147 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-20 21:15
-from functools import reduce
-
-import tensorflow as tf
-
+# Adopted from https://github.com/allenai/allennlp under Apache Licence 2.0.
+# Changed the packaging and created a subclass CharCNNEmbedding
+
+from typing import Union, Tuple, Optional, Callable
+import torch
+from torch import nn
+from alnlp.modules.cnn_encoder import CnnEncoder
+from alnlp.modules.time_distributed import TimeDistributed
+from hanlp_common.configurable import AutoConfigurable
+from hanlp.common.transform import VocabDict, ToChar
from hanlp.common.vocab import Vocab
-from hanlp.utils.tf_util import hanlp_register
-
-
-@hanlp_register
-class CharCNNEmbedding(tf.keras.layers.Layer):
- def __init__(self, word_vocab: Vocab, char_vocab: Vocab,
- char_embedding=100,
- kernel_size=3,
- filters=50,
- dropout=0.5,
- trainable=True, name=None, dtype=None, dynamic=False,
- **kwargs):
- super().__init__(trainable, name, dtype, dynamic, **kwargs)
- self.char_embedding = char_embedding
- self.filters = filters
- self.kernel_size = kernel_size
- self.char_vocab = char_vocab
- self.word_vocab = word_vocab
- self.embedding = tf.keras.layers.Embedding(input_dim=len(self.char_vocab), output_dim=char_embedding,
- trainable=True, mask_zero=True)
- self.dropout = tf.keras.layers.Dropout(dropout)
- self.cnn = tf.keras.layers.Conv1D(filters, kernel_size, padding='same')
-
- def call(self, inputs: tf.Tensor, **kwargs):
- mask = tf.not_equal(inputs, self.word_vocab.pad_token)
- inputs = tf.ragged.boolean_mask(inputs, mask)
- chars = tf.strings.unicode_split(inputs, input_encoding='UTF-8')
- chars = chars.to_tensor(default_value=self.char_vocab.pad_token)
- chars = self.char_vocab.lookup(chars)
- embed = self.embedding(chars)
- weights = embed._keras_mask
- embed = self.dropout(embed)
- features = masked_conv1d_and_max(embed, weights, self.cnn)
- features._keras_mask = mask
- return features
-
- def compute_output_shape(self, input_shape):
- return super().compute_output_shape(input_shape)
-
- def get_config(self):
- config = {
- 'char_embedding': self.char_embedding,
- 'kernel_size': self.kernel_size,
- 'filters': self.filters,
- 'dropout': self.dropout.rate,
- }
- base_config = super(CharCNNEmbedding, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
-
-
-def masked_conv1d_and_max(t, weights, conv1d):
- """Applies 1d convolution and a masked max-pooling
-
- https://github.com/guillaumegenthial/tf_ner/blob/master/models/chars_conv_lstm_crf/masked_conv.py
-
- Parameters
- ----------
- t : tf.Tensor
- A tensor with at least 3 dimensions [d1, d2, ..., dn-1, dn]
- weights : tf.Tensor of tf.bool
- A Tensor of shape [d1, d2, dn-1]
- filters : int
- number of filters
- kernel_size : int
- kernel size for the temporal convolution
-
- Returns
- -------
- tf.Tensor
- A tensor of shape [d1, d2, dn-1, filters]
-
- """
- # Get shape and parameters
- shape = tf.shape(t)
- ndims = t.shape.ndims
- dim1 = reduce(lambda x, y: x * y, [shape[i] for i in range(ndims - 2)])
- dim2 = shape[-2]
- dim3 = t.shape[-1]
-
- # Reshape weights
- weights = tf.reshape(weights, shape=[dim1, dim2, 1])
- weights = tf.cast(weights, tf.float32)
-
- # Reshape input and apply weights
- flat_shape = [dim1, dim2, dim3]
- t = tf.reshape(t, shape=flat_shape)
- t *= weights
-
- # Apply convolution
- t_conv = conv1d(t)
- t_conv *= weights
-
- # Reduce max -- set to zero if all padded
- t_conv += (1. - weights) * tf.reduce_min(t_conv, axis=-2, keepdims=True)
- t_max = tf.reduce_max(t_conv, axis=-2)
-
- # Reshape the output
- final_shape = [shape[i] for i in range(ndims - 2)] + [conv1d.filters]
- t_max = tf.reshape(t_max, shape=final_shape)
-
- return t_max
+from hanlp.layers.embeddings.embedding import EmbeddingDim, Embedding
+
+
+class CharCNN(nn.Module):
+ def __init__(self,
+ field: str,
+ embed: Union[int, Embedding], num_filters: int,
+ ngram_filter_sizes: Tuple[int, ...] = (2, 3, 4, 5),
+ conv_layer_activation: str = 'ReLU',
+ output_dim: Optional[int] = None,
+ vocab_size=None) -> None:
+ """A `CnnEncoder` is a combination of multiple convolution layers and max pooling layers.
+ The input to this module is of shape `(batch_size, num_tokens,
+ input_dim)`, and the output is of shape `(batch_size, output_dim)`.
+
+ The CNN has one convolution layer for each ngram filter size. Each convolution operation gives
+ out a vector of size num_filters. The number of times a convolution layer will be used
+ is `num_tokens - ngram_size + 1`. The corresponding maxpooling layer aggregates all these
+ outputs from the convolution layer and outputs the max.
+
+ This operation is repeated for every ngram size passed, and consequently the dimensionality of
+ the output after maxpooling is `len(ngram_filter_sizes) * num_filters`. This then gets
+ (optionally) projected down to a lower dimensional output, specified by `output_dim`.
+
+ We then use a fully connected layer to project in back to the desired output_dim. For more
+ details, refer to "A Sensitivity Analysis of (and Practitioners’ Guide to) Convolutional Neural
+ Networks for Sentence Classification", Zhang and Wallace 2016, particularly Figure 1.
+
+ See allennlp.modules.seq2vec_encoders.cnn_encoder.CnnEncoder, Apache 2.0
+
+ Args:
+ field: The field in samples this encoder will work on.
+ embed: An ``Embedding`` object or the feature size to create an ``Embedding`` object.
+ num_filters: This is the output dim for each convolutional layer, which is the number of "filters"
+ learned by that layer.
+ ngram_filter_sizes: This specifies both the number of convolutional layers we will create and their sizes. The
+ default of `(2, 3, 4, 5)` will have four convolutional layers, corresponding to encoding
+ ngrams of size 2 to 5 with some number of filters.
+ conv_layer_activation: `Activation`, optional (default=`torch.nn.ReLU`)
+ Activation to use after the convolution layers.
+ output_dim: After doing convolutions and pooling, we'll project the collected features into a vector of
+ this size. If this value is `None`, we will just return the result of the max pooling,
+ giving an output of shape `len(ngram_filter_sizes) * num_filters`.
+ vocab_size: The size of character vocab.
+
+ Returns:
+ A tensor of shape `(batch_size, output_dim)`.
+ """
+ super().__init__()
+ EmbeddingDim.__init__(self)
+ # the embedding layer
+ if isinstance(embed, int):
+ embed = nn.Embedding(num_embeddings=vocab_size,
+ embedding_dim=embed)
+ else:
+ raise ValueError(f'Unrecognized type for {embed}')
+ self.field = field
+ self.embed = TimeDistributed(embed)
+ self.encoder = TimeDistributed(
+ CnnEncoder(embed.embedding_dim, num_filters, ngram_filter_sizes, conv_layer_activation, output_dim))
+ self.embedding_dim = output_dim or num_filters * len(ngram_filter_sizes)
+
+ def forward(self, batch: dict, **kwargs):
+ tokens: torch.Tensor = batch[f'{self.field}_char_id']
+ mask = tokens.ge(0)
+ x = self.embed(tokens)
+ return self.encoder(x, mask)
+
+ def get_output_dim(self) -> int:
+ return self.embedding_dim
+
+
+class CharCNNEmbedding(Embedding, AutoConfigurable):
+ def __init__(self,
+ field,
+ embed: Union[int, Embedding],
+ num_filters: int,
+ ngram_filter_sizes: Tuple[int, ...] = (2, 3, 4, 5),
+ conv_layer_activation: str = 'ReLU',
+ output_dim: Optional[int] = None,
+ min_word_length=None
+ ) -> None:
+ """
+
+ Args:
+ field: The character field in samples this encoder will work on.
+ embed: An ``Embedding`` object or the feature size to create an ``Embedding`` object.
+ num_filters: This is the output dim for each convolutional layer, which is the number of "filters"
+ learned by that layer.
+ ngram_filter_sizes: This specifies both the number of convolutional layers we will create and their sizes. The
+ default of `(2, 3, 4, 5)` will have four convolutional layers, corresponding to encoding
+ ngrams of size 2 to 5 with some number of filters.
+ conv_layer_activation: `Activation`, optional (default=`torch.nn.ReLU`)
+ Activation to use after the convolution layers.
+ output_dim: After doing convolutions and pooling, we'll project the collected features into a vector of
+ this size. If this value is `None`, we will just return the result of the max pooling,
+ giving an output of shape `len(ngram_filter_sizes) * num_filters`.
+ min_word_length: For ngram filter with max size, the input (chars) is required to have at least max size
+ chars.
+ """
+ super().__init__()
+ if min_word_length is None:
+ min_word_length = max(ngram_filter_sizes)
+ self.min_word_length = min_word_length
+ self.output_dim = output_dim
+ self.conv_layer_activation = conv_layer_activation
+ self.ngram_filter_sizes = ngram_filter_sizes
+ self.num_filters = num_filters
+ self.embed = embed
+ self.field = field
+
+ def transform(self, vocabs: VocabDict, **kwargs) -> Optional[Callable]:
+ if isinstance(self.embed, Embedding):
+ self.embed.transform(vocabs=vocabs)
+ vocab_name = self.vocab_name
+ if vocab_name not in vocabs:
+ vocabs[vocab_name] = Vocab()
+ return ToChar(self.field, vocab_name, min_word_length=self.min_word_length,
+ pad=vocabs[vocab_name].safe_pad_token)
+
+ @property
+ def vocab_name(self):
+ vocab_name = f'{self.field}_char'
+ return vocab_name
+
+ def module(self, vocabs: VocabDict, **kwargs) -> Optional[nn.Module]:
+ embed = self.embed
+ if isinstance(embed, Embedding):
+ embed = embed.module(vocabs=vocabs)
+ return CharCNN(self.field,
+ embed,
+ self.num_filters,
+ self.ngram_filter_sizes,
+ self.conv_layer_activation,
+ self.output_dim,
+ vocab_size=len(vocabs[self.vocab_name]))
diff --git a/hanlp/layers/embeddings/char_cnn_tf.py b/hanlp/layers/embeddings/char_cnn_tf.py
new file mode 100644
index 000000000..e36420a6b
--- /dev/null
+++ b/hanlp/layers/embeddings/char_cnn_tf.py
@@ -0,0 +1,103 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-20 21:15
+from functools import reduce
+
+import tensorflow as tf
+
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.utils.tf_util import hanlp_register
+
+
+@hanlp_register
+class CharCNNEmbeddingTF(tf.keras.layers.Layer):
+ def __init__(self, word_vocab: VocabTF, char_vocab: VocabTF,
+ char_embedding=100,
+ kernel_size=3,
+ filters=50,
+ dropout=0.5,
+ trainable=True, name=None, dtype=None, dynamic=False,
+ **kwargs):
+ super().__init__(trainable, name, dtype, dynamic, **kwargs)
+ self.char_embedding = char_embedding
+ self.filters = filters
+ self.kernel_size = kernel_size
+ self.char_vocab = char_vocab
+ self.word_vocab = word_vocab
+ self.embedding = tf.keras.layers.Embedding(input_dim=len(self.char_vocab), output_dim=char_embedding,
+ trainable=True, mask_zero=True)
+ self.dropout = tf.keras.layers.Dropout(dropout)
+ self.cnn = tf.keras.layers.Conv1D(filters, kernel_size, padding='same')
+
+ def call(self, inputs: tf.Tensor, **kwargs):
+ mask = tf.not_equal(inputs, self.word_vocab.pad_token)
+ inputs = tf.ragged.boolean_mask(inputs, mask)
+ chars = tf.strings.unicode_split(inputs, input_encoding='UTF-8')
+ chars = chars.to_tensor(default_value=self.char_vocab.pad_token)
+ chars = self.char_vocab.lookup(chars)
+ embed = self.embedding(chars)
+ weights = embed._keras_mask
+ embed = self.dropout(embed)
+ features = masked_conv1d_and_max(embed, weights, self.cnn)
+ features._keras_mask = mask
+ return features
+
+ def compute_output_shape(self, input_shape):
+ return super().compute_output_shape(input_shape)
+
+ def get_config(self):
+ config = {
+ 'char_embedding': self.char_embedding,
+ 'kernel_size': self.kernel_size,
+ 'filters': self.filters,
+ 'dropout': self.dropout.rate,
+ }
+ base_config = super(CharCNNEmbeddingTF, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+def masked_conv1d_and_max(t, weights, conv1d):
+ """Applies 1d convolution and a masked max-pooling
+
+ https://github.com/guillaumegenthial/tf_ner/blob/master/models/chars_conv_lstm_crf/masked_conv.py
+
+ Args:
+ t(tf.Tensor): A tensor with at least 3 dimensions [d1, d2, ..., dn-1, dn]
+ weights(tf.Tensor of tf.bool): A Tensor of shape [d1, d2, dn-1]
+ filters(int): number of filters
+ kernel_size(int): kernel size for the temporal convolution
+ conv1d:
+
+ Returns:
+
+
+ """
+ # Get shape and parameters
+ shape = tf.shape(t)
+ ndims = t.shape.ndims
+ dim1 = reduce(lambda x, y: x * y, [shape[i] for i in range(ndims - 2)])
+ dim2 = shape[-2]
+ dim3 = t.shape[-1]
+
+ # Reshape weights
+ weights = tf.reshape(weights, shape=[dim1, dim2, 1])
+ weights = tf.cast(weights, tf.float32)
+
+ # Reshape input and apply weights
+ flat_shape = [dim1, dim2, dim3]
+ t = tf.reshape(t, shape=flat_shape)
+ t *= weights
+
+ # Apply convolution
+ t_conv = conv1d(t)
+ t_conv *= weights
+
+ # Reduce max -- set to zero if all padded
+ t_conv += (1. - weights) * tf.reduce_min(t_conv, axis=-2, keepdims=True)
+ t_max = tf.reduce_max(t_conv, axis=-2)
+
+ # Reshape the output
+ final_shape = [shape[i] for i in range(ndims - 2)] + [conv1d.filters]
+ t_max = tf.reshape(t_max, shape=final_shape)
+
+ return t_max
diff --git a/hanlp/layers/embeddings/char_rnn.py b/hanlp/layers/embeddings/char_rnn.py
index 984688ba4..aa5f362a9 100644
--- a/hanlp/layers/embeddings/char_rnn.py
+++ b/hanlp/layers/embeddings/char_rnn.py
@@ -1,61 +1,109 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-12-20 17:02
-import tensorflow as tf
+# Date: 2020-06-02 23:49
+from typing import Optional, Callable, Union
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pack_padded_sequence
+
+from hanlp_common.configurable import AutoConfigurable
+from hanlp.common.transform import VocabDict, ToChar
from hanlp.common.vocab import Vocab
-from hanlp.utils.tf_util import hanlp_register
-
-
-@hanlp_register
-class CharRNNEmbedding(tf.keras.layers.Layer):
- def __init__(self, word_vocab: Vocab, char_vocab: Vocab,
- char_embedding=100,
- char_rnn_units=25,
- dropout=0.5,
- trainable=True, name=None, dtype=None, dynamic=False,
- **kwargs):
- super().__init__(trainable, name, dtype, dynamic, **kwargs)
- self.char_embedding = char_embedding
- self.char_rnn_units = char_rnn_units
- self.char_vocab = char_vocab
- self.word_vocab = word_vocab
- self.embedding = tf.keras.layers.Embedding(input_dim=len(self.char_vocab), output_dim=char_embedding,
- trainable=True, mask_zero=True)
- self.dropout = tf.keras.layers.Dropout(dropout)
- self.rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=char_rnn_units,
- return_state=True), name='bilstm')
-
- def call(self, inputs: tf.Tensor, **kwargs):
- mask = tf.not_equal(inputs, self.word_vocab.pad_token)
- inputs = tf.ragged.boolean_mask(inputs, mask)
- chars = tf.strings.unicode_split(inputs, input_encoding='UTF-8')
- chars = chars.to_tensor(default_value=self.char_vocab.pad_token)
- chars = self.char_vocab.lookup(chars)
- embed = self.embedding(chars)
- char_mask = embed._keras_mask
- embed = self.dropout(embed)
- embed_shape = tf.shape(embed)
- embed = tf.reshape(embed, [-1, embed_shape[2], embed_shape[3]])
- char_mask = tf.reshape(char_mask, [-1, embed_shape[2]])
- all_zeros = tf.reduce_sum(tf.cast(char_mask, tf.int32), axis=1) == 0
- char_mask_shape = tf.shape(char_mask)
- hole = tf.zeros(shape=(char_mask_shape[0], char_mask_shape[1] - 1), dtype=tf.bool)
- all_zeros = tf.expand_dims(all_zeros, -1)
- non_all_zeros = tf.concat([all_zeros, hole], axis=1)
- char_mask = tf.logical_or(char_mask, non_all_zeros)
- output, h_fw, c_fw, h_bw, c_bw = self.rnn(embed, mask=char_mask)
- hidden = tf.concat([h_fw, h_bw], axis=-1)
- # hidden = output
- hidden = tf.reshape(hidden, [embed_shape[0], embed_shape[1], -1])
- hidden._keras_mask = mask
- return hidden
-
- def get_config(self):
- config = {
- 'char_embedding': self.char_embedding,
- 'char_rnn_units': self.char_rnn_units,
- 'dropout': self.dropout.rate,
- }
- base_config = super(CharRNNEmbedding, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
+from hanlp.layers.embeddings.embedding import Embedding, EmbeddingDim
+
+
+class CharRNN(nn.Module, EmbeddingDim):
+ def __init__(self,
+ field,
+ vocab_size,
+ embed: Union[int, nn.Embedding],
+ hidden_size):
+ """Character level RNN embedding module.
+
+ Args:
+ field: The field in samples this encoder will work on.
+ vocab_size: The size of character vocab.
+ embed: An ``Embedding`` object or the feature size to create an ``Embedding`` object.
+ hidden_size: The hidden size of RNNs.
+ """
+ super(CharRNN, self).__init__()
+ self.field = field
+ # the embedding layer
+ if isinstance(embed, int):
+ self.embed = nn.Embedding(num_embeddings=vocab_size,
+ embedding_dim=embed)
+ elif isinstance(embed, nn.Module):
+ self.embed = embed
+ embed = embed.embedding_dim
+ else:
+ raise ValueError(f'Unrecognized type for {embed}')
+ # the lstm layer
+ self.lstm = nn.LSTM(input_size=embed,
+ hidden_size=hidden_size,
+ batch_first=True,
+ bidirectional=True)
+
+ def forward(self, batch, mask, **kwargs):
+ x = batch[f'{self.field}_char_id']
+ # [batch_size, seq_len, fix_len]
+ mask = x.ne(0)
+ # [batch_size, seq_len]
+ lens = mask.sum(-1)
+ char_mask = lens.gt(0)
+
+ # [n, fix_len, n_embed]
+ x = self.embed(x[char_mask])
+ x = pack_padded_sequence(x, lens[char_mask], True, False)
+ x, (h, _) = self.lstm(x)
+ # [n, fix_len, n_out]
+ h = torch.cat(torch.unbind(h), -1)
+ # [batch_size, seq_len, n_out]
+ embed = h.new_zeros(*lens.shape, h.size(-1))
+ embed = embed.masked_scatter_(char_mask.unsqueeze(-1), h)
+
+ return embed
+
+ @property
+ def embedding_dim(self) -> int:
+ return self.lstm.hidden_size * 2
+
+
+class CharRNNEmbedding(Embedding, AutoConfigurable):
+ def __init__(self,
+ field,
+ embed,
+ hidden_size,
+ max_word_length=None) -> None:
+ """Character level RNN embedding module builder.
+
+ Args:
+ field: The field in samples this encoder will work on.
+ embed: An ``Embedding`` object or the feature size to create an ``Embedding`` object.
+ hidden_size: The hidden size of RNNs.
+ max_word_length: Character sequence longer than ``max_word_length`` will be truncated.
+ """
+ super().__init__()
+ self.field = field
+ self.hidden_size = hidden_size
+ self.embed = embed
+ self.max_word_length = max_word_length
+
+ def transform(self, vocabs: VocabDict, **kwargs) -> Optional[Callable]:
+ if isinstance(self.embed, Embedding):
+ self.embed.transform(vocabs=vocabs)
+ vocab_name = self.vocab_name
+ if vocab_name not in vocabs:
+ vocabs[vocab_name] = Vocab()
+ return ToChar(self.field, vocab_name, max_word_length=self.max_word_length)
+
+ @property
+ def vocab_name(self):
+ vocab_name = f'{self.field}_char'
+ return vocab_name
+
+ def module(self, vocabs: VocabDict, **kwargs) -> Optional[nn.Module]:
+ embed = self.embed
+ if isinstance(self.embed, Embedding):
+ embed = self.embed.module(vocabs=vocabs)
+ return CharRNN(self.field, len(vocabs[self.vocab_name]), embed, self.hidden_size)
diff --git a/hanlp/layers/embeddings/char_rnn_tf.py b/hanlp/layers/embeddings/char_rnn_tf.py
new file mode 100644
index 000000000..e27ab486c
--- /dev/null
+++ b/hanlp/layers/embeddings/char_rnn_tf.py
@@ -0,0 +1,61 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-20 17:02
+import tensorflow as tf
+
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.utils.tf_util import hanlp_register
+
+
+@hanlp_register
+class CharRNNEmbeddingTF(tf.keras.layers.Layer):
+ def __init__(self, word_vocab: VocabTF, char_vocab: VocabTF,
+ char_embedding=100,
+ char_rnn_units=25,
+ dropout=0.5,
+ trainable=True, name=None, dtype=None, dynamic=False,
+ **kwargs):
+ super().__init__(trainable, name, dtype, dynamic, **kwargs)
+ self.char_embedding = char_embedding
+ self.char_rnn_units = char_rnn_units
+ self.char_vocab = char_vocab
+ self.word_vocab = word_vocab
+ self.embedding = tf.keras.layers.Embedding(input_dim=len(self.char_vocab), output_dim=char_embedding,
+ trainable=True, mask_zero=True)
+ self.dropout = tf.keras.layers.Dropout(dropout)
+ self.rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=char_rnn_units,
+ return_state=True), name='bilstm')
+
+ def call(self, inputs: tf.Tensor, **kwargs):
+ mask = tf.not_equal(inputs, self.word_vocab.pad_token)
+ inputs = tf.ragged.boolean_mask(inputs, mask)
+ chars = tf.strings.unicode_split(inputs, input_encoding='UTF-8')
+ chars = chars.to_tensor(default_value=self.char_vocab.pad_token)
+ chars = self.char_vocab.lookup(chars)
+ embed = self.embedding(chars)
+ char_mask = embed._keras_mask
+ embed = self.dropout(embed)
+ embed_shape = tf.shape(embed)
+ embed = tf.reshape(embed, [-1, embed_shape[2], embed_shape[3]])
+ char_mask = tf.reshape(char_mask, [-1, embed_shape[2]])
+ all_zeros = tf.reduce_sum(tf.cast(char_mask, tf.int32), axis=1) == 0
+ char_mask_shape = tf.shape(char_mask)
+ hole = tf.zeros(shape=(char_mask_shape[0], char_mask_shape[1] - 1), dtype=tf.bool)
+ all_zeros = tf.expand_dims(all_zeros, -1)
+ non_all_zeros = tf.concat([all_zeros, hole], axis=1)
+ char_mask = tf.logical_or(char_mask, non_all_zeros)
+ output, h_fw, c_fw, h_bw, c_bw = self.rnn(embed, mask=char_mask)
+ hidden = tf.concat([h_fw, h_bw], axis=-1)
+ # hidden = output
+ hidden = tf.reshape(hidden, [embed_shape[0], embed_shape[1], -1])
+ hidden._keras_mask = mask
+ return hidden
+
+ def get_config(self):
+ config = {
+ 'char_embedding': self.char_embedding,
+ 'char_rnn_units': self.char_rnn_units,
+ 'dropout': self.dropout.rate,
+ }
+ base_config = super(CharRNNEmbeddingTF, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
diff --git a/hanlp/layers/embeddings/contextual_string_embedding.py b/hanlp/layers/embeddings/contextual_string_embedding.py
index 079d48073..d415b87b7 100644
--- a/hanlp/layers/embeddings/contextual_string_embedding.py
+++ b/hanlp/layers/embeddings/contextual_string_embedding.py
@@ -1,138 +1,216 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-19 03:24
-from typing import List
-
-import tensorflow as tf
-import numpy as np
-from hanlp.components.rnn_language_model import RNNLanguageModel
-from hanlp.common.constant import PAD
+# Most codes of this file is adopted from flair, which is licenced under:
+#
+# The MIT License (MIT)
+#
+# Flair is licensed under the following MIT License (MIT) Copyright © 2018 Zalando SE, https://tech.zalando.com
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+import os
+from typing import List, Dict, Callable
+
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+
+from hanlp_common.configurable import Configurable
+from hanlp.common.transform import TransformList, FieldToIndex
+from hanlp.common.vocab import Vocab
+from hanlp.layers.embeddings.embedding import Embedding, EmbeddingDim
from hanlp.utils.io_util import get_resource
-from hanlp.utils.tf_util import copy_mask, hanlp_register, str_tensor_2d_to_list
-from hanlp.utils.util import infer_space_after
-
-
-@hanlp_register
-class ContextualStringEmbedding(tf.keras.layers.Layer):
-
- def __init__(self, forward_model_path=None, backward_model_path=None, max_word_len=10,
- trainable=False, name=None, dtype=None,
- dynamic=True, **kwargs):
- assert dynamic, 'ContextualStringEmbedding works only in eager mode'
- super().__init__(trainable, name, dtype, dynamic, **kwargs)
- assert any([forward_model_path, backward_model_path]), 'At least one model is required'
- self.forward_model_path = forward_model_path
- self.backward_model_path = backward_model_path
- self.forward_model = self._load_lm(forward_model_path) if forward_model_path else None
- self.backward_model = self._load_lm(backward_model_path) if backward_model_path else None
- if trainable:
- self._fw = self.forward_model.model
- self._bw = self.backward_model.model
- for m in self._fw, self._bw:
- m.trainable = True
- self.supports_masking = True
- self.max_word_len = max_word_len
-
- def call(self, inputs, **kwargs):
- str_inputs = str_tensor_2d_to_list(inputs)
- outputs = self.embed(str_inputs)
- copy_mask(inputs, outputs)
- return outputs
-
- def _load_lm(self, filepath):
- filepath = get_resource(filepath)
- lm = RNNLanguageModel()
- lm.load(filepath)
- model: tf.keras.Sequential = lm.model
- for idx, layer in enumerate(model.layers):
- if isinstance(layer, tf.keras.layers.LSTM):
- lm.model = tf.keras.Sequential(model.layers[:idx + 1]) # discard dense layer
- return lm
-
- def embed(self, texts: List[List[str]]):
- """
- Embedding sentences (list of words) with contextualized string embedding
-
- Parameters
- ----------
- texts :
- List of words, not chars
-
- Returns
- -------
- tf.Tensor
- A 3d tensor of (batch, num_words, hidden)
- """
- fw = None
- if self.forward_model:
- fw = self._run_rnn(texts, model=self.forward_model)
- bw = None
- if self.backward_model:
- bw = self._run_rnn(texts, model=self.backward_model)
- if not all(x is not None for x in [fw, bw]):
- return fw if fw is not None else bw
- else:
- return tf.concat([fw, bw], axis=-1)
-
- def _run_rnn(self, texts, model):
- embeddings = []
- inputs = []
- offsets = []
- tokenizer = model.transform.tokenize_func()
- backward = not model.config['forward']
- for sent in texts:
- raw, off = self._get_raw_string(sent, tokenizer)
- inputs.append(raw)
- offsets.append(off)
- outputs = model.model.predict(model.transform.inputs_to_dataset(inputs))
- if backward:
- outputs = tf.reverse(outputs, axis=[1])
- maxlen = len(max(texts, key=len))
- for hidden, off, sent in zip(outputs, offsets, texts):
- embed = []
- for (start, end), word in zip(off, sent):
- embed.append(hidden[end - 1, :])
- if len(embed) < maxlen:
- embed += [np.zeros_like(embed[-1])] * (maxlen - len(embed))
- embeddings.append(np.stack(embed))
- return tf.stack(embeddings)
-
- def _get_raw_string(self, sent: List[str], tokenizer):
- raw_string = []
- offsets = []
- whitespace_after = infer_space_after(sent)
- start = 0
- for word, space in zip(sent, whitespace_after):
- chars = tokenizer(word)
- chars = chars[:self.max_word_len]
- if space:
- chars += [' ']
- end = start + len(chars)
- offsets.append((start, end))
- start = end
- raw_string += chars
- return raw_string, offsets
-
- def get_config(self):
- config = {
- 'forward_model_path': self.forward_model_path,
- 'backward_model_path': self.backward_model_path,
- 'max_word_len': self.max_word_len,
+from hanlp.utils.torch_util import pad_lists, batched_index_select
+from tests import cdroot
+
+
+class RNNLanguageModel(nn.Module):
+ """Container module with an encoder, a recurrent module, and a decoder."""
+
+ def __init__(self,
+ n_tokens,
+ is_forward_lm: bool,
+ hidden_size: int,
+ embedding_size: int = 100):
+ super(RNNLanguageModel, self).__init__()
+
+ self.is_forward_lm: bool = is_forward_lm
+ self.n_tokens = n_tokens
+ self.hidden_size = hidden_size
+ self.embedding_size = embedding_size
+
+ self.encoder = nn.Embedding(n_tokens, embedding_size)
+ self.rnn = nn.LSTM(embedding_size, hidden_size, batch_first=True)
+
+ def forward(self, ids: torch.LongTensor, lens: torch.LongTensor):
+ emb = self.encoder(ids)
+ x = pack_padded_sequence(emb, lens, True, False)
+ x, _ = self.rnn(x)
+ x, _ = pad_packed_sequence(x, True)
+ return x
+
+ @classmethod
+ def load_language_model(cls, model_file):
+ model_file = get_resource(model_file)
+ state = torch.load(model_file)
+ model = RNNLanguageModel(state['n_tokens'],
+ state['is_forward_lm'],
+ state['hidden_size'],
+ state['embedding_size'])
+ model.load_state_dict(state['state_dict'], strict=False)
+ return model
+
+ def save(self, file):
+ model_state = {
+ 'state_dict': self.state_dict(),
+ 'n_tokens': self.n_tokens,
+ 'is_forward_lm': self.is_forward_lm,
+ 'hidden_size': self.hidden_size,
+ 'embedding_size': self.embedding_size,
}
- base_config = super(ContextualStringEmbedding, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
+ torch.save(model_state, file, pickle_protocol=4)
- @property
- def output_dim(self):
- dim = 0
- for model in self.forward_model, self.backward_model:
- if model:
- dim += model.config['rnn_units']
- return dim
- def compute_output_shape(self, input_shape):
- return input_shape + self.output_dim
+class ContextualStringEmbeddingModule(nn.Module, EmbeddingDim):
+
+ def __init__(self, field: str, path: str, trainable=False) -> None:
+ super().__init__()
+ self.field = field
+ path = get_resource(path)
+ f = os.path.join(path, 'forward.pt')
+ b = os.path.join(path, 'backward.pt')
+ self.f: RNNLanguageModel = RNNLanguageModel.load_language_model(f)
+ self.b: RNNLanguageModel = RNNLanguageModel.load_language_model(b)
+ if not trainable:
+ for p in self.parameters():
+ p.requires_grad_(False)
- def compute_mask(self, inputs, mask=None):
+ def __call__(self, batch: dict, **kwargs):
+ args = ['f_char_id', 'f_offset', 'b_char_id', 'b_offset']
+ keys = [f'{self.field}_{key}' for key in args]
+ args = [batch[key] for key in keys]
+ return super().__call__(*args, **kwargs)
- return tf.not_equal(inputs, PAD)
+ @property
+ def embedding_dim(self):
+ return self.f.rnn.hidden_size + self.b.rnn.hidden_size
+
+ def run_lm(self, lm, ids: torch.Tensor, offsets: torch.LongTensor):
+ lens = offsets.max(-1)[0] + 1
+ rnn_output = lm(ids, lens)
+ return batched_index_select(rnn_output, offsets)
+
+ def forward(self,
+ f_chars_id: torch.Tensor,
+ f_offset: torch.LongTensor,
+ b_chars_id: torch.Tensor,
+ b_offset: torch.LongTensor, **kwargs):
+ f = self.run_lm(self.f, f_chars_id, f_offset)
+ b = self.run_lm(self.b, b_chars_id, b_offset)
+ return torch.cat([f, b], dim=-1)
+
+ def embed(self, sents: List[List[str]], vocab: Dict[str, int]):
+ f_chars, f_offsets = [], []
+ b_chars, b_offsets = [], []
+
+ transform = ContextualStringEmbeddingTransform('token')
+ for tokens in sents:
+ sample = transform({'token': tokens})
+ for each, name in zip([f_chars, b_chars, f_offsets, b_offsets],
+ 'f_chars, b_chars, f_offsets, b_offsets'.split(', ')):
+ each.append(sample[f'token_{name}'])
+ f_ids = []
+ for cs in f_chars:
+ f_ids.append([vocab[c] for c in cs])
+ f_ids = pad_lists(f_ids)
+ f_offsets = pad_lists(f_offsets)
+
+ b_ids = []
+ for cs in b_chars:
+ b_ids.append([vocab[c] for c in cs])
+ b_ids = pad_lists(b_ids)
+ b_offsets = pad_lists(b_offsets)
+ return self.forward(f_ids, f_offsets, b_ids, b_offsets)
+
+
+class ContextualStringEmbeddingTransform(Configurable):
+
+ def __init__(self, src: str) -> None:
+ self.src = src
+
+ def __call__(self, sample: dict):
+ tokens = sample[self.src]
+ f_o = []
+ b_o = []
+ sentence_text = ' '.join(tokens)
+ end_marker = ' '
+ extra_offset = 1
+ # f
+ input_text = '\n' + sentence_text + end_marker
+ f_chars = list(input_text)
+ # b
+ sentence_text = sentence_text[::-1]
+ input_text = '\n' + sentence_text + end_marker
+ b_chars = list(input_text)
+ offset_forward: int = extra_offset
+ offset_backward: int = len(sentence_text) + extra_offset
+ for token in tokens:
+ offset_forward += len(token)
+
+ f_o.append(offset_forward)
+ b_o.append(offset_backward)
+
+ # This language model is tokenized
+ offset_forward += 1
+ offset_backward -= 1
+
+ offset_backward -= len(token)
+ sample[f'{self.src}_f_char'] = f_chars
+ sample[f'{self.src}_b_char'] = b_chars
+ sample[f'{self.src}_f_offset'] = f_o
+ sample[f'{self.src}_b_offset'] = b_o
+ return sample
+
+
+class ContextualStringEmbedding(Embedding):
+ def __init__(self, field, path, trainable=False) -> None:
+ super().__init__()
+ self.trainable = trainable
+ self.path = path
+ self.field = field
+
+ def transform(self, **kwargs) -> Callable:
+ vocab = Vocab()
+ vocab.load(os.path.join(get_resource(self.path), 'vocab.json'))
+ return TransformList(ContextualStringEmbeddingTransform(self.field),
+ FieldToIndex(f'{self.field}_f_char', vocab),
+ FieldToIndex(f'{self.field}_b_char', vocab))
+
+ def module(self, **kwargs) -> nn.Module:
+ return ContextualStringEmbeddingModule(self.field, self.path, self.trainable)
+
+
+def main():
+ # _validate()
+ flair = ContextualStringEmbedding('token', 'FASTTEXT_DEBUG_EMBEDDING_EN')
+ print(flair.config)
+
+
+def _validate():
+ cdroot()
+ flair = ContextualStringEmbeddingModule('token', 'FLAIR_LM_WMT11_EN')
+ vocab = torch.load('/home/hhe43/flair/item2idx.pt')
+ vocab = dict((x.decode(), y) for x, y in vocab.items())
+ # vocab = Vocab(token_to_idx=vocab, pad_token='')
+ # vocab.lock()
+ # vocab.summary()
+ # vocab.save('vocab.json')
+ tokens = 'I love Berlin .'.split()
+ sent = ' '.join(tokens)
+ embed = flair.embed([tokens, tokens], vocab)
+ gold = torch.load('/home/hhe43/flair/gold.pt')
+ print(torch.allclose(embed[1, :, :2048], gold, atol=1e-6))
+ # print(torch.all(torch.eq(embed[1, :, :], gold)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/hanlp/layers/embeddings/contextual_string_embedding_tf.py b/hanlp/layers/embeddings/contextual_string_embedding_tf.py
new file mode 100644
index 000000000..124487e3d
--- /dev/null
+++ b/hanlp/layers/embeddings/contextual_string_embedding_tf.py
@@ -0,0 +1,135 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-19 03:24
+from typing import List
+
+import tensorflow as tf
+import numpy as np
+from hanlp.components.rnn_language_model import RNNLanguageModel
+from hanlp_common.constant import PAD
+from hanlp.utils.io_util import get_resource
+from hanlp.utils.tf_util import copy_mask, hanlp_register, str_tensor_2d_to_list
+from hanlp_common.util import infer_space_after
+
+
+@hanlp_register
+class ContextualStringEmbeddingTF(tf.keras.layers.Layer):
+
+ def __init__(self, forward_model_path=None, backward_model_path=None, max_word_len=10,
+ trainable=False, name=None, dtype=None,
+ dynamic=True, **kwargs):
+ assert dynamic, 'ContextualStringEmbedding works only in eager mode'
+ super().__init__(trainable, name, dtype, dynamic, **kwargs)
+ assert any([forward_model_path, backward_model_path]), 'At least one model is required'
+ self.forward_model_path = forward_model_path
+ self.backward_model_path = backward_model_path
+ self.forward_model = self._load_lm(forward_model_path) if forward_model_path else None
+ self.backward_model = self._load_lm(backward_model_path) if backward_model_path else None
+ if trainable:
+ self._fw = self.forward_model.model
+ self._bw = self.backward_model.model
+ for m in self._fw, self._bw:
+ m.trainable = True
+ self.supports_masking = True
+ self.max_word_len = max_word_len
+
+ def call(self, inputs, **kwargs):
+ str_inputs = str_tensor_2d_to_list(inputs)
+ outputs = self.embed(str_inputs)
+ copy_mask(inputs, outputs)
+ return outputs
+
+ def _load_lm(self, filepath):
+ filepath = get_resource(filepath)
+ lm = RNNLanguageModel()
+ lm.load(filepath)
+ model: tf.keras.Sequential = lm.model
+ for idx, layer in enumerate(model.layers):
+ if isinstance(layer, tf.keras.layers.LSTM):
+ lm.model = tf.keras.Sequential(model.layers[:idx + 1]) # discard dense layer
+ return lm
+
+ def embed(self, texts: List[List[str]]):
+ """Embedding sentences (list of words) with contextualized string embedding
+
+ Args:
+ texts: List of words, not chars
+ texts: List[List[str]]:
+
+ Returns:
+
+
+ """
+ fw = None
+ if self.forward_model:
+ fw = self._run_rnn(texts, model=self.forward_model)
+ bw = None
+ if self.backward_model:
+ bw = self._run_rnn(texts, model=self.backward_model)
+ if not all(x is not None for x in [fw, bw]):
+ return fw if fw is not None else bw
+ else:
+ return tf.concat([fw, bw], axis=-1)
+
+ def _run_rnn(self, texts, model):
+ embeddings = []
+ inputs = []
+ offsets = []
+ tokenizer = model.transform.tokenize_func()
+ backward = not model.config['forward']
+ for sent in texts:
+ raw, off = self._get_raw_string(sent, tokenizer)
+ inputs.append(raw)
+ offsets.append(off)
+ outputs = model.model_from_config.predict(model.transform.inputs_to_dataset(inputs))
+ if backward:
+ outputs = tf.reverse(outputs, axis=[1])
+ maxlen = len(max(texts, key=len))
+ for hidden, off, sent in zip(outputs, offsets, texts):
+ embed = []
+ for (start, end), word in zip(off, sent):
+ embed.append(hidden[end - 1, :])
+ if len(embed) < maxlen:
+ embed += [np.zeros_like(embed[-1])] * (maxlen - len(embed))
+ embeddings.append(np.stack(embed))
+ return tf.stack(embeddings)
+
+ def _get_raw_string(self, sent: List[str], tokenizer):
+ raw_string = []
+ offsets = []
+ whitespace_after = infer_space_after(sent)
+ start = 0
+ for word, space in zip(sent, whitespace_after):
+ chars = tokenizer(word)
+ chars = chars[:self.max_word_len]
+ if space:
+ chars += [' ']
+ end = start + len(chars)
+ offsets.append((start, end))
+ start = end
+ raw_string += chars
+ return raw_string, offsets
+
+ def get_config(self):
+ config = {
+ 'forward_model_path': self.forward_model_path,
+ 'backward_model_path': self.backward_model_path,
+ 'max_word_len': self.max_word_len,
+ }
+ base_config = super(ContextualStringEmbeddingTF, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ @property
+ def output_dim(self):
+ dim = 0
+ for model in self.forward_model, self.backward_model:
+ if model:
+ dim += model.config['rnn_units']
+ return dim
+
+ def compute_output_shape(self, input_shape):
+ return input_shape + self.output_dim
+
+ def compute_mask(self, inputs, mask=None):
+
+ return tf.not_equal(inputs, PAD)
diff --git a/hanlp/layers/embeddings/contextual_word_embedding.py b/hanlp/layers/embeddings/contextual_word_embedding.py
new file mode 100644
index 000000000..e05614804
--- /dev/null
+++ b/hanlp/layers/embeddings/contextual_word_embedding.py
@@ -0,0 +1,187 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-05 13:50
+from typing import Optional, Union, List, Any, Dict, Tuple
+
+import torch
+from torch import nn
+
+from hanlp_common.configurable import AutoConfigurable
+from hanlp.layers.embeddings.embedding import Embedding
+from hanlp.layers.scalar_mix import ScalarMixWithDropoutBuilder
+from hanlp.layers.transformers.encoder import TransformerEncoder
+from hanlp.layers.transformers.pt_imports import PreTrainedTokenizer, AutoConfig
+from hanlp.transform.transformer_tokenizer import TransformerSequenceTokenizer
+
+
+class ContextualWordEmbeddingModule(TransformerEncoder):
+ def __init__(self,
+ field: str,
+ transformer: str,
+ transformer_tokenizer: PreTrainedTokenizer,
+ average_subwords=False,
+ scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None,
+ word_dropout=None,
+ max_sequence_length=None,
+ ret_raw_hidden_states=False,
+ transformer_args: Dict[str, Any] = None,
+ trainable=True,
+ training=True) -> None:
+ """A contextualized word embedding module.
+
+ Args:
+ field: The field to work on. Usually some token fields.
+ transformer: An identifier of a ``PreTrainedModel``.
+ transformer_tokenizer:
+ average_subwords: ``True`` to average subword representations.
+ scalar_mix: Layer attention.
+ word_dropout: Dropout rate of randomly replacing a subword with MASK.
+ max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding
+ window.
+ ret_raw_hidden_states: ``True`` to return hidden states of each layer.
+ transformer_args: Extra arguments passed to the transformer.
+ trainable: ``False`` to use static embeddings.
+ training: ``False`` to skip loading weights from pre-trained transformers.
+ """
+ super().__init__(transformer, transformer_tokenizer, average_subwords, scalar_mix, word_dropout,
+ max_sequence_length, ret_raw_hidden_states, transformer_args, trainable,
+ training)
+ self.field = field
+
+ # noinspection PyMethodOverriding
+ # noinspection PyTypeChecker
+ def forward(self, batch: dict, mask=None, **kwargs):
+ input_ids: torch.LongTensor = batch[f'{self.field}_input_ids']
+ token_span: torch.LongTensor = batch.get(f'{self.field}_token_span', None)
+ # input_device = input_ids.device
+ # this_device = self.get_device()
+ # if input_device != this_device:
+ # input_ids = input_ids.to(this_device)
+ # token_span = token_span.to(this_device)
+ # We might want to apply mask here
+ output: Union[torch.Tensor, List[torch.Tensor]] = super().forward(input_ids, token_span=token_span, **kwargs)
+ # if input_device != this_device:
+ # if isinstance(output, torch.Tensor):
+ # output = output.to(input_device)
+ # else:
+ # output = [x.to(input_device) for x in output]
+ return output
+
+ def get_output_dim(self):
+ return self.transformer.config.hidden_size
+
+ def get_device(self):
+ device: torch.device = next(self.parameters()).device
+ return device
+
+
+class ContextualWordEmbedding(Embedding, AutoConfigurable):
+ def __init__(self, field: str,
+ transformer: str,
+ average_subwords=False,
+ scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None,
+ word_dropout: Optional[Union[float, Tuple[float, str]]] = None,
+ max_sequence_length=None,
+ truncate_long_sequences=False,
+ cls_is_bos=False,
+ sep_is_eos=False,
+ ret_token_span=True,
+ ret_subtokens=False,
+ ret_subtokens_group=False,
+ ret_prefix_mask=False,
+ ret_raw_hidden_states=False,
+ transformer_args: Dict[str, Any] = None,
+ use_fast=True,
+ do_basic_tokenize=True,
+ trainable=True) -> None:
+ """A contextual word embedding builder which builds a
+ :class:`~hanlp.layers.embeddings.contextual_word_embedding.ContextualWordEmbeddingModule` and a
+ :class:`~hanlp.transform.transformer_tokenizer.TransformerSequenceTokenizer`.
+
+ Args:
+ field: The field to work on. Usually some token fields.
+ transformer: An identifier of a ``PreTrainedModel``.
+ average_subwords: ``True`` to average subword representations.
+ scalar_mix: Layer attention.
+ word_dropout: Dropout rate of randomly replacing a subword with MASK.
+ max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding
+ window.
+ truncate_long_sequences: ``True`` to return hidden states of each layer.
+ cls_is_bos: ``True`` means the first token of input is treated as [CLS] no matter what its surface form is.
+ ``False`` (default) means the first token is not [CLS], it will have its own embedding other than
+ the embedding of [CLS].
+ sep_is_eos: ``True`` means the last token of input is [SEP].
+ ``False`` means it's not but [SEP] will be appended,
+ ``None`` means it dependents on `input[-1] == [EOS]`.
+ ret_token_span: ``True`` to return span of each token measured by subtoken offsets.
+ ret_subtokens: ``True`` to return list of subtokens belonging to each token.
+ ret_subtokens_group: ``True`` to return list of offsets of subtokens belonging to each token.
+ ret_prefix_mask: ``True`` to generate a mask where each non-zero element corresponds to a prefix of a token.
+ ret_raw_hidden_states: ``True`` to return hidden states of each layer.
+ transformer_args: Extra arguments passed to the transformer.
+ use_fast: Whether or not to try to load the fast version of the tokenizer.
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
+ trainable: ``False`` to use static embeddings.
+ """
+ super().__init__()
+ self.truncate_long_sequences = truncate_long_sequences
+ self.transformer_args = transformer_args
+ self.trainable = trainable
+ self.ret_subtokens_group = ret_subtokens_group
+ self.ret_subtokens = ret_subtokens
+ self.ret_raw_hidden_states = ret_raw_hidden_states
+ self.sep_is_eos = sep_is_eos
+ self.cls_is_bos = cls_is_bos
+ self.max_sequence_length = max_sequence_length
+ self.word_dropout = word_dropout
+ self.scalar_mix = scalar_mix
+ self.average_subwords = average_subwords
+ self.transformer = transformer
+ self.field = field
+ self._transformer_tokenizer = TransformerEncoder.build_transformer_tokenizer(self.transformer,
+ use_fast=use_fast,
+ do_basic_tokenize=do_basic_tokenize)
+ self._tokenizer_transform = TransformerSequenceTokenizer(self._transformer_tokenizer,
+ field,
+ truncate_long_sequences=truncate_long_sequences,
+ ret_prefix_mask=ret_prefix_mask,
+ ret_token_span=ret_token_span,
+ cls_is_bos=cls_is_bos,
+ sep_is_eos=sep_is_eos,
+ ret_subtokens=ret_subtokens,
+ ret_subtokens_group=ret_subtokens_group,
+ max_seq_length=self.max_sequence_length
+ )
+
+ def transform(self, **kwargs) -> TransformerSequenceTokenizer:
+ return self._tokenizer_transform
+
+ def module(self, training=True, **kwargs) -> Optional[nn.Module]:
+ return ContextualWordEmbeddingModule(self.field,
+ self.transformer,
+ self._transformer_tokenizer,
+ self.average_subwords,
+ self.scalar_mix,
+ self.word_dropout,
+ self.max_sequence_length,
+ self.ret_raw_hidden_states,
+ self.transformer_args,
+ self.trainable,
+ training=training)
+
+ def get_output_dim(self):
+ config = AutoConfig.from_pretrained(self.transformer)
+ return config.hidden_size
+
+ def get_tokenizer(self):
+ return self._transformer_tokenizer
+
+
+def find_transformer(embed: nn.Module):
+ if isinstance(embed, ContextualWordEmbeddingModule):
+ return embed
+ if isinstance(embed, nn.ModuleList):
+ for child in embed:
+ found = find_transformer(child)
+ if found:
+ return found
diff --git a/hanlp/layers/embeddings/embedding.py b/hanlp/layers/embeddings/embedding.py
new file mode 100644
index 000000000..9997fe0bd
--- /dev/null
+++ b/hanlp/layers/embeddings/embedding.py
@@ -0,0 +1,133 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-02 13:04
+from abc import ABC, abstractmethod
+from typing import Callable, List, Optional, Iterable
+
+import torch
+from torch import nn
+from torch.nn import Module
+
+from hanlp_common.configurable import AutoConfigurable
+from hanlp.common.transform import TransformList
+from hanlp.layers.dropout import IndependentDropout
+
+
+class EmbeddingDim(ABC):
+ @property
+ @abstractmethod
+ def embedding_dim(self) -> int:
+ return -1
+
+ def get_output_dim(self) -> int:
+ return self.embedding_dim
+
+
+class Embedding(AutoConfigurable, ABC):
+
+ def __init__(self) -> None:
+ """
+ Base class for embedding builders.
+ """
+ super().__init__()
+
+ def transform(self, **kwargs) -> Optional[Callable]:
+ """Build a transform function for this embedding.
+
+ Args:
+ **kwargs: Containing vocabs, training etc. Not finalized for now.
+
+ Returns:
+ A transform function.
+ """
+ return None
+
+ def module(self, **kwargs) -> Optional[nn.Module]:
+ """Build a module for this embedding.
+
+ Args:
+ **kwargs: Containing vocabs, training etc. Not finalized for now.
+
+ Returns:
+ A module.
+ """
+ return None
+
+
+class ConcatModuleList(nn.ModuleList, EmbeddingDim):
+
+ def __init__(self, *modules: Optional[Iterable[Module]], dropout=None) -> None:
+ """A ``nn.ModuleList`` to bundle several embeddings modules.
+
+ Args:
+ *modules: Embedding layers.
+ dropout: Dropout applied on the concatenated embedding.
+ """
+ super().__init__(*modules)
+ if dropout:
+ dropout = IndependentDropout(p=dropout)
+ self.dropout = dropout
+
+ @property
+ def embedding_dim(self) -> int:
+ return sum(embed.embedding_dim for embed in self)
+
+ def get_output_dim(self) -> int:
+ return sum(embed.get_output_dim() for embed in self)
+
+ # noinspection PyMethodOverriding
+ def forward(self, batch: dict, **kwargs):
+ embeds = [embed(batch, **kwargs) for embed in self.embeddings]
+ if self.dropout:
+ embeds = self.dropout(*embeds)
+ return torch.cat(embeds, -1)
+
+ @property
+ def embeddings(self):
+ embeddings = [x for x in self]
+ if self.dropout:
+ embeddings.remove(self.dropout)
+ return embeddings
+
+
+class EmbeddingList(Embedding):
+ def __init__(self, *embeddings_, embeddings: dict = None, dropout=None) -> None:
+ """An embedding builder to bundle several embedding builders.
+
+ Args:
+ *embeddings_: A list of embedding builders.
+ embeddings: Deserialization for a dict of embedding builders.
+ dropout: Dropout applied on the concatenated embedding.
+ """
+ # noinspection PyTypeChecker
+ self.dropout = dropout
+ self._embeddings: List[Embedding] = list(embeddings_)
+ if embeddings:
+ for each in embeddings:
+ if isinstance(each, dict):
+ each = AutoConfigurable.from_config(each)
+ self._embeddings.append(each)
+ self.embeddings = [e.config for e in self._embeddings]
+
+ def transform(self, **kwargs):
+ transforms = [e.transform(**kwargs) for e in self._embeddings]
+ transforms = [t for t in transforms if t]
+ return TransformList(*transforms)
+
+ def module(self, **kwargs):
+ modules = [e.module(**kwargs) for e in self._embeddings]
+ modules = [m for m in modules if m]
+ return ConcatModuleList(modules, dropout=self.dropout)
+
+ def to_list(self):
+ return self._embeddings
+
+
+def find_embedding_by_class(embed: Embedding, cls):
+ if isinstance(embed, cls):
+ return embed
+ if isinstance(embed, EmbeddingList):
+ for child in embed.to_list():
+ found = find_embedding_by_class(child, cls)
+ if found:
+ return found
diff --git a/hanlp/layers/embeddings/fast_text.py b/hanlp/layers/embeddings/fast_text.py
index 2e87be210..7e75798b6 100644
--- a/hanlp/layers/embeddings/fast_text.py
+++ b/hanlp/layers/embeddings/fast_text.py
@@ -1,111 +1,99 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-10-29 13:14
+# Date: 2020-05-27 15:06
import os
import sys
+from typing import Optional, Callable
import fasttext
-import numpy as np
-import tensorflow as tf
-from tensorflow.python.keras.utils import tf_utils
+import torch
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence
-from hanlp.common.constant import PAD
-from hanlp.utils import global_cache
+from hanlp_common.configurable import AutoConfigurable
+from hanlp.common.transform import EmbeddingNamedTransform
+from hanlp.layers.embeddings.embedding import Embedding
from hanlp.utils.io_util import get_resource, stdout_redirected
-from hanlp.utils.log_util import logger
-from hanlp.utils.tf_util import hanlp_register
+from hanlp.utils.log_util import flash
-@hanlp_register
-class FastTextEmbedding(tf.keras.layers.Embedding):
-
- def __init__(self, filepath: str, padding=PAD, name=None, **kwargs):
- self.padding = padding.encode('utf-8')
+class FastTextTransform(EmbeddingNamedTransform):
+ def __init__(self, filepath: str, src, dst=None, **kwargs) -> None:
+ if not dst:
+ dst = src + '_fasttext'
self.filepath = filepath
+ flash(f'Loading fasttext model {filepath} [blink][yellow]...[/yellow][/blink]')
filepath = get_resource(filepath)
- assert os.path.isfile(filepath), f'Resolved path {filepath} is not a file'
- existed = global_cache.get(filepath, None)
- if existed:
- logger.debug('Use cached fasttext model [{}].'.format(filepath))
- self.model = existed
- else:
- logger.debug('Loading fasttext model from [{}].'.format(filepath))
- # fasttext print a blank line here
- with stdout_redirected(to=os.devnull, stdout=sys.stderr):
- self.model = fasttext.load_model(filepath)
- global_cache[filepath] = self.model
- kwargs.pop('input_dim', None)
- kwargs.pop('output_dim', None)
- kwargs.pop('mask_zero', None)
- if not name:
- name = os.path.splitext(os.path.basename(filepath))[0]
- super().__init__(input_dim=len(self.model.words), output_dim=self.model['king'].size,
- mask_zero=padding is not None, trainable=False, dtype=tf.string, name=name, **kwargs)
- embed_fn = np.frompyfunc(self.embed, 1, 1)
- # vf = np.vectorize(self.embed, otypes=[np.ndarray])
- self._embed_np = embed_fn
-
- def embed(self, word):
- return self.model[word]
-
- def embed_np(self, words: np.ndarray):
- output = self._embed_np(words)
- if self.mask_zero:
- mask = words != self.padding
- output *= mask
- output = np.stack(output.reshape(-1)).reshape(list(words.shape) + [self.output_dim])
- return output, tf.constant(mask)
- else:
- output = np.stack(output.reshape(-1)).reshape(list(words.shape) + [self.output_dim])
- return output
-
- @tf_utils.shape_type_conversion
- def build(self, input_shape):
- self.built = True
-
- @tf_utils.shape_type_conversion
- def compute_output_shape(self, input_shape):
- return input_shape + (self.output_dim,)
-
- def call(self, inputs: tf.Tensor):
- if isinstance(inputs, list):
- inputs = inputs[0]
- if not hasattr(inputs, 'numpy'): # placeholder tensor
- inputs = tf.expand_dims(inputs, axis=-1)
- inputs = tf.tile(inputs, [1] * (len(inputs.shape) - 1) + [self.output_dim])
- inputs = tf.zeros_like(inputs, dtype=tf.float32)
- return inputs
- # seq_len = inputs.shape[-1]
- # if not seq_len:
- # seq_len = 1
- # return tf.zeros([1, seq_len, self.output_dim])
- if self.mask_zero:
- outputs, masks = self.embed_np(inputs.numpy())
- outputs = tf.constant(outputs)
- outputs._keras_mask = masks
+ with stdout_redirected(to=os.devnull, stdout=sys.stderr):
+ self._model = fasttext.load_model(filepath)
+ flash('')
+ output_dim = self._model['king'].size
+ super().__init__(output_dim, src, dst)
+
+ def __call__(self, sample: dict):
+ word = sample[self.src]
+ if isinstance(word, str):
+ vector = self.embed(word)
else:
- outputs = self.embed_np(inputs.numpy())
- outputs = tf.constant(outputs)
+ vector = torch.stack([self.embed(each) for each in word])
+ sample[self.dst] = vector
+ return sample
+
+ def embed(self, word: str):
+ return torch.tensor(self._model[word])
+
+
+class PassThroughModule(torch.nn.Module):
+ def __init__(self, key) -> None:
+ super().__init__()
+ self.key = key
+
+ def __call__(self, batch: dict, mask=None, **kwargs):
+ return batch[self.key]
+
+
+class FastTextEmbeddingModule(PassThroughModule):
+
+ def __init__(self, key, embedding_dim: int) -> None:
+ """An embedding layer for fastText (:cite:`bojanowski2017enriching`).
+
+ Args:
+ key: Field name.
+ embedding_dim: Size of this embedding layer
+ """
+ super().__init__(key)
+ self.embedding_dim = embedding_dim
+
+ def __call__(self, batch: dict, mask=None, **kwargs):
+ outputs = super().__call__(batch, **kwargs)
+ outputs = pad_sequence(outputs, True, 0).to(mask.device)
return outputs
- def compute_mask(self, inputs, mask=None):
- if not self.mask_zero:
- return None
- return tf.not_equal(inputs, self.padding)
-
- def get_config(self):
- config = {
- 'filepath': self.filepath,
- 'padding': self.padding.decode('utf-8')
- }
- base_config = super(FastTextEmbedding, self).get_config()
- for junk in 'embeddings_initializer' \
- , 'batch_input_shape' \
- , 'embeddings_regularizer' \
- , 'embeddings_constraint' \
- , 'activity_regularizer' \
- , 'trainable' \
- , 'input_length' \
- :
- base_config.pop(junk)
- return dict(list(base_config.items()) + list(config.items()))
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += f'key={self.key}, embedding_dim={self.embedding_dim}'
+ s += ')'
+ return s
+
+ def get_output_dim(self):
+ return self.embedding_dim
+
+
+class FastTextEmbedding(Embedding, AutoConfigurable):
+ def __init__(self, src: str, filepath: str) -> None:
+ """An embedding layer builder for fastText (:cite:`bojanowski2017enriching`).
+
+ Args:
+ src: Field name.
+ filepath: Filepath to pretrained fastText embeddings.
+ """
+ super().__init__()
+ self.src = src
+ self.filepath = filepath
+ self._fasttext = FastTextTransform(self.filepath, self.src)
+
+ def transform(self, **kwargs) -> Optional[Callable]:
+ return self._fasttext
+
+ def module(self, **kwargs) -> Optional[nn.Module]:
+ return FastTextEmbeddingModule(self._fasttext.dst, self._fasttext.output_dim)
diff --git a/hanlp/layers/embeddings/fast_text_tf.py b/hanlp/layers/embeddings/fast_text_tf.py
new file mode 100644
index 000000000..3fe06c8fe
--- /dev/null
+++ b/hanlp/layers/embeddings/fast_text_tf.py
@@ -0,0 +1,104 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-10-29 13:14
+import os
+import sys
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.keras.utils import tf_utils
+
+from hanlp_common.constant import PAD
+from hanlp.utils.io_util import get_resource, stdout_redirected
+from hanlp.utils.log_util import logger
+from hanlp.utils.tf_util import hanlp_register
+
+
+@hanlp_register
+class FastTextEmbeddingTF(tf.keras.layers.Embedding):
+
+ def __init__(self, filepath: str, padding=PAD, name=None, **kwargs):
+ import fasttext
+ self.padding = padding.encode('utf-8')
+ self.filepath = filepath
+ filepath = get_resource(filepath)
+ assert os.path.isfile(filepath), f'Resolved path {filepath} is not a file'
+ logger.debug('Loading fasttext model from [{}].'.format(filepath))
+ # fasttext print a blank line here
+ with stdout_redirected(to=os.devnull, stdout=sys.stderr):
+ self.model = fasttext.load_model(filepath)
+ kwargs.pop('input_dim', None)
+ kwargs.pop('output_dim', None)
+ kwargs.pop('mask_zero', None)
+ if not name:
+ name = os.path.splitext(os.path.basename(filepath))[0]
+ super().__init__(input_dim=len(self.model.words), output_dim=self.model['king'].size,
+ mask_zero=padding is not None, trainable=False, dtype=tf.string, name=name, **kwargs)
+ embed_fn = np.frompyfunc(self.embed, 1, 1)
+ # vf = np.vectorize(self.embed, otypes=[np.ndarray])
+ self._embed_np = embed_fn
+
+ def embed(self, word):
+ return self.model[word]
+
+ def embed_np(self, words: np.ndarray):
+ output = self._embed_np(words)
+ if self.mask_zero:
+ mask = words != self.padding
+ output *= mask
+ output = np.stack(output.reshape(-1)).reshape(list(words.shape) + [self.output_dim])
+ return output, tf.constant(mask)
+ else:
+ output = np.stack(output.reshape(-1)).reshape(list(words.shape) + [self.output_dim])
+ return output
+
+ @tf_utils.shape_type_conversion
+ def build(self, input_shape):
+ self.built = True
+
+ @tf_utils.shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape + (self.output_dim,)
+
+ def call(self, inputs: tf.Tensor):
+ if isinstance(inputs, list):
+ inputs = inputs[0]
+ if not hasattr(inputs, 'numpy'): # placeholder tensor
+ inputs = tf.expand_dims(inputs, axis=-1)
+ inputs = tf.tile(inputs, [1] * (len(inputs.shape) - 1) + [self.output_dim])
+ inputs = tf.zeros_like(inputs, dtype=tf.float32)
+ return inputs
+ # seq_len = inputs.shape[-1]
+ # if not seq_len:
+ # seq_len = 1
+ # return tf.zeros([1, seq_len, self.output_dim])
+ if self.mask_zero:
+ outputs, masks = self.embed_np(inputs.numpy())
+ outputs = tf.constant(outputs)
+ outputs._keras_mask = masks
+ else:
+ outputs = self.embed_np(inputs.numpy())
+ outputs = tf.constant(outputs)
+ return outputs
+
+ def compute_mask(self, inputs, mask=None):
+ if not self.mask_zero:
+ return None
+ return tf.not_equal(inputs, self.padding)
+
+ def get_config(self):
+ config = {
+ 'filepath': self.filepath,
+ 'padding': self.padding.decode('utf-8')
+ }
+ base_config = super(FastTextEmbeddingTF, self).get_config()
+ for junk in 'embeddings_initializer' \
+ , 'batch_input_shape' \
+ , 'embeddings_regularizer' \
+ , 'embeddings_constraint' \
+ , 'activity_regularizer' \
+ , 'trainable' \
+ , 'input_length' \
+ :
+ base_config.pop(junk)
+ return dict(list(base_config.items()) + list(config.items()))
diff --git a/hanlp/layers/embeddings/util.py b/hanlp/layers/embeddings/util.py
new file mode 100644
index 000000000..8975e3810
--- /dev/null
+++ b/hanlp/layers/embeddings/util.py
@@ -0,0 +1,106 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-09 15:45
+from typing import Union
+
+import torch
+from torch import nn
+
+from hanlp.common.vocab import Vocab
+from hanlp.utils.init_util import embedding_uniform
+from hanlp.utils.io_util import load_word2vec, load_word2vec_as_vocab_tensor
+
+
+def index_word2vec_with_vocab(filepath: str,
+ vocab: Vocab,
+ extend_vocab=True,
+ unk=None,
+ lowercase=False,
+ init='uniform',
+ normalize=None) -> torch.Tensor:
+ """
+
+ Args:
+ filepath: The path to pretrained embedding.
+ vocab: The vocabulary from training set.
+ extend_vocab: Unlock vocabulary of training set to add those tokens in pretrained embedding file.
+ unk: UNK token.
+ lowercase: Convert words in pretrained embeddings into lowercase.
+ init: Indicate which initialization to use for oov tokens.
+ normalize: ``True`` or a method to normalize the embedding matrix.
+
+ Returns:
+ An embedding matrix.
+
+ """
+ pret_vocab, pret_matrix = load_word2vec_as_vocab_tensor(filepath)
+ if unk and unk in pret_vocab:
+ pret_vocab[vocab.safe_unk_token] = pret_vocab.pop(unk)
+ if extend_vocab:
+ vocab.unlock()
+ for word in pret_vocab:
+ vocab.get_idx(word.lower() if lowercase else word)
+ vocab.lock()
+ ids = []
+
+ unk_id_offset = 0
+ for word, idx in vocab.token_to_idx.items():
+ word_id = pret_vocab.get(word, None)
+ # Retry lower case
+ if word_id is None:
+ word_id = pret_vocab.get(word.lower(), None)
+ if word_id is None:
+ word_id = len(pret_vocab) + unk_id_offset
+ unk_id_offset += 1
+ ids.append(word_id)
+ if unk_id_offset:
+ unk_embeds = torch.zeros(unk_id_offset, pret_matrix.size(1))
+ if init and init != 'zeros':
+ if init == 'uniform':
+ init = embedding_uniform
+ else:
+ raise ValueError(f'Unsupported init {init}')
+ unk_embeds = init(unk_embeds)
+ pret_matrix = torch.cat([pret_matrix, unk_embeds])
+ ids = torch.LongTensor(ids)
+ embedding = pret_matrix.index_select(0, ids)
+ if normalize == 'norm':
+ embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)
+ elif normalize == 'std':
+ embedding /= torch.std(embedding)
+ return embedding
+
+
+def build_word2vec_with_vocab(embed: Union[str, int],
+ vocab: Vocab,
+ extend_vocab=True,
+ unk=None,
+ lowercase=False,
+ trainable=False,
+ init='zeros',
+ normalize=None) -> nn.Embedding:
+ """Build word2vec embedding and a vocab.
+
+ Args:
+ embed:
+ vocab: The vocabulary from training set.
+ extend_vocab: Unlock vocabulary of training set to add those tokens in pretrained embedding file.
+ unk: UNK token.
+ lowercase: Convert words in pretrained embeddings into lowercase.
+ trainable: ``False`` to use static embeddings.
+ init: Indicate which initialization to use for oov tokens.
+ normalize: ``True`` or a method to normalize the embedding matrix.
+
+ Returns:
+ An embedding matrix.
+
+ """
+ if isinstance(embed, str):
+ embed = index_word2vec_with_vocab(embed, vocab, extend_vocab, unk, lowercase, init, normalize)
+ embed = nn.Embedding.from_pretrained(embed, freeze=not trainable, padding_idx=vocab.pad_idx)
+ return embed
+ elif isinstance(embed, int):
+ embed = nn.Embedding(len(vocab), embed, padding_idx=vocab.pad_idx)
+ return embed
+ else:
+ raise ValueError(f'Unsupported parameter type: {embed}')
diff --git a/hanlp/layers/embeddings/util_tf.py b/hanlp/layers/embeddings/util_tf.py
new file mode 100644
index 000000000..6ff82f7bd
--- /dev/null
+++ b/hanlp/layers/embeddings/util_tf.py
@@ -0,0 +1,88 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-09 15:46
+from typing import Union
+
+import tensorflow as tf
+
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.layers.embeddings.char_cnn_tf import CharCNNEmbeddingTF
+from hanlp.layers.embeddings.char_rnn_tf import CharRNNEmbeddingTF
+from hanlp.layers.embeddings.concat_embedding import ConcatEmbedding
+from hanlp.layers.embeddings.contextual_string_embedding_tf import ContextualStringEmbeddingTF
+from hanlp.layers.embeddings.fast_text_tf import FastTextEmbeddingTF
+from hanlp.layers.embeddings.word2vec_tf import Word2VecEmbeddingTF, StringWord2VecEmbeddingTF, Word2VecEmbeddingV1
+
+_upgrade = tf.keras.utils.get_custom_objects()
+for k, v in list(_upgrade.items()):
+ if k.startswith('HanLP>') and k.endswith('TF'):
+ _upgrade[k[:-2]] = v
+
+
+def build_embedding(embeddings: Union[str, int, dict], word_vocab: VocabTF, transform: Transform):
+ if not embeddings:
+ return None
+ config = transform.config
+ if isinstance(embeddings, int):
+ embeddings = tf.keras.layers.Embedding(input_dim=len(word_vocab), output_dim=embeddings,
+ trainable=True, mask_zero=True)
+ config.embedding_trainable = True
+ elif isinstance(embeddings, dict):
+ # Upgrade to 2.1
+ embed_name = embeddings['class_name'].split('>')[-1]
+ if embeddings['class_name'].startswith('HanLP>') and not embeddings['class_name'].endswith('TF'):
+ embed_name += 'TF'
+ # Embeddings need vocab
+ if embed_name in (Word2VecEmbeddingTF.__name__, StringWord2VecEmbeddingTF.__name__):
+ # Vocab won't present in the dict
+ embeddings['config']['vocab'] = word_vocab
+ elif embed_name in (CharRNNEmbeddingTF.__name__, CharCNNEmbeddingTF.__name__):
+ embeddings['config']['word_vocab'] = word_vocab
+ embeddings['config']['char_vocab'] = transform.char_vocab
+ transform.map_x = False
+ layer: tf.keras.layers.Embedding = tf.keras.utils.deserialize_keras_object(embeddings)
+ # Embedding specific configuration
+ if layer.__class__.__name__ in ('FastTextEmbedding', 'FastTextEmbeddingTF'):
+ config.run_eagerly = True # fasttext can only run in eager mode
+ config.embedding_trainable = False
+ transform.map_x = False # fasttext accept string instead of int
+ return layer
+ elif isinstance(embeddings, list):
+ if embeddings_require_string_input(embeddings):
+ # those embeddings require string as input
+ transform.map_x = False
+ # use the string version of Word2VecEmbedding instead
+ for embed in embeddings:
+ if embed['class_name'].split('>')[-1] == Word2VecEmbeddingTF.__name__:
+ embed['class_name'] = 'HanLP>' + StringWord2VecEmbeddingTF.__name__
+ return ConcatEmbedding(*[build_embedding(embed, word_vocab, transform) for embed in embeddings])
+ else:
+ assert isinstance(embeddings, str), 'embedding should be str or int or dict'
+ # word_vocab.unlock()
+ embeddings = Word2VecEmbeddingV1(path=embeddings, vocab=word_vocab,
+ trainable=config.get('embedding_trainable', False))
+ embeddings = embeddings.array_ks
+ return embeddings
+
+
+def any_embedding_in(embeddings, *cls):
+ names = set(x.__name__ for x in cls)
+ names.update(list(x[:-2] for x in names if x.endswith('TF')))
+ for embed in embeddings:
+ if isinstance(embed, dict) and embed['class_name'].split('>')[-1] in names:
+ return True
+ return False
+
+
+def embeddings_require_string_input(embeddings):
+ if not isinstance(embeddings, list):
+ embeddings = [embeddings]
+ return any_embedding_in(embeddings, CharRNNEmbeddingTF, CharCNNEmbeddingTF, FastTextEmbeddingTF,
+ ContextualStringEmbeddingTF)
+
+
+def embeddings_require_char_input(embeddings):
+ if not isinstance(embeddings, list):
+ embeddings = [embeddings]
+ return any_embedding_in(embeddings, CharRNNEmbeddingTF, CharCNNEmbeddingTF, ContextualStringEmbeddingTF)
diff --git a/hanlp/layers/embeddings/word2vec.py b/hanlp/layers/embeddings/word2vec.py
index 5621395b2..5539d8d14 100644
--- a/hanlp/layers/embeddings/word2vec.py
+++ b/hanlp/layers/embeddings/word2vec.py
@@ -1,195 +1,212 @@
# -*- coding:utf-8 -*-
# Author: hankcs
-# Date: 2019-08-24 21:49
-import os
-from typing import Tuple, Union, List
+# Date: 2020-05-09 13:38
+from typing import Optional, Callable, Union
-import numpy as np
-import tensorflow as tf
-from tensorflow.python.ops import math_ops
+import torch
+from torch import nn
+from hanlp_common.configurable import AutoConfigurable
+from hanlp.common.transform import VocabDict
from hanlp.common.vocab import Vocab
-from hanlp.utils.io_util import load_word2vec, get_resource
-from hanlp.utils.tf_util import hanlp_register
-
-
-class Word2VecEmbeddingV1(tf.keras.layers.Layer):
- def __init__(self, path: str = None, vocab: Vocab = None, normalize: bool = False, load_all=True, mask_zero=True,
- trainable=False, name=None, dtype=None, dynamic=False, **kwargs):
- super().__init__(trainable, name, dtype, dynamic, **kwargs)
- if load_all and vocab and vocab.locked:
- vocab.unlock()
- self.vocab, self.array_np = self._load(path, vocab, normalize)
- self.vocab.lock()
- self.array_ks = tf.keras.layers.Embedding(input_dim=len(self.vocab), output_dim=self.dim, trainable=trainable,
- embeddings_initializer=tf.keras.initializers.Constant(self.array_np),
- mask_zero=mask_zero)
- self.mask_zero = mask_zero
- self.supports_masking = mask_zero
-
- def compute_mask(self, inputs, mask=None):
- if not self.mask_zero:
- return None
-
- return math_ops.not_equal(inputs, self.vocab.pad_idx)
-
- def call(self, inputs, **kwargs):
- return self.array_ks(inputs, **kwargs)
-
- def compute_output_shape(self, input_shape):
- return input_shape[0], self.dim
-
- @staticmethod
- def _load(path, vocab, normalize=False) -> Tuple[Vocab, Union[np.ndarray, None]]:
- if not vocab:
- vocab = Vocab()
- if not path:
- return vocab, None
- assert vocab.unk_idx is not None
-
- word2vec, dim = load_word2vec(path)
- for word in word2vec:
- vocab.get_idx(word)
-
- pret_embs = np.zeros(shape=(len(vocab), dim), dtype=np.float32)
- state = np.random.get_state()
- np.random.seed(0)
- bias = np.random.uniform(low=-0.001, high=0.001, size=dim).astype(dtype=np.float32)
- scale = np.sqrt(3.0 / dim)
- for word, idx in vocab.token_to_idx.items():
- vec = word2vec.get(word, None)
- if vec is None:
- vec = word2vec.get(word.lower(), None)
- # if vec is not None:
- # vec += bias
- if vec is None:
- # vec = np.random.uniform(-scale, scale, [dim])
- vec = np.zeros([dim], dtype=np.float32)
- pret_embs[idx] = vec
- # noinspection PyTypeChecker
- np.random.set_state(state)
- return vocab, pret_embs
-
- @property
- def size(self):
- if self.array_np is not None:
- return self.array_np.shape[0]
+from hanlp.layers.dropout import WordDropout
+from hanlp.layers.embeddings.embedding import Embedding, EmbeddingDim
+from hanlp.layers.embeddings.util import build_word2vec_with_vocab
+from hanlp.utils.io_util import load_word2vec_as_vocab_tensor
+from hanlp_trie.trie import Trie
+
+
+class Word2VecEmbeddingModule(nn.Module, EmbeddingDim):
+ def __init__(self, field: str, embed: nn.Embedding, word_dropout: WordDropout = None, cpu=False,
+ second_channel=False, num_tokens_in_trn=None, unk_idx=1) -> None:
+ """A word2vec style embedding module which maps a token to its embedding through looking up a pre-defined table.
+
+ Args:
+ field: The field to work on. Usually some token fields.
+ embed: An ``Embedding`` layer.
+ word_dropout: The probability of randomly replacing a token with ``UNK``.
+ cpu: Reside on CPU instead of GPU.
+ second_channel: A trainable second channel for each token, which will be added to pretrained embeddings.
+ num_tokens_in_trn: The number of tokens in training set.
+ unk_idx: The index of ``UNK``.
+ """
+ super().__init__()
+ self.cpu = cpu
+ self.field = field
+ self.embed = embed
+ self.word_dropout = word_dropout
+ self.num_tokens_in_trn = num_tokens_in_trn
+ self.unk_idx = unk_idx
+ if second_channel:
+ n_words, n_embed = embed.weight.size()
+ if num_tokens_in_trn:
+ n_words = num_tokens_in_trn
+ second_channel = nn.Embedding(num_embeddings=n_words,
+ embedding_dim=n_embed)
+ nn.init.zeros_(second_channel.weight)
+ self.second_channel = second_channel
+
+ def forward(self, batch: dict, **kwargs):
+ x: torch.Tensor = batch[f'{self.field}_id']
+ if self.cpu:
+ device = x.device
+ x = x.cpu()
+ if self.word_dropout:
+ x = self.word_dropout(x)
+ if self.second_channel:
+ ext_mask = x.ge(self.second_channel.num_embeddings)
+ ext_words = x.masked_fill(ext_mask, self.unk_idx)
+ x = self.embed(x) + self.second_channel(ext_words)
+ else:
+ x = self.embed(x)
+ if self.cpu:
+ # noinspection PyUnboundLocalVariable
+ x = x.to(device)
+ return x
@property
- def dim(self):
- if self.array_np is not None:
- return self.array_np.shape[1]
-
- @property
- def shape(self):
- if self.array_np is None:
- return None
- return self.array_np.shape
-
- def get_vector(self, word: str) -> np.ndarray:
- assert self.array_np is not None
- return self.array_np[self.vocab.get_idx_without_add(word)]
-
- def __getitem__(self, word: Union[str, List, tf.Tensor]) -> np.ndarray:
- if isinstance(word, str):
- return self.get_vector(word)
- elif isinstance(word, list):
- vectors = np.zeros(shape=(len(word), self.dim))
- for idx, token in enumerate(word):
- vectors[idx] = self.get_vector(token)
- return vectors
- elif isinstance(word, tf.Tensor):
- if word.dtype == tf.string:
- word_ids = self.vocab.token_to_idx_table.lookup(word)
- return tf.nn.embedding_lookup(self.array_tf, word_ids)
- elif word.dtype == tf.int32 or word.dtype == tf.int64:
- return tf.nn.embedding_lookup(self.array_tf, word)
-
-
-@hanlp_register
-class Word2VecEmbedding(tf.keras.layers.Embedding):
-
- def __init__(self, filepath: str = None, vocab: Vocab = None, expand_vocab=True, lowercase=True,
- input_dim=None, output_dim=None, unk=None, normalize=False,
- embeddings_initializer='VarianceScaling',
- embeddings_regularizer=None,
- activity_regularizer=None, embeddings_constraint=None, mask_zero=True, input_length=None,
- name=None, **kwargs):
- filepath = get_resource(filepath)
- word2vec, _output_dim = load_word2vec(filepath)
- if output_dim:
- assert output_dim == _output_dim, f'output_dim = {output_dim} does not match {filepath}'
- output_dim = _output_dim
- # if the `unk` token exists in the pretrained,
- # then replace it with a self-defined one, usually the one in word vocab
- if unk and unk in word2vec:
- word2vec[vocab.safe_unk_token] = word2vec.pop(unk)
- if vocab is None:
- vocab = Vocab()
- vocab.update(word2vec.keys())
- if expand_vocab and vocab.mutable:
- for word in word2vec:
- vocab.get_idx(word.lower() if lowercase else word)
- if input_dim:
- assert input_dim == len(vocab), f'input_dim = {input_dim} does not match {filepath}'
- input_dim = len(vocab)
- # init matrix
- self._embeddings_initializer = embeddings_initializer
- embeddings_initializer = tf.keras.initializers.get(embeddings_initializer)
- with tf.device('cpu:0'):
- pret_embs = embeddings_initializer(shape=[input_dim, output_dim]).numpy()
- # insert to pret_embs
- for word, idx in vocab.token_to_idx.items():
- vec = word2vec.get(word, None)
- # Retry lower case
- if vec is None and lowercase:
- vec = word2vec.get(word.lower(), None)
- if vec is not None:
- pret_embs[idx] = vec
- if normalize:
- pret_embs /= np.std(pret_embs)
- if not name:
- name = os.path.splitext(os.path.basename(filepath))[0]
- super().__init__(input_dim, output_dim, tf.keras.initializers.Constant(pret_embs), embeddings_regularizer,
- activity_regularizer, embeddings_constraint, mask_zero, input_length, name=name, **kwargs)
- self.filepath = filepath
- self.expand_vocab = expand_vocab
+ def embedding_dim(self) -> int:
+ return self.embed.embedding_dim
+
+ # noinspection PyMethodOverriding
+ # def to(self, device, **kwargs):
+ # print(self.cpu)
+ # exit(1)
+ # if self.cpu:
+ # return super(Word2VecEmbeddingModule, self).to(-1, **kwargs)
+ # return super(Word2VecEmbeddingModule, self).to(device, **kwargs)
+
+ def _apply(self, fn):
+
+ if not self.cpu: # This might block all fn not limiting to moving between devices.
+ return super(Word2VecEmbeddingModule, self)._apply(fn)
+
+
+class Word2VecEmbedding(Embedding, AutoConfigurable):
+ def __init__(self,
+ field,
+ embed: Union[int, str],
+ extend_vocab=True,
+ pad=None,
+ unk=None,
+ lowercase=False,
+ trainable=False,
+ second_channel=False,
+ word_dropout: float = 0,
+ normalize=False,
+ cpu=False,
+ init='zeros') -> None:
+ """A word2vec style embedding builder which maps a token to its embedding through looking up a pre-defined
+ table.
+
+ Args:
+ field: The field to work on. Usually some token fields.
+ embed: A path to pre-trained embedding file or an integer defining the size of randomly initialized
+ embedding.
+ extend_vocab: Unlock vocabulary of training set to add those tokens in pre-trained embedding file.
+ pad: The padding token.
+ unk: The unknown token.
+ lowercase: Convert words in pretrained embeddings into lowercase.
+ trainable: ``False`` to use static embeddings.
+ second_channel: A trainable second channel for each token, which will be added to pretrained embeddings.
+ word_dropout: The probability of randomly replacing a token with ``UNK``.
+ normalize: ``True`` or a method to normalize the embedding matrix.
+ cpu: Reside on CPU instead of GPU.
+ init: Indicate which initialization to use for oov tokens.
+ """
+ super().__init__()
+ self.pad = pad
+ self.second_channel = second_channel
+ self.cpu = cpu
+ self.normalize = normalize
+ self.word_dropout = word_dropout
+ self.init = init
self.lowercase = lowercase
+ self.unk = unk
+ self.extend_vocab = extend_vocab
+ self.trainable = trainable
+ self.embed = embed
+ self.field = field
+
+ def module(self, vocabs: VocabDict, **kwargs) -> Optional[nn.Module]:
+ vocab = vocabs[self.field]
+ num_tokens_in_trn = len(vocab)
+ embed = build_word2vec_with_vocab(self.embed,
+ vocab,
+ self.extend_vocab,
+ self.unk,
+ self.lowercase,
+ self.trainable,
+ normalize=self.normalize)
+ if self.word_dropout:
+ assert vocab.unk_token, f'unk_token of vocab {self.field} has to be set in order to ' \
+ f'make use of word_dropout'
+ padding = []
+ if vocab.pad_token:
+ padding.append(vocab.pad_idx)
+ word_dropout = WordDropout(self.word_dropout, vocab.unk_idx, exclude_tokens=padding)
+ else:
+ word_dropout = None
+ return Word2VecEmbeddingModule(self.field, embed, word_dropout=word_dropout, cpu=self.cpu,
+ second_channel=self.second_channel, num_tokens_in_trn=num_tokens_in_trn,
+ unk_idx=vocab.unk_idx)
+
+ def transform(self, vocabs: VocabDict = None, **kwargs) -> Optional[Callable]:
+ assert vocabs is not None
+ if self.field not in vocabs:
+ vocabs[self.field] = Vocab(pad_token=self.pad, unk_token=self.unk)
+ return super().transform(**kwargs)
+
+
+class GazetterTransform(object):
+ def __init__(self, field, words: dict) -> None:
+ super().__init__()
+ self.field = field
+ self.trie = Trie()
+ for word, idx in words.items():
+ self.trie[word] = idx
+
+ def __call__(self, sample: dict) -> dict:
+ tokens = sample[self.field]
+ lexicons = self.trie.parse(tokens)
+ skips_l2r = [[] for _ in range(len(tokens))]
+ skips_r2l = [[] for _ in range(len(tokens))]
+ for w, i, s, e in lexicons:
+ e = e - 1
+ skips_l2r[e].append((s, w, i))
+ skips_r2l[s].append((e, w, i))
+ for direction, value in zip(['skips_l2r', 'skips_r2l'], [skips_l2r, skips_r2l]):
+ sample[f'{self.field}_{direction}_offset'] = [list(map(lambda x: x[0], p)) for p in value]
+ sample[f'{self.field}_{direction}_id'] = [list(map(lambda x: x[-1], p)) for p in value]
+ sample[f'{self.field}_{direction}_count'] = list(map(len, value))
+ return sample
+
+
+class GazetteerEmbedding(Embedding, AutoConfigurable):
+ def __init__(self, embed: str, field='char', trainable=False) -> None:
+ self.trainable = trainable
+ self.embed = embed
+ self.field = field
+ vocab, matrix = load_word2vec_as_vocab_tensor(self.embed)
+ ids = []
+ _vocab = {}
+ for word, idx in vocab.items():
+ if len(word) > 1:
+ ids.append(idx)
+ _vocab[word] = len(_vocab)
+ ids = torch.tensor(ids)
+ _matrix = matrix.index_select(0, ids)
+ self._vocab = _vocab
+ self._matrix = _matrix
+
+ def transform(self, **kwargs) -> Optional[Callable]:
+ return GazetterTransform(self.field, self._vocab)
+
+ def module(self, **kwargs) -> Optional[nn.Module]:
+ embed = nn.Embedding.from_pretrained(self._matrix, freeze=not self.trainable)
+ return embed
- def get_config(self):
- config = {
- 'filepath': self.filepath,
- 'expand_vocab': self.expand_vocab,
- 'lowercase': self.lowercase,
- }
- base_config = super(Word2VecEmbedding, self).get_config()
- base_config['embeddings_initializer'] = self._embeddings_initializer
- return dict(list(base_config.items()) + list(config.items()))
-
-
-@hanlp_register
-class StringWord2VecEmbedding(Word2VecEmbedding):
-
- def __init__(self, filepath: str = None, vocab: Vocab = None, expand_vocab=True, lowercase=False, input_dim=None,
- output_dim=None, unk=None, normalize=False, embeddings_initializer='VarianceScaling',
- embeddings_regularizer=None, activity_regularizer=None, embeddings_constraint=None, mask_zero=True,
- input_length=None, name=None, **kwargs):
- if vocab is None:
- vocab = Vocab()
- self.vocab = vocab
- super().__init__(filepath, vocab, expand_vocab, lowercase, input_dim, output_dim, unk, normalize,
- embeddings_initializer, embeddings_regularizer, activity_regularizer, embeddings_constraint,
- mask_zero, input_length, name, **kwargs)
-
- def call(self, inputs):
- assert inputs.dtype == tf.string, \
- f'Expect tf.string but got tf.{inputs.dtype.name}. {inputs}' \
- f'Please pass tf.{inputs.dtype.name} in.'
- inputs = self.vocab.lookup(inputs)
- # inputs._keras_mask = tf.not_equal(inputs, self.vocab.pad_idx)
- return super().call(inputs)
-
- def compute_mask(self, inputs, mask=None):
- if not self.mask_zero:
- return None
- return tf.not_equal(inputs, self.vocab.pad_token)
+ @staticmethod
+ def _remove_short_tokens(word2vec):
+ word2vec = dict((w, v) for w, v in word2vec.items() if len(w) > 1)
+ return word2vec
diff --git a/hanlp/layers/embeddings/word2vec_tf.py b/hanlp/layers/embeddings/word2vec_tf.py
new file mode 100644
index 000000000..567016d9c
--- /dev/null
+++ b/hanlp/layers/embeddings/word2vec_tf.py
@@ -0,0 +1,196 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-08-24 21:49
+import os
+from typing import Tuple, Union, List
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.ops import math_ops
+
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.utils.io_util import load_word2vec, get_resource
+from hanlp.utils.tf_util import hanlp_register
+from hanlp_common.util import DummyContext
+
+
+class Word2VecEmbeddingV1(tf.keras.layers.Layer):
+ def __init__(self, path: str = None, vocab: VocabTF = None, normalize: bool = False, load_all=True, mask_zero=True,
+ trainable=False, name=None, dtype=None, dynamic=False, **kwargs):
+ super().__init__(trainable, name, dtype, dynamic, **kwargs)
+ if load_all and vocab and vocab.locked:
+ vocab.unlock()
+ self.vocab, self.array_np = self._load(path, vocab, normalize)
+ self.vocab.lock()
+ self.array_ks = tf.keras.layers.Embedding(input_dim=len(self.vocab), output_dim=self.dim, trainable=trainable,
+ embeddings_initializer=tf.keras.initializers.Constant(self.array_np),
+ mask_zero=mask_zero)
+ self.mask_zero = mask_zero
+ self.supports_masking = mask_zero
+
+ def compute_mask(self, inputs, mask=None):
+ if not self.mask_zero:
+ return None
+
+ return math_ops.not_equal(inputs, self.vocab.pad_idx)
+
+ def call(self, inputs, **kwargs):
+ return self.array_ks(inputs, **kwargs)
+
+ def compute_output_shape(self, input_shape):
+ return input_shape[0], self.dim
+
+ @staticmethod
+ def _load(path, vocab, normalize=False) -> Tuple[VocabTF, Union[np.ndarray, None]]:
+ if not vocab:
+ vocab = VocabTF()
+ if not path:
+ return vocab, None
+ assert vocab.unk_idx is not None
+
+ word2vec, dim = load_word2vec(path)
+ for word in word2vec:
+ vocab.get_idx(word)
+
+ pret_embs = np.zeros(shape=(len(vocab), dim), dtype=np.float32)
+ state = np.random.get_state()
+ np.random.seed(0)
+ bias = np.random.uniform(low=-0.001, high=0.001, size=dim).astype(dtype=np.float32)
+ scale = np.sqrt(3.0 / dim)
+ for word, idx in vocab.token_to_idx.items():
+ vec = word2vec.get(word, None)
+ if vec is None:
+ vec = word2vec.get(word.lower(), None)
+ # if vec is not None:
+ # vec += bias
+ if vec is None:
+ # vec = np.random.uniform(-scale, scale, [dim])
+ vec = np.zeros([dim], dtype=np.float32)
+ pret_embs[idx] = vec
+ # noinspection PyTypeChecker
+ np.random.set_state(state)
+ return vocab, pret_embs
+
+ @property
+ def size(self):
+ if self.array_np is not None:
+ return self.array_np.shape[0]
+
+ @property
+ def dim(self):
+ if self.array_np is not None:
+ return self.array_np.shape[1]
+
+ @property
+ def shape(self):
+ if self.array_np is None:
+ return None
+ return self.array_np.shape
+
+ def get_vector(self, word: str) -> np.ndarray:
+ assert self.array_np is not None
+ return self.array_np[self.vocab.get_idx_without_add(word)]
+
+ def __getitem__(self, word: Union[str, List, tf.Tensor]) -> np.ndarray:
+ if isinstance(word, str):
+ return self.get_vector(word)
+ elif isinstance(word, list):
+ vectors = np.zeros(shape=(len(word), self.dim))
+ for idx, token in enumerate(word):
+ vectors[idx] = self.get_vector(token)
+ return vectors
+ elif isinstance(word, tf.Tensor):
+ if word.dtype == tf.string:
+ word_ids = self.vocab.token_to_idx_table.lookup(word)
+ return tf.nn.embedding_lookup(self.array_tf, word_ids)
+ elif word.dtype == tf.int32 or word.dtype == tf.int64:
+ return tf.nn.embedding_lookup(self.array_tf, word)
+
+
+@hanlp_register
+class Word2VecEmbeddingTF(tf.keras.layers.Embedding):
+
+ def __init__(self, filepath: str = None, vocab: VocabTF = None, expand_vocab=True, lowercase=True,
+ input_dim=None, output_dim=None, unk=None, normalize=False,
+ embeddings_initializer='VarianceScaling',
+ embeddings_regularizer=None,
+ activity_regularizer=None, embeddings_constraint=None, mask_zero=True, input_length=None,
+ name=None, cpu=True, **kwargs):
+ filepath = get_resource(filepath)
+ word2vec, _output_dim = load_word2vec(filepath)
+ if output_dim:
+ assert output_dim == _output_dim, f'output_dim = {output_dim} does not match {filepath}'
+ output_dim = _output_dim
+ # if the `unk` token exists in the pretrained,
+ # then replace it with a self-defined one, usually the one in word vocab
+ if unk and unk in word2vec:
+ word2vec[vocab.safe_unk_token] = word2vec.pop(unk)
+ if vocab is None:
+ vocab = VocabTF()
+ vocab.update(word2vec.keys())
+ if expand_vocab and vocab.mutable:
+ for word in word2vec:
+ vocab.get_idx(word.lower() if lowercase else word)
+ if input_dim:
+ assert input_dim == len(vocab), f'input_dim = {input_dim} does not match {filepath}'
+ input_dim = len(vocab)
+ # init matrix
+ self._embeddings_initializer = embeddings_initializer
+ embeddings_initializer = tf.keras.initializers.get(embeddings_initializer)
+ with tf.device('cpu:0') if cpu else DummyContext():
+ pret_embs = embeddings_initializer(shape=[input_dim, output_dim]).numpy()
+ # insert to pret_embs
+ for word, idx in vocab.token_to_idx.items():
+ vec = word2vec.get(word, None)
+ # Retry lower case
+ if vec is None and lowercase:
+ vec = word2vec.get(word.lower(), None)
+ if vec is not None:
+ pret_embs[idx] = vec
+ if normalize:
+ pret_embs /= np.std(pret_embs)
+ if not name:
+ name = os.path.splitext(os.path.basename(filepath))[0]
+ super().__init__(input_dim, output_dim, tf.keras.initializers.Constant(pret_embs), embeddings_regularizer,
+ activity_regularizer, embeddings_constraint, mask_zero, input_length, name=name, **kwargs)
+ self.filepath = filepath
+ self.expand_vocab = expand_vocab
+ self.lowercase = lowercase
+
+ def get_config(self):
+ config = {
+ 'filepath': self.filepath,
+ 'expand_vocab': self.expand_vocab,
+ 'lowercase': self.lowercase,
+ }
+ base_config = super(Word2VecEmbeddingTF, self).get_config()
+ base_config['embeddings_initializer'] = self._embeddings_initializer
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+@hanlp_register
+class StringWord2VecEmbeddingTF(Word2VecEmbeddingTF):
+
+ def __init__(self, filepath: str = None, vocab: VocabTF = None, expand_vocab=True, lowercase=False, input_dim=None,
+ output_dim=None, unk=None, normalize=False, embeddings_initializer='VarianceScaling',
+ embeddings_regularizer=None, activity_regularizer=None, embeddings_constraint=None, mask_zero=True,
+ input_length=None, name=None, **kwargs):
+ if vocab is None:
+ vocab = VocabTF()
+ self.vocab = vocab
+ super().__init__(filepath, vocab, expand_vocab, lowercase, input_dim, output_dim, unk, normalize,
+ embeddings_initializer, embeddings_regularizer, activity_regularizer, embeddings_constraint,
+ mask_zero, input_length, name, **kwargs)
+
+ def call(self, inputs):
+ assert inputs.dtype == tf.string, \
+ f'Expect tf.string but got tf.{inputs.dtype.name}. {inputs}' \
+ f'Please pass tf.{inputs.dtype.name} in.'
+ inputs = self.vocab.lookup(inputs)
+ # inputs._keras_mask = tf.not_equal(inputs, self.vocab.pad_idx)
+ return super().call(inputs)
+
+ def compute_mask(self, inputs, mask=None):
+ if not self.mask_zero:
+ return None
+ return tf.not_equal(inputs, self.vocab.pad_token)
diff --git a/hanlp/layers/feed_forward.py b/hanlp/layers/feed_forward.py
new file mode 100644
index 000000000..fa95696c4
--- /dev/null
+++ b/hanlp/layers/feed_forward.py
@@ -0,0 +1,15 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-06 14:37
+from typing import Union, List
+
+from alnlp.modules import feedforward
+
+from hanlp.common.structure import ConfigTracker
+
+
+class FeedForward(feedforward.FeedForward, ConfigTracker):
+ def __init__(self, input_dim: int, num_layers: int, hidden_dims: Union[int, List[int]],
+ activations: Union[str, List[str]], dropout: Union[float, List[float]] = 0.0) -> None:
+ super().__init__(input_dim, num_layers, hidden_dims, activations, dropout)
+ ConfigTracker.__init__(self, locals())
diff --git a/hanlp/layers/pass_through_encoder.py b/hanlp/layers/pass_through_encoder.py
new file mode 100644
index 000000000..80fa356f9
--- /dev/null
+++ b/hanlp/layers/pass_through_encoder.py
@@ -0,0 +1,12 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-08 17:56
+from alnlp.modules.pass_through_encoder import PassThroughEncoder as _PassThroughEncoder
+
+from hanlp.common.structure import ConfigTracker
+
+
+class PassThroughEncoder(_PassThroughEncoder, ConfigTracker):
+ def __init__(self, input_dim: int) -> None:
+ super().__init__(input_dim)
+ ConfigTracker.__init__(self, locals())
diff --git a/hanlp/layers/scalar_mix.py b/hanlp/layers/scalar_mix.py
new file mode 100644
index 000000000..a887ecf9b
--- /dev/null
+++ b/hanlp/layers/scalar_mix.py
@@ -0,0 +1,156 @@
+# This file is modified from udify, which is licensed under the MIT license:
+# MIT License
+#
+# Copyright (c) 2019 Dan Kondratyuk
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+"""
+The dot-product "Layer Attention" that is applied to the layers of BERT, along with layer dropout to reduce overfitting
+"""
+
+from typing import List, Tuple
+
+import torch
+from torch.nn import ParameterList, Parameter
+
+from hanlp.common.structure import ConfigTracker
+
+
+class ScalarMixWithDropout(torch.nn.Module):
+ """Computes a parameterised scalar mixture of N tensors, ``mixture = gamma * sum(s_k * tensor_k)``
+ where ``s = softmax(w)``, with ``w`` and ``gamma`` scalar parameters.
+
+ If ``do_layer_norm=True`` then apply layer normalization to each tensor before weighting.
+
+ If ``dropout > 0``, then for each scalar weight, adjust its softmax weight mass to 0 with
+ the dropout probability (i.e., setting the unnormalized weight to -inf). This effectively
+ should redistribute dropped probability mass to all other weights.
+
+ Args:
+
+ Returns:
+
+ """
+
+ def __init__(self,
+ mixture_range: Tuple[int, int],
+ do_layer_norm: bool = False,
+ initial_scalar_parameters: List[float] = None,
+ trainable: bool = True,
+ dropout: float = None,
+ dropout_value: float = -1e20,
+ **kwargs) -> None:
+ super(ScalarMixWithDropout, self).__init__()
+ self.mixture_range = mixture_range
+ mixture_size = mixture_range[1] - mixture_range[0]
+ self.mixture_size = mixture_size
+ self.do_layer_norm = do_layer_norm
+ self.dropout = dropout
+
+ if initial_scalar_parameters is None:
+ initial_scalar_parameters = [0.0] * mixture_size
+ elif len(initial_scalar_parameters) != mixture_size:
+ raise ValueError("Length of initial_scalar_parameters {} differs "
+ "from mixture_size {}".format(
+ initial_scalar_parameters, mixture_size))
+
+ # self.scalar_parameters = ParameterList(
+ # [Parameter(torch.FloatTensor([initial_scalar_parameters[i]]),
+ # requires_grad=trainable) for i
+ # in range(mixture_size)])
+ self.scalar_parameters = Parameter(torch.FloatTensor(initial_scalar_parameters), requires_grad=True)
+ self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable)
+
+ if self.dropout:
+ dropout_mask = torch.zeros(len(self.scalar_parameters))
+ dropout_fill = torch.empty(len(self.scalar_parameters)).fill_(dropout_value)
+ self.register_buffer("dropout_mask", dropout_mask)
+ self.register_buffer("dropout_fill", dropout_fill)
+
+ def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ
+ mask: torch.Tensor = None) -> torch.Tensor:
+ """Compute a weighted average of the ``tensors``. The input tensors an be any shape
+ with at least two dimensions, but must all be the same shape.
+
+ When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are
+ dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned
+ ``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape
+ ``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``.
+
+ When ``do_layer_norm=False`` the ``mask`` is ignored.
+
+ Args:
+ tensors: List[torch.Tensor]:
+ # pylint: disable: (Default value = arguments-differmask: torch.Tensor = None)
+
+ Returns:
+
+ """
+ if len(tensors) != self.mixture_size:
+ tensors = tensors[self.mixture_range[0]:self.mixture_range[1]]
+ if len(tensors) != self.mixture_size:
+ raise ValueError("{} tensors were passed, but the module was initialized to "
+ "mix {} tensors.".format(len(tensors), self.mixture_size))
+
+ def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked):
+ tensor_masked = tensor * broadcast_mask
+ mean = torch.sum(tensor_masked) / num_elements_not_masked
+ variance = torch.sum(((tensor_masked - mean) * broadcast_mask) ** 2) / num_elements_not_masked
+ return (tensor - mean) / torch.sqrt(variance + 1E-12)
+
+ weights = self.scalar_parameters
+
+ if self.dropout:
+ weights = torch.where(self.dropout_mask.uniform_() > self.dropout, weights, self.dropout_fill)
+
+ normed_weights = torch.nn.functional.softmax(weights, dim=0)
+ normed_weights = torch.split(normed_weights, split_size_or_sections=1)
+
+ if not self.do_layer_norm:
+ pieces = []
+ for weight, tensor in zip(normed_weights, tensors):
+ pieces.append(weight * tensor)
+ return self.gamma * sum(pieces)
+
+ else:
+ mask_float = mask.float()
+ broadcast_mask = mask_float.unsqueeze(-1)
+ input_dim = tensors[0].size(-1)
+ num_elements_not_masked = torch.sum(mask_float) * input_dim
+
+ pieces = []
+ for weight, tensor in zip(normed_weights, tensors):
+ pieces.append(weight * _do_layer_norm(tensor,
+ broadcast_mask, num_elements_not_masked))
+ return self.gamma * sum(pieces)
+
+
+class ScalarMixWithDropoutBuilder(ConfigTracker, ScalarMixWithDropout):
+
+ def __init__(self,
+ mixture_range: Tuple[int, int],
+ do_layer_norm: bool = False,
+ initial_scalar_parameters: List[float] = None,
+ trainable: bool = True,
+ dropout: float = None,
+ dropout_value: float = -1e20) -> None:
+ super().__init__(locals())
+
+ def build(self):
+ return ScalarMixWithDropout(**self.config)
diff --git a/hanlp/layers/transformers/__init__.py b/hanlp/layers/transformers/__init__.py
index 8760bfd27..4f5e1a4d9 100644
--- a/hanlp/layers/transformers/__init__.py
+++ b/hanlp/layers/transformers/__init__.py
@@ -1,14 +1,13 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-29 15:17
-from bert import bert_models_google
+# mute transformers
+import logging
-from hanlp.common.constant import HANLP_URL
-
-zh_albert_models_google = {
- 'albert_base_zh': HANLP_URL + 'embeddings/albert_base_zh.tar.gz', # Provide mirroring
- 'albert_large_zh': 'https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz',
- 'albert_xlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz',
- 'albert_xxlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz',
-}
-bert_models_google['chinese_L-12_H-768_A-12'] = HANLP_URL + 'embeddings/chinese_L-12_H-768_A-12.zip'
+logging.getLogger('transformers.file_utils').setLevel(logging.ERROR)
+logging.getLogger('transformers.filelock').setLevel(logging.ERROR)
+logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)
+logging.getLogger('transformers.configuration_utils').setLevel(logging.ERROR)
+logging.getLogger('transformers.modeling_tf_utils').setLevel(logging.ERROR)
+logging.getLogger('transformers.modeling_utils').setLevel(logging.ERROR)
+logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR)
diff --git a/hanlp/layers/transformers/encoder.py b/hanlp/layers/transformers/encoder.py
new file mode 100644
index 000000000..a7e5e46ad
--- /dev/null
+++ b/hanlp/layers/transformers/encoder.py
@@ -0,0 +1,124 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-22 21:06
+import warnings
+from typing import Union, Dict, Any, Sequence
+
+import torch
+from torch import nn
+
+from hanlp.layers.dropout import WordDropout
+from hanlp.layers.scalar_mix import ScalarMixWithDropout, ScalarMixWithDropoutBuilder
+from hanlp.layers.transformers.pt_imports import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoModel_
+from hanlp.layers.transformers.utils import transformer_encode
+
+
+# noinspection PyAbstractClass
+class TransformerEncoder(nn.Module):
+ def __init__(self,
+ transformer: Union[PreTrainedModel, str],
+ transformer_tokenizer: PreTrainedTokenizer,
+ average_subwords=False,
+ scalar_mix: Union[ScalarMixWithDropoutBuilder, int] = None,
+ word_dropout=None,
+ max_sequence_length=None,
+ ret_raw_hidden_states=False,
+ transformer_args: Dict[str, Any] = None,
+ trainable=True,
+ training=True) -> None:
+ """A pre-trained transformer encoder.
+
+ Args:
+ transformer: A ``PreTrainedModel`` or an identifier of a ``PreTrainedModel``.
+ transformer_tokenizer: A ``PreTrainedTokenizer``.
+ average_subwords: ``True`` to average subword representations.
+ scalar_mix: Layer attention.
+ word_dropout: Dropout rate of randomly replacing a subword with MASK.
+ max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding
+ window.
+ ret_raw_hidden_states: ``True`` to return hidden states of each layer.
+ transformer_args: Extra arguments passed to the transformer.
+ trainable: ``False`` to use static embeddings.
+ training: ``False`` to skip loading weights from pre-trained transformers.
+ """
+ super().__init__()
+ self.ret_raw_hidden_states = ret_raw_hidden_states
+ self.max_sequence_length = max_sequence_length
+ self.average_subwords = average_subwords
+ if word_dropout:
+ oov = transformer_tokenizer.mask_token_id
+ if isinstance(word_dropout, Sequence):
+ word_dropout, replacement = word_dropout
+ if replacement == 'unk':
+ # Electra English has to use unk
+ oov = transformer_tokenizer.unk_token_id
+ elif replacement == 'mask':
+ # UDify uses [MASK]
+ oov = transformer_tokenizer.mask_token_id
+ else:
+ oov = replacement
+ pad = transformer_tokenizer.pad_token_id
+ cls = transformer_tokenizer.cls_token_id
+ sep = transformer_tokenizer.sep_token_id
+ excludes = [pad, cls, sep]
+ self.word_dropout = WordDropout(p=word_dropout, oov_token=oov, exclude_tokens=excludes)
+ else:
+ self.word_dropout = None
+ if isinstance(transformer, str):
+ output_hidden_states = scalar_mix is not None
+ if transformer_args is None:
+ transformer_args = dict()
+ transformer_args['output_hidden_states'] = output_hidden_states
+ transformer = AutoModel_.from_pretrained(transformer, training=training or not trainable,
+ **transformer_args)
+ if hasattr(transformer, 'encoder') and hasattr(transformer, 'decoder'):
+ # For seq2seq model, use its encoder
+ transformer = transformer.encoder
+ self.transformer = transformer
+ if not trainable:
+ transformer.requires_grad_(False)
+
+ if isinstance(scalar_mix, ScalarMixWithDropoutBuilder):
+ self.scalar_mix: ScalarMixWithDropout = scalar_mix.build()
+ else:
+ self.scalar_mix = None
+
+ def forward(self, input_ids: torch.LongTensor, attention_mask=None, token_type_ids=None, token_span=None, **kwargs):
+ if self.word_dropout:
+ input_ids = self.word_dropout(input_ids)
+
+ x = transformer_encode(self.transformer,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ token_span,
+ layer_range=self.scalar_mix.mixture_range if self.scalar_mix else 0,
+ max_sequence_length=self.max_sequence_length,
+ average_subwords=self.average_subwords,
+ ret_raw_hidden_states=self.ret_raw_hidden_states)
+ if self.ret_raw_hidden_states:
+ x, raw_hidden_states = x
+ if self.scalar_mix:
+ x = self.scalar_mix(x)
+ if self.ret_raw_hidden_states:
+ # noinspection PyUnboundLocalVariable
+ return x, raw_hidden_states
+ return x
+
+ @staticmethod
+ def build_transformer(config, training=True) -> PreTrainedModel:
+ kwargs = {}
+ if config.scalar_mix and config.scalar_mix > 0:
+ kwargs['output_hidden_states'] = True
+ transformer = AutoModel_.from_pretrained(config.transformer, training=training, **kwargs)
+ return transformer
+
+ @staticmethod
+ def build_transformer_tokenizer(config_or_str, use_fast=True, do_basic_tokenize=True) -> PreTrainedTokenizer:
+ if isinstance(config_or_str, str):
+ transformer = config_or_str
+ else:
+ transformer = config_or_str.transformer
+ if use_fast and not do_basic_tokenize:
+ warnings.warn('`do_basic_tokenize=False` might not work when `use_fast=True`')
+ return AutoTokenizer.from_pretrained(transformer, use_fast=use_fast, do_basic_tokenize=do_basic_tokenize)
diff --git a/hanlp/layers/transformers/loader.py b/hanlp/layers/transformers/loader_tf.py
similarity index 97%
rename from hanlp/layers/transformers/loader.py
rename to hanlp/layers/transformers/loader_tf.py
index 8cea1c08c..351797f49 100644
--- a/hanlp/layers/transformers/loader.py
+++ b/hanlp/layers/transformers/loader_tf.py
@@ -9,7 +9,7 @@
from bert import albert_models_tfhub, fetch_tfhub_albert_model, load_stock_weights
from bert.loader_albert import albert_params
-from hanlp.layers.transformers import zh_albert_models_google, bert_models_google
+from hanlp.layers.transformers.tf_imports import zh_albert_models_google, bert_models_google
from hanlp.utils.io_util import get_resource, stdout_redirected, hanlp_home
diff --git a/hanlp/layers/transformers/pt_imports.py b/hanlp/layers/transformers/pt_imports.py
new file mode 100644
index 000000000..dec1c523e
--- /dev/null
+++ b/hanlp/layers/transformers/pt_imports.py
@@ -0,0 +1,26 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-09 11:25
+import os
+
+if os.environ.get('USE_TF', None) is None:
+ os.environ["USE_TF"] = 'NO' # saves time loading transformers
+if os.environ.get('TOKENIZERS_PARALLELISM', None) is None:
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+from transformers import BertTokenizer, BertConfig, PretrainedConfig, \
+ AutoConfig, AutoTokenizer, PreTrainedTokenizer, BertTokenizerFast, AlbertConfig, BertModel, AutoModel, \
+ PreTrainedModel, get_linear_schedule_with_warmup, AdamW, AutoModelForSequenceClassification, \
+ AutoModelForTokenClassification, optimization, BartModel
+
+
+class AutoModel_(AutoModel):
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, training=True, **kwargs):
+ if training:
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+ else:
+ if isinstance(pretrained_model_name_or_path, str):
+ return super().from_config(AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs))
+ else:
+ assert not kwargs
+ return super().from_config(pretrained_model_name_or_path)
diff --git a/hanlp/layers/transformers/relative_transformer.py b/hanlp/layers/transformers/relative_transformer.py
new file mode 100644
index 000000000..6a37773d9
--- /dev/null
+++ b/hanlp/layers/transformers/relative_transformer.py
@@ -0,0 +1,329 @@
+# A modified version of the implementation from the following paper:
+# TENER: Adapting Transformer Encoder for Named Entity Recognition
+# Hang Yan, Bocao Deng, Xiaonan Li, Xipeng Qiu
+
+import math
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from hanlp.common.structure import ConfigTracker
+
+
+class RelativeSinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+ Padding symbols are ignored.
+
+ Args:
+ embedding_dim: embedding size of each position
+ padding_idx:
+ Returns:
+
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ assert init_size % 2 == 0
+ weights = self.get_embedding(
+ init_size + 1,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('weights', weights)
+ self.register_buffer('_float_tensor', torch.as_tensor(1))
+
+ def get_embedding(self, num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+
+ Args:
+ num_embeddings:
+ embedding_dim:
+ padding_idx: (Default value = None)
+
+ Returns:
+
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(-num_embeddings // 2, num_embeddings // 2, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ self.origin_shift = num_embeddings // 2 + 1
+ return emb
+
+ def forward(self, inputs: Tensor):
+ """Input is expected to be of size [bsz x seqlen].
+
+ Args:
+ inputs: Tensor:
+
+ Returns:
+
+ """
+ bsz, seq_len = inputs.size()
+ max_pos = self.padding_idx + seq_len
+ if max_pos > self.origin_shift:
+ # recompute/expand embeddings if needed
+ weights = self.get_embedding(
+ max_pos * 2,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ weights = weights.to(self._float_tensor)
+ del self.weights
+ self.origin_shift = weights.size(0) // 2
+ self.register_buffer('weights', weights)
+
+ positions = torch.arange(-seq_len, seq_len).to(inputs.device).long() + self.origin_shift # 2*seq_len
+ embed = self.weights.index_select(0, positions.long()).detach()
+ return embed
+
+
+class RelativeMultiHeadAttn(nn.Module):
+ def __init__(self, in_features, num_heads, dropout, r_w_bias=None, r_r_bias=None, init_seq_length=1024,
+ k_as_x=True):
+ """
+ Args:
+ in_features:
+ num_heads:
+ dropout:
+ r_w_bias: n_head x head_dim or None
+ r_r_bias: n_head x head_dim or None
+ init_seq_length:
+ k_as_x:
+ """
+ super().__init__()
+ self.k_as_x = k_as_x
+ if k_as_x:
+ self.qv_linear = nn.Linear(in_features, in_features * 2, bias=False)
+ else:
+ self.qkv_linear = nn.Linear(in_features, in_features * 3, bias=False)
+ self.n_head = num_heads
+ self.head_dim = in_features // num_heads
+ self.dropout_layer = nn.Dropout(dropout)
+ self.pos_embed = RelativeSinusoidalPositionalEmbedding(self.head_dim, 0, init_seq_length)
+ if r_r_bias is None or r_w_bias is None: # Biases are not shared
+ self.r_r_bias = nn.Parameter(nn.init.xavier_normal_(torch.zeros(num_heads, in_features // num_heads)))
+ self.r_w_bias = nn.Parameter(nn.init.xavier_normal_(torch.zeros(num_heads, in_features // num_heads)))
+ else:
+ self.r_r_bias = r_r_bias # r_r_bias就是v
+ self.r_w_bias = r_w_bias # r_w_bias就是u
+
+ def forward(self, x, mask):
+ """
+
+ Args:
+ x: batch_size x max_len x d_model
+ mask: batch_size x max_len
+
+ Returns:
+
+ """
+
+ batch_size, max_len, d_model = x.size()
+ pos_embed = self.pos_embed(mask) # l x head_dim
+
+ if self.k_as_x:
+ qv = self.qv_linear(x) # batch_size x max_len x d_model2
+ q, v = torch.chunk(qv, chunks=2, dim=-1)
+ k = x.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)
+ else:
+ qkv = self.qkv_linear(x) # batch_size x max_len x d_model3
+ q, k, v = torch.chunk(qkv, chunks=3, dim=-1)
+ k = k.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)
+
+ q = q.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)
+ v = v.view(batch_size, max_len, self.n_head, -1).transpose(1, 2) # b x n x l x d
+
+ rw_head_q = q + self.r_r_bias[:, None]
+ AC = torch.einsum('bnqd,bnkd->bnqk', [rw_head_q, k]) # b x n x l x d, n是head
+
+ D_ = torch.einsum('nd,ld->nl', self.r_w_bias, pos_embed)[None, :, None] # head x 2max_len, 每个head对位置的bias
+ B_ = torch.einsum('bnqd,ld->bnql', q, pos_embed) # bsz x head x max_len x 2max_len,每个query对每个shift的偏移
+ E_ = torch.einsum('bnqd,ld->bnql', k, pos_embed) # bsz x head x max_len x 2max_len, key对relative的bias
+ BD = B_ + D_ # bsz x head x max_len x 2max_len, 要转换为bsz x head x max_len x max_len
+ if self.k_as_x:
+ BD = self._shift(BD)
+ attn = AC + BD
+ else:
+ BDE = self._shift(BD) + self._transpose_shift(E_)
+ attn = AC + BDE
+
+ attn = attn.masked_fill(mask[:, None, None, :].eq(0), float('-inf'))
+
+ attn = F.softmax(attn, dim=-1)
+ attn = self.dropout_layer(attn)
+ v = torch.matmul(attn, v).transpose(1, 2).reshape(batch_size, max_len, d_model) # b x n x l x d
+
+ return v
+
+ def _shift(self, BD):
+ """类似
+ -3 -2 -1 0 1 2
+ -3 -2 -1 0 1 2
+ -3 -2 -1 0 1 2
+ 转换为
+ 0 1 2
+ -1 0 1
+ -2 -1 0
+
+ Args:
+ BD: batch_size x n_head x max_len x 2max_len
+
+ Returns:
+ batch_size x n_head x max_len x max_len
+
+ """
+ bsz, n_head, max_len, _ = BD.size()
+ zero_pad = BD.new_zeros(bsz, n_head, max_len, 1)
+ BD = torch.cat([BD, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) # bsz x n_head x (2max_len+1) x max_len
+ BD = BD.narrow(dim=2, start=0, length=2 * max_len) \
+ .view(bsz, n_head, max_len, -1) # bsz x n_head x 2max_len x max_len
+ BD = BD.narrow(dim=-1, start=max_len, length=max_len)
+ return BD
+
+ def _transpose_shift(self, E):
+ """类似
+ -3 -2 -1 0 1 2
+ -30 -20 -10 00 10 20
+ -300 -200 -100 000 100 200
+
+ 转换为
+ 0 -10 -200
+ 1 00 -100
+ 2 10 000
+
+ Args:
+ E: batch_size x n_head x max_len x 2max_len
+
+ Returns:
+ batch_size x n_head x max_len x max_len
+
+ """
+ bsz, n_head, max_len, _ = E.size()
+ zero_pad = E.new_zeros(bsz, n_head, max_len, 1)
+ # bsz x n_head x -1 x (max_len+1)
+ E = torch.cat([E, zero_pad], dim=-1).view(bsz, n_head, -1, max_len)
+ indice = (torch.arange(max_len) * 2 + 1).to(E.device)
+ E = E.index_select(index=indice, dim=-2).transpose(-1, -2) # bsz x n_head x max_len x max_len
+
+ return E
+
+
+class RelativeTransformerLayer(nn.Module):
+ def __init__(self,
+ in_features,
+ num_heads=4,
+ feedforward_dim=256,
+ dropout=0.2,
+ dropout_attn=None,
+ after_norm=True,
+ k_as_x=True,
+ init_seq_length=1024):
+ super().__init__()
+ if dropout_attn is None:
+ dropout_attn = dropout
+ self.after_norm = after_norm
+ self.norm1 = nn.LayerNorm(in_features)
+ self.norm2 = nn.LayerNorm(in_features)
+ self.self_attn = RelativeMultiHeadAttn(in_features,
+ num_heads,
+ dropout=dropout_attn,
+ init_seq_length=init_seq_length,
+ k_as_x=k_as_x)
+ self.ffn = nn.Sequential(nn.Linear(in_features, feedforward_dim),
+ nn.LeakyReLU(),
+ nn.Dropout(dropout, inplace=True),
+ nn.Linear(feedforward_dim, in_features),
+ nn.Dropout(dropout, inplace=True))
+
+ def forward(self, x, mask):
+ """
+
+ Args:
+ x: batch_size x max_len x hidden_size
+ mask: batch_size x max_len, 为0的地方为pad
+
+ Returns:
+ batch_size x max_len x hidden_size
+
+ """
+ residual = x
+ if not self.after_norm:
+ x = self.norm1(x)
+
+ x = self.self_attn(x, mask)
+ x = x + residual
+ if self.after_norm:
+ x = self.norm1(x)
+ residual = x
+ if not self.after_norm:
+ x = self.norm2(x)
+ x = self.ffn(x)
+ x = residual + x
+ if self.after_norm:
+ x = self.norm2(x)
+ return x
+
+
+class RelativeTransformer(nn.Module):
+ def __init__(self,
+ in_features,
+ num_layers,
+ feedforward_dim,
+ num_heads,
+ dropout,
+ dropout_attn=None,
+ after_norm=True,
+ init_seq_length=1024,
+ k_as_x=True):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ RelativeTransformerLayer(in_features, feedforward_dim, num_heads, dropout, dropout_attn, after_norm,
+ init_seq_length=init_seq_length, k_as_x=k_as_x)
+ for _ in range(num_layers)
+ ])
+
+ def forward(self, x: Tensor, mask: Tensor):
+ """
+
+ Args:
+ x: batch_size x max_len
+ mask: batch_size x max_len. 有value的地方为1
+ x: Tensor:
+ mask: Tensor:
+
+ Returns:
+
+ """
+
+ for layer in self.layers:
+ x = layer(x, mask)
+ return x
+
+
+class RelativeTransformerEncoder(RelativeTransformer, ConfigTracker):
+ def __init__(self,
+ in_features,
+ num_layers=2,
+ num_heads=4,
+ feedforward_dim=256,
+ dropout=0.1,
+ dropout_attn=0.1,
+ after_norm=True,
+ k_as_x=True,
+ ):
+ super().__init__(in_features, num_layers, num_heads, feedforward_dim, dropout, dropout_attn, after_norm)
+ ConfigTracker.__init__(self, locals())
+
+ def get_output_dim(self):
+ return self.config['in_features']
diff --git a/hanlp/layers/transformers/tf_imports.py b/hanlp/layers/transformers/tf_imports.py
new file mode 100644
index 000000000..cd04d22fc
--- /dev/null
+++ b/hanlp/layers/transformers/tf_imports.py
@@ -0,0 +1,16 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-08 21:57
+from bert import bert_models_google
+from transformers import BertTokenizer, BertConfig, PretrainedConfig, TFAutoModel, \
+ AutoConfig, AutoTokenizer, PreTrainedTokenizer, TFPreTrainedModel, TFAlbertModel, TFAutoModelWithLMHead, BertTokenizerFast, TFAlbertForMaskedLM, AlbertConfig, TFBertModel
+
+from hanlp_common.constant import HANLP_URL
+
+zh_albert_models_google = {
+ 'albert_base_zh': HANLP_URL + 'embeddings/albert_base_zh.tar.gz', # Provide mirroring
+ 'albert_large_zh': 'https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz',
+ 'albert_xlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz',
+ 'albert_xxlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz',
+}
+bert_models_google['chinese_L-12_H-768_A-12'] = HANLP_URL + 'embeddings/chinese_L-12_H-768_A-12.zip'
\ No newline at end of file
diff --git a/hanlp/layers/transformers/utils.py b/hanlp/layers/transformers/utils.py
new file mode 100644
index 000000000..462bb625e
--- /dev/null
+++ b/hanlp/layers/transformers/utils.py
@@ -0,0 +1,366 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-15 21:22
+from collections import defaultdict
+from typing import Tuple, Union
+
+import torch
+from torch.nn import functional as F
+
+from hanlp.components.parsers.ud import udify_util as util
+from hanlp.layers.transformers.pt_imports import PreTrainedModel, optimization, AdamW, \
+ get_linear_schedule_with_warmup
+
+
+def transformer_encode(transformer: PreTrainedModel,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ token_span=None,
+ layer_range: Union[int, Tuple[int, int]] = 0,
+ max_sequence_length=None,
+ average_subwords=False,
+ ret_raw_hidden_states=False):
+ """Run transformer and pool its outputs.
+
+ Args:
+ transformer: A transformer model.
+ input_ids: Indices of subwords.
+ attention_mask: Mask for these subwords.
+ token_type_ids: Type ids for each subword.
+ token_span: The spans of tokens.
+ layer_range: The range of layers to use. Note that the 0-th layer means embedding layer, so the last 3 layers
+ of a 12-layer BERT will be (10, 13).
+ max_sequence_length: The maximum sequence length. Sequence longer than this will be handled by sliding
+ window.
+ average_subwords: ``True`` to average subword representations.
+ ret_raw_hidden_states: ``True`` to return hidden states of each layer.
+
+ Returns:
+ Pooled outputs.
+
+ """
+ if max_sequence_length and input_ids.size(-1) > max_sequence_length:
+ # TODO: split token type ids in transformer_sliding_window if token type ids are not always 1
+ outputs = transformer_sliding_window(transformer, input_ids, max_pieces=max_sequence_length)
+ else:
+ if attention_mask is None:
+ attention_mask = input_ids.ne(0)
+ if transformer.config.output_hidden_states:
+ outputs = transformer(input_ids, attention_mask, token_type_ids)[-1]
+ else:
+ outputs = transformer(input_ids, attention_mask, token_type_ids)[0]
+ if transformer.config.output_hidden_states:
+ if isinstance(layer_range, int):
+ outputs = outputs[layer_range:]
+ else:
+ outputs = outputs[layer_range[0], layer_range[1]]
+ # Slow pick
+ # hs = []
+ # for h in outputs:
+ # hs.append(pick_tensor_for_each_token(h, token_span, average_subwords))
+ # Fast pick
+ if not isinstance(outputs, torch.Tensor):
+ x = torch.stack(outputs)
+ else:
+ x = outputs
+ L, B, T, F = x.size()
+ x = x.flatten(end_dim=1)
+ # tile token_span as x
+ if token_span is not None:
+ token_span = token_span.repeat(L, 1, 1)
+ hs = pick_tensor_for_each_token(x, token_span, average_subwords).view(L, B, -1, F)
+ if ret_raw_hidden_states:
+ return hs, outputs
+ return hs
+ else:
+ if ret_raw_hidden_states:
+ return pick_tensor_for_each_token(outputs, token_span, average_subwords), outputs
+ return pick_tensor_for_each_token(outputs, token_span, average_subwords)
+
+
+def pick_tensor_for_each_token(h, token_span, average_subwords):
+ if token_span is None:
+ return h
+ if average_subwords and token_span.size(-1) > 1:
+ batch_size = h.size(0)
+ h_span = h.gather(1, token_span.view(batch_size, -1).unsqueeze(-1).expand(-1, -1, h.shape[-1]))
+ h_span = h_span.view(batch_size, *token_span.shape[1:], -1)
+ n_sub_tokens = token_span.ne(0)
+ n_sub_tokens[:, 0, 0] = True
+ h_span = (h_span * n_sub_tokens.unsqueeze(-1)).sum(2)
+ n_sub_tokens = n_sub_tokens.sum(-1).unsqueeze(-1)
+ zero_mask = n_sub_tokens == 0
+ if torch.any(zero_mask):
+ n_sub_tokens[zero_mask] = 1 # avoid dividing by zero
+ embed = h_span / n_sub_tokens
+ else:
+ embed = h.gather(1, token_span[:, :, 0].unsqueeze(-1).expand(-1, -1, h.size(-1)))
+ return embed
+
+
+def transformer_sliding_window(transformer: PreTrainedModel,
+ input_ids: torch.LongTensor,
+ input_mask=None,
+ offsets: torch.LongTensor = None,
+ token_type_ids: torch.LongTensor = None,
+ max_pieces=512,
+ start_tokens: int = 1,
+ end_tokens: int = 1,
+ ret_cls=None,
+ ) -> torch.Tensor:
+ """
+
+ Args:
+ transformer:
+ input_ids: torch.LongTensor:
+ input_mask: (Default value = None)
+ offsets: torch.LongTensor: (Default value = None)
+ token_type_ids: torch.LongTensor: (Default value = None)
+ max_pieces: (Default value = 512)
+ start_tokens: int: (Default value = 1)
+ end_tokens: int: (Default value = 1)
+ ret_cls: (Default value = None)
+
+ Returns:
+
+
+ """
+ # pylint: disable=arguments-differ
+ batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
+ initial_dims = list(input_ids.shape[:-1])
+
+ # The embedder may receive an input tensor that has a sequence length longer than can
+ # be fit. In that case, we should expect the wordpiece indexer to create padded windows
+ # of length `max_pieces` for us, and have them concatenated into one long sequence.
+ # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
+ # We can then split the sequence into sub-sequences of that length, and concatenate them
+ # along the batch dimension so we effectively have one huge batch of partial sentences.
+ # This can then be fed into BERT without any sentence length issues. Keep in mind
+ # that the memory consumption can dramatically increase for large batches with extremely
+ # long sentences.
+ needs_split = full_seq_len > max_pieces
+ if needs_split:
+ input_ids = split_to_sliding_window(input_ids, max_pieces)
+
+ # if token_type_ids is None:
+ # token_type_ids = torch.zeros_like(input_ids)
+ if input_mask is None:
+ input_mask = (input_ids != 0).long()
+
+ # input_ids may have extra dimensions, so we reshape down to 2-d
+ # before calling the BERT model and then reshape back at the end.
+ outputs = transformer(input_ids=util.combine_initial_dims_to_1d_or_2d(input_ids),
+ # token_type_ids=util.combine_initial_dims_to_1d_or_2d(token_type_ids),
+ attention_mask=util.combine_initial_dims_to_1d_or_2d(input_mask))
+ if len(outputs) == 3:
+ all_encoder_layers = outputs.hidden_states
+ all_encoder_layers = torch.stack(all_encoder_layers)
+ elif len(outputs) == 2:
+ all_encoder_layers, _ = outputs[:2]
+ else:
+ all_encoder_layers = outputs[0]
+
+ if needs_split:
+ if ret_cls is not None:
+ cls_mask = input_ids[:, 0] == input_ids[0][0]
+ cls_hidden = all_encoder_layers[:, 0, :]
+ if ret_cls == 'max':
+ cls_hidden[~cls_mask] = -1e20
+ else:
+ cls_hidden[~cls_mask] = 0
+ cls_mask = cls_mask.view(-1, batch_size).transpose(0, 1)
+ cls_hidden = cls_hidden.reshape(cls_mask.size(1), batch_size, -1).transpose(0, 1)
+ if ret_cls == 'max':
+ cls_hidden = cls_hidden.max(1)[0]
+ elif ret_cls == 'raw':
+ return cls_hidden, cls_mask
+ else:
+ cls_hidden = torch.sum(cls_hidden, dim=1)
+ cls_hidden /= torch.sum(cls_mask, dim=1, keepdim=True)
+ return cls_hidden
+ else:
+ recombined_embeddings, select_indices = restore_from_sliding_window(all_encoder_layers, batch_size,
+ max_pieces, full_seq_len, start_tokens,
+ end_tokens)
+
+ initial_dims.append(len(select_indices))
+ else:
+ recombined_embeddings = all_encoder_layers
+
+ # Recombine the outputs of all layers
+ # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
+ # recombined = torch.cat(combined, dim=2)
+ # input_mask = (recombined_embeddings != 0).long()
+
+ # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)
+
+ if offsets is None:
+ # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
+ dims = initial_dims if needs_split else input_ids.size()
+ layers = util.uncombine_initial_dims(recombined_embeddings, dims)
+ else:
+ # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
+ offsets2d = util.combine_initial_dims_to_1d_or_2d(offsets)
+ # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
+ range_vector = util.get_range_vector(offsets2d.size(0),
+ device=util.get_device_of(recombined_embeddings)).unsqueeze(1)
+ # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
+ selected_embeddings = recombined_embeddings[:, range_vector, offsets2d]
+
+ layers = util.uncombine_initial_dims(selected_embeddings, offsets.size())
+
+ return layers
+
+
+def split_to_sliding_window(input_ids, max_pieces):
+ # Split the flattened list by the window size, `max_pieces`
+ split_input_ids = list(input_ids.split(max_pieces, dim=-1))
+ # We want all sequences to be the same length, so pad the last sequence
+ last_window_size = split_input_ids[-1].size(-1)
+ padding_amount = max_pieces - last_window_size
+ split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0)
+ # Now combine the sequences along the batch dimension
+ input_ids = torch.cat(split_input_ids, dim=0)
+ return input_ids
+
+
+def restore_from_sliding_window(all_encoder_layers, batch_size, max_pieces, full_seq_len, start_tokens, end_tokens):
+ # First, unpack the output embeddings into one long sequence again
+ unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=-3)
+ unpacked_embeddings = torch.cat(unpacked_embeddings, dim=-2)
+ # Next, select indices of the sequence such that it will result in embeddings representing the original
+ # sentence. To capture maximal context, the indices will be the middle part of each embedded window
+ # sub-sequence (plus any leftover start and final edge windows), e.g.,
+ # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+ # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
+ # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
+ # and final windows with indices [0, 1] and [14, 15] respectively.
+ # Find the stride as half the max pieces, ignoring the special start and end tokens
+ # Calculate an offset to extract the centermost embeddings of each window
+ stride = (max_pieces - start_tokens - end_tokens) // 2
+ stride_offset = stride // 2 + start_tokens
+ first_window = list(range(stride_offset))
+ max_context_windows = [i for i in range(full_seq_len)
+ if stride_offset - 1 < i % max_pieces < stride_offset + stride]
+ final_window_start = max_context_windows[-1] + 1
+ final_window = list(range(final_window_start, full_seq_len))
+ select_indices = first_window + max_context_windows + final_window
+ select_indices = torch.LongTensor(select_indices).to(unpacked_embeddings.device)
+ recombined_embeddings = unpacked_embeddings.index_select(-2, select_indices)
+ return recombined_embeddings, select_indices
+
+
+def build_optimizer_for_pretrained(model: torch.nn.Module,
+ pretrained: torch.nn.Module,
+ lr=1e-5,
+ weight_decay=0.01,
+ eps=1e-8,
+ transformer_lr=None,
+ transformer_weight_decay=None,
+ no_decay=('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
+ **kwargs):
+ if transformer_lr is None:
+ transformer_lr = lr
+ if transformer_weight_decay is None:
+ transformer_weight_decay = weight_decay
+ params = defaultdict(lambda: defaultdict(list))
+ pretrained = set(pretrained.parameters())
+ if isinstance(no_decay, tuple):
+ def no_decay_fn(name):
+ return any(nd in name for nd in no_decay)
+ else:
+ assert callable(no_decay), 'no_decay has to be callable or a tuple of str'
+ no_decay_fn = no_decay
+ for n, p in model.named_parameters():
+ is_pretrained = 'pretrained' if p in pretrained else 'non_pretrained'
+ is_no_decay = 'no_decay' if no_decay_fn(n) else 'decay'
+ params[is_pretrained][is_no_decay].append(p)
+
+ grouped_parameters = [
+ {'params': params['pretrained']['decay'], 'weight_decay': transformer_weight_decay, 'lr': transformer_lr},
+ {'params': params['pretrained']['no_decay'], 'weight_decay': 0.0, 'lr': transformer_lr},
+ {'params': params['non_pretrained']['decay'], 'weight_decay': weight_decay, 'lr': lr},
+ {'params': params['non_pretrained']['no_decay'], 'weight_decay': 0.0, 'lr': lr},
+ ]
+
+ return optimization.AdamW(
+ grouped_parameters,
+ lr=lr,
+ weight_decay=weight_decay,
+ eps=eps,
+ **kwargs)
+
+
+def build_optimizer_scheduler_with_transformer(model: torch.nn.Module,
+ transformer: torch.nn.Module,
+ lr: float,
+ transformer_lr: float,
+ num_training_steps: int,
+ warmup_steps: Union[float, int],
+ weight_decay: float,
+ adam_epsilon: float,
+ no_decay=('bias', 'LayerNorm.bias', 'LayerNorm.weight')):
+ optimizer = build_optimizer_for_pretrained(model,
+ transformer,
+ lr,
+ weight_decay,
+ eps=adam_epsilon,
+ transformer_lr=transformer_lr,
+ no_decay=no_decay)
+ if isinstance(warmup_steps, float):
+ assert 0 < warmup_steps < 1, 'warmup_steps has to fall in range (0, 1) when it is float.'
+ warmup_steps = num_training_steps * warmup_steps
+ scheduler = optimization.get_linear_schedule_with_warmup(optimizer, warmup_steps, num_training_steps)
+ return optimizer, scheduler
+
+
+def get_optimizers(
+ model: torch.nn.Module,
+ num_training_steps: int,
+ learning_rate=5e-5,
+ adam_epsilon=1e-8,
+ weight_decay=0.0,
+ warmup_steps=0.1,
+) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
+ """
+ Modified from https://github.com/huggingface/transformers/blob/7b75aa9fa55bee577e2c7403301ed31103125a35/src/transformers/trainer.py#L232
+ Setup the optimizer and the learning rate scheduler.
+
+ We provide a reasonable default that works well.
+ """
+ if isinstance(warmup_steps, float):
+ assert 0 < warmup_steps < 1
+ warmup_steps = int(num_training_steps * warmup_steps)
+ # Prepare optimizer and schedule (linear warmup and decay)
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": weight_decay,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+ optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
+ scheduler = get_linear_schedule_with_warmup(
+ optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps
+ )
+ return optimizer, scheduler
+
+
+def collect_decay_params(model, weight_decay):
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": weight_decay,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+ return optimizer_grouped_parameters
diff --git a/hanlp/layers/transformers/utils_tf.py b/hanlp/layers/transformers/utils_tf.py
new file mode 100644
index 000000000..e44a9d3d8
--- /dev/null
+++ b/hanlp/layers/transformers/utils_tf.py
@@ -0,0 +1,191 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-29 15:32
+import tensorflow as tf
+from hanlp.optimizers.adamw import create_optimizer
+from hanlp.utils.log_util import logger
+
+
+def config_is(config, model='bert'):
+ return model in type(config).__name__.lower()
+
+
+def convert_examples_to_features(
+ words,
+ max_seq_length,
+ tokenizer,
+ labels=None,
+ label_map=None,
+ cls_token_at_end=False,
+ cls_token="[CLS]",
+ cls_token_segment_id=1,
+ sep_token="[SEP]",
+ sep_token_extra=False,
+ pad_on_left=False,
+ pad_token_id=0,
+ pad_token_segment_id=0,
+ pad_token_label_id=0,
+ sequence_a_segment_id=0,
+ mask_padding_with_zero=True,
+ unk_token='[UNK]',
+ do_padding=True
+):
+ """Loads a data file into a list of `InputBatch`s
+ `cls_token_at_end` define the location of the CLS token:
+ - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
+ - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
+ `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
+
+ Args:
+ words:
+ max_seq_length:
+ tokenizer:
+ labels: (Default value = None)
+ label_map: (Default value = None)
+ cls_token_at_end: (Default value = False)
+ cls_token: (Default value = "[CLS]")
+ cls_token_segment_id: (Default value = 1)
+ sep_token: (Default value = "[SEP]")
+ sep_token_extra: (Default value = False)
+ pad_on_left: (Default value = False)
+ pad_token_id: (Default value = 0)
+ pad_token_segment_id: (Default value = 0)
+ pad_token_label_id: (Default value = 0)
+ sequence_a_segment_id: (Default value = 0)
+ mask_padding_with_zero: (Default value = True)
+ unk_token: (Default value = '[UNK]')
+ do_padding: (Default value = True)
+
+ Returns:
+
+ """
+ args = locals()
+ if not labels:
+ labels = words
+ pad_token_label_id = False
+
+ tokens = []
+ label_ids = []
+ for word, label in zip(words, labels):
+ word_tokens = tokenizer.tokenize(word)
+ if not word_tokens:
+ # some wired chars cause the tagger to return empty list
+ word_tokens = [unk_token] * len(word)
+ tokens.extend(word_tokens)
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ label_ids.extend([label_map[label] if label_map else True] + [pad_token_label_id] * (len(word_tokens) - 1))
+
+ # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
+ special_tokens_count = 3 if sep_token_extra else 2
+ if len(tokens) > max_seq_length - special_tokens_count:
+ logger.warning(
+ f'Input tokens {words} exceed the max sequence length of {max_seq_length - special_tokens_count}. '
+ f'The exceeded part will be truncated and ignored. '
+ f'You are recommended to split your long text into several sentences within '
+ f'{max_seq_length - special_tokens_count} tokens beforehand.')
+ tokens = tokens[: (max_seq_length - special_tokens_count)]
+ label_ids = label_ids[: (max_seq_length - special_tokens_count)]
+
+ # The convention in BERT is:
+ # (a) For sequence pairs:
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ # token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ # (b) For single sequences:
+ # tokens: [CLS] the dog is hairy . [SEP]
+ # token_type_ids: 0 0 0 0 0 0 0
+ #
+ # Where "token_type_ids" are used to indicate whether this is the first
+ # sequence or the second sequence. The embedding vectors for `type=0` and
+ # `type=1` were learned during pre-training and are added to the wordpiece
+ # embedding vector (and position vector). This is not *strictly* necessary
+ # since the [SEP] token unambiguously separates the sequences, but it makes
+ # it easier for the model to learn the concept of sequences.
+ #
+ # For classification tasks, the first vector (corresponding to [CLS]) is
+ # used as as the "sentence vector". Note that this only makes sense because
+ # the entire model is fine-tuned.
+ tokens += [sep_token]
+ label_ids += [pad_token_label_id]
+ if sep_token_extra:
+ # roberta uses an extra separator b/w pairs of sentences
+ tokens += [sep_token]
+ label_ids += [pad_token_label_id]
+ segment_ids = [sequence_a_segment_id] * len(tokens)
+
+ if cls_token_at_end:
+ tokens += [cls_token]
+ label_ids += [pad_token_label_id]
+ segment_ids += [cls_token_segment_id]
+ else:
+ tokens = [cls_token] + tokens
+ label_ids = [pad_token_label_id] + label_ids
+ segment_ids = [cls_token_segment_id] + segment_ids
+
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
+
+ if do_padding:
+ # Zero-pad up to the sequence length.
+ padding_length = max_seq_length - len(input_ids)
+ if pad_on_left:
+ input_ids = ([pad_token_id] * padding_length) + input_ids
+ input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
+ segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
+ label_ids = ([pad_token_label_id] * padding_length) + label_ids
+ else:
+ input_ids += [pad_token_id] * padding_length
+ input_mask += [0 if mask_padding_with_zero else 1] * padding_length
+ segment_ids += [pad_token_segment_id] * padding_length
+ label_ids += [pad_token_label_id] * padding_length
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+ assert len(label_ids) == max_seq_length, f'failed for:\n {args}'
+ else:
+ assert len(set(len(x) for x in [input_ids, input_mask, segment_ids, label_ids])) == 1
+ return input_ids, input_mask, segment_ids, label_ids
+
+
+def build_adamw_optimizer(config, learning_rate, epsilon, clipnorm, train_steps, use_amp, warmup_steps,
+ weight_decay_rate):
+ opt = create_optimizer(init_lr=learning_rate,
+ epsilon=epsilon,
+ weight_decay_rate=weight_decay_rate,
+ clipnorm=clipnorm,
+ num_train_steps=train_steps, num_warmup_steps=warmup_steps)
+ # opt = tfa.optimizers.AdamW(learning_rate=3e-5, epsilon=1e-08, weight_decay=0.01)
+ # opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
+ config.optimizer = tf.keras.utils.serialize_keras_object(opt)
+ lr_config = config.optimizer['config']['learning_rate']['config']
+ if 'decay_schedule_fn' in lr_config:
+ lr_config['decay_schedule_fn'] = dict(
+ (k, v) for k, v in lr_config['decay_schedule_fn'].items() if not k.startswith('_'))
+ if use_amp:
+ # loss scaling is currently required when using mixed precision
+ opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
+ return opt
+
+
+def adjust_tokens_for_transformers(sentence):
+ """Adjust tokens for BERT
+ See https://github.com/DoodleJZ/HPSG-Neural-Parser/blob/master/src_joint/Zparser.py#L1204
+
+ Args:
+ sentence:
+
+ Returns:
+
+
+ """
+ cleaned_words = []
+ for word in sentence:
+ # word = BERT_TOKEN_MAPPING.get(word, word)
+ if word == "n't" and cleaned_words:
+ cleaned_words[-1] = cleaned_words[-1] + "n"
+ word = "'t"
+ cleaned_words.append(word)
+ return cleaned_words
diff --git a/hanlp/layers/weight_normalization.py b/hanlp/layers/weight_normalization.py
index c96917780..90eeb5ed0 100644
--- a/hanlp/layers/weight_normalization.py
+++ b/hanlp/layers/weight_normalization.py
@@ -25,7 +25,7 @@
class WeightNormalization(tf.keras.layers.Wrapper):
"""This wrapper reparameterizes a layer by decoupling the weight's
magnitude and direction.
-
+
This speeds up convergence by improving the
conditioning of the optimization problem.
Weight Normalization: A Simple Reparameterization to Accelerate
@@ -47,13 +47,18 @@ class WeightNormalization(tf.keras.layers.Wrapper):
tf.keras.layers.Dense(n_classes),
data_init=True)(net)
```
- Arguments:
- layer: a layer instance.
- data_init: If `True` use data dependent variable initialization
+
+ Args:
+ layer: a layer instance
+ data_init: If
+
+ Returns:
+
Raises:
- ValueError: If not initialized with a `Layer` instance.
- ValueError: If `Layer` does not contain a `kernel` of weights
- NotImplementedError: If `data_init` is True and running graph execution
+ ValueError: If not initialized with a
+ ValueError: If
+ NotImplementedError: If
+
"""
def __init__(self, layer, data_init=True, **kwargs):
@@ -64,7 +69,14 @@ def __init__(self, layer, data_init=True, **kwargs):
self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN)
def build(self, input_shape):
- """Build `Layer`"""
+ """Build `Layer`
+
+ Args:
+ input_shape:
+
+ Returns:
+
+ """
input_shape = tf.TensorShape(input_shape)
self.input_spec = tf.keras.layers.InputSpec(
shape=[None] + input_shape[1:])
@@ -114,7 +126,14 @@ def build(self, input_shape):
self.built = True
def call(self, inputs):
- """Call `Layer`"""
+ """Call `Layer`
+
+ Args:
+ inputs:
+
+ Returns:
+
+ """
def _do_nothing():
return tf.identity(self.g)
@@ -144,9 +163,15 @@ def compute_output_shape(self, input_shape):
def _initialize_weights(self, inputs):
"""Initialize weight g.
-
+
The initial value of g could either from the initial value in v,
or by the input value if self.data_init is True.
+
+ Args:
+ inputs:
+
+ Returns:
+
"""
with tf.control_dependencies([
tf.debugging.assert_equal( # pylint: disable=bad-continuation
@@ -170,7 +195,14 @@ def _init_norm(self):
return [g_tensor]
def _data_dep_init(self, inputs):
- """Data dependent initialization."""
+ """Data dependent initialization.
+
+ Args:
+ inputs:
+
+ Returns:
+
+ """
with tf.name_scope('data_dep_init'):
# Generate data dependent init values
x_init = self._naked_clone_layer(inputs)
diff --git a/hanlp/metrics/accuracy.py b/hanlp/metrics/accuracy.py
new file mode 100644
index 000000000..5c839a224
--- /dev/null
+++ b/hanlp/metrics/accuracy.py
@@ -0,0 +1,18 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-12 17:56
+from alnlp import metrics
+from hanlp.metrics.metric import Metric
+
+
+class CategoricalAccuracy(metrics.CategoricalAccuracy, Metric):
+ @property
+ def score(self):
+ return self.get_metric()
+
+ def __repr__(self) -> str:
+ return f'Accuracy:{self.score:.2%}'
+
+
+class BooleanAccuracy(metrics.BooleanAccuracy, CategoricalAccuracy):
+ pass
diff --git a/hanlp/metrics/amr/__init__.py b/hanlp/metrics/amr/__init__.py
new file mode 100644
index 000000000..926d85e05
--- /dev/null
+++ b/hanlp/metrics/amr/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-24 12:47
\ No newline at end of file
diff --git a/hanlp/metrics/amr/smatch_eval.py b/hanlp/metrics/amr/smatch_eval.py
new file mode 100644
index 000000000..b8096619c
--- /dev/null
+++ b/hanlp/metrics/amr/smatch_eval.py
@@ -0,0 +1,101 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-24 12:47
+import os
+import warnings
+from typing import Union
+
+from hanlp.metrics.f1 import F1_
+from hanlp.metrics.mtl import MetricDict
+from hanlp.utils.io_util import get_resource, run_cmd, pushd
+from hanlp.utils.log_util import flash
+
+_SMATCH_SCRIPT = 'https://github.com/ChunchuanLv/amr-evaluation-tool-enhanced/archive/master.zip#evaluation.sh'
+_FAST_SMATCH_SCRIPT = 'https://github.com/jcyk/AMR-gs/archive/master.zip#tools/fast_smatch/compute_smatch.sh'
+
+
+class SmatchScores(MetricDict):
+ @property
+ def score(self):
+ return self['Smatch'].score
+
+
+def smatch_eval(pred, gold, use_fast=False) -> Union[SmatchScores, F1_]:
+ script = get_resource(_FAST_SMATCH_SCRIPT if use_fast else _SMATCH_SCRIPT)
+ home = os.path.dirname(script)
+ pred = os.path.realpath(pred)
+ gold = os.path.realpath(gold)
+ with pushd(home):
+ flash('Running evaluation script [blink][yellow]...[/yellow][/blink]')
+ cmd = f'bash {script} {pred} {gold}'
+ text = run_cmd(cmd)
+ flash('')
+ return format_fast_scores(text) if use_fast else format_official_scores(text)
+
+
+def post_process(pred, amr_version):
+ pred = os.path.realpath(pred)
+ utils_tar_gz = get_amr_utils(amr_version)
+ util_dir = get_resource(utils_tar_gz)
+ stog_home = get_resource('https://github.com/jcyk/AMR-gs/archive/master.zip')
+ with pushd(stog_home):
+ run_cmd(
+ f'python3 -u -m stog.data.dataset_readers.amr_parsing.postprocess.postprocess '
+ f'--amr_path {pred} --util_dir {util_dir} --v 2')
+ return pred + '.post'
+
+
+def get_amr_utils(amr_version):
+ if amr_version == '1.0':
+ utils_tar_gz = 'https://www.cs.jhu.edu/~s.zhang/data/AMR/amr_1.0_utils.tar.gz'
+ elif amr_version == '2.0':
+ utils_tar_gz = 'https://www.cs.jhu.edu/~s.zhang/data/AMR/amr_2.0_utils.tar.gz'
+ elif amr_version == '3.0':
+ utils_tar_gz = 'https://od.hankcs.com/research/amr2020/amr_3.0_utils.tgz'
+ else:
+ raise ValueError(f'Unsupported AMR version {amr_version}')
+ return utils_tar_gz
+
+
+def format_official_scores(text: str):
+ # Smatch -> P: 0.136, R: 0.107, F: 0.120
+ # Unlabeled -> P: 0.229, R: 0.180, F: 0.202
+ # No WSD -> P: 0.137, R: 0.108, F: 0.120
+ # Non_sense_frames -> P: 0.008, R: 0.008, F: 0.008
+ # Wikification -> P: 0.000, R: 0.000, F: 0.000
+ # Named Ent. -> P: 0.222, R: 0.092, F: 0.130
+ # Negations -> P: 0.000, R: 0.000, F: 0.000
+ # IgnoreVars -> P: 0.005, R: 0.003, F: 0.003
+ # Concepts -> P: 0.075, R: 0.036, F: 0.049
+ # Frames -> P: 0.007, R: 0.007, F: 0.007
+ # Reentrancies -> P: 0.113, R: 0.060, F: 0.079
+ # SRL -> P: 0.145, R: 0.104, F: 0.121
+ scores = SmatchScores()
+ for line in text.split('\n'):
+ line = line.strip()
+ if not line:
+ continue
+ name, vs = line.split(' -> ')
+ try:
+ p, r, f = [float(x.split(': ')[-1]) for x in vs.split(', ')]
+ except ValueError:
+ warnings.warn(f'Failed to parse results from smatch: {line}')
+ p, r, f = float("nan"), float("nan"), float("nan")
+ scores[name] = F1_(p, r, f)
+ return scores
+
+
+def format_fast_scores(text: str):
+ # using fast smatch
+ # Precision: 0.137
+ # Recall: 0.108
+ # Document F-score: 0.121
+ scores = []
+ for line in text.split('\n'):
+ line = line.strip()
+ if not line or ':' not in line:
+ continue
+ name, score = line.split(': ')
+ scores.append(float(score))
+ assert len(scores) == 3
+ return F1_(*scores)
diff --git a/hanlp/metrics/chunking/binary_chunking_f1.py b/hanlp/metrics/chunking/binary_chunking_f1.py
new file mode 100644
index 000000000..7c708b082
--- /dev/null
+++ b/hanlp/metrics/chunking/binary_chunking_f1.py
@@ -0,0 +1,33 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-02 14:27
+from collections import defaultdict
+from typing import List, Union
+
+import torch
+
+from hanlp.metrics.f1 import F1
+
+
+class BinaryChunkingF1(F1):
+ def __call__(self, pred_tags: torch.LongTensor, gold_tags: torch.LongTensor, lens: List[int] = None):
+ if lens is None:
+ lens = [gold_tags.size(1)] * gold_tags.size(0)
+ self.update(self.decode_spans(pred_tags, lens), self.decode_spans(gold_tags, lens))
+
+ def update(self, pred_tags, gold_tags):
+ for pred, gold in zip(pred_tags, gold_tags):
+ super().__call__(set(pred), set(gold))
+
+ @staticmethod
+ def decode_spans(pred_tags: torch.LongTensor, lens: Union[List[int], torch.LongTensor]):
+ if isinstance(lens, torch.Tensor):
+ lens = lens.tolist()
+ batch_pred = defaultdict(list)
+ for batch, offset in pred_tags.nonzero(as_tuple=False).tolist():
+ batch_pred[batch].append(offset)
+ batch_pred_spans = [[(0, l)] for l in lens]
+ for batch, offsets in batch_pred.items():
+ l = lens[batch]
+ batch_pred_spans[batch] = list(zip(offsets, offsets[1:] + [l]))
+ return batch_pred_spans
diff --git a/hanlp/metrics/chunking/bmes.py b/hanlp/metrics/chunking/bmes.py
index 3b31b0207..9b67b93d4 100644
--- a/hanlp/metrics/chunking/bmes.py
+++ b/hanlp/metrics/chunking/bmes.py
@@ -2,14 +2,14 @@
# Author: hankcs
# Date: 2019-09-14 21:55
-from hanlp.common.vocab import Vocab
-from hanlp.metrics.chunking.f1 import ChunkingF1
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.metrics.chunking.chunking_f1_tf import ChunkingF1_TF
from hanlp.metrics.chunking.sequence_labeling import get_entities
-class BMES_F1(ChunkingF1):
+class BMES_F1_TF(ChunkingF1_TF):
- def __init__(self, tag_vocab: Vocab, from_logits=True, suffix=False, name='f1', dtype=None, **kwargs):
+ def __init__(self, tag_vocab: VocabTF, from_logits=True, suffix=False, name='f1', dtype=None, **kwargs):
super().__init__(tag_vocab, from_logits, name, dtype, **kwargs)
self.nb_correct = 0
self.nb_pred = 0
diff --git a/hanlp/metrics/chunking/chunking_f1.py b/hanlp/metrics/chunking/chunking_f1.py
new file mode 100644
index 000000000..75003b265
--- /dev/null
+++ b/hanlp/metrics/chunking/chunking_f1.py
@@ -0,0 +1,19 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-11 22:14
+from typing import List
+
+from hanlp.metrics.chunking.sequence_labeling import get_entities
+from hanlp.metrics.f1 import F1
+from hanlp.metrics.metric import Metric
+
+
+class ChunkingF1(F1):
+
+ def __call__(self, pred_tags: List[List[str]], gold_tags: List[List[str]]):
+ for p, g in zip(pred_tags, gold_tags):
+ pred = set(get_entities(p))
+ gold = set(get_entities(g))
+ self.nb_pred += len(pred)
+ self.nb_true += len(gold)
+ self.nb_correct += len(pred & gold)
diff --git a/hanlp/metrics/chunking/f1.py b/hanlp/metrics/chunking/chunking_f1_tf.py
similarity index 92%
rename from hanlp/metrics/chunking/f1.py
rename to hanlp/metrics/chunking/chunking_f1_tf.py
index 202303a00..dfbdda80e 100644
--- a/hanlp/metrics/chunking/f1.py
+++ b/hanlp/metrics/chunking/chunking_f1_tf.py
@@ -5,12 +5,12 @@
import tensorflow as tf
-from hanlp.common.vocab import Vocab
+from hanlp.common.vocab_tf import VocabTF
-class ChunkingF1(tf.keras.metrics.Metric, ABC):
+class ChunkingF1_TF(tf.keras.metrics.Metric, ABC):
- def __init__(self, tag_vocab: Vocab, from_logits=True, name='f1', dtype=None, **kwargs):
+ def __init__(self, tag_vocab: VocabTF, from_logits=True, name='f1', dtype=None, **kwargs):
super().__init__(name, dtype, dynamic=True, **kwargs)
self.tag_vocab = tag_vocab
self.from_logits = from_logits
diff --git a/hanlp/metrics/chunking/conlleval.py b/hanlp/metrics/chunking/conlleval.py
old mode 100755
new mode 100644
index 8a5299306..9b63fe282
--- a/hanlp/metrics/chunking/conlleval.py
+++ b/hanlp/metrics/chunking/conlleval.py
@@ -10,10 +10,13 @@
# - raw tags (-r argument) not supported
import io
import sys
-import re
from collections import defaultdict, namedtuple
-from typing import Tuple, Union
+from typing import Tuple, Union, List
+
+from alnlp.metrics.span_utils import bio_tags_to_spans
+
+from hanlp.metrics.metric import Metric
ANY_SPACE = ''
@@ -52,17 +55,30 @@ def reset_state(self):
state.clear()
-class CoNLLEval(object):
+class SpanF1(Metric):
- def __init__(self) -> None:
+ def __init__(self, label_encoding='IOBES') -> None:
super().__init__()
+ self.label_encoding = label_encoding
+ self.count = EvalCounts()
+
+ def reset(self):
self.count = EvalCounts()
+ @property
+ def score(self):
+ return self.result(False, False).fscore
+
def reset_state(self):
self.count.reset_state()
- def update_state(self, true_seqs, pred_seqs):
- count = evaluate(true_seqs, pred_seqs)
+ def update_state(self, true_seqs: List[str], pred_seqs: List[str]):
+ if self.label_encoding == 'IOBES':
+ count = evaluate_iobes(true_seqs, pred_seqs)
+ elif self.label_encoding in ['IOB2', 'BIO']:
+ count = evaluate_iob2(true_seqs, pred_seqs)
+ else:
+ raise ValueError(f'Unrecognized label encoding {self.label_encoding}')
self.count.correct_chunk += count.correct_chunk
self.count.correct_tags += count.correct_tags
self.count.found_correct += count.found_correct
@@ -72,6 +88,10 @@ def update_state(self, true_seqs, pred_seqs):
for k, v in n.items():
s[k] = s.get(k, 0) + v
+ def batch_update_state(self, true_seqs: List[List[str]], pred_seqs: List[List[str]]):
+ for t, p in zip(true_seqs, pred_seqs):
+ self.update_state(t, p)
+
def result(self, full=True, verbose=True) -> Union[Tuple[Metrics, dict, str], Metrics]:
if full:
out = io.StringIO()
@@ -85,6 +105,14 @@ def result(self, full=True, verbose=True) -> Union[Tuple[Metrics, dict, str], Me
overall, _ = metrics(self.count)
return overall
+ # torch convention: put pred before gold
+ def __call__(self, pred_seqs: List[List[str]], true_seqs: List[List[str]]):
+ return self.batch_update_state(true_seqs, pred_seqs)
+
+ def __repr__(self) -> str:
+ result = self.result(False, False)
+ return f"P: {result.prec:.2%} R: {result.rec:.2%} F: {result.fscore:.2%}"
+
def parse_args(argv):
import argparse
@@ -104,18 +132,23 @@ def parse_args(argv):
def split_tag(chunk_tag):
- """
- split chunk tag into IOBES prefix and chunk_type
+ """split chunk tag into IOBES prefix and chunk_type
e.g.
B-PER -> (B, PER)
O -> (O, None)
+
+ Args:
+ chunk_tag:
+
+ Returns:
+
"""
if chunk_tag == 'O':
return ('O', None)
return chunk_tag.split('-', maxsplit=1)
-def evaluate(true_seqs, pred_seqs):
+def evaluate_iobes(true_seqs, pred_seqs):
counts = EvalCounts()
in_correct = False # currently processed chunks is correct until now
last_correct = 'O' # previous chunk tag in corpus
@@ -171,6 +204,16 @@ def evaluate(true_seqs, pred_seqs):
return counts
+def evaluate_iob2(true_seqs, pred_seqs):
+ counts = EvalCounts()
+ gold = set(bio_tags_to_spans(true_seqs))
+ pred = set(bio_tags_to_spans(pred_seqs))
+ counts.correct_chunk = len(gold & pred)
+ counts.found_guessed = len(pred)
+ counts.found_correct = len(gold)
+ return counts
+
+
def uniq(iterable):
seen = set()
return [i for i in iterable if not (i in seen or seen.add(i))]
@@ -178,16 +221,24 @@ def uniq(iterable):
def calculate_metrics(correct, guessed, total):
tp, fp, fn = correct, guessed - correct, total - correct
- p = 0 if tp + fp == 0 else 1. * tp / (tp + fp)
- r = 0 if tp + fn == 0 else 1. * tp / (tp + fn)
- f = 0 if p + r == 0 else 2 * p * r / (p + r)
+ p = 0. if tp + fp == 0 else 1. * tp / (tp + fp)
+ r = 0. if tp + fn == 0 else 1. * tp / (tp + fn)
+ f = 0. if p + r == 0 else 2 * p * r / (p + r)
return Metrics(tp, fp, fn, p, r, f)
def calc_metrics(tp, p, t, percent=True):
- """
- compute overall precision, recall and FB1 (default values are 0.0)
+ """compute overall precision, recall and FB1 (default values are 0.0)
if percent is True, return 100 * original decimal value
+
+ Args:
+ tp:
+ p:
+ t:
+ percent: (Default value = True)
+
+ Returns:
+
"""
precision = tp / p if p else 0
recall = tp / t if t else 0
@@ -282,10 +333,10 @@ def main(argv):
args = parse_args(argv[1:])
if args.file is None:
- counts = evaluate(sys.stdin, args)
+ counts = evaluate_iobes(sys.stdin, args)
else:
with open(args.file, encoding='utf-8') as f:
- counts = evaluate(f, args)
+ counts = evaluate_iobes(f, args)
report(counts)
diff --git a/hanlp/metrics/chunking/cws_eval.py b/hanlp/metrics/chunking/cws_eval.py
new file mode 100644
index 000000000..1b610f7f0
--- /dev/null
+++ b/hanlp/metrics/chunking/cws_eval.py
@@ -0,0 +1,81 @@
+# -*- coding:utf-8 -*-
+# Author:hankcs
+# Date: 2018-06-02 22:53
+# 《自然语言处理入门》2.9 准确率评测
+# 配套书籍:http://nlp.hankcs.com/book.php
+# 讨论答疑:https://bbs.hankcs.com/
+import re
+
+
+def to_region(segmentation: str) -> list:
+ """将分词结果转换为区间
+
+ Args:
+ segmentation: 商品 和 服务
+ segmentation: str:
+
+ Returns:
+ 0, 2), (2, 3), (3, 5)]
+
+ """
+ region = []
+ start = 0
+ for word in re.compile("\\s+").split(segmentation.strip()):
+ end = start + len(word)
+ region.append((start, end))
+ start = end
+ return region
+
+
+def evaluate(gold: str, pred: str, dic: dict = None) -> tuple:
+ """计算P、R、F1
+
+ Args:
+ gold: 标准答案文件,比如“商品 和 服务”
+ pred: 分词结果文件,比如“商品 和服 务”
+ dic: 词典
+ gold: str:
+ pred: str:
+ dic: dict: (Default value = None)
+
+ Returns:
+ P, R, F1, OOV_R, IV_R)
+
+ """
+ A_size, B_size, A_cap_B_size, OOV, IV, OOV_R, IV_R = 0, 0, 0, 0, 0, 0, 0
+ with open(gold, encoding='utf-8') as gd, open(pred, encoding='utf-8') as pd:
+ for g, p in zip(gd, pd):
+ A, B = set(to_region(g)), set(to_region(p))
+ A_size += len(A)
+ B_size += len(B)
+ A_cap_B_size += len(A & B)
+ text = re.sub("\\s+", "", g)
+ if dic:
+ for (start, end) in A:
+ word = text[start: end]
+ if word in dic:
+ IV += 1
+ else:
+ OOV += 1
+
+ for (start, end) in A & B:
+ word = text[start: end]
+ if word in dic:
+ IV_R += 1
+ else:
+ OOV_R += 1
+ p, r = safe_division(A_cap_B_size, B_size), safe_division(A_cap_B_size, A_size)
+ return p, r, safe_division(2 * p * r, (p + r)), safe_division(OOV_R, OOV), safe_division(IV_R, IV)
+
+
+def build_dic_from_file(path):
+ dic = set()
+ with open(path, encoding='utf-8') as gd:
+ for g in gd:
+ for word in re.compile("\\s+").split(g.strip()):
+ dic.add(word)
+ return dic
+
+
+def safe_division(n, d):
+ return n / d if d else float('nan') if n else 0.
diff --git a/hanlp/metrics/chunking/iobes.py b/hanlp/metrics/chunking/iobes_tf.py
similarity index 67%
rename from hanlp/metrics/chunking/iobes.py
rename to hanlp/metrics/chunking/iobes_tf.py
index 48edb0b6b..a091ebb0e 100644
--- a/hanlp/metrics/chunking/iobes.py
+++ b/hanlp/metrics/chunking/iobes_tf.py
@@ -2,16 +2,16 @@
# Author: hankcs
# Date: 2019-09-14 21:55
-from hanlp.common.vocab import Vocab
-from hanlp.metrics.chunking.conlleval import CoNLLEval
-from hanlp.metrics.chunking.f1 import ChunkingF1
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.metrics.chunking.conlleval import SpanF1
+from hanlp.metrics.chunking.chunking_f1_tf import ChunkingF1_TF
-class IOBES_F1(ChunkingF1):
+class IOBES_F1_TF(ChunkingF1_TF):
- def __init__(self, tag_vocab: Vocab, from_logits=True, name='f1', dtype=None, **kwargs):
+ def __init__(self, tag_vocab: VocabTF, from_logits=True, name='f1', dtype=None, **kwargs):
super().__init__(tag_vocab, from_logits, name, dtype, **kwargs)
- self.state = CoNLLEval()
+ self.state = SpanF1()
def update_tags(self, true_tags, pred_tags):
# true_tags = list(itertools.chain.from_iterable(true_tags))
diff --git a/hanlp/metrics/chunking/sequence_labeling.py b/hanlp/metrics/chunking/sequence_labeling.py
index 1bbbe5314..6beae5100 100644
--- a/hanlp/metrics/chunking/sequence_labeling.py
+++ b/hanlp/metrics/chunking/sequence_labeling.py
@@ -1,3 +1,24 @@
+# MIT License
+#
+# Copyright (c) 2018 chakki
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
"""Metrics to assess performance on sequence labeling task given prediction
Functions named as ``*_score`` return a scalar value to maximize: the higher
the better
@@ -6,8 +27,6 @@
from collections import defaultdict
import numpy as np
-from hanlp.metrics.chunking import conlleval
-
def iobes_to_span(words, tags):
delimiter = ' '
@@ -23,13 +42,14 @@ def get_entities(seq, suffix=False):
"""Gets entities from sequence.
Args:
- seq (list): sequence of labels.
+ seq(list): sequence of labels.
+ suffix: (Default value = False)
Returns:
- list: list of (chunk_type, chunk_start, chunk_end).
+ list: list of (chunk_type, chunk_start, chunk_end).
+ Example:
- Example:
- >>> from seqeval.metrics.sequence_labeling import get_entities
+ >>> from seqeval.metrics.sequence_labeling import get_entities
>>> seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
>>> get_entities(seq)
[('PER', 0, 2), ('LOC', 3, 4)]
@@ -43,15 +63,12 @@ def get_entities(seq, suffix=False):
begin_offset = 0
chunks = []
for i, chunk in enumerate(seq + ['O']):
- cells = chunk.split('-')
if suffix:
tag = chunk[-1]
- type_ = cells[0]
+ type_ = chunk[:-2]
else:
tag = chunk[0]
- type_ = cells[-1]
- if len(cells) == 1:
- type_ = ''
+ type_ = chunk[2:]
if end_of_chunk(prev_tag, tag, prev_type, type_):
chunks.append((prev_type, begin_offset, i))
@@ -67,13 +84,14 @@ def end_of_chunk(prev_tag, tag, prev_type, type_):
"""Checks if a chunk ended between the previous and current word.
Args:
- prev_tag: previous chunk tag.
- tag: current chunk tag.
- prev_type: previous type.
- type_: current type.
+ prev_tag: previous chunk tag.
+ tag: current chunk tag.
+ prev_type: previous type.
+ type_: current type.
Returns:
- chunk_end: boolean.
+ chunk_end: boolean.
+
"""
chunk_end = False
@@ -97,13 +115,14 @@ def start_of_chunk(prev_tag, tag, prev_type, type_):
"""Checks if a chunk started between the previous and current word.
Args:
- prev_tag: previous chunk tag.
- tag: current chunk tag.
- prev_type: previous type.
- type_: current type.
+ prev_tag: previous chunk tag.
+ tag: current chunk tag.
+ prev_type: previous type.
+ type_: current type.
Returns:
- chunk_start: boolean.
+ chunk_start: boolean.
+
"""
chunk_start = False
@@ -125,23 +144,25 @@ def start_of_chunk(prev_tag, tag, prev_type, type_):
def f1_score(y_true, y_pred, average='micro', suffix=False):
"""Compute the F1 score.
-
+
The F1 score can be interpreted as a weighted average of the precision and
recall, where an F1 score reaches its best value at 1 and worst score at 0.
The relative contribution of precision and recall to the F1 score are
equal. The formula for the F1 score is::
-
+
F1 = 2 * (precision * recall) / (precision + recall)
Args:
- y_true : 2d array. Ground truth (correct) target values.
- y_pred : 2d array. Estimated targets as returned by a tagger.
+ y_true: 2d array. Ground truth (correct) target values.
+ y_pred: 2d array. Estimated targets as returned by a tagger.
+ average: (Default value = 'micro')
+ suffix: (Default value = False)
Returns:
- score : float.
+ score: float.
+ Example:
- Example:
- >>> from seqeval.metrics import f1_score
+ >>> from seqeval.metrics import f1_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> f1_score(y_true, y_pred)
@@ -163,20 +184,20 @@ def f1_score(y_true, y_pred, average='micro', suffix=False):
def accuracy_score(y_true, y_pred):
"""Accuracy classification score.
-
+
In multilabel classification, this function computes subset accuracy:
the set of labels predicted for a sample must *exactly* match the
corresponding set of labels in y_true.
Args:
- y_true : 2d array. Ground truth (correct) target values.
- y_pred : 2d array. Estimated targets as returned by a tagger.
+ y_true: 2d array. Ground truth (correct) target values.
+ y_pred: 2d array. Estimated targets as returned by a tagger.
Returns:
- score : float.
+ score: float.
+ Example:
- Example:
- >>> from seqeval.metrics import accuracy_score
+ >>> from seqeval.metrics import accuracy_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> accuracy_score(y_true, y_pred)
@@ -196,22 +217,24 @@ def accuracy_score(y_true, y_pred):
def precision_score(y_true, y_pred, average='micro', suffix=False):
"""Compute the precision.
-
+
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
true positives and ``fp`` the number of false positives. The precision is
intuitively the ability of the classifier not to label as positive a sample.
-
+
The best value is 1 and the worst value is 0.
Args:
- y_true : 2d array. Ground truth (correct) target values.
- y_pred : 2d array. Estimated targets as returned by a tagger.
+ y_true: 2d array. Ground truth (correct) target values.
+ y_pred: 2d array. Estimated targets as returned by a tagger.
+ average: (Default value = 'micro')
+ suffix: (Default value = False)
Returns:
- score : float.
+ score: float.
+ Example:
- Example:
- >>> from seqeval.metrics import precision_score
+ >>> from seqeval.metrics import precision_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> precision_score(y_true, y_pred)
@@ -230,22 +253,24 @@ def precision_score(y_true, y_pred, average='micro', suffix=False):
def recall_score(y_true, y_pred, average='micro', suffix=False):
"""Compute the recall.
-
+
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
true positives and ``fn`` the number of false negatives. The recall is
intuitively the ability of the classifier to find all the positive samples.
-
+
The best value is 1 and the worst value is 0.
Args:
- y_true : 2d array. Ground truth (correct) target values.
- y_pred : 2d array. Estimated targets as returned by a tagger.
+ y_true: 2d array. Ground truth (correct) target values.
+ y_pred: 2d array. Estimated targets as returned by a tagger.
+ average: (Default value = 'micro')
+ suffix: (Default value = False)
Returns:
- score : float.
+ score: float.
+ Example:
- Example:
- >>> from seqeval.metrics import recall_score
+ >>> from seqeval.metrics import recall_score
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> recall_score(y_true, y_pred)
@@ -263,18 +288,17 @@ def recall_score(y_true, y_pred, average='micro', suffix=False):
def performance_measure(y_true, y_pred):
- """
- Compute the performance metrics: TP, FP, FN, TN
+ """Compute the performance metrics: TP, FP, FN, TN
Args:
- y_true : 2d array. Ground truth (correct) target values.
- y_pred : 2d array. Estimated targets as returned by a tagger.
+ y_true: 2d array. Ground truth (correct) target values.
+ y_pred: 2d array. Estimated targets as returned by a tagger.
Returns:
- performance_dict : dict
+ performance_dict: dict
+ Example:
- Example:
- >>> from seqeval.metrics import performance_measure
+ >>> from seqeval.metrics import performance_measure
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'B-ORG'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O'], ['B-PER', 'I-PER', 'O']]
>>> performance_measure(y_true, y_pred)
@@ -299,15 +323,16 @@ def classification_report(y_true, y_pred, digits=2, suffix=False):
"""Build a text report showing the main classification metrics.
Args:
- y_true : 2d array. Ground truth (correct) target values.
- y_pred : 2d array. Estimated targets as returned by a classifier.
- digits : int. Number of digits for formatting output floating point values.
+ y_true: 2d array. Ground truth (correct) target values.
+ y_pred: 2d array. Estimated targets as returned by a classifier.
+ digits: int. Number of digits for formatting output floating point values. (Default value = 2)
+ suffix: (Default value = False)
Returns:
- report : string. Text summary of the precision, recall, F1 score for each class.
+ report: string. Text summary of the precision, recall, F1 score for each class.
+ Examples:
- Examples:
- >>> from seqeval.metrics import classification_report
+ >>> from seqeval.metrics import classification_report
>>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
>>> print(classification_report(y_true, y_pred))
diff --git a/hanlp/metrics/f1.py b/hanlp/metrics/f1.py
new file mode 100644
index 000000000..06f5fa373
--- /dev/null
+++ b/hanlp/metrics/f1.py
@@ -0,0 +1,64 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-10 14:55
+from abc import ABC
+
+from hanlp.metrics.metric import Metric
+
+
+class F1(Metric, ABC):
+ def __init__(self, nb_pred=0, nb_true=0, nb_correct=0) -> None:
+ super().__init__()
+ self.nb_correct = nb_correct
+ self.nb_pred = nb_pred
+ self.nb_true = nb_true
+
+ def __repr__(self) -> str:
+ p, r, f = self.prf
+ return f"P: {p:.2%} R: {r:.2%} F1: {f:.2%}"
+
+ @property
+ def prf(self):
+ nb_correct = self.nb_correct
+ nb_pred = self.nb_pred
+ nb_true = self.nb_true
+ p = nb_correct / nb_pred if nb_pred > 0 else .0
+ r = nb_correct / nb_true if nb_true > 0 else .0
+ f = 2 * p * r / (p + r) if p + r > 0 else .0
+ return p, r, f
+
+ @property
+ def score(self):
+ return self.prf[-1]
+
+ def reset(self):
+ self.nb_correct = 0
+ self.nb_pred = 0
+ self.nb_true = 0
+
+ def __call__(self, pred: set, gold: set):
+ self.nb_correct += len(pred & gold)
+ self.nb_pred += len(pred)
+ self.nb_true += len(gold)
+
+
+class F1_(Metric):
+ def __init__(self, p, r, f) -> None:
+ super().__init__()
+ self.f = f
+ self.r = r
+ self.p = p
+
+ @property
+ def score(self):
+ return self.f
+
+ def __call__(self, pred, gold):
+ raise NotImplementedError()
+
+ def reset(self):
+ self.f = self.r = self.p = 0
+
+ def __repr__(self) -> str:
+ p, r, f = self.p, self.r, self.f
+ return f"P: {p:.2%} R: {r:.2%} F1: {f:.2%}"
diff --git a/hanlp/metrics/metric.py b/hanlp/metrics/metric.py
new file mode 100644
index 000000000..e90e5cf5b
--- /dev/null
+++ b/hanlp/metrics/metric.py
@@ -0,0 +1,44 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-03 11:35
+from abc import ABC, abstractmethod
+
+
+class Metric(ABC):
+
+ def __lt__(self, other):
+ return self.score < other
+
+ def __le__(self, other):
+ return self.score <= other
+
+ def __eq__(self, other):
+ return self.score == other
+
+ def __ge__(self, other):
+ return self.score >= other
+
+ def __gt__(self, other):
+ return self.score > other
+
+ def __ne__(self, other):
+ return self.score != other
+
+ @property
+ @abstractmethod
+ def score(self):
+ pass
+
+ @abstractmethod
+ def __call__(self, pred, gold, mask=None):
+ pass
+
+ def __repr__(self) -> str:
+ return f'{self.score}:.4f'
+
+ def __float__(self):
+ return self.score
+
+ @abstractmethod
+ def reset(self):
+ pass
diff --git a/hanlp/metrics/mtl.py b/hanlp/metrics/mtl.py
new file mode 100644
index 000000000..93a03d100
--- /dev/null
+++ b/hanlp/metrics/mtl.py
@@ -0,0 +1,45 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-03 00:16
+from hanlp.metrics.metric import Metric
+
+
+class MetricDict(Metric, dict):
+ _COLORS = ["magenta", "cyan", "green", "yellow"]
+
+ @property
+ def score(self):
+ return sum(float(x) for x in self.values()) / len(self)
+
+ def __call__(self, pred, gold):
+ for metric in self.values():
+ metric(pred, gold)
+
+ def reset(self):
+ for metric in self.values():
+ metric.reset()
+
+ def __repr__(self) -> str:
+ return ' '.join(f'({k} {v})' for k, v in self.items())
+
+ def cstr(self, idx=None, level=0) -> str:
+ if idx is None:
+ idx = [0]
+ prefix = ''
+ for _, (k, v) in enumerate(self.items()):
+ color = self._COLORS[idx[0] % len(self._COLORS)]
+ idx[0] += 1
+ child_is_dict = isinstance(v, MetricDict)
+ _level = min(level, 2)
+ # if level != 0 and not child_is_dict:
+ # _level = 2
+ lb = '{[('
+ rb = '}])'
+ k = f'[bold][underline]{k}[/underline][/bold]'
+ prefix += f'[{color}]{lb[_level]}{k} [/{color}]'
+ if child_is_dict:
+ prefix += v.cstr(idx, level + 1)
+ else:
+ prefix += f'[{color}]{v}[/{color}]'
+ prefix += f'[{color}]{rb[_level]}[/{color}]'
+ return prefix
diff --git a/hanlp/metrics/parsing/attachmentscore.py b/hanlp/metrics/parsing/attachmentscore.py
new file mode 100644
index 000000000..85dd85a05
--- /dev/null
+++ b/hanlp/metrics/parsing/attachmentscore.py
@@ -0,0 +1,75 @@
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from hanlp.metrics.metric import Metric
+
+
+class AttachmentScore(Metric):
+
+ def __init__(self, eps=1e-12):
+ super(AttachmentScore, self).__init__()
+
+ self.eps = eps
+ self.total = 0.0
+ self.correct_arcs = 0.0
+ self.correct_rels = 0.0
+
+ def __repr__(self):
+ return f"UAS: {self.uas:.2%} LAS: {self.las:.2%}"
+
+ # noinspection PyMethodOverriding
+ def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask):
+ arc_mask = arc_preds.eq(arc_golds)[mask]
+ rel_mask = rel_preds.eq(rel_golds)[mask] & arc_mask
+
+ self.total += len(arc_mask)
+ self.correct_arcs += arc_mask.sum().item()
+ self.correct_rels += rel_mask.sum().item()
+
+ def __lt__(self, other):
+ return self.score < other
+
+ def __le__(self, other):
+ return self.score <= other
+
+ def __ge__(self, other):
+ return self.score >= other
+
+ def __gt__(self, other):
+ return self.score > other
+
+ @property
+ def score(self):
+ return self.las
+
+ @property
+ def uas(self):
+ return self.correct_arcs / (self.total + self.eps)
+
+ @property
+ def las(self):
+ return self.correct_rels / (self.total + self.eps)
+
+ def reset(self):
+ self.total = 0.0
+ self.correct_arcs = 0.0
+ self.correct_rels = 0.0
diff --git a/hanlp/metrics/parsing/conllx_eval.py b/hanlp/metrics/parsing/conllx_eval.py
new file mode 100644
index 000000000..41c46dab6
--- /dev/null
+++ b/hanlp/metrics/parsing/conllx_eval.py
@@ -0,0 +1,68 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-08 22:35
+import tempfile
+
+from hanlp.components.parsers.conll import read_conll
+from hanlp.utils.io_util import get_resource, get_exitcode_stdout_stderr
+
+CONLLX_EVAL = get_resource(
+ 'https://github.com/elikip/bist-parser/archive/master.zip' + '#bmstparser/src/utils/eval.pl')
+
+
+def evaluate(gold_file, pred_file):
+ """Evaluate using official CoNLL-X evaluation script (Yuval Krymolowski)
+
+ Args:
+ gold_file(str): The gold conllx file
+ pred_file(str): The pred conllx file
+
+ Returns:
+
+
+ """
+ gold_file = get_resource(gold_file)
+ fixed_pred_file = tempfile.NamedTemporaryFile().name
+ copy_cols(gold_file, pred_file, fixed_pred_file, keep_comments=False)
+ if gold_file.endswith('.conllu'):
+ fixed_gold_file = tempfile.NamedTemporaryFile().name
+ copy_cols(gold_file, gold_file, fixed_gold_file, keep_comments=False)
+ gold_file = fixed_gold_file
+
+ exitcode, out, err = get_exitcode_stdout_stderr(f'perl {CONLLX_EVAL} -q -b -g {gold_file} -s {fixed_pred_file}')
+ if exitcode:
+ raise RuntimeError(f'eval.pl exited with error code {exitcode} and error message {err} and output {out}.')
+ lines = out.split('\n')[-4:]
+ las = int(lines[0].split()[3]) / int(lines[0].split()[5])
+ uas = int(lines[1].split()[3]) / int(lines[1].split()[5])
+ return uas, las
+
+
+def copy_cols(gold_file, pred_file, copied_pred_file, keep_comments=True):
+ """Copy the first 6 columns from gold file to pred file
+
+ Args:
+ gold_file:
+ pred_file:
+ copied_pred_file:
+ keep_comments: (Default value = True)
+
+ Returns:
+
+
+ """
+ with open(copied_pred_file, 'w') as to_out, open(pred_file) as pred_file, open(gold_file) as gold_file:
+ for idx, (p, g) in enumerate(zip(pred_file, gold_file)):
+ while p.startswith('#'):
+ p = next(pred_file)
+ if not g.strip():
+ if p.strip():
+ raise ValueError(
+ f'Prediction file {pred_file.name} does not end a sentence at line {idx + 1}\n{p.strip()}')
+ to_out.write('\n')
+ continue
+ while g.startswith('#') or '-' in g.split('\t')[0]:
+ if keep_comments or g.startswith('-'):
+ to_out.write(g)
+ g = next(gold_file)
+ to_out.write('\t'.join(str(x) for x in g.split('\t')[:6] + p.split('\t')[6:]))
diff --git a/hanlp/metrics/parsing/iwpt20_eval.py b/hanlp/metrics/parsing/iwpt20_eval.py
new file mode 100644
index 000000000..0b322f4b1
--- /dev/null
+++ b/hanlp/metrics/parsing/iwpt20_eval.py
@@ -0,0 +1,154 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-25 16:04
+
+import os
+import tempfile
+from typing import List
+
+from hanlp.metrics.parsing.conllx_eval import copy_cols
+
+from hanlp_common.structure import SerializableDict
+from hanlp.metrics.parsing import iwpt20_xud_eval
+from hanlp.metrics.parsing.iwpt20_xud_eval import load_conllu_file
+from hanlp.utils.io_util import get_resource, get_exitcode_stdout_stderr
+
+UD_TOOLS_ROOT = get_resource(
+ 'https://github.com/UniversalDependencies/tools/archive/1650bd354bd158c75836cff6650ea35cc9928fc8.zip')
+
+ENHANCED_COLLAPSE_EMPTY_NODES = os.path.join(UD_TOOLS_ROOT, 'enhanced_collapse_empty_nodes.pl')
+CONLLU_QUICK_FIX = os.path.join(UD_TOOLS_ROOT, 'conllu-quick-fix.pl')
+
+
+def run_perl(script, src, dst=None):
+ if not dst:
+ dst = tempfile.NamedTemporaryFile().name
+ exitcode, out, err = get_exitcode_stdout_stderr(
+ f'perl -I{os.path.expanduser("~/.local/lib/perl5")} {script} {src}')
+ if exitcode:
+ # cpanm -l ~/.local namespace::autoclean
+ # cpanm -l ~/.local Moose
+ # cpanm -l ~/.local MooseX::SemiAffordanceAccessor module
+ raise RuntimeError(err)
+ with open(dst, 'w') as ofile:
+ ofile.write(out)
+ return dst
+
+
+def enhanced_collapse_empty_nodes(src, dst=None):
+ return run_perl(ENHANCED_COLLAPSE_EMPTY_NODES, src, dst)
+
+
+def conllu_quick_fix(src, dst=None):
+ return run_perl(CONLLU_QUICK_FIX, src, dst)
+
+
+def load_conll_to_str(path) -> List[str]:
+ with open(path) as src:
+ text = src.read()
+ sents = text.split('\n\n')
+ sents = [x for x in sents if x.strip()]
+ return sents
+
+
+def remove_complete_edges(src, dst):
+ sents = load_conll_to_str(src)
+ with open(dst, 'w') as out:
+ for each in sents:
+ for line in each.split('\n'):
+ if line.startswith('#'):
+ out.write(line)
+ else:
+ cells = line.split('\t')
+ cells[7] = cells[7].split(':')[0]
+ out.write('\t'.join(cells))
+ out.write('\n')
+ out.write('\n')
+
+
+def remove_collapse_edges(src, dst):
+ sents = load_conll_to_str(src)
+ with open(dst, 'w') as out:
+ for each in sents:
+ for line in each.split('\n'):
+ if line.startswith('#'):
+ out.write(line)
+ else:
+ cells = line.split('\t')
+ deps = cells[8].split('|')
+ deps = [x.split('>')[0] for x in deps]
+ cells[8] = '|'.join(deps)
+ out.write('\t'.join(cells))
+ out.write('\n')
+ out.write('\n')
+
+
+def restore_collapse_edges(src, dst):
+ sents = load_conll_to_str(src)
+ with open(dst, 'w') as out:
+ for each in sents:
+ empty_nodes = {} # head to deps
+ lines = each.split('\n')
+ tokens = [x for x in lines if not x.startswith('#') and x.split()[0].isdigit()]
+ for line in lines:
+ line = line.strip()
+ if not line:
+ continue
+ if line.startswith('#'):
+ out.write(line)
+ else:
+ cells = line.split('\t')
+ deps = cells[8].split('|')
+ for i, d in enumerate(deps):
+ if '>' in d:
+ head, rel = d.split(':', 1)
+ ehead = f'{len(tokens)}.{len(empty_nodes) + 1}'
+ par, cur = rel.split('>', 1)
+ cur = cur.split('>')[0]
+ deps[i] = f'{ehead}:{cur}'
+ empty_nodes[ehead] = f'{head}:{par}'
+ cells[8] = '|'.join(deps)
+ out.write('\t'.join(cells))
+ out.write('\n')
+ num_tokens = int(line.split('\t')[0])
+ assert num_tokens == len(tokens)
+ for idx, (ehead, deps) in enumerate(empty_nodes.items()):
+ out.write(f'{num_tokens}.{idx + 1}\t' + '_\t' * 7 + deps + '\t_\n')
+ out.write('\n')
+
+
+def evaluate(gold_file, pred_file, do_enhanced_collapse_empty_nodes=False, do_copy_cols=True):
+ """Evaluate using official CoNLL-X evaluation script (Yuval Krymolowski)
+
+ Args:
+ gold_file(str): The gold conllx file
+ pred_file(str): The pred conllx file
+ do_enhanced_collapse_empty_nodes: (Default value = False)
+ do_copy_cols: (Default value = True)
+
+ Returns:
+
+
+ """
+ if do_enhanced_collapse_empty_nodes:
+ gold_file = enhanced_collapse_empty_nodes(gold_file)
+ pred_file = enhanced_collapse_empty_nodes(pred_file)
+ if do_copy_cols:
+ fixed_pred_file = pred_file.replace('.conllu', '.fixed.conllu')
+ copy_cols(gold_file, pred_file, fixed_pred_file)
+ else:
+ fixed_pred_file = pred_file
+ args = SerializableDict()
+ args.enhancements = '0'
+ args.gold_file = gold_file
+ args.system_file = fixed_pred_file
+ return iwpt20_xud_eval.evaluate_wrapper(args)
+
+
+def main():
+ print(evaluate('data/iwpt2020/iwpt2020-test-gold/cs.conllu',
+ 'data/model/iwpt2020/bert/ens/cs.conllu', do_enhanced_collapse_empty_nodes=True))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/hanlp/metrics/parsing/iwpt20_xud_eval.py b/hanlp/metrics/parsing/iwpt20_xud_eval.py
new file mode 100644
index 000000000..1b59ac1b8
--- /dev/null
+++ b/hanlp/metrics/parsing/iwpt20_xud_eval.py
@@ -0,0 +1,766 @@
+#!/usr/bin/env python3
+
+# updated code from conll 2018 ud shared task for evaluation of enhanced dependencies in iwpt 2020 shared task
+# -- read DEPS, split on '|', compute overlap
+# Gosse Bouma
+
+# Compatible with Python 2.7 and 3.2+, can be used either as a module
+# or a standalone executable.
+#
+# Copyright 2017, 2018 Institute of Formal and Applied Linguistics (UFAL),
+# Faculty of Mathematics and Physics, Charles University, Czech Republic.
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+#
+# Authors: Milan Straka, Martin Popel
+#
+# Changelog:
+# - [12 Apr 2018] Version 0.9: Initial release.
+# - [19 Apr 2018] Version 1.0: Fix bug in MLAS (duplicate entries in functional_children).
+# Add --counts option.
+# - [02 May 2018] Version 1.1: When removing spaces to match gold and system characters,
+# consider all Unicode characters of category Zs instead of
+# just ASCII space.
+# - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python).
+# In Python2, make the whole computation use `unicode` strings.
+
+# Command line usage
+# ------------------
+# iwpt20_eud_eval.py3 [-v] [-c] gold_conllu_file system_conllu_file
+#
+# - if no -v is given, only the official IWPT 2020 Shared Task evaluation metrics
+# are printed
+# - if -v is given, more metrics are printed (as precision, recall, F1 score,
+# and in case the metric is computed on aligned words also accuracy on these):
+# - Tokens: how well do the gold tokens match system tokens
+# - Sentences: how well do the gold sentences match system sentences
+# - Words: how well can the gold words be aligned to system words
+# - UPOS: using aligned words, how well does UPOS match
+# - XPOS: using aligned words, how well does XPOS match
+# - UFeats: using aligned words, how well does universal FEATS match
+# - AllTags: using aligned words, how well does UPOS+XPOS+FEATS match
+# - Lemmas: using aligned words, how well does LEMMA match
+# - UAS: using aligned words, how well does HEAD match
+# - LAS: using aligned words, how well does HEAD+DEPREL(ignoring subtypes) match
+# - CLAS: using aligned words with content DEPREL, how well does
+# HEAD+DEPREL(ignoring subtypes) match
+# - MLAS: using aligned words with content DEPREL, how well does
+# HEAD+DEPREL(ignoring subtypes)+UPOS+UFEATS+FunctionalChildren(DEPREL+UPOS+UFEATS) match
+# - BLEX: using aligned words with content DEPREL, how well does
+# HEAD+DEPREL(ignoring subtypes)+LEMMAS match
+# - if -c is given, raw counts of correct/gold_total/system_total/aligned words are printed
+# instead of precision/recall/F1/AlignedAccuracy for all metrics.
+
+# API usage
+# ---------
+# - load_conllu(file)
+# - loads CoNLL-U file from given file object to an internal representation
+# - the file object should return str in both Python 2 and Python 3
+# - raises UDError exception if the given file cannot be loaded
+# - evaluate(gold_ud, system_ud)
+# - evaluate the given gold and system CoNLL-U files (loaded with load_conllu)
+# - raises UDError if the concatenated tokens of gold and system file do not match
+# - returns a dictionary with the metrics described above, each metric having
+# three fields: precision, recall and f1
+
+# Description of token matching
+# -----------------------------
+# In order to match tokens of gold file and system file, we consider the text
+# resulting from concatenation of gold tokens and text resulting from
+# concatenation of system tokens. These texts should match -- if they do not,
+# the evaluation fails.
+#
+# If the texts do match, every token is represented as a range in this original
+# text, and tokens are equal only if their range is the same.
+
+# Description of word matching
+# ----------------------------
+# When matching words of gold file and system file, we first match the tokens.
+# The words which are also tokens are matched as tokens, but words in multi-word
+# tokens have to be handled differently.
+#
+# To handle multi-word tokens, we start by finding "multi-word spans".
+# Multi-word span is a span in the original text such that
+# - it contains at least one multi-word token
+# - all multi-word tokens in the span (considering both gold and system ones)
+# are completely inside the span (i.e., they do not "stick out")
+# - the multi-word span is as small as possible
+#
+# For every multi-word span, we align the gold and system words completely
+# inside this span using LCS on their FORMs. The words not intersecting
+# (even partially) any multi-word span are then aligned as tokens.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import io
+import sys
+import unicodedata
+import unittest
+
+# CoNLL-U column names
+ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10)
+
+# Content and functional relations
+CONTENT_DEPRELS = {
+ "nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative",
+ "expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos",
+ "nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list",
+ "parataxis", "orphan", "goeswith", "reparandum", "root", "dep"
+}
+
+FUNCTIONAL_DEPRELS = {
+ "aux", "cop", "mark", "det", "clf", "case", "cc"
+}
+
+UNIVERSAL_FEATURES = {
+ "PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender",
+ "Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood",
+ "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite"
+}
+
+# UD Error is used when raising exceptions in this module
+class UDError(Exception):
+ pass
+
+# Conversion methods handling `str` <-> `unicode` conversions in Python2
+def _decode(text):
+ return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8")
+
+def _encode(text):
+ return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8")
+
+CASE_DEPRELS = {'obl','nmod','conj','advcl'}
+UNIVERSAL_DEPREL_EXTENSIONS = {'pass','relcl','xsubj'}
+
+# modify the set of deps produced by system to be in accordance with gold treebank type
+# return a (filtered) list of (hd,dependency_path) tuples. -- GB
+def process_enhanced_deps(deps) :
+ edeps = []
+ for edep in deps.split('|') :
+ (hd,path) = edep.split(':',1)
+ steps = path.split('>') # collapsing empty nodes gives rise to paths like this : 3:conj:en>obl:voor
+ edeps.append((hd,steps)) # (3,['conj:en','obj:voor'])
+ return edeps
+
+# Load given CoNLL-U file into internal representation
+def load_conllu(file,treebank_type):
+ # Internal representation classes
+ class UDRepresentation:
+ def __init__(self):
+ # Characters of all the tokens in the whole file.
+ # Whitespace between tokens is not included.
+ self.characters = []
+ # List of UDSpan instances with start&end indices into `characters`.
+ self.tokens = []
+ # List of UDWord instances.
+ self.words = []
+ # List of UDSpan instances with start&end indices into `characters`.
+ self.sentences = []
+ class UDSpan:
+ def __init__(self, start, end):
+ self.start = start
+ # Note that self.end marks the first position **after the end** of span,
+ # so we can use characters[start:end] or range(start, end).
+ self.end = end
+ class UDWord:
+ def __init__(self, span, columns, is_multiword):
+ # Span of this word (or MWT, see below) within ud_representation.characters.
+ self.span = span
+ # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,...
+ self.columns = columns
+ # is_multiword==True means that this word is part of a multi-word token.
+ # In that case, self.span marks the span of the whole multi-word token.
+ self.is_multiword = is_multiword
+ # Reference to the UDWord instance representing the HEAD (or None if root).
+ self.parent = None
+ # List of references to UDWord instances representing functional-deprel children.
+ self.functional_children = []
+ # Only consider universal FEATS.
+ self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|")
+ if feat.split("=", 1)[0] in UNIVERSAL_FEATURES))
+ # Let's ignore language-specific deprel subtypes.
+ self.columns[DEPREL] = columns[DEPREL].split(":")[0]
+ # Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS
+ self.is_content_deprel = self.columns[DEPREL] in CONTENT_DEPRELS
+ self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS
+ # store enhanced deps --GB
+ # split string positions and enhanced labels as well?
+ self.columns[DEPS] = process_enhanced_deps(columns[DEPS])
+
+ ud = UDRepresentation()
+
+ # Load the CoNLL-U file
+ index, sentence_start = 0, None
+
+ modified_deprels = 0
+
+ while True:
+ line = file.readline()
+ if not line:
+ break
+ line = _decode(line.rstrip("\r\n"))
+
+ # Handle sentence start boundaries
+ if sentence_start is None:
+ # Skip comments
+ if line.startswith("#"):
+ continue
+ # Start a new sentence
+ ud.sentences.append(UDSpan(index, 0))
+ sentence_start = len(ud.words)
+ if not line:
+ # Add parent and children UDWord links and check there are no cycles
+ def process_word(word):
+ if word.parent == "remapping":
+ raise UDError("There is a cycle in a sentence")
+ if word.parent is None:
+ head = int(word.columns[HEAD])
+ if head < 0 or head > len(ud.words) - sentence_start:
+ raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[HEAD])))
+ if head:
+ parent = ud.words[sentence_start + head - 1]
+ word.parent = "remapping"
+ process_word(parent)
+ word.parent = parent
+
+
+ position = sentence_start # need to incrementally keep track of current position for loop detection in relcl
+ for word in ud.words[sentence_start:]:
+ process_word(word)
+ enhanced_deps = word.columns[DEPS]
+ # replace head positions of enhanced dependencies with parent word object -- GB
+ processed_deps = []
+ for (head,steps) in word.columns[DEPS] : # (3,['conj:en','obj:voor'])
+ hd = int(head)
+ parent = ud.words[sentence_start + hd -1] if hd else hd # just assign '0' to parent for root cases
+ processed_deps.append((parent,steps))
+ enhanced_deps = processed_deps
+
+ # make the evaluation script ignore various types of enhancements -- GB
+
+ # ignore rel>rel dependencies, and instead append the original hd/rel edge
+ # note that this also ignores other extensions (like adding lemma's)
+ # note that this sometimes introduces duplicates (if orig hd/rel was already included in DEPS)
+ if (treebank_type['no_gapping']) : # enhancement 1
+ processed_deps = []
+ for (parent,steps) in enhanced_deps :
+ if len(steps) > 1 :
+ #print("replaced {} by {}".format(steps,word.columns[DEPREL]))
+ (parent,steps) = (word.parent,[word.columns[DEPREL]])
+ modified_deprels += 1
+ if not((parent,steps) in processed_deps) :
+ processed_deps.append((parent,steps))
+ enhanced_deps = processed_deps
+
+ # for a given conj node, any rel other than conj in DEPS can be ignored
+ if treebank_type['no_shared_parents_in_coordination'] : # enhancement 2
+ for (parent,steps) in enhanced_deps :
+ if len(steps) == 1 and steps[0].startswith('conj') :
+ enhanced_deps = [(parent,steps)]
+ modified_deprels += 1
+
+ # duplicate deprels not matching ud_hd/ud_dep are spurious.
+ # czech/pud estonian/ewt syntagrus finnish/pud
+ # NB: treebanks that do not mark xcomp and relcl subjects: we now preserve duplicate nsubj if parent is xcomp
+ # but in: the man who walked and talked, we now also preserve nsubj 2x for 'who'
+ # idem in I know that she walked and talked
+ if treebank_type['no_shared_dependents_in_coordination'] : # enhancement 3
+ processed_deps = []
+ for (parent,steps) in enhanced_deps :
+ duplicate = 0
+ ud_hd = word.parent
+ for (p2,s2) in enhanced_deps :
+ if steps == s2 and p2 == ud_hd and parent != p2 :
+ if not (p2.columns[DEPREL] in ('xcomp','acl','acl:relcl') and steps == ['nsubj']) :
+ duplicate = 1
+ modified_deprels += 1
+ if not(duplicate) :
+ processed_deps.append((parent,steps))
+ enhanced_deps = processed_deps
+
+ # if treebank does not have control relations: subjects of xcomp parents in system are to be skipped
+ # note that rel is actually a path sometimes rel1>rel2 in theory rel2 could be subj?
+ # from lassy-small: 7:conj:en>nsubj:pass|7:conj:en>nsubj:xsubj (7,['conj:en','nsubj:xsubj'])
+ if (treebank_type['no_control']) : # enhancement 4
+ processed_deps = []
+ for (parent,steps) in enhanced_deps :
+ include = 1
+ if ( parent and parent.columns[DEPREL] == 'xcomp') :
+ for rel in steps:
+ if rel.startswith('nsubj') :
+ include = 0
+ modified_deprels += 1
+ if include :
+ processed_deps.append((parent,steps))
+ enhanced_deps = processed_deps
+
+ if (treebank_type['no_external_arguments_of_relative_clauses']) : # enhancement 5
+ processed_deps = []
+ for (parent,steps) in enhanced_deps :
+ if (steps[0] == 'ref') :
+ processed_deps.append((word.parent,[word.columns[DEPREL]])) # append the original relation
+ modified_deprels += 1
+ # ignore external argument link
+ # external args are deps of an acl:relcl where that acl also is a dependent of external arg (i.e. ext arg introduces a cycle)
+ elif ( parent and parent.columns[DEPREL].startswith('acl') and int(parent.columns[HEAD]) == position - sentence_start ) :
+ #print('removed external argument')
+ modified_deprels += 1
+ else :
+ processed_deps.append((parent,steps))
+ enhanced_deps = processed_deps
+
+ # treebanks where no lemma info has been added
+ if treebank_type['no_case_info'] : # enhancement number 6
+ processed_deps = []
+ for (hd,steps) in enhanced_deps :
+ processed_steps = []
+ for dep in steps :
+ depparts = dep.split(':')
+ if depparts[0] in CASE_DEPRELS :
+ if (len(depparts) == 2 and not(depparts[1] in UNIVERSAL_DEPREL_EXTENSIONS )) :
+ dep = depparts[0]
+ modified_deprels += 1
+ processed_steps.append(dep)
+ processed_deps.append((hd,processed_steps))
+ enhanced_deps = processed_deps
+
+ position += 1
+ word.columns[DEPS] = enhanced_deps
+
+
+ # func_children cannot be assigned within process_word
+ # because it is called recursively and may result in adding one child twice.
+ for word in ud.words[sentence_start:]:
+ if word.parent and word.is_functional_deprel:
+ word.parent.functional_children.append(word)
+
+ # Check there is a single root node
+ if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1:
+ raise UDError("There are multiple roots in a sentence")
+
+ # End the sentence
+ ud.sentences[-1].end = index
+ sentence_start = None
+ continue
+
+ # Read next token/word
+ columns = line.split("\t")
+ if len(columns) != 10:
+ raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(line)))
+
+ # Skip empty nodes
+ # After collapsing empty nodes into the enhancements, these should not occur --GB
+ if "." in columns[ID]:
+ raise UDError("The collapsed CoNLL-U line still contains empty nodes: {}".format(_encode(line)))
+
+ # Delete spaces from FORM, so gold.characters == system.characters
+ # even if one of them tokenizes the space. Use any Unicode character
+ # with category Zs.
+ columns[FORM] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[FORM]))
+ if not columns[FORM]:
+ raise UDError("There is an empty FORM in the CoNLL-U file")
+
+ # Save token
+ ud.characters.extend(columns[FORM])
+ ud.tokens.append(UDSpan(index, index + len(columns[FORM])))
+ index += len(columns[FORM])
+
+ # Handle multi-word tokens to save word(s)
+ if "-" in columns[ID]:
+ try:
+ start, end = map(int, columns[ID].split("-"))
+ except:
+ raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[ID])))
+
+ for _ in range(start, end + 1):
+ word_line = _decode(file.readline().rstrip("\r\n"))
+ word_columns = word_line.split("\t")
+ if len(word_columns) != 10:
+ raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line)))
+ ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True))
+
+ # Basic tokens/words
+ else:
+ try:
+ word_id = int(columns[ID])
+ except:
+ raise UDError("Cannot parse word ID '{}'".format(_encode(columns[ID])))
+ if word_id != len(ud.words) - sentence_start + 1:
+ raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format(
+ _encode(columns[ID]), _encode(columns[FORM]), len(ud.words) - sentence_start + 1))
+
+ try:
+ head_id = int(columns[HEAD])
+ except:
+ raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[HEAD])))
+ if head_id < 0:
+ raise UDError("HEAD cannot be negative")
+
+ ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False))
+
+ if modified_deprels :
+ print('modified/deleted {} enhanced DEPRELS in {}'.format(modified_deprels,file.name))
+
+ if sentence_start is not None:
+ raise UDError("The CoNLL-U file does not end with empty line")
+
+ return ud
+
+# Evaluate the gold and system treebanks (loaded using load_conllu).
+def evaluate(gold_ud, system_ud):
+ class Score:
+ def __init__(self, gold_total, system_total, correct, aligned_total=None):
+ self.correct = correct
+ self.gold_total = gold_total
+ self.system_total = system_total
+ self.aligned_total = aligned_total
+ self.precision = correct / system_total if system_total else 0.0
+ self.recall = correct / gold_total if gold_total else 0.0
+ self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0
+ self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total
+ class AlignmentWord:
+ def __init__(self, gold_word, system_word):
+ self.gold_word = gold_word
+ self.system_word = system_word
+ class Alignment:
+ def __init__(self, gold_words, system_words):
+ self.gold_words = gold_words
+ self.system_words = system_words
+ self.matched_words = []
+ self.matched_words_map = {}
+ def append_aligned_words(self, gold_word, system_word):
+ self.matched_words.append(AlignmentWord(gold_word, system_word))
+ self.matched_words_map[system_word] = gold_word
+
+ def spans_score(gold_spans, system_spans):
+ correct, gi, si = 0, 0, 0
+ while gi < len(gold_spans) and si < len(system_spans):
+ if system_spans[si].start < gold_spans[gi].start:
+ si += 1
+ elif gold_spans[gi].start < system_spans[si].start:
+ gi += 1
+ else:
+ correct += gold_spans[gi].end == system_spans[si].end
+ si += 1
+ gi += 1
+
+ return Score(len(gold_spans), len(system_spans), correct)
+
+ def alignment_score(alignment, key_fn=None, filter_fn=None):
+ if filter_fn is not None:
+ gold = sum(1 for gold in alignment.gold_words if filter_fn(gold))
+ system = sum(1 for system in alignment.system_words if filter_fn(system))
+ aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word))
+ else:
+ gold = len(alignment.gold_words)
+ system = len(alignment.system_words)
+ aligned = len(alignment.matched_words)
+
+ if key_fn is None:
+ # Return score for whole aligned words
+ return Score(gold, system, aligned)
+
+ def gold_aligned_gold(word):
+ return word
+ def gold_aligned_system(word):
+ return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None
+ correct = 0
+ for words in alignment.matched_words:
+ if filter_fn is None or filter_fn(words.gold_word):
+ if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system):
+ correct += 1
+
+ return Score(gold, system, correct, aligned)
+
+ def enhanced_alignment_score(alignment):
+ # count all matching enhanced deprels in gold, system GB
+ # gold and system = sum of gold and predicted deps
+ # parents are pointers to word object, make sure to compare system parent with aligned word in gold in cases where
+ # tokenization introduces mismatches in number of words per sentence.
+ gold = 0
+ for gold_word in alignment.gold_words :
+ gold += len(gold_word.columns[DEPS])
+ system = 0
+ for system_word in alignment.system_words :
+ system += len(system_word.columns[DEPS])
+ # NB aligned does not play a role in computing f1 score -- GB
+ aligned = len(alignment.matched_words)
+ correct = 0
+ for words in alignment.matched_words:
+ gold_deps = words.gold_word.columns[DEPS]
+ system_deps = words.system_word.columns[DEPS]
+ for (parent,dep) in gold_deps :
+ for (sparent,sdep) in system_deps :
+ if dep == sdep :
+ if parent == alignment.matched_words_map.get(sparent,"NotAligned") :
+ correct += 1
+ elif (parent == 0 and sparent == 0) : # cases where parent is root
+ correct += 1
+
+ return Score(gold, system, correct, aligned)
+
+
+ def beyond_end(words, i, multiword_span_end):
+ if i >= len(words):
+ return True
+ if words[i].is_multiword:
+ return words[i].span.start >= multiword_span_end
+ return words[i].span.end > multiword_span_end
+
+ def extend_end(word, multiword_span_end):
+ if word.is_multiword and word.span.end > multiword_span_end:
+ return word.span.end
+ return multiword_span_end
+
+ def find_multiword_span(gold_words, system_words, gi, si):
+ # We know gold_words[gi].is_multiword or system_words[si].is_multiword.
+ # Find the start of the multiword span (gs, ss), so the multiword span is minimal.
+ # Initialize multiword_span_end characters index.
+ if gold_words[gi].is_multiword:
+ multiword_span_end = gold_words[gi].span.end
+ if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start:
+ si += 1
+ else: # if system_words[si].is_multiword
+ multiword_span_end = system_words[si].span.end
+ if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start:
+ gi += 1
+ gs, ss = gi, si
+
+ # Find the end of the multiword span
+ # (so both gi and si are pointing to the word following the multiword span end).
+ while not beyond_end(gold_words, gi, multiword_span_end) or \
+ not beyond_end(system_words, si, multiword_span_end):
+ if gi < len(gold_words) and (si >= len(system_words) or
+ gold_words[gi].span.start <= system_words[si].span.start):
+ multiword_span_end = extend_end(gold_words[gi], multiword_span_end)
+ gi += 1
+ else:
+ multiword_span_end = extend_end(system_words[si], multiword_span_end)
+ si += 1
+ return gs, ss, gi, si
+
+ def compute_lcs(gold_words, system_words, gi, si, gs, ss):
+ lcs = [[0] * (si - ss) for i in range(gi - gs)]
+ for g in reversed(range(gi - gs)):
+ for s in reversed(range(si - ss)):
+ if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower():
+ lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0)
+ lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0)
+ lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0)
+ return lcs
+
+ def align_words(gold_words, system_words):
+ alignment = Alignment(gold_words, system_words)
+
+ gi, si = 0, 0
+ while gi < len(gold_words) and si < len(system_words):
+ if gold_words[gi].is_multiword or system_words[si].is_multiword:
+ # A: Multi-word tokens => align via LCS within the whole "multiword span".
+ gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si)
+
+ if si > ss and gi > gs:
+ lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss)
+
+ # Store aligned words
+ s, g = 0, 0
+ while g < gi - gs and s < si - ss:
+ if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower():
+ alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s])
+ g += 1
+ s += 1
+ elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0):
+ g += 1
+ else:
+ s += 1
+ else:
+ # B: No multi-word token => align according to spans.
+ if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end):
+ alignment.append_aligned_words(gold_words[gi], system_words[si])
+ gi += 1
+ si += 1
+ elif gold_words[gi].span.start <= system_words[si].span.start:
+ gi += 1
+ else:
+ si += 1
+
+ return alignment
+
+ # Check that the underlying character sequences do match.
+ if gold_ud.characters != system_ud.characters:
+ index = 0
+ while index < len(gold_ud.characters) and index < len(system_ud.characters) and \
+ gold_ud.characters[index] == system_ud.characters[index]:
+ index += 1
+
+ raise UDError(
+ "The concatenation of tokens in gold file and in system file differ!\n" +
+ "First 20 differing characters in gold file: '{}' and system file: '{}'".format(
+ "".join(map(_encode, gold_ud.characters[index:index + 20])),
+ "".join(map(_encode, system_ud.characters[index:index + 20]))
+ )
+ )
+
+ # Align words
+ alignment = align_words(gold_ud.words, system_ud.words)
+
+ # Compute the F1-scores
+ return {
+ "Tokens": spans_score(gold_ud.tokens, system_ud.tokens),
+ "Sentences": spans_score(gold_ud.sentences, system_ud.sentences),
+ "Words": alignment_score(alignment),
+ "UPOS": alignment_score(alignment, lambda w, _: w.columns[UPOS]),
+ "XPOS": alignment_score(alignment, lambda w, _: w.columns[XPOS]),
+ "UFeats": alignment_score(alignment, lambda w, _: w.columns[FEATS]),
+ "AllTags": alignment_score(alignment, lambda w, _: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])),
+ "Lemmas": alignment_score(alignment, lambda w, ga: w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"),
+ "UAS": alignment_score(alignment, lambda w, ga: ga(w.parent)),
+ "LAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL])),
+ # include enhanced DEPS score -- GB
+ "ELAS": enhanced_alignment_score(alignment),
+ "CLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL]),
+ filter_fn=lambda w: w.is_content_deprel),
+ "MLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL], w.columns[UPOS], w.columns[FEATS],
+ [(ga(c), c.columns[DEPREL], c.columns[UPOS], c.columns[FEATS])
+ for c in w.functional_children]),
+ filter_fn=lambda w: w.is_content_deprel),
+ "BLEX": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL],
+ w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"),
+ filter_fn=lambda w: w.is_content_deprel),
+ }
+
+
+def load_conllu_file(path,treebank_type):
+ _file = open(path, mode="r", **({"encoding": "utf-8"} if sys.version_info >= (3, 0) else {}))
+ return load_conllu(_file,treebank_type)
+
+def evaluate_wrapper(args):
+ treebank_type = {}
+ enhancements = list(args.enhancements)
+ treebank_type['no_gapping'] = 1 if '1' in enhancements else 0
+ treebank_type['no_shared_parents_in_coordination'] = 1 if '2' in enhancements else 0
+ treebank_type['no_shared_dependents_in_coordination'] = 1 if '3' in enhancements else 0
+ treebank_type['no_control'] = 1 if '4' in enhancements else 0
+ treebank_type['no_external_arguments_of_relative_clauses'] = 1 if '5' in enhancements else 0
+ treebank_type['no_case_info'] = 1 if '6' in enhancements else 0
+ for key in treebank_type :
+ if treebank_type[key] :
+ print('evaluating with {} enhancements setting'.format(key))
+
+ # Load CoNLL-U files
+ gold_ud = load_conllu_file(args.gold_file,treebank_type)
+ system_ud = load_conllu_file(args.system_file,treebank_type)
+ return evaluate(gold_ud, system_ud)
+
+def main():
+ # Parse arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument("gold_file", type=str,
+ help="Name of the CoNLL-U file with the gold data.")
+ parser.add_argument("system_file", type=str,
+ help="Name of the CoNLL-U file with the predicted data.")
+ parser.add_argument("--verbose", "-v", default=False, action="store_true",
+ help="Print all metrics.")
+ parser.add_argument("--counts", "-c", default=False, action="store_true",
+ help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.")
+ parser.add_argument("--enhancements", type=str, default='0',
+ help="Level of enhancements in the gold data (see guidelines) 0=all (default), 1=no gapping, 2=no shared parents, 3=no shared dependents 4=no control, 5=no external arguments, 6=no lemma info, 12=both 1 and 2 apply, etc.")
+ args = parser.parse_args()
+
+ # Evaluate
+ evaluation = evaluate_wrapper(args)
+
+ # Print the evaluation
+ if not args.verbose and not args.counts:
+ print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1))
+ print("ELAS F1 Score: {:.2f}".format(100 * evaluation["ELAS"].f1))
+
+ print("MLAS Score: {:.2f}".format(100 * evaluation["MLAS"].f1))
+ print("BLEX Score: {:.2f}".format(100 * evaluation["BLEX"].f1))
+ else:
+ if args.counts:
+ print("Metric | Correct | Gold | Predicted | Aligned")
+ else:
+ print("Metric | Precision | Recall | F1 Score | AligndAcc")
+ print("-----------+-----------+-----------+-----------+-----------")
+ for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "ELAS", "CLAS", "MLAS", "BLEX"]:
+ if args.counts:
+ print("{:11}|{:10} |{:10} |{:10} |{:10}".format(
+ metric,
+ evaluation[metric].correct,
+ evaluation[metric].gold_total,
+ evaluation[metric].system_total,
+ evaluation[metric].aligned_total or (evaluation[metric].correct if metric == "Words" else "")
+ ))
+ else:
+ print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format(
+ metric,
+ 100 * evaluation[metric].precision,
+ 100 * evaluation[metric].recall,
+ 100 * evaluation[metric].f1,
+ "{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else ""
+ ))
+
+if __name__ == "__main__":
+ main()
+
+# Tests, which can be executed with `python -m unittest conll18_ud_eval`.
+class TestAlignment(unittest.TestCase):
+ @staticmethod
+ def _load_words(words):
+ """Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors.
+
+ Args:
+ words:
+
+ Returns:
+
+ """
+ lines, num_words = [], 0
+ for w in words:
+ parts = w.split(" ")
+ if len(parts) == 1:
+ num_words += 1
+ lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1)))
+ else:
+ lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0]))
+ for part in parts[1:]:
+ num_words += 1
+ lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1)))
+ return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"])))
+
+ def _test_exception(self, gold, system):
+ self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system))
+
+ def _test_ok(self, gold, system, correct):
+ metrics = evaluate(self._load_words(gold), self._load_words(system))
+ gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold))
+ system_words = sum((max(1, len(word.split(" ")) - 1) for word in system))
+ self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1),
+ (correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words)))
+
+ def test_exception(self):
+ self._test_exception(["a"], ["b"])
+
+ def test_equal(self):
+ self._test_ok(["a"], ["a"], 1)
+ self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3)
+
+ def test_equal_with_multiword(self):
+ self._test_ok(["abc a b c"], ["a", "b", "c"], 3)
+ self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4)
+ self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4)
+ self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5)
+
+ def test_alignment(self):
+ self._test_ok(["abcd"], ["a", "b", "c", "d"], 0)
+ self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1)
+ self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2)
+ self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2)
+ self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4)
+ self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2)
+ self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1)
diff --git a/hanlp/metrics/parsing/labeled_f1.py b/hanlp/metrics/parsing/labeled_f1.py
index 1ff5e870a..4d4795030 100644
--- a/hanlp/metrics/parsing/labeled_f1.py
+++ b/hanlp/metrics/parsing/labeled_f1.py
@@ -1,10 +1,11 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-27 21:42
-import tensorflow as tf
+from hanlp.metrics.metric import Metric
-class LabeledF1(object):
+
+class LabeledF1(Metric):
def __init__(self):
super(LabeledF1, self).__init__()
@@ -15,21 +16,20 @@ def __init__(self):
self.correct_rels_wo_punc = 0.0
def __repr__(self):
- return f"UF: {self.uf:6.2%} LF: {self.lf:6.2%}"
+ return f"UF: {self.uf:4.2%} LF: {self.lf:4.2%}"
def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask):
- mask = tf.tile(tf.expand_dims(mask, -1), [1, 1, tf.shape(mask)[-1]])
- mask &= tf.transpose(mask, [0, 2, 1])
-
mask_gold = mask & arc_golds
mask_pred = mask & arc_preds
- correct_arcs_wo_punc = (arc_preds == arc_golds)[mask_gold & mask_pred]
- correct_rels_wo_punc = (rel_preds == rel_golds)[mask_gold & mask_pred] & correct_arcs_wo_punc
- self.sum_gold_arcs_wo_punc += float(tf.math.count_nonzero(mask_gold))
- self.sum_pred_arcs_wo_punc += float(tf.math.count_nonzero(mask_pred))
- self.correct_arcs_wo_punc += float(tf.math.count_nonzero(correct_arcs_wo_punc))
- self.correct_rels_wo_punc += float(tf.math.count_nonzero(correct_rels_wo_punc))
+ correct_mask = mask_gold & mask_pred
+ correct_arcs_wo_punc = (arc_preds == arc_golds)[correct_mask]
+ correct_rels_wo_punc = (rel_preds == rel_golds)[correct_mask] & correct_arcs_wo_punc
+
+ self.sum_gold_arcs_wo_punc += float(mask_gold.sum())
+ self.sum_pred_arcs_wo_punc += float(mask_pred.sum())
+ self.correct_arcs_wo_punc += float(correct_arcs_wo_punc.sum())
+ self.correct_rels_wo_punc += float(correct_rels_wo_punc.sum())
def __lt__(self, other):
return self.score < other
@@ -58,42 +58,42 @@ def las(self):
@property
def ur(self):
if not self.sum_gold_arcs_wo_punc:
- return 0
+ return .0
return self.correct_arcs_wo_punc / self.sum_gold_arcs_wo_punc
@property
def up(self):
if not self.sum_pred_arcs_wo_punc:
- return 0
+ return .0
return self.correct_arcs_wo_punc / self.sum_pred_arcs_wo_punc
@property
def lr(self):
if not self.sum_gold_arcs_wo_punc:
- return 0
+ return .0
return self.correct_rels_wo_punc / self.sum_gold_arcs_wo_punc
@property
def lp(self):
if not self.sum_pred_arcs_wo_punc:
- return 0
+ return .0
return self.correct_rels_wo_punc / self.sum_pred_arcs_wo_punc
@property
def uf(self):
rp = self.ur + self.up
if not rp:
- return 0
+ return .0
return 2 * self.ur * self.up / rp
@property
def lf(self):
rp = self.lr + self.lp
if not rp:
- return 0
+ return .0
return 2 * self.lr * self.lp / rp
- def reset_states(self):
+ def reset(self):
self.sum_gold_arcs_wo_punc = 0.0
self.sum_pred_arcs_wo_punc = 0.0
self.correct_arcs_wo_punc = 0.0
diff --git a/hanlp/metrics/parsing/labeled_f1_tf.py b/hanlp/metrics/parsing/labeled_f1_tf.py
new file mode 100644
index 000000000..efe39ee24
--- /dev/null
+++ b/hanlp/metrics/parsing/labeled_f1_tf.py
@@ -0,0 +1,103 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-27 21:42
+import tensorflow as tf
+
+
+class LabeledF1TF(object):
+
+ def __init__(self):
+ super(LabeledF1TF, self).__init__()
+
+ self.sum_gold_arcs_wo_punc = 0.0
+ self.sum_pred_arcs_wo_punc = 0.0
+ self.correct_arcs_wo_punc = 0.0
+ self.correct_rels_wo_punc = 0.0
+
+ def __repr__(self):
+ return f"UF: {self.uf:6.2%} LF: {self.lf:6.2%}"
+
+ def __call__(self, arc_preds, rel_preds, arc_golds, rel_golds, mask):
+ mask = mask.unsqueeze(-1).expand_as(arc_preds)
+ mask = mask & mask.transpose(1, 2)
+
+ mask_gold = mask & arc_golds
+ mask_pred = mask & arc_preds
+ correct_arcs_wo_punc = (arc_preds == arc_golds)[mask_gold & mask_pred]
+ correct_rels_wo_punc = (rel_preds == rel_golds)[mask_gold & mask_pred] & correct_arcs_wo_punc
+
+ self.sum_gold_arcs_wo_punc += float(tf.math.count_nonzero(mask_gold))
+ self.sum_pred_arcs_wo_punc += float(tf.math.count_nonzero(mask_pred))
+ self.correct_arcs_wo_punc += float(tf.math.count_nonzero(correct_arcs_wo_punc))
+ self.correct_rels_wo_punc += float(tf.math.count_nonzero(correct_rels_wo_punc))
+
+ def __lt__(self, other):
+ return self.score < other
+
+ def __le__(self, other):
+ return self.score <= other
+
+ def __ge__(self, other):
+ return self.score >= other
+
+ def __gt__(self, other):
+ return self.score > other
+
+ @property
+ def score(self):
+ return self.las
+
+ @property
+ def uas(self):
+ return self.uf
+
+ @property
+ def las(self):
+ return self.lf
+
+ @property
+ def ur(self):
+ if not self.sum_gold_arcs_wo_punc:
+ return 0
+ return self.correct_arcs_wo_punc / self.sum_gold_arcs_wo_punc
+
+ @property
+ def up(self):
+ if not self.sum_pred_arcs_wo_punc:
+ return 0
+ return self.correct_arcs_wo_punc / self.sum_pred_arcs_wo_punc
+
+ @property
+ def lr(self):
+ if not self.sum_gold_arcs_wo_punc:
+ return 0
+ return self.correct_rels_wo_punc / self.sum_gold_arcs_wo_punc
+
+ @property
+ def lp(self):
+ if not self.sum_pred_arcs_wo_punc:
+ return 0
+ return self.correct_rels_wo_punc / self.sum_pred_arcs_wo_punc
+
+ @property
+ def uf(self):
+ rp = self.ur + self.up
+ if not rp:
+ return 0
+ return 2 * self.ur * self.up / rp
+
+ @property
+ def lf(self):
+ rp = self.lr + self.lp
+ if not rp:
+ return 0
+ return 2 * self.lr * self.lp / rp
+
+ def reset_states(self):
+ self.sum_gold_arcs_wo_punc = 0.0
+ self.sum_pred_arcs_wo_punc = 0.0
+ self.correct_arcs_wo_punc = 0.0
+ self.correct_rels_wo_punc = 0.0
+
+ def to_dict(self) -> dict:
+ return {'UF': self.uf, 'LF': self.lf}
diff --git a/hanlp/metrics/parsing/semdep_eval.py b/hanlp/metrics/parsing/semdep_eval.py
new file mode 100644
index 000000000..d2cb023e8
--- /dev/null
+++ b/hanlp/metrics/parsing/semdep_eval.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2017 Timothy Dozat
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import codecs
+import sys
+from collections import namedtuple
+
+
+# ===============================================================
+def sdp_eval(gold_files, sys_files, labeled=False):
+ """Modified from https://github.com/tdozat/Parser-v3/blob/2ff4061373e8aac8c962537a6220e1d5b196abf6/scripts/semdep_eval.py
+ Dozat claimed "I tested it against the official eval script and it reported identical LF1".
+
+ Args:
+ gold_files:
+ sys_files:
+ labeled: (Default value = False)
+
+ Returns:
+
+
+ """
+
+ correct = 0
+ predicted = 0
+ actual = 0
+ n_tokens = 0
+ n_sequences = 0
+ current_seq_correct = False
+ n_correct_sequences = 0
+ current_sent = 0
+ if isinstance(gold_files, str):
+ gold_files = [gold_files]
+ if isinstance(sys_files, str):
+ sys_files = [sys_files]
+
+ for gold_file, sys_file in zip(gold_files, sys_files):
+ with codecs.open(gold_file, encoding='utf-8') as gf, \
+ codecs.open(sys_file, encoding='utf-8') as sf:
+ gold_line = gf.readline()
+ gold_i = 1
+ sys_i = 0
+ while gold_line:
+ while gold_line.startswith('#'):
+ current_sent += 1
+ gold_i += 1
+ n_sequences += 1
+ n_correct_sequences += current_seq_correct
+ current_seq_correct = True
+ gold_line = gf.readline()
+ if gold_line.rstrip() != '':
+ sys_line = sf.readline()
+ sys_i += 1
+ while sys_line.startswith('#') or sys_line.rstrip() == '' or sys_line.split('\t')[0] == '0':
+ sys_line = sf.readline()
+ sys_i += 1
+
+ gold_line = gold_line.rstrip().split('\t')
+ sys_line = sys_line.rstrip().split('\t')
+ # assert sys_line[1] == gold_line[1], 'Files are misaligned at lines {}, {}'.format(gold_i, sys_i)
+
+ # Compute the gold edges
+ gold_node = gold_line[8]
+ if gold_node != '_':
+ gold_node = gold_node.split('|')
+ if labeled:
+ gold_edges = set(tuple(gold_edge.split(':', 1)) for gold_edge in gold_node)
+ else:
+ gold_edges = set(gold_edge.split(':', 1)[0] for gold_edge in gold_node)
+ else:
+ gold_edges = set()
+
+ # Compute the sys edges
+ sys_node = sys_line[8]
+ if sys_node != '_':
+ sys_node = sys_node.split('|')
+ if labeled:
+ sys_edges = set(tuple(sys_edge.split(':', 1)) for sys_edge in sys_node)
+ else:
+ sys_edges = set(sys_edge.split(':', 1)[0] for sys_edge in sys_node)
+ else:
+ sys_edges = set()
+
+ correct_edges = gold_edges & sys_edges
+ if len(correct_edges) != len(gold_edges):
+ current_seq_correct = False
+ correct += len(correct_edges)
+ predicted += len(sys_edges)
+ actual += len(gold_edges)
+ n_tokens += 1
+ # current_fp += len(sys_edges) - len(gold_edges & sys_edges)
+ gold_line = gf.readline()
+ gold_i += 1
+ # print(correct, predicted - correct, actual - correct)
+ Accuracy = namedtuple('Accuracy', ['precision', 'recall', 'F1', 'seq_acc'])
+ precision = correct / (predicted + 1e-12)
+ recall = correct / (actual + 1e-12)
+ F1 = 2 * precision * recall / (precision + recall + 1e-12)
+ seq_acc = n_correct_sequences / n_sequences
+ return Accuracy(precision, recall, F1, seq_acc)
+
+
+# ===============================================================
+def main():
+ """ """
+
+ files = sys.argv[1:]
+ n_files = len(files)
+ assert (n_files % 2) == 0
+ gold_files, sys_files = files[:n_files // 2], files[n_files // 2:]
+ UAS = sdp_eval(gold_files, sys_files, labeled=False)
+ LAS = sdp_eval(gold_files, sys_files, labeled=True)
+ # print(UAS.F1, UAS.seq_acc)
+ print('UAS={:0.1f}'.format(UAS.F1 * 100))
+ print('LAS={:0.1f}'.format(LAS.F1 * 100))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/hanlp/metrics/parsing/span.py b/hanlp/metrics/parsing/span.py
new file mode 100644
index 000000000..160799f67
--- /dev/null
+++ b/hanlp/metrics/parsing/span.py
@@ -0,0 +1,103 @@
+# MIT License
+#
+# Copyright (c) 2020 Yu Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from collections import Counter
+
+from hanlp.metrics.metric import Metric
+
+
+class SpanMetric(Metric):
+
+ def __init__(self, eps=1e-12):
+ super().__init__()
+ self.reset(eps)
+
+ # noinspection PyAttributeOutsideInit
+ def reset(self, eps=1e-12):
+ self.n = 0.0
+ self.n_ucm = 0.0
+ self.n_lcm = 0.0
+ self.utp = 0.0
+ self.ltp = 0.0
+ self.pred = 0.0
+ self.gold = 0.0
+ self.eps = eps
+
+ def __call__(self, preds, golds):
+ for pred, gold in zip(preds, golds):
+ upred = Counter([(i, j) for i, j, label in pred])
+ ugold = Counter([(i, j) for i, j, label in gold])
+ utp = list((upred & ugold).elements())
+ lpred = Counter(pred)
+ lgold = Counter(gold)
+ ltp = list((lpred & lgold).elements())
+ self.n += 1
+ self.n_ucm += len(utp) == len(pred) == len(gold)
+ self.n_lcm += len(ltp) == len(pred) == len(gold)
+ self.utp += len(utp)
+ self.ltp += len(ltp)
+ self.pred += len(pred)
+ self.gold += len(gold)
+ return self
+
+ def __repr__(self):
+ s = f"UCM: {self.ucm:.2%} LCM: {self.lcm:.2%} "
+ s += f"UP: {self.up:.2%} UR: {self.ur:.2%} UF: {self.uf:.2%} "
+ s += f"LP: {self.lp:.2%} LR: {self.lr:.2%} LF: {self.lf:.2%}"
+
+ return s
+
+ @property
+ def score(self):
+ return self.lf
+
+ @property
+ def ucm(self):
+ return self.n_ucm / (self.n + self.eps)
+
+ @property
+ def lcm(self):
+ return self.n_lcm / (self.n + self.eps)
+
+ @property
+ def up(self):
+ return self.utp / (self.pred + self.eps)
+
+ @property
+ def ur(self):
+ return self.utp / (self.gold + self.eps)
+
+ @property
+ def uf(self):
+ return 2 * self.utp / (self.pred + self.gold + self.eps)
+
+ @property
+ def lp(self):
+ return self.ltp / (self.pred + self.eps)
+
+ @property
+ def lr(self):
+ return self.ltp / (self.gold + self.eps)
+
+ @property
+ def lf(self):
+ return 2 * self.ltp / (self.pred + self.gold + self.eps)
diff --git a/hanlp/metrics/srl/__init__.py b/hanlp/metrics/srl/__init__.py
new file mode 100644
index 000000000..0a1413ca7
--- /dev/null
+++ b/hanlp/metrics/srl/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-16 18:44
\ No newline at end of file
diff --git a/hanlp/metrics/srl/srlconll.py b/hanlp/metrics/srl/srlconll.py
new file mode 100644
index 000000000..0e26a19f8
--- /dev/null
+++ b/hanlp/metrics/srl/srlconll.py
@@ -0,0 +1,39 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-07-16 18:44
+import os
+
+from hanlp.utils.io_util import get_resource, get_exitcode_stdout_stderr, run_cmd
+
+
+def official_conll_05_evaluate(pred_path, gold_path):
+ script_root = get_resource('http://www.lsi.upc.edu/~srlconll/srlconll-1.1.tgz')
+ lib_path = f'{script_root}/lib'
+ if lib_path not in os.environ.get("PERL5LIB", ""):
+ os.environ['PERL5LIB'] = f'{lib_path}:{os.environ.get("PERL5LIB", "")}'
+ bin_path = f'{script_root}/bin'
+ if bin_path not in os.environ.get('PATH', ''):
+ os.environ['PATH'] = f'{bin_path}:{os.environ.get("PATH", "")}'
+ eval_info_gold_pred = run_cmd(f'perl {script_root}/bin/srl-eval.pl {gold_path} {pred_path}')
+ eval_info_pred_gold = run_cmd(f'perl {script_root}/bin/srl-eval.pl {pred_path} {gold_path}')
+ conll_recall = float(eval_info_gold_pred.strip().split("\n")[6].strip().split()[5]) / 100
+ conll_precision = float(eval_info_pred_gold.strip().split("\n")[6].strip().split()[5]) / 100
+ if conll_recall + conll_precision > 0:
+ conll_f1 = 2 * conll_recall * conll_precision / (conll_recall + conll_precision)
+ else:
+ conll_f1 = 0
+ return conll_precision, conll_recall, conll_f1
+
+
+def run_perl(script, src, dst=None):
+ os.environ['PERL5LIB'] = f''
+ exitcode, out, err = get_exitcode_stdout_stderr(
+ f'perl -I{os.path.expanduser("~/.local/lib/perl5")} {script} {src}')
+ if exitcode:
+ # cpanm -l ~/.local namespace::autoclean
+ # cpanm -l ~/.local Moose
+ # cpanm -l ~/.local MooseX::SemiAffordanceAccessor module
+ raise RuntimeError(err)
+ with open(dst, 'w') as ofile:
+ ofile.write(out)
+ return dst
diff --git a/hanlp/optimizers/adamw/__init__.py b/hanlp/optimizers/adamw/__init__.py
index 8070f5438..84b477d00 100644
--- a/hanlp/optimizers/adamw/__init__.py
+++ b/hanlp/optimizers/adamw/__init__.py
@@ -33,7 +33,19 @@
def create_optimizer(init_lr, num_train_steps, num_warmup_steps, weight_decay_rate=0.01, epsilon=1e-6, clipnorm=None):
- """Creates an optimizer with learning rate schedule."""
+ """Creates an optimizer with learning rate schedule.
+
+ Args:
+ init_lr:
+ num_train_steps:
+ num_warmup_steps:
+ weight_decay_rate: (Default value = 0.01)
+ epsilon: (Default value = 1e-6)
+ clipnorm: (Default value = None)
+
+ Returns:
+
+ """
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
diff --git a/hanlp/optimizers/adamw/optimization.py b/hanlp/optimizers/adamw/optimization.py
index 53eb61479..acbb90196 100644
--- a/hanlp/optimizers/adamw/optimization.py
+++ b/hanlp/optimizers/adamw/optimization.py
@@ -24,174 +24,211 @@
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
- """Applys a warmup schedule on a given learning rate decay schedule."""
-
- def __init__(
- self,
- initial_learning_rate,
- decay_schedule_fn,
- warmup_steps,
- power=1.0,
- name=None):
- super(WarmUp, self).__init__()
- self.initial_learning_rate = initial_learning_rate
- self.warmup_steps = warmup_steps
- self.power = power
- self.decay_schedule_fn = decay_schedule_fn
- self.name = name
-
- def __call__(self, step):
- with tf.name_scope(self.name or 'WarmUp') as name:
- # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
- # learning rate will be `global_step/num_warmup_steps * init_lr`.
- global_step_float = tf.cast(step, tf.float32)
- warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
- warmup_percent_done = global_step_float / warmup_steps_float
- warmup_learning_rate = (
- self.initial_learning_rate *
- tf.math.pow(warmup_percent_done, self.power))
- return tf.cond(global_step_float < warmup_steps_float,
- lambda: warmup_learning_rate,
- lambda: self.decay_schedule_fn(step),
- name=name)
-
- def get_config(self):
- return {
- 'initial_learning_rate': self.initial_learning_rate,
- 'decay_schedule_fn': self.decay_schedule_fn,
- 'warmup_steps': self.warmup_steps,
- 'power': self.power,
- 'name': self.name
- }
+ """Applys a warmup schedule on a given learning rate decay schedule."""
+
+ def __init__(
+ self,
+ initial_learning_rate,
+ decay_schedule_fn,
+ warmup_steps,
+ power=1.0,
+ name=None):
+ super(WarmUp, self).__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.warmup_steps = warmup_steps
+ self.power = power
+ self.decay_schedule_fn = decay_schedule_fn
+ self.name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self.name or 'WarmUp') as name:
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ global_step_float = tf.cast(step, tf.float32)
+ warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
+ warmup_percent_done = global_step_float / warmup_steps_float
+ warmup_learning_rate = (
+ self.initial_learning_rate *
+ tf.math.pow(warmup_percent_done, self.power))
+ return tf.cond(global_step_float < warmup_steps_float,
+ lambda: warmup_learning_rate,
+ lambda: self.decay_schedule_fn(step),
+ name=name)
+
+ def get_config(self):
+ return {
+ 'initial_learning_rate': self.initial_learning_rate,
+ 'decay_schedule_fn': self.decay_schedule_fn,
+ 'warmup_steps': self.warmup_steps,
+ 'power': self.power,
+ 'name': self.name
+ }
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
- """Creates an optimizer with learning rate schedule."""
- # Implements linear decay of the learning rate.
- learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
- initial_learning_rate=init_lr,
- decay_steps=num_train_steps,
- end_learning_rate=0.0)
- if num_warmup_steps:
- learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
- decay_schedule_fn=learning_rate_fn,
- warmup_steps=num_warmup_steps)
- optimizer = AdamWeightDecay(
- learning_rate=learning_rate_fn,
- weight_decay_rate=0.01,
- beta_1=0.9,
- beta_2=0.999,
- epsilon=1e-6,
- exclude_from_weight_decay=['layer_norm', 'bias'])
- return optimizer
+ """Creates an optimizer with learning rate schedule.
+
+ Args:
+ init_lr:
+ num_train_steps:
+ num_warmup_steps:
+
+ Returns:
+
+ """
+ # Implements linear decay of the learning rate.
+ learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
+ initial_learning_rate=init_lr,
+ decay_steps=num_train_steps,
+ end_learning_rate=0.0)
+ if num_warmup_steps:
+ learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
+ decay_schedule_fn=learning_rate_fn,
+ warmup_steps=num_warmup_steps)
+ optimizer = AdamWeightDecay(
+ learning_rate=learning_rate_fn,
+ weight_decay_rate=0.01,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-6,
+ exclude_from_weight_decay=['layer_norm', 'bias'])
+ return optimizer
class AdamWeightDecay(tf.keras.optimizers.Adam):
- """Adam enables L2 weight decay and clip_by_global_norm on gradients.
-
- Just adding the square of the weights to the loss function is *not* the
- correct way of using L2 regularization/weight decay with Adam, since that will
- interact with the m and v parameters in strange ways.
-
- Instead we want ot decay the weights in a manner that doesn't interact with
- the m/v parameters. This is equivalent to adding the square of the weights to
- the loss with plain (non-momentum) SGD.
- """
-
- def __init__(self,
- learning_rate=0.001,
- beta_1=0.9,
- beta_2=0.999,
- epsilon=1e-7,
- amsgrad=False,
- weight_decay_rate=0.0,
- include_in_weight_decay=None,
- exclude_from_weight_decay=None,
- name='AdamWeightDecay',
- **kwargs):
- super(AdamWeightDecay, self).__init__(
- learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
- self.weight_decay_rate = weight_decay_rate
- self._include_in_weight_decay = include_in_weight_decay
- self._exclude_from_weight_decay = exclude_from_weight_decay
-
- @classmethod
- def from_config(cls, config):
- """Creates an optimizer from its config with WarmUp custom object."""
- custom_objects = {'WarmUp': WarmUp}
- return super(AdamWeightDecay, cls).from_config(
- config, custom_objects=custom_objects)
-
- def _prepare_local(self, var_device, var_dtype, apply_state):
- super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
- apply_state)
- apply_state['weight_decay_rate'] = tf.constant(
- self.weight_decay_rate, name='adam_weight_decay_rate')
-
- def _decay_weights_op(self, var, learning_rate, apply_state):
- do_decay = self._do_use_weight_decay(var.name)
- if do_decay:
- return var.assign_sub(
- learning_rate * var *
- apply_state['weight_decay_rate'],
- use_locking=self._use_locking)
- return tf.no_op()
-
- def apply_gradients(self, grads_and_vars, name=None):
- grads, tvars = list(zip(*grads_and_vars))
- (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
- return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
-
- def _get_lr(self, var_device, var_dtype, apply_state):
- """Retrieves the learning rate with the given state."""
- if apply_state is None:
- return self._decayed_lr_t[var_dtype], {}
-
- apply_state = apply_state or {}
- coefficients = apply_state.get((var_device, var_dtype))
- if coefficients is None:
- coefficients = self._fallback_apply_state(var_device, var_dtype)
- apply_state[(var_device, var_dtype)] = coefficients
-
- return coefficients['lr_t'], dict(apply_state=apply_state)
-
- def _resource_apply_dense(self, grad, var, apply_state=None):
- lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
- decay = self._decay_weights_op(var, lr_t, apply_state)
- with tf.control_dependencies([decay]):
- return super(AdamWeightDecay, self)._resource_apply_dense(
- grad, var, **kwargs)
-
- def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
- lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
- decay = self._decay_weights_op(var, lr_t, apply_state)
- with tf.control_dependencies([decay]):
- return super(AdamWeightDecay, self)._resource_apply_sparse(
- grad, var, indices, **kwargs)
-
- def get_config(self):
- config = super(AdamWeightDecay, self).get_config()
- config.update({
- 'weight_decay_rate': self.weight_decay_rate,
- })
- return config
-
- def _do_use_weight_decay(self, param_name):
- """Whether to use L2 weight decay for `param_name`."""
- if self.weight_decay_rate == 0:
- return False
-
- if self._include_in_weight_decay:
- for r in self._include_in_weight_decay:
- if re.search(r, param_name) is not None:
- return True
-
- if self._exclude_from_weight_decay:
- for r in self._exclude_from_weight_decay:
- if re.search(r, param_name) is not None:
- return False
- return True
-
- def apply_gradients(self, grads_and_vars, name=None, **kwargs):
- grads, tvars = list(zip(*grads_and_vars))
- return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs)
+ """Adam enables L2 weight decay and clip_by_global_norm on gradients.
+
+ Just adding the square of the weights to the loss function is *not* the
+ correct way of using L2 regularization/weight decay with Adam, since that will
+ interact with the m and v parameters in strange ways.
+
+ Instead we want ot decay the weights in a manner that doesn't interact with
+ the m/v parameters. This is equivalent to adding the square of the weights to
+ the loss with plain (non-momentum) SGD.
+
+ Args:
+
+ Returns:
+
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-7,
+ amsgrad=False,
+ weight_decay_rate=0.0,
+ include_in_weight_decay=None,
+ exclude_from_weight_decay=None,
+ name='AdamWeightDecay',
+ **kwargs):
+ super(AdamWeightDecay, self).__init__(
+ learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
+ self.weight_decay_rate = weight_decay_rate
+ self._include_in_weight_decay = include_in_weight_decay
+ self._exclude_from_weight_decay = exclude_from_weight_decay
+
+ @classmethod
+ def from_config(cls, config):
+ """Creates an optimizer from its config with WarmUp custom object.
+
+ Args:
+ config:
+
+ Returns:
+
+ """
+ custom_objects = {'WarmUp': WarmUp}
+ return super(AdamWeightDecay, cls).from_config(
+ config, custom_objects=custom_objects)
+
+ def _prepare_local(self, var_device, var_dtype, apply_state):
+ super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
+ apply_state)
+ apply_state['weight_decay_rate'] = tf.constant(
+ self.weight_decay_rate, name='adam_weight_decay_rate')
+
+ def _decay_weights_op(self, var, learning_rate, apply_state):
+ do_decay = self._do_use_weight_decay(var.name)
+ if do_decay:
+ return var.assign_sub(
+ learning_rate * var *
+ apply_state['weight_decay_rate'],
+ use_locking=self._use_locking)
+ return tf.no_op()
+
+ def apply_gradients(self, grads_and_vars, name=None):
+ grads, tvars = list(zip(*grads_and_vars))
+ (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
+ return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
+
+ def _get_lr(self, var_device, var_dtype, apply_state):
+ """Retrieves the learning rate with the given state.
+
+ Args:
+ var_device:
+ var_dtype:
+ apply_state:
+
+ Returns:
+
+ """
+ if apply_state is None:
+ return self._decayed_lr_t[var_dtype], {}
+
+ apply_state = apply_state or {}
+ coefficients = apply_state.get((var_device, var_dtype))
+ if coefficients is None:
+ coefficients = self._fallback_apply_state(var_device, var_dtype)
+ apply_state[(var_device, var_dtype)] = coefficients
+
+ return coefficients['lr_t'], dict(apply_state=apply_state)
+
+ def _resource_apply_dense(self, grad, var, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super(AdamWeightDecay, self)._resource_apply_dense(
+ grad, var, **kwargs)
+
+ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super(AdamWeightDecay, self)._resource_apply_sparse(
+ grad, var, indices, **kwargs)
+
+ def get_config(self):
+ config = super(AdamWeightDecay, self).get_config()
+ config.update({
+ 'weight_decay_rate': self.weight_decay_rate,
+ })
+ return config
+
+ def _do_use_weight_decay(self, param_name):
+ """Whether to use L2 weight decay for `param_name`.
+
+ Args:
+ param_name:
+
+ Returns:
+
+ """
+ if self.weight_decay_rate == 0:
+ return False
+
+ if self._include_in_weight_decay:
+ for r in self._include_in_weight_decay:
+ if re.search(r, param_name) is not None:
+ return True
+
+ if self._exclude_from_weight_decay:
+ for r in self._exclude_from_weight_decay:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
+
+ def apply_gradients(self, grads_and_vars, name=None, **kwargs):
+ grads, tvars = list(zip(*grads_and_vars))
+ return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars), name=name, **kwargs)
diff --git a/hanlp/pretrained/__init__.py b/hanlp/pretrained/__init__.py
index f11cd2943..c58eeab4b 100644
--- a/hanlp/pretrained/__init__.py
+++ b/hanlp/pretrained/__init__.py
@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 19:10
-from hanlp.pretrained import cws
+from hanlp.pretrained import tok
from hanlp.pretrained import dep
from hanlp.pretrained import sdp
from hanlp.pretrained import glove
@@ -10,6 +10,9 @@
from hanlp.pretrained import word2vec
from hanlp.pretrained import ner
from hanlp.pretrained import classifiers
+from hanlp.pretrained import fasttext
+from hanlp.pretrained import mtl
+from hanlp.pretrained import eos
# Will be filled up during runtime
ALL = {}
diff --git a/hanlp/pretrained/classifiers.py b/hanlp/pretrained/classifiers.py
index 066fcc3c0..bdc7a7191 100644
--- a/hanlp/pretrained/classifiers.py
+++ b/hanlp/pretrained/classifiers.py
@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-01-01 03:51
-from hanlp.common.constant import HANLP_URL
+from hanlp_common.constant import HANLP_URL
CHNSENTICORP_BERT_BASE_ZH = HANLP_URL + 'classification/chnsenticorp_bert_base_20200104_164655.zip'
SST2_BERT_BASE_EN = HANLP_URL + 'classification/sst2_bert_base_uncased_en_20200210_090240.zip'
diff --git a/hanlp/pretrained/cws.py b/hanlp/pretrained/cws.py
deleted file mode 100644
index 5eb4fc9a8..000000000
--- a/hanlp/pretrained/cws.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 21:12
-from hanlp.common.constant import HANLP_URL
-
-SIGHAN2005_PKU_CONVSEG = HANLP_URL + 'cws/sighan2005-pku-convseg_20200110_153722.zip'
-SIGHAN2005_MSR_CONVSEG = HANLP_URL + 'cws/convseg-msr-nocrf-noembed_20200110_153524.zip'
-# SIGHAN2005_MSR_BERT_BASE = HANLP_URL + 'cws/cws_bert_base_msra_20191230_194627.zip'
-CTB6_CONVSEG = HANLP_URL + 'cws/ctb6_convseg_nowe_nocrf_20200110_004046.zip'
-# CTB6_BERT_BASE = HANLP_URL + 'cws/cws_bert_base_ctb6_20191230_185536.zip'
-PKU_NAME_MERGED_SIX_MONTHS_CONVSEG = HANLP_URL + 'cws/pku98_6m_conv_ngram_20200110_134736.zip'
-LARGE_ALBERT_BASE = HANLP_URL + 'cws/large_cws_albert_base_20200828_011451.zip'
-# Will be filled up during runtime
-ALL = {}
diff --git a/hanlp/pretrained/dep.py b/hanlp/pretrained/dep.py
index 82e802663..778772a2a 100644
--- a/hanlp/pretrained/dep.py
+++ b/hanlp/pretrained/dep.py
@@ -1,11 +1,14 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-29 02:55
-from hanlp.common.constant import HANLP_URL
+from hanlp_common.constant import HANLP_URL
CTB5_BIAFFINE_DEP_ZH = HANLP_URL + 'dep/biaffine_ctb5_20191229_025833.zip'
+'Biaffine LSTM model (:cite:`dozat:17a`) trained on CTB5.'
CTB7_BIAFFINE_DEP_ZH = HANLP_URL + 'dep/biaffine_ctb7_20200109_022431.zip'
+'Biaffine LSTM model (:cite:`dozat:17a`) trained on CTB7.'
PTB_BIAFFINE_DEP_EN = HANLP_URL + 'dep/ptb_dep_biaffine_20200101_174624.zip'
+'Biaffine LSTM model (:cite:`dozat:17a`) trained on PTB.'
-ALL = {}
\ No newline at end of file
+ALL = {}
diff --git a/hanlp/pretrained/eos.py b/hanlp/pretrained/eos.py
new file mode 100644
index 000000000..d188e96c8
--- /dev/null
+++ b/hanlp/pretrained/eos.py
@@ -0,0 +1,10 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-22 13:22
+from hanlp_common.constant import HANLP_URL
+
+UD_CTB_EOS_MUL = HANLP_URL + 'eos/eos_ud_ctb_mul_20201222_133543.zip'
+'EOS model (:cite:`Schweter:Ahmed:2019`) trained on concatenated UD2.3 and CTB9.'
+
+# Will be filled up during runtime
+ALL = {}
diff --git a/hanlp/pretrained/fasttext.py b/hanlp/pretrained/fasttext.py
index 70eb5e9ca..34b9077a7 100644
--- a/hanlp/pretrained/fasttext.py
+++ b/hanlp/pretrained/fasttext.py
@@ -2,11 +2,15 @@
# Author: hankcs
# Date: 2019-12-30 18:57
FASTTEXT_DEBUG_EMBEDDING_EN = 'https://elit-models.s3-us-west-2.amazonaws.com/fasttext.debug.bin.zip'
-FASTTEXT_CC_300_EN = 'https://elit-models.s3-us-west-2.amazonaws.com/cc.en.300.bin.zip'
+FASTTEXT_CC_300_EN = 'https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz'
+'FastText (:cite:`bojanowski2017enriching`) embeddings trained on Common Crawl.'
FASTTEXT_WIKI_NYT_AMAZON_FRIENDS_200_EN \
= 'https://elit-models.s3-us-west-2.amazonaws.com/fasttext-200-wikipedia-nytimes-amazon-friends-20191107.bin'
+'FastText (:cite:`bojanowski2017enriching`) embeddings trained on wikipedia, nytimes and friends.'
FASTTEXT_WIKI_300_ZH = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.zh.zip#wiki.zh.bin'
+'FastText (:cite:`bojanowski2017enriching`) embeddings trained on Chinese Wikipedia.'
FASTTEXT_WIKI_300_ZH_CLASSICAL = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.zh_classical.zip#wiki.zh_classical.bin'
+'FastText (:cite:`bojanowski2017enriching`) embeddings trained on traditional Chinese wikipedia.'
-ALL = {}
\ No newline at end of file
+ALL = {}
diff --git a/hanlp/pretrained/glove.py b/hanlp/pretrained/glove.py
index 495fe0779..c5d1f3b62 100644
--- a/hanlp/pretrained/glove.py
+++ b/hanlp/pretrained/glove.py
@@ -2,13 +2,18 @@
# Author: hankcs
# Date: 2019-08-27 20:42
-GLOVE_6B_ROOT = 'http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip'
+_GLOVE_6B_ROOT = 'http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip'
-GLOVE_6B_50D = GLOVE_6B_ROOT + '#' + 'glove.6B.50d.txt'
-GLOVE_6B_100D = GLOVE_6B_ROOT + '#' + 'glove.6B.100d.txt'
-GLOVE_6B_200D = GLOVE_6B_ROOT + '#' + 'glove.6B.200d.txt'
-GLOVE_6B_300D = GLOVE_6B_ROOT + '#' + 'glove.6B.300d.txt'
+GLOVE_6B_50D = _GLOVE_6B_ROOT + '#' + 'glove.6B.50d.txt'
+'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 50d trained on 6B tokens.'
+GLOVE_6B_100D = _GLOVE_6B_ROOT + '#' + 'glove.6B.100d.txt'
+'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 100d trained on 6B tokens.'
+GLOVE_6B_200D = _GLOVE_6B_ROOT + '#' + 'glove.6B.200d.txt'
+'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 200d trained on 6B tokens.'
+GLOVE_6B_300D = _GLOVE_6B_ROOT + '#' + 'glove.6B.300d.txt'
+'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 300d trained on 6B tokens.'
GLOVE_840B_300D = 'http://nlp.stanford.edu/data/glove.840B.300d.zip'
+'Global Vectors for Word Representation (:cite:`pennington-etal-2014-glove`) 300d trained on 840B tokens.'
-ALL = {}
\ No newline at end of file
+ALL = {}
diff --git a/hanlp/pretrained/mtl.py b/hanlp/pretrained/mtl.py
new file mode 100644
index 000000000..a1118c85c
--- /dev/null
+++ b/hanlp/pretrained/mtl.py
@@ -0,0 +1,22 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-22 13:16
+from hanlp_common.constant import HANLP_URL
+
+OPEN_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH = HANLP_URL + 'mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_small_20201223_035557.zip'
+"Electra small version of joint tok, pos, ner, srl, dep, sdp and con model trained on open-source Chinese corpus."
+OPEN_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_BASE_ZH = HANLP_URL + 'mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_base_20201223_201906.zip'
+"Electra base version of joint tok, pos, ner, srl, dep, sdp and con model trained on open-source Chinese corpus."
+CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH = HANLP_URL + 'mtl/close_tok_pos_ner_srl_dep_sdp_con_electra_small_zh_20201222_130611.zip'
+"Electra small version of joint tok, pos, ner, srl, dep, sdp and con model trained on private Chinese corpus."
+CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_BASE_ZH = HANLP_URL + 'mtl/close_tok_pos_ner_srl_dep_sdp_con_electra_base_20201226_221208.zip'
+"Electra base version of joint tok, pos, ner, srl, dep, sdp and con model trained on private Chinese corpus."
+
+UD_ONTONOTES_TOK_POS_LEM_FEA_NER_SRL_DEP_SDP_CON_MT5_SMALL = HANLP_URL + 'mtl/ud_ontonotes_tok_pos_lem_fea_ner_srl_dep_sdp_con_mt5_small_20201231_211858.zip'
+'mt5 small version of joint tok, pos, lem, fea, ner, srl, dep, sdp and con model trained on UD and OntoNotes5 corpus.'
+
+UD_ONTONOTES_TOK_POS_LEM_FEA_NER_SRL_DEP_SDP_CON_MT5_BASE = HANLP_URL + 'mtl/ud_ontonotes_tok_pos_lem_fea_ner_srl_dep_sdp_con_xlm_base_20201231_232029.zip'
+'XLM-R base version of joint tok, pos, lem, fea, ner, srl, dep, sdp and con model trained on UD and OntoNotes5 corpus.'
+
+# Will be filled up during runtime
+ALL = {}
diff --git a/hanlp/pretrained/ner.py b/hanlp/pretrained/ner.py
index ebae4d9c8..ee9b99e42 100644
--- a/hanlp/pretrained/ner.py
+++ b/hanlp/pretrained/ner.py
@@ -1,10 +1,13 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-30 20:07
-from hanlp.common.constant import HANLP_URL
+from hanlp_common.constant import HANLP_URL
MSRA_NER_BERT_BASE_ZH = HANLP_URL + 'ner/ner_bert_base_msra_20200104_185735.zip'
+'BERT model (:cite:`devlin-etal-2019-bert`) trained on MSRA with 3 entity types.'
MSRA_NER_ALBERT_BASE_ZH = HANLP_URL + 'ner/ner_albert_base_zh_msra_20200111_202919.zip'
+'ALBERT model (:cite:`Lan2020ALBERT:`) trained on MSRA with 3 entity types.'
CONLL03_NER_BERT_BASE_UNCASED_EN = HANLP_URL + 'ner/ner_conll03_bert_base_uncased_en_20200104_194352.zip'
+'BERT model (:cite:`devlin-etal-2019-bert`) trained on CoNLL03.'
ALL = {}
diff --git a/hanlp/pretrained/pos.py b/hanlp/pretrained/pos.py
index 29bc6fcaf..69719c771 100644
--- a/hanlp/pretrained/pos.py
+++ b/hanlp/pretrained/pos.py
@@ -1,12 +1,16 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-29 01:57
-from hanlp.common.constant import HANLP_URL
+from hanlp_common.constant import HANLP_URL
CTB5_POS_RNN = HANLP_URL + 'pos/ctb5_pos_rnn_20200113_235925.zip'
+'An old school BiLSTM tagging model trained on CTB5.'
CTB5_POS_RNN_FASTTEXT_ZH = HANLP_URL + 'pos/ctb5_pos_rnn_fasttext_20191230_202639.zip'
+'An old school BiLSTM tagging model with FastText (:cite:`bojanowski2017enriching`) embeddings trained on CTB5.'
CTB9_POS_ALBERT_BASE = HANLP_URL + 'pos/ctb9_albert_base_zh_epoch_20_20201011_090522.zip'
+'ALBERT model (:cite:`Lan2020ALBERT:`) trained on CTB9.'
PTB_POS_RNN_FASTTEXT_EN = HANLP_URL + 'pos/ptb_pos_rnn_fasttext_20200103_145337.zip'
+'An old school BiLSTM tagging model with FastText (:cite:`bojanowski2017enriching`) embeddings trained on PTB.'
ALL = {}
\ No newline at end of file
diff --git a/hanlp/pretrained/rnnlm.py b/hanlp/pretrained/rnnlm.py
index b34e185ae..30d9bb4c3 100644
--- a/hanlp/pretrained/rnnlm.py
+++ b/hanlp/pretrained/rnnlm.py
@@ -1,9 +1,10 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-19 03:47
-from hanlp.common.constant import HANLP_URL
+from hanlp_common.constant import HANLP_URL
-FLAIR_LM_FW_WMT11_EN = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_fw_wmt11_en'
-FLAIR_LM_BW_WMT11_EN = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_bw_wmt11_en'
+FLAIR_LM_FW_WMT11_EN_TF = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_fw_wmt11_en'
+FLAIR_LM_BW_WMT11_EN_TF = HANLP_URL + 'lm/flair_lm_wmt11_en_20200211_091932.zip#flair_lm_bw_wmt11_en'
+FLAIR_LM_WMT11_EN = HANLP_URL + 'lm/flair_lm_wmt11_en_20200601_205350.zip'
-ALL = {}
\ No newline at end of file
+ALL = {}
diff --git a/hanlp/pretrained/sdp.py b/hanlp/pretrained/sdp.py
index abaf4f7de..9c498267f 100644
--- a/hanlp/pretrained/sdp.py
+++ b/hanlp/pretrained/sdp.py
@@ -1,13 +1,18 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-31 23:54
-from hanlp.common.constant import HANLP_URL
+from hanlp_common.constant import HANLP_URL
SEMEVAL16_NEWS_BIAFFINE_ZH = HANLP_URL + 'sdp/semeval16-news-biaffine_20191231_235407.zip'
+'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval16 news data.'
SEMEVAL16_TEXT_BIAFFINE_ZH = HANLP_URL + 'sdp/semeval16-text-biaffine_20200101_002257.zip'
+'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval16 text data.'
SEMEVAL15_PAS_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_pas_20200103_152405.zip'
+'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval15 PAS data.'
SEMEVAL15_PSD_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_psd_20200106_123009.zip'
+'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval15 PSD data.'
SEMEVAL15_DM_BIAFFINE_EN = HANLP_URL + 'sdp/semeval15_biaffine_dm_20200106_122808.zip'
+'Biaffine SDP (:cite:`bertbaseline`) trained on SemEval15 DM data.'
ALL = {}
diff --git a/hanlp/pretrained/tok.py b/hanlp/pretrained/tok.py
new file mode 100644
index 000000000..8cbaf5092
--- /dev/null
+++ b/hanlp/pretrained/tok.py
@@ -0,0 +1,22 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 21:12
+from hanlp_common.constant import HANLP_URL
+
+SIGHAN2005_PKU_CONVSEG = HANLP_URL + 'tok/sighan2005-pku-convseg_20200110_153722.zip'
+'Conv model (:cite:`wang-xu-2017-convolutional`) trained on sighan2005 pku dataset.'
+SIGHAN2005_MSR_CONVSEG = HANLP_URL + 'tok/convseg-msr-nocrf-noembed_20200110_153524.zip'
+'Conv model (:cite:`wang-xu-2017-convolutional`) trained on sighan2005 msr dataset.'
+# SIGHAN2005_MSR_BERT_BASE = HANLP_URL + 'tok/cws_bert_base_msra_20191230_194627.zip'
+CTB6_CONVSEG = HANLP_URL + 'tok/ctb6_convseg_nowe_nocrf_20200110_004046.zip'
+'Conv model (:cite:`wang-xu-2017-convolutional`) trained on CTB6 dataset.'
+# CTB6_BERT_BASE = HANLP_URL + 'tok/cws_bert_base_ctb6_20191230_185536.zip'
+PKU_NAME_MERGED_SIX_MONTHS_CONVSEG = HANLP_URL + 'tok/pku98_6m_conv_ngram_20200110_134736.zip'
+'Conv model (:cite:`wang-xu-2017-convolutional`) trained on pku98 six months dataset with name merged into one unit.'
+LARGE_ALBERT_BASE = HANLP_URL + 'tok/large_cws_albert_base_20200828_011451.zip'
+'ALBERT model (:cite:`Lan2020ALBERT:`) trained on the largest CWS dataset in the world.'
+SIGHAN2005_PKU_BERT_BASE_ZH = HANLP_URL + 'tok/sighan2005_pku_bert_base_zh_20201231_141130.zip'
+'BERT model (:cite:`devlin-etal-2019-bert`) trained on sighan2005 pku dataset.'
+
+# Will be filled up during runtime
+ALL = {}
diff --git a/hanlp/pretrained/word2vec.py b/hanlp/pretrained/word2vec.py
index ba89a8cc4..2ec217a41 100644
--- a/hanlp/pretrained/word2vec.py
+++ b/hanlp/pretrained/word2vec.py
@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-21 18:25
-from hanlp.common.constant import HANLP_URL
+from hanlp_common.constant import HANLP_URL
CONVSEG_W2V_NEWS_TENSITE = HANLP_URL + 'embeddings/convseg_embeddings.zip'
CONVSEG_W2V_NEWS_TENSITE_WORD_PKU = CONVSEG_W2V_NEWS_TENSITE + '#news_tensite.pku.words.w2v50'
@@ -17,5 +17,12 @@
TENCENT_AI_LAB_EMBEDDING = 'https://ai.tencent.com/ailab/nlp/data/Tencent_AILab_ChineseEmbedding.tar.gz#Tencent_AILab_ChineseEmbedding.txt'
RADICAL_CHAR_EMBEDDING_100 = HANLP_URL + 'embeddings/radical_char_vec_20191229_013849.zip#character.vec.txt'
+'Chinese character embedding enhanced with rich radical information (:cite:`he2018dual`).'
+
+_SUBWORD_ENCODING_CWS = HANLP_URL + 'embeddings/subword_encoding_cws_20200524_190636.zip'
+SUBWORD_ENCODING_CWS_ZH_WIKI_BPE_50 = _SUBWORD_ENCODING_CWS + '#zh.wiki.bpe.vs200000.d50.w2v.txt'
+SUBWORD_ENCODING_CWS_GIGAWORD_UNI = _SUBWORD_ENCODING_CWS + '#gigaword_chn.all.a2b.uni.ite50.vec'
+SUBWORD_ENCODING_CWS_GIGAWORD_BI = _SUBWORD_ENCODING_CWS + '#gigaword_chn.all.a2b.bi.ite50.vec'
+SUBWORD_ENCODING_CWS_CTB_GAZETTEER_50 = _SUBWORD_ENCODING_CWS + '#ctb.50d.vec'
ALL = {}
diff --git a/hanlp/transform/conll_tf.py b/hanlp/transform/conll_tf.py
new file mode 100644
index 000000000..4c7366370
--- /dev/null
+++ b/hanlp/transform/conll_tf.py
@@ -0,0 +1,799 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-08 15:30
+from abc import abstractmethod
+from collections import Counter
+from typing import Union, Tuple, Iterable, Any, Generator
+
+import numpy as np
+import tensorflow as tf
+from transformers import PreTrainedTokenizer, PretrainedConfig
+
+from hanlp_common.constant import ROOT
+from hanlp_common.structure import SerializableDict
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.components.parsers.alg_tf import tolist, kmeans, randperm, arange
+from hanlp.components.parsers.conll import read_conll
+from hanlp_common.conll import CoNLLWord, CoNLLUWord, CoNLLSentence
+from hanlp.layers.transformers.utils_tf import config_is, adjust_tokens_for_transformers, convert_examples_to_features
+from hanlp.utils.log_util import logger
+from hanlp.utils.string_util import ispunct
+from hanlp_common.util import merge_locals_kwargs
+
+
+class CoNLLTransform(Transform):
+
+ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32, min_freq=2,
+ use_pos=True, **kwargs) -> None:
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.form_vocab: VocabTF = None
+ if use_pos:
+ self.cpos_vocab: VocabTF = None
+ self.rel_vocab: VocabTF = None
+ self.puncts: tf.Tensor = None
+
+ @property
+ def use_pos(self):
+ return self.config.get('use_pos', True)
+
+ def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
+ form, cpos = x
+ return self.form_vocab.token_to_idx_table.lookup(form), self.cpos_vocab.token_to_idx_table.lookup(cpos)
+
+ def y_to_idx(self, y):
+ head, rel = y
+ return head, self.rel_vocab.token_to_idx_table.lookup(rel)
+
+ def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
+ if len(X) == 2:
+ form_batch, cposes_batch = X
+ mask = tf.not_equal(form_batch, 0)
+ elif len(X) == 3:
+ form_batch, cposes_batch, mask = X
+ else:
+ raise ValueError(f'Expect X to be 2 or 3 elements but got {repr(X)}')
+ sents = []
+
+ for form_sent, cposes_sent, length in zip(form_batch, cposes_batch,
+ tf.math.count_nonzero(mask, axis=-1)):
+ forms = tolist(form_sent)[1:length + 1]
+ cposes = tolist(cposes_sent)[1:length + 1]
+ sents.append([(self.form_vocab.idx_to_token[f],
+ self.cpos_vocab.idx_to_token[c]) for f, c in zip(forms, cposes)])
+
+ return sents
+
+ def lock_vocabs(self):
+ super().lock_vocabs()
+ self.puncts = tf.constant([i for s, i in self.form_vocab.token_to_idx.items()
+ if ispunct(s)], dtype=tf.int64)
+
+ def file_to_inputs(self, filepath: str, gold=True):
+ assert gold, 'only support gold file for now'
+ use_pos = self.use_pos
+ conllu = filepath.endswith('.conllu')
+ for sent in read_conll(filepath):
+ for i, cell in enumerate(sent):
+ form = cell[1]
+ cpos = cell[3]
+ head = cell[6]
+ deprel = cell[7]
+ # if conllu:
+ # deps = cell[8]
+ # deps = [x.split(':', 1) for x in deps.split('|')]
+ # heads = [int(x[0]) for x in deps if '_' not in x[0] and '.' not in x[0]]
+ # rels = [x[1] for x in deps if '_' not in x[0] and '.' not in x[0]]
+ # if head in heads:
+ # offset = heads.index(head)
+ # if not self.rel_vocab or rels[offset] in self.rel_vocab:
+ # deprel = rels[offset]
+ sent[i] = [form, cpos, head, deprel] if use_pos else [form, head, deprel]
+ yield sent
+
+ @property
+ def bos(self):
+ if self.form_vocab.idx_to_token is None:
+ return ROOT
+ return self.form_vocab.idx_to_token[2]
+
+ def input_is_single_sample(self, input: Any) -> bool:
+ if self.use_pos:
+ return isinstance(input[0][0], str) if len(input[0]) else False
+ else:
+ return isinstance(input[0], str) if len(input[0]) else False
+
+ @abstractmethod
+ def batched_inputs_to_batches(self, corpus, indices, shuffle):
+ pass
+
+ def len_of_sent(self, sent):
+ return 1 + len(sent) # take ROOT into account
+
+ def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_size=5000, shuffle=None, repeat=None,
+ drop_remainder=False, prefetch=1, cache=True) -> tf.data.Dataset:
+ def generator():
+ # custom bucketing, load corpus into memory
+ corpus = list(x for x in (samples() if callable(samples) else samples))
+ lengths = [self.len_of_sent(i) for i in corpus]
+ if len(corpus) < 32:
+ n_buckets = 1
+ else:
+ n_buckets = min(self.config.n_buckets, len(corpus))
+ buckets = dict(zip(*kmeans(lengths, n_buckets)))
+ # buckets = dict(zip(*kmeans(lengths, n_buckets, 233)))
+ sizes, buckets = zip(*[
+ (size, bucket) for size, bucket in buckets.items()
+ ])
+ # the number of chunks in each bucket, which is clipped by
+ # range [1, len(bucket)]. Thus how many batches of batch_size in each bucket
+ chunks = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1)) for size, bucket in
+ zip(sizes, buckets)]
+ range_fn = randperm if shuffle else arange
+ max_samples_per_batch = self.config.get('max_samples_per_batch', None)
+ # range_fn = arange
+ for i in tolist(range_fn(len(buckets))):
+ split_sizes = [(len(buckets[i]) - j - 1) // chunks[i] + 1
+ for j in range(chunks[i])] # how many sentences in each batch
+ for batch_indices in tf.split(range_fn(len(buckets[i])), split_sizes):
+ indices = [buckets[i][j] for j in tolist(batch_indices)]
+ if max_samples_per_batch:
+ for j in range(0, len(indices), max_samples_per_batch):
+ yield from self.batched_inputs_to_batches(corpus, indices[j:j + max_samples_per_batch],
+ shuffle)
+ else:
+ yield from self.batched_inputs_to_batches(corpus, indices, shuffle)
+
+ # debug for CoNLLTransform
+ # next(generator())
+ return super().samples_to_dataset(generator, False, False, 0, False, repeat, drop_remainder, prefetch,
+ cache)
+
+
+class CoNLL_DEP_Transform(CoNLLTransform):
+
+ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32,
+ min_freq=2, **kwargs) -> None:
+ super().__init__(config, map_x, map_y, lower, n_buckets, min_freq, **kwargs)
+
+ def batched_inputs_to_batches(self, corpus, indices, shuffle):
+ """Convert batched inputs to batches of samples
+
+ Args:
+ corpus(list): A list of inputs
+ indices(list): A list of indices, each list belongs to a batch
+ shuffle:
+
+ Returns:
+
+
+ """
+ raw_batch = [[], [], [], []]
+ for idx in indices:
+ for b in raw_batch:
+ b.append([])
+ for cells in corpus[idx]:
+ for b, c, v in zip(raw_batch, cells,
+ [self.form_vocab, self.cpos_vocab, None, self.rel_vocab]):
+ b[-1].append(v.get_idx_without_add(c) if v else c)
+ batch = []
+ for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab, None, self.rel_vocab]):
+ b = tf.keras.preprocessing.sequence.pad_sequences(b, padding='post',
+ value=v.safe_pad_token_idx if v else 0,
+ dtype='int64')
+ batch.append(b)
+ assert len(batch) == 4
+ yield (batch[0], batch[1]), (batch[2], batch[3])
+
+ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
+ types = (tf.int64, tf.int64), (tf.int64, tf.int64)
+ shapes = ([None, None], [None, None]), ([None, None], [None, None])
+ values = (self.form_vocab.safe_pad_token_idx, self.cpos_vocab.safe_pad_token_idx), (
+ 0, self.rel_vocab.safe_pad_token_idx)
+ return types, shapes, values
+
+ def inputs_to_samples(self, inputs, gold=False):
+ token_mapping: dict = self.config.get('token_mapping', None)
+ use_pos = self.config.get('use_pos', True)
+ for sent in inputs:
+ sample = []
+ for i, cell in enumerate(sent):
+ if isinstance(cell, tuple):
+ cell = list(cell)
+ elif isinstance(cell, str):
+ cell = [cell]
+ if token_mapping:
+ cell[0] = token_mapping.get(cell[0], cell[0])
+ if self.config['lower']:
+ cell[0] = cell[0].lower()
+ if not gold:
+ cell += [0, self.rel_vocab.safe_pad_token]
+ sample.append(cell)
+ # insert root word with arbitrary fields, anyway it will be masked
+ # form, cpos, head, deprel = sample[0]
+ sample.insert(0, [self.bos, self.bos, 0, self.bos] if use_pos else [self.bos, 0, self.bos])
+ yield sample
+
+ def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]], Y: Union[tf.Tensor, Tuple[tf.Tensor]],
+ gold=False, inputs=None, conll=True, arc_scores=None, rel_scores=None) -> Iterable:
+ (words, feats, mask), (arc_preds, rel_preds) = X, Y
+ if inputs is None:
+ inputs = self.X_to_inputs(X)
+ ys = self.Y_to_outputs((arc_preds, rel_preds, mask), inputs=inputs)
+ sents = []
+ for x, y in zip(inputs, ys):
+ sent = CoNLLSentence()
+ for idx, (cell, (head, deprel)) in enumerate(zip(x, y)):
+ if self.use_pos and not self.config.get('joint_pos', None):
+ form, cpos = cell
+ else:
+ form, cpos = cell, None
+ if conll:
+ sent.append(
+ CoNLLWord(id=idx + 1, form=form, cpos=cpos, head=head, deprel=deprel) if conll == '.conll'
+ else CoNLLUWord(id=idx + 1, form=form, upos=cpos, head=head, deprel=deprel))
+ else:
+ sent.append([head, deprel])
+ sents.append(sent)
+ return sents
+
+ def fit(self, trn_path: str, **kwargs) -> int:
+ use_pos = self.config.use_pos
+ self.form_vocab = VocabTF()
+ self.form_vocab.add(ROOT) # make root the 2ed elements while 0th is pad, 1st is unk
+ if self.use_pos:
+ self.cpos_vocab = VocabTF(pad_token=None, unk_token=None)
+ self.rel_vocab = VocabTF(pad_token=None, unk_token=None)
+ num_samples = 0
+ counter = Counter()
+ for sent in self.file_to_samples(trn_path, gold=True):
+ num_samples += 1
+ for idx, cell in enumerate(sent):
+ if use_pos:
+ form, cpos, head, deprel = cell
+ else:
+ form, head, deprel = cell
+ if idx == 0:
+ root = form
+ else:
+ counter[form] += 1
+ if use_pos:
+ self.cpos_vocab.add(cpos)
+ self.rel_vocab.add(deprel)
+
+ for token in [token for token, freq in counter.items() if freq >= self.config.min_freq]:
+ self.form_vocab.add(token)
+ return num_samples
+
+ @property
+ def root_rel_idx(self):
+ root_rel_idx = self.config.get('root_rel_idx', None)
+ if root_rel_idx is None:
+ for idx, rel in enumerate(self.rel_vocab.idx_to_token):
+ if 'root' in rel.lower() and rel != self.bos:
+ self.config['root_rel_idx'] = root_rel_idx = idx
+ break
+ return root_rel_idx
+
+ def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
+ arc_preds, rel_preds, mask = Y
+ sents = []
+
+ for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
+ tf.math.count_nonzero(mask, axis=-1)):
+ arcs = tolist(arc_sent)[1:length + 1]
+ rels = tolist(rel_sent)[1:length + 1]
+ sents.append([(a, self.rel_vocab.idx_to_token[r]) for a, r in zip(arcs, rels)])
+
+ return sents
+
+
+class CoNLL_Transformer_Transform(CoNLL_DEP_Transform):
+
+ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True,
+ lower=True, n_buckets=32, min_freq=0, max_seq_length=256, use_pos=False,
+ mask_p=None, graph=False, topk=None,
+ **kwargs) -> None:
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.tokenizer: PreTrainedTokenizer = None
+ self.transformer_config: PretrainedConfig = None
+ if graph:
+ self.orphan_relation = ROOT
+
+ def lock_vocabs(self):
+ super().lock_vocabs()
+ if self.graph:
+ CoNLL_SDP_Transform._find_orphan_relation(self)
+
+ def fit(self, trn_path: str, **kwargs) -> int:
+ if self.config.get('joint_pos', None):
+ self.config.use_pos = True
+ if self.graph:
+ # noinspection PyCallByClass
+ num = CoNLL_SDP_Transform.fit(self, trn_path, **kwargs)
+ else:
+ num = super().fit(trn_path, **kwargs)
+ if self.config.get('topk', None):
+ counter = Counter()
+ for sent in self.file_to_samples(trn_path, gold=True):
+ for idx, cell in enumerate(sent):
+ form, head, deprel = cell
+ counter[form] += 1
+ self.topk_vocab = VocabTF()
+ for k, v in counter.most_common(self.config.topk):
+ self.topk_vocab.add(k)
+ return num
+
+ def inputs_to_samples(self, inputs, gold=False):
+ if self.graph:
+ yield from CoNLL_SDP_Transform.inputs_to_samples(self, inputs, gold)
+ else:
+ yield from super().inputs_to_samples(inputs, gold)
+
+ def file_to_inputs(self, filepath: str, gold=True):
+ if self.graph:
+ yield from CoNLL_SDP_Transform.file_to_inputs(self, filepath, gold)
+ else:
+ yield from super().file_to_inputs(filepath, gold)
+
+ @property
+ def mask_p(self) -> float:
+ return self.config.get('mask_p', None)
+
+ @property
+ def graph(self):
+ return self.config.get('graph', None)
+
+ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
+ mask_p = self.mask_p
+ types = (tf.int64, (tf.int64, tf.int64, tf.int64)), (tf.bool if self.graph else tf.int64, tf.int64, tf.int64) if mask_p else (
+ tf.bool if self.graph else tf.int64, tf.int64)
+ if self.graph:
+ shapes = ([None, None], ([None, None], [None, None], [None, None])), (
+ [None, None, None], [None, None, None], [None, None]) if mask_p else (
+ [None, None, None], [None, None, None])
+ else:
+ shapes = ([None, None], ([None, None], [None, None], [None, None])), (
+ [None, None], [None, None], [None, None]) if mask_p else ([None, None], [None, None])
+
+ values = (self.form_vocab.safe_pad_token_idx, (0, 0, 0)), \
+ (0, self.rel_vocab.safe_pad_token_idx, 0) if mask_p else (0, self.rel_vocab.safe_pad_token_idx)
+ types_shapes_values = types, shapes, values
+ if self.use_pos:
+ types_shapes_values = [((shapes[0][0], shapes[0][1] + (shapes[0][0],)), shapes[1]) for shapes in
+ types_shapes_values]
+ return types_shapes_values
+
+ def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
+ form_batch, feat, prefix_mask = X
+ sents = []
+
+ for form_sent, length in zip(form_batch, tf.math.count_nonzero(prefix_mask, axis=-1)):
+ forms = tolist(form_sent)[1:length + 1]
+ sents.append([self.form_vocab.idx_to_token[f] for f in forms])
+
+ return sents
+
+ def batched_inputs_to_batches(self, corpus, indices, shuffle):
+ use_pos = self.use_pos
+ if use_pos:
+ raw_batch = [[], [], [], []]
+ else:
+ raw_batch = [[], [], []]
+ if self.graph:
+ max_len = len(max([corpus[i] for i in indices], key=len))
+ for idx in indices:
+ arc = np.zeros((max_len, max_len), dtype=np.bool)
+ rel = np.zeros((max_len, max_len), dtype=np.int64)
+ for b in raw_batch[:2 if use_pos else 1]:
+ b.append([])
+ for m, cells in enumerate(corpus[idx]):
+ if use_pos:
+ for b, c, v in zip(raw_batch, cells, [None, self.cpos_vocab]):
+ b[-1].append(v.get_idx_without_add(c) if v else c)
+ else:
+ for b, c, v in zip(raw_batch, cells, [None]):
+ b[-1].append(c)
+ for n, r in zip(cells[-2], cells[-1]):
+ arc[m, n] = True
+ rid = self.rel_vocab.get_idx_without_add(r)
+ if rid is None:
+ logger.warning(f'Relation OOV: {r} not exists in train')
+ continue
+ rel[m, n] = rid
+ raw_batch[-2].append(arc)
+ raw_batch[-1].append(rel)
+ else:
+ for idx in indices:
+ for s in raw_batch:
+ s.append([])
+ for cells in corpus[idx]:
+ if use_pos:
+ for s, c, v in zip(raw_batch, cells, [None, self.cpos_vocab, None, self.rel_vocab]):
+ s[-1].append(v.get_idx_without_add(c) if v else c)
+ else:
+ for s, c, v in zip(raw_batch, cells, [None, None, self.rel_vocab]):
+ s[-1].append(v.get_idx_without_add(c) if v else c)
+
+ # Transformer tokenizing
+ config = self.transformer_config
+ tokenizer = self.tokenizer
+ xlnet = config_is(config, 'xlnet')
+ roberta = config_is(config, 'roberta')
+ pad_token = tokenizer.pad_token
+ pad_token_id = tokenizer.convert_tokens_to_ids([pad_token])[0]
+ cls_token = tokenizer.cls_token
+ sep_token = tokenizer.sep_token
+ max_seq_length = self.config.max_seq_length
+ batch_forms = []
+ batch_input_ids = []
+ batch_input_mask = []
+ batch_prefix_offset = []
+ mask_p = self.mask_p
+ if mask_p:
+ batch_masked_offsets = []
+ mask_token_id = tokenizer.mask_token_id
+ for sent_idx, sent in enumerate(raw_batch[0]):
+ batch_forms.append([self.form_vocab.get_idx_without_add(token) for token in sent])
+ sent = adjust_tokens_for_transformers(sent)
+ sent = sent[1:] # remove use [CLS] instead
+ pad_label_idx = self.form_vocab.pad_idx
+ input_ids, input_mask, segment_ids, prefix_mask = \
+ convert_examples_to_features(sent,
+ max_seq_length,
+ tokenizer,
+ cls_token_at_end=xlnet,
+ # xlnet has a cls token at the end
+ cls_token=cls_token,
+ cls_token_segment_id=2 if xlnet else 0,
+ sep_token=sep_token,
+ sep_token_extra=roberta,
+ # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
+ pad_on_left=xlnet,
+ # pad on the left for xlnet
+ pad_token_id=pad_token_id,
+ pad_token_segment_id=4 if xlnet else 0,
+ pad_token_label_id=pad_label_idx,
+ do_padding=False)
+ num_masks = sum(prefix_mask)
+ # assert len(sent) == num_masks # each token has a True subtoken
+ if num_masks < len(sent): # long sent gets truncated, +1 for root
+ batch_forms[-1] = batch_forms[-1][:num_masks + 1] # form
+ raw_batch[-1][sent_idx] = raw_batch[-1][sent_idx][:num_masks + 1] # head
+ raw_batch[-2][sent_idx] = raw_batch[-2][sent_idx][:num_masks + 1] # rel
+ raw_batch[-3][sent_idx] = raw_batch[-3][sent_idx][:num_masks + 1] # pos
+ prefix_mask[0] = True # is now [CLS]
+ prefix_offset = [idx for idx, m in enumerate(prefix_mask) if m]
+ batch_input_ids.append(input_ids)
+ batch_input_mask.append(input_mask)
+ batch_prefix_offset.append(prefix_offset)
+ if mask_p:
+ if shuffle:
+ size = int(np.ceil(mask_p * len(prefix_offset[1:]))) # never mask [CLS]
+ mask_offsets = np.random.choice(np.arange(1, len(prefix_offset)), size, replace=False)
+ for offset in sorted(mask_offsets):
+ assert 0 < offset < len(input_ids)
+ # mask_word = raw_batch[0][sent_idx][offset]
+ # mask_prefix = tokenizer.convert_ids_to_tokens([input_ids[prefix_offset[offset]]])[0]
+ # assert mask_word.startswith(mask_prefix) or mask_prefix.startswith(
+ # mask_word) or mask_prefix == "'", \
+ # f'word {mask_word} prefix {mask_prefix} not match' # could vs couldn
+ # mask_offsets.append(input_ids[offset]) # subword token
+ # mask_offsets.append(offset) # form token
+ input_ids[prefix_offset[offset]] = mask_token_id # mask prefix
+ # whole word masking, mask the rest of the word
+ for i in range(prefix_offset[offset] + 1, len(input_ids) - 1):
+ if prefix_mask[i]:
+ break
+ input_ids[i] = mask_token_id
+
+ batch_masked_offsets.append(sorted(mask_offsets))
+ else:
+ batch_masked_offsets.append([0]) # No masking in prediction
+
+ batch_forms = tf.keras.preprocessing.sequence.pad_sequences(batch_forms, padding='post',
+ value=self.form_vocab.safe_pad_token_idx,
+ dtype='int64')
+ batch_input_ids = tf.keras.preprocessing.sequence.pad_sequences(batch_input_ids, padding='post',
+ value=pad_token_id,
+ dtype='int64')
+ batch_input_mask = tf.keras.preprocessing.sequence.pad_sequences(batch_input_mask, padding='post',
+ value=0,
+ dtype='int64')
+ batch_prefix_offset = tf.keras.preprocessing.sequence.pad_sequences(batch_prefix_offset, padding='post',
+ value=0,
+ dtype='int64')
+ batch_heads = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[-2], padding='post',
+ value=0,
+ dtype='int64')
+ batch_rels = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[-1], padding='post',
+ value=self.rel_vocab.safe_pad_token_idx,
+ dtype='int64')
+ if mask_p:
+ batch_masked_offsets = tf.keras.preprocessing.sequence.pad_sequences(batch_masked_offsets, padding='post',
+ value=pad_token_id,
+ dtype='int64')
+ feats = (tf.constant(batch_input_ids, dtype='int64'), tf.constant(batch_input_mask, dtype='int64'),
+ tf.constant(batch_prefix_offset))
+ if use_pos:
+ batch_pos = tf.keras.preprocessing.sequence.pad_sequences(raw_batch[1], padding='post',
+ value=self.cpos_vocab.safe_pad_token_idx,
+ dtype='int64')
+ feats += (batch_pos,)
+ yield (batch_forms, feats), \
+ (batch_heads, batch_rels, batch_masked_offsets) if mask_p else (batch_heads, batch_rels)
+
+ def len_of_sent(self, sent):
+ # Transformer tokenizing
+ config = self.transformer_config
+ tokenizer = self.tokenizer
+ xlnet = config_is(config, 'xlnet')
+ roberta = config_is(config, 'roberta')
+ pad_token = tokenizer.pad_token
+ pad_token_id = tokenizer.convert_tokens_to_ids([pad_token])[0]
+ cls_token = tokenizer.cls_token
+ sep_token = tokenizer.sep_token
+ max_seq_length = self.config.max_seq_length
+ sent = sent[1:] # remove use [CLS] instead
+ pad_label_idx = self.form_vocab.pad_idx
+ sent = [x[0] for x in sent]
+ sent = adjust_tokens_for_transformers(sent)
+ input_ids, input_mask, segment_ids, prefix_mask = \
+ convert_examples_to_features(sent,
+ max_seq_length,
+ tokenizer,
+ cls_token_at_end=xlnet,
+ # xlnet has a cls token at the end
+ cls_token=cls_token,
+ cls_token_segment_id=2 if xlnet else 0,
+ sep_token=sep_token,
+ sep_token_extra=roberta,
+ # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
+ pad_on_left=xlnet,
+ # pad on the left for xlnet
+ pad_token_id=pad_token_id,
+ pad_token_segment_id=4 if xlnet else 0,
+ pad_token_label_id=pad_label_idx,
+ do_padding=False)
+ return len(input_ids)
+
+ def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_size=5000, shuffle=None, repeat=None,
+ drop_remainder=False, prefetch=1, cache=True) -> tf.data.Dataset:
+ if shuffle:
+ return CoNLL_DEP_Transform.samples_to_dataset(self, samples, map_x, map_y, batch_size, shuffle, repeat,
+ drop_remainder, prefetch, cache)
+
+ def generator():
+ # custom bucketing, load corpus into memory
+ corpus = list(x for x in (samples() if callable(samples) else samples))
+ n_tokens = 0
+ batch = []
+ for idx, sent in enumerate(corpus):
+ sent_len = self.len_of_sent(sent)
+ if n_tokens + sent_len > batch_size and batch:
+ yield from self.batched_inputs_to_batches(corpus, batch, shuffle)
+ n_tokens = 0
+ batch = []
+ n_tokens += sent_len
+ batch.append(idx)
+ if batch:
+ yield from self.batched_inputs_to_batches(corpus, batch, shuffle)
+
+ # debug for transformer
+ # next(generator())
+ return Transform.samples_to_dataset(self, generator, False, False, 0, False, repeat, drop_remainder, prefetch,
+ cache)
+
+ def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
+ if self.graph:
+ ys = CoNLL_SDP_Transform.Y_to_outputs(self, Y, gold, inputs, X)
+ ys = [[([t[0] for t in l], [t[1] for t in l]) for l in y] for y in ys]
+ return ys
+ return super().Y_to_outputs(Y, gold, inputs, X)
+
+
+class CoNLL_SDP_Transform(CoNLLTransform):
+
+ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32, min_freq=2,
+ use_pos=True, **kwargs) -> None:
+ super().__init__(config, map_x, map_y, lower, n_buckets, min_freq, use_pos, **kwargs)
+ self.orphan_relation = ROOT
+
+ def lock_vocabs(self):
+ super().lock_vocabs()
+ # heuristic to find the orphan relation
+ self._find_orphan_relation()
+
+ def _find_orphan_relation(self):
+ for rel in self.rel_vocab.idx_to_token:
+ if 'root' in rel.lower():
+ self.orphan_relation = rel
+ break
+
+ def file_to_inputs(self, filepath: str, gold=True):
+ assert gold, 'only support gold file for now'
+ use_pos = self.use_pos
+ conllu = filepath.endswith('.conllu')
+ enhanced_only = self.config.get('enhanced_only', None)
+ for i, sent in enumerate(read_conll(filepath)):
+ parsed_sent = []
+ if conllu:
+ for cell in sent:
+ ID = cell[0]
+ form = cell[1]
+ cpos = cell[3]
+ head = cell[6]
+ deprel = cell[7]
+ deps = cell[8]
+ deps = [x.split(':', 1) for x in deps.split('|')]
+ heads = [int(x[0]) for x in deps if x[0].isdigit()]
+ rels = [x[1] for x in deps if x[0].isdigit()]
+ if enhanced_only:
+ if head in heads:
+ offset = heads.index(head)
+ heads.pop(offset)
+ rels.pop(offset)
+ else:
+ if head not in heads:
+ heads.append(head)
+ rels.append(deprel)
+ parsed_sent.append([form, cpos, heads, rels] if use_pos else [form, heads, rels])
+ else:
+ prev_cells = None
+ heads = []
+ rels = []
+ for j, cell in enumerate(sent):
+ ID = cell[0]
+ form = cell[1]
+ cpos = cell[3]
+ head = cell[6]
+ deprel = cell[7]
+ if prev_cells and ID != prev_cells[0]: # found end of token
+ parsed_sent.append(
+ [prev_cells[1], prev_cells[2], heads, rels] if use_pos else [prev_cells[1], heads, rels])
+ heads = []
+ rels = []
+ heads.append(head)
+ rels.append(deprel)
+ prev_cells = [ID, form, cpos, head, deprel] if use_pos else [ID, form, head, deprel]
+ parsed_sent.append(
+ [prev_cells[1], prev_cells[2], heads, rels] if use_pos else [prev_cells[1], heads, rels])
+ yield parsed_sent
+
+ def fit(self, trn_path: str, **kwargs) -> int:
+ self.form_vocab = VocabTF()
+ self.form_vocab.add(ROOT) # make root the 2ed elements while 0th is pad, 1st is unk
+ if self.use_pos:
+ self.cpos_vocab = VocabTF(pad_token=None, unk_token=None)
+ self.rel_vocab = VocabTF(pad_token=None, unk_token=None)
+ num_samples = 0
+ counter = Counter()
+ for sent in self.file_to_samples(trn_path, gold=True):
+ num_samples += 1
+ for idx, cell in enumerate(sent):
+ if len(cell) == 4:
+ form, cpos, head, deprel = cell
+ elif len(cell) == 3:
+ if self.use_pos:
+ form, cpos = cell[0]
+ else:
+ form = cell[0]
+ head, deprel = cell[1:]
+ else:
+ raise ValueError('Unknown data arrangement')
+ if idx == 0:
+ root = form
+ else:
+ counter[form] += 1
+ if self.use_pos:
+ self.cpos_vocab.add(cpos)
+ self.rel_vocab.update(deprel)
+
+ for token in [token for token, freq in counter.items() if freq >= self.config.min_freq]:
+ self.form_vocab.add(token)
+ return num_samples
+
+ def inputs_to_samples(self, inputs, gold=False):
+ use_pos = self.use_pos
+ for sent in inputs:
+ sample = []
+ for i, cell in enumerate(sent):
+ if isinstance(cell, tuple):
+ cell = list(cell)
+ elif isinstance(cell, str):
+ cell = [cell]
+ if self.config['lower']:
+ cell[0] = cell[0].lower()
+ if not gold:
+ cell += [[0], [self.rel_vocab.safe_pad_token]]
+ sample.append(cell)
+ # insert root word with arbitrary fields, anyway it will be masked
+ if use_pos:
+ form, cpos, head, deprel = sample[0]
+ sample.insert(0, [self.bos, self.bos, [0], deprel])
+ else:
+ form, head, deprel = sample[0]
+ sample.insert(0, [self.bos, [0], deprel])
+ yield sample
+
+ def batched_inputs_to_batches(self, corpus, indices, shuffle):
+ use_pos = self.use_pos
+ raw_batch = [[], [], [], []] if use_pos else [[], [], []]
+ max_len = len(max([corpus[i] for i in indices], key=len))
+ for idx in indices:
+ arc = np.zeros((max_len, max_len), dtype=np.bool)
+ rel = np.zeros((max_len, max_len), dtype=np.int64)
+ for b in raw_batch[:2]:
+ b.append([])
+ for m, cells in enumerate(corpus[idx]):
+ if use_pos:
+ for b, c, v in zip(raw_batch, cells,
+ [self.form_vocab, self.cpos_vocab]):
+ b[-1].append(v.get_idx_without_add(c))
+ else:
+ for b, c, v in zip(raw_batch, cells,
+ [self.form_vocab]):
+ b[-1].append(v.get_idx_without_add(c))
+ for n, r in zip(cells[-2], cells[-1]):
+ arc[m, n] = True
+ rid = self.rel_vocab.get_idx_without_add(r)
+ if rid is None:
+ logger.warning(f'Relation OOV: {r} not exists in train')
+ continue
+ rel[m, n] = rid
+ raw_batch[-2].append(arc)
+ raw_batch[-1].append(rel)
+ batch = []
+ for b, v in zip(raw_batch, [self.form_vocab, self.cpos_vocab]):
+ b = tf.keras.preprocessing.sequence.pad_sequences(b, padding='post',
+ value=v.safe_pad_token_idx,
+ dtype='int64')
+ batch.append(b)
+ batch += raw_batch[2:]
+ assert len(batch) == 4
+ yield (batch[0], batch[1]), (batch[2], batch[3])
+
+ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
+ types = (tf.int64, tf.int64), (tf.bool, tf.int64)
+ shapes = ([None, None], [None, None]), ([None, None, None], [None, None, None])
+ values = (self.form_vocab.safe_pad_token_idx, self.cpos_vocab.safe_pad_token_idx), (
+ False, self.rel_vocab.safe_pad_token_idx)
+ return types, shapes, values
+
+ def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
+ arc_preds, rel_preds, mask = Y
+ sents = []
+
+ for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
+ tf.math.count_nonzero(mask, axis=-1)):
+ sent = []
+ for arc, rel in zip(tolist(arc_sent[1:, 1:]), tolist(rel_sent[1:, 1:])):
+ ar = []
+ for idx, (a, r) in enumerate(zip(arc, rel)):
+ if a:
+ ar.append((idx + 1, self.rel_vocab.idx_to_token[r]))
+ if not ar:
+ # orphan
+ ar.append((0, self.orphan_relation))
+ sent.append(ar)
+ sents.append(sent)
+
+ return sents
+
+ def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]], Y: Union[tf.Tensor, Tuple[tf.Tensor]],
+ gold=False, inputs=None, conll=True) -> Iterable:
+ (words, feats, mask), (arc_preds, rel_preds) = X, Y
+ xs = inputs
+ ys = self.Y_to_outputs((arc_preds, rel_preds, mask))
+ sents = []
+ for x, y in zip(xs, ys):
+ sent = CoNLLSentence()
+ for idx, ((form, cpos), pred) in enumerate(zip(x, y)):
+ head = [p[0] for p in pred]
+ deprel = [p[1] for p in pred]
+ if conll:
+ sent.append(CoNLLWord(id=idx + 1, form=form, cpos=cpos, head=head, deprel=deprel))
+ else:
+ sent.append([head, deprel])
+ sents.append(sent)
+ return sents
diff --git a/hanlp/transform/glue_tf.py b/hanlp/transform/glue_tf.py
new file mode 100644
index 000000000..eace50062
--- /dev/null
+++ b/hanlp/transform/glue_tf.py
@@ -0,0 +1,44 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-08 16:34
+from hanlp_common.structure import SerializableDict
+from hanlp.datasets.glue import STANFORD_SENTIMENT_TREEBANK_2_TRAIN, MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV
+from hanlp.transform.table import TableTransform
+
+
+class StanfordSentimentTreebank2Transorm(TableTransform):
+ pass
+
+
+class MicrosoftResearchParaphraseCorpus(TableTransform):
+
+ def __init__(self, config: SerializableDict = None, map_x=False, map_y=True, x_columns=(3, 4),
+ y_column=0, skip_header=True, delimiter='auto', **kwargs) -> None:
+ super().__init__(config, map_x, map_y, x_columns, y_column, skip_header, delimiter, **kwargs)
+
+
+def main():
+ # _test_sst2()
+ _test_mrpc()
+
+
+def _test_sst2():
+ transform = StanfordSentimentTreebank2Transorm()
+ transform.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN)
+ transform.lock_vocabs()
+ transform.label_vocab.summary()
+ transform.build_config()
+ dataset = transform.file_to_dataset(STANFORD_SENTIMENT_TREEBANK_2_TRAIN)
+ for batch in dataset.take(1):
+ print(batch)
+
+
+def _test_mrpc():
+ transform = MicrosoftResearchParaphraseCorpus()
+ transform.fit(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV)
+ transform.lock_vocabs()
+ transform.label_vocab.summary()
+ transform.build_config()
+ dataset = transform.file_to_dataset(MICROSOFT_RESEARCH_PARAPHRASE_CORPUS_DEV)
+ for batch in dataset.take(1):
+ print(batch)
\ No newline at end of file
diff --git a/hanlp/transform/table.py b/hanlp/transform/table.py
index ad95fd8f7..e5d7865c6 100644
--- a/hanlp/transform/table.py
+++ b/hanlp/transform/table.py
@@ -6,10 +6,10 @@
import numpy as np
import tensorflow as tf
-from hanlp.common.structure import SerializableDict
-from hanlp.common.transform import Transform
-from hanlp.common.constant import PAD
-from hanlp.common.vocab import create_label_vocab
+from hanlp_common.structure import SerializableDict
+from hanlp.common.transform_tf import Transform
+from hanlp_common.constant import PAD
+from hanlp.common.vocab_tf import create_label_vocab
from hanlp.utils.io_util import read_cells
from hanlp.utils.log_util import logger
@@ -28,8 +28,8 @@ def file_to_inputs(self, filepath: str, gold=True):
y_column = self.config.y_column
num_features = self.config.get('num_features', None)
for cells in read_cells(filepath, skip_header=self.config.skip_header, delimiter=self.config.delimiter):
- #multi-label: Dataset in .tsv format: x_columns: at most 2 columns being a sentence pair while in most
- # cases just one column being the doc content. y_column being the single label, which shall be modified
+ #multi-label: Dataset in .tsv format: x_columns: at most 2 columns being a sentence pair while in most
+ # cases just one column being the doc content. y_column being the single label, which shall be modified
# to load a list of labels.
if x_columns:
inputs = tuple(c for i, c in enumerate(cells) if i in x_columns), cells[y_column]
diff --git a/hanlp/transform/tacred.py b/hanlp/transform/tacred.py
new file mode 100644
index 000000000..a4fc2f200
--- /dev/null
+++ b/hanlp/transform/tacred.py
@@ -0,0 +1,106 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-14 17:06
+from typing import Union, Tuple
+
+import tensorflow as tf
+
+from hanlp_common.structure import SerializableDict
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
+from hanlp_common.io import load_json
+from hanlp_common.util import merge_locals_kwargs
+
+
+def get_positions(start_idx, end_idx, length):
+ """Get subj/obj position sequence.
+
+ Args:
+ start_idx:
+ end_idx:
+ length:
+
+ Returns:
+
+ """
+ return list(range(-start_idx, 0)) + [0] * (end_idx - start_idx + 1) + \
+ list(range(1, length - end_idx))
+
+
+class TACREDTransform(Transform):
+ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=False, **kwargs) -> None:
+ super().__init__(**merge_locals_kwargs(locals(), kwargs))
+ self.token_vocab = VocabTF()
+ self.pos_vocab = VocabTF(pad_token=None, unk_token=None)
+ self.ner_vocab = VocabTF(pad_token=None)
+ self.deprel_vocab = VocabTF(pad_token=None, unk_token=None)
+ self.rel_vocab = VocabTF(pad_token=None, unk_token=None)
+
+ def fit(self, trn_path: str, **kwargs) -> int:
+ count = 0
+ for (tokens, pos, ner, head, deprel, subj_positions, obj_positions, subj_type,
+ obj_type), relation in self.file_to_samples(
+ trn_path, gold=True):
+ count += 1
+ self.token_vocab.update(tokens)
+ self.pos_vocab.update(pos)
+ self.ner_vocab.update(ner)
+ self.deprel_vocab.update(deprel)
+ self.rel_vocab.add(relation)
+ return count
+
+ def file_to_inputs(self, filepath: str, gold=True):
+ data = load_json(filepath)
+ for d in data:
+ tokens = list(d['token'])
+ ss, se = d['subj_start'], d['subj_end']
+ os, oe = d['obj_start'], d['obj_end']
+ pos = d['stanford_pos']
+ ner = d['stanford_ner']
+ deprel = d['stanford_deprel']
+ head = [int(x) for x in d['stanford_head']]
+ assert any([x == 0 for x in head])
+ relation = d['relation']
+ yield (tokens, pos, ner, head, deprel, ss, se, os, oe), relation
+
+ def inputs_to_samples(self, inputs, gold=False):
+ for input in inputs:
+ if gold:
+ (tokens, pos, ner, head, deprel, ss, se, os, oe), relation = input
+ else:
+ tokens, pos, ner, head, deprel, ss, se, os, oe = input
+ relation = self.rel_vocab.safe_pad_token
+ l = len(tokens)
+ subj_positions = get_positions(ss, se, l)
+ obj_positions = get_positions(os, oe, l)
+ subj_type = ner[ss]
+ obj_type = ner[os]
+ # anonymize tokens
+ tokens[ss:se + 1] = ['SUBJ-' + subj_type] * (se - ss + 1)
+ tokens[os:oe + 1] = ['OBJ-' + obj_type] * (oe - os + 1)
+ # min head is 0, but root is not included in tokens, so take 1 off from each head
+ head = [h - 1 for h in head]
+ yield (tokens, pos, ner, head, deprel, subj_positions, obj_positions, subj_type, obj_type), relation
+
+ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
+ # (tokens, pos, ner, head, deprel, subj_positions, obj_positions, subj_type, obj_type), relation
+ types = (tf.string, tf.string, tf.string, tf.int32, tf.string, tf.int32, tf.int32, tf.string,
+ tf.string), tf.string
+ shapes = ([None], [None], [None], [None], [None], [None], [None], [], []), []
+ pads = (self.token_vocab.safe_pad_token, self.pos_vocab.safe_pad_token, self.ner_vocab.safe_pad_token, 0,
+ self.deprel_vocab.safe_pad_token,
+ 0, 0, self.ner_vocab.safe_pad_token, self.ner_vocab.safe_pad_token), self.rel_vocab.safe_pad_token
+ return types, shapes, pads
+
+ def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
+ tokens, pos, ner, head, deprel, subj_positions, obj_positions, subj_type, obj_type = x
+ tokens = self.token_vocab.lookup(tokens)
+ pos = self.pos_vocab.lookup(pos)
+ ner = self.ner_vocab.lookup(ner)
+ deprel = self.deprel_vocab.lookup(deprel)
+ subj_type = self.ner_vocab.lookup(subj_type)
+ obj_type = self.ner_vocab.lookup(obj_type)
+ return tokens, pos, ner, head, deprel, subj_positions, obj_positions, subj_type, obj_type
+
+ def y_to_idx(self, y) -> tf.Tensor:
+ return self.rel_vocab.lookup(y)
diff --git a/hanlp/transform/text.py b/hanlp/transform/text.py
index 0c3acdf42..3c87f42b1 100644
--- a/hanlp/transform/text.py
+++ b/hanlp/transform/text.py
@@ -5,12 +5,12 @@
import tensorflow as tf
-from hanlp.common.structure import SerializableDict
-from hanlp.common.transform import Transform
-from hanlp.common.vocab import Vocab
+from hanlp_common.structure import SerializableDict
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
from hanlp.metrics.chunking.sequence_labeling import get_entities
from hanlp.utils.file_read_backwards import FileReadBackwards
-from hanlp.utils.io_util import read_tsv
+from hanlp.utils.io_util import read_tsv_as_sents
class TextTransform(Transform):
@@ -21,7 +21,7 @@ def __init__(self,
tokenizer='char',
config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None:
super().__init__(config, map_x, map_y, seq_len=seq_len, tokenizer=tokenizer, forward=forward, **kwargs)
- self.vocab: Vocab = None
+ self.vocab: VocabTF = None
def tokenize_func(self):
if self.config.tokenizer == 'char':
@@ -32,7 +32,7 @@ def tokenize_func(self):
return lambda x: x.split(self.config.tokenizer)
def fit(self, trn_path: str, **kwargs) -> int:
- self.vocab = Vocab()
+ self.vocab = VocabTF()
num_samples = 0
for x, y in self.file_to_inputs(trn_path):
self.vocab.update(x)
@@ -90,7 +90,7 @@ def input_is_single_sample(self, input: Any) -> bool:
def bmes_to_flat(inpath, outpath):
with open(outpath, 'w', encoding='utf-8') as out:
- for sent in read_tsv(inpath):
+ for sent in read_tsv_as_sents(inpath):
chunks = get_entities([cells[1] for cells in sent])
chars = [cells[0] for cells in sent]
words = []
diff --git a/hanlp/transform/transformer_tokenizer.py b/hanlp/transform/transformer_tokenizer.py
new file mode 100644
index 000000000..77661abaf
--- /dev/null
+++ b/hanlp/transform/transformer_tokenizer.py
@@ -0,0 +1,631 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-03 16:23
+import warnings
+from typing import Union, Optional
+
+from hanlp_common.constant import BOS, EOS
+from hanlp_common.structure import SerializableDict
+from hanlp.layers.transformers.pt_imports import PreTrainedTokenizer, PretrainedConfig, AutoTokenizer
+from hanlp_trie import DictInterface
+
+
+class TransformerTokenizer(object):
+
+ def __init__(self, max_seq_length=512, truncate_long_sequences=True) -> None:
+ self.truncate_long_sequences = truncate_long_sequences
+ self.max_seq_length = max_seq_length
+
+ def sliding_window(self, flat_wordpiece_ids, same_tail=True):
+ if same_tail:
+ start_piece_ids, flat_wordpiece_ids, end_piece_ids = flat_wordpiece_ids[:1], \
+ flat_wordpiece_ids[1:-1], flat_wordpiece_ids[-1:]
+ else:
+ start_piece_ids, flat_wordpiece_ids, end_piece_ids = flat_wordpiece_ids[:1], \
+ flat_wordpiece_ids[1:], []
+ window_length = self.max_seq_length - len(start_piece_ids) - len(end_piece_ids)
+ stride = window_length // 2
+ wordpiece_windows = [start_piece_ids + flat_wordpiece_ids[i:i + window_length] + end_piece_ids
+ for i in range(0, len(flat_wordpiece_ids), stride)]
+
+ # Check for overlap in the last window. Throw it away if it is redundant.
+ last_window = wordpiece_windows[-1][1:]
+ penultimate_window = wordpiece_windows[-2]
+ if last_window == penultimate_window[-len(last_window):]:
+ wordpiece_windows = wordpiece_windows[:-1]
+
+ wordpiece_ids = [wordpiece for sequence in wordpiece_windows for wordpiece in sequence]
+ return wordpiece_ids
+
+
+class TransformerTextTokenizer(TransformerTokenizer):
+ _KEY = ['input_ids', 'attention_mask', 'token_type_ids']
+
+ def __init__(self,
+ tokenizer: Union[PreTrainedTokenizer, str],
+ text_a_key: str,
+ text_b_key: str = None,
+ output_key=None,
+ max_seq_length=512, truncate_long_sequences=True) -> None:
+ super().__init__(max_seq_length, truncate_long_sequences)
+ self.text_b = text_b_key
+ self.text_a = text_a_key
+ if output_key is None:
+ output_key = self.text_a
+ if text_b_key:
+ output_key += '_' + text_b_key
+ if output_key == '':
+ output_key = self._KEY
+ else:
+ output_key = [f'{output_key}_{key}' for key in self._KEY]
+ self.output_key = output_key
+ if isinstance(tokenizer, str):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+ self.tokenizer = tokenizer
+
+ def __call__(self, sample: dict):
+ text_a = sample[self.text_a]
+ text_b = sample[self.text_b] if self.text_b else None
+ max_seq_length = self.max_seq_length if self.truncate_long_sequences else None
+ encoding = self.tokenizer.encode_plus(text_a, text_b, max_length=max_seq_length)
+ input_ids = encoding.data['input_ids']
+ if not self.truncate_long_sequences and len(input_ids) > self.max_seq_length:
+ input_ids = self.sliding_window(input_ids)
+ encoding.data['input_ids'] = input_ids # TODO: other fields should be properly handled too
+ for k, v in zip(self.output_key, [encoding.data[_] for _ in self._KEY]):
+ sample[k] = v
+ return sample
+
+
+class TransformerSequenceTokenizer(TransformerTokenizer):
+
+ def __init__(self,
+ tokenizer: Union[PreTrainedTokenizer, str],
+ input_key,
+ output_key=None,
+ max_seq_length=512,
+ truncate_long_sequences=False,
+ config: PretrainedConfig = None,
+ cls_token_at_end=False,
+ cls_token_segment_id=0,
+ pad_token_segment_id=0,
+ pad_on_left=False,
+ do_padding=False,
+ sep_token_extra=False,
+ ret_mask_and_type=False,
+ ret_prefix_mask=False,
+ ret_token_span=True,
+ ret_subtokens=False,
+ ret_subtokens_group=False,
+ cls_is_bos=False,
+ sep_is_eos=False,
+ do_basic_tokenize=True,
+ use_fast=True,
+ dict_force=None,
+ strip_cls_sep=True,
+ check_space_before=None,
+ ) -> None:
+ """A transformer tokenizer for token-level tasks. It honors the boundary of tokens and tokenize each token into
+ several subtokens then merge them. The information about each subtoken belongs to which token are kept and
+ returned as a new field in the sample. It also provides out-of-box sliding window trick on long sequences.
+
+ Args:
+ tokenizer: The identifier of a pre-trained tokenizer or a ``PreTrainedTokenizer``.
+ input_key: The token key in samples.
+ output_key: The output keys to store results.
+ max_seq_length: Sentences longer than ``max_seq_len`` will be split into shorter ones if possible.
+ truncate_long_sequences: ``True`` to truncate exceeded parts of long sequences. ``False`` to enable
+ sliding window.
+ config: The ``PretrainedConfig`` to determine the model structure of the transformer, so that special
+ tokenization can be applied.
+ cls_token_at_end: ``True`` to put ``[CLS]`` at the end of input tokens.
+ cls_token_segment_id: The id of ``[CLS]``.
+ pad_token_segment_id: The id of ``[SEP]``.
+ pad_on_left: ``True`` to put ``[PAD]`` at the left side of input tokens.
+ do_padding: ``True`` to pad sequence to the left.
+ sep_token_extra: ``True`` to have two ``[SEP]``.
+ ret_mask_and_type: ``True`` to return masks and type ids.
+ ret_prefix_mask: ``True`` to generate a mask where each non-zero element corresponds to a prefix of a token.
+ ret_token_span: ``True`` to return span of each token measured by subtoken offsets.
+ ret_subtokens: ``True`` to return list of subtokens belonging to each token for tokenization purpose.
+ When enabled, the prefix mask for each subtoken is set to True as each subtoken is a token unit in
+ tokenization task. Similarity, the token span for each token will be a continuous integer sequence.
+ ret_subtokens_group: ``True`` to return list of offsets of subtokens belonging to each token.
+ cls_is_bos: ``True`` means the first token of input is treated as [CLS] no matter what its surface form is.
+ ``False`` (default) means the first token is not [CLS], it will have its own embedding other than
+ the embedding of [CLS].
+ sep_is_eos: ``True`` means the last token of input is [SEP].
+ ``False`` means it's not but [SEP] will be appended,
+ ``None`` means it dependents on `input[-1] == [EOS]`.
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
+ use_fast: Whether or not to try to load the fast version of the tokenizer.
+ dict_force: A dictionary doing longest-prefix-match on input text so that the head and tail of each keyword
+ won't be concatenated to other tokens by transformer tokenizers.
+ strip_cls_sep: ``True`` to strip [CLS] and [SEP] off the input tokens.
+ check_space_before: ``True`` to detect the space before each token to handle underline in sentence piece
+ tokenization.
+
+ Examples:
+
+ .. highlight:: python
+ .. code-block:: python
+
+ transform = TransformerSequenceTokenizer('bert-base-uncased', 'token')
+ sample = {'token': 'HanLP good'.split()}
+ print(transform(sample))
+
+ """
+ super().__init__(max_seq_length, truncate_long_sequences)
+ tokenizer_name = tokenizer if isinstance(tokenizer, str) else tokenizer.name_or_path
+ if check_space_before is None:
+ # These tokenizer is BPE-based which appends a space before each token and tokenizes loving into
+ # ['▁lo', 'ving'], tokenize 商品 into ['▁', '商品']. For the later case, the prefix '▁' has to be removed
+ # as there is no space between some languages like Chinese
+ check_space_before = tokenizer_name in ('xlm-roberta-base', 'xlm-roberta-large', 'google/mt5-small',
+ 'google/mt5-base')
+ self.check_space_before = check_space_before
+ self.ret_subtokens_group = ret_subtokens_group
+ self.ret_subtokens = ret_subtokens
+ self.sep_is_eos = sep_is_eos
+ self.ret_prefix_mask = ret_prefix_mask
+ self.ret_mask_and_type = ret_mask_and_type
+ self.cls_is_bos = cls_is_bos
+ self.ret_token_span = ret_token_span
+ if not output_key or isinstance(output_key, str):
+ suffixes = ['input_ids']
+ if ret_mask_and_type:
+ suffixes += 'attention_mask', 'token_type_ids'
+ if ret_prefix_mask:
+ suffixes += ['prefix_mask']
+ if ret_token_span:
+ suffixes.append('token_span')
+ if output_key is None:
+ output_key = [f'{input_key}_{key}' for key in suffixes]
+ elif output_key == '':
+ output_key = suffixes
+ else:
+ output_key = [f'{output_key}_{key}' for key in suffixes]
+
+ self.input_key = input_key
+ self.output_key = output_key
+ if config:
+ xlnet = config_is(config, 'xlnet')
+ pad_token_segment_id = 4 if xlnet else 0
+ cls_token_segment_id = 2 if xlnet else 0
+ cls_token_at_end = xlnet
+ pad_on_left = xlnet
+ if isinstance(tokenizer, str):
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=use_fast, do_basic_tokenize=do_basic_tokenize)
+ pad_token = tokenizer.pad_token
+ self.pad_token_id = tokenizer.convert_tokens_to_ids([pad_token])[0]
+ self.pad_token_segment_id = pad_token_segment_id
+ if tokenizer_name in ('google/mt5-small', 'google/mt5-base'):
+ # mt5 doesn't have cls or sep, but we can use something similar
+ self.has_cls = False
+ self.cls_token = '▁'
+ self.cls_token_id = tokenizer.convert_tokens_to_ids(self.cls_token)
+ self.sep_token = tokenizer.eos_token
+ self.sep_token_id = tokenizer.eos_token_id
+ else:
+ self.has_cls = True
+ self.cls_token = tokenizer.cls_token
+ self.sep_token = tokenizer.sep_token
+ self.cls_token_segment_id = cls_token_segment_id
+ self.cls_token_id = tokenizer.cls_token_id
+ self.sep_token_id = tokenizer.sep_token_id
+
+ self.sep_token_extra = sep_token_extra
+ self.cls_token_at_end = cls_token_at_end
+ self.tokenizer = tokenizer
+ self.pad_on_left = pad_on_left
+ self.do_padding = do_padding
+ if self.ret_token_span or not self.truncate_long_sequences:
+ assert not self.cls_token_at_end
+ assert not self.pad_on_left
+ if self.ret_subtokens:
+ if not use_fast:
+ raise NotImplementedError(
+ 'ret_subtokens is not available when using Python tokenizers. '
+ 'To use this feature, set use_fast = True.')
+ self.dict: Optional[DictInterface] = dict_force # For tokenization of raw text
+ self.strip_cls_sep = strip_cls_sep
+
+ def __call__(self, sample: dict):
+ input_tokens = sample[self.input_key]
+ input_is_str = isinstance(input_tokens, str)
+ tokenizer = self.tokenizer
+ ret_token_span = self.ret_token_span
+ if input_is_str: # This happens in a tokenizer component where the raw sentence is fed.
+
+ # noinspection PyShadowingNames
+ def tokenize_str(input_str, add_special_tokens=True):
+ if tokenizer.is_fast:
+ encoding = tokenizer.encode_plus(input_str,
+ return_offsets_mapping=True,
+ add_special_tokens=add_special_tokens).encodings[0]
+ subtoken_offsets = encoding.offsets
+ if add_special_tokens:
+ subtoken_offsets = subtoken_offsets[1 if self.has_cls else 0:-1]
+ input_tokens = encoding.tokens
+ input_ids = encoding.ids
+ if not self.has_cls:
+ input_tokens = [self.cls_token] + input_tokens
+ input_ids = [self.cls_token_id] + input_ids
+ else:
+ input_tokens = tokenizer.tokenize(input_str)
+ subtoken_offsets = input_tokens
+ if add_special_tokens:
+ input_tokens = [self.cls_token] + input_tokens + [self.sep_token]
+ input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
+ if self.check_space_before:
+ non_blank_offsets = [i for i in range(len(input_tokens)) if input_tokens[i] != '▁']
+ if add_special_tokens and not self.has_cls:
+ non_blank_offsets.insert(0, 0)
+ input_tokens = [input_tokens[i] for i in non_blank_offsets]
+ input_ids = [input_ids[i] for i in non_blank_offsets]
+ if add_special_tokens:
+ non_blank_offsets = non_blank_offsets[1:-1]
+ subtoken_offsets = [subtoken_offsets[i - 1] for i in non_blank_offsets]
+ return input_tokens, input_ids, subtoken_offsets
+
+ if self.dict:
+ chunks = self.dict.split(input_tokens)
+ _input_tokens, _input_ids, _subtoken_offsets = [self.cls_token], [self.cls_token_id], []
+ _offset = 0
+ custom_words = sample['custom_words'] = []
+ for chunk in chunks:
+ if isinstance(chunk, str):
+ tokens, ids, offsets = tokenize_str(chunk, add_special_tokens=False)
+ else:
+ begin, end, label = chunk
+ # chunk offset is in char level
+ # custom_words.append(chunk)
+ if isinstance(label, list):
+ tokens, ids, offsets, delta = [], [], [], 0
+ for token in label:
+ _tokens, _ids, _offsets = tokenize_str(token, add_special_tokens=False)
+ tokens.extend(_tokens)
+ # track the subword offset of this chunk, -1 for [CLS]
+ custom_words.append(
+ (len(_input_ids) + len(ids) - 1, len(_input_ids) + len(ids) - 1 + len(_ids), token))
+ ids.extend(_ids)
+ offsets.extend((x[0] + delta, x[1] + delta) for x in _offsets)
+ delta = offsets[-1][-1]
+ else:
+ tokens, ids, offsets = tokenize_str(input_tokens[begin:end], add_special_tokens=False)
+ # offsets = [(offsets[0][0], offsets[-1][-1])]
+ custom_words.append((len(_input_ids) - 1, len(_input_ids) + len(ids) - 1, label))
+ _input_tokens.extend(tokens)
+ _input_ids.extend(ids)
+ _subtoken_offsets.extend((x[0] + _offset, x[1] + _offset) for x in offsets)
+ _offset = _subtoken_offsets[-1][-1]
+ subtoken_offsets = _subtoken_offsets
+ input_tokens = _input_tokens + [self.sep_token]
+ input_ids = _input_ids + [self.sep_token_id]
+ else:
+ input_tokens, input_ids, subtoken_offsets = tokenize_str(input_tokens, add_special_tokens=True)
+
+ if self.ret_subtokens:
+ sample[f'{self.input_key}_subtoken_offsets'] = subtoken_offsets
+
+ cls_is_bos = self.cls_is_bos
+ if cls_is_bos is None:
+ cls_is_bos = input_tokens[0] == BOS
+ sep_is_eos = self.sep_is_eos
+ if sep_is_eos is None:
+ sep_is_eos = input_tokens[-1] == EOS
+ if self.strip_cls_sep:
+ if cls_is_bos:
+ input_tokens = input_tokens[1:]
+ if sep_is_eos:
+ input_tokens = input_tokens[:-1]
+ if not self.ret_mask_and_type: # only need input_ids and token_span, use a light version
+ if input_is_str:
+ prefix_mask = self._init_prefix_mask(input_ids)
+ else:
+ if input_tokens:
+ return_offsets_mapping = tokenizer.is_fast and self.ret_subtokens
+ encodings = tokenizer.batch_encode_plus(
+ input_tokens,
+ return_offsets_mapping=return_offsets_mapping,
+ add_special_tokens=False
+ )
+ if return_offsets_mapping:
+ offsets_mapping = [encoding.offsets for encoding in encodings.encodings]
+ else:
+ offsets_mapping = [None for encoding in encodings.encodings]
+ else:
+ encodings = SerializableDict()
+ encodings.data = {'input_ids': []}
+ subtoken_ids_per_token = encodings.data['input_ids']
+ if self.check_space_before:
+ # noinspection PyUnboundLocalVariable
+ for token, subtokens, mapping, encoding in zip(input_tokens, subtoken_ids_per_token,
+ offsets_mapping, encodings.encodings):
+ # Remove ▁ generated by spm for 2 reasons:
+ # 1. During decoding, mostly no ▁ will be created unless blanks are placed between tokens (which
+ # is true for English but in English it will likely be concatenated to the token following it)
+ # 2. For T5, '▁' is used as CLS
+ if len(subtokens) > 1 and encoding.tokens[0] == '▁':
+ subtokens.pop(0)
+ if mapping:
+ mapping.pop(0)
+ # Some tokens get stripped out
+ subtoken_ids_per_token = [ids if ids else [tokenizer.unk_token_id] for ids in subtoken_ids_per_token]
+ input_ids = sum(subtoken_ids_per_token, [self.cls_token_id])
+ if self.sep_is_eos is None:
+ # None means to check whether sep is at the tail or between tokens
+ if sep_is_eos:
+ input_ids += [self.sep_token_id]
+ elif self.sep_token_id not in input_ids:
+ input_ids += [self.sep_token_id]
+ else:
+ input_ids += [self.sep_token_id]
+ # else self.sep_is_eos == False means sep is between tokens and don't bother to check
+
+ if self.ret_subtokens:
+ prefix_mask = self._init_prefix_mask(input_ids)
+ # if self.check_space_before:
+ # if offsets_mapping[0] and not input_tokens[0].startswith(' '):
+ # prefix_mask[1] = False
+ else:
+ prefix_mask = [False] * len(input_ids)
+ offset = 1
+ for _subtokens in subtoken_ids_per_token:
+ prefix_mask[offset] = True
+ offset += len(_subtokens)
+ if self.ret_subtokens:
+ subtoken_offsets = []
+ for token, offsets in zip(input_tokens, offsets_mapping):
+ if offsets:
+ subtoken_offsets.append(offsets)
+ else:
+ subtoken_offsets.append([(0, len(token))])
+ if self.ret_subtokens_group:
+ sample[f'{self.input_key}_subtoken_offsets_group'] = subtoken_offsets
+ sample[f'{self.input_key}_subtoken_offsets'] = sum(subtoken_offsets, [])
+ else:
+ input_ids, attention_mask, token_type_ids, prefix_mask = \
+ convert_examples_to_features(input_tokens,
+ None,
+ tokenizer,
+ cls_token_at_end=self.cls_token_at_end,
+ # xlnet has a cls token at the end
+ cls_token=tokenizer.cls_token,
+ cls_token_segment_id=self.cls_token_segment_id,
+ sep_token=self.sep_token,
+ sep_token_extra=self.sep_token_extra,
+ # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
+ pad_on_left=self.pad_on_left,
+ # pad on the left for xlnet
+ pad_token_id=self.pad_token_id,
+ pad_token_segment_id=self.pad_token_segment_id,
+ pad_token_label_id=0,
+ do_padding=self.do_padding)
+ if len(input_ids) > self.max_seq_length:
+ if self.truncate_long_sequences:
+ # raise SequenceTooLong(
+ # f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
+ # f'For sequence tasks, truncate_long_sequences = True is not supported.'
+ # f'You are recommended to split your long text into several sentences within '
+ # f'{self.max_seq_length - 2} tokens beforehand. '
+ # f'Or simply set truncate_long_sequences = False to enable sliding window.')
+ input_ids = input_ids[:self.max_seq_length]
+ prefix_mask = prefix_mask[:self.max_seq_length]
+ warnings.warn(
+ f'Input tokens {input_tokens} exceed the max sequence length of {self.max_seq_length - 2}. '
+ f'The exceeded part will be truncated and ignored. '
+ f'You are recommended to split your long text into several sentences within '
+ f'{self.max_seq_length - 2} tokens beforehand.'
+ f'Or simply set truncate_long_sequences = False to enable sliding window.'
+ )
+ else:
+ input_ids = self.sliding_window(input_ids, input_ids[-1] == self.sep_token_id)
+ if prefix_mask:
+ if cls_is_bos:
+ prefix_mask[0] = True
+ if sep_is_eos:
+ prefix_mask[-1] = True
+ outputs = [input_ids]
+ if self.ret_mask_and_type:
+ # noinspection PyUnboundLocalVariable
+ outputs += [attention_mask, token_type_ids]
+ if self.ret_prefix_mask:
+ outputs += [prefix_mask]
+ if ret_token_span and prefix_mask:
+ if cls_is_bos:
+ token_span = [[0]]
+ else:
+ token_span = []
+ offset = 1
+ span = []
+ for mask in prefix_mask[1:len(prefix_mask) if sep_is_eos is None else -1]: # skip [CLS] and [SEP]
+ if mask and span:
+ token_span.append(span)
+ span = []
+ span.append(offset)
+ offset += 1
+ if span:
+ token_span.append(span)
+ if sep_is_eos:
+ assert offset == len(prefix_mask) - 1
+ token_span.append([offset])
+ outputs.append(token_span)
+ for k, v in zip(self.output_key, outputs):
+ sample[k] = v
+ return sample
+
+ def _init_prefix_mask(self, input_ids):
+ prefix_mask = [True] * len(input_ids)
+ if not self.cls_is_bos:
+ prefix_mask[0] = False
+ if not self.sep_is_eos:
+ prefix_mask[-1] = False
+ return prefix_mask
+
+
+def config_is(config, model='bert'):
+ return model in type(config).__name__.lower()
+
+
+def convert_examples_to_features(
+ words,
+ max_seq_length: Optional[int],
+ tokenizer,
+ labels=None,
+ label_map=None,
+ cls_token_at_end=False,
+ cls_token="[CLS]",
+ cls_token_segment_id=1,
+ sep_token="[SEP]",
+ sep_token_extra=False,
+ pad_on_left=False,
+ pad_token_id=0,
+ pad_token_segment_id=0,
+ pad_token_label_id=0,
+ sequence_a_segment_id=0,
+ mask_padding_with_zero=True,
+ unk_token='[UNK]',
+ do_padding=True
+):
+ """Loads a data file into a list of `InputBatch`s
+ `cls_token_at_end` define the location of the CLS token:
+ - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
+ - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
+ `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
+
+ Args:
+ words:
+ max_seq_length:
+ tokenizer:
+ labels: (Default value = None)
+ label_map: (Default value = None)
+ cls_token_at_end: (Default value = False)
+ cls_token: (Default value = "[CLS]")
+ cls_token_segment_id: (Default value = 1)
+ sep_token: (Default value = "[SEP]")
+ sep_token_extra: (Default value = False)
+ pad_on_left: (Default value = False)
+ pad_token_id: (Default value = 0)
+ pad_token_segment_id: (Default value = 0)
+ pad_token_label_id: (Default value = 0)
+ sequence_a_segment_id: (Default value = 0)
+ mask_padding_with_zero: (Default value = True)
+ unk_token: (Default value = '[UNK]')
+ do_padding: (Default value = True)
+
+ Returns:
+
+ """
+ args = locals()
+ if not labels:
+ labels = words
+ pad_token_label_id = False
+
+ tokens = []
+ label_ids = []
+ for word, label in zip(words, labels):
+ word_tokens = tokenizer.tokenize(word)
+ if not word_tokens:
+ # some wired chars cause the tagger to return empty list
+ word_tokens = [unk_token] * len(word)
+ tokens.extend(word_tokens)
+ # Use the real label id for the first token of the word, and padding ids for the remaining tokens
+ label_ids.extend([label_map[label] if label_map else True] + [pad_token_label_id] * (len(word_tokens) - 1))
+
+ # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
+ special_tokens_count = 3 if sep_token_extra else 2
+ if max_seq_length and len(tokens) > max_seq_length - special_tokens_count:
+ warnings.warn(
+ f'Input tokens {words} exceed the max sequence length of {max_seq_length - special_tokens_count}. '
+ f'The exceeded part will be truncated and ignored. '
+ f'You are recommended to split your long text into several sentences within '
+ f'{max_seq_length - special_tokens_count} tokens beforehand.')
+ tokens = tokens[: (max_seq_length - special_tokens_count)]
+ label_ids = label_ids[: (max_seq_length - special_tokens_count)]
+
+ # The convention in BERT is:
+ # (a) For sequence pairs:
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ # token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ # (b) For single sequences:
+ # tokens: [CLS] the dog is hairy . [SEP]
+ # token_type_ids: 0 0 0 0 0 0 0
+ #
+ # Where "token_type_ids" are used to indicate whether this is the first
+ # sequence or the second sequence. The embedding vectors for `type=0` and
+ # `type=1` were learned during pre-training and are added to the wordpiece
+ # embedding vector (and position vector). This is not *strictly* necessary
+ # since the [SEP] token unambiguously separates the sequences, but it makes
+ # it easier for the model to learn the concept of sequences.
+ #
+ # For classification tasks, the first vector (corresponding to [CLS]) is
+ # used as as the "sentence vector". Note that this only makes sense because
+ # the entire model is fine-tuned.
+ tokens += [sep_token]
+ label_ids += [pad_token_label_id]
+ if sep_token_extra:
+ # roberta uses an extra separator b/w pairs of sentences
+ tokens += [sep_token]
+ label_ids += [pad_token_label_id]
+ segment_ids = [sequence_a_segment_id] * len(tokens)
+
+ if cls_token_at_end:
+ tokens += [cls_token]
+ label_ids += [pad_token_label_id]
+ segment_ids += [cls_token_segment_id]
+ else:
+ tokens = [cls_token] + tokens
+ label_ids = [pad_token_label_id] + label_ids
+ segment_ids = [cls_token_segment_id] + segment_ids
+
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
+
+ if do_padding:
+ # Zero-pad up to the sequence length.
+ padding_length = max_seq_length - len(input_ids)
+ if pad_on_left:
+ input_ids = ([pad_token_id] * padding_length) + input_ids
+ input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
+ segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
+ label_ids = ([pad_token_label_id] * padding_length) + label_ids
+ else:
+ input_ids += [pad_token_id] * padding_length
+ input_mask += [0 if mask_padding_with_zero else 1] * padding_length
+ segment_ids += [pad_token_segment_id] * padding_length
+ label_ids += [pad_token_label_id] * padding_length
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+ assert len(label_ids) == max_seq_length, f'failed for:\n {args}'
+ else:
+ assert len(set(len(x) for x in [input_ids, input_mask, segment_ids, label_ids])) == 1
+ return input_ids, input_mask, segment_ids, label_ids
+
+
+def main():
+ transformer = 'bert-base-uncased'
+ tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(transformer)
+ # _test_text_transform(tokenizer)
+ _test_sequence_transform(tokenizer)
+
+
+def _test_text_transform(tokenizer):
+ transform = TransformerTextTokenizer(tokenizer, 'text')
+ sample = {'text': 'HanLP good'}
+ print(transform(sample))
+
+
+def _test_sequence_transform(tokenizer):
+ transform = TransformerSequenceTokenizer(tokenizer, 'token')
+ sample = {'token': 'HanLP good'.split()}
+ print(transform(sample))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/hanlp/transform/tsv.py b/hanlp/transform/tsv.py
index 6a95778be..e4eb236c1 100644
--- a/hanlp/transform/tsv.py
+++ b/hanlp/transform/tsv.py
@@ -7,18 +7,18 @@
import tensorflow as tf
-from hanlp.common.structure import SerializableDict
+from hanlp_common.structure import SerializableDict
-from hanlp.common.transform import Transform
-from hanlp.common.vocab import Vocab
-from hanlp.utils.io_util import generator_words_tags
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
+from hanlp.utils.io_util import generate_words_tags_from_tsv
from hanlp.utils.tf_util import str_tensor_to_str
-from hanlp.utils.util import merge_locals_kwargs
+from hanlp_common.util import merge_locals_kwargs
-def dataset_from_tsv(tsv_file_path, word_vocab: Vocab, char_vocab: Vocab, tag_vocab: Vocab, batch_size=32,
+def dataset_from_tsv(tsv_file_path, word_vocab: VocabTF, char_vocab: VocabTF, tag_vocab: VocabTF, batch_size=32,
shuffle=None, repeat=None, prefetch=1, lower=False, **kwargs):
- generator = functools.partial(generator_words_tags, tsv_file_path, word_vocab, char_vocab, tag_vocab, lower)
+ generator = functools.partial(generate_words_tags_from_tsv, tsv_file_path, word_vocab, char_vocab, tag_vocab, lower)
return dataset_from_generator(generator, word_vocab, tag_vocab, batch_size, shuffle, repeat, prefetch,
**kwargs)
@@ -40,10 +40,10 @@ def dataset_from_generator(generator, word_vocab, tag_vocab, batch_size=32, shuf
def vocab_from_tsv(tsv_file_path, lower=False, lock_word_vocab=False, lock_char_vocab=True, lock_tag_vocab=True) \
- -> Tuple[Vocab, Vocab, Vocab]:
- word_vocab = Vocab()
- char_vocab = Vocab()
- tag_vocab = Vocab(unk_token=None)
+ -> Tuple[VocabTF, VocabTF, VocabTF]:
+ word_vocab = VocabTF()
+ char_vocab = VocabTF()
+ tag_vocab = VocabTF(unk_token=None)
with open(tsv_file_path, encoding='utf-8') as tsv_file:
for line in tsv_file:
cells = line.strip().split()
@@ -67,8 +67,8 @@ def vocab_from_tsv(tsv_file_path, lower=False, lock_word_vocab=False, lock_char_
class TsvTaggingFormat(Transform, ABC):
def file_to_inputs(self, filepath: str, gold=True):
assert gold, 'TsvTaggingFormat does not support reading non-gold files'
- yield from generator_words_tags(filepath, gold=gold, lower=self.config.get('lower', False),
- max_seq_length=self.max_seq_length)
+ yield from generate_words_tags_from_tsv(filepath, gold=gold, lower=self.config.get('lower', False),
+ max_seq_length=self.max_seq_length)
@property
def max_seq_length(self):
@@ -78,20 +78,20 @@ def max_seq_length(self):
class TSVTaggingTransform(TsvTaggingFormat, Transform):
def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, use_char=False, **kwargs) -> None:
super().__init__(**merge_locals_kwargs(locals(), kwargs))
- self.word_vocab: Optional[Vocab] = None
- self.tag_vocab: Optional[Vocab] = None
- self.char_vocab: Optional[Vocab] = None
+ self.word_vocab: Optional[VocabTF] = None
+ self.tag_vocab: Optional[VocabTF] = None
+ self.char_vocab: Optional[VocabTF] = None
def fit(self, trn_path: str, **kwargs) -> int:
- self.word_vocab = Vocab()
- self.tag_vocab = Vocab(pad_token=None, unk_token=None)
+ self.word_vocab = VocabTF()
+ self.tag_vocab = VocabTF(pad_token=None, unk_token=None)
num_samples = 0
for words, tags in self.file_to_inputs(trn_path, True):
self.word_vocab.update(words)
self.tag_vocab.update(tags)
num_samples += 1
if self.char_vocab:
- self.char_vocab = Vocab()
+ self.char_vocab = VocabTF()
for word in self.word_vocab.token_to_idx.keys():
if word in (self.word_vocab.pad_token, self.word_vocab.unk_token):
continue
diff --git a/hanlp/transform/txt.py b/hanlp/transform/txt.py
index f539c36c4..0db529533 100644
--- a/hanlp/transform/txt.py
+++ b/hanlp/transform/txt.py
@@ -7,10 +7,11 @@
import tensorflow as tf
-from hanlp.common.transform import Transform
-from hanlp.common.vocab import Vocab
+from hanlp.common.transform_tf import Transform
+from hanlp.common.vocab_tf import VocabTF
from hanlp.utils.io_util import get_resource
from hanlp.utils.lang.zh.char_table import CharTable
+from hanlp.utils.span_util import bmes_of, bmes_to_words
from hanlp.utils.string_util import split_long_sent
@@ -35,23 +36,6 @@ def words_to_bmes(words):
return tags
-def bmes_to_words(chars, tags):
- result = []
- if len(chars) == 0:
- return result
- word = chars[0]
-
- for c, t in zip(chars[1:], tags[1:]):
- if t == 'B' or t == 'S':
- result.append(word)
- word = ''
- word += c
- if len(word) != 0:
- result.append(word)
-
- return result
-
-
def extract_ngram_features_and_tags(sentence, bigram_only=False, window_size=4, segmented=True):
"""
Feature extraction for windowed approaches
@@ -77,23 +61,6 @@ def extract_ngram_features_and_tags(sentence, bigram_only=False, window_size=4,
return tuple(ret[:-1]), ret[-1] # x, y
-def bmes_of(sentence, segmented):
- if segmented:
- chars = []
- tags = []
- words = sentence.split()
- for w in words:
- chars.extend(list(w))
- if len(w) == 1:
- tags.append('S')
- else:
- tags.extend(['B'] + ['M'] * (len(w) - 2) + ['E'])
- else:
- chars = list(sentence)
- tags = ['S'] * len(chars)
- return chars, tags
-
-
def extract_ngram_features(chars, bigram_only, window_size):
ret = []
if bigram_only:
@@ -141,8 +108,8 @@ def generate_ngram_bmes(file_path, bigram_only=False, window_size=4, gold=True):
yield extract_ngram_features_and_tags(sentence, bigram_only, window_size, gold)
-def vocab_from_txt(txt_file_path, bigram_only=False, window_size=4, **kwargs) -> Tuple[Vocab, Vocab, Vocab]:
- char_vocab, ngram_vocab, tag_vocab = Vocab(), Vocab(), Vocab(pad_token=None, unk_token=None)
+def vocab_from_txt(txt_file_path, bigram_only=False, window_size=4, **kwargs) -> Tuple[VocabTF, VocabTF, VocabTF]:
+ char_vocab, ngram_vocab, tag_vocab = VocabTF(), VocabTF(), VocabTF(pad_token=None, unk_token=None)
for X, Y in generate_ngram_bmes(txt_file_path, bigram_only, window_size, gold=True):
char_vocab.update(X[0])
for ngram in X[1:]:
@@ -151,7 +118,8 @@ def vocab_from_txt(txt_file_path, bigram_only=False, window_size=4, **kwargs) ->
return char_vocab, ngram_vocab, tag_vocab
-def dataset_from_txt(txt_file_path: str, char_vocab: Vocab, ngram_vocab: Vocab, tag_vocab: Vocab, bigram_only=False,
+def dataset_from_txt(txt_file_path: str, char_vocab: VocabTF, ngram_vocab: VocabTF, tag_vocab: VocabTF,
+ bigram_only=False,
window_size=4, segmented=True, batch_size=32, shuffle=None, repeat=None, prefetch=1):
generator = functools.partial(generate_ngram_bmes, txt_file_path, bigram_only, window_size, segmented)
return dataset_from_generator(generator, char_vocab, ngram_vocab, tag_vocab, bigram_only, window_size, batch_size,
diff --git a/hanlp/utils/__init__.py b/hanlp/utils/__init__.py
index 7501252f5..898a5a3fc 100644
--- a/hanlp/utils/__init__.py
+++ b/hanlp/utils/__init__.py
@@ -1,5 +1,19 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-08-24 22:12
-global_cache = {}
-from . import rules
\ No newline at end of file
+from . import rules
+
+
+def ls_resource_in_module(root) -> dict:
+ res = dict()
+ for k, v in root.__dict__.items():
+ if k.startswith('_') or v == root:
+ continue
+ if isinstance(v, str):
+ if v.startswith('http') and not v.endswith('/') and not v.endswith('#') and not v.startswith('_'):
+ res[k] = v
+ elif type(v).__name__ == 'module':
+ res.update(ls_resource_in_module(v))
+ if 'ALL' in root.__dict__ and isinstance(root.__dict__['ALL'], dict):
+ root.__dict__['ALL'].update(res)
+ return res
diff --git a/hanlp/utils/component_util.py b/hanlp/utils/component_util.py
index 87abcdae0..9647c60d1 100644
--- a/hanlp/utils/component_util.py
+++ b/hanlp/utils/component_util.py
@@ -5,15 +5,29 @@
import traceback
from sys import exit
+from hanlp_common.constant import HANLP_VERBOSE
+
from hanlp import pretrained
from hanlp.common.component import Component
-from hanlp.utils.io_util import get_resource, load_json, eprint, get_latest_info_from_pypi
-from hanlp.utils.reflection import object_from_class_path, str_to_type
+from hanlp.utils.io_util import get_resource, get_latest_info_from_pypi
+from hanlp_common.io import load_json, eprint
+from hanlp_common.reflection import object_from_classpath, str_to_type
from hanlp import version
-def load_from_meta_file(save_dir: str, meta_filename='meta.json', transform_only=False, load_kwargs=None,
+def load_from_meta_file(save_dir: str, meta_filename='meta.json', transform_only=False, verbose=HANLP_VERBOSE,
**kwargs) -> Component:
+ """
+
+ Args:
+ save_dir:
+ meta_filename (str): The meta file of that saved component, which stores the classpath and version.
+ transform_only:
+ **kwargs:
+
+ Returns:
+
+ """
identifier = save_dir
load_path = save_dir
save_dir = get_resource(save_dir)
@@ -21,6 +35,11 @@ def load_from_meta_file(save_dir: str, meta_filename='meta.json', transform_only
meta_filename = os.path.basename(save_dir)
save_dir = os.path.dirname(save_dir)
metapath = os.path.join(save_dir, meta_filename)
+ if not os.path.isfile(metapath):
+ tf_model = False
+ metapath = os.path.join(save_dir, 'config.json')
+ else:
+ tf_model = True
if not os.path.isfile(metapath):
tips = ''
if save_dir.isupper():
@@ -33,22 +52,36 @@ def load_from_meta_file(save_dir: str, meta_filename='meta.json', transform_only
f'Tips: it might be one of {similar_keys}'
raise FileNotFoundError(f'The identifier {save_dir} resolves to a non-exist meta file {metapath}. {tips}')
meta: dict = load_json(metapath)
- cls = meta.get('class_path', None)
- assert cls, f'{meta_filename} doesn\'t contain class_path field'
+ cls = meta.get('classpath', None)
+ if not cls:
+ cls = meta.get('class_path', None) # For older version
+ if tf_model:
+ # tf models are trained with version <= 2.0. To migrate them to 2.1, map their classpath to new locations
+ upgrade = {
+ 'hanlp.components.tok.TransformerTokenizer': 'hanlp.components.tok_tf.TransformerTokenizerTF',
+ 'hanlp.components.pos.RNNPartOfSpeechTagger': 'hanlp.components.pos_tf.RNNPartOfSpeechTaggerTF',
+ 'hanlp.components.pos.CNNPartOfSpeechTagger': 'hanlp.components.pos_tf.CNNPartOfSpeechTaggerTF',
+ 'hanlp.components.ner.TransformerNamedEntityRecognizer': 'hanlp.components.ner_tf.TransformerNamedEntityRecognizerTF',
+ 'hanlp.components.parsers.biaffine_parser.BiaffineDependencyParser': 'hanlp.components.parsers.biaffine_parser_tf.BiaffineDependencyParserTF',
+ 'hanlp.components.parsers.biaffine_parser.BiaffineSemanticDependencyParser': 'hanlp.components.parsers.biaffine_parser_tf.BiaffineSemanticDependencyParserTF',
+ 'hanlp.components.tok.NgramConvTokenizer': 'hanlp.components.tok_tf.NgramConvTokenizerTF',
+ 'hanlp.components.classifiers.transformer_classifier.TransformerClassifier': 'hanlp.components.classifiers.transformer_classifier_tf.TransformerClassifierTF',
+ 'hanlp.components.taggers.transformers.transformer_tagger.TransformerTagger': 'hanlp.components.taggers.transformers.transformer_tagger_tf.TransformerTaggerTF',
+ }
+ cls = upgrade.get(cls, cls)
+ assert cls, f'{meta_filename} doesn\'t contain classpath field'
try:
- obj: Component = object_from_class_path(cls, **kwargs)
+ obj: Component = object_from_classpath(cls)
if hasattr(obj, 'load'):
if transform_only:
# noinspection PyUnresolvedReferences
obj.load_transform(save_dir)
else:
- if load_kwargs is None:
- load_kwargs = {}
if os.path.isfile(os.path.join(save_dir, 'config.json')):
- obj.load(save_dir, **load_kwargs)
+ obj.load(save_dir, verbose=verbose, **kwargs)
else:
- obj.load(metapath, **load_kwargs)
- obj.meta['load_path'] = load_path
+ obj.load(metapath, **kwargs)
+ obj.config['load_path'] = load_path
return obj
except Exception as e:
eprint(f'Failed to load {identifier}. See traceback below:')
@@ -82,7 +115,9 @@ def load_from_meta_file(save_dir: str, meta_filename='meta.json', transform_only
def load_from_meta(meta: dict) -> Component:
- cls = meta.get('class_path', None)
- assert cls, f'{meta} doesn\'t contain class_path field'
+ if 'load_path' in meta:
+ return load_from_meta_file(meta['load_path'])
+ cls = meta.get('class_path', None) or meta.get('classpath', None)
+ assert cls, f'{meta} doesn\'t contain classpath field'
cls = str_to_type(cls)
- return cls.from_meta(meta)
+ return cls.from_config(meta)
diff --git a/hanlp/utils/english_tokenizer.py b/hanlp/utils/english_tokenizer.py
old mode 100755
new mode 100644
index 5ec43a27e..504e9c696
--- a/hanlp/utils/english_tokenizer.py
+++ b/hanlp/utils/english_tokenizer.py
@@ -322,4 +322,4 @@ def tokenize_english(sentence):
chunks.append(token)
tokens = chunks
results.append(tokens)
- return results[0] if flat else results
+ return results[0] if flat else results
\ No newline at end of file
diff --git a/hanlp/utils/file_read_backwards/buffer_work_space.py b/hanlp/utils/file_read_backwards/buffer_work_space.py
old mode 100755
new mode 100644
index e604edd0d..9e6b448f2
--- a/hanlp/utils/file_read_backwards/buffer_work_space.py
+++ b/hanlp/utils/file_read_backwards/buffer_work_space.py
@@ -30,8 +30,11 @@ def add_to_buffer(self, content, read_position):
"""Add additional bytes content as read from the read_position.
Args:
- content (bytes): data to be added to buffer working BufferWorkSpac.
- read_position (int): where in the file pointer the data was read from.
+ content(bytes): data to be added to buffer working BufferWorkSpac.
+ read_position(int): where in the file pointer the data was read from.
+
+ Returns:
+
"""
self.read_position = read_position
if self.read_buffer is None:
@@ -40,7 +43,7 @@ def add_to_buffer(self, content, read_position):
self.read_buffer = content + self.read_buffer
def yieldable(self):
- """Return True if there is a line that the buffer can return, False otherwise."""
+ """ """
if self.read_buffer is None:
return False
@@ -55,9 +58,13 @@ def yieldable(self):
return False
def return_line(self):
- """Return a new line if it is available.
+ """
+
+ Args:
+
+ Returns:
+ Precondition: self.yieldable() must be True
- Precondition: self.yieldable() must be True
"""
assert(self.yieldable())
@@ -82,7 +89,7 @@ def read_until_yieldable(self):
self.add_to_buffer(read_content, read_position)
def has_returned_every_line(self):
- """Return True if every single line in the file has been returned, False otherwise."""
+ """ """
if self.read_position == 0 and self.read_buffer is None:
return True
return False
@@ -96,12 +103,13 @@ def _get_next_chunk(fp, previously_read_position, chunk_size):
"""Return next chunk of data that we would from the file pointer.
Args:
- fp: file-like object
- previously_read_position: file pointer position that we have read from
- chunk_size: desired read chunk_size
+ fp: file
+ previously_read_position: file pointer position that we have read from
+ chunk_size: desired read chunk_size
Returns:
- (bytestring, int): data that has been read in, the file pointer position where the data has been read from
+ (bytestring, int): data that has been read in, the file pointer position where the data has been read from
+
"""
seek_position, read_size = _get_what_to_read_next(fp, previously_read_position, chunk_size)
fp.seek(seek_position)
@@ -114,12 +122,14 @@ def _get_what_to_read_next(fp, previously_read_position, chunk_size):
"""Return information on which file pointer position to read from and how many bytes.
Args:
- fp
- past_read_positon (int): The file pointer position that has been read previously
- chunk_size(int): ideal io chunk_size
+ fp:
+ past_read_positon: int
+ chunk_size: int
+ previously_read_position:
Returns:
- (int, int): The next seek position, how many bytes to read next
+ (int, int): The next seek position, how many bytes to read next
+
"""
seek_position = max(previously_read_position - chunk_size, 0)
read_size = chunk_size
@@ -146,8 +156,12 @@ def _get_what_to_read_next(fp, previously_read_position, chunk_size):
def _remove_trailing_new_line(l):
"""Remove a single instance of new line at the end of l if it exists.
+ Args:
+ l:
+
Returns:
- bytestring
+ : bytestring
+
"""
# replace only 1 instance of newline
# match longest line first (hence the reverse=True), we want to match "\r\n" rather than "\n" if we can
@@ -162,10 +176,11 @@ def _find_furthest_new_line(read_buffer):
"""Return -1 if read_buffer does not contain new line otherwise the position of the rightmost newline.
Args:
- read_buffer (bytestring)
+ read_buffer: bytestring
Returns:
- int: The right most position of new line character in read_buffer if found, else -1
+ int: The right most position of new line character in read_buffer if found, else -1
+
"""
new_line_positions = [read_buffer.rfind(n) for n in new_lines_bytes]
return max(new_line_positions)
@@ -175,10 +190,11 @@ def _is_partially_read_new_line(b):
"""Return True when b is part of a new line separator found at index >= 1, False otherwise.
Args:
- b (bytestring)
+ b: bytestring
Returns:
- bool
+ bool
+
"""
for n in new_lines_bytes:
if n.find(b) >= 1:
diff --git a/hanlp/utils/file_read_backwards/file_read_backwards.py b/hanlp/utils/file_read_backwards/file_read_backwards.py
old mode 100755
new mode 100644
index fbda47f89..29bd9ef73
--- a/hanlp/utils/file_read_backwards/file_read_backwards.py
+++ b/hanlp/utils/file_read_backwards/file_read_backwards.py
@@ -14,12 +14,17 @@
class FileReadBackwards:
"""Class definition for `FileReadBackwards`.
-
+
A `FileReadBackwards` will spawn a `FileReadBackwardsIterator` and keep an opened file handler.
-
+
It can be used as a Context Manager. If done so, when exited, it will close its file handler.
-
+
In any mode, `close()` can be called to close the file handler..
+
+ Args:
+
+ Returns:
+
"""
def __init__(self, path, encoding="utf-8", chunk_size=io.DEFAULT_BUFFER_SIZE):
@@ -57,7 +62,7 @@ def close(self):
self.iterator.close()
def readline(self):
- """Return a line content (with a trailing newline) if there are content. Return '' otherwise."""
+ """ """
try:
r = next(self.iterator) + os.linesep
@@ -68,8 +73,13 @@ def readline(self):
class FileReadBackwardsIterator:
"""Iterator for `FileReadBackwards`.
-
+
This will read backwards line by line a file. It holds an opened file handler.
+
+ Args:
+
+ Returns:
+
"""
def __init__(self, fp, encoding, chunk_size):
"""Constructor for FileReadBackwardsIterator
@@ -90,13 +100,18 @@ def __iter__(self):
def next(self):
"""Returns unicode string from the last line until the beginning of file.
-
+
Gets exhausted if::
-
+
* already reached the beginning of the file on previous iteration
* the file got closed
-
+
When it gets exhausted, it closes the file handler.
+
+ Args:
+
+ Returns:
+
"""
# Using binary mode, because some encodings such as "utf-8" use variable number of
# bytes to encode different Unicode points.
@@ -116,8 +131,13 @@ def next(self):
@property
def closed(self):
"""The status of the file handler.
-
+
:return: True if the file handler is still opened. False otherwise.
+
+ Args:
+
+ Returns:
+
"""
return self.__fp.closed
diff --git a/hanlp/utils/init_util.py b/hanlp/utils/init_util.py
new file mode 100644
index 000000000..eb8dbb04c
--- /dev/null
+++ b/hanlp/utils/init_util.py
@@ -0,0 +1,16 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-27 13:25
+import math
+
+import torch
+from torch import nn
+import functools
+
+
+def embedding_uniform(tensor:torch.Tensor, seed=233):
+ gen = torch.Generator().manual_seed(seed)
+ with torch.no_grad():
+ fan_out = tensor.size(-1)
+ bound = math.sqrt(3.0 / fan_out)
+ return tensor.uniform_(-bound, bound, generator=gen)
diff --git a/hanlp/utils/io_util.py b/hanlp/utils/io_util.py
index e9c53df97..2d2a66010 100644
--- a/hanlp/utils/io_util.py
+++ b/hanlp/utils/io_util.py
@@ -1,58 +1,52 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-08-26 15:02
+import contextlib
import glob
+import gzip
import json
+import logging
import os
-import pickle
import platform
import random
+import shlex
import shutil
import sys
-from sys import exit
-from contextlib import contextmanager
+import tarfile
import tempfile
import time
import urllib
import zipfile
-import tarfile
-from typing import Dict, Tuple, Optional
+from contextlib import contextmanager
+from pathlib import Path
+from subprocess import Popen, PIPE
+from typing import Dict, Tuple, Optional, Union, List
from urllib.parse import urlparse
from urllib.request import urlretrieve
-from pathlib import Path
+
import numpy as np
+import torch
from pkg_resources import parse_version
+
+import hanlp
+from hanlp_common.constant import HANLP_URL, HANLP_VERBOSE
from hanlp.utils import time_util
-from hanlp.utils.log_util import logger
+from hanlp.utils.log_util import logger, flash
from hanlp.utils.string_util import split_long_sentence_into
-from hanlp.utils.time_util import now_filename
-from hanlp.common.constant import HANLP_URL
+from hanlp.utils.time_util import now_filename, CountdownTimer
from hanlp.version import __version__
+from hanlp_common.io import save_pickle, load_pickle, eprint
-def save_pickle(item, path):
- with open(path, 'wb') as f:
- pickle.dump(item, f)
-
-
-def load_pickle(path):
- with open(path, 'rb') as f:
- return pickle.load(f)
-
-
-def save_json(item: dict, path: str, ensure_ascii=False, cls=None, default=lambda o: repr(o)):
- with open(path, 'w', encoding='utf-8') as out:
- json.dump(item, out, ensure_ascii=ensure_ascii, indent=2, cls=cls, default=default)
-
-
-def load_json(path):
- with open(path, encoding='utf-8') as src:
- return json.load(src)
-
-
-def filename_is_json(filename):
- filename, file_extension = os.path.splitext(filename)
- return file_extension in ['.json', '.jsonl']
+def load_jsonl(path, verbose=False):
+ if verbose:
+ src = TimingFileIterator(path)
+ else:
+ src = open(path, encoding='utf-8')
+ for line in src:
+ yield json.loads(line)
+ if not verbose:
+ src.close()
def make_debug_corpus(path, delimiter=None, percentage=0.1, max_samples=100):
@@ -99,12 +93,16 @@ def tempdir_human():
class NumpyEncoder(json.JSONEncoder):
- """
- Special json encoder for numpy types
- See https://interviewbubble.com/typeerror-object-of-type-float32-is-not-json-serializable/
- """
-
def default(self, obj):
+ """Special json encoder for numpy types
+ See https://interviewbubble.com/typeerror-object-of-type-float32-is-not-json-serializable/
+
+ Args:
+ obj: Object to be json encoded.
+
+ Returns:
+ Json string.
+ """
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
np.int16, np.int32, np.int64, np.uint8,
np.uint16, np.uint32, np.uint64)):
@@ -118,10 +116,7 @@ def default(self, obj):
def hanlp_home_default():
- """
-
- :return: default data directory depending on the platform and environment variables
- """
+ """Default data directory depending on the platform and environment variables"""
if windows():
return os.path.join(os.environ.get('APPDATA'), 'hanlp')
else:
@@ -134,9 +129,19 @@ def windows():
def hanlp_home():
- """
+ """ Home directory for HanLP resources.
+
+ Returns:
+ Data directory in the filesystem for storage, for example when downloading models.
+
+ This home directory can be customized with the following shell command or equivalent environment variable on Windows
+ systems.
+
+ .. highlight:: bash
+ .. code-block:: bash
+
+ $ export HANLP_HOME=/data/hanlp
- :return: data directory in the filesystem for storage, for example when downloading models
"""
return os.getenv('HANLP_HOME', hanlp_home_default())
@@ -150,23 +155,21 @@ def remove_file(filename):
os.remove(filename)
-def eprint(*args, **kwargs):
- print(*args, file=sys.stderr, **kwargs)
-
-
def parent_dir(path):
return os.path.normpath(os.path.join(path, os.pardir))
-def download(url, save_path=None, save_dir=hanlp_home(), prefix=HANLP_URL, append_location=True):
+def download(url, save_path=None, save_dir=hanlp_home(), prefix=HANLP_URL, append_location=True, verbose=HANLP_VERBOSE):
if not save_path:
save_path = path_from_url(url, save_dir, prefix, append_location)
if os.path.isfile(save_path):
- eprint('Using local {}, ignore {}'.format(save_path, url))
+ if verbose:
+ eprint('Using local {}, ignore {}'.format(save_path, url))
return save_path
else:
makedirs(parent_dir(save_path))
- eprint('Downloading {} to {}'.format(url, save_path))
+ if verbose:
+ eprint('Downloading {} to {}'.format(url, save_path))
tmp_path = '{}.downloading'.format(save_path)
remove_file(tmp_path)
try:
@@ -188,10 +191,11 @@ def reporthook(count, block_size, total_size):
eta = duration / ratio * (1 - ratio)
speed = human_bytes(speed)
progress_size = human_bytes(progress_size)
- sys.stderr.write("\r%.2f%%, %s/%s, %s/s, ETA %s " %
- (percent, progress_size, human_bytes(total_size), speed,
- time_util.report_time_delta(eta)))
- sys.stderr.flush()
+ if verbose:
+ sys.stderr.write("\r%.2f%%, %s/%s, %s/s, ETA %s " %
+ (percent, progress_size, human_bytes(total_size), speed,
+ time_util.report_time_delta(eta)))
+ sys.stderr.flush()
import socket
socket.setdefaulttimeout(10)
@@ -228,76 +232,100 @@ def parse_url_path(url):
return parsed.netloc, path
-def uncompress(path, dest=None, remove=True):
- """
- uncompress a file
-
- Parameters
- ----------
- path
- The path to a compressed file
- dest
- The dest folder
- remove
- Remove compressed file after unzipping
- Returns
- -------
- The folder which contains the unzipped content if the zip contains multiple files,
- otherwise the path to the unique file
+def uncompress(path, dest=None, remove=True, verbose=HANLP_VERBOSE):
+ """uncompress a file
+
+ Args:
+ path: The path to a compressed file
+ dest: The dest folder.
+ remove: Remove archive file after decompression.
+ verbose: ``True`` to print log message.
+
+ Returns:
+ Destination path.
+
"""
# assert path.endswith('.zip')
- prefix, ext = os.path.splitext(path)
+ prefix, ext = split_if_compressed(path)
folder_name = os.path.basename(prefix)
file_is_zip = ext == '.zip'
root_of_folder = None
- with zipfile.ZipFile(path, "r") if ext == '.zip' else tarfile.open(path, 'r:*') as archive:
- try:
- if not dest:
- namelist = sorted(archive.namelist() if file_is_zip else archive.getnames())
- root_of_folder = namelist[0].strip('/') if len(
- namelist) > 1 else '' # only one file, root_of_folder = ''
- if all(f.split('/')[0] == root_of_folder for f in namelist[1:]) or not root_of_folder:
- dest = os.path.dirname(path) # only one folder, unzip to the same dir
- else:
- root_of_folder = None
- dest = prefix # assume zip contains more than one files or folders
- eprint('Extracting {} to {}'.format(path, dest))
- archive.extractall(dest)
- if root_of_folder:
- if root_of_folder != folder_name:
- # move root to match folder name
- os.rename(path_join(dest, root_of_folder), path_join(dest, folder_name))
- dest = path_join(dest, folder_name)
- elif len(namelist) == 1:
- dest = path_join(dest, namelist[0])
- except (RuntimeError, KeyboardInterrupt) as e:
- remove = False
- if os.path.exists(dest):
- if os.path.isfile(dest):
- os.remove(dest)
- else:
- shutil.rmtree(dest)
- raise e
+ if ext == '.gz':
+ with gzip.open(path, 'rb') as f_in, open(prefix, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ else:
+ with zipfile.ZipFile(path, "r") if ext == '.zip' else tarfile.open(path, 'r:*') as archive:
+ try:
+ if not dest:
+ namelist = sorted(archive.namelist() if file_is_zip else archive.getnames())
+ if namelist[0] == '.':
+ namelist = namelist[1:]
+ namelist = [p[len('./'):] if p.startswith('./') else p for p in namelist]
+ if ext == '.tgz':
+ roots = set(x.split('/')[0] for x in namelist)
+ if len(roots) == 1:
+ root_of_folder = next(iter(roots))
+ else:
+ # only one file, root_of_folder = ''
+ root_of_folder = namelist[0].strip('/') if len(namelist) > 1 else ''
+ if all(f.split('/')[0] == root_of_folder for f in namelist[1:]) or not root_of_folder:
+ dest = os.path.dirname(path) # only one folder, unzip to the same dir
+ else:
+ root_of_folder = None
+ dest = prefix # assume zip contains more than one files or folders
+ if verbose:
+ eprint('Extracting {} to {}'.format(path, dest))
+ archive.extractall(dest)
+ if root_of_folder:
+ if root_of_folder != folder_name:
+ # move root to match folder name
+ os.rename(path_join(dest, root_of_folder), path_join(dest, folder_name))
+ dest = path_join(dest, folder_name)
+ elif len(namelist) == 1:
+ dest = path_join(dest, namelist[0])
+ except (RuntimeError, KeyboardInterrupt) as e:
+ remove = False
+ if os.path.exists(dest):
+ if os.path.isfile(dest):
+ os.remove(dest)
+ else:
+ shutil.rmtree(dest)
+ raise e
if remove:
remove_file(path)
return dest
def split_if_compressed(path: str, compressed_ext=('.zip', '.tgz', '.gz', 'bz2', '.xz')) -> Tuple[str, Optional[str]]:
- root, ext = os.path.splitext(path)
- if ext in compressed_ext:
+ tar_gz = '.tar.gz'
+ if path.endswith(tar_gz):
+ root, ext = path[:-len(tar_gz)], tar_gz
+ else:
+ root, ext = os.path.splitext(path)
+ if ext in compressed_ext or ext == tar_gz:
return root, ext
return path, None
-def get_resource(path: str, save_dir=None, extract=True, prefix=HANLP_URL, append_location=True):
- """
- Fetch real path for a resource (model, corpus, whatever)
- :param path: the general path (can be a url or a real path)
- :param extract: whether to unzip it if it's a zip file
- :param save_dir:
- :return: the real path to the resource
+def get_resource(path: str, save_dir=None, extract=True, prefix=HANLP_URL, append_location=True, verbose=HANLP_VERBOSE):
+ """Fetch real path for a resource (model, corpus, whatever) to :meth:`hanlp.utils.io_util.hanlp_home`.
+
+ Args:
+ path: the general path (can be a url or a real path)
+ extract: whether to unzip it if it's a zip file (Default value = True)
+ save_dir: return: the real path to the resource (Default value = None)
+ path: A local path (which will returned as is) or a remote URL (which will be downloaded, decompressed then
+ returned).
+ prefix: A prefix when matched with an URL (path), then that URL is considered to official. For official resources,
+ they will not go to a folder called ``thirdparty`` under :const:`~hanlp_common.constants.IDX`.
+ append_location: (Default value = True)
+ verbose: Whether print log messages.
+
+ Returns:
+ the real path to the resource
+
"""
+ path = hanlp.pretrained.ALL.get(path, path)
anchor: str = None
compressed = None
if os.path.isdir(path):
@@ -324,24 +352,26 @@ def get_resource(path: str, save_dir=None, extract=True, prefix=HANLP_URL, appen
elif os.path.isdir(realpath) or (os.path.isfile(realpath) and (compressed and extract)):
return realpath
else:
- pattern = realpath + '*'
- files = glob.glob(pattern)
- zip_path = realpath + compressed
- if extract and zip_path in files:
- files.remove(zip_path)
- if files:
- if len(files) > 1:
- logger.debug(f'Found multiple files with {pattern}, will use the first one.')
- return files[0]
+ if compressed:
+ pattern = realpath + '.*'
+ files = glob.glob(pattern)
+ files = list(filter(lambda x: not x.endswith('.downloading'), files))
+ zip_path = realpath + compressed
+ if zip_path in files:
+ files.remove(zip_path)
+ if files:
+ if len(files) > 1:
+ logger.debug(f'Found multiple files with {pattern}, will use the first one.')
+ return files[0]
# realpath is where its path after exaction
if compressed:
realpath += compressed
if not os.path.isfile(realpath):
- path = download(url=path, save_path=realpath)
+ path = download(url=path, save_path=realpath, verbose=verbose)
else:
path = realpath
if extract and compressed:
- path = uncompress(path)
+ path = uncompress(path, verbose=verbose)
if anchor:
path = path_join(path, anchor)
@@ -377,7 +407,7 @@ def human_bytes(file_size: int) -> str:
return '%d KB' % file_size
-def read_cells(filepath: str, delimiter='auto', strip=True, skip_header=False):
+def read_cells(filepath: str, delimiter='auto', strip=True, skip_first_line=False):
filepath = get_resource(filepath)
if delimiter == 'auto':
if filepath.endswith('.tsv'):
@@ -387,7 +417,7 @@ def read_cells(filepath: str, delimiter='auto', strip=True, skip_header=False):
else:
delimiter = None
with open(filepath, encoding='utf-8') as src:
- if skip_header:
+ if skip_first_line:
next(src)
for line in src:
line = line.strip()
@@ -400,16 +430,14 @@ def read_cells(filepath: str, delimiter='auto', strip=True, skip_header=False):
def replace_ext(filepath, ext) -> str:
- """
- Replace the extension of filepath to ext
- Parameters
- ----------
- filepath
- ext
+ """ Replace the extension of filepath to ext.
- Returns
- -------
+ Args:
+ filepath: Filepath to be replaced.
+ ext: Extension to replace.
+ Returns:
+ A new path.
"""
file_prefix, _ = os.path.splitext(filepath)
return file_prefix + ext
@@ -420,33 +448,61 @@ def load_word2vec(path, delimiter=' ', cache=True) -> Tuple[Dict[str, np.ndarray
binpath = replace_ext(realpath, '.pkl')
if cache:
try:
+ flash('Loading word2vec from cache [blink][yellow]...[/yellow][/blink]')
word2vec, dim = load_pickle(binpath)
- logger.debug(f'Loaded {binpath}')
+ flash('')
return word2vec, dim
except IOError:
pass
dim = None
word2vec = dict()
- with open(realpath, encoding='utf-8', errors='ignore') as f:
- for idx, line in enumerate(f):
- line = line.rstrip().split(delimiter)
- if len(line) > 2:
- if dim is None:
- dim = len(line)
- else:
- if len(line) != dim:
- logger.warning('{}#{} length mismatches with {}'.format(path, idx + 1, dim))
- continue
- word, vec = line[0], line[1:]
- word2vec[word] = np.array(vec, dtype=np.float32)
+ f = TimingFileIterator(realpath)
+ for idx, line in enumerate(f):
+ f.log('Loading word2vec from text file [blink][yellow]...[/yellow][/blink]')
+ line = line.rstrip().split(delimiter)
+ if len(line) > 2:
+ if dim is None:
+ dim = len(line)
+ else:
+ if len(line) != dim:
+ logger.warning('{}#{} length mismatches with {}'.format(path, idx + 1, dim))
+ continue
+ word, vec = line[0], line[1:]
+ word2vec[word] = np.array(vec, dtype=np.float32)
dim -= 1
if cache:
+ flash('Caching word2vec [blink][yellow]...[/yellow][/blink]')
save_pickle((word2vec, dim), binpath)
- logger.debug(f'Cached {binpath}')
+ flash('')
return word2vec, dim
+def load_word2vec_as_vocab_tensor(path, delimiter=' ', cache=True) -> Tuple[Dict[str, int], torch.Tensor]:
+ realpath = get_resource(path)
+ vocab_path = replace_ext(realpath, '.vocab')
+ matrix_path = replace_ext(realpath, '.pt')
+ if cache:
+ try:
+ flash('Loading vocab and matrix from cache [blink][yellow]...[/yellow][/blink]')
+ vocab = load_pickle(vocab_path)
+ matrix = torch.load(matrix_path, map_location='cpu')
+ flash('')
+ return vocab, matrix
+ except IOError:
+ pass
+
+ word2vec, dim = load_word2vec(path, delimiter, cache)
+ vocab = dict((k, i) for i, k in enumerate(word2vec.keys()))
+ matrix = torch.Tensor(list(word2vec.values()))
+ if cache:
+ flash('Caching vocab and matrix [blink][yellow]...[/yellow][/blink]')
+ save_pickle(vocab, vocab_path)
+ torch.save(matrix, matrix_path)
+ flash('')
+ return vocab, matrix
+
+
def save_word2vec(word2vec: dict, filepath, delimiter=' '):
with open(filepath, 'w', encoding='utf-8') as out:
for w, v in word2vec.items():
@@ -454,30 +510,33 @@ def save_word2vec(word2vec: dict, filepath, delimiter=' '):
out.write(f'{delimiter.join(str(x) for x in v)}\n')
-def read_tsv(tsv_file_path):
+def read_tsv_as_sents(tsv_file_path, ignore_prefix=None, delimiter=None):
sent = []
tsv_file_path = get_resource(tsv_file_path)
with open(tsv_file_path, encoding='utf-8') as tsv_file:
for line in tsv_file:
- cells = line.strip().split()
- if cells:
- # if len(cells) != 2:
- # print(line)
+ if ignore_prefix and line.startswith(ignore_prefix):
+ continue
+ line = line.strip()
+ cells = line.split(delimiter)
+ if line and cells:
sent.append(cells)
- else:
+ elif sent:
yield sent
sent = []
if sent:
yield sent
-def generator_words_tags(tsv_file_path, lower=True, gold=True, max_seq_length=None):
- for sent in read_tsv(tsv_file_path):
+def generate_words_tags_from_tsv(tsv_file_path, lower=False, gold=True, max_seq_length=None, sent_delimiter=None,
+ char_level=False, hard_constraint=False):
+ for sent in read_tsv_as_sents(tsv_file_path):
words = [cells[0] for cells in sent]
- if max_seq_length and len(words) > max_seq_length:
+ if max_seq_length:
offset = 0
# try to split the sequence to make it fit into max_seq_length
- for shorter_words in split_long_sentence_into(words, max_seq_length):
+ for shorter_words in split_long_sentence_into(words, max_seq_length, sent_delimiter, char_level,
+ hard_constraint):
if gold:
shorter_tags = [cells[1] for cells in sent[offset:offset + len(shorter_words)]]
offset += len(shorter_words)
@@ -488,7 +547,10 @@ def generator_words_tags(tsv_file_path, lower=True, gold=True, max_seq_length=No
yield shorter_words, shorter_tags
else:
if gold:
- tags = [cells[1] for cells in sent]
+ try:
+ tags = [cells[1] for cells in sent]
+ except:
+ raise ValueError(f'Failed to load {tsv_file_path}: {sent}')
else:
tags = None
if lower:
@@ -496,12 +558,16 @@ def generator_words_tags(tsv_file_path, lower=True, gold=True, max_seq_length=No
yield words, tags
-def split_file(filepath, train=0.8, valid=0.1, test=0.1, names=None, shuffle=False):
- num_lines = 0
- with open(filepath, encoding='utf-8') as src:
- for line in src:
- num_lines += 1
- splits = {'train': train, 'valid': valid, 'test': test}
+def split_file(filepath, train=0.8, dev=0.1, test=0.1, names=None, shuffle=False):
+ num_samples = 0
+ if filepath.endswith('.tsv'):
+ for sent in read_tsv_as_sents(filepath):
+ num_samples += 1
+ else:
+ with open(filepath, encoding='utf-8') as src:
+ for sample in src:
+ num_samples += 1
+ splits = {'train': train, 'dev': dev, 'test': test}
splits = dict((k, v) for k, v in splits.items() if v)
splits = dict((k, v / sum(splits.values())) for k, v in splits.items())
accumulated = 0
@@ -514,21 +580,30 @@ def split_file(filepath, train=0.8, valid=0.1, test=0.1, names=None, shuffle=Fal
if names is None:
names = {}
name, ext = os.path.splitext(filepath)
- outs = [open(names.get(split, name + '.' + split + ext), 'w', encoding='utf-8') for split in splits.keys()]
+ filenames = [names.get(split, name + '.' + split + ext) for split in splits.keys()]
+ outs = [open(f, 'w', encoding='utf-8') for f in filenames]
if shuffle:
- shuffle = list(range(num_lines))
+ shuffle = list(range(num_samples))
random.shuffle(shuffle)
- with open(filepath, encoding='utf-8') as src:
- for idx, line in enumerate(src):
- if shuffle:
- idx = shuffle[idx]
- ratio = idx / num_lines
- for sid, out in enumerate(outs):
- if r[2 * sid] <= ratio < r[2 * sid + 1]:
- out.write(line)
- break
+ if filepath.endswith('.tsv'):
+ src = read_tsv_as_sents(filepath)
+ else:
+ src = open(filepath, encoding='utf-8')
+ for idx, sample in enumerate(src):
+ if shuffle:
+ idx = shuffle[idx]
+ ratio = idx / num_samples
+ for sid, out in enumerate(outs):
+ if r[2 * sid] <= ratio < r[2 * sid + 1]:
+ if isinstance(sample, list):
+ sample = '\n'.join('\t'.join(x) for x in sample) + '\n\n'
+ out.write(sample)
+ break
+ if not filepath.endswith('.tsv'):
+ src.close()
for out in outs:
out.close()
+ return filenames
def fileno(file_or_fd):
@@ -543,13 +618,13 @@ def fileno(file_or_fd):
@contextmanager
def stdout_redirected(to=os.devnull, stdout=None):
- """
- Redirect stdout to else where
+ """Redirect stdout to else where.
Copied from https://stackoverflow.com/questions/4675728/redirect-stdout-to-a-file-in-python/22434262#22434262
- Parameters
- ----------
- to
- stdout
+
+ Args:
+ to: Target device.
+ stdout: Source device.
+
"""
if windows(): # This doesn't play well with windows
yield None
@@ -582,9 +657,97 @@ def stdout_redirected(to=os.devnull, stdout=None):
pass
-def check_outdated(package='hanlp', version=__version__, repository_url='https://pypi.python.org/pypi/%s/json'):
+def get_exitcode_stdout_stderr(cmd):
+ """Execute the external command and get its exitcode, stdout and stderr.
+ See https://stackoverflow.com/a/21000308/3730690
+
+ Args:
+ cmd: Command.
+
+ Returns:
+ Exit code, stdout, stderr.
"""
- Given the name of a package on PyPI and a version (both strings), checks
+ args = shlex.split(cmd)
+ proc = Popen(args, stdout=PIPE, stderr=PIPE)
+ out, err = proc.communicate()
+ exitcode = proc.returncode
+ return exitcode, out.decode('utf-8'), err.decode('utf-8')
+
+
+def run_cmd(cmd: str) -> str:
+ exitcode, out, err = get_exitcode_stdout_stderr(cmd)
+ if exitcode:
+ raise RuntimeError(err + '\nThe command is:\n' + cmd)
+ return out
+
+
+@contextlib.contextmanager
+def pushd(new_dir):
+ previous_dir = os.getcwd()
+ os.chdir(new_dir)
+ try:
+ yield
+ finally:
+ os.chdir(previous_dir)
+
+
+def basename_no_ext(path):
+ basename = os.path.basename(path)
+ no_ext, ext = os.path.splitext(basename)
+ return no_ext
+
+
+def file_cache(path: str, purge=False):
+ cache_name = path + '.cache'
+ cache_time = os.path.getmtime(cache_name) if os.path.isfile(cache_name) and not purge else 0
+ file_time = os.path.getmtime(path)
+ cache_valid = cache_time > file_time
+ return cache_name, cache_valid
+
+
+def merge_files(files: List[str], dst: str):
+ with open(dst, 'wb') as write:
+ for f in files:
+ with open(f, 'rb') as read:
+ shutil.copyfileobj(read, write)
+
+
+class TimingFileIterator(CountdownTimer):
+
+ def __init__(self, filepath) -> None:
+ super().__init__(os.path.getsize(filepath))
+ self.filepath = filepath
+
+ def __iter__(self):
+ if not os.path.isfile(self.filepath):
+ raise FileNotFoundError(self.filepath)
+ fp = open(self.filepath, encoding='utf-8', errors='ignore')
+ line = fp.readline()
+ while line:
+ yield line
+ self.current = fp.tell()
+ line = fp.readline()
+ fp.close()
+
+ def log(self, info=None, ratio_percentage=True, ratio=True, step=0, interval=0.5, erase=True,
+ logger: Union[logging.Logger, bool] = None, newline=False, ratio_width=None):
+ assert step == 0
+ super().log(info, ratio_percentage, ratio, step, interval, erase, logger, newline, ratio_width)
+
+ @property
+ def ratio(self) -> str:
+ return f'{human_bytes(self.current)}/{human_bytes(self.total)}'
+
+ @property
+ def ratio_width(self) -> int:
+ return len(f'{human_bytes(self.total)}') * 2 + 1
+
+ def close(self):
+ pass
+
+
+def check_outdated(package='hanlp', version=__version__, repository_url='https://pypi.python.org/pypi/%s/json'):
+ """Given the name of a package on PyPI and a version (both strings), checks
if the given version is the latest version of the package available.
Returns a 2-tuple (installed_version, latest_version)
`repository_url` is a `%` style format string
@@ -592,8 +755,15 @@ def check_outdated(package='hanlp', version=__version__, repository_url='https:/
e.g. test.pypi.org or a private repository.
The string is formatted with the package name.
Adopted from https://github.com/alexmojaki/outdated/blob/master/outdated/__init__.py
- """
+ Args:
+ package: Package name.
+ version: Installed version string.
+ repository_url: URL on pypi.
+
+ Returns:
+ Parsed installed version and latest version.
+ """
installed_version = parse_version(version)
latest_version = get_latest_info_from_pypi(package, repository_url)
return installed_version, latest_version
diff --git a/hanlp/utils/lang/zh/char_table.py b/hanlp/utils/lang/zh/char_table.py
index 397ad11d3..e58e0b474 100644
--- a/hanlp/utils/lang/zh/char_table.py
+++ b/hanlp/utils/lang/zh/char_table.py
@@ -4,8 +4,10 @@
from typing import List
from hanlp.utils.io_util import get_resource
+from hanlp_common.io import load_json
-HANLP_CHAR_TABLE = 'https://file.hankcs.com/corpus/char_table.zip#CharTable.txt'
+HANLP_CHAR_TABLE_TXT = 'https://file.hankcs.com/corpus/char_table.zip#CharTable.txt'
+HANLP_CHAR_TABLE_JSON = 'https://file.hankcs.com/corpus/char_table.json.zip'
class CharTable:
@@ -27,10 +29,25 @@ def normalize_chars(chars: List[str]) -> List[str]:
@staticmethod
def _init():
- with open(get_resource(HANLP_CHAR_TABLE), encoding='utf-8') as src:
+ CharTable.convert = CharTable.load()
+
+ @staticmethod
+ def load():
+ mapper = {}
+ with open(get_resource(HANLP_CHAR_TABLE_TXT), encoding='utf-8') as src:
for line in src:
cells = line.rstrip('\n')
if len(cells) != 3:
continue
a, _, b = cells
- CharTable.convert[a] = b
+ mapper[a] = b
+ return mapper
+
+
+class JsonCharTable(CharTable):
+
+ @staticmethod
+ def load():
+ return load_json(get_resource(HANLP_CHAR_TABLE_JSON))
+
+
diff --git a/hanlp/utils/lang/zh/localization.py b/hanlp/utils/lang/zh/localization.py
new file mode 100644
index 000000000..61e72b384
--- /dev/null
+++ b/hanlp/utils/lang/zh/localization.py
@@ -0,0 +1,35 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-05 02:09
+
+task = {
+ 'dep': '依存句法树',
+ 'token': '单词',
+ 'pos': '词性',
+ 'ner': '命名实体',
+ 'srl': '语义角色'
+}
+
+pos = {
+ 'VA': '表语形容词', 'VC': '系动词', 'VE': '动词有无', 'VV': '其他动词', 'NR': '专有名词', 'NT': '时间名词', 'NN': '其他名词',
+ 'LC': '方位词', 'PN': '代词', 'DT': '限定词', 'CD': '概数词', 'OD': '序数词', 'M': '量词', 'AD': '副词', 'P': '介词',
+ 'CC': '并列连接词', 'CS': '从属连词', 'DEC': '补语成分“的”', 'DEG': '属格“的”', 'DER': '表结果的“得”', 'DEV': '表方式的“地”',
+ 'AS': '动态助词', 'SP': '句末助词', 'ETC': '表示省略', 'MSP': '其他小品词', 'IJ': '句首感叹词', 'ON': '象声词',
+ 'LB': '长句式表被动', 'SB': '短句式表被动', 'BA': '把字句', 'JJ': '其他名词修饰语', 'FW': '外来语', 'PU': '标点符号',
+ 'NOI': '噪声', 'URL': '网址'
+}
+
+ner = {
+ 'NT': '机构团体', 'NS': '地名', 'NR': '人名'
+}
+
+dep = {
+ 'nn': '复合名词修饰', 'punct': '标点符号', 'nsubj': '名词性主语', 'conj': '连接性状语', 'dobj': '直接宾语', 'advmod': '名词性状语',
+ 'prep': '介词性修饰语', 'nummod': '数词修饰语', 'amod': '形容词修饰语', 'pobj': '介词性宾语', 'rcmod': '相关关系', 'cpm': '补语',
+ 'assm': '关联标记', 'assmod': '关联修饰', 'cc': '并列关系', 'elf': '类别修饰', 'ccomp': '从句补充', 'det': '限定语', 'lobj': '时间介词',
+ 'range': '数量词间接宾语', 'asp': '时态标记', 'tmod': '时间修饰语', 'plmod': '介词性地点修饰', 'attr': '属性', 'mmod': '情态动词',
+ 'loc': '位置补语', 'top': '主题', 'pccomp': '介词补语', 'etc': '省略关系', 'lccomp': '位置补语', 'ordmod': '量词修饰',
+ 'xsubj': '控制主语', 'neg': '否定修饰', 'rcomp': '结果补语', 'comod': '并列联合动词', 'vmod': '动词修饰', 'prtmod': '小品词',
+ 'ba': '把字关系', 'dvpm': '地字修饰', 'dvpmod': '地字动词短语', 'prnmod': '插入词修饰', 'cop': '系动词', 'pass': '被动标记',
+ 'nsubjpass': '被动名词主语', 'clf': '类别修饰', 'dep': '依赖关系', 'root': '核心关系'
+}
diff --git a/hanlp/utils/log_util.py b/hanlp/utils/log_util.py
index 801842f08..a5134a69d 100644
--- a/hanlp/utils/log_util.py
+++ b/hanlp/utils/log_util.py
@@ -2,19 +2,38 @@
# Author: hankcs
# Date: 2019-08-24 22:12
import datetime
+import io
import logging
import os
import sys
+from logging import LogRecord
+import termcolor
-def init_logger(name=datetime.datetime.now().strftime("%y-%m-%d_%H.%M.%S"), root_dir=None,
- level=logging.INFO, mode='a') -> logging.Logger:
- logFormatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s", datefmt='%y-%m-%d %H:%M:%S')
- rootLogger = logging.getLogger(name)
+
+class ColoredFormatter(logging.Formatter):
+ def __init__(self, fmt=None, datefmt=None, style='%', enable=True):
+ super().__init__(fmt, datefmt, style)
+ self.enable = enable
+
+ def formatMessage(self, record: LogRecord) -> str:
+ message = super().formatMessage(record)
+ if self.enable:
+ return color_format(message)
+ else:
+ return remove_color_tag(message)
+
+
+def init_logger(name=None, root_dir=None, level=logging.INFO, mode='w',
+ fmt="%(asctime)s %(levelname)s %(message)s",
+ datefmt='%Y-%m-%d %H:%M:%S') -> logging.Logger:
+ if not name:
+ name = datetime.datetime.now().strftime("%Y-%m-%d_%H.%M.%S")
+ rootLogger = logging.getLogger(os.path.join(root_dir, name) if root_dir else name)
rootLogger.propagate = False
- consoleHandler = logging.StreamHandler()
- consoleHandler.setFormatter(logFormatter)
+ consoleHandler = logging.StreamHandler(sys.stdout) # stderr will be rendered as red which is bad
+ consoleHandler.setFormatter(ColoredFormatter(fmt, datefmt=datefmt))
attached_to_std = False
for handler in rootLogger.handlers:
if isinstance(handler, logging.StreamHandler):
@@ -28,44 +47,125 @@ def init_logger(name=datetime.datetime.now().strftime("%y-%m-%d_%H.%M.%S"), root
if root_dir:
os.makedirs(root_dir, exist_ok=True)
- fileHandler = logging.FileHandler("{0}/{1}.log".format(root_dir, name), mode=mode)
- fileHandler.setFormatter(logFormatter)
+ log_path = "{0}/{1}.log".format(root_dir, name)
+ fileHandler = logging.FileHandler(log_path, mode=mode)
+ fileHandler.setFormatter(ColoredFormatter(fmt, datefmt=datefmt, enable=False))
rootLogger.addHandler(fileHandler)
fileHandler.setLevel(level)
return rootLogger
-def set_tf_loglevel(level=logging.ERROR):
- if level >= logging.FATAL:
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
- os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '3'
- if level >= logging.ERROR:
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
- os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '2'
- if level >= logging.WARNING:
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
- os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '1'
- else:
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
- os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '0'
- shut_up_python_logging()
- logging.getLogger('tensorflow').setLevel(level)
+logger = init_logger(name='hanlp', level=os.environ.get('HANLP_LOG_LEVEL', 'INFO'))
-def shut_up_python_logging():
- logging.getLogger('tensorflow').setLevel(logging.ERROR)
- import absl.logging
- logging.root.removeHandler(absl.logging._absl_handler)
- absl.logging._warn_preinit_stderr = False
+def enable_debug(debug=True):
+ logger.setLevel(logging.DEBUG if debug else logging.ERROR)
-logger = init_logger(name='hanlp', level=os.environ.get('HANLP_LOG_LEVEL', 'INFO'))
-# shut_up_python_logging()
+class ErasablePrinter(object):
+ def __init__(self):
+ self._last_print_width = 0
-# shut up tensorflow
-# set_tf_loglevel()
+ def erase(self):
+ if self._last_print_width:
+ sys.stdout.write("\b" * self._last_print_width)
+ sys.stdout.write(" " * self._last_print_width)
+ sys.stdout.write("\b" * self._last_print_width)
+ sys.stdout.write("\r") # \r is essential when multi-lines were printed
+ self._last_print_width = 0
+ def print(self, msg: str, color=True):
+ self.erase()
+ if color:
+ msg, _len = color_format_len(msg)
+ self._last_print_width = _len
+ else:
+ self._last_print_width = len(msg)
+ sys.stdout.write(msg)
+ sys.stdout.flush()
-def enable_debug(debug=True):
- logger.setLevel(logging.DEBUG if debug else logging.ERROR)
+
+_printer = ErasablePrinter()
+
+
+def flash(line: str, color=True):
+ _printer.print(line, color)
+
+
+def color_format(msg: str):
+ for tag in termcolor.COLORS, termcolor.HIGHLIGHTS, termcolor.ATTRIBUTES:
+ for c, v in tag.items():
+ start, end = f'[{c}]', f'[/{c}]'
+ msg = msg.replace(start, '\033[%dm' % v).replace(end, termcolor.RESET)
+ return msg
+
+
+def remove_color_tag(msg: str):
+ for tag in termcolor.COLORS, termcolor.HIGHLIGHTS, termcolor.ATTRIBUTES:
+ for c, v in tag.items():
+ start, end = f'[{c}]', f'[/{c}]'
+ msg = msg.replace(start, '').replace(end, '')
+ return msg
+
+
+def color_format_len(msg: str):
+ _len = len(msg)
+ for tag in termcolor.COLORS, termcolor.HIGHLIGHTS, termcolor.ATTRIBUTES:
+ for c, v in tag.items():
+ start, end = f'[{c}]', f'[/{c}]'
+ msg, delta = _replace_color_offset(msg, start, '\033[%dm' % v)
+ _len -= delta
+ msg, delta = _replace_color_offset(msg, end, termcolor.RESET)
+ _len -= delta
+ return msg, _len
+
+
+def _replace_color_offset(msg: str, color: str, ctrl: str):
+ chunks = msg.split(color)
+ delta = (len(chunks) - 1) * len(color)
+ return ctrl.join(chunks), delta
+
+
+def cprint(*args, **kwargs):
+ out = io.StringIO()
+ print(*args, file=out, **kwargs)
+ text = out.getvalue()
+ out.close()
+ c_text = color_format(text)
+ print(c_text, end='')
+
+
+def main():
+ # cprint('[blink][yellow]...[/yellow][/blink]')
+ # show_colors_and_formats()
+ show_colors()
+ # print('previous', end='')
+ # for i in range(10):
+ # flash(f'[red]{i}[/red]')
+
+
+def show_colors_and_formats():
+ msg = ''
+ for c in termcolor.COLORS.keys():
+ for h in termcolor.HIGHLIGHTS.keys():
+ for a in termcolor.ATTRIBUTES.keys():
+ msg += f'[{c}][{h}][{a}] {c}+{h}+{a} [/{a}][/{h}][/{c}]'
+ logger.info(msg)
+
+
+def show_colors():
+ msg = ''
+ for c in termcolor.COLORS.keys():
+ cprint(f'[{c}]"{c}",[/{c}]')
+
+
+# Generates tables for Doxygen flavored Markdown. See the Doxygen
+# documentation for details:
+# http://www.doxygen.nl/manual/markdown.html#md_tables
+
+# Translation dictionaries for table alignment
+
+
+if __name__ == '__main__':
+ main()
diff --git a/hanlp/utils/rules.py b/hanlp/utils/rules.py
index c6ac9715a..4bd0f5642 100644
--- a/hanlp/utils/rules.py
+++ b/hanlp/utils/rules.py
@@ -1,7 +1,5 @@
import re
-from hanlp.utils.english_tokenizer import tokenize_english
-
SEPARATOR = r'@'
RE_SENTENCE = re.compile(r'(\S.+?[.!?])(?=\s+|$)|(\S.+?)(?=[\n]|$)', re.UNICODE)
AB_SENIOR = re.compile(r'([A-Z][a-z]{1,2}\.)\s(\w)', re.UNICODE)
diff --git a/hanlp/utils/span_util.py b/hanlp/utils/span_util.py
new file mode 100644
index 000000000..8f3f9c63a
--- /dev/null
+++ b/hanlp/utils/span_util.py
@@ -0,0 +1,82 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-06-12 20:34
+
+
+def generate_words_per_line(file_path):
+ with open(file_path, encoding='utf-8') as src:
+ for line in src:
+ cells = line.strip().split()
+ if not cells:
+ continue
+ yield cells
+
+
+def words_to_bmes(words):
+ tags = []
+ for w in words:
+ if not w:
+ raise ValueError('{} contains None or zero-length word {}'.format(str(words), w))
+ if len(w) == 1:
+ tags.append('S')
+ else:
+ tags.extend(['B'] + ['M'] * (len(w) - 2) + ['E'])
+ return tags
+
+
+def words_to_bi(words):
+ tags = []
+ for w in words:
+ if not w:
+ raise ValueError('{} contains None or zero-length word {}'.format(str(words), w))
+ tags.extend(['B'] + ['I'] * (len(w) - 1))
+ return tags
+
+
+def bmes_to_words(chars, tags):
+ result = []
+ if len(chars) == 0:
+ return result
+ word = chars[0]
+
+ for c, t in zip(chars[1:], tags[1:]):
+ if t == 'B' or t == 'S':
+ result.append(word)
+ word = ''
+ word += c
+ if len(word) != 0:
+ result.append(word)
+
+ return result
+
+
+def bmes_to_spans(tags):
+ result = []
+ offset = 0
+ pre_offset = 0
+ for t in tags[1:]:
+ offset += 1
+ if t == 'B' or t == 'S':
+ result.append((pre_offset, offset))
+ pre_offset = offset
+ if offset != len(tags):
+ result.append((pre_offset, len(tags)))
+
+ return result
+
+
+def bmes_of(sentence, segmented):
+ if segmented:
+ chars = []
+ tags = []
+ words = sentence.split()
+ for w in words:
+ chars.extend(list(w))
+ if len(w) == 1:
+ tags.append('S')
+ else:
+ tags.extend(['B'] + ['M'] * (len(w) - 2) + ['E'])
+ else:
+ chars = list(sentence)
+ tags = ['S'] * len(chars)
+ return chars, tags
diff --git a/hanlp/utils/string_util.py b/hanlp/utils/string_util.py
index b00fd89db..6046c0499 100644
--- a/hanlp/utils/string_util.py
+++ b/hanlp/utils/string_util.py
@@ -4,12 +4,6 @@
import unicodedata
from typing import List, Dict
-import tensorflow as tf
-
-
-def format_metrics(metrics: List[tf.keras.metrics.Metric]):
- return ' - '.join(f'{m.name}: {m.result():.4f}' for m in metrics)
-
def format_scores(results: Dict[str, float]) -> str:
return ' - '.join(f'{k}: {v:.4f}' for (k, v) in results.items())
@@ -20,19 +14,65 @@ def ispunct(token):
for char in token)
-def split_long_sentence_into(tokens: List[str], max_seq_length):
- punct_offset = [i for i, x in enumerate(tokens) if ispunct(x)]
+def split_long_sentence_into(tokens: List[str], max_seq_length, sent_delimiter=None, char_level=False,
+ hard_constraint=False):
+ punct_offset = [i for i, x in enumerate(tokens) if
+ ((sent_delimiter and x in sent_delimiter) or (not sent_delimiter and ispunct(x)))]
if not punct_offset:
# treat every token as punct
punct_offset = [i for i in range(len(tokens))]
punct_offset += [len(tokens)]
+ token_to_char_offset = []
+ if char_level:
+ offset = 0
+ for token in tokens:
+ token_to_char_offset.append(offset)
+ offset += len(token)
+ token_to_char_offset.append(offset)
+
start = 0
for i, offset in enumerate(punct_offset[:-1]):
- if punct_offset[i + 1] - start >= max_seq_length:
- yield tokens[start: offset + 1]
+ end = punct_offset[i + 1]
+ length_at_next_punct = _len(start, end, token_to_char_offset, char_level)
+ if length_at_next_punct >= max_seq_length:
+ if hard_constraint:
+ yield from _gen_short_sent(tokens, start, offset, max_seq_length, token_to_char_offset, char_level)
+ else:
+ yield tokens[start: offset + 1]
start = offset + 1
- if start < punct_offset[-1]:
- yield tokens[start:]
+ offset = punct_offset[-1]
+ if start < offset:
+ offset -= 1
+ length_at_next_punct = _len(start, offset, token_to_char_offset, char_level)
+ if length_at_next_punct >= max_seq_length and hard_constraint:
+ yield from _gen_short_sent(tokens, start, offset, max_seq_length, token_to_char_offset, char_level)
+ else:
+ yield tokens[start:]
+
+
+def _gen_short_sent(tokens, start, offset, max_seq_length, token_to_char_offset, char_level):
+ while start <= offset:
+ for j in range(offset + 1, start, -1):
+ if _len(start, j, token_to_char_offset, char_level) <= max_seq_length or j == start + 1:
+ yield tokens[start: j]
+ start = j
+ break
+
+
+def _len(start, end, token_to_char_offset, char_level):
+ if char_level:
+ length_at_next_punct = token_to_char_offset[end] - token_to_char_offset[start]
+ else:
+ length_at_next_punct = end - start
+ return length_at_next_punct
+
+
+def guess_delimiter(tokens):
+ if all(ord(c) < 128 for c in ''.join(tokens)):
+ delimiter_in_entity = ' '
+ else:
+ delimiter_in_entity = ''
+ return delimiter_in_entity
def split_long_sent(sent, delimiters, max_seq_length):
@@ -53,4 +93,4 @@ def split_long_sent(sent, delimiters, max_seq_length):
else:
if len(short) + len(parts[idx + 1]) > max_seq_length:
yield short
- short = []
\ No newline at end of file
+ short = []
diff --git a/hanlp/utils/tf_util.py b/hanlp/utils/tf_util.py
index 1ea040ea9..85ade1a02 100644
--- a/hanlp/utils/tf_util.py
+++ b/hanlp/utils/tf_util.py
@@ -1,49 +1,25 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-08-27 01:27
+import logging
+import os
import random
from typing import List
import numpy as np
-import tensorflow as tf
-
-from hanlp.common.constant import PAD
-
-
-def size_of_dataset(dataset: tf.data.Dataset) -> int:
- count = len(list(dataset.unbatch().as_numpy_iterator()))
- return count
-
-def summary_of_model(model: tf.keras.Model):
- """
- https://stackoverflow.com/a/53668338/3730690
- Parameters
- ----------
- model
- """
- if not model.built:
- return 'model structure unknown until calling fit() with some data'
- line_list = []
- model.summary(print_fn=lambda x: line_list.append(x))
- summary = "\n".join(line_list)
- return summary
+from hanlp_common.constant import PAD
-def register_custom_cls(custom_cls, name=None):
- if not name:
- name = custom_cls.__name__
- tf.keras.utils.get_custom_objects()[name] = custom_cls
+def set_gpu(idx=0):
+ """Restrict TensorFlow to only use the GPU of idx
+ Args:
+ idx: (Default value = 0)
-def set_gpu(idx=0):
- """
- Restrict TensorFlow to only use the GPU of idx
+ Returns:
- Parameters
- ----------
- idx : int
- Which GPU to use
+
"""
gpus = get_visible_gpus()
if gpus:
@@ -76,14 +52,75 @@ def set_gpu_memory_growth(growth=True):
def nice_gpu():
- """
- Use GPU nicely.
- """
+ """Use GPU nicely."""
set_gpu_memory_growth()
set_gpu()
-def set_seed(seed=233):
+def shut_up_python_logging():
+ logging.getLogger('tensorflow').setLevel(logging.ERROR)
+ import absl.logging
+ logging.root.removeHandler(absl.logging._absl_handler)
+ absl.logging._warn_preinit_stderr = False
+
+
+def set_tf_loglevel(level=logging.ERROR):
+ if level >= logging.FATAL:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+ os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '3'
+ if level >= logging.ERROR:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
+ os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '2'
+ if level >= logging.WARNING:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
+ os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '1'
+ else:
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
+ os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '0'
+ shut_up_python_logging()
+ logging.getLogger('tensorflow').setLevel(level)
+
+
+set_tf_loglevel()
+
+shut_up_python_logging()
+import tensorflow as tf
+
+nice_gpu()
+
+
+def size_of_dataset(dataset: tf.data.Dataset) -> int:
+ count = 0
+ for element in dataset.unbatch().batch(1):
+ count += 1
+ return count
+
+
+def summary_of_model(model: tf.keras.Model):
+ """https://stackoverflow.com/a/53668338/3730690
+
+ Args:
+ model: tf.keras.Model:
+
+ Returns:
+
+
+ """
+ if not model.built:
+ return 'model structure unknown until calling fit() with some data'
+ line_list = []
+ model.summary(print_fn=lambda x: line_list.append(x))
+ summary = "\n".join(line_list)
+ return summary
+
+
+def register_custom_cls(custom_cls, name=None):
+ if not name:
+ name = custom_cls.__name__
+ tf.keras.utils.get_custom_objects()[name] = custom_cls
+
+
+def set_seed_tf(seed=233):
tf.random.set_seed(seed)
np.random.seed(seed)
random.seed(seed)
@@ -91,13 +128,18 @@ def set_seed(seed=233):
def nice():
nice_gpu()
- set_seed()
+ set_seed_tf()
+
+def hanlp_register(arg):
+ """Registers a class with the Keras serialization framework.
+ Args:
+ arg:
+ Returns:
-def hanlp_register(arg):
- """Registers a class with the Keras serialization framework."""
+ """
class_name = arg.__name__
registered_name = 'HanLP' + '>' + class_name
@@ -126,8 +168,9 @@ def get_callback_by_class(callbacks: List[tf.keras.callbacks.Callback], cls) ->
if isinstance(callback, cls):
return callback
-def tf_bernoulli(shape, p):
- return tf.keras.backend.random_binomial(shape, p)
+
+def tf_bernoulli(shape, p, dtype=None):
+ return tf.keras.backend.random_binomial(shape, p, dtype)
def str_tensor_to_str(str_tensor: tf.Tensor) -> str:
@@ -148,4 +191,8 @@ def str_tensor_2d_to_list(str_tensor: tf.Tensor, pad=PAD) -> List[List[str]]:
def str_tensor_to_list(pred):
- return [tag.predict('utf-8') for tag in pred]
\ No newline at end of file
+ return [tag.predict('utf-8') for tag in pred]
+
+
+def format_metrics(metrics: List[tf.keras.metrics.Metric]):
+ return ' - '.join(f'{m.name}: {m.result():.4f}' for m in metrics)
diff --git a/hanlp/utils/time_util.py b/hanlp/utils/time_util.py
index 6b11e5d6f..e3723cd2f 100644
--- a/hanlp/utils/time_util.py
+++ b/hanlp/utils/time_util.py
@@ -2,7 +2,12 @@
# Author: hankcs
# Date: 2019-08-27 00:01
import datetime
+import logging
+import sys
import time
+from typing import Union
+
+from hanlp.utils.log_util import ErasablePrinter, color_format, color_format_len
def human_time_delta(days, hours, minutes, seconds, delimiter=' ') -> str:
@@ -56,6 +61,154 @@ def __truediv__(self, scalar):
return HumanTimeDelta(self.delta_seconds / scalar)
+class CountdownTimer(ErasablePrinter):
+
+ def __init__(self, total: int) -> None:
+ super().__init__()
+ self.total = total
+ self.current = 0
+ self.start = time.time()
+ self.finished_in = None
+ self.last_log_time = 0
+
+ def update(self, n=1):
+ self.current += n
+ self.current = min(self.total, self.current)
+ if self.current == self.total:
+ self.finished_in = time.time() - self.start
+
+ @property
+ def ratio(self) -> str:
+ return f'{self.current}/{self.total}'
+
+ @property
+ def ratio_percentage(self) -> str:
+ return f'{self.current / self.total:.2%}'
+
+ @property
+ def eta(self) -> float:
+ elapsed = self.elapsed
+ if self.finished_in:
+ eta = 0
+ else:
+ eta = elapsed / max(self.current, 0.1) * (self.total - self.current)
+
+ return eta
+
+ @property
+ def elapsed(self) -> float:
+ if self.finished_in:
+ elapsed = self.finished_in
+ else:
+ elapsed = time.time() - self.start
+ return elapsed
+
+ @property
+ def elapsed_human(self) -> str:
+ return human_time_delta(*seconds_to_time_delta(self.elapsed))
+
+ @property
+ def elapsed_average(self) -> float:
+ return self.elapsed / self.current
+
+ @property
+ def elapsed_average_human(self) -> str:
+ return human_time_delta(*seconds_to_time_delta(self.elapsed_average))
+
+ @property
+ def eta_human(self) -> str:
+ return human_time_delta(*seconds_to_time_delta(self.eta))
+
+ @property
+ def total_time(self) -> float:
+ elapsed = self.elapsed
+ if self.finished_in:
+ t = self.finished_in
+ else:
+ t = elapsed / max(self.current, 1) * self.total
+
+ return t
+
+ @property
+ def total_time_human(self) -> str:
+ return human_time_delta(*seconds_to_time_delta(self.total_time))
+
+ def stop(self, total=None):
+ if not self.finished_in or total:
+ self.finished_in = time.time() - self.start
+ if not total:
+ self.total = self.current
+ else:
+ self.current = total
+ self.total = total
+
+ @property
+ def et_eta(self):
+ _ = self.elapsed
+ if self.finished_in:
+ return self.elapsed
+ else:
+ return self.eta
+
+ @property
+ def et_eta_human(self):
+ text = human_time_delta(*seconds_to_time_delta(self.et_eta))
+ if self.finished_in:
+ return f'ET: {text}'
+ else:
+ return f'ETA: {text}'
+
+ @property
+ def finished(self):
+ return self.total == self.current
+
+ def log(self, info=None, ratio_percentage=True, ratio=True, step=1, interval=0.5, erase=True,
+ logger: Union[logging.Logger, bool] = None, newline=False, ratio_width=None):
+ self.update(step)
+ now = time.time()
+ if now - self.last_log_time > interval or self.finished:
+ cells = []
+ if ratio_percentage:
+ cells.append(self.ratio_percentage)
+ if ratio:
+ ratio = self.ratio
+ if not ratio_width:
+ ratio_width = self.ratio_width
+ ratio = ratio.rjust(ratio_width)
+ cells.append(ratio)
+ cells += [info, self.et_eta_human]
+ cells = [x for x in cells if x]
+ msg = f'{" ".join(cells)}'
+ self.last_log_time = now
+ self.print(msg, newline, erase, logger)
+
+ @property
+ def ratio_width(self) -> int:
+ return len(f'{self.total}') * 2 + 1
+
+ def print(self, msg, newline=False, erase=True, logger=None):
+ self.erase()
+ msg_len = 0 if newline else len(msg)
+ if self.finished and logger:
+ sys.stdout.flush()
+ if isinstance(logger, logging.Logger):
+ logger.info(msg)
+ else:
+ msg, msg_len = color_format_len(msg)
+ sys.stdout.write(msg)
+ if newline:
+ sys.stdout.write('\n')
+ msg_len = 0
+ self._last_print_width = msg_len
+ if self.finished and not logger:
+ if erase:
+ self.erase()
+ else:
+ sys.stdout.write("\n")
+ self._last_print_width = 0
+ sys.stdout.flush()
+
+
class Timer(object):
def __init__(self) -> None:
self.last = time.time()
@@ -80,11 +233,14 @@ def now_datetime():
def now_filename(fmt="%y%m%d_%H%M%S"):
- """
- Generate filename using current datetime, in 20180102_030405 format
- Returns
- -------
+ """Generate filename using current datetime, in 20180102_030405 format
+
+ Args:
+ fmt: (Default value = "%y%m%d_%H%M%S")
+
+ Returns:
+
"""
now = datetime.datetime.now()
return now.strftime(fmt)
diff --git a/hanlp/utils/torch_util.py b/hanlp/utils/torch_util.py
new file mode 100644
index 000000000..706d8e516
--- /dev/null
+++ b/hanlp/utils/torch_util.py
@@ -0,0 +1,187 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-05-09 15:52
+import os
+import random
+import time
+from typing import List, Union
+
+import numpy as np
+import torch
+from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit, nvmlShutdown, nvmlDeviceGetCount
+from torch import nn
+from torch.nn.utils.rnn import pad_sequence
+
+from hanlp.utils.log_util import logger
+
+
+def gpus_available() -> dict:
+ try:
+ nvmlInit()
+ gpus = {}
+ visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
+ if visible_devices:
+ visible_devices = {int(x.strip()) for x in visible_devices.split(',')}
+ else:
+ visible_devices = list(range(nvmlDeviceGetCount()))
+ for i, real_id in enumerate(visible_devices):
+ h = nvmlDeviceGetHandleByIndex(real_id)
+ info = nvmlDeviceGetMemoryInfo(h)
+ total = info.total
+ free = info.free
+ ratio = free / total
+ gpus[i] = ratio
+ # print(f'total : {info.total}')
+ # print(f'free : {info.free}')
+ # print(f'used : {info.used}')
+ # t = torch.cuda.get_device_properties(0).total_memory
+ # c = torch.cuda.memory_cached(0)
+ # a = torch.cuda.memory_allocated(0)
+ # print(t, c, a)
+ nvmlShutdown()
+ return dict(sorted(gpus.items(), key=lambda x: x[1], reverse=True))
+ except Exception as e:
+ logger.debug(f'Failed to get gpu info due to {e}')
+ return {}
+
+
+def visuable_devices():
+ visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
+ if visible_devices:
+ visible_devices = {int(x.strip()) for x in visible_devices.split(',')}
+ else:
+ visible_devices = list(range(torch.cuda.device_count()))
+ return visible_devices
+
+
+def cuda_devices(query=None) -> List[int]:
+ """Decide which GPUs to use
+
+ Args:
+ query: (Default value = None)
+
+ Returns:
+
+
+ """
+ if isinstance(query, list):
+ if len(query) == 0:
+ return [-1]
+ return query
+ if query is None:
+ query = gpus_available()
+ if not query:
+ return []
+ size, idx = max((v, k) for k, v in query.items())
+ # When multiple GPUs have the same size, randomly pick one to avoid conflicting
+ gpus_with_same_size = [k for k, v in query.items() if v == size]
+ query = random.choice(gpus_with_same_size)
+ if isinstance(query, float):
+ gpus = gpus_available()
+ if not query:
+ return []
+ query = [k for k, v in gpus.items() if v > query]
+ elif isinstance(query, int):
+ query = [query]
+ return query
+
+
+def pad_lists(sequences: List[List], dtype=torch.long, padding_value=0):
+ return pad_sequence([torch.tensor(x, dtype=dtype) for x in sequences], True, padding_value)
+
+
+def set_seed(seed=233, dont_care_speed=False):
+ """Copied from https://github.com/huggingface/transformers/blob/7b75aa9fa55bee577e2c7403301ed31103125a35/src/transformers/trainer.py#L76
+
+ Args:
+ seed: (Default value = 233)
+ dont_care_speed: True may have a negative single-run performance impact, but ensures deterministic
+
+ Returns:
+
+
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ # ^^ safe to call this function even if cuda is not available
+ torch.cuda.manual_seed_all(seed)
+ if dont_care_speed:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def batched_index_select(input, index, dim=1):
+ """
+
+ Args:
+ input: B x * x ... x *
+ index: B x M
+ dim: (Default value = 1)
+
+ Returns:
+
+
+ """
+ views = [input.shape[0]] + [1 if i != dim else -1 for i in range(1, len(input.shape))]
+ expanse = list(input.shape)
+ expanse[0] = -1
+ expanse[dim] = -1
+ index = index.view(views).expand(expanse)
+ return torch.gather(input, dim, index)
+
+
+def truncated_normal_(tensor, mean=0, std=1):
+ size = tensor.shape
+ tmp = tensor.new_empty(size + (4,)).normal_()
+ valid = (tmp < 2) & (tmp > -2)
+ ind = valid.max(-1, keepdim=True)[1]
+ tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
+ tensor.data.mul_(std).add_(mean)
+ return tensor
+
+
+def dtype_of(e: Union[int, bool, float]):
+ if isinstance(e, bool):
+ return torch.bool
+ if isinstance(e, int):
+ return torch.long
+ if isinstance(e, float):
+ return torch.float
+ raise ValueError(f'Unsupported type of {repr(e)}')
+
+
+def mean_model(model: torch.nn.Module):
+ return float(torch.mean(torch.stack([torch.sum(p) for p in model.parameters() if p.requires_grad])))
+
+
+def main():
+ start = time.time()
+ print(visuable_devices())
+ print(time.time() - start)
+ # print(gpus_available())
+ # print(cuda_devices())
+ # print(cuda_devices(0.1))
+
+
+if __name__ == '__main__':
+ main()
+
+
+def clip_grad_norm(model: nn.Module, grad_norm, transformer: nn.Module = None, transformer_grad_norm=None):
+ if transformer_grad_norm is None:
+ if grad_norm is not None:
+ nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), grad_norm)
+ else:
+ is_transformer = []
+ non_transformer = []
+ transformer = set(transformer.parameters())
+ for p in model.parameters():
+ if not p.requires_grad:
+ continue
+ if p in transformer:
+ is_transformer.append(p)
+ else:
+ non_transformer.append(p)
+ nn.utils.clip_grad_norm_(non_transformer, grad_norm)
+ nn.utils.clip_grad_norm_(is_transformer, transformer_grad_norm)
diff --git a/hanlp/version.py b/hanlp/version.py
index 083a83357..2dcf73220 100644
--- a/hanlp/version.py
+++ b/hanlp/version.py
@@ -2,4 +2,5 @@
# Author: hankcs
# Date: 2019-12-28 19:26
-__version__ = '2.0.0-alpha.69'
+__version__ = '2.1.0-alpha.0'
+"""HanLP version"""
diff --git a/plugins/hanlp_common/README.md b/plugins/hanlp_common/README.md
new file mode 100644
index 000000000..6e60455da
--- /dev/null
+++ b/plugins/hanlp_common/README.md
@@ -0,0 +1,17 @@
+# Common utilities and structures for HanLP
+
+[中文](https://github.com/hankcs/HanLP/tree/doc-zh) | [1.x](https://github.com/hankcs/HanLP/tree/1.x) | [forum](https://bbs.hankcs.com/) | [docker](https://github.com/WalterInSH/hanlp-jupyter-docker)
+
+The multilingual NLP library for researchers and companies, built on PyTorch and TensorFlow 2.x, for advancing state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be efficient, user friendly and extendable. It comes with pretrained models for various human languages including English, Chinese and many others. Currently, HanLP 2.0 is in alpha stage with more killer features on the roadmap. Discussions are welcomed on our [forum](https://bbs.hankcs.com/), while bug reports and feature requests are reserved for GitHub issues. For Java users, please checkout the [1.x](https://github.com/hankcs/HanLP/tree/1.x) branch.
+
+
+## Installation
+
+```bash
+pip install hanlp
+```
+
+## License
+
+HanLP is licensed under **Apache License 2.0**. You can use HanLP in your commercial products for free. We would appreciate it if you add a link to HanLP on your website.
+
diff --git a/plugins/hanlp_common/__init__.py b/plugins/hanlp_common/__init__.py
new file mode 100644
index 000000000..079c7284b
--- /dev/null
+++ b/plugins/hanlp_common/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-16 22:20
diff --git a/plugins/hanlp_common/hanlp_common/__init__.py b/plugins/hanlp_common/hanlp_common/__init__.py
new file mode 100644
index 000000000..f659353bd
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-16 22:21
diff --git a/plugins/hanlp_common/hanlp_common/amr.py b/plugins/hanlp_common/hanlp_common/amr.py
new file mode 100644
index 000000000..0146a5334
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/amr.py
@@ -0,0 +1,963 @@
+# MIT License
+#
+# Copyright (c) 2019 Sheng Zhang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import json
+import logging
+import re
+import traceback
+from collections import Counter, defaultdict
+
+from hanlp_common.io import eprint
+
+try:
+ import networkx as nx
+ import penman
+ from penman import Triple
+except ModuleNotFoundError:
+ traceback.print_exc()
+ eprint('AMR support requires the full version which can be installed via:\n'
+ 'pip install hanlp_common[full]')
+ exit(1)
+
+DEFAULT_PADDING_TOKEN = "@@PADDING@@"
+DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
+logger = logging.getLogger('amr')
+
+# Disable inverting ':mod' relation.
+penman.AMRCodec._inversions.pop('domain')
+penman.AMRCodec._deinversions.pop('mod')
+
+amr_codec = penman.AMRCodec(indent=6)
+
+WORDSENSE_RE = re.compile(r'-\d\d$')
+QUOTED_RE = re.compile(r'^".*"$')
+
+
+def is_abstract_token(token):
+ return re.search(r'^([A-Z]+_)+\d+$', token) or re.search(r'^\d0*$', token)
+
+
+def is_english_punct(c):
+ return re.search(r'^[,.?!:;"\'-(){}\[\]]$', c)
+
+
+def find_similar_token(token, tokens):
+ token = re.sub(r'-\d\d$', '', token) # .lower())
+ for i, t in enumerate(tokens):
+ if token == t:
+ return tokens[i]
+ # t = t.lower()
+ # if (token == t or
+ # (t.startswith(token) and len(token) > 3) or
+ # token + 'd' == t or
+ # token + 'ed' == t or
+ # re.sub('ly$', 'le', t) == token or
+ # re.sub('tive$', 'te', t) == token or
+ # re.sub('tion$', 'te', t) == token or
+ # re.sub('ied$', 'y', t) == token or
+ # re.sub('ly$', '', t) == token
+ # ):
+ # return tokens[i]
+ return None
+
+
+class AMR:
+
+ def __init__(self,
+ id=None,
+ sentence=None,
+ graph=None,
+ tokens=None,
+ lemmas=None,
+ pos_tags=None,
+ ner_tags=None,
+ abstract_map=None,
+ misc=None):
+ self.id = id
+ self.sentence = sentence
+ self.graph = graph
+ self.tokens = tokens
+ self.lemmas = lemmas
+ self.pos_tags = pos_tags
+ self.ner_tags = ner_tags
+ self.abstract_map = abstract_map
+ self.misc = misc
+
+ def is_named_entity(self, index):
+ return self.ner_tags[index] not in ('0', 'O')
+
+ def get_named_entity_span(self, index):
+ if self.ner_tags is None or not self.is_named_entity(index):
+ return []
+ span = [index]
+ tag = self.ner_tags[index]
+ prev = index - 1
+ while prev > 0 and self.ner_tags[prev] == tag:
+ span.append(prev)
+ prev -= 1
+ next = index + 1
+ while next < len(self.ner_tags) and self.ner_tags[next] == tag:
+ span.append(next)
+ next += 1
+ return span
+
+ def find_span_indexes(self, span):
+ for i, token in enumerate(self.tokens):
+ if token == span[0]:
+ _span = self.tokens[i: i + len(span)]
+ if len(_span) == len(span) and all(x == y for x, y in zip(span, _span)):
+ return list(range(i, i + len(span)))
+ return None
+
+ def replace_span(self, indexes, new, pos=None, ner=None):
+ self.tokens = self.tokens[:indexes[0]] + new + self.tokens[indexes[-1] + 1:]
+ self.lemmas = self.lemmas[:indexes[0]] + new + self.lemmas[indexes[-1] + 1:]
+ if pos is None:
+ pos = [self.pos_tags[indexes[0]]]
+ self.pos_tags = self.pos_tags[:indexes[0]] + pos + self.pos_tags[indexes[-1] + 1:]
+ if ner is None:
+ ner = [self.ner_tags[indexes[0]]]
+ self.ner_tags = self.ner_tags[:indexes[0]] + ner + self.ner_tags[indexes[-1] + 1:]
+
+ def remove_span(self, indexes):
+ self.replace_span(indexes, [], [], [])
+
+ def __repr__(self):
+ fields = []
+ for k, v in dict(
+ id=self.id,
+ snt=self.sentence,
+ tokens=self.tokens,
+ lemmas=self.lemmas,
+ pos_tags=self.pos_tags,
+ ner_tags=self.ner_tags,
+ abstract_map=self.abstract_map,
+ misc=self.misc,
+ graph=self.graph
+ ).items():
+ if v is None:
+ continue
+ if k == 'misc':
+ fields += v
+ elif k == 'graph':
+ fields.append(str(v))
+ else:
+ if not isinstance(v, str):
+ v = json.dumps(v)
+ fields.append('# ::{} {}'.format(k, v))
+ return '\n'.join(fields)
+
+ def get_src_tokens(self):
+ return self.lemmas if self.lemmas else self.sentence.split()
+
+
+class AMRNode:
+ attribute_priority = [
+ 'instance', 'quant', 'mode', 'value', 'name', 'li', 'mod', 'frequency',
+ 'month', 'day', 'year', 'time', 'unit', 'decade', 'poss'
+ ]
+
+ def __init__(self, identifier, attributes=None, copy_of=None):
+ self.identifier = identifier
+ if attributes is None:
+ self.attributes = []
+ else:
+ self.attributes = attributes
+ # self._sort_attributes()
+ self._num_copies = 0
+ self.copy_of = copy_of
+
+ def _sort_attributes(self):
+ def get_attr_priority(attr):
+ if attr in self.attribute_priority:
+ return self.attribute_priority.index(attr), attr
+ if not re.search(r'^(ARG|op|snt)', attr):
+ return len(self.attribute_priority), attr
+ else:
+ return len(self.attribute_priority) + 1, attr
+
+ self.attributes.sort(key=lambda x: get_attr_priority(x[0]))
+
+ def __hash__(self):
+ return hash(self.identifier)
+
+ def __eq__(self, other):
+ if not isinstance(other, AMRNode):
+ return False
+ return self.identifier == other.identifier
+
+ def __repr__(self):
+ ret = str(self.identifier)
+ for k, v in self.attributes:
+ if k == 'instance':
+ ret += ' / ' + v
+ break
+ return ret
+
+ def __str__(self):
+ ret = repr(self)
+ for key, value in self.attributes:
+ if key == 'instance':
+ continue
+ ret += '\n\t:{} {}'.format(key, value)
+ return ret
+
+ @property
+ def instance(self):
+ for key, value in self.attributes:
+ if key == 'instance':
+ return value
+ else:
+ return None
+
+ @property
+ def ops(self):
+ ops = []
+ for key, value in self.attributes:
+ if re.search(r'op\d+', key):
+ ops.append((int(key[2:]), value))
+ if len(ops):
+ ops.sort(key=lambda x: x[0])
+ return [v for k, v in ops]
+
+ def copy(self):
+ attributes = None
+ if self.attributes is not None:
+ attributes = self.attributes[:]
+ self._num_copies += 1
+ copy = AMRNode(self.identifier + '_copy_{}'.format(self._num_copies), attributes, self)
+ return copy
+
+ def remove_attribute(self, attr, value):
+ self.attributes.remove((attr, value))
+
+ def add_attribute(self, attr, value):
+ self.attributes.append((attr, value))
+
+ def replace_attribute(self, attr, old, new):
+ index = self.attributes.index((attr, old))
+ self.attributes[index] = (attr, new)
+
+ def get_frame_attributes(self):
+ for k, v in self.attributes:
+ if isinstance(v, str) and re.search(r'-\d\d$', v):
+ yield k, v
+
+ def get_senseless_attributes(self):
+ for k, v in self.attributes:
+ if isinstance(v, str) and not re.search(r'-\d\d$', v):
+ yield k, v
+
+
+class AMRGraph(penman.Graph):
+ edge_label_priority = (
+ 'mod name time location degree poss domain quant manner unit purpose topic condition part-of compared-to '
+ 'duration source ord beneficiary concession direction frequency consist-of example medium location-of '
+ 'manner-of quant-of time-of instrument prep-in destination accompanier prep-with extent instrument-of age '
+ 'path concession-of subevent-of prep-as prep-to prep-against prep-on prep-for degree-of prep-under part '
+ 'condition-of prep-without topic-of season duration-of poss-of prep-from prep-at range purpose-of source-of '
+ 'subevent example-of value path-of scale conj-as-if prep-into prep-by prep-on-behalf-of medium-of prep-among '
+ 'calendar beneficiary-of prep-along-with extent-of age-of frequency-of dayperiod accompanier-of '
+ 'destination-of prep-amid prep-toward prep-in-addition-to ord-of name-of weekday direction-of prep-out-of '
+ 'timezone subset-of'.split())
+
+ def __init__(self, penman_graph):
+ super(AMRGraph, self).__init__()
+ self._triples = penman_graph._triples
+ self._top = penman_graph._top
+ self._build_extras()
+ self._src_tokens = []
+
+ def __str__(self):
+ self._triples = penman.alphanum_order(self._triples)
+ return amr_codec.encode(self)
+
+ def _build_extras(self):
+ G = nx.DiGraph()
+
+ self.variable_to_node = {}
+ for v in self.variables():
+ if type(v) is not str:
+ continue
+ attributes = [(t.relation, t.target) for t in self.attributes(source=v)]
+ node = AMRNode(v, attributes)
+ G.add_node(node)
+ self.variable_to_node[v] = node
+
+ edge_set = set()
+ for edge in self.edges():
+ if type(edge.source) is not str:
+ continue
+ source = self.variable_to_node[edge.source]
+ target = self.variable_to_node[edge.target]
+ relation = edge.relation
+
+ if relation == 'instance':
+ continue
+
+ if source == target:
+ continue
+
+ if edge.inverted:
+ source, target, relation = target, source, amr_codec.invert_relation(edge.relation)
+
+ if (source, target) in edge_set:
+ target = target.copy()
+
+ edge_set.add((source, target))
+ G.add_edge(source, target, label=relation)
+
+ self._G = G
+
+ def attributes(self, source=None, relation=None, target=None):
+ # Refine attributes because there's a bug in penman.attributes()
+ # See https://github.com/goodmami/penman/issues/29
+ attrmatch = lambda a: (
+ (source is None or source == a.source) and
+ (relation is None or relation == a.relation) and
+ (target is None or target == a.target)
+ )
+ variables = self.variables()
+ attrs = [t for t in self.triples() if t.target not in variables or t.relation == 'instance']
+ return list(filter(attrmatch, attrs))
+
+ def _update_penman_graph(self, triples):
+ self._triples = triples
+ if self._top not in self.variables():
+ self._top = None
+
+ def is_name_node(self, node):
+ edges = list(self._G.in_edges(node))
+ return any(self._G[source][target].get('label', None) == 'name' for source, target in edges)
+
+ def get_name_node_type(self, node):
+ edges = list(self._G.in_edges(node))
+ for source, target in edges:
+ if self._G[source][target].get('label', None) == 'name':
+ return source.instance
+ raise KeyError
+
+ def get_name_node_wiki(self, node):
+ edges = list(self._G.in_edges(node))
+ for source, target in edges:
+ if self._G[source][target].get('label', None) == 'name':
+ for attr, value in source.attributes:
+ if attr == 'wiki':
+ if value != '-':
+ value = value[1:-1] # remove quotes
+ return value
+ return None
+
+ def set_name_node_wiki(self, node, wiki):
+ edges = list(self._G.in_edges(node))
+ parent = None
+ for source, target in edges:
+ if self._G[source][target].get('label', None) == 'name':
+ parent = source
+ break
+ if parent:
+ if wiki != '-':
+ wiki = '"{}"'.format(wiki)
+ self.add_node_attribute(parent, 'wiki', wiki)
+
+ def is_date_node(self, node):
+ return node.instance == 'date-entity'
+
+ def add_edge(self, source, target, label):
+ self._G.add_edge(source, target, label=label)
+ t = penman.Triple(source=source.identifier, relation=label, target=target.identifier)
+ triples = self._triples + [t]
+ triples = penman.alphanum_order(triples)
+ self._update_penman_graph(triples)
+
+ def remove_edge(self, x, y):
+ if isinstance(x, AMRNode) and isinstance(y, AMRNode):
+ self._G.remove_edge(x, y)
+ if isinstance(x, AMRNode):
+ x = x.identifier
+ if isinstance(y, AMRNode):
+ y = y.identifier
+ triples = [t for t in self._triples if not (t.source == x and t.target == y)]
+ self._update_penman_graph(triples)
+
+ def update_edge_label(self, x, y, old, new):
+ self._G[x][y]['label'] = new
+ triples = []
+ for t in self._triples:
+ if t.source == x.identifier and t.target == y.identifier and t.relation == old:
+ t = Triple(x.identifier, new, y.identifier)
+ triples.append(t)
+ self._update_penman_graph(triples)
+
+ def add_node(self, instance):
+ identifier = instance[0]
+ assert identifier.isalpha()
+ if identifier in self.variables():
+ i = 2
+ while identifier + str(i) in self.variables():
+ i += 1
+ identifier += str(i)
+ triples = self._triples + [Triple(identifier, 'instance', instance)]
+ self._triples = penman.alphanum_order(triples)
+
+ node = AMRNode(identifier, [('instance', instance)])
+ self._G.add_node(node)
+ return node
+
+ def remove_node(self, node):
+ self._G.remove_node(node)
+ triples = [t for t in self._triples if t.source != node.identifier]
+ self._update_penman_graph(triples)
+
+ def replace_node_attribute(self, node, attr, old, new):
+ node.replace_attribute(attr, old, new)
+ triples = []
+ found = False
+ for t in self._triples:
+ if t.source == node.identifier and t.relation == attr and t.target == old:
+ found = True
+ t = penman.Triple(source=node.identifier, relation=attr, target=new)
+ triples.append(t)
+ if not found:
+ raise KeyError
+ self._triples = penman.alphanum_order(triples)
+
+ def remove_node_attribute(self, node, attr, value):
+ node.remove_attribute(attr, value)
+ triples = [t for t in self._triples if
+ not (t.source == node.identifier and t.relation == attr and t.target == value)]
+ self._update_penman_graph(triples)
+
+ def add_node_attribute(self, node, attr, value):
+ node.add_attribute(attr, value)
+ t = penman.Triple(source=node.identifier, relation=attr, target=value)
+ self._triples = penman.alphanum_order(self._triples + [t])
+
+ def remove_node_ops(self, node):
+ ops = []
+ for attr, value in node.attributes:
+ if re.search(r'^op\d+$', attr):
+ ops.append((attr, value))
+ for attr, value in ops:
+ self.remove_node_attribute(node, attr, value)
+
+ def remove_subtree(self, root):
+ children = []
+ removed_nodes = set()
+ for _, child in list(self._G.edges(root)):
+ self.remove_edge(root, child)
+ children.append(child)
+ for child in children:
+ if len(list(self._G.in_edges(child))) == 0:
+ removed_nodes.update(self.remove_subtree(child))
+ if len(list(self._G.in_edges(root))) == 0:
+ self.remove_node(root)
+ removed_nodes.add(root)
+ return removed_nodes
+
+ def get_subtree(self, root, max_depth):
+ if max_depth == 0:
+ return []
+ nodes = [root]
+ children = [child for _, child in self._G.edges(root)]
+ nodes += children
+ for child in children:
+ if len(list(self._G.in_edges(child))) == 1:
+ nodes = nodes + self.get_subtree(child, max_depth - 1)
+ return nodes
+
+ def get_nodes(self):
+ return self._G.nodes
+
+ def get_edges(self):
+ return self._G.edges
+
+ def set_src_tokens(self, sentence):
+ if type(sentence) is not list:
+ sentence = sentence.split(" ")
+ self._src_tokens = sentence
+
+ def get_src_tokens(self):
+ return self._src_tokens
+
+ def get_list_node(self, replace_copy=True):
+ visited = defaultdict(int)
+ node_list = []
+
+ def dfs(node, relation, parent):
+
+ node_list.append((
+ node if node.copy_of is None or not replace_copy else node.copy_of,
+ relation,
+ parent if parent.copy_of is None or not replace_copy else parent.copy_of))
+
+ if len(self._G[node]) > 0 and visited[node] == 0:
+ visited[node] = 1
+ for child_node, child_relation in self.sort_edges(self._G[node].items()):
+ dfs(child_node, child_relation["label"], node)
+
+ dfs(
+ self.variable_to_node[self._top],
+ 'root',
+ self.variable_to_node[self._top]
+ )
+
+ return node_list
+
+ def sort_edges(self, edges):
+ return edges
+
+ def get_tgt_tokens(self):
+ node_list = self.get_list_node()
+
+ tgt_token = []
+ visited = defaultdict(int)
+
+ for node, relation, parent_node in node_list:
+ instance = [attr[1] for attr in node.attributes if attr[0] == "instance"]
+ assert len(instance) == 1
+ tgt_token.append(str(instance[0]))
+
+ if len(node.attributes) > 1 and visited[node] == 0:
+ for attr in node.attributes:
+ if attr[0] != "instance":
+ tgt_token.append(str(attr[1]))
+
+ visited[node] = 1
+
+ return tgt_token
+
+ def get_list_data(self, amr, bos=None, eos=None, bert_tokenizer=None, max_tgt_length=None):
+ node_list = self.get_list_node()
+
+ tgt_tokens = []
+ head_tags = []
+ head_indices = []
+
+ node_to_idx = defaultdict(list)
+ visited = defaultdict(int)
+
+ def update_info(node, relation, parent, token):
+ head_indices.append(1 + node_to_idx[parent][-1])
+ head_tags.append(relation)
+ tgt_tokens.append(str(token))
+
+ for node, relation, parent_node in node_list:
+
+ node_to_idx[node].append(len(tgt_tokens))
+
+ instance = [attr[1] for attr in node.attributes if attr[0] == "instance"]
+ assert len(instance) == 1
+ instance = instance[0]
+
+ update_info(node, relation, parent_node, instance)
+
+ if len(node.attributes) > 1 and visited[node] == 0:
+ for attr in node.attributes:
+ if attr[0] != "instance":
+ update_info(node, attr[0], node, attr[1])
+
+ visited[node] = 1
+
+ def trim_very_long_tgt_tokens(tgt_tokens, head_tags, head_indices, node_to_idx):
+ tgt_tokens = tgt_tokens[:max_tgt_length]
+ head_tags = head_tags[:max_tgt_length]
+ head_indices = head_indices[:max_tgt_length]
+ for node, indices in node_to_idx.items():
+ invalid_indices = [index for index in indices if index >= max_tgt_length]
+ for index in invalid_indices:
+ indices.remove(index)
+ return tgt_tokens, head_tags, head_indices, node_to_idx
+
+ if max_tgt_length is not None:
+ tgt_tokens, head_tags, head_indices, node_to_idx = trim_very_long_tgt_tokens(
+ tgt_tokens, head_tags, head_indices, node_to_idx)
+
+ copy_offset = 0
+ if bos:
+ tgt_tokens = [bos] + tgt_tokens
+ copy_offset += 1
+ if eos:
+ tgt_tokens = tgt_tokens + [eos]
+
+ head_indices[node_to_idx[self.variable_to_node[self.top]][0]] = 0
+
+ # Target side Coreference
+ tgt_copy_indices = [i for i in range(len(tgt_tokens))]
+
+ for node, indices in node_to_idx.items():
+ if len(indices) > 1:
+ copy_idx = indices[0] + copy_offset
+ for token_idx in indices[1:]:
+ tgt_copy_indices[token_idx + copy_offset] = copy_idx
+
+ tgt_copy_map = [(token_idx, copy_idx) for token_idx, copy_idx in enumerate(tgt_copy_indices)]
+
+ for i, copy_index in enumerate(tgt_copy_indices):
+ # Set the coreferred target to 0 if no coref is available.
+ if i == copy_index:
+ tgt_copy_indices[i] = 0
+
+ tgt_token_counter = Counter(tgt_tokens)
+ tgt_copy_mask = [0] * len(tgt_tokens)
+ for i, token in enumerate(tgt_tokens):
+ if tgt_token_counter[token] > 1:
+ tgt_copy_mask[i] = 1
+
+ def add_source_side_tags_to_target_side(_src_tokens, _src_tags):
+ assert len(_src_tags) == len(_src_tokens)
+ tag_counter = defaultdict(lambda: defaultdict(int))
+ for src_token, src_tag in zip(_src_tokens, _src_tags):
+ tag_counter[src_token][src_tag] += 1
+
+ tag_lut = {DEFAULT_OOV_TOKEN: DEFAULT_OOV_TOKEN,
+ DEFAULT_PADDING_TOKEN: DEFAULT_OOV_TOKEN}
+ for src_token in set(_src_tokens):
+ tag = max(tag_counter[src_token].keys(), key=lambda x: tag_counter[src_token][x])
+ tag_lut[src_token] = tag
+
+ tgt_tags = []
+ for tgt_token in tgt_tokens:
+ sim_token = find_similar_token(tgt_token, _src_tokens)
+ if sim_token is not None:
+ index = _src_tokens.index(sim_token)
+ tag = _src_tags[index]
+ else:
+ tag = DEFAULT_OOV_TOKEN
+ tgt_tags.append(tag)
+
+ return tgt_tags, tag_lut
+
+ # Source Copy
+ src_tokens = self.get_src_tokens()
+ src_token_ids = None
+ src_token_subword_index = None
+ src_pos_tags = amr.pos_tags
+ src_copy_vocab = SourceCopyVocabulary(src_tokens)
+ src_copy_indices = src_copy_vocab.index_sequence(tgt_tokens)
+ src_copy_map = src_copy_vocab.get_copy_map(src_tokens)
+ tgt_pos_tags, pos_tag_lut = add_source_side_tags_to_target_side(src_tokens, src_pos_tags)
+
+ if bert_tokenizer is not None:
+ src_token_ids, src_token_subword_index = bert_tokenizer.tokenize(src_tokens, True)
+
+ src_must_copy_tags = [1 if is_abstract_token(t) else 0 for t in src_tokens]
+ src_copy_invalid_ids = set(src_copy_vocab.index_sequence(
+ [t for t in src_tokens if is_english_punct(t)]))
+
+ return {
+ "tgt_tokens": tgt_tokens,
+ "tgt_pos_tags": tgt_pos_tags,
+ "tgt_copy_indices": tgt_copy_indices,
+ "tgt_copy_map": tgt_copy_map,
+ "tgt_copy_mask": tgt_copy_mask,
+ "src_tokens": src_tokens,
+ "src_token_ids": src_token_ids,
+ "src_token_subword_index": src_token_subword_index,
+ "src_must_copy_tags": src_must_copy_tags,
+ "src_pos_tags": src_pos_tags,
+ "src_copy_vocab": src_copy_vocab,
+ "src_copy_indices": src_copy_indices,
+ "src_copy_map": src_copy_map,
+ "pos_tag_lut": pos_tag_lut,
+ "head_tags": head_tags,
+ "head_indices": head_indices,
+ "src_copy_invalid_ids": src_copy_invalid_ids
+ }
+
+ @classmethod
+ def decode(cls, raw_graph_string):
+ _graph = amr_codec.decode(raw_graph_string)
+ return cls(_graph)
+
+ @classmethod
+ def from_lists(cls, all_list):
+ head_tags = all_list['head_tags']
+ head_indices = all_list['head_indices']
+ tgt_tokens = all_list['tokens']
+
+ tgt_copy_indices = all_list['coref']
+ variables = []
+ variables_count = defaultdict(int)
+ for i, token in enumerate(tgt_tokens):
+ if tgt_copy_indices[i] != i:
+ variables.append(variables[tgt_copy_indices[i]])
+ else:
+ if token[0] in variables_count:
+ variables.append(token[0] + str(variables_count[token[0]]))
+ else:
+ variables.append(token[0])
+
+ variables_count[token[0]] += 1
+
+ Triples = []
+ for variable, token in zip(variables, tgt_tokens):
+ Triples.append(Triple(variable, "instance", token))
+ Triples.append(
+ Triple(
+ head_indices[variable],
+ head_tags[variable],
+ variable
+ )
+ )
+
+ @classmethod
+ def from_prediction(cls, prediction):
+
+ def is_attribute_value(value):
+ return re.search(r'(^".*"$|^[^a-zA-Z]+$)', value) is not None
+
+ def is_attribute_edge(label):
+ return label in ('instance', 'mode', 'li', 'value', 'month', 'year', 'day', 'decade', 'ARG6')
+
+ def normalize_number(text):
+ if re.search(r'^\d+,\d+$', text):
+ text = text.replace(',', '')
+ return text
+
+ def abstract_node(value):
+ return re.search(r'^([A-Z]+|DATE_ATTRS|SCORE_ENTITY|ORDINAL_ENTITY)_\d+$', value)
+
+ def abstract_attribute(value):
+ return re.search(r'^_QUANTITY_\d+$', value)
+
+ def correct_multiroot(heads):
+ for i in range(1, len(heads)):
+ if heads[i] == 0:
+ heads[i] = 1
+ return heads
+
+ nodes = [normalize_number(n) for n in prediction['nodes']]
+ heads = correct_multiroot(prediction['heads'])
+ corefs = [int(x) for x in prediction['corefs']]
+ head_labels = prediction['head_labels']
+
+ triples = []
+ top = None
+ # Build the variable map from variable to instance.
+ variable_map = {}
+ for coref_index in corefs:
+ node = nodes[coref_index - 1]
+ head_label = head_labels[coref_index - 1]
+ if (re.search(r'[/:\\()]', node) or is_attribute_value(node) or
+ is_attribute_edge(head_label) or abstract_attribute(node)):
+ continue
+ variable_map['vv{}'.format(coref_index)] = node
+ for head_index in heads:
+ if head_index == 0:
+ continue
+ node = nodes[head_index - 1]
+ coref_index = corefs[head_index - 1]
+ variable_map['vv{}'.format(coref_index)] = node
+ # Build edge triples and other attribute triples.
+ for i, head_index in enumerate(heads):
+ if head_index == 0:
+ top_variable = 'vv{}'.format(corefs[i])
+ if top_variable not in variable_map:
+ variable_map[top_variable] = nodes[i]
+ top = top_variable
+ continue
+ head_variable = 'vv{}'.format(corefs[head_index - 1])
+ modifier = nodes[i]
+ modifier_variable = 'vv{}'.format(corefs[i])
+ label = head_labels[i]
+ assert head_variable in variable_map
+ if modifier_variable in variable_map:
+ triples.append((head_variable, label, modifier_variable))
+ else:
+ # Add quotes if there's a backslash.
+ if re.search(r'[/:\\()]', modifier) and not re.search(r'^".*"$', modifier):
+ modifier = '"{}"'.format(modifier)
+ triples.append((head_variable, label, modifier))
+
+ for var, node in variable_map.items():
+ if re.search(r'^".*"$', node):
+ node = node[1:-1]
+ if re.search(r'[/:\\()]', node):
+ parts = re.split(r'[/:\\()]', node)
+ for part in parts[::-1]:
+ if len(part):
+ node = part
+ break
+ else:
+ node = re.sub(r'[/:\\()]', '_', node)
+ triples.append((var, 'instance', node))
+
+ if len(triples) == 0:
+ triples.append(('vv1', 'instance', 'string-entity'))
+ top = 'vv1'
+ triples.sort(key=lambda x: int(x[0].replace('vv', '')))
+ graph = penman.Graph()
+ graph._top = top
+ graph._triples = [penman.Triple(*t) for t in triples]
+ graph = cls(graph)
+ try:
+ GraphRepair.do(graph, nodes)
+ amr_codec.encode(graph)
+ except Exception as e:
+ graph._top = top
+ graph._triples = [penman.Triple(*t) for t in triples]
+ graph = cls(graph)
+ return graph
+
+
+class SourceCopyVocabulary:
+ def __init__(self, sentence, pad_token=DEFAULT_PADDING_TOKEN, unk_token=DEFAULT_OOV_TOKEN):
+ if type(sentence) is not list:
+ sentence = sentence.split(" ")
+
+ self.src_tokens = sentence
+ self.pad_token = pad_token
+ self.unk_token = unk_token
+
+ self.token_to_idx = {self.pad_token: 0, self.unk_token: 1}
+ self.idx_to_token = {0: self.pad_token, 1: self.unk_token}
+
+ self.vocab_size = 2
+
+ for token in sentence:
+ if token not in self.token_to_idx:
+ self.token_to_idx[token] = self.vocab_size
+ self.idx_to_token[self.vocab_size] = token
+ self.vocab_size += 1
+
+ def get_token_from_idx(self, idx):
+ return self.idx_to_token[idx]
+
+ def get_token_idx(self, token):
+ return self.token_to_idx.get(token, self.token_to_idx[self.unk_token])
+
+ def index_sequence(self, list_tokens):
+ return [self.get_token_idx(token) for token in list_tokens]
+
+ def get_copy_map(self, list_tokens):
+ src_indices = [self.get_token_idx(self.unk_token)] + self.index_sequence(list_tokens)
+ return [
+ (src_idx, src_token_idx) for src_idx, src_token_idx in enumerate(src_indices)
+ ]
+
+ def get_special_tok_list(self):
+ return [self.pad_token, self.unk_token]
+
+ def __repr__(self):
+ return json.dumps(self.idx_to_token)
+
+
+def is_similar(instances1, instances2):
+ if len(instances1) < len(instances2):
+ small = instances1
+ large = instances2
+ else:
+ small = instances2
+ large = instances1
+ coverage1 = sum(1 for x in small if x in large) / len(small)
+ coverage2 = sum(1 for x in large if x in small) / len(large)
+ return coverage1 > .8 and coverage2 > .8
+
+
+class GraphRepair:
+
+ def __init__(self, graph, nodes):
+ self.graph = graph
+ self.nodes = nodes
+ self.repaired_items = set()
+
+ @staticmethod
+ def do(graph, nodes):
+ gr = GraphRepair(graph, nodes)
+ gr.remove_redundant_edges()
+ gr.remove_unknown_nodes()
+
+ def remove_unknown_nodes(self):
+ graph = self.graph
+ nodes = [node for node in graph.get_nodes()]
+ for node in nodes:
+ for attr, value in node.attributes:
+ if value == '@@UNKNOWN@@' and attr != 'instance':
+ graph.remove_node_attribute(node, attr, value)
+ if node.instance == '@@UNKNOWN@@':
+ if len(list(graph._G.edges(node))) == 0:
+ for source, target in list(graph._G.in_edges(node)):
+ graph.remove_edge(source, target)
+ graph.remove_node(node)
+ self.repaired_items.add('remove-unknown-node')
+
+ def remove_redundant_edges(self):
+ """
+ Edge labels such as ARGx, ARGx-of, and 'opx' should only appear at most once
+ in each node's outgoing edges.
+ """
+ graph = self.graph
+ nodes = [node for node in graph.get_nodes()]
+ removed_nodes = set()
+ for node in nodes:
+ if node in removed_nodes:
+ continue
+ edges = list(graph._G.edges(node))
+ edge_counter = defaultdict(list)
+ for source, target in edges:
+ label = graph._G[source][target]['label']
+ # `name`, `ARGx`, and `ARGx-of` should only appear once.
+ if label == 'name': # or label.startswith('ARG'):
+ edge_counter[label].append(target)
+ # the target of `opx' should only appear once.
+ elif label.startswith('op') or label.startswith('snt'):
+ edge_counter[str(target.instance)].append(target)
+ else:
+ edge_counter[label + str(target.instance)].append(target)
+ for label, children in edge_counter.items():
+ if len(children) == 1:
+ continue
+ if label == 'name':
+ # remove redundant edges.
+ for target in children[1:]:
+ if len(list(graph._G.in_edges(target))) == 1 and len(list(graph._G.edges(target))) == 0:
+ graph.remove_edge(node, target)
+ graph.remove_node(target)
+ removed_nodes.add(target)
+ self.repaired_items.add('remove-redundant-edge')
+ continue
+ visited_children = set()
+ groups = []
+ for i, target in enumerate(children):
+ if target in visited_children:
+ continue
+ subtree_instances1 = [n.instance for n in graph.get_subtree(target, 5)]
+ group = [(target, subtree_instances1)]
+ visited_children.add(target)
+ for _t in children[i + 1:]:
+ if _t in visited_children or target.instance != _t.instance:
+ continue
+ subtree_instances2 = [n.instance for n in graph.get_subtree(_t, 5)]
+ if is_similar(subtree_instances1, subtree_instances2):
+ group.append((_t, subtree_instances2))
+ visited_children.add(_t)
+ groups.append(group)
+ for group in groups:
+ if len(group) == 1:
+ continue
+ kept_target, _ = max(group, key=lambda x: len(x[1]))
+ for target, _ in group:
+ if target == kept_target:
+ continue
+ graph.remove_edge(node, target)
+ removed_nodes.update(graph.remove_subtree(target))
diff --git a/plugins/hanlp_common/hanlp_common/configurable.py b/plugins/hanlp_common/hanlp_common/configurable.py
new file mode 100644
index 000000000..a5a103ec9
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/configurable.py
@@ -0,0 +1,48 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-16 22:24
+from hanlp_common.reflection import str_to_type, classpath_of
+
+
+class Configurable(object):
+ @staticmethod
+ def from_config(config: dict, **kwargs):
+ """Build an object from config.
+
+ Args:
+ config: A ``dict`` holding parameters for its constructor. It has to contain a `classpath` key,
+ which has a classpath str as its value. ``classpath`` will determine the type of object
+ being deserialized.
+ kwargs: Arguments not used.
+
+ Returns: A deserialized object.
+
+ """
+ cls = config.get('classpath', None)
+ assert cls, f'{config} doesn\'t contain classpath field'
+ cls = str_to_type(cls)
+ deserialized_config = dict(config)
+ for k, v in config.items():
+ if isinstance(v, dict) and 'classpath' in v:
+ deserialized_config[k] = Configurable.from_config(v)
+ if cls.from_config == Configurable.from_config:
+ deserialized_config.pop('classpath')
+ return cls(**deserialized_config)
+ else:
+ return cls.from_config(deserialized_config)
+
+
+class AutoConfigurable(Configurable):
+ @property
+ def config(self) -> dict:
+ """
+ The config of this object, which are public properties. If any properties needs to be excluded from this config,
+ simply declare it with prefix ``_``.
+ """
+ return dict([('classpath', classpath_of(self))] +
+ [(k, v.config if hasattr(v, 'config') else v)
+ for k, v in self.__dict__.items() if
+ not k.startswith('_')])
+
+ def __repr__(self) -> str:
+ return repr(self.config)
diff --git a/plugins/hanlp_common/hanlp_common/conll.py b/plugins/hanlp_common/hanlp_common/conll.py
new file mode 100644
index 000000000..1e785f2c5
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/conll.py
@@ -0,0 +1,369 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-19 20:50
+from typing import Union, List
+
+from hanlp_common.structure import SerializableDict
+from hanlp_common.visualization import pretty_tree_horizontal, make_table, markdown_table
+
+
+class CoNLLWord(SerializableDict):
+ def __init__(self, id, form, lemma=None, cpos=None, pos=None, feats=None, head=None, deprel=None, phead=None,
+ pdeprel=None):
+ """CoNLL (:cite:`buchholz-marsi-2006-conll`) format template, see http://anthology.aclweb.org/W/W06/W06-2920.pdf
+
+ Args:
+ id (int):
+ Token counter, starting at 1 for each new sentence.
+ form (str):
+ Word form or punctuation symbol.
+ lemma (str):
+ Lemma or stem (depending on the particular treebank) of word form, or an underscore if not available.
+ cpos (str):
+ Coarse-grained part-of-speech tag, where the tagset depends on the treebank.
+ pos (str):
+ Fine-grained part-of-speech tag, where the tagset depends on the treebank.
+ feats (str):
+ Unordered set of syntactic and/or morphological features (depending on the particular treebank),
+ or an underscore if not available.
+ head (Union[int, List[int]]):
+ Head of the current token, which is either a value of ID,
+ or zero (’0’) if the token links to the virtual root node of the sentence.
+ deprel (Union[str, List[str]]):
+ Dependency relation to the HEAD.
+ phead (int):
+ Projective head of current token, which is either a value of ID or zero (’0’),
+ or an underscore if not available.
+ pdeprel (str):
+ Dependency relation to the PHEAD, or an underscore if not available.
+ """
+ self.id = sanitize_conll_int_value(id)
+ self.form = form
+ self.cpos = cpos
+ self.pos = pos
+ self.head = sanitize_conll_int_value(head)
+ self.deprel = deprel
+ self.lemma = lemma
+ self.feats = feats
+ self.phead = phead
+ self.pdeprel = pdeprel
+
+ def __str__(self):
+ if isinstance(self.head, list):
+ return '\n'.join('\t'.join(['_' if v is None else v for v in values]) for values in [
+ [str(self.id), self.form, self.lemma, self.cpos, self.pos, self.feats,
+ None if head is None else str(head), deprel, self.phead, self.pdeprel] for head, deprel in
+ zip(self.head, self.deprel)
+ ])
+ values = [str(self.id), self.form, self.lemma, self.cpos, self.pos, self.feats,
+ None if self.head is None else str(self.head), self.deprel, self.phead, self.pdeprel]
+ return '\t'.join(['_' if v is None else v for v in values])
+
+ @property
+ def nonempty_fields(self):
+ """
+ Get the values of nonempty fields as a list.
+ """
+ return list(f for f in
+ [self.form, self.lemma, self.cpos, self.pos, self.feats, self.head, self.deprel, self.phead,
+ self.pdeprel] if f)
+
+ def get_pos(self):
+ """
+ Get the precisest pos for this word.
+
+ Returns: ``self.pos`` or ``self.cpos``.
+
+ """
+ return self.pos or self.cpos
+
+
+class CoNLLUWord(SerializableDict):
+ def __init__(self, id: Union[int, str], form, lemma=None, upos=None, xpos=None, feats=None, head=None, deprel=None,
+ deps=None,
+ misc=None):
+ """CoNLL-U format template, see https://universaldependencies.org/format.html
+
+ Args:
+
+ id (Union[int, str]):
+ Token counter, starting at 1 for each new sentence.
+ form (Union[str, None]):
+ Word form or punctuation symbol.
+ lemma (str):
+ Lemma or stem (depending on the particular treebank) of word form, or an underscore if not available.
+ upos (str):
+ Universal part-of-speech tag.
+ xpos (str):
+ Language-specific part-of-speech tag; underscore if not available.
+ feats (str):
+ List of morphological features from the universal feature inventory or from a defined language-specific extension; underscore if not available.
+ head (int):
+ Head of the current token, which is either a value of ID,
+ or zero (’0’) if the token links to the virtual root node of the sentence.
+ deprel (str):
+ Dependency relation to the HEAD.
+ deps (Union[List[Tuple[int, str], str]):
+ Projective head of current token, which is either a value of ID or zero (’0’),
+ or an underscore if not available.
+ misc (str):
+ Dependency relation to the PHEAD, or an underscore if not available.
+ """
+ self.id = sanitize_conll_int_value(id)
+ self.form = form
+ self.upos = upos
+ self.xpos = xpos
+ if isinstance(head, list):
+ assert deps is None, 'When head is a list, deps has to be None'
+ assert isinstance(deprel, list), 'When head is a list, deprel has to be a list'
+ assert len(deprel) == len(head), 'When head is a list, deprel has to match its length'
+ deps = list(zip(head, deprel))
+ head = None
+ deprel = None
+ self.head = sanitize_conll_int_value(head)
+ self.deprel = deprel
+ self.lemma = lemma
+ self.feats = feats
+ if deps == '_':
+ deps = None
+ if isinstance(deps, str):
+ self.deps = []
+ for pair in deps.split('|'):
+ h, r = pair.split(':')
+ h = int(h)
+ self.deps.append((h, r))
+ else:
+ self.deps = deps
+ self.misc = misc
+
+ def __str__(self):
+ deps = self.deps
+ if not deps:
+ deps = None
+ else:
+ deps = '|'.join(f'{h}:{r}' for h, r in deps)
+ values = [str(self.id), self.form, self.lemma, self.upos, self.xpos, self.feats,
+ str(self.head) if self.head is not None else None, self.deprel, deps, self.misc]
+ return '\t'.join(['_' if v is None else v for v in values])
+
+ @property
+ def nonempty_fields(self):
+ """
+ Get the values of nonempty fields as a list.
+ """
+ return list(f for f in
+ [self.form, self.lemma, self.upos, self.xpos, self.feats, self.head, self.deprel, self.deps,
+ self.misc] if f)
+
+ def get_pos(self):
+ """
+ Get the precisest pos for this word.
+
+ Returns: ``self.xpos`` or ``self.upos``
+
+ """
+ return self.xpos or self.upos
+
+
+class CoNLLSentence(list):
+ def __init__(self, words=None):
+ """
+ Create from a list of :class:`~hanlp_common.conll.CoNLLWord` or :class:`~hanlp_common.conll.CoNLLUWord`
+
+ Args:
+ words (list[Union[CoNLLWord, CoNLLUWord]]): A list of words.
+ """
+ super().__init__()
+ if words:
+ self.extend(words)
+
+ def __str__(self):
+ return '\n'.join([word.__str__() for word in self])
+
+ @staticmethod
+ def from_str(conll: str, conllu=False):
+ """Build a CoNLLSentence from CoNLL-X format str
+
+ Args:
+ conll (str): CoNLL-X or CoNLL-U format string
+ conllu: ``True`` to build :class:`~hanlp_common.conll.CoNLLUWord` for each token.
+
+ Returns:
+ A :class:`~hanlp_common.conll.CoNLLSentence`.
+ """
+ words: List[CoNLLWord] = []
+ prev_id = None
+ for line in conll.strip().split('\n'):
+ if line.startswith('#'):
+ continue
+ cells = line.split('\t')
+ cells = [None if c == '_' else c for c in cells]
+ if '-' in cells[0]:
+ continue
+ cells[0] = int(cells[0])
+ cells[6] = int(cells[6])
+ if cells[0] != prev_id:
+ words.append(CoNLLUWord(*cells) if conllu else CoNLLWord(*cells))
+ else:
+ if isinstance(words[-1].head, list):
+ words[-1].head.append(cells[6])
+ words[-1].deprel.append(cells[7])
+ else:
+ words[-1].head = [words[-1].head] + [cells[6]]
+ words[-1].deprel = [words[-1].deprel] + [cells[7]]
+ prev_id = cells[0]
+ if conllu:
+ for word in words: # type: CoNLLUWord
+ if isinstance(word.head, list):
+ assert not word.deps
+ word.deps = list(zip(word.head, word.deprel))
+ word.head = None
+ word.deprel = None
+ return CoNLLSentence(words)
+
+ @staticmethod
+ def from_file(path: str, conllu=False):
+ """Build a CoNLLSentence from ``.conllx`` or ``.conllu`` file
+
+ Args:
+ path: Path to the file.
+ conllu: ``True`` to build :class:`~hanlp_common.conll.CoNLLUWord` for each token.
+
+ Returns:
+ A :class:`~hanlp_common.conll.CoNLLSentence`.
+ """
+ with open(path) as src:
+ return [CoNLLSentence.from_str(x, conllu) for x in src.read().split('\n\n') if x.strip()]
+
+ @staticmethod
+ def from_dict(d: dict, conllu=False):
+ """Build a CoNLLSentence from a dict.
+
+ Args:
+ d: A dict storing a list for each field, where each index corresponds to a token.
+ conllu: ``True`` to build :class:`~hanlp_common.conll.CoNLLUWord` for each token.
+
+ Returns:
+ A :class:`~hanlp_common.conll.CoNLLSentence`.
+ """
+ if conllu:
+ headings = ['ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS', 'FEATS', 'HEAD', 'DEPREL', 'DEPS', 'MISC']
+ else:
+ headings = ['ID', 'FORM', 'LEMMA', 'CPOS', 'POS', 'FEATS', 'HEAD', 'DEPREL', 'PHEAD', 'PDEPREL']
+ words: List[Union[CoNLLWord, CoNLLUWord]] = []
+ for cells in zip(*list(d[f] for f in headings)):
+ words.append(CoNLLUWord(*cells) if conllu else CoNLLWord(*cells))
+ return CoNLLSentence(words)
+
+ def to_markdown(self, headings: Union[str, List[str]] = 'auto') -> str:
+ r"""Convert into markdown string.
+
+ Args:
+ headings: ``auto`` to automatically detect the word type. When passed a list of string, they are treated as
+ headings for each field.
+
+ Returns:
+ A markdown representation of this sentence.
+ """
+ cells = [str(word).split('\t') for word in self]
+ if headings == 'auto':
+ if isinstance(self[0], CoNLLWord):
+ headings = ['ID', 'FORM', 'LEMMA', 'CPOS', 'POS', 'FEATS', 'HEAD', 'DEPREL', 'PHEAD', 'PDEPREL']
+ else: # conllu
+ headings = ['ID', 'FORM', 'LEMMA', 'UPOS', 'XPOS', 'FEATS', 'HEAD', 'DEPREL', 'DEPS', 'MISC']
+ for each in cells:
+ # if '|' in each[8]:
+ # each[8] = f'`{each[8]}`'
+ each[8] = each[8].replace('|', '⎮')
+ alignment = [('^', '>'), ('^', '<'), ('^', '<'), ('^', '<'), ('^', '<'), ('^', '<'), ('^', '>'), ('^', '<'),
+ ('^', '<'), ('^', '<')]
+ text = markdown_table(headings, cells, alignment=alignment)
+ return text
+
+ def to_tree(self, extras: List[str] = None) -> str:
+ """Convert into a pretty tree string which can be printed to show the tree structure.
+
+ Args:
+ extras: Extra table to be aligned to this tree.
+
+ Returns:
+ A pretty tree string along with extra table if passed any.
+ """
+ arrows = []
+ for word in self: # type: Union[CoNLLWord, CoNLLUWord]
+ if word.head:
+ arrows.append({'from': word.head - 1, 'to': word.id - 1})
+ tree = pretty_tree_horizontal(arrows)
+ rows = [['Dep Tree', 'Token', 'Relation']]
+ has_lem = all(x.lemma for x in self)
+ has_pos = all(x.get_pos() for x in self)
+ if has_lem:
+ rows[0].append('Lemma')
+ if has_pos:
+ rows[0].append('PoS')
+ if extras:
+ rows[0].extend(extras[0])
+ for i, (word, arc) in enumerate(zip(self, tree)):
+ cell_per_word = [arc]
+ cell_per_word.append(word.form)
+ cell_per_word.append(word.deprel)
+ if has_lem:
+ cell_per_word.append(word.lemma)
+ if has_pos:
+ cell_per_word.append(word.get_pos())
+ if extras:
+ cell_per_word.extend(extras[i + 1])
+ rows.append(cell_per_word)
+ return make_table(rows, insert_header=True)
+
+ @property
+ def projective(self):
+ """
+ ``True`` if this tree is projective.
+ """
+ return isprojective([x.head for x in self])
+
+
+def sanitize_conll_int_value(value: Union[str, int]):
+ if value is None or isinstance(value, int):
+ return value
+ if value == '_':
+ return None
+ if isinstance(value, str):
+ return int(value)
+ return value
+
+
+def isprojective(sequence):
+ r"""
+ Checks if a dependency tree is projective.
+ This also works for partial annotation.
+
+ Besides the obvious crossing arcs, the examples below illustrate two non-projective cases
+ which are hard to detect in the scenario of partial annotation.
+
+ Args:
+ sequence (list[int]):
+ A list of head indices.
+
+ Returns:
+ ``True`` if the tree is projective, ``False`` otherwise.
+
+ Examples:
+ >>> isprojective([2, -1, 1]) # -1 denotes un-annotated cases
+ False
+ >>> isprojective([3, -1, 2])
+ False
+ """
+
+ pairs = [(h, d) for d, h in enumerate(sequence, 1) if h >= 0]
+ for i, (hi, di) in enumerate(pairs):
+ for hj, dj in pairs[i + 1:]:
+ (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj])
+ if li <= hj <= ri and hi == dj:
+ return False
+ if lj <= hi <= rj and hj == di:
+ return False
+ if (li < lj < ri or li < rj < ri) and (li - lj) * (ri - rj) > 0:
+ return False
+ return True
diff --git a/plugins/hanlp_common/hanlp_common/constant.py b/plugins/hanlp_common/hanlp_common/constant.py
new file mode 100644
index 000000000..aff99d368
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/constant.py
@@ -0,0 +1,21 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-06-13 22:41
+import os
+
+PAD = ''
+'''Padding token.'''
+UNK = ''
+'''Unknown token.'''
+CLS = '[CLS]'
+BOS = ''
+EOS = ''
+ROOT = BOS
+IDX = '_idx_'
+'''Key for index.'''
+HANLP_URL = os.getenv('HANLP_URL', 'https://file.hankcs.com/hanlp/')
+'''Resource URL.'''
+HANLP_VERBOSE = os.environ.get('HANLP_VERBOSE', '1').lower() in ('1', 'true', 'yes')
+'''Enable verbose or not.'''
+NULL = ''
+PRED = 'PRED'
diff --git a/plugins/hanlp_common/hanlp_common/document.py b/plugins/hanlp_common/hanlp_common/document.py
new file mode 100644
index 000000000..81732004a
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/document.py
@@ -0,0 +1,400 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-31 04:16
+import json
+import re
+import warnings
+from typing import List, Union
+
+from phrasetree.tree import Tree
+
+from hanlp_common.conll import CoNLLUWord, CoNLLSentence
+from hanlp_common.constant import PRED
+from hanlp_common.util import collapse_json, prefix_match
+from hanlp_common.visualization import tree_to_list, list_to_tree, render_labeled_span, make_table
+
+
+class Document(dict):
+ def __init__(self, *args, **kwargs) -> None:
+ """
+ A dict structure holding parsed annotations.
+
+ Args:
+ *args: An iterator of key-value pairs.
+ **kwargs: Arguments from ``**`` operator.
+ """
+ super().__init__(*args, **kwargs)
+ for k, v in list(self.items()):
+ if not v:
+ continue
+ if k == 'con':
+ if isinstance(v, Tree) or isinstance(v[0], Tree):
+ continue
+ flat = isinstance(v[0], str)
+ if flat:
+ v = [v]
+ ls = []
+ for each in v:
+ if not isinstance(each, Tree):
+ ls.append(list_to_tree(each))
+ if flat:
+ ls = ls[0]
+ self[k] = ls
+ elif k == 'amr':
+ from hanlp_common.amr import AMRGraph
+ import penman
+ if isinstance(v, AMRGraph) or isinstance(v[0], AMRGraph):
+ continue
+ flat = isinstance(v[0][0], str)
+ if flat:
+ v = [v]
+ graphs = [AMRGraph(penman.Graph(triples)) for triples in v]
+ if flat:
+ graphs = graphs[0]
+ self[k] = graphs
+
+ def to_json(self, ensure_ascii=False, indent=2) -> str:
+ """Convert to json string.
+
+ Args:
+ ensure_ascii: ``False`` to allow for non-ascii text.
+ indent: Indent per nested structure.
+
+ Returns:
+ A text representation in ``str``.
+
+ """
+ d = self.to_dict()
+ text = json.dumps(d, ensure_ascii=ensure_ascii, indent=indent, default=lambda o: repr(o))
+ text = collapse_json(text, 4)
+ return text
+
+ def to_dict(self):
+ """Convert to a json compatible dict.
+
+ Returns:
+ A dict representation.
+ """
+ d = dict(self)
+ for k, v in self.items():
+ if not v:
+ continue
+ if k == 'con':
+ if not isinstance(v, Tree) and not isinstance(v[0], Tree):
+ continue
+ flat = isinstance(v, Tree)
+ if flat:
+ v = [v]
+ ls = []
+ for each in v:
+ if isinstance(each, Tree):
+ ls.append(tree_to_list(each))
+ if flat:
+ ls = ls[0]
+ d[k] = ls
+ return d
+
+ def __str__(self) -> str:
+ return self.to_json()
+
+ def to_conll(self, tok='tok', lem='lem', pos='pos', dep='dep', sdp='sdp') -> Union[
+ CoNLLSentence, List[CoNLLSentence]]:
+ """
+ Convert to :class:`~hanlp_common.conll.CoNLLSentence`.
+
+ Args:
+ tok (str): Field name for tok.
+ lem (str): Field name for lem.
+ pos (str): Filed name for upos.
+ dep (str): Field name for dependency parsing.
+ sdp (str): Field name for semantic dependency parsing.
+
+ Returns:
+ A :class:`~hanlp_common.conll.CoNLLSentence` representation.
+
+ """
+ results = []
+ if not self[tok]:
+ return results
+ flat = isinstance(self[tok][0], str)
+ if flat:
+ d = Document((k, [v]) for k, v in self.items())
+ else:
+ d = self
+ for sample in [dict(zip(d, t)) for t in zip(*d.values())]:
+ def get(_k, _i):
+ _v = sample.get(_k, None)
+ if not _v:
+ return None
+ return _v[_i]
+
+ sent = CoNLLSentence()
+
+ for i, _tok in enumerate(sample[tok]):
+ _dep = get(dep, i)
+ if not _dep:
+ _dep = (None, None)
+ sent.append(
+ CoNLLUWord(i + 1, form=_tok, lemma=get(lem, i), upos=get(pos, i), head=_dep[0], deprel=_dep[1],
+ deps=None if not get(sdp, i) else '|'.join(f'{x[0]}:{x[1]}' for x in get(sdp, i))))
+ results.append(sent)
+ if flat:
+ return results[0]
+ return results
+
+ def to_pretty(self, tok='tok', lem='lem', pos='pos', dep='dep', sdp='sdp', ner='ner', srl='srl', con='con',
+ show_header=True) -> str:
+ """
+ Convert to a pretty text representation which can be printed to visualize linguistics structures.
+
+ Args:
+ tok: Token key.
+ lem: Lemma key.
+ pos: Part-of-speech key.
+ dep: Dependency parse tree key.
+ sdp: Semantic dependency tree/graph key. SDP visualization has not been implemented yet.
+ ner: Named entity key.
+ srl: Semantic role labeling key.
+ con: Constituency parsing key.
+ show_header: ``True`` to print a header which indicates each field with its name.
+
+ Returns:
+ A pretty string.
+
+ """
+ results = []
+ tok = prefix_match(tok, self)
+ pos = prefix_match(pos, self)
+ ner = prefix_match(ner, self)
+ conlls = self.to_conll(tok, lem, pos, dep, sdp)
+ flat = isinstance(conlls, CoNLLSentence)
+ if flat:
+ conlls: List[CoNLLSentence] = [conlls]
+
+ def condense(block_, extras_=None):
+ text_ = make_table(block_, insert_header=False)
+ text_ = [x.split('\t', 1) for x in text_.split('\n')]
+ text_ = [[x[0], x[1].replace('\t', '')] for x in text_]
+ if extras_:
+ for r, s in zip(extras_, text_):
+ r.extend(s)
+ return text_
+
+ for i, conll in enumerate(conlls):
+ conll: CoNLLSentence = conll
+ tokens = [x.form for x in conll]
+ length = len(conll)
+ extras = [[] for j in range(length + 1)]
+ if ner in self:
+ ner_samples = self[ner]
+ if flat:
+ ner_samples = [ner_samples]
+ ner_per_sample = ner_samples[i]
+ header = ['Tok', 'NER', 'Type']
+ block = [[] for _ in range(length + 1)]
+ _ner = []
+ _type = []
+ offset = 0
+ for ent, label, b, e in ner_per_sample:
+ render_labeled_span(b, e, _ner, _type, label, offset)
+ offset = e
+ if offset != length:
+ _ner.extend([''] * (length - offset))
+ _type.extend([''] * (length - offset))
+ if any(_type):
+ block[0].extend(header)
+ for j, (_s, _t) in enumerate(zip(_ner, _type)):
+ block[j + 1].extend((tokens[j], _s, _t))
+ text = condense(block, extras)
+
+ if srl in self:
+ srl_samples = self[srl]
+ if flat:
+ srl_samples = [srl_samples]
+ srl_per_sample = srl_samples[i]
+ for k, pas in enumerate(srl_per_sample):
+ if not pas:
+ continue
+ block = [[] for _ in range(length + 1)]
+ header = ['Tok', 'SRL', f'PA{k + 1}']
+ _srl = []
+ _type = []
+ offset = 0
+ p_index = None
+ for _, label, b, e in pas:
+ render_labeled_span(b, e, _srl, _type, label, offset)
+ offset = e
+ if label == PRED:
+ p_index = b
+ if len(_srl) != length:
+ _srl.extend([''] * (length - offset))
+ _type.extend([''] * (length - offset))
+ if p_index is not None:
+ _srl[p_index] = '╟──►'
+ # _type[j] = 'V'
+ if len(block) != len(_srl) + 1:
+ warnings.warn(f'Unable to visualize overlapped spans: {pas}')
+ continue
+ block[0].extend(header)
+ for j, (_s, _t) in enumerate(zip(_srl, _type)):
+ block[j + 1].extend((tokens[j], _s, _t))
+ text = condense(block, extras)
+ if con in self:
+ con_samples: Tree = self[con]
+ if flat:
+ con_samples: List[Tree] = [con_samples]
+ tree = con_samples[i]
+ block = [[] for _ in range(length + 1)]
+ block[0].extend(('Tok', 'PoS'))
+ for j, t in enumerate(tree.pos()):
+ block[j + 1].extend(t)
+
+ for height in range(2, tree.height()):
+ offset = 0
+ spans = []
+ labels = []
+ for k, subtree in enumerate(tree.subtrees(lambda x: x.height() == height)):
+ subtree: Tree = subtree
+ b, e = offset, offset + len(subtree.leaves())
+ if height >= 3:
+ b, e = subtree[0].center, subtree[-1].center + 1
+ subtree.center = b + (e - b) // 2
+ render_labeled_span(b, e, spans, labels, subtree.label(), offset, unidirectional=True)
+ offset = e
+ if len(spans) != length:
+ spans.extend([''] * (length - len(spans)))
+ if len(labels) != length:
+ labels.extend([''] * (length - len(labels)))
+ if height < 3:
+ continue
+ block[0].extend(['', f'{height}'])
+ for j, (_s, _t) in enumerate(zip(spans, labels)):
+ block[j + 1].extend((_s, _t))
+ # check short arrows and increase their length
+ for j, arrow in enumerate(spans):
+ if not arrow:
+ # -1 current tag ; -2 arrow to current tag ; -3 = prev tag ; -4 = arrow to prev tag
+ if block[j + 1][-3] or block[j + 1][-4] == '───►':
+ if height > 3:
+ if block[j + 1][-3]:
+ block[j + 1][-1] = block[j + 1][-3]
+ block[j + 1][-2] = '───►'
+ else:
+ block[j + 1][-1] = '────'
+ block[j + 1][-2] = '────'
+ block[j + 1][-3] = '────'
+ if block[j + 1][-4] == '───►':
+ block[j + 1][-4] = '────'
+ else:
+ block[j + 1][-1] = '────'
+ if block[j + 1][-1] == '────':
+ block[j + 1][-2] = '────'
+ if not block[j + 1][-4]:
+ block[j + 1][-4] = '────'
+
+ text = condense(block)
+ # Cosmetic issues
+ for row in text:
+ while ' ─' in row[1]:
+ row[1] = row[1].replace(' ─', ' ──')
+ row[1] = row[1].replace('─ │', '───┤')
+ row[1] = row[1].replace('─ ├', '───┼')
+ row[1] = re.sub(r'►(\w+)(\s+)([│├])', lambda
+ m: f'►{m.group(1)}{"─" * len(m.group(2))}{"┤" if m.group(3) == "│" else "┼"}', row[1])
+ row[1] = re.sub(r'►(─+)►', r'─\1►', row[1])
+ for r, s in zip(extras, text):
+ r.extend(s)
+ # warnings.warn('Unable to visualize non-projective trees.')
+ if dep in self and conll.projective:
+ text = conll.to_tree(extras)
+ if not show_header:
+ text = text.split('\n')
+ text = '\n'.join(text[2:])
+ results.append(text)
+ elif any(extras):
+ results.append(make_table(extras, insert_header=True))
+ else:
+ results.append(' '.join(['/'.join(str(f) for f in x.nonempty_fields) for x in conll]))
+ if flat:
+ return results[0]
+ return results
+
+ def pretty_print(self, tok='tok', lem='lem', pos='pos', dep='dep', sdp='sdp', ner='ner', srl='srl', con='con',
+ show_header=True):
+ """
+ Print a pretty text representation which visualizes linguistics structures.
+
+ Args:
+ tok: Token key.
+ lem: Lemma key.
+ pos: Part-of-speech key.
+ dep: Dependency parse tree key.
+ sdp: Semantic dependency tree/graph key. SDP visualization has not been implemented yet.
+ ner: Named entity key.
+ srl: Semantic role labeling key.
+ con: Constituency parsing key.
+ show_header: ``True`` to print a header which indicates each field with its name.
+
+ """
+ results = self.to_pretty(tok, lem, pos, dep, sdp, ner, srl, con, show_header)
+ if isinstance(results, str):
+ results = [results]
+ sent_new_line = '\n\n' if any('\n' in x for x in results) else '\n'
+ print(sent_new_line.join(results))
+
+ def translate(self, lang, tok='tok', pos='pos', dep='dep', sdp='sdp', ner='ner', srl='srl'):
+ """
+ Translate tags for each annotation. This is an inplace operation.
+
+ .. Attention:: Note that the translated document might not print well in terminal due to non-ASCII characters.
+
+ Args:
+ lang: Target language to be translated to.
+ tok: Token key.
+ pos: Part-of-speech key.
+ dep: Dependency parse tree key.
+ sdp: Semantic dependency tree/graph key. SDP visualization has not been implemented yet.
+ ner: Named entity key.
+ srl: Semantic role labeling key.
+
+ Returns:
+ The translated document.
+
+ """
+ if lang == 'zh':
+ from hanlp.utils.lang.zh import localization
+ else:
+ raise NotImplementedError(f'No translation for {lang}. '
+ f'Please contribute to our translation at https://github.com/hankcs/HanLP')
+ flat = isinstance(self[tok][0], str)
+ for task, name in zip(['pos', 'ner', 'dep', 'sdp', 'srl'], [pos, ner, dep, sdp, srl]):
+ annotations = self.get(name, None)
+ if not annotations:
+ continue
+ if flat:
+ annotations = [annotations]
+ translate: dict = getattr(localization, name, None)
+ if not translate:
+ continue
+ for anno_per_sent in annotations:
+ for i, v in enumerate(anno_per_sent):
+ if task == 'ner' or task == 'dep':
+ v[1] = translate.get(v[1], v[1])
+ else:
+ anno_per_sent[i] = translate.get(v, v)
+ return self
+
+ def squeeze(self):
+ r"""
+ Squeeze the dimension of each field into one. It's intended to convert a nested document like ``[[sent1]]``
+ to ``[sent1]``. When there are multiple sentences, only the first one will be returned. Note this is not a
+ inplace operation.
+
+ Returns:
+ A squeezed document with only one sentence.
+
+ """
+ sq = Document()
+ for k, v in self.items():
+ sq[k] = v[0] if isinstance(v, list) else v
+ return sq
diff --git a/plugins/hanlp_common/hanlp_common/io.py b/plugins/hanlp_common/hanlp_common/io.py
new file mode 100644
index 000000000..6d7862447
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/io.py
@@ -0,0 +1,39 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-16 22:38
+import json
+import os
+import pickle
+import sys
+
+
+def save_pickle(item, path):
+ with open(path, 'wb') as f:
+ pickle.dump(item, f)
+
+
+def load_pickle(path):
+ with open(path, 'rb') as f:
+ return pickle.load(f)
+
+
+def save_json(item: dict, path: str, ensure_ascii=False, cls=None, default=lambda o: repr(o), indent=2):
+ dirname = os.path.dirname(path)
+ if dirname:
+ os.makedirs(dirname, exist_ok=True)
+ with open(path, 'w', encoding='utf-8') as out:
+ json.dump(item, out, ensure_ascii=ensure_ascii, indent=indent, cls=cls, default=default)
+
+
+def load_json(path):
+ with open(path, encoding='utf-8') as src:
+ return json.load(src)
+
+
+def filename_is_json(filename):
+ filename, file_extension = os.path.splitext(filename)
+ return file_extension in ['.json', '.jsonl']
+
+
+def eprint(*args, **kwargs):
+ print(*args, file=sys.stderr, **kwargs)
\ No newline at end of file
diff --git a/hanlp/utils/reflection.py b/plugins/hanlp_common/hanlp_common/reflection.py
similarity index 60%
rename from hanlp/utils/reflection.py
rename to plugins/hanlp_common/hanlp_common/reflection.py
index cf54bb086..20d15019f 100644
--- a/hanlp/utils/reflection.py
+++ b/plugins/hanlp_common/hanlp_common/reflection.py
@@ -5,11 +5,14 @@
import inspect
-def class_path_of(obj) -> str:
- """
- get the full class path of object
- :param obj:
- :return:
+def classpath_of(obj) -> str:
+ """get the full class path of object
+
+ Args:
+ obj: return:
+
+ Returns:
+
"""
if inspect.isfunction(obj):
return module_path_of(obj)
@@ -20,18 +23,22 @@ def module_path_of(func) -> str:
return inspect.getmodule(func).__name__ + '.' + func.__name__
-def object_from_class_path(class_path, **kwargs):
- class_path = str_to_type(class_path)
- if inspect.isfunction(class_path):
- return class_path
- return class_path(**kwargs)
+def object_from_classpath(classpath, **kwargs):
+ classpath = str_to_type(classpath)
+ if inspect.isfunction(classpath):
+ return classpath
+ return classpath(**kwargs)
def str_to_type(classpath):
- """
- convert class path in str format to a type
- :param classpath: class path
- :return: type
+ """convert class path in str format to a type
+
+ Args:
+ classpath: class path
+
+ Returns:
+ type
+
"""
module_name, class_name = classpath.rsplit(".", 1)
cls = getattr(importlib.import_module(module_name), class_name)
@@ -39,10 +46,14 @@ def str_to_type(classpath):
def type_to_str(type_object) -> str:
- """
- convert a type object to class path in str format
- :param type_object: type
- :return: class path
+ """convert a type object to class path in str format
+
+ Args:
+ type_object: type
+
+ Returns:
+ class path
+
"""
cls_name = str(type_object)
assert cls_name.startswith(" str:
+ d = self.to_dict()
+ if sort:
+ d = OrderedDict(sorted(d.items()))
+ return json.dumps(d, ensure_ascii=ensure_ascii, indent=indent, default=lambda o: repr(o))
+
+ def to_dict(self) -> dict:
+ return self.__dict__
+
+
+class SerializableDict(Serializable, dict):
+
+ def save_json(self, path):
+ save_json(self, path)
+
+ def copy_from(self, item):
+ if isinstance(item, dict):
+ self.clear()
+ self.update(item)
+
+ def __getattr__(self, key):
+ if key.startswith('__'):
+ return dict.__getattr__(key)
+ return self.__getitem__(key)
+
+ def __setattr__(self, key, value):
+ return self.__setitem__(key, value)
+
+ def to_dict(self) -> dict:
+ return self
\ No newline at end of file
diff --git a/hanlp/utils/util.py b/plugins/hanlp_common/hanlp_common/util.py
similarity index 58%
rename from hanlp/utils/util.py
rename to plugins/hanlp_common/hanlp_common/util.py
index 28dae2afe..aa67604ad 100644
--- a/hanlp/utils/util.py
+++ b/plugins/hanlp_common/hanlp_common/util.py
@@ -1,7 +1,33 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-10-27 19:09
-from typing import Union, Any, List, Optional, Tuple, Iterable
+import math
+from typing import Union, Any, List, Optional, Tuple, Iterable, Dict
+import inspect
+from itertools import chain, combinations
+
+
+def powerset(iterable, descending=False):
+ """
+ powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
+
+ Args:
+ iterable:
+
+ Returns:
+
+ """
+ s = list(iterable)
+ sizes = range(len(s), -1, -1) if descending else range(len(s) + 1)
+ return chain.from_iterable(combinations(s, r) for r in sizes)
+
+
+def isdebugging():
+ """See Also https://stackoverflow.com/questions/333995/how-to-detect-that-python-code-is-being-executed-through-the-debugger"""
+ for frame in inspect.stack():
+ if frame[1].endswith("pydevd.py"):
+ return True
+ return False
def list_is_list_of_lists(sent: Union[Any, List[Any]]) -> Optional[bool]:
@@ -25,16 +51,19 @@ def consume_keys_from_dict(keys: Iterable, d: dict) -> dict:
def merge_dict(d: dict, overwrite=False, inplace=False, **kwargs):
- """
- Merging the provided dict with other kvs
- Parameters
- ----------
- d
- kwargs
+ """Merging the provided dict with other kvs
+
+ Args:
+ d:
+ kwargs:
+ d: dict:
+ overwrite: (Default value = False)
+ inplace: (Default value = False)
+ **kwargs:
- Returns
- -------
+ Returns:
+
"""
nd = dict([(k, v) for k, v in d.items()] + [(k, v) for k, v in kwargs.items() if overwrite or k not in d])
if inplace:
@@ -43,9 +72,11 @@ def merge_dict(d: dict, overwrite=False, inplace=False, **kwargs):
return nd
-def merge_locals_kwargs(locals: dict, kwargs: dict):
+def merge_locals_kwargs(locals: dict, kwargs: dict = None, excludes=('self', 'kwargs', '__class__')):
+ if not kwargs:
+ kwargs = dict()
return merge_dict(dict((k, v) for k, v in list(locals.items())
- if k not in ('self', 'kwargs', '__class__')), **kwargs)
+ if k not in excludes), **kwargs)
def infer_space_after(sent: List[str]):
@@ -76,27 +107,19 @@ def infer_space_after(sent: List[str]):
return whitespace_after
-def ls_resource_in_module(root) -> dict:
- res = dict()
- for k, v in root.__dict__.items():
- if k.startswith('_') or v == root:
- continue
- if isinstance(v, str):
- if v.startswith('http') and not v.endswith('/') and not v.endswith('#'):
- res[k] = v
- elif type(v).__name__ == 'module':
- res.update(ls_resource_in_module(v))
- if 'ALL' in root.__dict__ and isinstance(root.__dict__['ALL'], dict):
- root.__dict__['ALL'].update(res)
- return res
-
-
def collapse_json(text, indent=12):
"""Compacts a string of json data by collapsing whitespace after the
specified indent level
-
+
NOTE: will not produce correct results when indent level is not a multiple
of the json indent level
+
+ Args:
+ text:
+ indent: (Default value = 12)
+
+ Returns:
+
"""
initial = " " * indent
out = [] # final json output
@@ -144,3 +167,73 @@ def collapse_json(text, indent=12):
return "\n".join(out)
+class DummyContext(object):
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+def merge_list_of_dict(samples: List[Dict]) -> dict:
+ batch = {}
+ for each in samples:
+ for k, v in each.items():
+ vs = batch.get(k, None)
+ if vs is None:
+ vs = []
+ batch[k] = vs
+ vs.append(v)
+ return batch
+
+
+def reorder(samples: List, order: List[int]) -> List:
+ return [samples[i] for i in sorted(range(len(order)), key=lambda k: order[k])]
+
+
+def k_fold(k, total, i):
+ trn = math.ceil(i / k * total)
+ tst = math.ceil((i + 1) / k * total)
+ return list(range(0, trn)) + list(range(tst, total)), list(range(trn, tst))
+
+
+def dfs(graph, start):
+ seen = set()
+ path = []
+ q = [start]
+ while q:
+ v = q.pop()
+ if v not in seen:
+ seen.add(v)
+ path.append(v)
+ q.extend(graph[v])
+
+ return path
+
+
+def topological_sort(graph, start):
+ seen = set()
+ stack = []
+ order = []
+ q = [start]
+ while q:
+ v = q.pop()
+ if v not in seen:
+ seen.add(v)
+ q.extend(graph[v])
+
+ while stack and v not in graph[stack[-1]]:
+ order.append(stack.pop())
+ stack.append(v)
+
+ return stack + order[::-1]
+
+
+def prefix_match(target, sources: Iterable[str]):
+ if target is None:
+ return None
+ if target in sources:
+ return target
+ for each in sources:
+ if each.startswith(target):
+ return each
diff --git a/plugins/hanlp_common/hanlp_common/visualization.py b/plugins/hanlp_common/hanlp_common/visualization.py
new file mode 100644
index 000000000..00d83fdc6
--- /dev/null
+++ b/plugins/hanlp_common/hanlp_common/visualization.py
@@ -0,0 +1,306 @@
+# -*- coding:utf-8 -*-
+# Modified from https://github.com/tylerneylon/explacy
+import io
+from collections import defaultdict
+from pprint import pprint
+
+from phrasetree.tree import Tree
+
+
+def make_table(rows, insert_header=False):
+ col_widths = [max(len(s) for s in col) for col in zip(*rows[1:])]
+ rows[0] = [x[:l] for x, l in zip(rows[0], col_widths)]
+ fmt = '\t'.join('%%-%ds' % width for width in col_widths)
+ if insert_header:
+ rows.insert(1, ['─' * width for width in col_widths])
+ return '\n'.join(fmt % tuple(row) for row in rows)
+
+
+def _start_end(arrow):
+ start, end = arrow['from'], arrow['to']
+ mn = min(start, end)
+ mx = max(start, end)
+ return start, end, mn, mx
+
+
+def pretty_tree_horizontal(arrows, _do_print_debug_info=False):
+ """Print the dependency tree horizontally
+
+ Args:
+ arrows:
+ _do_print_debug_info: (Default value = False)
+
+ Returns:
+
+ """
+ # Set the base height; these may increase to allow room for arrowheads after this.
+ arrows_with_deps = defaultdict(set)
+ for i, arrow in enumerate(arrows):
+ arrow['underset'] = set()
+ if _do_print_debug_info:
+ print('Arrow %d: "%s" -> "%s"' % (i, arrow['from'], arrow['to']))
+ num_deps = 0
+ start, end, mn, mx = _start_end(arrow)
+ for j, other in enumerate(arrows):
+ if arrow is other:
+ continue
+ o_start, o_end, o_mn, o_mx = _start_end(other)
+ if ((start == o_start and mn <= o_end <= mx) or
+ (start != o_start and mn <= o_start <= mx)):
+ num_deps += 1
+ if _do_print_debug_info:
+ print('%d is over %d' % (i, j))
+ arrow['underset'].add(j)
+ arrow['num_deps_left'] = arrow['num_deps'] = num_deps
+ arrows_with_deps[num_deps].add(i)
+
+ if _do_print_debug_info:
+ print('')
+ print('arrows:')
+ pprint(arrows)
+
+ print('')
+ print('arrows_with_deps:')
+ pprint(arrows_with_deps)
+
+ # Render the arrows in characters. Some heights will be raised to make room for arrowheads.
+ sent_len = max([max(arrow['from'], arrow['to']) for arrow in arrows]) + 1
+ lines = [[] for i in range(sent_len)]
+ num_arrows_left = len(arrows)
+ while num_arrows_left > 0:
+
+ assert len(arrows_with_deps[0])
+
+ arrow_index = arrows_with_deps[0].pop()
+ arrow = arrows[arrow_index]
+ src, dst, mn, mx = _start_end(arrow)
+
+ # Check the height needed.
+ height = 3
+ if arrow['underset']:
+ height = max(arrows[i]['height'] for i in arrow['underset']) + 1
+ height = max(height, 3, len(lines[dst]) + 3)
+ arrow['height'] = height
+
+ if _do_print_debug_info:
+ print('')
+ print('Rendering arrow %d: "%s" -> "%s"' % (arrow_index,
+ arrow['from'],
+ arrow['to']))
+ print(' height = %d' % height)
+
+ goes_up = src > dst
+
+ # Draw the outgoing src line.
+ if lines[src] and len(lines[src]) < height:
+ lines[src][-1].add('w')
+ while len(lines[src]) < height - 1:
+ lines[src].append(set(['e', 'w']))
+ if len(lines[src]) < height:
+ lines[src].append({'e'})
+ lines[src][height - 1].add('n' if goes_up else 's')
+
+ # Draw the incoming dst line.
+ lines[dst].append(u'►')
+ while len(lines[dst]) < height:
+ lines[dst].append(set(['e', 'w']))
+ lines[dst][-1] = set(['e', 's']) if goes_up else set(['e', 'n'])
+
+ # Draw the adjoining vertical line.
+ for i in range(mn + 1, mx):
+ while len(lines[i]) < height - 1:
+ lines[i].append(' ')
+ lines[i].append(set(['n', 's']))
+
+ # Update arrows_with_deps.
+ for arr_i, arr in enumerate(arrows):
+ if arrow_index in arr['underset']:
+ arrows_with_deps[arr['num_deps_left']].remove(arr_i)
+ arr['num_deps_left'] -= 1
+ arrows_with_deps[arr['num_deps_left']].add(arr_i)
+
+ num_arrows_left -= 1
+
+ return render_arrows(lines)
+
+
+def render_arrows(lines):
+ arr_chars = {'ew': u'─',
+ 'ns': u'│',
+ 'en': u'└',
+ 'es': u'┌',
+ 'enw': u'┴',
+ 'ensw': u'┼',
+ 'ens': u'├',
+ 'esw': u'┬'}
+ # Convert the character lists into strings.
+ max_len = max(len(line) for line in lines)
+ for i in range(len(lines)):
+ lines[i] = [arr_chars[''.join(sorted(ch))] if type(ch) is set else ch for ch in lines[i]]
+ lines[i] = ''.join(reversed(lines[i]))
+ lines[i] = ' ' * (max_len - len(lines[i])) + lines[i]
+ return lines
+
+
+def render_span(begin, end, unidirectional=False):
+ if end - begin == 1:
+ return ['───►']
+ elif end - begin == 2:
+ return [
+ '──┐',
+ '──┴►',
+ ] if unidirectional else [
+ '◄─┐',
+ '◄─┴►',
+ ]
+
+ rows = []
+ for i in range(begin, end):
+ if i == (end - begin) // 2 + begin:
+ rows.append(' ├►')
+ elif i == begin:
+ rows.append('──┐' if unidirectional else '◄─┐')
+ elif i == end - 1:
+ rows.append('──┘' if unidirectional else '◄─┘')
+ else:
+ rows.append(' │')
+ return rows
+
+
+def tree_to_list(T):
+ return [T.label(), [tree_to_list(t) if isinstance(t, Tree) else t for t in T]]
+
+
+def list_to_tree(L):
+ if isinstance(L, str):
+ return L
+ return Tree(L[0], [list_to_tree(child) for child in L[1]])
+
+
+def render_labeled_span(b, e, spans, labels, label, offset, unidirectional=False):
+ spans.extend([''] * (b - offset))
+ spans.extend(render_span(b, e, unidirectional))
+ center = b + (e - b) // 2
+ labels.extend([''] * (center - offset))
+ labels.append(label)
+ labels.extend([''] * (e - center - 1))
+
+
+def main():
+ # arrows = [{'from': 1, 'to': 0}, {'from': 2, 'to': 1}, {'from': 2, 'to': 4}, {'from': 2, 'to': 5},
+ # {'from': 4, 'to': 3}]
+ # lines = pretty_tree_horizontal(arrows)
+ # print('\n'.join(lines))
+ # print('\n'.join([
+ # '◄─┐',
+ # ' │',
+ # ' ├►',
+ # ' │',
+ # '◄─┘',
+ # ]))
+ print('\n'.join(render_span(7, 12)))
+
+
+if __name__ == '__main__':
+ main()
+left_rule = {'<': ':', '^': ':', '>': '-'}
+right_rule = {'<': '-', '^': ':', '>': ':'}
+
+
+def evalute_field(record, field_spec):
+ """Evalute a field of a record using the type of the field_spec as a guide.
+
+ Args:
+ record:
+ field_spec:
+
+ Returns:
+
+ """
+ if type(field_spec) is int:
+ return str(record[field_spec])
+ elif type(field_spec) is str:
+ return str(getattr(record, field_spec))
+ else:
+ return str(field_spec(record))
+
+
+def markdown_table(headings, records, fields=None, alignment=None, file=None):
+ """Generate a Doxygen-flavor Markdown table from records.
+ See https://stackoverflow.com/questions/13394140/generate-markdown-tables
+
+ file -- Any object with a 'write' method that takes a single string
+ parameter.
+ records -- Iterable. Rows will be generated from this.
+ fields -- List of fields for each row. Each entry may be an integer,
+ string or a function. If the entry is an integer, it is assumed to be
+ an index of each record. If the entry is a string, it is assumed to be
+ a field of each record. If the entry is a function, it is called with
+ the record and its return value is taken as the value of the field.
+ headings -- List of column headings.
+ alignment - List of pairs alignment characters. The first of the pair
+ specifies the alignment of the header, (Doxygen won't respect this, but
+ it might look good, the second specifies the alignment of the cells in
+ the column.
+
+ Possible alignment characters are:
+ '<' = Left align
+ '>' = Right align (default for cells)
+ '^' = Center (default for column headings)
+
+ Args:
+ headings:
+ records:
+ fields: (Default value = None)
+ alignment: (Default value = None)
+ file: (Default value = None)
+
+ Returns:
+
+ """
+ if not file:
+ file = io.StringIO()
+ num_columns = len(headings)
+ if not fields:
+ fields = list(range(num_columns))
+ assert len(headings) == num_columns
+
+ # Compute the table cell data
+ columns = [[] for i in range(num_columns)]
+ for record in records:
+ for i, field in enumerate(fields):
+ columns[i].append(evalute_field(record, field))
+
+ # Fill out any missing alignment characters.
+ extended_align = alignment if alignment is not None else [('^', '<')]
+ if len(extended_align) > num_columns:
+ extended_align = extended_align[0:num_columns]
+ elif len(extended_align) < num_columns:
+ extended_align += [('^', '>') for i in range(num_columns - len(extended_align))]
+
+ heading_align, cell_align = [x for x in zip(*extended_align)]
+
+ field_widths = [len(max(column, key=len)) if len(column) > 0 else 0
+ for column in columns]
+ heading_widths = [max(len(head), 2) for head in headings]
+ column_widths = [max(x) for x in zip(field_widths, heading_widths)]
+
+ _ = ' | '.join(['{:' + a + str(w) + '}'
+ for a, w in zip(heading_align, column_widths)])
+ heading_template = '| ' + _ + ' |'
+ _ = ' | '.join(['{:' + a + str(w) + '}'
+ for a, w in zip(cell_align, column_widths)])
+ row_template = '| ' + _ + ' |'
+
+ _ = ' | '.join([left_rule[a] + '-' * (w - 2) + right_rule[a]
+ for a, w in zip(cell_align, column_widths)])
+ ruling = '| ' + _ + ' |'
+
+ file.write(heading_template.format(*headings).rstrip() + '\n')
+ file.write(ruling.rstrip() + '\n')
+ for row in zip(*columns):
+ file.write(row_template.format(*row).rstrip() + '\n')
+ if isinstance(file, io.StringIO):
+ text = file.getvalue()
+ file.close()
+ return text
\ No newline at end of file
diff --git a/plugins/hanlp_common/setup.py b/plugins/hanlp_common/setup.py
new file mode 100644
index 000000000..df2f3f9a2
--- /dev/null
+++ b/plugins/hanlp_common/setup.py
@@ -0,0 +1,45 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 19:26
+from os.path import abspath, join, dirname
+from setuptools import find_packages, setup
+
+this_dir = abspath(dirname(__file__))
+with open(join(this_dir, 'README.md'), encoding='utf-8') as file:
+ long_description = file.read()
+
+setup(
+ name='hanlp_common',
+ version='0.0.1',
+ description='HanLP: Han Language Processing',
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url='https://github.com/hankcs/HanLP',
+ author='hankcs',
+ author_email='hankcshe@gmail.com',
+ license='Apache License 2.0',
+ classifiers=[
+ 'Intended Audience :: Science/Research',
+ 'Intended Audience :: Developers',
+ "Development Status :: 3 - Alpha",
+ 'Operating System :: OS Independent',
+ "License :: OSI Approved :: Apache Software License",
+ 'Programming Language :: Python :: 3 :: Only',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ "Topic :: Text Processing :: Linguistic"
+ ],
+ keywords='corpus,machine-learning,NLU,NLP',
+ packages=find_packages(exclude=['docs', 'tests*']),
+ include_package_data=True,
+ install_requires=[
+ 'phrasetree',
+ ],
+ extras_require={
+ # These AMR dependencies might not be necessary for most people.
+ 'full': [
+ 'networkx',
+ 'penman==0.6.2',
+ ],
+ },
+ python_requires='>=3.6',
+)
diff --git a/plugins/hanlp_demo/README.md b/plugins/hanlp_demo/README.md
new file mode 100644
index 000000000..aae8a4fa5
--- /dev/null
+++ b/plugins/hanlp_demo/README.md
@@ -0,0 +1,3 @@
+# Demos and examples for HanLP
+
+This package is intended for demonstration purpose and won't be released to pypi.
\ No newline at end of file
diff --git a/plugins/hanlp_demo/hanlp_demo/__init__.py b/plugins/hanlp_demo/hanlp_demo/__init__.py
new file mode 100644
index 000000000..2725b5824
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-29 17:48
+from .trie import Trie
+from .dictionary import DictInterface, TrieDict
diff --git a/tests/demo/en/__init__.py b/plugins/hanlp_demo/hanlp_demo/en/__init__.py
similarity index 100%
rename from tests/demo/en/__init__.py
rename to plugins/hanlp_demo/hanlp_demo/en/__init__.py
diff --git a/tests/demo/en/demo_dep.py b/plugins/hanlp_demo/hanlp_demo/en/demo_dep.py
similarity index 100%
rename from tests/demo/en/demo_dep.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_dep.py
diff --git a/tests/demo/en/demo_emotion_prediction.py b/plugins/hanlp_demo/hanlp_demo/en/demo_emotion_prediction.py
similarity index 100%
rename from tests/demo/en/demo_emotion_prediction.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_emotion_prediction.py
diff --git a/tests/demo/en/demo_lm.py b/plugins/hanlp_demo/hanlp_demo/en/demo_lm.py
similarity index 66%
rename from tests/demo/en/demo_lm.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_lm.py
index dfe3d8e85..a1207f3dd 100644
--- a/tests/demo/en/demo_lm.py
+++ b/plugins/hanlp_demo/hanlp_demo/en/demo_lm.py
@@ -3,5 +3,5 @@
# Date: 2020-02-11 09:14
import hanlp
-lm = hanlp.load(hanlp.pretrained.rnnlm.FLAIR_LM_FW_WMT11_EN)
+lm = hanlp.load(hanlp.pretrained.rnnlm.FLAIR_LM_FW_WMT11_EN_TF)
print(''.join(lm.generate_text(list('hello'))))
diff --git a/tests/demo/en/demo_ner.py b/plugins/hanlp_demo/hanlp_demo/en/demo_ner.py
similarity index 100%
rename from tests/demo/en/demo_ner.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_ner.py
diff --git a/tests/demo/en/demo_pipeline.py b/plugins/hanlp_demo/hanlp_demo/en/demo_pipeline.py
similarity index 100%
rename from tests/demo/en/demo_pipeline.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_pipeline.py
diff --git a/tests/demo/en/demo_pos.py b/plugins/hanlp_demo/hanlp_demo/en/demo_pos.py
similarity index 100%
rename from tests/demo/en/demo_pos.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_pos.py
diff --git a/tests/demo/en/demo_sdp.py b/plugins/hanlp_demo/hanlp_demo/en/demo_sdp.py
similarity index 91%
rename from tests/demo/en/demo_sdp.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_sdp.py
index 696f8e55f..20947eb7b 100644
--- a/tests/demo/en/demo_sdp.py
+++ b/plugins/hanlp_demo/hanlp_demo/en/demo_sdp.py
@@ -2,7 +2,7 @@
# Author: hankcs
# Date: 2020-01-03 15:26
import hanlp
-from hanlp.components.parsers.conll import CoNLLSentence
+from hanlp_common.conll import CoNLLSentence
# semeval15 offers three independent annotations over the Penn Treebank (PTB)
semantic_parser = hanlp.load(hanlp.pretrained.sdp.SEMEVAL15_PAS_BIAFFINE_EN)
diff --git a/tests/demo/en/demo_sentiment_analysis.py b/plugins/hanlp_demo/hanlp_demo/en/demo_sentiment_analysis.py
similarity index 100%
rename from tests/demo/en/demo_sentiment_analysis.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_sentiment_analysis.py
diff --git a/tests/demo/en/demo_tok.py b/plugins/hanlp_demo/hanlp_demo/en/demo_tok.py
similarity index 100%
rename from tests/demo/en/demo_tok.py
rename to plugins/hanlp_demo/hanlp_demo/en/demo_tok.py
diff --git a/plugins/hanlp_demo/hanlp_demo/mul/__init__.py b/plugins/hanlp_demo/hanlp_demo/mul/__init__.py
new file mode 100644
index 000000000..325516506
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/mul/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-31 22:25
diff --git a/plugins/hanlp_demo/hanlp_demo/mul/demo_mtl.py b/plugins/hanlp_demo/hanlp_demo/mul/demo_mtl.py
new file mode 100644
index 000000000..367bf48ca
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/mul/demo_mtl.py
@@ -0,0 +1,14 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-31 13:51
+import hanlp
+from hanlp_common.document import Document
+
+HanLP = hanlp.load(hanlp.pretrained.mtl.UD_ONTONOTES_TOK_POS_LEM_FEA_NER_SRL_DEP_SDP_CON_MT5_BASE)
+doc: Document = HanLP([
+ 'In 2021, HanLPv2.1 delivers state-of-the-art multilingual NLP techniques to production environment.',
+ '2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。',
+ '2021年 HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。',
+])
+print(doc)
+doc.pretty_print()
diff --git a/plugins/hanlp_demo/hanlp_demo/sent_split.py b/plugins/hanlp_demo/hanlp_demo/sent_split.py
new file mode 100644
index 000000000..e38410b2a
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/sent_split.py
@@ -0,0 +1,9 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-31 14:23
+import hanlp
+
+split_sent = hanlp.load(hanlp.pretrained.eos.UD_CTB_EOS_MUL)
+output = split_sent('3.14 is pi. “你好!!!”——他说。劇場版「Fate/stay night [HF]」最終章公開カウントダウン!')
+print('\n'.join(output))
+# See also https://hanlp.hankcs.com/docs/api/hanlp/components/eos.html
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/__init__.py b/plugins/hanlp_demo/hanlp_demo/zh/__init__.py
new file mode 100644
index 000000000..cf552c85e
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-31 13:51
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/demo_custom_dict.py b/plugins/hanlp_demo/hanlp_demo/zh/demo_custom_dict.py
new file mode 100644
index 000000000..d106243f2
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/demo_custom_dict.py
@@ -0,0 +1,27 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-15 22:26
+import hanlp
+from hanlp.components.mtl.multi_task_learning import MultiTaskLearning
+from hanlp.components.mtl.tasks.tok.tag_tok import TaggingTokenization
+from tests import cdroot
+
+cdroot()
+HanLP: MultiTaskLearning = hanlp.load(hanlp.pretrained.mtl.CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH)
+tok: TaggingTokenization = HanLP['tok/fine']
+
+# tok.dict_force = tok.dict_combine = None
+# print(f'不挂词典:\n{HanLP("商品和服务行业")["tok/fine"]}')
+#
+# tok.dict_force = {'和服', '服务行业'}
+# print(f'强制模式:\n{HanLP("商品和服务行业")["tok/fine"]}') # 慎用,详见《自然语言处理入门》第二章
+#
+# tok.dict_force = {'和服务': ['和', '服务']}
+# print(f'强制校正:\n{HanLP("正向匹配商品和服务、任何和服务必按上述切分")["tok/fine"]}')
+
+tok.dict_force = None
+tok.dict_combine = {'和服', '服务行业'}
+print(f'合并模式:\n{HanLP("商品和服务行业")["tok/fine"]}')
+
+# 需要算法基础才能理解,初学者可参考 http://nlp.hankcs.com/book.php
+# See also https://hanlp.hankcs.com/docs/api/hanlp/components/tokenizers/transformer.html
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/demo_mtl.py b/plugins/hanlp_demo/hanlp_demo/zh/demo_mtl.py
new file mode 100644
index 000000000..89a16ef02
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/demo_mtl.py
@@ -0,0 +1,12 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-31 13:51
+import hanlp
+from hanlp_common.document import Document
+
+HanLP = hanlp.load(hanlp.pretrained.mtl.CLOSE_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_BASE_ZH)
+doc: Document = HanLP(['2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。', '阿婆主来到北京立方庭参观自然语义科技公司。'])
+print(doc)
+doc.pretty_print()
+# Specify which annotation to use
+# doc.pretty_print(ner='ner/ontonotes', pos='pku')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/__init__.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/__init__.py
new file mode 100644
index 000000000..badaca5de
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-31 20:36
diff --git a/tests/train/zh/cws/__init__.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/__init__.py
similarity index 100%
rename from tests/train/zh/cws/__init__.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/__init__.py
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_albert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_albert.py
new file mode 100644
index 000000000..f780082ab
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_albert.py
@@ -0,0 +1,18 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 22:22
+
+from hanlp.components.tok_tf import TransformerTokenizerTF
+from hanlp.datasets.cws.ctb import CTB6_CWS_TRAIN, CTB6_CWS_DEV, CTB6_CWS_TEST
+from tests import cdroot
+
+cdroot()
+tokenizer = TransformerTokenizerTF()
+save_dir = 'data/model/cws_bert_albert_ctb6'
+tokenizer.fit(CTB6_CWS_TRAIN, CTB6_CWS_DEV, save_dir,
+ transformer='/home/ubuntu/hankcs/laser/data/transformer/albert_base_tf2',
+ metrics='f1', learning_rate=5e-5, epochs=3)
+tokenizer.load(save_dir)
+print(tokenizer.predict(['中央民族乐团离开北京前往维也纳', '商品和服务']))
+tokenizer.evaluate(CTB6_CWS_TEST, save_dir=save_dir)
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_bert.py
new file mode 100644
index 000000000..f6df7a34d
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_bert.py
@@ -0,0 +1,17 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 22:22
+
+from hanlp.components.tok_tf import TransformerTokenizerTF
+from hanlp.datasets.cws.ctb import CTB6_CWS_TRAIN, CTB6_CWS_DEV, CTB6_CWS_TEST
+from tests import cdroot
+
+cdroot()
+tokenizer = TransformerTokenizerTF()
+save_dir = 'data/model/cws_bert_base_ctb6'
+# tagger.fit(CTB6_CWS_TRAIN, CTB6_CWS_DEV, save_dir, transformer='bert-base-chinese',
+# metrics='f1')
+tokenizer.load(save_dir)
+print(tokenizer.predict(['中央民族乐团离开北京前往维也纳', '商品和服务']))
+tokenizer.evaluate(CTB6_CWS_TEST, save_dir=save_dir)
+print(f'Model saved in {save_dir}')
diff --git a/tests/train/zh/cws/train_ctb6_cws_convseg.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_convseg.py
similarity index 84%
rename from tests/train/zh/cws/train_ctb6_cws_convseg.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_convseg.py
index eb7087d31..a981ef0a7 100644
--- a/tests/train/zh/cws/train_ctb6_cws_convseg.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_ctb6_cws_convseg.py
@@ -3,18 +3,18 @@
# Date: 2019-12-28 22:22
import tensorflow as tf
-from hanlp.components.tok import NgramConvTokenizer
-from hanlp.datasets.cws.ctb import CTB6_CWS_TRAIN, CTB6_CWS_VALID, CTB6_CWS_TEST
+from hanlp.components.tok_tf import NgramConvTokenizerTF
+from hanlp.datasets.cws.ctb6 import CTB6_CWS_TRAIN, CTB6_CWS_DEV, CTB6_CWS_TEST
from hanlp.pretrained.word2vec import CONVSEG_W2V_NEWS_TENSITE_CHAR
from tests import cdroot
cdroot()
-tokenizer = NgramConvTokenizer()
+tokenizer = NgramConvTokenizerTF()
save_dir = 'data/model/cws/ctb6_cws'
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,
epsilon=1e-8, clipnorm=5)
tokenizer.fit(CTB6_CWS_TRAIN,
- CTB6_CWS_VALID,
+ CTB6_CWS_DEV,
save_dir,
word_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/train/zh/cws/train_large_bert_cws.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_bert_cws.py
similarity index 62%
rename from tests/train/zh/cws/train_large_bert_cws.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_bert_cws.py
index 7f77cd13b..54f6410a1 100644
--- a/tests/train/zh/cws/train_large_bert_cws.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_bert_cws.py
@@ -1,14 +1,14 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-21 15:39
-from hanlp.components.tok import TransformerTokenizer
-from hanlp.datasets.cws.ctb import CTB6_CWS_VALID, CTB6_CWS_TEST
+from hanlp.components.tok_tf import TransformerTokenizerTF
+from hanlp.datasets.cws.ctb import CTB6_CWS_DEV, CTB6_CWS_TEST
from tests import cdroot
cdroot()
-tokenizer = TransformerTokenizer()
+tokenizer = TransformerTokenizerTF()
save_dir = 'data/model/cws_bert_base_100million'
-tokenizer.fit('data/cws/large/all.txt', CTB6_CWS_VALID, save_dir, transformer='bert-base-chinese',
+tokenizer.fit('data/cws/large/all.txt', CTB6_CWS_DEV, save_dir, transformer='bert-base-chinese',
metrics='accuracy', batch_size=32)
tokenizer.load(save_dir, metrics='f1')
print(tokenizer.predict(['中央民族乐团离开北京前往维也纳', '商品和服务']))
diff --git a/tests/train/zh/cws/train_large_conv_cws.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_conv_cws.py
similarity index 84%
rename from tests/train/zh/cws/train_large_conv_cws.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_conv_cws.py
index 2c09142e3..af58ef11c 100644
--- a/tests/train/zh/cws/train_large_conv_cws.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_conv_cws.py
@@ -4,18 +4,18 @@
import tensorflow as tf
-from hanlp.components.tok import NgramConvTokenizer
-from hanlp.datasets.cws.ctb import CTB6_CWS_TRAIN, CTB6_CWS_VALID, CTB6_CWS_TEST
+from hanlp.components.tok_tf import NgramConvTokenizerTF
+from hanlp.datasets.cws.ctb import CTB6_CWS_TRAIN, CTB6_CWS_DEV, CTB6_CWS_TEST
from hanlp.pretrained.word2vec import CONVSEG_W2V_NEWS_TENSITE_CHAR
from tests import cdroot
cdroot()
-tokenizer = NgramConvTokenizer()
+tokenizer = NgramConvTokenizerTF()
save_dir = 'data/model/cws/ctb6_cws'
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,
epsilon=1e-8, clipnorm=5)
tokenizer.fit(CTB6_CWS_TRAIN,
- CTB6_CWS_VALID,
+ CTB6_CWS_DEV,
save_dir,
word_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/train/zh/cws/train_large_cws_albert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_cws_albert.py
similarity index 100%
rename from tests/train/zh/cws/train_large_cws_albert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_cws_albert.py
diff --git a/tests/train/zh/cws/train_large_rnn_cws.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_rnn_cws.py
similarity index 87%
rename from tests/train/zh/cws/train_large_rnn_cws.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_rnn_cws.py
index 60974fe86..922cccc9b 100644
--- a/tests/train/zh/cws/train_large_rnn_cws.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_large_rnn_cws.py
@@ -3,19 +3,19 @@
# Date: 2019-12-21 15:39
import tensorflow as tf
-from hanlp.components.tok import RNNTokenizer
-from hanlp.datasets.cws.ctb import CTB6_CWS_TEST, CTB6_CWS_VALID
+from hanlp.components.tok_tf import RNNTokenizerTF
+from hanlp.datasets.cws.ctb import CTB6_CWS_TEST, CTB6_CWS_DEV
from hanlp.pretrained.word2vec import RADICAL_CHAR_EMBEDDING_100, CONVSEG_W2V_NEWS_TENSITE_CHAR
from tests import cdroot
cdroot()
-tokenizer = RNNTokenizer()
+tokenizer = RNNTokenizerTF()
save_dir = 'data/model/cws/large_rnn_cws'
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,
epsilon=1e-8, clipnorm=5)
tokenizer.fit('data/cws/large/all.txt',
- CTB6_CWS_VALID,
+ CTB6_CWS_DEV,
save_dir,
embeddings={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/train/zh/cws/train_msr_cws_albert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_albert.py
similarity index 100%
rename from tests/train/zh/cws/train_msr_cws_albert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_albert.py
diff --git a/tests/train/zh/cws/train_msr_cws_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_bert.py
similarity index 61%
rename from tests/train/zh/cws/train_msr_cws_bert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_bert.py
index 50252eb09..938b540f7 100644
--- a/tests/train/zh/cws/train_msr_cws_bert.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_bert.py
@@ -1,16 +1,16 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-21 15:39
-from hanlp.components.tok import TransformerTokenizer
-from hanlp.datasets.cws.sighan2005.msr import SIGHAN2005_MSR_TRAIN, SIGHAN2005_MSR_VALID, SIGHAN2005_MSR_TEST
+from hanlp.components.tok_tf import TransformerTokenizerTF
+from hanlp.datasets.cws.sighan2005.msr import SIGHAN2005_MSR_TRAIN, SIGHAN2005_MSR_DEV, SIGHAN2005_MSR_TEST
from tests import cdroot
cdroot()
-tokenizer = TransformerTokenizer()
+tokenizer = TransformerTokenizerTF()
save_dir = 'data/model/cws_bert_base_msra'
-tokenizer.fit(SIGHAN2005_MSR_TRAIN, SIGHAN2005_MSR_VALID, save_dir, transformer='chinese_L-12_H-768_A-12',
+tokenizer.fit(SIGHAN2005_MSR_TRAIN, SIGHAN2005_MSR_DEV, save_dir, transformer='bert-base-chinese',
metrics='f1')
-# tokenizer.load(save_dir)
+# tagger.load(save_dir)
print(tokenizer.predict(['中央民族乐团离开北京前往维也纳', '商品和服务']))
tokenizer.evaluate(SIGHAN2005_MSR_TEST, save_dir=save_dir)
print(f'Model saved in {save_dir}')
diff --git a/tests/train/zh/cws/train_msr_cws_ngram_conv.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_ngram_conv.py
similarity index 87%
rename from tests/train/zh/cws/train_msr_cws_ngram_conv.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_ngram_conv.py
index 0ba1d2ec2..0a6d69573 100644
--- a/tests/train/zh/cws/train_msr_cws_ngram_conv.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_ngram_conv.py
@@ -3,16 +3,16 @@
# Date: 2019-12-21 15:39
import tensorflow as tf
-from hanlp.components.tok import NgramConvTokenizer
-from hanlp.datasets.cws.sighan2005.msr import SIGHAN2005_MSR_TRAIN, SIGHAN2005_MSR_VALID, SIGHAN2005_MSR_TEST
+from hanlp.components.tok_tf import NgramConvTokenizerTF
+from hanlp.datasets.cws.sighan2005.msr import SIGHAN2005_MSR_TRAIN, SIGHAN2005_MSR_DEV, SIGHAN2005_MSR_TEST
from hanlp.pretrained.word2vec import CONVSEG_W2V_NEWS_TENSITE_CHAR
from tests import cdroot
cdroot()
-tokenizer = NgramConvTokenizer()
+tokenizer = NgramConvTokenizerTF()
save_dir = 'data/model/cws/convseg-msr-nocrf-noembed'
tokenizer.fit(SIGHAN2005_MSR_TRAIN,
- SIGHAN2005_MSR_VALID,
+ SIGHAN2005_MSR_DEV,
save_dir,
word_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/train/zh/cws/train_msr_cws_ngram_conv_embed.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_ngram_conv_embed.py
similarity index 100%
rename from tests/train/zh/cws/train_msr_cws_ngram_conv_embed.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_msr_cws_ngram_conv_embed.py
diff --git a/tests/train/zh/cws/train_pku980106_conv_cws.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku980106_conv_cws.py
similarity index 92%
rename from tests/train/zh/cws/train_pku980106_conv_cws.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku980106_conv_cws.py
index a2bb425ed..046b6868c 100644
--- a/tests/train/zh/cws/train_pku980106_conv_cws.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku980106_conv_cws.py
@@ -3,13 +3,13 @@
# Date: 2019-12-21 15:39
import tensorflow as tf
-from hanlp.components.tok import NgramConvTokenizer
+from hanlp.components.tok_tf import NgramConvTokenizerTF
from hanlp.pretrained.word2vec import RADICAL_CHAR_EMBEDDING_100
from tests import cdroot
cdroot()
-tokenizer = NgramConvTokenizer()
+tokenizer = NgramConvTokenizerTF()
save_dir = 'data/model/cws/pku98_6m_conv_ngram'
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,
epsilon=1e-8, clipnorm=5)
diff --git a/tests/train/zh/cws/train_pku980106_rnn_cws.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku980106_rnn_cws.py
similarity index 93%
rename from tests/train/zh/cws/train_pku980106_rnn_cws.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku980106_rnn_cws.py
index 890c5d3aa..94df236d9 100644
--- a/tests/train/zh/cws/train_pku980106_rnn_cws.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku980106_rnn_cws.py
@@ -3,13 +3,13 @@
# Date: 2019-12-21 15:39
import tensorflow as tf
-from hanlp.components.tok import RNNTokenizer
+from hanlp.components.tok_tf import RNNTokenizerTF
from hanlp.pretrained.word2vec import RADICAL_CHAR_EMBEDDING_100
from tests import cdroot
cdroot()
-tokenizer = RNNTokenizer()
+tokenizer = RNNTokenizerTF()
save_dir = 'data/model/cws/pku_6m_rnn_cws'
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,
epsilon=1e-8, clipnorm=5)
diff --git a/tests/train/zh/cws/train_pku_conv_cws.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku_conv_cws.py
similarity index 88%
rename from tests/train/zh/cws/train_pku_conv_cws.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku_conv_cws.py
index d3a3f7e95..ba9606a5a 100644
--- a/tests/train/zh/cws/train_pku_conv_cws.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/cws/train_pku_conv_cws.py
@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-21 15:39
-from hanlp.datasets.cws.sighan2005.pku import SIGHAN2005_PKU_TRAIN, SIGHAN2005_PKU_VALID, SIGHAN2005_PKU_TEST
+from hanlp.datasets.cws.sighan2005.pku import SIGHAN2005_PKU_TRAIN, SIGHAN2005_PKU_DEV, SIGHAN2005_PKU_TEST
from hanlp.pretrained.word2vec import CONVSEG_W2V_NEWS_TENSITE_CHAR
from hanlp.utils.tf_util import nice
from tests import cdroot
@@ -9,14 +9,14 @@
nice()
cdroot()
-from hanlp.components.tok import NgramConvTokenizer
+from hanlp.components.tok_tf import NgramConvTokenizerTF
-tokenizer = NgramConvTokenizer()
+tokenizer = NgramConvTokenizerTF()
save_dir = 'data/model/cws/sighan2005-pku-convseg'
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,
epsilon=1e-8, clipnorm=5)
tokenizer.fit(SIGHAN2005_PKU_TRAIN,
- SIGHAN2005_PKU_VALID,
+ SIGHAN2005_PKU_DEV,
save_dir,
word_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/demo/zh/demo_classifier.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_classifier.py
similarity index 100%
rename from tests/demo/zh/demo_classifier.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_classifier.py
diff --git a/tests/demo/zh/demo_client.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_client.py
similarity index 95%
rename from tests/demo/zh/demo_client.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_client.py
index 641773379..7bcf78403 100644
--- a/tests/demo/zh/demo_client.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_client.py
@@ -7,7 +7,7 @@
from tensorflow_core.python.framework import tensor_util
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
import hanlp
-from hanlp.common.component import KerasComponent
+from hanlp.common.keras_component import KerasComponent
tagger: KerasComponent = hanlp.load(hanlp.pretrained.pos.CTB5_POS_RNN, transform_only=True)
transform = tagger.transform
diff --git a/tests/demo/zh/demo_cws.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_cws.py
similarity index 95%
rename from tests/demo/zh/demo_cws.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_cws.py
index 6a019932d..b4c55ba3b 100644
--- a/tests/demo/zh/demo_cws.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_cws.py
@@ -3,7 +3,7 @@
# Date: 2019-12-28 21:25
import hanlp
-tokenizer = hanlp.load(hanlp.pretrained.cws.LARGE_ALBERT_BASE)
+tokenizer = hanlp.load(hanlp.pretrained.tok.LARGE_ALBERT_BASE)
print(tokenizer('商品和服务'))
print(tokenizer(['萨哈夫说,伊拉克将同联合国销毁伊拉克大规模杀伤性武器特别委员会继续保持合作。',
'上海华安工业(集团)公司董事长谭旭光和秘书张晚霞来到美国纽约现代艺术博物馆参观。',
diff --git a/tests/demo/zh/demo_cws_trie.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_cws_trie.py
similarity index 81%
rename from tests/demo/zh/demo_cws_trie.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_cws_trie.py
index 86adf98f7..812b13e17 100644
--- a/tests/demo/zh/demo_cws_trie.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_cws_trie.py
@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 21:25
-from hanlp.common.trie import Trie
+from hanlp_trie.trie import Trie
import hanlp
@@ -18,7 +18,7 @@ def split_sents(text: str, trie: Trie):
sents = []
pre_start = 0
offsets = []
- for word, value, start, end in words:
+ for start, end, value in words:
if pre_start != start:
sents.append(text[pre_start: start])
offsets.append(pre_start)
@@ -34,9 +34,7 @@ def split_sents(text: str, trie: Trie):
def merge_parts(parts, offsets, words):
items = [(i, p) for (i, p) in zip(offsets, parts)]
- items += [(start, [word]) for (word, value, start, end) in words]
- # In case you need the tag, use the following line instead
- # items += [(start, [(word, value)]) for (word, value, start, end) in words]
+ items += [(start, [value]) for (start, end, value) in words]
return [each for x in sorted(items) for each in x[1]]
diff --git a/tests/demo/zh/demo_dep.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_dep.py
similarity index 100%
rename from tests/demo/zh/demo_dep.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_dep.py
diff --git a/tests/demo/zh/demo_multiprocess.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_multiprocess.py
similarity index 92%
rename from tests/demo/zh/demo_multiprocess.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_multiprocess.py
index db959e8ea..c4476f7af 100644
--- a/tests/demo/zh/demo_multiprocess.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_multiprocess.py
@@ -4,7 +4,7 @@
import multiprocessing
import hanlp
-tokenizer = hanlp.load(hanlp.pretrained.cws.PKU_NAME_MERGED_SIX_MONTHS_CONVSEG)
+tokenizer = hanlp.load(hanlp.pretrained.tok.PKU_NAME_MERGED_SIX_MONTHS_CONVSEG)
def worker(job):
diff --git a/tests/demo/zh/demo_ner.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_ner.py
similarity index 100%
rename from tests/demo/zh/demo_ner.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_ner.py
diff --git a/tests/demo/zh/demo_pipeline.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_pipeline.py
similarity index 100%
rename from tests/demo/zh/demo_pipeline.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_pipeline.py
diff --git a/tests/demo/zh/demo_pos.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_pos.py
similarity index 100%
rename from tests/demo/zh/demo_pos.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_pos.py
diff --git a/tests/demo/zh/demo_sdp.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_sdp.py
similarity index 100%
rename from tests/demo/zh/demo_sdp.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_sdp.py
diff --git a/tests/demo/zh/demo_serving.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_serving.py
similarity index 78%
rename from tests/demo/zh/demo_serving.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/demo_serving.py
index 07fc16c2e..ebca07aa6 100644
--- a/tests/demo/zh/demo_serving.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/demo_serving.py
@@ -2,7 +2,7 @@
# Author: hankcs
# Date: 2020-01-06 20:23
import hanlp
-from hanlp.common.component import KerasComponent
+from hanlp.common.keras_component import KerasComponent
tagger: KerasComponent = hanlp.load(hanlp.pretrained.pos.CTB5_POS_RNN)
print(tagger('商品 和 服务'.split()))
diff --git a/tests/train/zh/train_chnsenticorp_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_chnsenticorp_bert.py
similarity index 64%
rename from tests/train/zh/train_chnsenticorp_bert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_chnsenticorp_bert.py
index cdacdb26a..bf81eb354 100644
--- a/tests/train/zh/train_chnsenticorp_bert.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_chnsenticorp_bert.py
@@ -1,15 +1,15 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-30 21:01
-from hanlp.components.classifiers.transformer_classifier import TransformerClassifier, TransformerTextTransform
+from hanlp.components.classifiers.transformer_classifier_tf import TransformerClassifierTF, TransformerTextTransform
from hanlp.datasets.classification.sentiment import CHNSENTICORP_ERNIE_TRAIN, CHNSENTICORP_ERNIE_TEST, \
- CHNSENTICORP_ERNIE_VALID
+ CHNSENTICORP_ERNIE_DEV
from tests import cdroot
cdroot()
save_dir = 'data/model/classification/chnsenticorp_bert_base'
-classifier = TransformerClassifier(TransformerTextTransform(y_column=0))
-classifier.fit(CHNSENTICORP_ERNIE_TRAIN, CHNSENTICORP_ERNIE_VALID, save_dir,
+classifier = TransformerClassifierTF(TransformerTextTransform(y_column=0))
+classifier.fit(CHNSENTICORP_ERNIE_TRAIN, CHNSENTICORP_ERNIE_DEV, save_dir,
transformer='chinese_L-12_H-768_A-12')
classifier.load(save_dir)
print(classifier.predict('前台客房服务态度非常好!早餐很丰富,房价很干净。再接再厉!'))
diff --git a/tests/train/en/train_conll03_ner_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_conll03_ner_bert.py
similarity index 67%
rename from tests/train/en/train_conll03_ner_bert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_conll03_ner_bert.py
index 921b54c0f..e834f87a5 100644
--- a/tests/train/en/train_conll03_ner_bert.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_conll03_ner_bert.py
@@ -1,14 +1,14 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-10-25 21:34
-from hanlp.components.ner import TransformerNamedEntityRecognizer
-from hanlp.datasets.ner.conll03 import CONLL03_EN_TRAIN, CONLL03_EN_VALID, CONLL03_EN_TEST
+from hanlp.components.ner_tf import TransformerNamedEntityRecognizerTF
+from hanlp.datasets.ner.conll03 import CONLL03_EN_TRAIN, CONLL03_EN_DEV, CONLL03_EN_TEST
from tests import cdroot
cdroot()
-tagger = TransformerNamedEntityRecognizer()
+tagger = TransformerNamedEntityRecognizerTF()
save_dir = 'data/model/ner/ner_conll03_bert_base_uncased_en'
-tagger.fit(CONLL03_EN_TRAIN, CONLL03_EN_VALID, save_dir, transformer='uncased_L-12_H-768_A-12',
+tagger.fit(CONLL03_EN_TRAIN, CONLL03_EN_DEV, save_dir, transformer='uncased_L-12_H-768_A-12',
metrics='accuracy')
tagger.load(save_dir, metrics='f1')
print(tagger.predict('West Indian all-rounder Phil Simmons eats apple .'.split()))
diff --git a/tests/train/en/train_conll03_ner_flair.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_conll03_ner_flair.py
similarity index 84%
rename from tests/train/en/train_conll03_ner_flair.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_conll03_ner_flair.py
index 0a06e3110..e3da49bc6 100644
--- a/tests/train/en/train_conll03_ner_flair.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_conll03_ner_flair.py
@@ -4,14 +4,14 @@
import tensorflow as tf
-from hanlp.components.ner import RNNNamedEntityRecognizer
+from hanlp.components.ner_tf import RNNNamedEntityRecognizerTF
from hanlp.datasets.ner.conll03 import CONLL03_EN_TRAIN, CONLL03_EN_TEST
from hanlp.pretrained.glove import GLOVE_6B_100D
-from hanlp.pretrained.rnnlm import FLAIR_LM_FW_WMT11_EN, FLAIR_LM_BW_WMT11_EN
+from hanlp.pretrained.rnnlm import FLAIR_LM_FW_WMT11_EN_TF, FLAIR_LM_BW_WMT11_EN_TF
from tests import cdroot
cdroot()
-tagger = RNNNamedEntityRecognizer()
+tagger = RNNNamedEntityRecognizerTF()
save_dir = 'data/model/conll03-ner-rnn-flair'
tagger.fit(CONLL03_EN_TRAIN, CONLL03_EN_TEST, save_dir, epochs=100,
optimizer=tf.keras.optimizers.Adam(learning_rate=0.1,
@@ -32,8 +32,8 @@
{'class_name': 'HanLP>ContextualStringEmbedding',
'config': {
'trainable': False,
- 'forward_model_path': FLAIR_LM_FW_WMT11_EN,
- 'backward_model_path': FLAIR_LM_BW_WMT11_EN
+ 'forward_model_path': FLAIR_LM_FW_WMT11_EN_TF,
+ 'backward_model_path': FLAIR_LM_BW_WMT11_EN_TF
}}
],
rnn_output_dropout=0.5,
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb5_dep.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb5_dep.py
new file mode 100644
index 000000000..52e5e25b4
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb5_dep.py
@@ -0,0 +1,27 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 18:33
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineDependencyParserTF
+from hanlp.datasets.parsing.ctb5 import CTB5_DEP_TRAIN, CTB5_DEP_DEV, CTB5_DEP_TEST
+from hanlp.pretrained.word2vec import CTB5_FASTTEXT_300_CN
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/biaffine_ctb'
+parser = BiaffineDependencyParserTF()
+# parser.fit(CTB5_DEP_TRAIN, CTB5_DEP_DEV, save_dir,
+# pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
+# 'config': {
+# 'trainable': False,
+# 'embeddings_initializer': 'zero',
+# 'filepath': CTB5_FASTTEXT_300_CN,
+# 'expand_vocab': True,
+# 'lowercase': True,
+# 'normalize': True,
+# }},
+# )
+parser.load(save_dir)
+sentence = [('中国', 'NR'), ('批准', 'VV'), ('设立', 'VV'), ('外商', 'NN'), ('投资', 'NN'), ('企业', 'NN'), ('逾', 'VV'),
+ ('三十万', 'CD'), ('家', 'M')]
+print(parser.predict(sentence))
+parser.evaluate(CTB5_DEP_TEST, save_dir)
diff --git a/tests/train/zh/train_ctb5_pos_rnn.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb5_pos_rnn.py
similarity index 56%
rename from tests/train/zh/train_ctb5_pos_rnn.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb5_pos_rnn.py
index 103c1b0c1..1e214a5ff 100644
--- a/tests/train/zh/train_ctb5_pos_rnn.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb5_pos_rnn.py
@@ -1,16 +1,16 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 22:46
-from hanlp.components.pos import RNNPartOfSpeechTagger
-from hanlp.datasets.parsing.ctb import CIP_W2V_100_CN
-from hanlp.datasets.pos.ctb import CTB5_POS_TRAIN, CTB5_POS_VALID, CTB5_POS_TEST
+from hanlp.components.pos_tf import RNNPartOfSpeechTaggerTF
+from hanlp.datasets.parsing.ctb5 import CIP_W2V_100_CN
+from hanlp.datasets.pos.ctb5 import CTB5_POS_TRAIN, CTB5_POS_DEV, CTB5_POS_TEST
from hanlp.pretrained.fasttext import FASTTEXT_CC_300_EN, FASTTEXT_WIKI_300_ZH
from tests import cdroot
cdroot()
-tagger = RNNPartOfSpeechTagger()
+tagger = RNNPartOfSpeechTaggerTF()
save_dir = 'data/model/pos/ctb5_pos_rnn_fasttext'
-tagger.fit(CTB5_POS_TRAIN, CTB5_POS_VALID, save_dir, embeddings={'class_name': 'HanLP>FastTextEmbedding',
+tagger.fit(CTB5_POS_TRAIN, CTB5_POS_DEV, save_dir, embeddings={'class_name': 'HanLP>FastTextEmbedding',
'config': {'filepath': FASTTEXT_WIKI_300_ZH}}, )
tagger.evaluate(CTB5_POS_TEST, save_dir=save_dir)
print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb7_dep.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb7_dep.py
new file mode 100644
index 000000000..f2d99238d
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb7_dep.py
@@ -0,0 +1,26 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 18:33
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineDependencyParserTF
+from hanlp.datasets.parsing.ctb5 import CTB7_DEP_TRAIN, CTB7_DEP_DEV, CTB7_DEP_TEST, CIP_W2V_100_CN
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/biaffine_ctb7'
+parser = BiaffineDependencyParserTF()
+# parser.fit(CTB7_DEP_TRAIN, CTB7_DEP_DEV, save_dir,
+# pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
+# 'config': {
+# 'trainable': False,
+# 'embeddings_initializer': 'zero',
+# 'filepath': CIP_W2V_100_CN,
+# 'expand_vocab': True,
+# 'lowercase': True,
+# 'normalize': True,
+# }},
+# )
+parser.load(save_dir)
+sentence = [('中国', 'NR'), ('批准', 'VV'), ('设立', 'VV'), ('外商', 'NN'), ('投资', 'NN'), ('企业', 'NN'), ('逾', 'VV'),
+ ('三十万', 'CD'), ('家', 'M')]
+print(parser.predict(sentence))
+parser.evaluate(CTB7_DEP_TEST, save_dir)
diff --git a/tests/train/zh/train_ctb9_pos_albert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb9_pos_albert.py
similarity index 100%
rename from tests/train/zh/train_ctb9_pos_albert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_ctb9_pos_albert.py
diff --git a/tests/train/zh/train_msra_ner_albert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_albert.py
similarity index 53%
rename from tests/train/zh/train_msra_ner_albert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_albert.py
index 3ea7c560e..82a17f725 100644
--- a/tests/train/zh/train_msra_ner_albert.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_albert.py
@@ -1,17 +1,17 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 23:15
-from hanlp.components.ner import TransformerNamedEntityRecognizer
-from hanlp.datasets.ner.msra import MSRA_NER_TRAIN, MSRA_NER_VALID, MSRA_NER_TEST
+from hanlp.components.ner_tf import TransformerNamedEntityRecognizerTF
+from hanlp.datasets.ner.msra import MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, MSRA_NER_CHAR_LEVEL_TEST
from tests import cdroot
cdroot()
-recognizer = TransformerNamedEntityRecognizer()
+recognizer = TransformerNamedEntityRecognizerTF()
save_dir = 'data/model/ner/ner_albert_base_zh_msra_sparse_categorical_crossentropy'
-recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir, transformer='albert_base_zh',
+recognizer.fit(MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, save_dir, transformer='albert_base_zh',
learning_rate=5e-5,
metrics='f1')
recognizer.load(save_dir)
print(recognizer.predict(list('上海华安工业(集团)公司董事长谭旭光和秘书张晚霞来到美国纽约现代艺术博物馆参观。')))
-recognizer.evaluate(MSRA_NER_TEST, save_dir=save_dir)
+recognizer.evaluate(MSRA_NER_CHAR_LEVEL_TEST, save_dir=save_dir)
print(f'Model saved in {save_dir}')
diff --git a/tests/train/zh/train_msra_ner_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_bert.py
similarity index 51%
rename from tests/train/zh/train_msra_ner_bert.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_bert.py
index 88e2b4f42..f6aed4ea1 100644
--- a/tests/train/zh/train_msra_ner_bert.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_bert.py
@@ -1,16 +1,16 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 23:15
-from hanlp.components.ner import TransformerNamedEntityRecognizer
-from hanlp.datasets.ner.msra import MSRA_NER_TRAIN, MSRA_NER_VALID, MSRA_NER_TEST
+from hanlp.components.ner_tf import TransformerNamedEntityRecognizerTF
+from hanlp.datasets.ner.msra import MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, MSRA_NER_CHAR_LEVEL_TEST
from tests import cdroot
cdroot()
-recognizer = TransformerNamedEntityRecognizer()
+recognizer = TransformerNamedEntityRecognizerTF()
save_dir = 'data/model/ner/ner_bert_base_msra_2'
-recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir, transformer='chinese_L-12_H-768_A-12',
+recognizer.fit(MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, save_dir, transformer='chinese_L-12_H-768_A-12',
metrics='accuracy') # accuracy is faster
recognizer.load(save_dir, metrics='f1')
print(recognizer.predict(list('上海华安工业(集团)公司董事长谭旭光和秘书张晚霞来到美国纽约现代艺术博物馆参观。')))
-recognizer.evaluate(MSRA_NER_TEST, save_dir=save_dir)
+recognizer.evaluate(MSRA_NER_CHAR_LEVEL_TEST, save_dir=save_dir)
print(f'Model saved in {save_dir}')
diff --git a/tests/train/zh/train_msra_ner_ngram_conv.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_ngram_conv.py
similarity index 73%
rename from tests/train/zh/train_msra_ner_ngram_conv.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_ngram_conv.py
index 08bfd2267..66683e57f 100644
--- a/tests/train/zh/train_msra_ner_ngram_conv.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_ngram_conv.py
@@ -1,16 +1,16 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 23:15
-from hanlp.components.ner import RNNNamedEntityRecognizer, NgramConvNamedEntityRecognizer
-from hanlp.datasets.ner.msra import MSRA_NER_TRAIN, MSRA_NER_VALID, MSRA_NER_TEST
+from hanlp.components.ner_tf import RNNNamedEntityRecognizerTF, NgramConvNamedEntityRecognizerTF
+from hanlp.datasets.ner.msra import MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, MSRA_NER_CHAR_LEVEL_TEST
from hanlp.pretrained.word2vec import RADICAL_CHAR_EMBEDDING_100, CONVSEG_W2V_NEWS_TENSITE_CHAR, \
CONVSEG_W2V_NEWS_TENSITE_WORD_MSR
from tests import cdroot
cdroot()
-recognizer = NgramConvNamedEntityRecognizer()
+recognizer = NgramConvNamedEntityRecognizerTF()
save_dir = 'data/model/ner/msra_ner_ngram_conv'
-recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir,
+recognizer.fit(MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, save_dir,
word_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
'trainable': True,
@@ -26,4 +26,4 @@
'lowercase': False,
}},
weight_norm=True)
-recognizer.evaluate(MSRA_NER_TEST, save_dir)
+recognizer.evaluate(MSRA_NER_CHAR_LEVEL_TEST, save_dir)
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_rnn.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_rnn.py
new file mode 100644
index 000000000..a8fd29225
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_msra_ner_rnn.py
@@ -0,0 +1,16 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 23:15
+from hanlp.components.ner_tf import RNNNamedEntityRecognizerTF
+from hanlp.datasets.ner.msra import MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, MSRA_NER_CHAR_LEVEL_TEST
+from hanlp.pretrained.word2vec import RADICAL_CHAR_EMBEDDING_100
+from tests import cdroot
+
+cdroot()
+recognizer = RNNNamedEntityRecognizerTF()
+save_dir = 'data/model/ner/msra_ner_rnn'
+recognizer.fit(MSRA_NER_CHAR_LEVEL_TRAIN, MSRA_NER_CHAR_LEVEL_DEV, save_dir,
+ embeddings=RADICAL_CHAR_EMBEDDING_100,
+ embedding_trainable=True,
+ epochs=100)
+recognizer.evaluate(MSRA_NER_CHAR_LEVEL_TEST, save_dir)
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_albert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_albert.py
new file mode 100644
index 000000000..5913233a6
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_albert.py
@@ -0,0 +1,29 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+from hanlp.metrics.parsing import conllx_eval
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/ptb_albert3'
+parser = BiaffineTransformerDependencyParserTF()
+parser.fit(PTB_SD330_TRAIN, PTB_SD330_DEV, save_dir,
+ 'albert-xxlarge-v2',
+ batch_size=256,
+ warmup_steps_ratio=.1,
+ token_mapping=PTB_TOKEN_MAPPING,
+ samples_per_batch=150,
+ transformer_dropout=.33,
+ learning_rate=2e-3,
+ learning_rate_transformer=1e-5,
+ # early_stopping_patience=10,
+ )
+parser.load(save_dir)
+# output = f'{save_dir}/test.predict.conll'
+parser.evaluate(PTB_SD330_TEST, save_dir, warm_up=False)
+# uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+# print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert.py
new file mode 100644
index 000000000..5adb43062
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert.py
@@ -0,0 +1,28 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+from hanlp.metrics.parsing import conllx_eval
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/ptb_bert_1e-5'
+parser = BiaffineTransformerDependencyParserTF()
+# parser.fit(PTB_SD330_TRAIN, PTB_SD330_DEV, save_dir, 'bert-base-uncased',
+# batch_size=3000,
+# warmup_steps_ratio=.1,
+# token_mapping=PTB_TOKEN_MAPPING,
+# samples_per_batch=150,
+# transformer_dropout=.33,
+# learning_rate=2e-3,
+# learning_rate_transformer=1e-5,
+# # early_stopping_patience=10,
+# )
+parser.load(save_dir, tree='tarjan')
+# output = f'{save_dir}/test.predict.conll'
+parser.evaluate(PTB_SD330_TEST, save_dir, warm_up=False)
+# uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+# print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert_96.6.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert_96.6.py
new file mode 100644
index 000000000..a702ea043
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert_96.6.py
@@ -0,0 +1,24 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF
+from tests import cdroot
+from hanlp.metrics.parsing import conllx_eval
+
+cdroot()
+save_dir = 'data/model/dep/ptb_bert_96.61'
+parser = BiaffineTransformerDependencyParserTF()
+# parser.fit(PTB_SD330_TRAIN, PTB_SD330_DEV, save_dir, 'bert-base-uncased',
+# batch_size=3000,
+# warmup_steps_ratio=.1,
+# token_mapping=PTB_TOKEN_MAPPING,
+# samples_per_batch=150,
+# )
+parser.load(save_dir)
+output = f'{save_dir}/test.predict.conll'
+parser.evaluate(PTB_SD330_TEST, save_dir, warm_up=False, output=output)
+uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert_positional.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert_positional.py
new file mode 100644
index 000000000..0ab597fc9
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_biaffine_bert_positional.py
@@ -0,0 +1,29 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+from hanlp.metrics.parsing import conllx_eval
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/ptb_bert_positional_diff_lr'
+parser = BiaffineTransformerDependencyParserTF()
+parser.fit(PTB_SD330_TRAIN, PTB_SD330_DEV, save_dir, 'bert-base-uncased',
+ batch_size=3000,
+ warmup_steps_ratio=.1,
+ token_mapping=PTB_TOKEN_MAPPING,
+ samples_per_batch=150,
+ transformer_dropout=.33,
+ learning_rate=1e-4,
+ learning_rate_transformer=1e-5,
+ d_positional=128,
+ # early_stopping_patience=10,
+ )
+# parser.load(save_dir)
+# output = f'{save_dir}/test.predict.conll'
+parser.evaluate(PTB_SD330_TEST, save_dir, warm_up=False)
+# uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+# print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+# print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_albert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_albert.py
new file mode 100644
index 000000000..532c5afd1
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_albert.py
@@ -0,0 +1,42 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+from hanlp.metrics.parsing import conllx_eval
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF, \
+ StructuralAttentionDependencyParserTF
+from hanlp.pretrained.glove import GLOVE_840B_300D
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/ptb_sa_glove'
+parser = StructuralAttentionDependencyParserTF()
+# parser.fit(PTB_SD330_TRAIN, PTB_SD330_DEV, save_dir, 'bert-base-uncased',
+# batch_size=3000,
+# warmup_steps_ratio=.1,
+# token_mapping=PTB_TOKEN_MAPPING,
+# samples_per_batch=150,
+# transformer_dropout=.33,
+# masked_lm_dropout=.33,
+# # learning_rate=2e-3,
+# # learning_rate_transformer=1e-5,
+# masked_lm_embed={'class_name': 'HanLP>Word2VecEmbedding',
+# 'config': {
+# 'trainable': False,
+# # 'embeddings_initializer': 'zero',
+# 'filepath': GLOVE_840B_300D,
+# 'expand_vocab': False,
+# 'lowercase': True,
+# 'cpu': False
+# }}
+# # alpha=1,
+# # early_stopping_patience=10,
+# # num_decoder_layers=2,
+# )
+parser.load(save_dir)
+# output = f'{save_dir}/test.predict.conll'
+parser.evaluate(PTB_SD330_TEST, save_dir, warm_up=False)
+# uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+# print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_albert_topk.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_albert_topk.py
new file mode 100644
index 000000000..54b425d88
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_albert_topk.py
@@ -0,0 +1,34 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+from hanlp.metrics.parsing import conllx_eval
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF, \
+ StructuralAttentionDependencyParserTF
+from hanlp.pretrained.glove import GLOVE_840B_300D
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/ptb_sa_topk'
+parser = StructuralAttentionDependencyParserTF()
+parser.fit(PTB_SD330_TRAIN, PTB_SD330_DEV, save_dir, 'bert-base-uncased',
+ batch_size=3000,
+ warmup_steps_ratio=.1,
+ token_mapping=PTB_TOKEN_MAPPING,
+ samples_per_batch=150,
+ transformer_dropout=.33,
+ masked_lm_dropout=.33,
+ learning_rate=2e-3,
+ learning_rate_transformer=1e-5,
+
+ # alpha=1,
+ # early_stopping_patience=10,
+ # num_decoder_layers=2,
+ )
+parser.load(save_dir)
+# output = f'{save_dir}/test.predict.conll'
+parser.evaluate(PTB_SD330_TEST, save_dir, warm_up=False)
+# uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+# print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_bert.py
new file mode 100644
index 000000000..532c5afd1
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_bert.py
@@ -0,0 +1,42 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+from hanlp.metrics.parsing import conllx_eval
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF, \
+ StructuralAttentionDependencyParserTF
+from hanlp.pretrained.glove import GLOVE_840B_300D
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/ptb_sa_glove'
+parser = StructuralAttentionDependencyParserTF()
+# parser.fit(PTB_SD330_TRAIN, PTB_SD330_DEV, save_dir, 'bert-base-uncased',
+# batch_size=3000,
+# warmup_steps_ratio=.1,
+# token_mapping=PTB_TOKEN_MAPPING,
+# samples_per_batch=150,
+# transformer_dropout=.33,
+# masked_lm_dropout=.33,
+# # learning_rate=2e-3,
+# # learning_rate_transformer=1e-5,
+# masked_lm_embed={'class_name': 'HanLP>Word2VecEmbedding',
+# 'config': {
+# 'trainable': False,
+# # 'embeddings_initializer': 'zero',
+# 'filepath': GLOVE_840B_300D,
+# 'expand_vocab': False,
+# 'lowercase': True,
+# 'cpu': False
+# }}
+# # alpha=1,
+# # early_stopping_patience=10,
+# # num_decoder_layers=2,
+# )
+parser.load(save_dir)
+# output = f'{save_dir}/test.predict.conll'
+parser.evaluate(PTB_SD330_TEST, save_dir, warm_up=False)
+# uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+# print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_pos_bert.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_pos_bert.py
new file mode 100644
index 000000000..4b00a2198
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_dep_sa_pos_bert.py
@@ -0,0 +1,34 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-03-07 23:48
+from hanlp.metrics.parsing import conllx_eval
+
+from hanlp.datasets.parsing.ptb import PTB_SD330_DEV, PTB_SD330_TRAIN, PTB_SD330_TEST, PTB_TOKEN_MAPPING
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineTransformerDependencyParserTF, \
+ StructuralAttentionDependencyParserTF
+from hanlp.pretrained.glove import GLOVE_840B_300D
+from tests import cdroot
+
+cdroot()
+save_dir = 'data/model/dep/ptb_sa_bert_joint_pos'
+parser = StructuralAttentionDependencyParserTF()
+parser.fit('data/ptb-dep/train.conllx', 'data/ptb-dep/dev.conllx', save_dir, 'bert-base-uncased',
+ batch_size=256,
+ warmup_steps_ratio=.1,
+ token_mapping=PTB_TOKEN_MAPPING,
+ samples_per_batch=150,
+ transformer_dropout=.33,
+ masked_lm_dropout=.33,
+ learning_rate=2e-3,
+ learning_rate_transformer=1e-5,
+ joint_pos=True
+ # alpha=1,
+ # early_stopping_patience=10,
+ # num_decoder_layers=2,
+ )
+# parser.load(save_dir)
+# output = f'{save_dir}/test.predict.conll'
+parser.evaluate('data/ptb-dep/test.conllx', save_dir, warm_up=False)
+# uas, las = conllx_eval.evaluate(PTB_SD330_TEST, output)
+# print(f'Official UAS: {uas:.4f} LAS: {las:.4f}')
+print(f'Model saved in {save_dir}')
diff --git a/tests/train/en/train_ptb_pos_rnn_fasttext.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_pos_rnn_fasttext.py
similarity index 91%
rename from tests/train/en/train_ptb_pos_rnn_fasttext.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_pos_rnn_fasttext.py
index 4916d6344..02df1c0c5 100644
--- a/tests/train/en/train_ptb_pos_rnn_fasttext.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_ptb_pos_rnn_fasttext.py
@@ -4,12 +4,12 @@
import tensorflow as tf
-from hanlp.components.pos import RNNPartOfSpeechTagger
+from hanlp.components.pos_tf import RNNPartOfSpeechTaggerTF
from hanlp.pretrained.fasttext import FASTTEXT_CC_300_EN
from tests import cdroot
cdroot()
-tagger = RNNPartOfSpeechTagger()
+tagger = RNNPartOfSpeechTaggerTF()
save_dir = 'data/model/pos/ptb_pos_rnn_fasttext'
optimizer = tf.keras.optimizers.SGD(lr=0.015)
# optimizer = 'adam'
diff --git a/tests/train/en/train_semeval15_dm.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_dm.py
similarity index 89%
rename from tests/train/en/train_semeval15_dm.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_dm.py
index 8a930c5f9..4b710d4a0 100644
--- a/tests/train/en/train_semeval15_dm.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_dm.py
@@ -1,13 +1,13 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-01-01 18:26
-from hanlp.components.parsers.biaffine_parser import BiaffineSemanticDependencyParser
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineSemanticDependencyParserTF
from hanlp.pretrained.glove import GLOVE_6B_100D
from tests import cdroot
cdroot()
save_dir = 'data/model/sdp/semeval15_biaffine_dm'
-parser = BiaffineSemanticDependencyParser()
+parser = BiaffineSemanticDependencyParserTF()
parser.fit('data/semeval15/en.dm.train.conll', 'data/semeval15/en.dm.dev.conll', save_dir,
pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/train/en/train_semeval15_pas.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_pas.py
similarity index 89%
rename from tests/train/en/train_semeval15_pas.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_pas.py
index 3a5ad357b..63b62147d 100644
--- a/tests/train/en/train_semeval15_pas.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_pas.py
@@ -1,13 +1,13 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-01-01 18:26
-from hanlp.components.parsers.biaffine_parser import BiaffineSemanticDependencyParser
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineSemanticDependencyParserTF
from hanlp.pretrained.glove import GLOVE_6B_100D
from tests import cdroot
cdroot()
save_dir = 'data/model/sdp/semeval15_biaffine_pas'
-parser = BiaffineSemanticDependencyParser()
+parser = BiaffineSemanticDependencyParserTF()
parser.fit('data/semeval15/en.pas.train.conll', 'data/semeval15/en.pas.dev.conll', save_dir,
pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/train/en/train_semeval15_psd.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_psd.py
similarity index 89%
rename from tests/train/en/train_semeval15_psd.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_psd.py
index 9a3375054..b48dbde9a 100644
--- a/tests/train/en/train_semeval15_psd.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval15_psd.py
@@ -1,13 +1,13 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-01-01 18:26
-from hanlp.components.parsers.biaffine_parser import BiaffineSemanticDependencyParser
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineSemanticDependencyParserTF
from hanlp.pretrained.glove import GLOVE_6B_100D
from tests import cdroot
cdroot()
save_dir = 'data/model/sdp/semeval15_biaffine_psd'
-parser = BiaffineSemanticDependencyParser()
+parser = BiaffineSemanticDependencyParserTF()
parser.fit('data/semeval15/en.psd.train.conll', 'data/semeval15/en.psd.dev.conll', save_dir,
pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
diff --git a/tests/train/zh/train_semeval16_news.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval16_news.py
similarity index 76%
rename from tests/train/zh/train_semeval16_news.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval16_news.py
index 142163737..79dcd11cd 100644
--- a/tests/train/zh/train_semeval16_news.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval16_news.py
@@ -1,18 +1,18 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-26 23:20
-from hanlp.datasets.parsing.semeval2016 import SEMEVAL2016_NEWS_TRAIN, SEMEVAL2016_NEWS_VALID, SEMEVAL2016_NEWS_TEST
+from hanlp.datasets.parsing.semeval16 import SEMEVAL2016_NEWS_TRAIN, SEMEVAL2016_NEWS_DEV, SEMEVAL2016_NEWS_TEST
from hanlp.pretrained.word2vec import SEMEVAL16_EMBEDDINGS_300_NEWS_CN
from hanlp.utils.tf_util import nice
nice()
-from hanlp.components.parsers.biaffine_parser import BiaffineSemanticDependencyParser
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineSemanticDependencyParserTF
from tests import cdroot
cdroot()
save_dir = 'data/model/sdp/semeval16-news'
-parser = BiaffineSemanticDependencyParser()
-parser.fit(SEMEVAL2016_NEWS_TRAIN, SEMEVAL2016_NEWS_VALID, save_dir,
+parser = BiaffineSemanticDependencyParserTF()
+parser.fit(SEMEVAL2016_NEWS_TRAIN, SEMEVAL2016_NEWS_DEV, save_dir,
pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
'trainable': False,
diff --git a/tests/train/zh/train_semeval16_text.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval16_text.py
similarity index 76%
rename from tests/train/zh/train_semeval16_text.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval16_text.py
index f38bc2532..2e7aed631 100644
--- a/tests/train/zh/train_semeval16_text.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_semeval16_text.py
@@ -1,18 +1,18 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-26 23:20
-from hanlp.datasets.parsing.semeval2016 import SEMEVAL2016_TEXT_TRAIN, SEMEVAL2016_TEXT_VALID, SEMEVAL2016_TEXT_TEST
+from hanlp.datasets.parsing.semeval16 import SEMEVAL2016_TEXT_TRAIN, SEMEVAL2016_TEXT_DEV, SEMEVAL2016_TEXT_TEST
from hanlp.pretrained.word2vec import SEMEVAL16_EMBEDDINGS_300_TEXT_CN
from hanlp.utils.tf_util import nice
nice()
-from hanlp.components.parsers.biaffine_parser import BiaffineSemanticDependencyParser
+from hanlp.components.parsers.biaffine_parser_tf import BiaffineSemanticDependencyParserTF
from tests import cdroot
cdroot()
save_dir = 'data/model/sdp/semeval16-text'
-parser = BiaffineSemanticDependencyParser()
-parser.fit(SEMEVAL2016_TEXT_TRAIN, SEMEVAL2016_TEXT_VALID, save_dir,
+parser = BiaffineSemanticDependencyParserTF()
+parser.fit(SEMEVAL2016_TEXT_TRAIN, SEMEVAL2016_TEXT_DEV, save_dir,
pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
'config': {
'trainable': False,
diff --git a/tests/train/en/train_sst2_albert_base.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_sst2_albert_base.py
similarity index 66%
rename from tests/train/en/train_sst2_albert_base.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_sst2_albert_base.py
index b09885564..9c8ab5d4a 100644
--- a/tests/train/en/train_sst2_albert_base.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_sst2_albert_base.py
@@ -3,17 +3,17 @@
# Date: 2019-11-10 17:41
import os
-from hanlp.components.classifiers.transformer_classifier import TransformerClassifier
+from hanlp.components.classifiers.transformer_classifier_tf import TransformerClassifierTF
from tests import cdroot
-from hanlp.datasets.glue import STANFORD_SENTIMENT_TREEBANK_2_VALID, STANFORD_SENTIMENT_TREEBANK_2_TRAIN, \
+from hanlp.datasets.glue import STANFORD_SENTIMENT_TREEBANK_2_DEV, STANFORD_SENTIMENT_TREEBANK_2_TRAIN, \
STANFORD_SENTIMENT_TREEBANK_2_TEST
cdroot()
save_dir = os.path.join('data', 'model', 'sst', 'sst2_albert_base')
-classifier = TransformerClassifier()
-classifier.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN, STANFORD_SENTIMENT_TREEBANK_2_VALID, save_dir,
+classifier = TransformerClassifierTF()
+classifier.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN, STANFORD_SENTIMENT_TREEBANK_2_DEV, save_dir,
transformer='albert_base')
classifier.load(save_dir)
print(classifier('it\' s a charming and often affecting journey'))
diff --git a/tests/train/en/train_sst2_bert_base.py b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_sst2_bert_base_tf.py
similarity index 72%
rename from tests/train/en/train_sst2_bert_base.py
rename to plugins/hanlp_demo/hanlp_demo/zh/tf/train_sst2_bert_base_tf.py
index 9df1dc324..ab750af3c 100644
--- a/tests/train/en/train_sst2_bert_base.py
+++ b/plugins/hanlp_demo/hanlp_demo/zh/tf/train_sst2_bert_base_tf.py
@@ -1,13 +1,13 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-01-03 23:00
-from hanlp.components.classifiers.transformer_classifier import TransformerClassifier
+from hanlp.components.classifiers.transformer_classifier_tf import TransformerClassifierTF
from hanlp.datasets.glue import STANFORD_SENTIMENT_TREEBANK_2_TRAIN, STANFORD_SENTIMENT_TREEBANK_2_TEST, \
- STANFORD_SENTIMENT_TREEBANK_2_VALID
+ STANFORD_SENTIMENT_TREEBANK_2_DEV
save_dir = 'data/model/classification/sst2_bert_base_uncased_en'
-classifier = TransformerClassifier()
-classifier.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN, STANFORD_SENTIMENT_TREEBANK_2_VALID, save_dir,
+classifier = TransformerClassifierTF()
+classifier.fit(STANFORD_SENTIMENT_TREEBANK_2_TRAIN, STANFORD_SENTIMENT_TREEBANK_2_DEV, save_dir,
transformer='uncased_L-12_H-768_A-12')
classifier.load(save_dir)
print(classifier.predict('it\' s a charming and often affecting journey'))
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/train/__init__.py b/plugins/hanlp_demo/hanlp_demo/zh/train/__init__.py
new file mode 100644
index 000000000..b9ad8c2b1
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/train/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-31 20:12
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/train/open_base.py b/plugins/hanlp_demo/hanlp_demo/zh/train/open_base.py
new file mode 100644
index 000000000..91a8a5601
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/train/open_base.py
@@ -0,0 +1,129 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-03 14:24
+
+from hanlp.common.dataset import SortingSamplerBuilder
+from hanlp.common.transform import NormalizeCharacter
+from hanlp.components.mtl.multi_task_learning import MultiTaskLearning
+from hanlp.components.mtl.tasks.constituency import CRFConstituencyParsing
+from hanlp.components.mtl.tasks.dep import BiaffineDependencyParsing
+from hanlp.components.mtl.tasks.ner.tag_ner import TaggingNamedEntityRecognition
+from hanlp.components.mtl.tasks.pos import TransformerTagging
+from hanlp.components.mtl.tasks.sdp import BiaffineSemanticDependencyParsing
+from hanlp.components.mtl.tasks.srl.bio_srl import SpanBIOSemanticRoleLabeling
+from hanlp.components.mtl.tasks.tok.tag_tok import TaggingTokenization
+from hanlp.datasets.ner.msra import MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN, MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV, \
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST
+from hanlp.datasets.parsing.ctb8 import CTB8_POS_TRAIN, CTB8_POS_DEV, CTB8_POS_TEST, CTB8_SD330_TEST, CTB8_SD330_DEV, \
+ CTB8_SD330_TRAIN, CTB8_CWS_TRAIN, CTB8_CWS_DEV, CTB8_CWS_TEST, CTB8_BRACKET_LINE_NOEC_TRAIN, \
+ CTB8_BRACKET_LINE_NOEC_DEV, CTB8_BRACKET_LINE_NOEC_TEST
+from hanlp.datasets.parsing.semeval16 import SEMEVAL2016_TEXT_TRAIN_CONLLU, SEMEVAL2016_TEXT_TEST_CONLLU, \
+ SEMEVAL2016_TEXT_DEV_CONLLU
+from hanlp.datasets.srl.ontonotes5.chinese import ONTONOTES5_CONLL12_CHINESE_TEST, ONTONOTES5_CONLL12_CHINESE_DEV, \
+ ONTONOTES5_CONLL12_CHINESE_TRAIN
+from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding
+from hanlp.layers.transformers.relative_transformer import RelativeTransformerEncoder
+from hanlp.utils.lang.zh.char_table import HANLP_CHAR_TABLE_JSON
+from hanlp.utils.log_util import cprint
+from tests import cdroot
+
+cdroot()
+tasks = {
+ 'tok': TaggingTokenization(
+ CTB8_CWS_TRAIN,
+ CTB8_CWS_DEV,
+ CTB8_CWS_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ max_seq_len=510,
+ hard_constraint=True,
+ char_level=True,
+ tagging_scheme='BMES',
+ lr=1e-3,
+ transform=NormalizeCharacter(HANLP_CHAR_TABLE_JSON, 'token'),
+ ),
+ 'pos': TransformerTagging(
+ CTB8_POS_TRAIN,
+ CTB8_POS_DEV,
+ CTB8_POS_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ hard_constraint=True,
+ max_seq_len=510,
+ char_level=True,
+ dependencies='tok',
+ lr=1e-3,
+ ),
+ 'ner': TaggingNamedEntityRecognition(
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN,
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV,
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ lr=1e-3,
+ secondary_encoder=RelativeTransformerEncoder(768, k_as_x=True),
+ dependencies='tok',
+ ),
+ 'srl': SpanBIOSemanticRoleLabeling(
+ ONTONOTES5_CONLL12_CHINESE_TRAIN,
+ ONTONOTES5_CONLL12_CHINESE_DEV,
+ ONTONOTES5_CONLL12_CHINESE_TEST,
+ SortingSamplerBuilder(batch_size=32, batch_max_tokens=2048),
+ lr=1e-3,
+ crf=True,
+ dependencies='tok',
+ ),
+ 'dep': BiaffineDependencyParsing(
+ CTB8_SD330_TRAIN,
+ CTB8_SD330_DEV,
+ CTB8_SD330_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ lr=1e-3,
+ tree=True,
+ punct=True,
+ dependencies='tok',
+ ),
+ 'sdp': BiaffineSemanticDependencyParsing(
+ SEMEVAL2016_TEXT_TRAIN_CONLLU,
+ SEMEVAL2016_TEXT_DEV_CONLLU,
+ SEMEVAL2016_TEXT_TEST_CONLLU,
+ SortingSamplerBuilder(batch_size=32),
+ lr=1e-3,
+ apply_constraint=True,
+ punct=True,
+ dependencies='tok',
+ ),
+ 'con': CRFConstituencyParsing(
+ CTB8_BRACKET_LINE_NOEC_TRAIN,
+ CTB8_BRACKET_LINE_NOEC_DEV,
+ CTB8_BRACKET_LINE_NOEC_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ lr=1e-3,
+ dependencies='tok',
+ )
+}
+mtl = MultiTaskLearning()
+save_dir = 'data/model/mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_base'
+mtl.fit(
+ ContextualWordEmbedding('token',
+ "hfl/chinese-electra-180g-base-discriminator",
+ average_subwords=True,
+ max_sequence_length=512,
+ word_dropout=.1),
+ tasks,
+ save_dir,
+ 30,
+ lr=1e-3,
+ encoder_lr=5e-5,
+ grad_norm=1,
+ gradient_accumulation=2,
+ eval_trn=False,
+)
+cprint(f'Model saved in [cyan]{save_dir}[/cyan]')
+mtl.load(save_dir)
+for k, v in tasks.items():
+ v.trn = tasks[k].trn
+ v.dev = tasks[k].dev
+ v.tst = tasks[k].tst
+metric, *_ = mtl.evaluate(save_dir)
+for k, v in tasks.items():
+ print(metric[k], end=' ')
+print()
+print(mtl('华纳音乐旗下的新垣结衣在12月21日于日本武道馆举办歌手出道活动'))
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/train/open_small.py b/plugins/hanlp_demo/hanlp_demo/zh/train/open_small.py
new file mode 100644
index 000000000..ee20deb99
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/train/open_small.py
@@ -0,0 +1,127 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-12-03 14:24
+
+from hanlp.common.dataset import SortingSamplerBuilder
+from hanlp.common.transform import NormalizeCharacter
+from hanlp.components.mtl.multi_task_learning import MultiTaskLearning
+from hanlp.components.mtl.tasks.constituency import CRFConstituencyParsing
+from hanlp.components.mtl.tasks.dep import BiaffineDependencyParsing
+from hanlp.components.mtl.tasks.ner.tag_ner import TaggingNamedEntityRecognition
+from hanlp.components.mtl.tasks.pos import TransformerTagging
+from hanlp.components.mtl.tasks.sdp import BiaffineSemanticDependencyParsing
+from hanlp.components.mtl.tasks.srl.bio_srl import SpanBIOSemanticRoleLabeling
+from hanlp.components.mtl.tasks.tok.tag_tok import TaggingTokenization
+from hanlp.datasets.ner.msra import MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST, MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV, \
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN
+from hanlp.datasets.parsing.ctb8 import CTB8_POS_TRAIN, CTB8_POS_DEV, CTB8_POS_TEST, CTB8_SD330_TEST, CTB8_SD330_DEV, \
+ CTB8_SD330_TRAIN, CTB8_CWS_TRAIN, CTB8_CWS_DEV, CTB8_CWS_TEST, CTB8_BRACKET_LINE_NOEC_TEST, \
+ CTB8_BRACKET_LINE_NOEC_DEV, CTB8_BRACKET_LINE_NOEC_TRAIN
+from hanlp.datasets.parsing.semeval16 import SEMEVAL2016_TEXT_TRAIN_CONLLU, SEMEVAL2016_TEXT_TEST_CONLLU, \
+ SEMEVAL2016_TEXT_DEV_CONLLU
+from hanlp.datasets.srl.ontonotes5.chinese import ONTONOTES5_CONLL12_CHINESE_TEST, ONTONOTES5_CONLL12_CHINESE_DEV, \
+ ONTONOTES5_CONLL12_CHINESE_TRAIN
+from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding
+from hanlp.layers.transformers.relative_transformer import RelativeTransformerEncoder
+from hanlp.utils.lang.zh.char_table import HANLP_CHAR_TABLE_JSON
+from hanlp.utils.log_util import cprint
+from tests import cdroot
+
+cdroot()
+tasks = {
+ 'tok': TaggingTokenization(
+ CTB8_CWS_TRAIN,
+ CTB8_CWS_DEV,
+ CTB8_CWS_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ max_seq_len=510,
+ hard_constraint=True,
+ char_level=True,
+ tagging_scheme='BMES',
+ lr=1e-3,
+ transform=NormalizeCharacter(HANLP_CHAR_TABLE_JSON, 'token'),
+ ),
+ 'pos': TransformerTagging(
+ CTB8_POS_TRAIN,
+ CTB8_POS_DEV,
+ CTB8_POS_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ hard_constraint=True,
+ max_seq_len=510,
+ char_level=True,
+ dependencies='tok',
+ lr=1e-3,
+ ),
+ 'ner': TaggingNamedEntityRecognition(
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN,
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV,
+ MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ max_seq_len=510,
+ hard_constraint=True,
+ char_level=True,
+ lr=1e-3,
+ secondary_encoder=RelativeTransformerEncoder(256, k_as_x=True, feedforward_dim=128),
+ dependencies='tok',
+ ),
+ 'srl': SpanBIOSemanticRoleLabeling(
+ ONTONOTES5_CONLL12_CHINESE_TRAIN,
+ ONTONOTES5_CONLL12_CHINESE_DEV,
+ ONTONOTES5_CONLL12_CHINESE_TEST,
+ SortingSamplerBuilder(batch_size=32, batch_max_tokens=1280),
+ lr=1e-3,
+ crf=True,
+ dependencies='tok',
+ ),
+ 'dep': BiaffineDependencyParsing(
+ CTB8_SD330_TRAIN,
+ CTB8_SD330_DEV,
+ CTB8_SD330_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ lr=1e-3,
+ tree=True,
+ proj=True,
+ punct=True,
+ dependencies='tok',
+ ),
+ 'sdp': BiaffineSemanticDependencyParsing(
+ SEMEVAL2016_TEXT_TRAIN_CONLLU,
+ SEMEVAL2016_TEXT_DEV_CONLLU,
+ SEMEVAL2016_TEXT_TEST_CONLLU,
+ SortingSamplerBuilder(batch_size=32),
+ lr=1e-3,
+ apply_constraint=True,
+ punct=True,
+ dependencies='tok',
+ ),
+ 'con': CRFConstituencyParsing(
+ CTB8_BRACKET_LINE_NOEC_TRAIN,
+ CTB8_BRACKET_LINE_NOEC_DEV,
+ CTB8_BRACKET_LINE_NOEC_TEST,
+ SortingSamplerBuilder(batch_size=32),
+ lr=1e-3,
+ dependencies='tok',
+ )
+}
+mtl = MultiTaskLearning()
+save_dir = 'data/model/mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_small'
+cprint(f'Model will be saved in [cyan]{save_dir}[/cyan]')
+mtl.fit(
+ ContextualWordEmbedding('token',
+ "hfl/chinese-electra-180g-small-discriminator",
+ average_subwords=True,
+ max_sequence_length=512,
+ word_dropout=.1),
+ tasks,
+ save_dir,
+ 30,
+ lr=1e-3,
+ encoder_lr=5e-5,
+ grad_norm=1,
+ gradient_accumulation=1,
+ eval_trn=False,
+)
+cprint(f'Model saved in [cyan]{save_dir}[/cyan]')
+mtl.evaluate(save_dir)
+mtl.load(save_dir)
+mtl('华纳音乐旗下的新垣结衣在12月21日于日本武道馆举办歌手出道活动').pretty_print()
diff --git a/plugins/hanlp_demo/hanlp_demo/zh/train_sota_bert_pku.py b/plugins/hanlp_demo/hanlp_demo/zh/train_sota_bert_pku.py
new file mode 100644
index 000000000..e750fd535
--- /dev/null
+++ b/plugins/hanlp_demo/hanlp_demo/zh/train_sota_bert_pku.py
@@ -0,0 +1,29 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-08-11 02:47
+from hanlp.common.dataset import SortingSamplerBuilder
+from hanlp.components.tokenizers.transformer import TransformerTaggingTokenizer
+from hanlp.datasets.cws.sighan2005.pku import SIGHAN2005_PKU_TRAIN_ALL, SIGHAN2005_PKU_TEST
+from tests import cdroot
+
+cdroot()
+tokenizer = TransformerTaggingTokenizer()
+save_dir = 'data/model/cws/sighan2005_pku_bert_base_96.66'
+tokenizer.fit(
+ SIGHAN2005_PKU_TRAIN_ALL,
+ SIGHAN2005_PKU_TEST, # Conventionally, no devset is used. See Tian et al. (2020).
+ save_dir,
+ 'bert-base-chinese',
+ max_seq_len=300,
+ char_level=True,
+ hard_constraint=True,
+ sampler_builder=SortingSamplerBuilder(batch_size=32),
+ epochs=3,
+ adam_epsilon=1e-6,
+ warmup_steps=0.1,
+ weight_decay=0.01,
+ word_dropout=0.1,
+ seed=1609422632,
+)
+tokenizer.evaluate(SIGHAN2005_PKU_TEST, save_dir)
+print(f'Model saved in {save_dir}')
diff --git a/plugins/hanlp_demo/setup.py b/plugins/hanlp_demo/setup.py
new file mode 100644
index 000000000..6215ad131
--- /dev/null
+++ b/plugins/hanlp_demo/setup.py
@@ -0,0 +1,38 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 19:26
+from os.path import abspath, join, dirname
+from setuptools import find_packages, setup
+
+this_dir = abspath(dirname(__file__))
+with open(join(this_dir, 'README.md'), encoding='utf-8') as file:
+ long_description = file.read()
+
+setup(
+ name='hanlp_demo',
+ version='0.0.1',
+ description='HanLP: Han Language Processing',
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url='https://github.com/hankcs/HanLP',
+ author='hankcs',
+ author_email='hankcshe@gmail.com',
+ license='Apache License 2.0',
+ classifiers=[
+ 'Intended Audience :: Science/Research',
+ 'Intended Audience :: Developers',
+ "Development Status :: 3 - Alpha",
+ 'Operating System :: OS Independent',
+ "License :: OSI Approved :: Apache Software License",
+ 'Programming Language :: Python :: 3 :: Only',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ "Topic :: Text Processing :: Linguistic"
+ ],
+ keywords='corpus,machine-learning,NLU,NLP',
+ packages=find_packages(exclude=['docs', 'tests*']),
+ include_package_data=True,
+ install_requires=[
+ 'hanlp_common'
+ ],
+ python_requires='>=3.6',
+)
diff --git a/plugins/hanlp_restful/README.md b/plugins/hanlp_restful/README.md
new file mode 100644
index 000000000..29cd30d2d
--- /dev/null
+++ b/plugins/hanlp_restful/README.md
@@ -0,0 +1,17 @@
+# RESTFul API Client for HanLP
+
+[中文](https://github.com/hankcs/HanLP/tree/doc-zh) | [1.x](https://github.com/hankcs/HanLP/tree/1.x) | [forum](https://bbs.hankcs.com/) | [docker](https://github.com/WalterInSH/hanlp-jupyter-docker)
+
+The multilingual NLP library for researchers and companies, built on PyTorch and TensorFlow 2.x, for advancing state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be efficient, user friendly and extendable. It comes with pretrained models for various human languages including English, Chinese and many others. Currently, HanLP 2.0 is in alpha stage with more killer features on the roadmap. Discussions are welcomed on our [forum](https://bbs.hankcs.com/), while bug reports and feature requests are reserved for GitHub issues. For Java users, please checkout the [1.x](https://github.com/hankcs/HanLP/tree/1.x) branch.
+
+
+## Installation
+
+```bash
+pip install hanlp-restful
+```
+
+## License
+
+HanLP is licensed under **Apache License 2.0**. You can use HanLP in your commercial products for free. We would appreciate it if you add a link to HanLP on your website.
+
diff --git a/plugins/hanlp_restful/hanlp_restful/__init__.py b/plugins/hanlp_restful/hanlp_restful/__init__.py
new file mode 100644
index 000000000..1784ca354
--- /dev/null
+++ b/plugins/hanlp_restful/hanlp_restful/__init__.py
@@ -0,0 +1,137 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-29 17:48
+import json
+from typing import Union, List, Optional, Dict, Any
+from urllib.error import HTTPError
+from urllib.parse import urlencode
+from urllib.request import Request, urlopen
+from hanlp_common.document import Document
+
+try:
+ # noinspection PyUnresolvedReferences
+ import requests
+
+
+ def _post(url, form: Dict[str, Any], headers: Dict[str, Any], timeout=5) -> str:
+ response = requests.post(url, json=form, headers=headers, timeout=timeout)
+ if response.status_code != 200:
+ raise HTTPError(url, response.status_code, response.text, response.headers, None)
+ return response.text
+except ImportError:
+ def _post(url, form: Dict[str, Any], headers: Dict[str, Any], timeout=5) -> str:
+ request = Request(url, json.dumps(form).encode())
+ for k, v in headers.items():
+ request.add_header(k, v)
+ return urlopen(request, timeout=timeout).read().decode()
+
+
+class HanLPClient(object):
+
+ def __init__(self, url: str, auth: str = None, language=None, timeout=5) -> None:
+ """
+
+ Args:
+ url (str): An API endpoint to a service provider.
+ auth (str): An auth key licenced from a service provider.
+ language (str): The default language for each :func:`~hanlp_restful.HanLPClient.parse` call.
+ Contact the service provider for the list of languages supported.
+ Conventionally, ``zh`` is used for Chinese and ``mul`` for multilingual.
+ Leave ``None`` to use the default language on server.
+ timeout (int): Maximum waiting time in seconds for a request.
+ """
+ super().__init__()
+ self._language = language
+ self._timeout = timeout
+ self._url = url
+ if auth is None:
+ import os
+ auth = os.getenv('HANLP_AUTH', None)
+ self._auth = auth
+
+ def parse(self,
+ text: Union[str, List[str]] = None,
+ tokens: List[List[str]] = None,
+ tasks: Optional[Union[str, List[str]]] = None,
+ skip_tasks: Optional[Union[str, List[str]]] = None,
+ language: str = None,
+ ) -> Document:
+ """
+ Parse a piece of text.
+
+ Args:
+ text: A paragraph (str), or a list of sentences (List[str]).
+ tokens: A list of sentences where each sentence is a list of tokens.
+ tasks: The tasks to predict.
+ skip_tasks: The tasks to skip.
+ language: The language of input text or tokens. ``None`` to use the default language on server.
+
+ Returns:
+ A :class:`~hanlp_common.document.Document`.
+
+ """
+ assert text or tokens, 'At least one of text or tokens has to be specified.'
+ response = self._send_post_json(self._url + '/parse', {
+ 'text': text,
+ 'tokens': tokens,
+ 'tasks': tasks,
+ 'skip_tasks': skip_tasks,
+ 'language': language or self._language
+ })
+ return Document(response)
+
+ def __call__(self,
+ text: Union[str, List[str]] = None,
+ tokens: List[List[str]] = None,
+ tasks: Optional[Union[str, List[str]]] = None,
+ skip_tasks: Optional[Union[str, List[str]]] = None,
+ language: str = None,
+ ) -> Document:
+ """
+ Parse a piece of text. This is a shortcut for :func:`~hanlp_restful.HanLPClient.parse`.
+
+ Args:
+ text: A paragraph (str), or a list of sentences (List[str]).
+ tokens: A list of sentences where each sentence is a list of tokens.
+ tasks: The tasks to predict.
+ skip_tasks: The tasks to skip.
+ language: The language of input text or tokens. ``None`` to use the default language on server.
+
+ Returns:
+ A :class:`~hanlp_common.document.Document`.
+
+ """
+ return self.parse(text, tokens, tasks, skip_tasks)
+
+ def about(self) -> Dict[str, Any]:
+ """Get the information about server and your client.
+
+ Returns:
+ A dict containing your rate limit and server version etc.
+
+ """
+ info = self._send_get_json(self._url + '/about', {})
+ return Document(info)
+
+ def _send_post(self, url, form: Dict[str, Any]):
+ request = Request(url, json.dumps(form).encode())
+ self._add_headers(request)
+ return self._fire_request(request)
+
+ def _fire_request(self, request):
+ return urlopen(request, timeout=self._timeout).read().decode()
+
+ def _send_post_json(self, url, form: Dict[str, Any]):
+ return json.loads(_post(url, form, {'Authorization': f'Basic {self._auth}'}, self._timeout))
+
+ def _send_get(self, url, form: Dict[str, Any]):
+ request = Request(url + '?' + urlencode(form))
+ self._add_headers(request)
+ return self._fire_request(request)
+
+ def _add_headers(self, request):
+ if self._auth:
+ request.add_header('Authorization', f'Basic {self._auth}')
+
+ def _send_get_json(self, url, form: Dict[str, Any]):
+ return json.loads(self._send_get(url, form))
diff --git a/plugins/hanlp_restful/setup.py b/plugins/hanlp_restful/setup.py
new file mode 100644
index 000000000..2c7c329ff
--- /dev/null
+++ b/plugins/hanlp_restful/setup.py
@@ -0,0 +1,38 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 19:26
+from os.path import abspath, join, dirname
+from setuptools import find_packages, setup
+
+this_dir = abspath(dirname(__file__))
+with open(join(this_dir, 'README.md'), encoding='utf-8') as file:
+ long_description = file.read()
+
+setup(
+ name='hanlp_restful',
+ version='0.0.3',
+ description='HanLP: Han Language Processing',
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url='https://github.com/hankcs/HanLP',
+ author='hankcs',
+ author_email='hankcshe@gmail.com',
+ license='Apache License 2.0',
+ classifiers=[
+ 'Intended Audience :: Science/Research',
+ 'Intended Audience :: Developers',
+ "Development Status :: 3 - Alpha",
+ 'Operating System :: OS Independent',
+ "License :: OSI Approved :: Apache Software License",
+ 'Programming Language :: Python :: 3 :: Only',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ "Topic :: Text Processing :: Linguistic"
+ ],
+ keywords='corpus,machine-learning,NLU,NLP',
+ packages=find_packages(exclude=['docs', 'tests*']),
+ include_package_data=True,
+ install_requires=[
+ 'hanlp_common'
+ ],
+ python_requires='>=3.6',
+)
diff --git a/plugins/hanlp_restful/tests/__init__.py b/plugins/hanlp_restful/tests/__init__.py
new file mode 100644
index 000000000..7cb9c3ba7
--- /dev/null
+++ b/plugins/hanlp_restful/tests/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-29 18:05
diff --git a/plugins/hanlp_restful/tests/test_client.py b/plugins/hanlp_restful/tests/test_client.py
new file mode 100644
index 000000000..14b1200e4
--- /dev/null
+++ b/plugins/hanlp_restful/tests/test_client.py
@@ -0,0 +1,35 @@
+import unittest
+
+from hanlp_restful import HanLPClient
+
+
+class TestClient(unittest.TestCase):
+
+ def setUp(self) -> None:
+ self.HanLP = HanLPClient('https://hanlp.hankcs.com/api', auth=None) # Fill in your auth
+
+ def test_raw_text(self):
+ text = '2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。阿婆主来到北京立方庭参观自然语义科技公司。'
+ doc = self.HanLP.parse(text)
+
+ def test_sents(self):
+ text = ['2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。',
+ '阿婆主来到北京立方庭参观自然语义科技公司。']
+ doc = self.HanLP(text)
+
+ def test_tokens(self):
+ tokens = [
+ ["2021年", "HanLPv2.1", "为", "生产", "环境", "带来", "次", "世代", "最", "先进", "的", "多语种", "NLP", "技术", "。"],
+ ["英", "首相", "与", "特朗普", "通", "电话", "讨论", "华为", "与", "苹果", "公司", "。"]
+ ]
+ doc = self.HanLP(tokens=tokens, tasks=['ner*', 'srl', 'dep'])
+
+ def test_sents_mul(self):
+ text = ['In 2021, HanLPv2.1 delivers state-of-the-art multilingual NLP techniques to production environment.',
+ '2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。',
+ '2021年 HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。']
+ doc = self.HanLP.parse(text, language='mul')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/plugins/hanlp_restful_java/pom.xml b/plugins/hanlp_restful_java/pom.xml
new file mode 100644
index 000000000..2a3baf8b9
--- /dev/null
+++ b/plugins/hanlp_restful_java/pom.xml
@@ -0,0 +1,122 @@
+
+
+ 4.0.0
+
+ com.hankcs.hanlp.restful
+ hanlp-restful
+ 0.0.2
+
+ HanLP RESTful Client in Java
+ https://github.com/hankcs/HanLP
+
+ HanLP: Han Language Processing
+
+
+ hankcs
+ http://www.hankcs.com/
+
+
+
+ Apache License Version 2.0
+ https://www.apache.org/licenses/LICENSE-2.0.html
+
+
+ 2020
+
+
+ hankcs
+ cnhankmc@gmail.com
+ http://www.hankcs.com
+
+
+
+ scm:git@github.com:hankcs/HanLP.git
+ scm:git@github.com:hankcs/HanLP.git
+ git@github.com:hankcs/HanLP.git
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+ 8
+ 8
+
+
+
+ maven-source-plugin
+ 2.4
+
+
+ attach-sources
+
+ jar
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 2.9.1
+
+
+ package
+
+ jar
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-gpg-plugin
+ 1.6
+
+
+ verify
+
+ sign
+
+
+
+
+ --pinentry-mode
+ loopback
+
+
+
+
+
+
+
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ 2.12.0
+
+
+ org.junit.jupiter
+ junit-jupiter
+ RELEASE
+ test
+
+
+
+
+
+ maven-repo
+ https://oss.sonatype.org/content/repositories/snapshots/
+
+
+ maven-repo
+ https://oss.sonatype.org/service/local/staging/deploy/maven2/
+
+
+
\ No newline at end of file
diff --git a/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/BaseInput.java b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/BaseInput.java
new file mode 100644
index 000000000..96f132d93
--- /dev/null
+++ b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/BaseInput.java
@@ -0,0 +1,28 @@
+/*
+ * Han He
+ * me@hankcs.com
+ * 2020-12-27 12:07 AM
+ *
+ *
+ * Copyright (c) 2020, Han He. All Rights Reserved, http://www.hankcs.com/
+ * See LICENSE file in the project root for full license information.
+ *
+ */
+package com.hankcs.hanlp.restful;
+
+/**
+ * @author hankcs
+ */
+public class BaseInput
+{
+ public String[] tasks;
+ public String[] skipTasks;
+ public String language;
+
+ public BaseInput(String[] tasks, String[] skipTasks, String language)
+ {
+ this.tasks = tasks;
+ this.skipTasks = skipTasks;
+ this.language = language;
+ }
+}
diff --git a/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/DocumentInput.java b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/DocumentInput.java
new file mode 100644
index 000000000..1804bc335
--- /dev/null
+++ b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/DocumentInput.java
@@ -0,0 +1,25 @@
+/*
+ * Han He
+ * me@hankcs.com
+ * 2020-12-27 12:09 AM
+ *
+ *
+ * Copyright (c) 2020, Han He. All Rights Reserved, http://www.hankcs.com/
+ * See LICENSE file in the project root for full license information.
+ *
+ */
+package com.hankcs.hanlp.restful;
+
+/**
+ * @author hankcs
+ */
+public class DocumentInput extends BaseInput
+{
+ public String text;
+
+ public DocumentInput(String text, String[] tasks, String[] skipTasks, String language)
+ {
+ super(tasks, skipTasks, language);
+ this.text = text;
+ }
+}
diff --git a/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/HanLPClient.java b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/HanLPClient.java
new file mode 100644
index 000000000..6a04f71f7
--- /dev/null
+++ b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/HanLPClient.java
@@ -0,0 +1,196 @@
+/*
+ * Han He
+ * me@hankcs.com
+ * 2020-12-26 11:54 PM
+ *
+ *
+ * Copyright (c) 2020, Han He. All Rights Reserved, http://www.hankcs.com/
+ * See LICENSE file in the project root for full license information.
+ *
+ */
+package com.hankcs.hanlp.restful;
+
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A RESTful client implementing the data format specification of HanLP.
+ *
+ * @author hankcs
+ * @see Data Format
+ */
+public class HanLPClient
+{
+ private String url;
+ private String auth;
+ private String language;
+ private int timeout;
+ private ObjectMapper mapper;
+
+ /**
+ * @param url An API endpoint to a service provider.
+ * @param auth An auth key licenced by a service provider.
+ * @param language The language this client will be expecting. Contact the service provider for the list of
+ * languages supported. Conventionally, zh is used for Chinese and mul for multilingual.
+ * Leave null to use the default language on server.
+ * @param timeout Maximum waiting time in seconds for a request.
+ */
+ public HanLPClient(String url, String auth, String language, int timeout)
+ {
+ if (auth == null)
+ {
+ auth = System.getenv().getOrDefault("HANLP_AUTH", null);
+ }
+ this.url = url;
+ this.auth = auth;
+ this.language = language;
+ this.timeout = timeout * 1000;
+ this.mapper = new ObjectMapper();
+ }
+
+ /**
+ * @param url An API endpoint to a service provider.
+ * @param auth An auth key licenced by a service provider.
+ */
+ public HanLPClient(String url, String auth)
+ {
+ this(url, auth, null, 5);
+ }
+
+ /**
+ * Parse a raw document.
+ *
+ * @param text Document content which can have multiple sentences.
+ * @param tasks Tasks to perform.
+ * @param skipTasks Tasks to skip.
+ * @return Parsed annotations.
+ * @throws IOException HTTP exception.
+ * @see Data Format
+ */
+ public Map parse(String text, String[] tasks, String[] skipTasks) throws IOException
+ {
+ //noinspection unchecked
+ return mapper.readValue(post("/parse", new DocumentInput(text, tasks, skipTasks, language)), Map.class);
+ }
+
+ /**
+ * Parse a raw document.
+ *
+ * @param text Document content which can have multiple sentences.
+ * @return Parsed annotations.
+ * @throws IOException HTTP exception.
+ * @see Data Format
+ */
+ public Map parse(String text) throws IOException
+ {
+ return parse(text, null, null);
+ }
+
+ /**
+ * Parse an array of sentences.
+ *
+ * @param sentences Multiple sentences to parse.
+ * @param tasks Tasks to perform.
+ * @param skipTasks Tasks to skip.
+ * @return Parsed annotations.
+ * @throws IOException HTTP exception.
+ * @see Data Format
+ */
+ public Map parse(String[] sentences, String[] tasks, String[] skipTasks) throws IOException
+ {
+ //noinspection unchecked
+ return mapper.readValue(post("/parse", new SentenceInput(sentences, tasks, skipTasks, language)), Map.class);
+ }
+
+ /**
+ * Parse an array of sentences.
+ *
+ * @param sentences Multiple sentences to parse.
+ * @return Parsed annotations.
+ * @throws IOException HTTP exception.
+ * @see Data Format
+ */
+ public Map parse(String[] sentences) throws IOException
+ {
+ return parse(sentences, null, null);
+ }
+
+ /**
+ * Parse an array of pre-tokenized sentences.
+ *
+ * @param tokens Multiple pre-tokenized sentences to parse.
+ * @param tasks Tasks to perform.
+ * @param skipTasks Tasks to skip.
+ * @return Parsed annotations.
+ * @throws IOException HTTP exception.
+ * @see Data Format
+ */
+ public Map parse(String[][] tokens, String[] tasks, String[] skipTasks) throws IOException
+ {
+ //noinspection unchecked
+ return mapper.readValue(post("/parse", new TokenInput(tokens, tasks, skipTasks, language)), Map.class);
+ }
+
+ /**
+ * Parse an array of pre-tokenized sentences.
+ *
+ * @param tokens Multiple pre-tokenized sentences to parse.
+ * @return Parsed annotations.
+ * @throws IOException HTTP exception.
+ * @see Data Format
+ */
+ public Map parse(String[][] tokens) throws IOException
+ {
+ return parse(tokens, null, null);
+ }
+
+ private String post(String api, BaseInput input_) throws IOException
+ {
+ URL url = new URL(this.url + api);
+
+ HttpURLConnection con = (HttpURLConnection) url.openConnection();
+ con.setRequestMethod("POST");
+ con.setRequestProperty("Authorization", "Basic " + auth);
+ con.setRequestProperty("Content-Type", "application/json; utf-8");
+ con.setRequestProperty("Accept", "application/json");
+ con.setDoOutput(true);
+ con.setConnectTimeout(timeout);
+ con.setReadTimeout(timeout);
+
+ String jsonInputString = mapper.writeValueAsString(input_);
+
+ try (OutputStream os = con.getOutputStream())
+ {
+ byte[] input = jsonInputString.getBytes(StandardCharsets.UTF_8);
+ os.write(input, 0, input.length);
+ }
+
+ int code = con.getResponseCode();
+ if (code != 200)
+ {
+ throw new IOException(String.format("Request failed, status code = %d, error = %s", code, con.getResponseMessage()));
+ }
+
+ StringBuilder response = new StringBuilder();
+ try (BufferedReader br = new BufferedReader(new InputStreamReader(con.getInputStream(), StandardCharsets.UTF_8)))
+ {
+ String responseLine;
+ while ((responseLine = br.readLine()) != null)
+ {
+ response.append(responseLine.trim());
+ }
+ }
+ return response.toString();
+ }
+
+}
diff --git a/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/SentenceInput.java b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/SentenceInput.java
new file mode 100644
index 000000000..0eaab7f6a
--- /dev/null
+++ b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/SentenceInput.java
@@ -0,0 +1,25 @@
+/*
+ * Han He
+ * me@hankcs.com
+ * 2020-12-27 12:09 AM
+ *
+ *
+ * Copyright (c) 2020, Han He. All Rights Reserved, http://www.hankcs.com/
+ * See LICENSE file in the project root for full license information.
+ *
+ */
+package com.hankcs.hanlp.restful;
+
+/**
+ * @author hankcs
+ */
+public class SentenceInput extends BaseInput
+{
+ public String[] text;
+
+ public SentenceInput(String[] text, String[] tasks, String[] skipTasks, String language)
+ {
+ super(tasks, skipTasks, language);
+ this.text = text;
+ }
+}
diff --git a/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/TokenInput.java b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/TokenInput.java
new file mode 100644
index 000000000..10479bc4d
--- /dev/null
+++ b/plugins/hanlp_restful_java/src/main/java/com/hankcs/hanlp/restful/TokenInput.java
@@ -0,0 +1,25 @@
+/*
+ * Han He
+ * me@hankcs.com
+ * 2020-12-27 12:09 AM
+ *
+ *
+ * Copyright (c) 2020, Han He. All Rights Reserved, http://www.hankcs.com/
+ * See LICENSE file in the project root for full license information.
+ *
+ */
+package com.hankcs.hanlp.restful;
+
+/**
+ * @author hankcs
+ */
+public class TokenInput extends BaseInput
+{
+ public String[][] tokens;
+
+ public TokenInput(String[][] tokens, String[] tasks, String[] skipTasks, String language)
+ {
+ super(tasks, skipTasks, language);
+ this.tokens = tokens;
+ }
+}
diff --git a/plugins/hanlp_restful_java/src/test/java/com/hankcs/hanlp/restful/HanLPClientTest.java b/plugins/hanlp_restful_java/src/test/java/com/hankcs/hanlp/restful/HanLPClientTest.java
new file mode 100644
index 000000000..06e817449
--- /dev/null
+++ b/plugins/hanlp_restful_java/src/test/java/com/hankcs/hanlp/restful/HanLPClientTest.java
@@ -0,0 +1,53 @@
+package com.hankcs.hanlp.restful;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.jupiter.api.BeforeEach;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+class HanLPClientTest
+{
+ HanLPClient client;
+
+ @BeforeEach
+ void setUp()
+ {
+ client = new HanLPClient("https://hanlp.hankcs.com/api", null);
+ }
+
+ @org.junit.jupiter.api.Test
+ void parseText() throws IOException
+ {
+ Map doc = client.parse("2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。英首相与特朗普通电话讨论华为与苹果公司。");
+ prettyPrint(doc);
+ }
+
+ @org.junit.jupiter.api.Test
+ void parseSentences() throws IOException
+ {
+ Map doc = client.parse(new String[]{
+ "2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。",
+ "英首相与特朗普通电话讨论华为与苹果公司。"
+ });
+ prettyPrint(doc);
+ }
+
+ @org.junit.jupiter.api.Test
+ void parseTokens() throws IOException
+ {
+ Map doc = client.parse(new String[][]{
+ new String[]{"2021年", "HanLPv2.1", "为", "生产", "环境", "带来", "次", "世代", "最", "先进", "的", "多语种", "NLP", "技术", "。"},
+ new String[]{"英", "首相", "与", "特朗普", "通", "电话", "讨论", "华为", "与", "苹果", "公司", "。"},
+ });
+ prettyPrint(doc);
+ }
+
+ void prettyPrint(Object object) throws JsonProcessingException
+ {
+ ObjectMapper mapper = new ObjectMapper();
+ System.out.println(mapper.writerWithDefaultPrettyPrinter().writeValueAsString(object));
+ }
+}
\ No newline at end of file
diff --git a/plugins/hanlp_trie/README.md b/plugins/hanlp_trie/README.md
new file mode 100644
index 000000000..d9519faa0
--- /dev/null
+++ b/plugins/hanlp_trie/README.md
@@ -0,0 +1,17 @@
+# Trie interface and implementation for HanLP
+
+[中文](https://github.com/hankcs/HanLP/tree/doc-zh) | [1.x](https://github.com/hankcs/HanLP/tree/1.x) | [forum](https://bbs.hankcs.com/) | [docker](https://github.com/WalterInSH/hanlp-jupyter-docker)
+
+The multilingual NLP library for researchers and companies, built on PyTorch and TensorFlow 2.x, for advancing state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be efficient, user friendly and extendable. It comes with pretrained models for various human languages including English, Chinese and many others. Currently, HanLP 2.0 is in alpha stage with more killer features on the roadmap. Discussions are welcomed on our [forum](https://bbs.hankcs.com/), while bug reports and feature requests are reserved for GitHub issues. For Java users, please checkout the [1.x](https://github.com/hankcs/HanLP/tree/1.x) branch.
+
+## Installation
+
+```bash
+pip install hanlp
+```
+
+
+## License
+
+HanLP is licensed under **Apache License 2.0**. You can use HanLP in your commercial products for free. We would appreciate it if you add a link to HanLP on your website.
+
diff --git a/plugins/hanlp_trie/hanlp_trie/__init__.py b/plugins/hanlp_trie/hanlp_trie/__init__.py
new file mode 100644
index 000000000..2725b5824
--- /dev/null
+++ b/plugins/hanlp_trie/hanlp_trie/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-29 17:48
+from .trie import Trie
+from .dictionary import DictInterface, TrieDict
diff --git a/plugins/hanlp_trie/hanlp_trie/dictionary.py b/plugins/hanlp_trie/hanlp_trie/dictionary.py
new file mode 100644
index 000000000..967f29d62
--- /dev/null
+++ b/plugins/hanlp_trie/hanlp_trie/dictionary.py
@@ -0,0 +1,156 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-29 17:53
+from abc import ABC, abstractmethod
+from typing import List, Tuple, Any, Dict, Union, Set, Sequence
+
+from hanlp_common.configurable import Configurable
+from hanlp_common.reflection import classpath_of
+from hanlp_trie.trie import Trie
+
+
+class DictInterface(ABC):
+ @abstractmethod
+ def tokenize(self, text: Union[str, Sequence[str]]) -> List[Tuple[int, int, Any]]:
+ """Implement this method to tokenize a piece of text into a list of non-intersect spans, each span is a tuple
+ of ``(begin_offset, end_offset, label)``, where label is some properties related to this span and downstream
+ tasks have the freedom to define what kind of labels they want.
+
+ Args:
+ text: The text to be tokenized.
+
+ Returns:
+ A list of tokens.
+
+ """
+ pass
+
+ def split(self, text: Union[str, Sequence[str]]) -> List[Tuple[int, int, Any]]:
+ """Like the :meth:`str.split`, this method splits a piece of text into chunks by taking the keys in this
+ dictionary as delimiters. It performs longest-prefix-matching on text and split it whenever a longest key is
+ matched. Unlike the :meth:`str.split`, it inserts matched keys into the results list right after where they are
+ found. So that the text can be restored by joining chunks in the results list.
+
+ Args:
+ text: A piece of text.
+
+ Returns:
+ A list of chunks, each chunk is a span of ``(begin_offset, end_offset, label)``, where label is some
+ properties related to this span and downstream tasks.
+ """
+ offset = 0
+ spans = []
+ for begin, end, label in self.tokenize(text):
+ if begin > offset:
+ spans.append(text[offset:begin])
+ spans.append((begin, end, label))
+ offset = end
+ if offset < len(text):
+ spans.append(text[offset:])
+ return spans
+
+
+class TrieDict(Trie, DictInterface, Configurable):
+ def __init__(self, dictionary: Union[Dict[str, Any], Set[str]]) -> None:
+ r"""
+ A dict-like structure for fast custom dictionary strategies in tokenization and tagging. It is built with
+ a dict of key-value pairs or a set of strings. When a set is passed in, it will be turned into a dict where each
+ key is assigned with a boolean value ``True``.
+
+ Args:
+ dictionary: A custom dictionary of string-value pairs.
+ """
+ super().__init__(dictionary)
+
+ def tokenize(self, text: Union[str, Sequence[str]]) -> List[Tuple[int, int, Any]]:
+ return self.parse_longest(text)
+
+ def split_batch(self, data: List[str]) -> Tuple[List[str], List[int], List[List[Tuple[int, int, Any]]]]:
+ """ A handy method to perform longest-prefix-matching on a batch of sentences. It tokenize each sentence, record
+ the chunks being either a key in the dict or a span outside of the dict. The spans are then packed into a new
+ batch and returned along with the following information:
+
+ - which sentence a span belongs to
+ - the matched keys along with their spans and values.
+
+ This method bridges the gap between statistical models and rule-based gazetteers.
+ It's used in conjunction with :meth:`~hanlp_trie.dictionary.TrieDict.merge_batch`.
+
+ Args:
+ data: A batch of sentences.
+
+ Returns:
+ A tuple of the new batch, the belonging information and the keys.
+ """
+ new_data, new_data_belongs, parts = [], [], []
+ for idx, sent in enumerate(data):
+ parts.append([])
+ found = self.tokenize(sent)
+ if found:
+ pre_start = 0
+ for start, end, info in found:
+ if start > pre_start:
+ new_data.append(sent[pre_start:start])
+ new_data_belongs.append(idx)
+ pre_start = end
+ parts[idx].append((start, end, info))
+ if pre_start != len(sent):
+ new_data.append(sent[pre_start:])
+ new_data_belongs.append(idx)
+ else:
+ new_data.append(sent)
+ new_data_belongs.append(idx)
+ return new_data, new_data_belongs, parts
+
+ @staticmethod
+ def merge_batch(data, new_outputs, new_data_belongs, parts):
+ """ A helper method to merge the outputs of split batch back by concatenating the output per span with the key
+ used to split it. It's used in conjunction with :meth:`~hanlp_trie.dictionary.TrieDict.split_batch`.
+
+ Args:
+ data: Split batch.
+ new_outputs: Outputs of the split batch.
+ new_data_belongs: Belonging information.
+ parts: The keys.
+
+ Returns:
+ Merged outputs.
+ """
+ outputs = []
+ segments = []
+ for idx in range(len(data)):
+ segments.append([])
+ for o, b in zip(new_outputs, new_data_belongs):
+ dst = segments[b]
+ dst.append(o)
+ for s, p, sent in zip(segments, parts, data):
+ s: list = s
+ if p:
+ dst = []
+ offset = 0
+ for start, end, info in p:
+ while offset < start:
+ head = s.pop(0)
+ offset += sum(len(token) for token in head)
+ dst += head
+ if isinstance(info, list):
+ dst += info
+ elif isinstance(info, str):
+ dst.append(info)
+ else:
+ dst.append(sent[start:end])
+ offset = end
+ if s:
+ assert len(s) == 1
+ dst += s[0]
+ outputs.append(dst)
+ else:
+ outputs.append(s[0])
+ return outputs
+
+ @property
+ def config(self):
+ return {
+ 'classpath': classpath_of(self),
+ 'dictionary': dict(self.items())
+ }
diff --git a/plugins/hanlp_trie/hanlp_trie/trie.py b/plugins/hanlp_trie/hanlp_trie/trie.py
new file mode 100644
index 000000000..71a0e5d24
--- /dev/null
+++ b/plugins/hanlp_trie/hanlp_trie/trie.py
@@ -0,0 +1,156 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-01-04 23:46
+from typing import Dict, Any, List, Tuple, Sequence, Union, Set
+
+
+class Node(object):
+ def __init__(self, value=None) -> None:
+ """A node in a trie tree.
+
+ Args:
+ value: The value associated with this node.
+ """
+ self._children = {}
+ self._value = value
+
+ def _add_child(self, char, value, overwrite=False):
+ child = self._children.get(char)
+ if child is None:
+ child = Node(value)
+ self._children[char] = child
+ elif overwrite:
+ child._value = value
+ return child
+
+ def transit(self, key):
+ """Transit the state of a Deterministic Finite Automata (DFA) with key.
+
+ Args:
+ key: A sequence of criterion (tokens or characters) used to transit to a new state.
+
+ Returns:
+ A new state if the transition succeeded, otherwise ``None``.
+
+ """
+ state = self
+ for char in key:
+ state = state._children.get(char)
+ if state is None:
+ break
+ return state
+
+ def _walk(self, prefix: str, ordered=False):
+ for char, child in sorted(self._children.items()) if ordered else self._children.items():
+ prefix_new = prefix + char
+ if child._value:
+ yield prefix_new, child._value
+ yield from child._walk(prefix_new)
+
+
+class Trie(Node):
+ def __init__(self, tokens: Union[Dict[str, Any], Set[str]] = None) -> None:
+ """A referential implementation of the trie (:cite:`10.1145/1457838.1457895`) structure. It stores a dict by
+ assigning each key/value pair a :class:`~hanlp_trie.trie.Node` in a trie tree. It provides get/set/del/items
+ methods just like a :class:`dict` does. Additionally, it also provides longest-prefix-matching and keywords
+ lookup against a piece of text, which are very helpful in rule-based Natural Language Processing.
+
+ Args:
+ tokens: A set of keys or a dict mapping.
+ """
+ super().__init__()
+ if tokens:
+ if isinstance(tokens, set):
+ for k in tokens:
+ self[k] = True
+ else:
+ for k, v in tokens.items():
+ self[k] = v
+
+ def __contains__(self, key):
+ return self[key] is not None
+
+ def __getitem__(self, key):
+ state = self.transit(key)
+ if state is None:
+ return None
+ return state._value
+
+ def __setitem__(self, key, value):
+ state = self
+ for i, char in enumerate(key):
+ if i < len(key) - 1:
+ state = state._add_child(char, None, False)
+ else:
+ state = state._add_child(char, value, True)
+
+ def __delitem__(self, key):
+ state = self.transit(key)
+ if state is not None:
+ state._value = None
+
+ def update(self, dic: Dict[str, Any]):
+ for k, v in dic.items():
+ self[k] = v
+ return self
+
+ def parse(self, text: Sequence[str]) -> List[Tuple[int, int, Any]]:
+ """Keywords lookup which takes a piece of text as input, and lookup all occurrences of keywords in it. These
+ occurrences can overlap with each other.
+
+ Args:
+ text: A piece of text. In HanLP's design, it doesn't really matter whether this is a str or a list of str.
+ The trie will transit on either types properly, which means a list of str simply defines a list of
+ transition criteria while a str defines each criterion as a character.
+
+ Returns:
+ A tuple of ``(begin, end, value)``.
+ """
+ found = []
+ for i in range(len(text)):
+ state = self
+ for j in range(i, len(text)):
+ state = state.transit(text[j])
+ if state:
+ if state._value is not None:
+ found.append((i, j + 1, state._value))
+ else:
+ break
+ return found
+
+ def parse_longest(self, text: Sequence[str]) -> List[Tuple[int, int, Any]]:
+ """Longest-prefix-matching which tries to match the longest keyword sequentially from the head of the text till
+ its tail. By definition, the matches won't overlap with each other.
+
+ Args:
+ text: A piece of text. In HanLP's design, it doesn't really matter whether this is a str or a list of str.
+ The trie will transit on either types properly, which means a list of str simply defines a list of
+ transition criteria while a str defines each criterion as a character.
+
+ Returns:
+ A tuple of ``(begin, end, value)``.
+
+ """
+ found = []
+ i = 0
+ while i < len(text):
+ state = self.transit(text[i])
+ if state:
+ to = i + 1
+ end = to
+ value = state._value
+ for to in range(i + 1, len(text)):
+ state = state.transit(text[to])
+ if not state:
+ break
+ if state._value is not None:
+ value = state._value
+ end = to + 1
+ if value is not None:
+ found.append((i, end, value))
+ i = end - 1
+ i += 1
+ return found
+
+ def items(self, ordered=False):
+ yield from self._walk('', ordered)
diff --git a/plugins/hanlp_trie/setup.py b/plugins/hanlp_trie/setup.py
new file mode 100644
index 000000000..1f7171670
--- /dev/null
+++ b/plugins/hanlp_trie/setup.py
@@ -0,0 +1,38 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2019-12-28 19:26
+from os.path import abspath, join, dirname
+from setuptools import find_packages, setup
+
+this_dir = abspath(dirname(__file__))
+with open(join(this_dir, 'README.md'), encoding='utf-8') as file:
+ long_description = file.read()
+
+setup(
+ name='hanlp_trie',
+ version='0.0.1',
+ description='HanLP: Han Language Processing',
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url='https://github.com/hankcs/HanLP',
+ author='hankcs',
+ author_email='hankcshe@gmail.com',
+ license='Apache License 2.0',
+ classifiers=[
+ 'Intended Audience :: Science/Research',
+ 'Intended Audience :: Developers',
+ "Development Status :: 3 - Alpha",
+ 'Operating System :: OS Independent',
+ "License :: OSI Approved :: Apache Software License",
+ 'Programming Language :: Python :: 3 :: Only',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ "Topic :: Text Processing :: Linguistic"
+ ],
+ keywords='corpus,machine-learning,NLU,NLP',
+ packages=find_packages(exclude=['docs', 'tests*']),
+ include_package_data=True,
+ install_requires=[
+ 'hanlp_common'
+ ],
+ python_requires='>=3.6',
+)
diff --git a/plugins/hanlp_trie/tests/__init__.py b/plugins/hanlp_trie/tests/__init__.py
new file mode 100644
index 000000000..7cb9c3ba7
--- /dev/null
+++ b/plugins/hanlp_trie/tests/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding:utf-8 -*-
+# Author: hankcs
+# Date: 2020-11-29 18:05
diff --git a/plugins/hanlp_trie/tests/test_trie.py b/plugins/hanlp_trie/tests/test_trie.py
new file mode 100644
index 000000000..7e7711d21
--- /dev/null
+++ b/plugins/hanlp_trie/tests/test_trie.py
@@ -0,0 +1,41 @@
+import unittest
+
+from hanlp_trie import Trie
+
+
+class TestTrie(unittest.TestCase):
+ def build_small_trie(self):
+ return Trie({'商品': 'goods', '和': 'and', '和服': 'kimono', '服务': 'service', '务': 'business'})
+
+ def assert_results_valid(self, text, results, trie):
+ for begin, end, value in results:
+ self.assertEqual(value, trie[text[begin:end]])
+
+ def test_parse(self):
+ trie = self.build_small_trie()
+ text = '商品和服务'
+ parse_result = trie.parse(text)
+ self.assert_results_valid(text, parse_result, trie)
+ self.assertEqual([(0, 2, 'goods'),
+ (2, 3, 'and'),
+ (2, 4, 'kimono'),
+ (3, 5, 'service'),
+ (4, 5, 'business')],
+ parse_result)
+
+ def test_parse_longest(self):
+ trie = self.build_small_trie()
+ text = '商品和服务'
+ parse_longest_result = trie.parse_longest(text)
+ self.assert_results_valid(text, parse_longest_result, trie)
+ self.assertEqual([(0, 2, 'goods'), (2, 4, 'kimono'), (4, 5, 'business')],
+ parse_longest_result)
+
+ def test_items(self):
+ trie = self.build_small_trie()
+ items = list(trie.items())
+ self.assertEqual([('商品', 'goods'), ('和', 'and'), ('和服', 'kimono'), ('服务', 'service'), ('务', 'business')], items)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/plugins/hanlp_trie/tests/test_trie_dict.py b/plugins/hanlp_trie/tests/test_trie_dict.py
new file mode 100644
index 000000000..16723c9a9
--- /dev/null
+++ b/plugins/hanlp_trie/tests/test_trie_dict.py
@@ -0,0 +1,31 @@
+import unittest
+
+from hanlp_trie import TrieDict
+
+
+class TestTrieDict(unittest.TestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.text = '第一个词语很重要,第二个词语也很重要'
+ self.trie_dict = TrieDict({'重要': 'important'})
+
+ def test_tokenize(self):
+ self.assertEqual([(6, 8, 'important'), (16, 18, 'important')], self.trie_dict.tokenize(self.text))
+
+ def test_split_batch(self):
+ data = [self.text]
+ new_data, new_data_belongs, parts = self.trie_dict.split_batch(data)
+ predictions = [list(x) for x in new_data]
+ self.assertSequenceEqual(
+ [['第', '一', '个', '词', '语', '很', 'important', ',', '第', '二', '个', '词', '语', '也', '很', 'important']],
+ self.trie_dict.merge_batch(data, predictions, new_data_belongs, parts))
+
+ def test_tokenize_2(self):
+ t = TrieDict({'次世代', '生产环境'})
+ self.assertSequenceEqual(t.tokenize('2021年HanLPv2.1为生产环境带来次世代最先进的多语种NLP技术。'),
+ [(15, 19, True), (21, 24, True)])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/setup.py b/setup.py
index c7ac3d0d3..1e908526e 100644
--- a/setup.py
+++ b/setup.py
@@ -34,10 +34,26 @@
keywords='corpus,machine-learning,NLU,NLP',
packages=find_packages(exclude=['docs', 'tests*']),
include_package_data=True,
- install_requires=['tensorflow==2.3.0', 'bert-for-tf2==0.14.6', 'py-params==0.9.7',
- 'params-flow==0.8.2', 'sentencepiece==0.1.91'],
+ install_requires=[
+ 'termcolor',
+ 'pynvml',
+ 'alnlp',
+ 'penman==0.6.2',
+ 'toposort==1.5',
+ 'transformers',
+ 'torch>=1.6.0',
+ 'hanlp-common',
+ 'hanlp-trie',
+ ],
extras_require={
- 'full': ['fasttext==0.9.1'],
+ 'full': [
+ 'fasttext==0.9.1',
+ 'tensorflow==2.3.0',
+ 'bert-for-tf2==0.14.6',
+ 'py-params==0.9.7',
+ 'params-flow==0.8.2',
+ 'sentencepiece==0.1.91'
+ ],
},
python_requires='>=3.6',
# entry_points={
diff --git a/tests/debug/break_long_sents.py b/tests/debug/break_long_sents.py
deleted file mode 100644
index 001046d80..000000000
--- a/tests/debug/break_long_sents.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-04-03 20:33
-from hanlp.utils.string_util import split_long_sent
-
-
-def main():
- delimiter = set()
- delimiter.update('。!?:;、,,;!?、,')
- print([x for x in split_long_sent(
- ['中', '方', '愿', '与', '东', '盟', '国', '家', '在', '业', '已', '建', '立', '的', '基', '础', '上', ',', '培', '育', '新', '的',
- '合', '作', '点', ',', '增', '强', '优', '势', '互', '补', ',', '即', '在', '深', '化', '经', '贸', '合', '作', '的', '同', '时',
- ',', '将', '合', '作', '领', '域', '拓', '展', '至', '资', '源', '开', '发', ',', '农', '业', ',', '适', '用', '技', '术', ',',
- '医', '药', '卫', '生', ',', '人', '力', '资', '源', '开', '发', ',', '环', '保', '等', '领', '域', ',', '尤', '其', '是', '增',
- '加', '科', '技', '合', '作', '的', '比', '重', ',', '加', '强', '金', '融', ',', '商', '务', '信', '息', '的', '交', '流', ',',
- '以', '促', '进', '各', '自', '结', '构', '调', '整', '和', '增', '长', '方', '式', '的', '转', '变', '。'],
- delimiter, 126)])
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/debug/debug_bert_ner.py b/tests/debug/debug_bert_ner.py
deleted file mode 100644
index ec7846416..000000000
--- a/tests/debug/debug_bert_ner.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-09 00:06
-import hanlp
-
-recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
-print(recognizer([list('孽债 (上海话)')]))
-print(recognizer(['超', '长'] * 256))
diff --git a/tests/debug/debug_conll_sent.py b/tests/debug/debug_conll_sent.py
deleted file mode 100644
index 74b2e559b..000000000
--- a/tests/debug/debug_conll_sent.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-29 16:55
-from hanlp.components.parsers.conll import CoNLLSentence
-
-conll = '''\
-1 蜡烛 蜡烛 NN NN _ 3 Poss _ _
-1 蜡烛 蜡烛 NN NN _ 4 Pat _ _
-2 两 两 CD CD _ 3 Quan _ _
-3 头 头 NN NN _ 4 Loc _ _
-4 烧 烧 VV VV _ 0 Root _ _
-'''
-
-sent = CoNLLSentence.from_str(conll)
-print(sent)
-print([(x.form, x.pos) for x in sent])
diff --git a/tests/debug/debug_transformer_transform.py b/tests/debug/debug_transformer_transform.py
deleted file mode 100644
index f9d01a204..000000000
--- a/tests/debug/debug_transformer_transform.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-11 18:37
-from hanlp.datasets.ner.msra import MSRA_NER_TRAIN
-
-from hanlp.components.taggers.transformers.transformer_transform import TransformerTransform
-
-transform = TransformerTransform(max_seq_length=128)
-
-for x, y in transform.file_to_inputs(MSRA_NER_TRAIN):
- assert len(x) == len(y)
- if not len(x) or len(x) > 126:
- print(x)
diff --git a/tests/debug/test_bert_ner.py b/tests/debug/test_bert_ner.py
deleted file mode 100644
index a85e85b4e..000000000
--- a/tests/debug/test_bert_ner.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import unittest
-
-import hanlp
-
-
-class TestTransformerNamedEntityRecognizer(unittest.TestCase):
-
- def setUp(self) -> None:
- super().setUp()
- self.recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
-
- def test_unk_token(self):
- self.recognizer([list('孽债 (上海话)')])
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/debug/test_string_util.py b/tests/debug/test_string_util.py
deleted file mode 100644
index 43fd6cdce..000000000
--- a/tests/debug/test_string_util.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import unittest
-
-from hanlp.utils.string_util import split_long_sentence_into
-
-
-class TestStringUtil(unittest.TestCase):
- def test_split_long_sentence_into(self):
- sent = ['a', 'b', 'c', ',', 'd', 'e', ',', 'f', 'g', ',', 'h']
- self.assertListEqual([['a', 'b', 'c', ','], ['d', 'e', ','], ['f', 'g', ','], ['h']],
- list(split_long_sentence_into(sent, 2)))
-
- def test_split_127(self):
- sent = ['“', '旧', '货', '”', '不', '仅', '仅', '是', '指', '新', '货', '被', '使', '用', '才', '成', '为', '旧', '货', ';', '还',
- '包', '括', '商', '品', '的', '调', '剂', ',', '即', '卖', '出', '旧', '货', '的', '人', '是', '为', '了', '买', '入', '新',
- '货', ',', '买', '入', '旧', '货', '的', '人', '是', '因', '为', '符', '合', '自', '己', '的', '需', '要', ',', '不', '管',
- '新', '旧', ';', '有', '的', '商', '店', '还', '包', '括', '一', '些', '高', '档', '的', '工', '艺', '品', '、', '古', '董',
- '、', '字', '画', '、', '家', '具', '等', '商', '品', ';', '有', '的', '还', '包', '括', '新', '货', '卖', '不', '出', '去',
- ',', '企', '业', '或', '店', '主', '为', '了', '盘', '活', '资', '金', ',', '削', '价', '销', '售', '积', '压', '产', '品',
- '。']
- results = list(split_long_sentence_into(sent, 126))
- self.assertListEqual([['“', '旧', '货', '”', '不', '仅', '仅', '是', '指', '新', '货', '被', '使', '用', '才', '成', '为', '旧',
- '货', ';', '还', '包', '括', '商', '品', '的', '调', '剂', ',', '即', '卖', '出', '旧', '货', '的', '人',
- '是', '为', '了', '买', '入', '新', '货', ',', '买', '入', '旧', '货', '的', '人', '是', '因', '为', '符',
- '合', '自', '己', '的', '需', '要', ',', '不', '管', '新', '旧', ';', '有', '的', '商', '店', '还', '包',
- '括', '一', '些', '高', '档', '的', '工', '艺', '品', '、', '古', '董', '、', '字', '画', '、', '家', '具',
- '等', '商', '品', ';', '有', '的', '还', '包', '括', '新', '货', '卖', '不', '出', '去', ',', '企', '业',
- '或', '店', '主', '为', '了', '盘', '活', '资', '金', ','],
- ['削', '价', '销', '售', '积', '压', '产', '品', '。']], results)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/debug/test_trie.py b/tests/debug/test_trie.py
deleted file mode 100644
index 9383ad8f7..000000000
--- a/tests/debug/test_trie.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-05 22:47
-from unittest import TestCase
-
-from hanlp.common.trie import Trie
-
-
-class TestTrie(TestCase):
-
- def test_transit(self):
- trie = self.create_trie()
- state = trie.transit('自然')
- self.assertEqual(2, len(state._children))
- self.assertTrue('自然' in trie)
- self.assertEqual('nature', trie['自然'])
- del trie['自然']
- self.assertFalse('自然' in trie)
-
- @staticmethod
- def create_trie():
- trie = Trie()
- trie['自然'] = 'nature'
- trie['自然人'] = 'human'
- trie['自然语言'] = 'language'
- trie['自语'] = 'talk to oneself'
- trie['入门'] = 'introduction'
- return trie
-
- def test_parse_longest(self):
- trie = self.create_trie()
- trie.parse_longest('《自然语言处理入门》出版了')
diff --git a/tests/debug/trie/longest.py b/tests/debug/trie/longest.py
deleted file mode 100644
index 24decb0c8..000000000
--- a/tests/debug/trie/longest.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-11-11 11:08
-from hanlp.common.trie import Trie
-
-trie = Trie({'密码', '码'})
-print(trie.parse_longest('密码设置'))
diff --git a/tests/script/convert_dm_sdp.py b/tests/script/convert_dm_sdp.py
deleted file mode 100644
index f089a7fee..000000000
--- a/tests/script/convert_dm_sdp.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-26 23:20
-from hanlp.pretrained.glove import GLOVE_6B_100D
-from hanlp.components.parsers.biaffine_parser import BiaffineDependencyParser, BiaffineSemanticDependencyParser
-from tests import cdroot
-
-cdroot()
-save_dir = 'data/model/semeval15_biaffine_dm'
-parser = BiaffineSemanticDependencyParser()
-# parser.fit('data/semeval15/en.dm.train.conll', 'data/semeval15/en.dm.dev.conll', save_dir,
-# pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
-# 'config': {
-# 'trainable': False,
-# 'embeddings_initializer': 'zero',
-# 'filepath': GLOVE_6B_100D,
-# 'expand_vocab': True,
-# 'lowercase': False,
-# 'unk': 'unk',
-# 'normalize': True,
-# 'name': 'glove.6B.100d'
-# }},
-# # lstm_dropout=0,
-# # mlp_dropout=0,
-# # embed_dropout=0,
-# epochs=1
-# )
-parser.load(save_dir)
-parser.save_meta(save_dir)
-parser.transform.summarize_vocabs()
-sentence = [('Is', 'VBZ'), ('this', 'DT'), ('the', 'DT'), ('future', 'NN'), ('of', 'IN'), ('chamber', 'NN'),
- ('music', 'NN'), ('?', '.')]
-print(parser.predict(sentence))
-parser.evaluate('data/semeval15/en.id.dm.conll', save_dir)
-parser.evaluate('data/semeval15/en.ood.dm.conll', save_dir)
diff --git a/tests/script/convert_psd_sdp.py b/tests/script/convert_psd_sdp.py
deleted file mode 100644
index 49657ee34..000000000
--- a/tests/script/convert_psd_sdp.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-26 23:20
-from hanlp.pretrained.glove import GLOVE_6B_100D
-from hanlp.components.parsers.biaffine_parser import BiaffineDependencyParser, BiaffineSemanticDependencyParser
-from tests import cdroot
-
-cdroot()
-save_dir = 'data/model/semeval15_biaffine_psd'
-parser = BiaffineSemanticDependencyParser()
-# parser.fit('data/semeval15/en.psd.train.conll', 'data/semeval15/en.psd.dev.conll', save_dir,
-# pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
-# 'config': {
-# 'trainable': False,
-# 'embeddings_initializer': 'zero',
-# 'filepath': GLOVE_6B_100D,
-# 'expand_vocab': True,
-# 'lowercase': False,
-# 'unk': 'unk',
-# 'normalize': True,
-# 'name': 'glove.6B.100d'
-# }},
-# # lstm_dropout=0,
-# # mlp_dropout=0,
-# # embed_dropout=0,
-# epochs=1
-# )
-parser.load(save_dir)
-parser.save_meta(save_dir)
-parser.transform.summarize_vocabs()
-sentence = [('Is', 'VBZ'), ('this', 'DT'), ('the', 'DT'), ('future', 'NN'), ('of', 'IN'), ('chamber', 'NN'),
- ('music', 'NN'), ('?', '.')]
-print(parser.predict(sentence))
-parser.evaluate('data/semeval15/en.id.psd.conll', save_dir)
-parser.evaluate('data/semeval15/en.ood.psd.conll', save_dir)
diff --git a/tests/script/convert_ptb_dep.py b/tests/script/convert_ptb_dep.py
deleted file mode 100644
index ec47f2c2b..000000000
--- a/tests/script/convert_ptb_dep.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-26 23:20
-from hanlp.pretrained.glove import GLOVE_6B_100D
-from hanlp.components.parsers.biaffine_parser import BiaffineDependencyParser
-from tests import cdroot
-
-cdroot()
-save_dir = 'data/model/ptb-dep-converted'
-parser = BiaffineDependencyParser()
-# parser.fit('data/ptb-dep/train.auto.conllx', 'data/ptb-dep/dev.auto.conllx', save_dir,
-# pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
-# 'config': {
-# 'trainable': False,
-# 'embeddings_initializer': 'zero',
-# 'filepath': GLOVE_6B_100D,
-# 'expand_vocab': True,
-# 'lowercase': False,
-# 'unk': 'unk',
-# 'normalize': True,
-# 'name': 'glove.6B.100d'
-# }},
-# # lstm_dropout=0,
-# # mlp_dropout=0,
-# # embed_dropout=0,
-# epochs=1
-# )
-# exit(1)
-parser.load(save_dir)
-parser.save_meta(save_dir)
-parser.transform.summarize_vocabs()
-sentence = [('Is', 'VBZ'), ('this', 'DT'), ('the', 'DT'), ('future', 'NN'), ('of', 'IN'), ('chamber', 'NN'),
- ('music', 'NN'), ('?', '.')]
-print(parser.predict(sentence))
-parser.evaluate('data/ptb-dep/test.auto.conllx', save_dir)
diff --git a/tests/script/evaluate_dep.py b/tests/script/evaluate_dep.py
deleted file mode 100644
index 9f48fe466..000000000
--- a/tests/script/evaluate_dep.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-10 21:27
-import hanlp
-
-syntactic_parser = hanlp.load(hanlp.pretrained.dep.CTB7_BIAFFINE_DEP_ZH)
-syntactic_parser.evaluate(hanlp.datasets.parsing.ctb.CTB7_DEP_TEST)
diff --git a/tests/script/evaluate_sdp.py b/tests/script/evaluate_sdp.py
deleted file mode 100644
index 1f55ee4be..000000000
--- a/tests/script/evaluate_sdp.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-10 21:27
-import hanlp
-
-syntactic_parser = hanlp.load(hanlp.pretrained.sdp.SEMEVAL16_NEWS_BIAFFINE_ZH)
-syntactic_parser.evaluate(hanlp.datasets.parsing.semeval2016.SEMEVAL2016_NEWS_TEST)
diff --git a/tests/test_config_tracker.py b/tests/test_config_tracker.py
new file mode 100644
index 000000000..c5a8f31d8
--- /dev/null
+++ b/tests/test_config_tracker.py
@@ -0,0 +1,18 @@
+import unittest
+
+from hanlp.common.structure import ConfigTracker
+
+
+class MyClass(ConfigTracker):
+ def __init__(self, i_need_this='yes') -> None:
+ super().__init__(locals())
+
+
+class TestConfigTracker(unittest.TestCase):
+ def test_init(self):
+ obj = MyClass()
+ self.assertEqual(obj.config.get('i_need_this', None), 'yes')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_mtl.py b/tests/test_mtl.py
new file mode 100644
index 000000000..b94073101
--- /dev/null
+++ b/tests/test_mtl.py
@@ -0,0 +1,33 @@
+import unittest
+import hanlp
+from hanlp_common.document import Document
+
+
+class TestMultiTaskLearning(unittest.TestCase):
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.mtl = hanlp.load(hanlp.pretrained.mtl.OPEN_TOK_POS_NER_SRL_DEP_SDP_CON_ELECTRA_SMALL_ZH, devices=-1)
+
+ def test_mtl_single_sent(self):
+ doc: Document = self.mtl('商品和服务')
+ self.assertSequenceEqual(doc['tok'], ["商品", "和", "服务"])
+
+ def test_mtl_multiple_sents(self):
+ doc: Document = self.mtl(['商品和服务', '研究生命'])
+ self.assertSequenceEqual(doc['tok'], [
+ ["商品", "和", "服务"],
+ ["研究", "生命"]
+ ])
+
+ def test_skip_tok(self):
+ pre_tokenized_sents = [
+ ["商品和服务", '一个', '词'],
+ ["研究", "生命"]
+ ]
+ doc: Document = self.mtl(pre_tokenized_sents, skip_tasks='tok*')
+ self.assertSequenceEqual(doc['tok'], pre_tokenized_sents)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/train/__init__.py b/tests/train/__init__.py
deleted file mode 100644
index 4c331a9d9..000000000
--- a/tests/train/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 22:22
\ No newline at end of file
diff --git a/tests/train/en/__init__.py b/tests/train/en/__init__.py
deleted file mode 100644
index da17fe806..000000000
--- a/tests/train/en/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-01 18:24
\ No newline at end of file
diff --git a/tests/train/zh/__init__.py b/tests/train/zh/__init__.py
deleted file mode 100644
index da17fe806..000000000
--- a/tests/train/zh/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2020-01-01 18:24
\ No newline at end of file
diff --git a/tests/train/zh/cws/train_ctb6_cws_albert.py b/tests/train/zh/cws/train_ctb6_cws_albert.py
deleted file mode 100644
index 8821b5677..000000000
--- a/tests/train/zh/cws/train_ctb6_cws_albert.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 22:22
-
-from hanlp.components.tok import TransformerTokenizer
-from hanlp.datasets.cws.ctb import CTB6_CWS_TRAIN, CTB6_CWS_VALID, CTB6_CWS_TEST
-from tests import cdroot
-
-cdroot()
-tokenizer = TransformerTokenizer()
-save_dir = 'data/model/ctb6_cws_albert_base'
-tokenizer.fit(CTB6_CWS_TRAIN, CTB6_CWS_VALID, save_dir,
- transformer='albert_base_zh',
- max_seq_length=150,
- metrics='f1', learning_rate=5e-5, epochs=10)
-tokenizer.load(save_dir)
-print(tokenizer.predict(['中央民族乐团离开北京前往维也纳', '商品和服务']))
-tokenizer.evaluate(CTB6_CWS_TEST, save_dir=save_dir)
-print(f'Model saved in {save_dir}')
diff --git a/tests/train/zh/cws/train_ctb6_cws_bert.py b/tests/train/zh/cws/train_ctb6_cws_bert.py
deleted file mode 100644
index 382b30d88..000000000
--- a/tests/train/zh/cws/train_ctb6_cws_bert.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 22:22
-
-from hanlp.components.tok import TransformerTokenizer
-from hanlp.datasets.cws.ctb import CTB6_CWS_TRAIN, CTB6_CWS_VALID, CTB6_CWS_TEST
-from tests import cdroot
-
-cdroot()
-tokenizer = TransformerTokenizer()
-save_dir = 'data/model/ctb6_cws_bert_base'
-tokenizer.fit(CTB6_CWS_TRAIN, CTB6_CWS_VALID, save_dir, transformer='chinese_L-12_H-768_A-12',
- max_seq_length=150,
- epochs=10,
- metrics='f1')
-tokenizer.load(save_dir)
-print(tokenizer.predict(['中央民族乐团离开北京前往维也纳', '商品和服务']))
-tokenizer.evaluate(CTB6_CWS_TEST, save_dir=save_dir)
-print(f'Model saved in {save_dir}')
diff --git a/tests/train/zh/train_ctb5_dep.py b/tests/train/zh/train_ctb5_dep.py
deleted file mode 100644
index 52c720fd3..000000000
--- a/tests/train/zh/train_ctb5_dep.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 18:33
-from hanlp.components.parsers.biaffine_parser import BiaffineDependencyParser
-from hanlp.datasets.parsing.ctb import CTB5_DEP_TRAIN, CTB5_DEP_VALID, CTB5_DEP_TEST
-from hanlp.pretrained.word2vec import CTB5_FASTTEXT_300_CN
-from tests import cdroot
-
-cdroot()
-save_dir = 'data/model/dep/biaffine_ctb'
-parser = BiaffineDependencyParser()
-parser.fit(CTB5_DEP_TRAIN, CTB5_DEP_VALID, save_dir,
- pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
- 'config': {
- 'trainable': False,
- 'embeddings_initializer': 'zero',
- 'filepath': CTB5_FASTTEXT_300_CN,
- 'expand_vocab': True,
- 'lowercase': True,
- 'normalize': True,
- }},
- )
-parser.load(save_dir)
-sentence = [('中国', 'NR'), ('批准', 'VV'), ('设立', 'VV'), ('外商', 'NN'), ('投资', 'NN'), ('企业', 'NN'), ('逾', 'VV'),
- ('三十万', 'CD'), ('家', 'M')]
-print(parser.predict(sentence))
-parser.evaluate(CTB5_DEP_TEST, save_dir)
diff --git a/tests/train/zh/train_ctb7_dep.py b/tests/train/zh/train_ctb7_dep.py
deleted file mode 100644
index 9dd71f744..000000000
--- a/tests/train/zh/train_ctb7_dep.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 18:33
-from hanlp.components.parsers.biaffine_parser import BiaffineDependencyParser
-from hanlp.datasets.parsing.ctb import CTB7_DEP_TRAIN, CTB7_DEP_VALID, CTB7_DEP_TEST, CIP_W2V_100_CN
-from tests import cdroot
-
-cdroot()
-save_dir = 'data/model/dep/biaffine_ctb7'
-parser = BiaffineDependencyParser()
-parser.fit(CTB7_DEP_TRAIN, CTB7_DEP_VALID, save_dir,
- pretrained_embed={'class_name': 'HanLP>Word2VecEmbedding',
- 'config': {
- 'trainable': False,
- 'embeddings_initializer': 'zero',
- 'filepath': CIP_W2V_100_CN,
- 'expand_vocab': True,
- 'lowercase': True,
- 'normalize': True,
- }},
- )
-parser.load(save_dir)
-sentence = [('中国', 'NR'), ('批准', 'VV'), ('设立', 'VV'), ('外商', 'NN'), ('投资', 'NN'), ('企业', 'NN'), ('逾', 'VV'),
- ('三十万', 'CD'), ('家', 'M')]
-print(parser.predict(sentence))
-parser.evaluate(CTB7_DEP_TEST, save_dir)
diff --git a/tests/train/zh/train_msra_ner_rnn.py b/tests/train/zh/train_msra_ner_rnn.py
deleted file mode 100644
index f125f92b0..000000000
--- a/tests/train/zh/train_msra_ner_rnn.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# -*- coding:utf-8 -*-
-# Author: hankcs
-# Date: 2019-12-28 23:15
-from hanlp.components.ner import RNNNamedEntityRecognizer
-from hanlp.datasets.ner.msra import MSRA_NER_TRAIN, MSRA_NER_VALID, MSRA_NER_TEST
-from hanlp.pretrained.word2vec import RADICAL_CHAR_EMBEDDING_100
-from tests import cdroot
-
-cdroot()
-recognizer = RNNNamedEntityRecognizer()
-save_dir = 'data/model/ner/msra_ner_rnn'
-recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir,
- embeddings=RADICAL_CHAR_EMBEDDING_100,
- embedding_trainable=True,
- epochs=100)
-recognizer.evaluate(MSRA_NER_TEST, save_dir)