Skip to content

Commit

Permalink
Enum for self-hosted compiler (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Mar 13, 2023
1 parent d3c7453 commit a3f82bc
Show file tree
Hide file tree
Showing 15 changed files with 540 additions and 59 deletions.
26 changes: 23 additions & 3 deletions self_hosted/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class AstExpression:
kind: AstExpressionKind

union:
enum_member: AstEnumMember* # TODO: a pointer only because compiling the self-hosted compiler takes forever otherwise
enum_member: AstEnumMember
string: byte*
int_value: int
long_value: long
Expand Down Expand Up @@ -158,8 +158,8 @@ class AstExpression:
elif self->kind == AstExpressionKind::GetEnumMember:
printf(
"get member \"%s\" from enum \"%s\"\n",
&self->enum_member->member_name[0],
&self->enum_member->enum_name[0],
&self->enum_member.member_name[0],
&self->enum_member.enum_name[0],
)
elif self->kind == AstExpressionKind::As:
printf("as ")
Expand Down Expand Up @@ -548,13 +548,15 @@ enum AstToplevelStatementKind:
Import
Function
ClassDefinition
Enum
GlobalVariableDeclaration

class AstToplevelStatement:
union:
the_import: AstImport # must be placed in the beginning of the class
function: AstFunction
classdef: AstClassDef
enumdef: AstEnumDef
global_var: AstNameTypeValue

kind: AstToplevelStatementKind
Expand All @@ -577,6 +579,9 @@ class AstToplevelStatement:
elif self->kind == AstToplevelStatementKind::ClassDefinition:
printf("Define a ")
self->classdef.print()
elif self->kind == AstToplevelStatementKind::Enum:
printf("Define ")
self->enumdef.print()
elif self->kind == AstToplevelStatementKind::GlobalVariableDeclaration:
printf("Declare a global variable ")
self->global_var.print(NULL)
Expand All @@ -588,6 +593,8 @@ class AstToplevelStatement:
def free(self) -> void:
if self->kind == AstToplevelStatementKind::Import:
self->the_import.free()
if self->kind == AstToplevelStatementKind::Enum:
self->enumdef.free()

class AstFile:
path: byte* # not owned
Expand Down Expand Up @@ -662,3 +669,16 @@ class AstClassDef:
self->fields[i].free()
for i = 0; i < self->nmethods; i++:
self->methods[i].free()

class AstEnumDef:
name: byte[100]
member_count: int
member_names: byte[100]*

def print(self) -> void:
printf("enum \"%s\" with %d members:\n", &self->name[0], self->member_count)
for i = 0; i < self->member_count; i++:
printf(" %s\n", &self->member_names[i][0])

def free(self) -> void:
free(self->member_names)
204 changes: 196 additions & 8 deletions self_hosted/create_llvm_ir.jou
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ import "stdlib/io.jou"
import "stdlib/mem.jou"
import "stdlib/str.jou"


def build_signed_mod(builder: LLVMBuilder*, lhs: LLVMValue*, rhs: LLVMValue*) -> LLVMValue*:
# Jou's % operator ensures that a%b has same sign as b:
# jou_mod(a, b) = llvm_mod(llvm_mod(a, b) + b, b)
llvm_mod = LLVMBuildSRem(builder, lhs, rhs, "smod_tmp")
sum = LLVMBuildAdd(builder, llvm_mod, rhs, "smod_tmp")
return LLVMBuildSRem(builder, sum, rhs, "smod")

def build_signed_div(builder: LLVMBuilder*, lhs: LLVMValue*, rhs: LLVMValue*) -> LLVMValue*:
# LLVM's provides two divisions. One truncates, the other is an "exact div"
# that requires there is no remainder. Jou uses floor division which is
# neither of the two, but is quite easy to implement:
#
# floordiv(a, b) = exact_div(a - jou_mod(a, b), b)
top = LLVMBuildSub(builder, lhs, build_signed_mod(builder, lhs, rhs), "sdiv_tmp")
return LLVMBuildExactSDiv(builder, top, rhs, "sdiv")


class LocalVarNameAndPointer:
name: byte[100]
pointer: LLVMValue*
Expand All @@ -34,6 +52,8 @@ class AstToIR:
return LLVMIntType(type->size_in_bits)
if type->kind == TypeKind::Pointer:
return LLVMPointerType(self->do_type(type->value_type), 0)
if type->kind == TypeKind::Enum:
return LLVMInt32Type()
printf("asd-Asd., %s\n", &type->name)
assert False

Expand Down Expand Up @@ -73,12 +93,132 @@ class AstToIR:
return LLVMBuildBitCast(self->builder, global_var, string_type, "string_ptr")

def do_cast(self, obj: LLVMValue*, from: Type*, to: Type*) -> LLVMValue*:
# TODO: actually do this
assert from == byte_type
assert to == int_type
return LLVMBuildZExt(self->builder, obj, LLVMInt32Type(), "int_cast")
# Treat enums as just integers
if from->kind == TypeKind::Enum:
from = int_type
if to->kind == TypeKind::Enum:
to = int_type

if from == to:
return obj

if from->is_pointer_type() and to->is_pointer_type():
return LLVMBuildBitCast(self->builder, obj, self->do_type(to), "cast_ptr")

if from->is_integer_type() and to->is_integer_type():
# Examples:
# signed 8-bit 0xFF (-1) --> 16-bit 0xFFFF (-1 or max value)
# unsigned 8-bit 0xFF (255) --> 16-bit 0x00FF (255)
return LLVMBuildIntCast2(self->builder, obj, self->do_type(to), from->kind == TypeKind::SignedInteger, "cast_int")

# TODO: float/double <--> integer, float <--> double

if from == &bool_type and to->is_integer_type():
# True --> 1, False --> 0
return LLVMBuildZExt(self->builder, obj, self->do_type(to), "cast_bool_to_int")

def build_assert(self, condition: LLVMValue*) -> void:
printf("unimplemented cast: %s --> %s\n", &from->name[0], &to->name[0])
assert False

def do_binop(
self,
op: AstExpressionKind,
lhs: LLVMValue*,
lhs_type: Type*,
rhs: LLVMValue*,
rhs_type: Type*,
) -> LLVMValue*:
if lhs_type->kind == TypeKind::Enum:
lhs_type = int_type
if rhs_type->kind == TypeKind::Enum:
rhs_type = int_type

got_numbers = lhs_type->is_number_type() and rhs_type->is_number_type()
got_pointers = lhs_type->is_pointer_type() and rhs_type->is_pointer_type()
assert got_numbers or got_pointers

# if lhs_type->kind == TYPE_FLOATING_POINT and rhstype->kind == TYPE_FLOATING_POINT:
# if op == AstExpressionKind::Add:
# return LLVMBuildFAdd(self->builder, lhs, rhs, "add")
# if op == AstExpressionKind::Subtract:
# return LLVMBuildFSub(self->builder, lhs, rhs, "sub")
# if op == AstExpressionKind::Multiply:
# return LLVMBuildFMul(self->builder, lhs, rhs, "mul")
# if op == AstExpressionKind::Divide:
# return LLVMBuildFDiv(self->builder, lhs, rhs, "div")
# if op == AstExpressionKind::MOD:
# return LLVMBuildFRem(self->builder, lhs, rhs, "mod")
# if op == AstExpressionKind::Eq:
# return LLVMBuildFCmp(self->builder, LLVMRealOEQ, lhs, rhs, "eq")
# if op == AstExpressionKind::Ne:
# return LLVMBuildFCmp(self->builder, LLVMRealONE, lhs, rhs, "ne")
# if op == AstExpressionKind::Gt:
# return LLVMBuildFCmp(self->builder, LLVMRealOGT, lhs, rhs, "gt")
# if op == AstExpressionKind::Ge:
# return LLVMBuildFCmp(self->builder, LLVMRealOGE, lhs, rhs, "ge")
# if op == AstExpressionKind::Lt:
# return LLVMBuildFCmp(self->builder, LLVMRealOLT, lhs, rhs, "lt")
# if op == AstExpressionKind::Le:
# return LLVMBuildFCmp(self->builder, LLVMRealOLE, lhs, rhs, "le")
# assert False

if lhs_type->is_integer_type() and rhs_type->is_integer_type():
is_signed = lhs_type->kind == TypeKind::SignedInteger and rhs_type->kind == TypeKind::SignedInteger
if op == AstExpressionKind::Add:
return LLVMBuildAdd(self->builder, lhs, rhs, "add")
if op == AstExpressionKind::Subtract:
return LLVMBuildSub(self->builder, lhs, rhs, "sub")
if op == AstExpressionKind::Multiply:
return LLVMBuildMul(self->builder, lhs, rhs, "mul")
if op == AstExpressionKind::Divide:
if is_signed:
return build_signed_div(self->builder, lhs, rhs)
else:
return LLVMBuildUDiv(self->builder, lhs, rhs, "div")
if op == AstExpressionKind::Modulo:
if is_signed:
return build_signed_mod(self->builder, lhs, rhs)
else:
return LLVMBuildURem(self->builder, lhs, rhs, "mod")
if op == AstExpressionKind::Eq:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs, rhs, "eq")
if op == AstExpressionKind::Ne:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs, rhs, "ne")
if op == AstExpressionKind::Gt:
if is_signed:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::SGT, lhs, rhs, "gt")
else:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::UGT, lhs, rhs, "gt")
if op == AstExpressionKind::Ge:
if is_signed:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::SGE, lhs, rhs, "ge")
else:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::UGE, lhs, rhs, "ge")
if op == AstExpressionKind::Lt:
if is_signed:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::SLT, lhs, rhs, "lt")
else:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::ULT, lhs, rhs, "lt")
if op == AstExpressionKind::Le:
if is_signed:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::SLE, lhs, rhs, "le")
else:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::ULE, lhs, rhs, "le")
assert False

