-
Notifications
You must be signed in to change notification settings - Fork 12
/
fusion_options.py
120 lines (102 loc) · 4.89 KB
/
fusion_options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
from argparse import ArgumentParser
class AttentionMaskFormat:
MaskIndexEnd = 0
MaskIndexEndAndStart = 1
AttentionMask = 2
NoMask = 3
class FusionOptions:
""" Options of fusion in graph optimization
"""
def __init__(self, model_type):
self.enable_gelu = True
self.enable_layer_norm = True
self.enable_attention = True
self.enable_skip_layer_norm = True
self.enable_embed_layer_norm = True
self.enable_bias_skip_layer_norm = True
self.enable_bias_gelu = True
self.enable_gelu_approximation = False
self.attention_mask_format = AttentionMaskFormat.AttentionMask
if model_type == 'gpt2':
self.enable_skip_layer_norm = False
def use_raw_attention_mask(self, use_raw_mask=True):
if use_raw_mask:
self.attention_mask_format = AttentionMaskFormat.AttentionMask
else:
self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
def disable_attention_mask(self):
self.attention_mask_format = AttentionMaskFormat.NoMask
@staticmethod
def parse(args):
options = FusionOptions(args.model_type)
if args.disable_gelu:
options.enable_gelu = False
if args.disable_layer_norm:
options.enable_layer_norm = False
if args.disable_attention:
options.enable_attention = False
if args.disable_skip_layer_norm:
options.enable_skip_layer_norm = False
if args.disable_embed_layer_norm:
options.enable_embed_layer_norm = False
if args.disable_bias_skip_layer_norm:
options.enable_bias_skip_layer_norm = False
if args.disable_bias_gelu:
options.enable_bias_gelu = False
if args.enable_gelu_approximation:
options.enable_gelu_approximation = True
if args.use_mask_index:
options.use_raw_attention_mask(False)
if args.no_attention_mask:
options.disable_attention_mask()
return options
@staticmethod
def add_arguments(parser: ArgumentParser):
parser.add_argument('--disable_attention', required=False, action='store_true', help="disable Attention fusion")
parser.set_defaults(disable_attention=False)
parser.add_argument('--disable_skip_layer_norm',
required=False,
action='store_true',
help="disable SkipLayerNormalization fusion")
parser.set_defaults(disable_skip_layer_norm=False)
parser.add_argument('--disable_embed_layer_norm',
required=False,
action='store_true',
help="disable EmbedLayerNormalization fusion")
parser.set_defaults(disable_embed_layer_norm=False)
parser.add_argument('--disable_bias_skip_layer_norm',
required=False,
action='store_true',
help="disable Add Bias and SkipLayerNormalization fusion")
parser.set_defaults(disable_bias_skip_layer_norm=False)
parser.add_argument('--disable_bias_gelu',
required=False,
action='store_true',
help="disable Add Bias and Gelu/FastGelu fusion")
parser.set_defaults(disable_bias_gelu=False)
parser.add_argument('--disable_layer_norm',
required=False,
action='store_true',
help="disable LayerNormalization fusion")
parser.set_defaults(disable_layer_norm=False)
parser.add_argument('--disable_gelu', required=False, action='store_true', help="disable Gelu fusion")
parser.set_defaults(disable_gelu=False)
parser.add_argument('--enable_gelu_approximation',
required=False,
action='store_true',
help="enable Gelu/BiasGelu to FastGelu conversion")
parser.set_defaults(enable_gelu_approximation=False)
parser.add_argument('--use_mask_index',
required=False,
action='store_true',
help="use mask index instead of raw attention mask in attention operator")
parser.set_defaults(use_mask_index=False)
parser.add_argument('--no_attention_mask',
required=False,
action='store_true',
help="no attention mask. Only works for model_type=bert")
parser.set_defaults(no_attention_mask=False)