Skip to content

Commit

Permalink
Merge pull request #744 from SGSSGene/patch/improved_template_export
Browse files Browse the repository at this point in the history
Improved c++ code generation when template are involved
  • Loading branch information
mergify[bot] authored Sep 27, 2023
2 parents f098da4 + 60818b8 commit 6356268
Showing 1 changed file with 104 additions and 14 deletions.
118 changes: 104 additions & 14 deletions schema_salad/cpp_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,55 @@ def safename2(name: Dict[str, str]) -> str:
return safename(name["namespace"]) + "::" + safename(name["classname"])


# Splits names like https://xyz.xyz/blub#cwl/class
# into its class path and non class path
def split_name(s: str) -> Tuple[str, str]:
"""Split url name into its components.
Splits names like https://xyz.xyz/blub#cwl/class
into its class path and non class path
"""
t = s.split("#")
if len(t) != 2:
raise ValueError("Expected field to be formatted as 'https://xyz.xyz/blub#cwl/class'.")
return (t[0], t[1])


# similar to split_name but for field names
def split_field(s: str) -> Tuple[str, str, str]:
"""Split field into its components.
similar to split_name but for field names
"""
(namespace, field) = split_name(s)
t = field.split("/")
if len(t) != 2:
raise ValueError("Expected field to be formatted as 'https://xyz.xyz/blub#cwl/class'.")
return (namespace, t[0], t[1])


# Prototype of a class
class ClassDefinition:
"""Prototype of a class."""

def __init__(self, name: str):
"""Initialize the class definition with a name."""
self.fullName = name
self.extends: List[Dict[str, str]] = []

# List of types from parent classes that have been specialized
self.specializationTypes: List[str] = []

# this includes fields that are also inheritant
self.allfields: List[FieldDefinition] = []
self.fields: List[FieldDefinition] = []
self.abstract = False
(self.namespace, self.classname) = split_name(name)
self.namespace = safename(self.namespace)
self.classname = safename(self.classname)

def writeFwdDeclaration(self, target: IO[str], fullInd: str, ind: str) -> None:
"""Write forward declaration."""
target.write(f"{fullInd}namespace {self.namespace} {{ struct {self.classname}; }}\n")

def writeDefinition(self, target: IO[Any], fullInd: str, ind: str) -> None:
"""Write definition of the class."""
target.write(f"{fullInd}namespace {self.namespace} {{\n")
target.write(f"{fullInd}struct {self.classname}")
extends = list(map(safename2, self.extends))
Expand All @@ -113,6 +129,7 @@ def writeDefinition(self, target: IO[Any], fullInd: str, ind: str) -> None:
target.write(f"{fullInd}}}\n\n")

def writeImplDefinition(self, target: IO[str], fullInd: str, ind: str) -> None:
"""Write definition with implementation."""
extends = list(map(safename2, self.extends))

if self.abstract:
Expand All @@ -131,20 +148,33 @@ def writeImplDefinition(self, target: IO[str], fullInd: str, ind: str) -> None:

for field in self.fields:
fieldname = safename(field.name)
target.write(
f'{fullInd}{ind}addYamlField(n, "{field.name}", toYaml(*{fieldname}));\n' # noqa: B907
)
if field.remap != "":
target.write(
f"""{fullInd}{ind}addYamlField(n, "{field.name}",
convertListToMap(toYaml(*{fieldname}), "{field.remap}"));\n""" # noqa: B907
)
else:
target.write(
f'{fullInd}{ind}addYamlField(n, "{field.name}", toYaml(*{fieldname}));\n' # noqa: B907
)
# target.write(f"{fullInd}{ind}addYamlIfNotEmpty(n, \"{field.name}\", toYaml(*{fieldname}));\n")

target.write(f"{fullInd}{ind}return n;\n{fullInd}}}\n")


# Prototype of a single field of a class
class FieldDefinition:
def __init__(self, name: str, typeStr: str, optional: bool):
"""Prototype of a single field from a class definition."""

def __init__(self, name: str, typeStr: str, optional: bool, remap: str):
"""Initialize field definition.
Creates a new field with name, its type, optional and which field to use to convert
from list to map (or empty if it is not possible)
"""
self.name = name
self.typeStr = typeStr
self.optional = optional
self.remap = remap

def writeDefinition(self, target: IO[Any], fullInd: str, ind: str, namespace: str) -> None:
"""Write a C++ definition for the class field."""
Expand All @@ -153,13 +183,16 @@ def writeDefinition(self, target: IO[Any], fullInd: str, ind: str, namespace: st
target.write(f"{fullInd}heap_object<{typeStr}> {name};\n")


# Prototype of an enum definition
class EnumDefinition:
"""Prototype of a enum."""

def __init__(self, name: str, values: List[str]):
"""Initialize enum definition with a name and possible values."""
self.name = name
self.values = values

def writeDefinition(self, target: IO[str], ind: str) -> None:
"""Write enum definition to output."""
namespace = ""
if len(self.name.split("#")) == 2:
(namespace, classname) = split_name(self.name)
Expand Down Expand Up @@ -201,12 +234,14 @@ def writeDefinition(self, target: IO[str], ind: str) -> None:

# !TODO way tot many functions, most of these shouldn't exists
def isPrimitiveType(v: Any) -> bool:
"""Check if v is a primitve type."""
if not isinstance(v, str):
return False
return v in ["null", "boolean", "int", "long", "float", "double", "string"]


def hasFieldValue(e: Any, f: str, v: Any) -> bool:
"""Check if e has a field f value."""
if not isinstance(e, dict):
return False
if f not in e:
Expand All @@ -215,10 +250,12 @@ def hasFieldValue(e: Any, f: str, v: Any) -> bool:


def isRecordSchema(v: Any) -> bool:
"""Check if v is of type record schema."""
return hasFieldValue(v, "type", "record")


def isEnumSchema(v: Any) -> bool:
"""Check if v is of type enum schema."""
if not hasFieldValue(v, "type", "enum"):
return False
if "symbols" not in v:
Expand All @@ -229,6 +266,7 @@ def isEnumSchema(v: Any) -> bool:


def isArray(v: Any) -> bool:
"""Check if v is of type array."""
if not isinstance(v, list):
return False
for i in v:
Expand All @@ -238,6 +276,7 @@ def isArray(v: Any) -> bool:


def pred(i: Any) -> bool:
"""Check if v is any of the simple types."""
return (
isPrimitiveType(i)
or isRecordSchema(i)
Expand All @@ -248,6 +287,7 @@ def pred(i: Any) -> bool:


def isArraySchema(v: Any) -> bool:
"""Check if v is of type array schema."""
if not hasFieldValue(v, "type", "array"):
return False
if "items" not in v:
Expand All @@ -272,6 +312,7 @@ def __init__(
package: str,
copyright: Optional[str],
) -> None:
"""Initialize the C++ code generator."""
super().__init__()
self.base_uri = base
self.target = target
Expand Down Expand Up @@ -376,8 +417,8 @@ def convertTypeToCpp(self, type_declaration: Union[List[Any], Dict[str, Any], st
type_declaration = ", ".join(type_declaration)
return f"std::variant<{type_declaration}>"

# start of our generated file
def epilogue(self, root_loader: Optional[TypeDef]) -> None:
"""Generate final part of our cpp file."""
self.target.write(
"""#pragma once
Expand Down Expand Up @@ -428,12 +469,23 @@ def epilogue(self, root_loader: Optional[TypeDef]) -> None:
return YAML::Node{v};
}
inline void addYamlField(YAML::Node node, std::string const& key, YAML::Node value) {
inline void addYamlField(YAML::Node& node, std::string const& key, YAML::Node value) {
if (value.IsDefined()) {
node[key] = value;
}
}
inline auto convertListToMap(YAML::Node list, std::string const& key_name) {
if (list.size() == 0) return list;
auto map = YAML::Node{};
for (YAML::Node n : list) {
auto key = n[key_name].as<std::string>();
n.remove(key_name);
map[key] = n;
}
return map;
}
// fwd declaring toYaml
template <typename T>
auto toYaml(std::vector<T> const& v) -> YAML::Node;
Expand Down Expand Up @@ -505,6 +557,27 @@ class heap_object {
for key in self.classDefinitions:
self.classDefinitions[key].writeFwdDeclaration(self.target, "", " ")

# remove parent classes, that are specialized/templated versions
for key in self.classDefinitions:
if len(self.classDefinitions[key].specializationTypes) > 0:
self.classDefinitions[key].extends = []

# remove fields that are available in a parent class
for key in self.classDefinitions:
for field in self.classDefinitions[key].allfields:
found = False
for parent_key in self.classDefinitions[key].extends:
fullKey = parent_key["namespace"] + "#" + parent_key["classname"]
for f in self.classDefinitions[fullKey].allfields:
if f.name == field.name:
found = True
break
if found:
break

if not found:
self.classDefinitions[key].fields.append(field)

for key in self.enumDefinitions:
self.enumDefinitions[key].writeDefinition(self.target, " ")
for key in self.classDefinitions:
Expand Down Expand Up @@ -542,7 +615,13 @@ class heap_object {
)

def parseRecordField(self, field: Dict[str, Any]) -> FieldDefinition:
"""Parse a record field."""
(namespace, classname, fieldname) = split_field(field["name"])
remap = ""
if "jsonldPredicate" in field:
if "mapSubject" in field["jsonldPredicate"]:
remap = field["jsonldPredicate"]["mapSubject"]

if isinstance(field["type"], dict):
if field["type"]["type"] == "enum":
fieldtype = "Enum"
Expand All @@ -553,9 +632,10 @@ def parseRecordField(self, field: Dict[str, Any]) -> FieldDefinition:
fieldtype = field["type"]
fieldtype = self.convertTypeToCpp(fieldtype)

return FieldDefinition(name=fieldname, typeStr=fieldtype, optional=False)
return FieldDefinition(name=fieldname, typeStr=fieldtype, optional=False, remap=remap)

def parseRecordSchema(self, stype: Dict[str, Any]) -> None:
"""Parse a record schema."""
cd = ClassDefinition(name=stype["name"])
cd.abstract = stype.get("abstract", False)

Expand All @@ -565,13 +645,18 @@ def parseRecordSchema(self, stype: Dict[str, Any]) -> None:
ext = {"namespace": base_namespace, "classname": base_classname}
cd.extends.append(ext)

if "specialize" in stype:
for e in aslist(stype["specialize"]):
cd.specializationTypes.append(e["specializeFrom"])

if "fields" in stype:
for field in stype["fields"]:
cd.fields.append(self.parseRecordField(field))
cd.allfields.append(self.parseRecordField(field))

self.classDefinitions[stype["name"]] = cd

def parseEnum(self, stype: Dict[str, Any]) -> str:
"""Parse a schema salad enum."""
name = cast(str, stype["name"])
if name not in self.enumDefinitions:
self.enumDefinitions[name] = EnumDefinition(
Expand All @@ -580,6 +665,11 @@ def parseEnum(self, stype: Dict[str, Any]) -> str:
return name

def parse(self, items: List[Dict[str, Any]]) -> None:
"""Parse sechema salad items.
This function is being called from the outside and drives
the whole code generation.
"""
for stype in items:
if "type" in stype and stype["type"] == "documentation":
continue
Expand Down

0 comments on commit 6356268

Please sign in to comment.