if lhs_type->is_pointer_type() and rhs_type->is_pointer_type():
lhs_int = LLVMBuildPtrToInt(self->builder, lhs, LLVMInt64Type(), "ptreq_lhs")
rhs_int = LLVMBuildPtrToInt(self->builder, rhs, LLVMInt64Type(), "ptreq_rhs")
if op == AstExpressionKind::Eq:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs_int, rhs_int, "ptreq")
if op == AstExpressionKind::Ne:
return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs_int, rhs_int, "ptreq")
assert False

printf("%s %d %s\n", &lhs_type->name[0], op, &rhs_type->name[0])
assert False

def do_assert(self, condition: LLVMValue*) -> void:
true_block = LLVMAppendBasicBlock(self->llvm_function, "assert_true")
false_block = LLVMAppendBasicBlock(self->llvm_function, "assert_false")
LLVMBuildCondBr(self->builder, condition, true_block, false_block)
Expand Down Expand Up @@ -144,8 +284,43 @@ class AstToIR:
elif ast->kind == AstExpressionKind::GetVariable:
pointer = self->do_address_of_expression(ast)
result = LLVMBuildLoad(self->builder, pointer, &ast->varname[0])
elif ast->kind == AstExpressionKind::GetEnumMember:
enum_type = self->file_types->find_type(&ast->enum_member.enum_name[0])
assert enum_type != NULL
i = enum_type->enum_members.find_index(&ast->enum_member.member_name[0])
assert i != -1
result = LLVMConstInt(LLVMInt32Type(), i, False)
elif ast->kind == AstExpressionKind::As:
value = self->do_expression(&ast->as_expression->value)
type_before_cast = self->function_or_method_types->get_expression_types(&ast->as_expression->value)->original_type
type_after_cast = self->function_or_method_types->get_expression_types(ast)->original_type
result = self->do_cast(value, type_before_cast, type_after_cast)

