Skip to content

Commit

Permalink
[sai-gen] Deprecate name attribute for all match keys and action para…
Browse files Browse the repository at this point in the history
…meters, remove type guessing heuristics on object parent names. (#480)
  • Loading branch information
r12f authored Dec 13, 2023
1 parent 01ac82e commit 0841490
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 139 deletions.
128 changes: 60 additions & 68 deletions dash-pipeline/SAI/sai_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def get_sai_type(sai_type):
return SAITypeSolver.sai_type_info_registry[sai_type]

@staticmethod
def get_object_sai_type(object_size, object_parent_name, object_name):
def get_object_sai_type(object_size, object_name):
object_name = object_name.lower()

sai_type_name = ""

if object_size == 1:
Expand All @@ -136,13 +138,13 @@ def get_object_sai_type(object_size, object_parent_name, object_name):
sai_type_name = 'sai_uint16_t'
elif object_size == 32 and ('ip_addr_family' in object_name):
sai_type_name = 'sai_ip_addr_family_t'
elif object_size == 32 and ('addr' in object_name or 'ip' in object_parent_name):
elif object_size == 32 and ('ip' in object_name):
sai_type_name = 'sai_ip_address_t'
elif object_size == 32 and ('_id' in object_name):
sai_type_name = 'sai_object_id_t'
elif object_size <= 32:
sai_type_name = 'sai_uint32_t'
elif object_size == 48 and ('addr' in object_name or 'mac' in object_parent_name):
elif object_size == 48 and ('mac' in object_name):
sai_type_name = 'sai_mac_t'
elif object_size <= 64:
sai_type_name = 'sai_uint64_t'
Expand All @@ -154,60 +156,55 @@ def get_object_sai_type(object_size, object_parent_name, object_name):
return SAITypeSolver.get_sai_type(sai_type_name)

@staticmethod
def get_match_key_sai_type(match_type, key_size, key_parent_name, key_name):
def get_match_key_sai_type(match_type, key_size, key_name):
if match_type == 'exact' or match_type == 'optional' or match_type == 'ternary':
return SAITypeSolver.get_object_sai_type(key_size, key_parent_name, key_name)
return SAITypeSolver.get_object_sai_type(key_size, key_name)
elif match_type == 'lpm':
return SAITypeSolver.__get_lpm_match_key_sai_type(key_size, key_parent_name, key_name)
return SAITypeSolver.__get_lpm_match_key_sai_type(key_size)
elif match_type == 'list':
return SAITypeSolver.__get_list_match_key_sai_type(key_size, key_parent_name, key_name)
return SAITypeSolver.__get_list_match_key_sai_type(key_size)
elif match_type == 'range_list':
return SAITypeSolver.__get_range_list_sai_type(key_size, key_parent_name, key_name)
return SAITypeSolver.__get_range_list_sai_type(key_size)
else:
raise ValueError(f"match_type={match_type} is not supported")

@staticmethod
def __get_lpm_match_key_sai_type(key_size, key_parent_name, key_name):
def __get_lpm_match_key_sai_type(key_size):
sai_type_name = ""

if key_size == 32 and ('addr' in key_name or 'ip' in key_parent_name):
# LPM match key should always be converted into IP prefix.
if key_size == 32:
sai_type_name = 'sai_ip_prefix_t'
elif key_size == 128 and ('addr' in key_name or 'ip' in key_parent_name):
elif key_size == 128:
sai_type_name = 'sai_ip_prefix_t'
else:
raise ValueError(f'key_size={key_size}, key_header={key_parent_name}, and key_field={key_name} is not supported')
raise ValueError(f'key_size={key_size} is not supported')

return SAITypeSolver.get_sai_type(sai_type_name)

@staticmethod
def __get_list_match_key_sai_type(key_size, key_header, key_field):
def __get_list_match_key_sai_type(key_size):
sai_type_name = ""

if key_size <= 8:
sai_type_name = 'sai_u8_list_t'
elif key_size <= 16:
sai_type_name = 'sai_u16_list_t'
elif key_size == 32 and ('addr' in key_field or 'ip' in key_header):
sai_type_name = 'sai_ip_prefix_list_t'
elif key_size <= 32:
sai_type_name = 'sai_u32_list_t'
elif key_size == 128 and ('addr' in key_field or 'ip' in key_header):
sai_type_name = 'sai_ip_prefix_list_t'
else:
raise ValueError(f'key_size={key_size} is not supported')

return SAITypeSolver.get_sai_type(sai_type_name)

@staticmethod
def __get_range_list_sai_type(key_size, key_header, key_field):
def __get_range_list_sai_type(key_size):
sai_type_name = ""

if key_size <= 8:
sai_type_name = 'sai_u8_range_list_t'
elif key_size <= 16:
sai_type_name = 'sai_u16_range_list_t'
elif key_size == 32 and ('addr' in key_field or 'ip' in key_header):
sai_type_name = 'sai_ipaddr_range_list_t'
elif key_size <= 32:
sai_type_name = 'sai_u32_range_list_t'
elif key_size <= 64:
Expand Down Expand Up @@ -313,7 +310,9 @@ def _parse_sai_object_annotation(self, p4rt_anno_list):
for anno in p4rt_anno_list[STRUCTURED_ANNOTATIONS_TAG]:
if anno[NAME_TAG] == SAI_VAL_TAG:
for kv in anno[KV_PAIR_LIST_TAG][KV_PAIRS_TAG]:
if kv['key'] == 'type':
if kv['key'] == 'name':
self.name = kv['value']['stringValue']
elif kv['key'] == 'type':
self.type = kv['value']['stringValue']
elif kv['key'] == 'default_value': # "default" is a reserved keyword and cannot be used.
self.default = kv['value']['stringValue']
Expand All @@ -328,10 +327,16 @@ def _parse_sai_object_annotation(self, p4rt_anno_list):
else:
raise ValueError("Unknown attr annotation " + kv['key'])

sai_type_info = SAITypeSolver.get_sai_type(self.type)
self.field = sai_type_info.field_func_prefix
if self.default == None and sai_type_info.is_enum:
self.default = sai_type_info.default
def _link_ip_is_v6_vars(self, vars):
# Link *_is_v6 var to its corresponding var.
ip_is_v6_key_ids = {v.name.replace("_is_v6", ""): v.id for v in vars if '_is_v6' in v.name}

for v in vars:
if v.name in ip_is_v6_key_ids:
v.ip_is_v6_field_id = ip_is_v6_key_ids[v.name]

# Delete all vars with *_is_v6 in their names.
return [v for v in vars if '_is_v6' not in v.name]


@sai_parser_from_p4rt
Expand Down Expand Up @@ -400,12 +405,11 @@ class SAIAPITableKey(SAIObject):
'''
def __init__(self):
super().__init__()
self.sai_key_name = ""
self.match_type = ""
self.bitwidth = 0
self.ip_is_v6_field_id = 0

def parse_p4rt(self, p4rt_table_key, ip_is_v6_key_ids):
def parse_p4rt(self, p4rt_table_key):
'''
This method parses the P4Runtime table key object and populates the SAI API table key object.
Expand All @@ -427,12 +431,10 @@ def parse_p4rt(self, p4rt_table_key, ip_is_v6_key_ids):

self.id = p4rt_table_key['id']
self.name = p4rt_table_key[NAME_TAG]
#print("Parsing table key: " + self.name)

full_key_name, self.sai_key_name, _ = self.parse_sai_annotated_name(self.name, full_name_part_start = -2)
key_header, key_field = full_key_name.split('.')

self.bitwidth = p4rt_table_key[BITWIDTH_TAG]
# print("Parsing table key: " + self.name)

_, self.name, _ = self.parse_sai_annotated_name(self.name, full_name_part_start = -2)

if OTHER_MATCH_TYPE_TAG in p4rt_table_key:
self.match_type = p4rt_table_key[OTHER_MATCH_TYPE_TAG].lower()
Expand All @@ -443,15 +445,17 @@ def parse_p4rt(self, p4rt_table_key, ip_is_v6_key_ids):

if STRUCTURED_ANNOTATIONS_TAG in p4rt_table_key:
self._parse_sai_object_annotation(p4rt_table_key)

# If type is specified, use it. Otherwise, try to find the proper type using default heuristics.
if self.type != None:
sai_type_info = SAITypeSolver.get_sai_type(self.type)
else:
sai_type_info = SAITypeSolver.get_match_key_sai_type(self.match_type, self.bitwidth, key_header, key_field)
sai_type_info = SAITypeSolver.get_match_key_sai_type(self.match_type, self.bitwidth, self.name)
self.type = sai_type_info.name
self.field = sai_type_info.field_func_prefix

# If *_is_v6 key is present, save its id.
ip_is_v6_key_name = self.sai_key_name + "_is_v6"
if ip_is_v6_key_name in ip_is_v6_key_ids:
self.ip_is_v6_field_id = ip_is_v6_key_ids[ip_is_v6_key_name]
self.field = sai_type_info.field_func_prefix
if self.default == None and sai_type_info.is_enum:
self.default = sai_type_info.default

return

Expand Down Expand Up @@ -489,18 +493,13 @@ def parse_action_params(self, p4rt_table_action, sai_enums):
if PARAMS_TAG not in p4rt_table_action:
return

# Save all *_is_v6 param ids.
ip_is_v6_param_ids = dict()
for p4rt_table_action_param in p4rt_table_action[PARAMS_TAG]:
if '_is_v6' in p4rt_table_action_param[NAME_TAG]:
ip_is_v6_param_name = p4rt_table_action_param[NAME_TAG]
ip_is_v6_param_ids[ip_is_v6_param_name] = p4rt_table_action_param['id']

# Parse all params.
for p in p4rt_table_action[PARAMS_TAG]:
param = SAIAPITableActionParam.from_p4rt(p, sai_enums = sai_enums, ip_is_v6_param_ids = ip_is_v6_param_ids)
param = SAIAPITableActionParam.from_p4rt(p)
self.params.append(param)

self.params = self._link_ip_is_v6_vars(self.params)

return


Expand All @@ -512,7 +511,7 @@ def __init__(self):
self.ip_is_v6_field_id = 0
self.param_actions = []

def parse_p4rt(self, p4rt_table_action_param, sai_enums, ip_is_v6_param_ids):
def parse_p4rt(self, p4rt_table_action_param):
'''
This method parses the P4Runtime table action object and populates the SAI API table action object.
Expand All @@ -527,16 +526,17 @@ def parse_p4rt(self, p4rt_table_action_param, sai_enums, ip_is_v6_param_ids):

if STRUCTURED_ANNOTATIONS_TAG in p4rt_table_action_param:
self._parse_sai_object_annotation(p4rt_table_action_param)

# If type is specified, use it. Otherwise, try to find the proper type using default heuristics.
if self.type != None:
sai_type_info = SAITypeSolver.get_sai_type(self.type)
else:
sai_type_info = SAITypeSolver.get_object_sai_type(self.bitwidth, self.name, self.name)
self.type, self.field = sai_type_info.name, sai_type_info.field_func_prefix
if sai_type_info.is_enum:
self.default = sai_type_info.default
sai_type_info = SAITypeSolver.get_object_sai_type(self.bitwidth, self.name)
self.type = sai_type_info.name

# If *_is_v6 key is present, save its id.
ip_is_v6_param_name = self.name + "_is_v6"
if ip_is_v6_param_name in ip_is_v6_param_ids:
self.ip_is_v6_field_id = ip_is_v6_param_ids[ip_is_v6_param_name]
self.field = sai_type_info.field_func_prefix
if self.default == None and sai_type_info.is_enum:
self.default = sai_type_info.default

return

Expand Down Expand Up @@ -610,7 +610,7 @@ def parse_p4rt(self, p4rt_table, program, all_actions, ignore_tables):
self.__parse_table_actions(p4rt_table, all_actions)

if self.is_object == None:
if len(self.keys) == 1 and self.keys[0].sai_key_name.endswith(self.name.split('.')[-1] + '_id'):
if len(self.keys) == 1 and self.keys[0].name.endswith(self.name.split('.')[-1] + '_id'):
self.is_object = 'true'
elif len(self.keys) > 5:
self.is_object = 'true'
Expand Down Expand Up @@ -651,20 +651,12 @@ def __table_with_counters(self, program):
return 'false'

def __parse_table_keys(self, p4rt_table):
ip_is_v6_key_ids = dict()
for p4rt_table_key in p4rt_table[MATCH_FIELDS_TAG]:
if '_is_v6' in p4rt_table_key[NAME_TAG]:
_, ip_is_v6_key_name, _ = self.parse_sai_annotated_name(p4rt_table_key[NAME_TAG])
ip_is_v6_key_ids[ip_is_v6_key_name] = p4rt_table_key['id']

for p4rt_table_key in p4rt_table[MATCH_FIELDS_TAG]:
# Skip all *_is_v6 keys, as they will be linked via table key property.
if '_is_v6' in p4rt_table_key[NAME_TAG]:
continue

table_key = SAIAPITableKey.from_p4rt(p4rt_table_key, ip_is_v6_key_ids)
table_key = SAIAPITableKey.from_p4rt(p4rt_table_key)
self.keys.append(table_key)

self.keys = self._link_ip_is_v6_vars(self.keys)

for p4rt_table_key in self.keys:
if (p4rt_table_key.match_type == 'exact' and p4rt_table_key.type == 'sai_ip_address_t') or \
(p4rt_table_key.match_type == 'ternary' and p4rt_table_key.type == 'sai_ip_address_t') or \
Expand Down Expand Up @@ -791,7 +783,7 @@ def __update_table_param_object_name_reference(self):
for key in table.keys:
if key.type != None:
if key.type == 'sai_object_id_t':
table_ref = key.sai_key_name[:-len("_id")]
table_ref = key.name[:-len("_id")]
for table_name in all_table_names:
if table_ref.endswith(table_name):
key.object_name = table_name
Expand Down
26 changes: 13 additions & 13 deletions dash-pipeline/SAI/templates/saiapi.cpp.j2
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static sai_status_t dash_sai_create_{{ table.name }}(
switch(attr_list[i].id)
{
{% for key in table['keys'] %}
case SAI_{{ table.name | upper }}_ATTR_{{ key.sai_key_name | upper }}:
case SAI_{{ table.name | upper }}_ATTR_{{ key.name | upper }}:
{
auto mf = matchActionEntry->add_match();
mf->set_field_id({{key.id}});
Expand All @@ -102,8 +102,8 @@ static sai_status_t dash_sai_create_{{ table.name }}(
{% elif key.match_type == 'ternary' %}
auto mf_ternary = mf->mutable_ternary();
{{key.field}}SetVal(attr_list[i].value, mf_ternary, {{key.bitwidth}});
auto mask = getMaskAttr(SAI_{{ table.name | upper }}_ATTR_{{ key.sai_key_name | upper }}_MASK, attr_count, attr_list);
assert(mask && "SAI_{{ table.name | upper }}_ATTR_{{ key.sai_key_name | upper }}_MASK isn't provided");
auto mask = getMaskAttr(SAI_{{ table.name | upper }}_ATTR_{{ key.name | upper }}_MASK, attr_count, attr_list);
assert(mask && "SAI_{{ table.name | upper }}_ATTR_{{ key.name | upper }}_MASK isn't provided");
{{key.field}}SetMask(mask->value, mf_ternary, {{key.bitwidth}});
{% endif %}
{% if key.ip_is_v6_field_id != 0 %}
Expand Down Expand Up @@ -340,17 +340,17 @@ static sai_status_t dash_sai_create_{{ table.name }}(
mf->set_field_id({{key.id}});
{% if key.match_type == 'exact' %}
auto mf_exact = mf->mutable_exact();
//{{key.field}}SetVal(tableEntry->{{ key.sai_key_name | lower }}, mf_exact, {{key.bitwidth}});
//{{key.field}}SetVal(tableEntry->{{ key.name | lower }}, mf_exact, {{key.bitwidth}});
{% set keyfield = key.field %}
{% set bitwidth = key.bitwidth %}
{% if keyfield in ['ipaddr','mac'] or bitwidth in [24] %}
{{key.field}}SetVal(tableEntry->{{ key.sai_key_name | lower }}, mf_exact, {{key.bitwidth}});
{{key.field}}SetVal(tableEntry->{{ key.name | lower }}, mf_exact, {{key.bitwidth}});
{% else %}
{{key.field}}SetVal(static_cast<uint{{key.bitwidth}}_t>(tableEntry->{{ key.sai_key_name | lower }}), mf_exact, {{key.bitwidth}});
{{key.field}}SetVal(static_cast<uint{{key.bitwidth}}_t>(tableEntry->{{ key.name | lower }}), mf_exact, {{key.bitwidth}});
{% endif %}
{% elif key.match_type == 'lpm' %}
auto mf_lpm = mf->mutable_lpm();
{{key.field}}SetVal(tableEntry->{{ key.sai_key_name | lower }}, mf_lpm, {{key.bitwidth}});
{{key.field}}SetVal(tableEntry->{{ key.name | lower }}, mf_lpm, {{key.bitwidth}});
{% elif key.match_type == 'list' %}
assert(0 && "mutable_list is not supported");
goto ErrRet;
Expand All @@ -370,7 +370,7 @@ static sai_status_t dash_sai_create_{{ table.name }}(
auto mf = matchActionEntry->add_match();
mf->set_field_id({{key.ip_is_v6_field_id}});
auto mf_exact = mf->mutable_exact();
booldataSetVal((tableEntry->{{ key.sai_key_name | lower }}.addr_family == SAI_IP_ADDR_FAMILY_IPV4) ? 0 : 1, mf_exact, 1);
booldataSetVal((tableEntry->{{ key.name | lower }}.addr_family == SAI_IP_ADDR_FAMILY_IPV4) ? 0 : 1, mf_exact, 1);
}
{% endif %}
{% endfor %}
Expand Down Expand Up @@ -507,14 +507,14 @@ static sai_status_t dash_sai_remove_{{ table.name }}(
{% set keyfield = key.field %}
{% set bitwidth = key.bitwidth %}
{% if keyfield in ['ipaddr','mac'] or bitwidth in [24] %}
{{key.field}}SetVal(tableEntry->{{ key.sai_key_name | lower }}, mf_exact, {{key.bitwidth}});
{{key.field}}SetVal(tableEntry->{{ key.name | lower }}, mf_exact, {{key.bitwidth}});
{% else %}
{{key.field}}SetVal(static_cast<uint{{key.bitwidth}}_t>(tableEntry->{{ key.sai_key_name | lower }}), mf_exact, {{key.bitwidth}});
{{key.field}}SetVal(static_cast<uint{{key.bitwidth}}_t>(tableEntry->{{ key.name | lower }}), mf_exact, {{key.bitwidth}});
{% endif %}
//{{key.field}}SetVal(tableEntry->{{ key.sai_key_name | lower }}, mf_exact, {{key.bitwidth}});
//{{key.field}}SetVal(tableEntry->{{ key.name | lower }}, mf_exact, {{key.bitwidth}});
{% elif key.match_type == 'lpm' %}
auto mf_lpm = mf->mutable_lpm();
{{key.field}}SetVal(tableEntry->{{ key.sai_key_name | lower }}, mf_lpm, {{key.bitwidth}});
{{key.field}}SetVal(tableEntry->{{ key.name | lower }}, mf_lpm, {{key.bitwidth}});
{% elif key.match_type == 'list' %}
assert(0 && "mutable_list is not supported");
goto ErrRet;
Expand All @@ -533,7 +533,7 @@ static sai_status_t dash_sai_remove_{{ table.name }}(
auto mf = matchActionEntry->add_match();
mf->set_field_id({{key.ip_is_v6_field_id}});
auto mf_exact = mf->mutable_exact();
booldataSetVal((tableEntry->{{ key.sai_key_name | lower }}.addr_family == SAI_IP_ADDR_FAMILY_IPV4) ? 0 : 1, mf_exact, 1);
booldataSetVal((tableEntry->{{ key.name | lower }}.addr_family == SAI_IP_ADDR_FAMILY_IPV4) ? 0 : 1, mf_exact, 1);
}
{% endif %}
{% endfor %}
Expand Down
Loading

0 comments on commit 0841490

Please sign in to comment.