Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust: introduce typed labels #17460

Merged
merged 5 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions misc/codegen/generators/rustgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _get_type(t: str) -> str:
case "int":
return "usize"
case _ if t[0].isupper():
return "trap::Label"
return f"trap::Label<{t}>"
case "boolean":
assert False, "boolean unsupported"
case _:
Expand Down Expand Up @@ -57,6 +57,15 @@ def _get_properties(
yield cls, p


def _get_ancestors(
cls: schema.Class, lookup: dict[str, schema.Class]
) -> typing.Iterable[schema.Class]:
for b in cls.bases:
base = lookup[b]
yield base
yield from _get_ancestors(base, lookup)


class Processor:
def __init__(self, data: schema.Schema):
self._classmap = data.classes
Expand All @@ -69,14 +78,15 @@ def _get_class(self, name: str) -> rust.Class:
_get_field(c, p)
for c, p in _get_properties(cls, self._classmap)
if "rust_skip" not in p.pragmas and not p.synth
],
table_name=inflection.tableize(cls.name),
] if not cls.derived else [],
ancestors=sorted(set(a.name for a in _get_ancestors(cls, self._classmap))),
entry_table=inflection.tableize(cls.name) if not cls.derived else None,
)

def get_classes(self):
ret = {"": []}
for k, cls in self._classmap.items():
if not cls.synth and not cls.derived:
if not cls.synth:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
return ret

Expand Down
13 changes: 10 additions & 3 deletions misc/codegen/lib/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,19 @@ def is_label(self):
@dataclasses.dataclass
class Class:
name: str
table_name: str
entry_table: str | None = None
fields: list[Field] = dataclasses.field(default_factory=list)
ancestors: list[str] = dataclasses.field(default_factory=list)

@property
def is_entry(self) -> bool:
return bool(self.entry_table)

@property
def single_field_entries(self):
ret = {self.table_name: []}
def single_field_entries(self) -> dict[str, list[dict]]:
ret = {}
if self.is_entry:
ret[self.entry_table] = []
for f in self.fields:
if f.is_single:
ret.setdefault(f.table_name, []).append(f)
Expand Down
47 changes: 35 additions & 12 deletions misc/codegen/templates/rust_classes.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,76 @@

#![cfg_attr(any(), rustfmt::skip)]

use crate::trap::{TrapId, TrapEntry};
use codeql_extractor::trap;
use crate::trap;
{{#classes}}

{{#is_entry}}
#[derive(Debug)]
pub struct {{name}} {
pub id: TrapId,
pub id: trap::TrapId<{{name}}>,
{{#fields}}
pub {{field_name}}: {{type}},
{{/fields}}
}

impl TrapEntry for {{name}} {
fn extract_id(&mut self) -> TrapId {
std::mem::replace(&mut self.id, TrapId::Star)
impl trap::TrapEntry for {{name}} {
fn extract_id(&mut self) -> trap::TrapId<Self> {
std::mem::replace(&mut self.id, trap::TrapId::Star)
}

fn emit(self, id: trap::Label, out: &mut trap::Writer) {
fn emit(self, id: trap::Label<Self>, out: &mut trap::Writer) {
{{#single_field_entries}}
out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id){{#fields}}, self.{{field_name}}.into(){{/fields}}]);
out.add_tuple("{{entry_table}}", vec![id.into(){{#fields}}, self.{{field_name}}.into(){{/fields}}]);
{{/single_field_entries}}
{{#fields}}
{{#is_predicate}}
if self.{{field_name}} {
out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id)]);
out.add_tuple("{{table_name}}", vec![id.into()]);
}
{{/is_predicate}}
{{#is_optional}}
{{^is_repeated}}
if let Some(v) = self.{{field_name}} {
out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id), v.into()]);
out.add_tuple("{{table_name}}", vec![id.into(), v.into()]);
}
{{/is_repeated}}
{{/is_optional}}
{{#is_repeated}}
for (i, v) in self.{{field_name}}.into_iter().enumerate() {
{{^is_optional}}
out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]);
out.add_tuple("{{table_name}}", vec![id.into(){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]);
{{/is_optional}}
{{#is_optional}}
if let Some(v) = v {
out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]);
out.add_tuple("{{table_name}}", vec![id.into(){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]);
}
{{/is_optional}}
}
{{/is_repeated}}
{{/fields}}
}
}
{{/is_entry}}
{{^is_entry}}
{{! virtual class, make it unbuildable }}
#[derive(Debug)]
pub struct {{name}} {
_unused: ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this just be pub struct {{name}} { } ?

Copy link
Contributor Author

@redsun82 redsun82 Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a pub struct X {} can be built by a user (e.g. let x = X{}), and I'd like to avoid that as we shouldn't be building instances of structs that are not leaves in the hierarchy. By adding one private field we disallow that.

}
{{/is_entry}}

impl trap::TrapClass for {{name}} {
fn class_name() -> &'static str { "{{name}}" }
}
{{#ancestors}}

impl From<trap::Label<{{name}}>> for trap::Label<{{.}}> {
fn from(value: trap::Label<{{name}}>) -> Self {
// SAFETY: this is safe because in the dbscheme {{name}} is a subclass of {{.}}
unsafe {
Self::from_untyped(value.as_untyped())
}
}
}
{{/ancestors}}
{{/classes}}
2 changes: 1 addition & 1 deletion rust/extractor/src/generated/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading