Skip to content

Commit

Permalink
info: Print readable function signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp committed Jul 20, 2024
1 parent 72efff4 commit 0e7fb00
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 62 deletions.
3 changes: 1 addition & 2 deletions numbat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ use markup::Markup;
use module_importer::{ModuleImporter, NullImporter};
use prefix_transformer::Transformer;

use pretty_print::PrettyPrint;
use resolver::CodeSource;
use resolver::Resolver;
use resolver::ResolverError;
Expand Down Expand Up @@ -437,7 +436,7 @@ impl Context {

help += m::text("Signature: ")
+ m::space()
+ fn_signature.fn_type.pretty_print()
+ fn_signature.pretty_print(self.typechecker.registry())
+ m::nl();

if let Some(description) = &metadata.description {
Expand Down
49 changes: 47 additions & 2 deletions numbat/src/typechecker/environment.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::ast::TypeParameterBound;
use crate::ast::{TypeAnnotation, TypeParameterBound};
use crate::dimension::DimensionRegistry;
use crate::pretty_print::PrettyPrint;
use crate::span::Span;
use crate::type_variable::TypeVariable;
use crate::typed_ast::pretty_print_function_signature;
use crate::Type;

use super::substitutions::{ApplySubstitution, Substitution, SubstitutionError};
Expand All @@ -12,13 +15,55 @@ type Identifier = String;

#[derive(Clone, Debug)]
pub struct FunctionSignature {
pub name: String,
pub definition_span: Span,
#[allow(dead_code)]
pub type_parameters: Vec<(Span, String, Option<TypeParameterBound>)>,
pub parameters: Vec<(Span, String)>,
pub parameters: Vec<(Span, String, Option<TypeAnnotation>)>,
pub return_type_annotation: Option<TypeAnnotation>,
pub fn_type: TypeScheme,
}

impl FunctionSignature {
pub fn pretty_print(&self, registry: &DimensionRegistry) -> crate::markup::Markup {
let (fn_type, type_parameters) = self.fn_type.instantiate_for_printing(Some(
self.type_parameters
.iter()
.map(|(_, name, _)| name.clone())
.collect(),
));

let Type::Fn(ref parameter_types, ref return_type) = fn_type.inner else {
unreachable!()
};

let parameters =
self.parameters
.iter()
.zip(parameter_types)
.map(|((_, name, annotation), type_)| {
let readable_type = match annotation {
Some(annotation) => annotation.pretty_print(),
None => type_.to_readable_type(registry),
};
(name.clone(), readable_type)
});

let readable_return_type = match &self.return_type_annotation {
Some(annotation) => annotation.pretty_print(),
None => return_type.to_readable_type(registry),
};

pretty_print_function_signature(
&self.name,
&fn_type,
&type_parameters,
parameters,
&readable_return_type,
)
}
}

#[derive(Clone, Debug)]
pub struct FunctionMetadata {
pub name: Option<String>,
Expand Down
6 changes: 5 additions & 1 deletion numbat/src/typechecker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,11 @@ impl TypeChecker {
argument_types: Vec<Type>,
) -> Result<typed_ast::Expression> {
let FunctionSignature {
name: _,
definition_span,
type_parameters: _,
parameters,
return_type_annotation: _,
fn_type,
} = signature;

Expand Down Expand Up @@ -1377,7 +1379,7 @@ impl TypeChecker {

let parameters: Vec<_> = typed_parameters
.iter()
.map(|(span, name, _, _)| (*span, name.clone()))
.map(|(span, name, _, annotation)| (*span, name.clone(), (*annotation).clone()))
.collect();
let parameter_types = typed_parameters
.iter()
Expand All @@ -1390,9 +1392,11 @@ impl TypeChecker {
typechecker_fn.env.add_function(
function_name.clone(),
FunctionSignature {
name: function_name.clone(),
definition_span: *function_name_span,
type_parameters: type_parameters.clone(),
parameters,
return_type_annotation: return_type_annotation.clone(),
fn_type: fn_type.clone(),
},
FunctionMetadata {
Expand Down
30 changes: 20 additions & 10 deletions numbat/src/typechecker/type_scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ impl TypeScheme {
}
}

pub fn instantiate_for_printing(&self) -> (QualifiedType, Vec<TypeVariable>) {
pub fn instantiate_for_printing(
&self,
type_parameters: Option<Vec<String>>,
) -> (QualifiedType, Vec<TypeVariable>) {
match self {
TypeScheme::Concrete(t) => {
// We take this branch when we report errors during constraint solving, where the
Expand All @@ -49,14 +52,21 @@ impl TypeScheme {
}
TypeScheme::Quantified(n_gen, _) => {
// TODO: is this a good idea? we don't take care of name clashes here
let type_parameters = if *n_gen <= 26 {
(0..*n_gen)
.map(|n| TypeVariable::new(format!("{}", (b'A' + n as u8) as char)))
.collect::<Vec<_>>()
} else {
(0..*n_gen)
.map(|n| TypeVariable::new(format!("T{n}")))
.collect::<Vec<_>>()
let type_parameters = match type_parameters {
Some(tp) if tp.len() == *n_gen => {
tp.iter().map(|s| TypeVariable::new(s.clone())).collect()
}
_ => {
if *n_gen <= 26 {
(0..*n_gen)
.map(|n| TypeVariable::new(format!("{}", (b'A' + n as u8) as char)))
.collect::<Vec<_>>()
} else {
(0..*n_gen)
.map(|n| TypeVariable::new(format!("T{n}")))
.collect::<Vec<_>>()
}
}
};

(self.instantiate_with(&type_parameters), type_parameters)
Expand Down Expand Up @@ -100,7 +110,7 @@ impl TypeScheme {
&self,
registry: &crate::dimension::DimensionRegistry,
) -> crate::markup::Markup {
let (instantiated_type, type_parameters) = self.instantiate_for_printing();
let (instantiated_type, type_parameters) = self.instantiate_for_printing(None);

let mut markup = m::empty();
for type_parameter in &type_parameters {
Expand Down
121 changes: 74 additions & 47 deletions numbat/src/typed_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::dimension::DimensionRegistry;
use crate::pretty_print::escape_numbat_string;
use crate::traversal::{ForAllExpressions, ForAllTypeSchemes};
use crate::type_variable::TypeVariable;
use crate::typechecker::qualified_type::QualifiedType;
use crate::typechecker::type_scheme::TypeScheme;
use crate::typechecker::TypeCheckError;
use crate::{
Expand Down Expand Up @@ -641,14 +642,16 @@ impl Statement {
Statement::DefineFunction(
_,
_,
_,
type_parameters,
parameters,
_,
fn_type,
return_type_annotation,
readable_return_type,
) => {
let (fn_type, _) = fn_type.instantiate_for_printing();
let (fn_type, _) = fn_type.instantiate_for_printing(Some(
type_parameters.iter().map(|(n, _)| n.clone()).collect(),
));

let Type::Fn(parameter_types, return_type) = fn_type.inner else {
unreachable!("Expected a function type")
Expand Down Expand Up @@ -836,6 +839,58 @@ fn decorator_markup(decorators: &Vec<Decorator>) -> Markup {
markup_decorators
}

pub fn pretty_print_function_signature(
function_name: &str,
fn_type: &QualifiedType,
type_parameters: &[TypeVariable],
parameters: impl Iterator<
Item = (
String, // parameter name
Markup, // readable parameter type
),
>,
readable_return_type: &Markup,
) -> Markup {
let markup_type_parameters = if type_parameters.is_empty() {
m::empty()
} else {
m::operator("<")
+ Itertools::intersperse(
type_parameters.iter().map(|tv| {
m::type_identifier(tv.unsafe_name())
+ if fn_type.bounds.is_dtype_bound(tv) {
m::operator(":") + m::space() + m::type_identifier("Dim")
} else {
m::empty()
}
}),
m::operator(", "),
)
.sum()
+ m::operator(">")
};

let markup_parameters = Itertools::intersperse(
parameters.map(|(name, parameter_type)| {
m::identifier(name) + m::operator(":") + m::space() + parameter_type.clone()
}),
m::operator(", "),
)
.sum();

let markup_return_type =
m::space() + m::operator("->") + m::space() + readable_return_type.clone();

m::keyword("fn")
+ m::space()
+ m::identifier(function_name)
+ markup_type_parameters
+ m::operator("(")
+ markup_parameters
+ m::operator(")")
+ markup_return_type
}

impl PrettyPrint for Statement {
fn pretty_print(&self) -> Markup {
match self {
Expand All @@ -861,57 +916,29 @@ impl PrettyPrint for Statement {
Statement::DefineFunction(
function_name,
_decorators,
_type_parameters, // TODO: we ignore user-supplied type parameters here
type_parameters,
parameters,
body,
fn_type,
_return_type_annotation,
readable_return_type,
) => {
let (fn_type, type_parameters) = fn_type.instantiate_for_printing();

let markup_type_parameters = if type_parameters.is_empty() {
m::empty()
} else {
m::operator("<")
+ Itertools::intersperse(
type_parameters.iter().map(|tv| {
m::type_identifier(tv.unsafe_name())
+ if fn_type.bounds.is_dtype_bound(tv) {
m::operator(":") + m::space() + m::type_identifier("Dim")
} else {
m::empty()
}
}),
m::operator(", "),
)
.sum()
+ m::operator(">")
};

let markup_parameters = Itertools::intersperse(
parameters.iter().map(|(_span, name, _, parameter_type)| {
m::identifier(name) + m::operator(":") + m::space() + parameter_type.clone()
}),
m::operator(", "),
)
.sum();

let markup_return_type =
m::space() + m::operator("->") + m::space() + readable_return_type.clone();

m::keyword("fn")
+ m::space()
+ m::identifier(function_name)
+ markup_type_parameters
+ m::operator("(")
+ markup_parameters
+ m::operator(")")
+ markup_return_type
+ body
.as_ref()
.map(|e| m::space() + m::operator("=") + m::space() + e.pretty_print())
.unwrap_or_default()
let (fn_type, type_parameters) = fn_type.instantiate_for_printing(Some(
type_parameters.iter().map(|(n, _)| n.clone()).collect(),
));

pretty_print_function_signature(
function_name,
&fn_type,
&type_parameters,
parameters
.iter()
.map(|(_, name, _, type_)| (name.clone(), type_.clone())),
readable_return_type,
) + body
.as_ref()
.map(|e| m::space() + m::operator("=") + m::space() + e.pretty_print())
.unwrap_or_default()
}
Statement::Expression(expr) => expr.pretty_print(),
Statement::DefineDimension(identifier, dexprs) if dexprs.is_empty() => {
Expand Down

0 comments on commit 0e7fb00

Please sign in to comment.