elif ast->kind == AstExpressionKind::Negate:
value = self->do_expression(&ast->operands[0])
result = LLVMBuildNeg(self->builder, value, "negate")

elif (
ast->kind == AstExpressionKind::Add
or ast->kind == AstExpressionKind::Subtract
or ast->kind == AstExpressionKind::Multiply
or ast->kind == AstExpressionKind::Divide
or ast->kind == AstExpressionKind::Modulo
or ast->kind == AstExpressionKind::Eq
or ast->kind == AstExpressionKind::Ne
or ast->kind == AstExpressionKind::Gt
or ast->kind == AstExpressionKind::Ge
or ast->kind == AstExpressionKind::Lt
or ast->kind == AstExpressionKind::Le
):
lhs = self->do_expression(&ast->operands[0])
rhs = self->do_expression(&ast->operands[1])
lhs_type = self->function_or_method_types->get_expression_types(&ast->operands[0])->original_type
rhs_type = self->function_or_method_types->get_expression_types(&ast->operands[1])->original_type
result = self->do_binop(ast->kind, lhs, lhs_type, rhs, rhs_type)

else:
printf("Asd-asd. Unknown expr %d...\n", ast->kind)
printf("create_llvm_ir: unknown expression kind %d...\n", ast->kind)
assert False

types = self->function_or_method_types->get_expression_types(ast)
Expand All @@ -166,7 +341,7 @@ class AstToIR:
self->new_block("after_return")
elif ast->kind == AstStatementKind::Assert:
condition = self->do_expression(&ast->expression)
self->build_assert(condition)
self->do_assert(condition)
elif ast->kind == AstStatementKind::Assign:
target_pointer = self->do_address_of_expression(&ast->assignment.target)
value = self->do_expression(&ast->assignment.value)
Expand All @@ -176,8 +351,21 @@ class AstToIR:
target_pointer = self->get_local_var_pointer(&ast->var_declaration.name[0])
value = self->do_expression(ast->var_declaration.value)
LLVMBuildStore(self->builder, value, target_pointer)
elif ast->kind == AstStatementKind::If:
# TODO: do this properly...
assert ast->if_statement.n_if_and_elifs == 1
assert ast->if_statement.else_body.nstatements == 0

condition = self->do_expression(&ast->if_statement.if_and_elifs[0].condition)
true_block = LLVMAppendBasicBlock(self->llvm_function, "then")
done_block = LLVMAppendBasicBlock(self->llvm_function, "if_done")
LLVMBuildCondBr(self->builder, condition, true_block, done_block)
LLVMPositionBuilderAtEnd(self->builder, true_block)
self->do_body(&ast->if_statement.if_and_elifs[0].body)
LLVMBuildBr(self->builder, done_block)
LLVMPositionBuilderAtEnd(self->builder, done_block)
else:
printf("Asd-asd. Unknown statement %d...\n", ast->kind)
printf("create_llvm_ir: unknown statement kind %d...\n", ast->kind)
assert False

def do_body(self, body: AstBody*) -> void:
Expand Down
5 changes: 5 additions & 0 deletions self_hosted/llvm.jou
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,13 @@ declare LLVMBuildCondBr(Builder: LLVMBuilder*, If: LLVMValue*, Then: LLVMBasicBl
declare LLVMBuildUnreachable(Builder: LLVMBuilder*) -> LLVMValue*
declare LLVMBuildAdd(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildSub(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildMul(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildUDiv(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildExactSDiv(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildURem(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildSRem(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildXor(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildNeg(Builder: LLVMBuilder*, V: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildMemSet(Builder: LLVMBuilder*, Ptr: LLVMValue*, Val: LLVMValue*, Len: LLVMValue*, Align: int) -> LLVMValue*
declare LLVMBuildAlloca(Builder: LLVMBuilder*, Ty: LLVMType*, Name: byte*) -> LLVMValue*
declare LLVMBuildLoad(Builder: LLVMBuilder*, PointerVal: LLVMValue*, Name: byte*) -> LLVMValue*
Expand All @@ -252,6 +256,7 @@ declare LLVMBuildUIToFP(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType
declare LLVMBuildSIToFP(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue*
declare LLVMBuildPtrToInt(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue*
declare LLVMBuildBitCast(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue*
declare LLVMBuildIntCast2(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, IsSigned: bool, Name: byte*) -> LLVMValue*
declare LLVMBuildFPCast(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue*
declare LLVMBuildICmp(Builder: LLVMBuilder*, Op: LLVMIntPredicate, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildFCmp(Builder: LLVMBuilder*, Op: LLVMRealPredicate, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
Expand Down
16 changes: 15 additions & 1 deletion self_hosted/main.jou
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,24 @@ class Compiler:
free(self->files[i].pending_exports)
self->files[i].pending_exports = NULL

def typecheck_stage1_all_files(self) -> void:
for i = 0; i < self->nfiles; i++:
if self->verbosity >= 1:
printf("Type-check stage 1: %s\n", self->files[i].ast.path)

assert self->files[i].pending_exports == NULL
self->files[i].pending_exports = typecheck_stage1_create_types(
&self->files[i].typectx,
&self->files[i].ast,
)

def typecheck_stage2_all_files(self) -> void:
for i = 0; i < self->nfiles; i++:
if self->verbosity >= 1:
printf("Type-check stage 2: %s\n", self->files[i].ast.path)

assert self->files[i].pending_exports == NULL
self->files[i].pending_exports = typecheck_stage2_signatures_globals_structbodies(
self->files[i].pending_exports = typecheck_stage2_populate_types(
&self->files[i].typectx,
&self->files[i].ast,
)
Expand Down Expand Up @@ -378,6 +389,9 @@ def main(argc: int, argv: byte**) -> int:
}
compiler.determine_automagic_files()
compiler.parse_all_files()

compiler.typecheck_stage1_all_files()
compiler.process_imports_and_exports()
compiler.typecheck_stage2_all_files()
compiler.process_imports_and_exports()
compiler.typecheck_stage3_all_files()
Expand Down
Loading

0 comments on commit a3f82bc

Please sign in to comment.