From a3f82bc3d8d5cc228d278a4dfb6feb2f55bb9778 Mon Sep 17 00:00:00 2001 From: Akuli Date: Mon, 13 Mar 2023 20:13:10 +0200 Subject: [PATCH] Enum for self-hosted compiler (#333) --- self_hosted/ast.jou | 26 ++- self_hosted/create_llvm_ir.jou | 204 +++++++++++++++++++++- self_hosted/llvm.jou | 5 + self_hosted/main.jou | 16 +- self_hosted/parser.jou | 30 +++- self_hosted/parses_wrong.txt | 5 - self_hosted/runs_wrong.txt | 4 - self_hosted/typecheck.jou | 227 +++++++++++++++++++++++-- self_hosted/types.jou | 42 +++++ src/codegen.c | 18 +- src/jou_compiler.h | 4 +- src/main.c | 2 +- src/parse.c | 3 + src/typecheck.c | 11 +- tests/syntax_error/bad_enum_member.jou | 2 + 15 files changed, 540 insertions(+), 59 deletions(-) create mode 100644 tests/syntax_error/bad_enum_member.jou diff --git a/self_hosted/ast.jou b/self_hosted/ast.jou index 5d2d2cf9..a467bd9f 100644 --- a/self_hosted/ast.jou +++ b/self_hosted/ast.jou @@ -114,7 +114,7 @@ class AstExpression: kind: AstExpressionKind union: - enum_member: AstEnumMember* # TODO: a pointer only because compiling the self-hosted compiler takes forever otherwise + enum_member: AstEnumMember string: byte* int_value: int long_value: long @@ -158,8 +158,8 @@ class AstExpression: elif self->kind == AstExpressionKind::GetEnumMember: printf( "get member \"%s\" from enum \"%s\"\n", - &self->enum_member->member_name[0], - &self->enum_member->enum_name[0], + &self->enum_member.member_name[0], + &self->enum_member.enum_name[0], ) elif self->kind == AstExpressionKind::As: printf("as ") @@ -548,6 +548,7 @@ enum AstToplevelStatementKind: Import Function ClassDefinition + Enum GlobalVariableDeclaration class AstToplevelStatement: @@ -555,6 +556,7 @@ class AstToplevelStatement: the_import: AstImport # must be placed in the beginning of the class function: AstFunction classdef: AstClassDef + enumdef: AstEnumDef global_var: AstNameTypeValue kind: AstToplevelStatementKind @@ -577,6 +579,9 @@ class AstToplevelStatement: elif self->kind == AstToplevelStatementKind::ClassDefinition: printf("Define a ") self->classdef.print() + elif self->kind == AstToplevelStatementKind::Enum: + printf("Define ") + self->enumdef.print() elif self->kind == AstToplevelStatementKind::GlobalVariableDeclaration: printf("Declare a global variable ") self->global_var.print(NULL) @@ -588,6 +593,8 @@ class AstToplevelStatement: def free(self) -> void: if self->kind == AstToplevelStatementKind::Import: self->the_import.free() + if self->kind == AstToplevelStatementKind::Enum: + self->enumdef.free() class AstFile: path: byte* # not owned @@ -662,3 +669,16 @@ class AstClassDef: self->fields[i].free() for i = 0; i < self->nmethods; i++: self->methods[i].free() + +class AstEnumDef: + name: byte[100] + member_count: int + member_names: byte[100]* + + def print(self) -> void: + printf("enum \"%s\" with %d members:\n", &self->name[0], self->member_count) + for i = 0; i < self->member_count; i++: + printf(" %s\n", &self->member_names[i][0]) + + def free(self) -> void: + free(self->member_names) diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index d9e889c8..bf836462 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -8,6 +8,24 @@ import "stdlib/io.jou" import "stdlib/mem.jou" import "stdlib/str.jou" + +def build_signed_mod(builder: LLVMBuilder*, lhs: LLVMValue*, rhs: LLVMValue*) -> LLVMValue*: + # Jou's % operator ensures that a%b has same sign as b: + # jou_mod(a, b) = llvm_mod(llvm_mod(a, b) + b, b) + llvm_mod = LLVMBuildSRem(builder, lhs, rhs, "smod_tmp") + sum = LLVMBuildAdd(builder, llvm_mod, rhs, "smod_tmp") + return LLVMBuildSRem(builder, sum, rhs, "smod") + +def build_signed_div(builder: LLVMBuilder*, lhs: LLVMValue*, rhs: LLVMValue*) -> LLVMValue*: + # LLVM's provides two divisions. One truncates, the other is an "exact div" + # that requires there is no remainder. Jou uses floor division which is + # neither of the two, but is quite easy to implement: + # + # floordiv(a, b) = exact_div(a - jou_mod(a, b), b) + top = LLVMBuildSub(builder, lhs, build_signed_mod(builder, lhs, rhs), "sdiv_tmp") + return LLVMBuildExactSDiv(builder, top, rhs, "sdiv") + + class LocalVarNameAndPointer: name: byte[100] pointer: LLVMValue* @@ -34,6 +52,8 @@ class AstToIR: return LLVMIntType(type->size_in_bits) if type->kind == TypeKind::Pointer: return LLVMPointerType(self->do_type(type->value_type), 0) + if type->kind == TypeKind::Enum: + return LLVMInt32Type() printf("asd-Asd., %s\n", &type->name) assert False @@ -73,12 +93,132 @@ class AstToIR: return LLVMBuildBitCast(self->builder, global_var, string_type, "string_ptr") def do_cast(self, obj: LLVMValue*, from: Type*, to: Type*) -> LLVMValue*: - # TODO: actually do this - assert from == byte_type - assert to == int_type - return LLVMBuildZExt(self->builder, obj, LLVMInt32Type(), "int_cast") + # Treat enums as just integers + if from->kind == TypeKind::Enum: + from = int_type + if to->kind == TypeKind::Enum: + to = int_type + + if from == to: + return obj + + if from->is_pointer_type() and to->is_pointer_type(): + return LLVMBuildBitCast(self->builder, obj, self->do_type(to), "cast_ptr") + + if from->is_integer_type() and to->is_integer_type(): + # Examples: + # signed 8-bit 0xFF (-1) --> 16-bit 0xFFFF (-1 or max value) + # unsigned 8-bit 0xFF (255) --> 16-bit 0x00FF (255) + return LLVMBuildIntCast2(self->builder, obj, self->do_type(to), from->kind == TypeKind::SignedInteger, "cast_int") + + # TODO: float/double <--> integer, float <--> double + + if from == &bool_type and to->is_integer_type(): + # True --> 1, False --> 0 + return LLVMBuildZExt(self->builder, obj, self->do_type(to), "cast_bool_to_int") - def build_assert(self, condition: LLVMValue*) -> void: + printf("unimplemented cast: %s --> %s\n", &from->name[0], &to->name[0]) + assert False + + def do_binop( + self, + op: AstExpressionKind, + lhs: LLVMValue*, + lhs_type: Type*, + rhs: LLVMValue*, + rhs_type: Type*, + ) -> LLVMValue*: + if lhs_type->kind == TypeKind::Enum: + lhs_type = int_type + if rhs_type->kind == TypeKind::Enum: + rhs_type = int_type + + got_numbers = lhs_type->is_number_type() and rhs_type->is_number_type() + got_pointers = lhs_type->is_pointer_type() and rhs_type->is_pointer_type() + assert got_numbers or got_pointers + +# if lhs_type->kind == TYPE_FLOATING_POINT and rhstype->kind == TYPE_FLOATING_POINT: +# if op == AstExpressionKind::Add: +# return LLVMBuildFAdd(self->builder, lhs, rhs, "add") +# if op == AstExpressionKind::Subtract: +# return LLVMBuildFSub(self->builder, lhs, rhs, "sub") +# if op == AstExpressionKind::Multiply: +# return LLVMBuildFMul(self->builder, lhs, rhs, "mul") +# if op == AstExpressionKind::Divide: +# return LLVMBuildFDiv(self->builder, lhs, rhs, "div") +# if op == AstExpressionKind::MOD: +# return LLVMBuildFRem(self->builder, lhs, rhs, "mod") +# if op == AstExpressionKind::Eq: +# return LLVMBuildFCmp(self->builder, LLVMRealOEQ, lhs, rhs, "eq") +# if op == AstExpressionKind::Ne: +# return LLVMBuildFCmp(self->builder, LLVMRealONE, lhs, rhs, "ne") +# if op == AstExpressionKind::Gt: +# return LLVMBuildFCmp(self->builder, LLVMRealOGT, lhs, rhs, "gt") +# if op == AstExpressionKind::Ge: +# return LLVMBuildFCmp(self->builder, LLVMRealOGE, lhs, rhs, "ge") +# if op == AstExpressionKind::Lt: +# return LLVMBuildFCmp(self->builder, LLVMRealOLT, lhs, rhs, "lt") +# if op == AstExpressionKind::Le: +# return LLVMBuildFCmp(self->builder, LLVMRealOLE, lhs, rhs, "le") +# assert False + + if lhs_type->is_integer_type() and rhs_type->is_integer_type(): + is_signed = lhs_type->kind == TypeKind::SignedInteger and rhs_type->kind == TypeKind::SignedInteger + if op == AstExpressionKind::Add: + return LLVMBuildAdd(self->builder, lhs, rhs, "add") + if op == AstExpressionKind::Subtract: + return LLVMBuildSub(self->builder, lhs, rhs, "sub") + if op == AstExpressionKind::Multiply: + return LLVMBuildMul(self->builder, lhs, rhs, "mul") + if op == AstExpressionKind::Divide: + if is_signed: + return build_signed_div(self->builder, lhs, rhs) + else: + return LLVMBuildUDiv(self->builder, lhs, rhs, "div") + if op == AstExpressionKind::Modulo: + if is_signed: + return build_signed_mod(self->builder, lhs, rhs) + else: + return LLVMBuildURem(self->builder, lhs, rhs, "mod") + if op == AstExpressionKind::Eq: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs, rhs, "eq") + if op == AstExpressionKind::Ne: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs, rhs, "ne") + if op == AstExpressionKind::Gt: + if is_signed: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::SGT, lhs, rhs, "gt") + else: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::UGT, lhs, rhs, "gt") + if op == AstExpressionKind::Ge: + if is_signed: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::SGE, lhs, rhs, "ge") + else: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::UGE, lhs, rhs, "ge") + if op == AstExpressionKind::Lt: + if is_signed: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::SLT, lhs, rhs, "lt") + else: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::ULT, lhs, rhs, "lt") + if op == AstExpressionKind::Le: + if is_signed: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::SLE, lhs, rhs, "le") + else: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::ULE, lhs, rhs, "le") + assert False + + if lhs_type->is_pointer_type() and rhs_type->is_pointer_type(): + lhs_int = LLVMBuildPtrToInt(self->builder, lhs, LLVMInt64Type(), "ptreq_lhs") + rhs_int = LLVMBuildPtrToInt(self->builder, rhs, LLVMInt64Type(), "ptreq_rhs") + if op == AstExpressionKind::Eq: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::EQ, lhs_int, rhs_int, "ptreq") + if op == AstExpressionKind::Ne: + return LLVMBuildICmp(self->builder, LLVMIntPredicate::NE, lhs_int, rhs_int, "ptreq") + assert False + + printf("%s %d %s\n", &lhs_type->name[0], op, &rhs_type->name[0]) + assert False + + def do_assert(self, condition: LLVMValue*) -> void: true_block = LLVMAppendBasicBlock(self->llvm_function, "assert_true") false_block = LLVMAppendBasicBlock(self->llvm_function, "assert_false") LLVMBuildCondBr(self->builder, condition, true_block, false_block) @@ -144,8 +284,43 @@ class AstToIR: elif ast->kind == AstExpressionKind::GetVariable: pointer = self->do_address_of_expression(ast) result = LLVMBuildLoad(self->builder, pointer, &ast->varname[0]) + elif ast->kind == AstExpressionKind::GetEnumMember: + enum_type = self->file_types->find_type(&ast->enum_member.enum_name[0]) + assert enum_type != NULL + i = enum_type->enum_members.find_index(&ast->enum_member.member_name[0]) + assert i != -1 + result = LLVMConstInt(LLVMInt32Type(), i, False) + elif ast->kind == AstExpressionKind::As: + value = self->do_expression(&ast->as_expression->value) + type_before_cast = self->function_or_method_types->get_expression_types(&ast->as_expression->value)->original_type + type_after_cast = self->function_or_method_types->get_expression_types(ast)->original_type + result = self->do_cast(value, type_before_cast, type_after_cast) + + elif ast->kind == AstExpressionKind::Negate: + value = self->do_expression(&ast->operands[0]) + result = LLVMBuildNeg(self->builder, value, "negate") + + elif ( + ast->kind == AstExpressionKind::Add + or ast->kind == AstExpressionKind::Subtract + or ast->kind == AstExpressionKind::Multiply + or ast->kind == AstExpressionKind::Divide + or ast->kind == AstExpressionKind::Modulo + or ast->kind == AstExpressionKind::Eq + or ast->kind == AstExpressionKind::Ne + or ast->kind == AstExpressionKind::Gt + or ast->kind == AstExpressionKind::Ge + or ast->kind == AstExpressionKind::Lt + or ast->kind == AstExpressionKind::Le + ): + lhs = self->do_expression(&ast->operands[0]) + rhs = self->do_expression(&ast->operands[1]) + lhs_type = self->function_or_method_types->get_expression_types(&ast->operands[0])->original_type + rhs_type = self->function_or_method_types->get_expression_types(&ast->operands[1])->original_type + result = self->do_binop(ast->kind, lhs, lhs_type, rhs, rhs_type) + else: - printf("Asd-asd. Unknown expr %d...\n", ast->kind) + printf("create_llvm_ir: unknown expression kind %d...\n", ast->kind) assert False types = self->function_or_method_types->get_expression_types(ast) @@ -166,7 +341,7 @@ class AstToIR: self->new_block("after_return") elif ast->kind == AstStatementKind::Assert: condition = self->do_expression(&ast->expression) - self->build_assert(condition) + self->do_assert(condition) elif ast->kind == AstStatementKind::Assign: target_pointer = self->do_address_of_expression(&ast->assignment.target) value = self->do_expression(&ast->assignment.value) @@ -176,8 +351,21 @@ class AstToIR: target_pointer = self->get_local_var_pointer(&ast->var_declaration.name[0]) value = self->do_expression(ast->var_declaration.value) LLVMBuildStore(self->builder, value, target_pointer) + elif ast->kind == AstStatementKind::If: + # TODO: do this properly... + assert ast->if_statement.n_if_and_elifs == 1 + assert ast->if_statement.else_body.nstatements == 0 + + condition = self->do_expression(&ast->if_statement.if_and_elifs[0].condition) + true_block = LLVMAppendBasicBlock(self->llvm_function, "then") + done_block = LLVMAppendBasicBlock(self->llvm_function, "if_done") + LLVMBuildCondBr(self->builder, condition, true_block, done_block) + LLVMPositionBuilderAtEnd(self->builder, true_block) + self->do_body(&ast->if_statement.if_and_elifs[0].body) + LLVMBuildBr(self->builder, done_block) + LLVMPositionBuilderAtEnd(self->builder, done_block) else: - printf("Asd-asd. Unknown statement %d...\n", ast->kind) + printf("create_llvm_ir: unknown statement kind %d...\n", ast->kind) assert False def do_body(self, body: AstBody*) -> void: diff --git a/self_hosted/llvm.jou b/self_hosted/llvm.jou index 8a189b91..0eaf2c85 100644 --- a/self_hosted/llvm.jou +++ b/self_hosted/llvm.jou @@ -234,9 +234,13 @@ declare LLVMBuildCondBr(Builder: LLVMBuilder*, If: LLVMValue*, Then: LLVMBasicBl declare LLVMBuildUnreachable(Builder: LLVMBuilder*) -> LLVMValue* declare LLVMBuildAdd(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildSub(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildMul(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildUDiv(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildExactSDiv(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildURem(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildSRem(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildXor(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildNeg(Builder: LLVMBuilder*, V: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildMemSet(Builder: LLVMBuilder*, Ptr: LLVMValue*, Val: LLVMValue*, Len: LLVMValue*, Align: int) -> LLVMValue* declare LLVMBuildAlloca(Builder: LLVMBuilder*, Ty: LLVMType*, Name: byte*) -> LLVMValue* declare LLVMBuildLoad(Builder: LLVMBuilder*, PointerVal: LLVMValue*, Name: byte*) -> LLVMValue* @@ -252,6 +256,7 @@ declare LLVMBuildUIToFP(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType declare LLVMBuildSIToFP(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* declare LLVMBuildPtrToInt(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* declare LLVMBuildBitCast(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildIntCast2(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, IsSigned: bool, Name: byte*) -> LLVMValue* declare LLVMBuildFPCast(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* declare LLVMBuildICmp(Builder: LLVMBuilder*, Op: LLVMIntPredicate, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildFCmp(Builder: LLVMBuilder*, Op: LLVMRealPredicate, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* diff --git a/self_hosted/main.jou b/self_hosted/main.jou index cce0f0a7..ea4d2142 100644 --- a/self_hosted/main.jou +++ b/self_hosted/main.jou @@ -198,13 +198,24 @@ class Compiler: free(self->files[i].pending_exports) self->files[i].pending_exports = NULL + def typecheck_stage1_all_files(self) -> void: + for i = 0; i < self->nfiles; i++: + if self->verbosity >= 1: + printf("Type-check stage 1: %s\n", self->files[i].ast.path) + + assert self->files[i].pending_exports == NULL + self->files[i].pending_exports = typecheck_stage1_create_types( + &self->files[i].typectx, + &self->files[i].ast, + ) + def typecheck_stage2_all_files(self) -> void: for i = 0; i < self->nfiles; i++: if self->verbosity >= 1: printf("Type-check stage 2: %s\n", self->files[i].ast.path) assert self->files[i].pending_exports == NULL - self->files[i].pending_exports = typecheck_stage2_signatures_globals_structbodies( + self->files[i].pending_exports = typecheck_stage2_populate_types( &self->files[i].typectx, &self->files[i].ast, ) @@ -378,6 +389,9 @@ def main(argc: int, argv: byte**) -> int: } compiler.determine_automagic_files() compiler.parse_all_files() + + compiler.typecheck_stage1_all_files() + compiler.process_imports_and_exports() compiler.typecheck_stage2_all_files() compiler.process_imports_and_exports() compiler.typecheck_stage3_all_files() diff --git a/self_hosted/parser.jou b/self_hosted/parser.jou index f99cfec8..daa3338b 100644 --- a/self_hosted/parser.jou +++ b/self_hosted/parser.jou @@ -205,8 +205,7 @@ def parse_elementary_expression(tokens: Token**) -> AstExpression: expr.call = parse_call(tokens, "(", ")") elif (*tokens)[1].is_operator("::") and (*tokens)[2].kind == TokenKind::Name: expr.kind = AstExpressionKind::GetEnumMember - expr.enum_member = malloc(sizeof *expr.enum_member) - *expr.enum_member = AstEnumMember{ + expr.enum_member = AstEnumMember{ enum_name = (*tokens)->short_string, member_name = (*tokens)[2].short_string, } @@ -617,7 +616,7 @@ def parse_funcdef(tokens: Token**) -> AstFunction: body = parse_body(tokens), } -def parse_classdef(tokens: Token**) -> AstClassDef: +def parse_class(tokens: Token**) -> AstClassDef: if (*tokens)->kind != TokenKind::Name: (*tokens)->fail_expected_got("a name for the class") @@ -639,6 +638,24 @@ def parse_classdef(tokens: Token**) -> AstClassDef: ++*tokens return result +def parse_enum(tokens: Token**) -> AstEnumDef: + if (*tokens)->kind != TokenKind::Name: + (*tokens)->fail_expected_got("a name for the enum") + + result = AstEnumDef{name = ((*tokens)++)->short_string} + + parse_start_of_body(tokens) + while (*tokens)->kind != TokenKind::Dedent: + if (*tokens)->kind != TokenKind::Name: + (*tokens)->fail_expected_got("a name for an enum member") + result.member_names = realloc(result.member_names, sizeof result.member_names[0] * (result.member_count + 1)) + result.member_names[result.member_count++] = (*tokens)->short_string + ++*tokens + eat_newline(tokens) + + ++*tokens + return result + def parse_toplevel_node(tokens: Token**, stdlib_path: byte*) -> AstToplevelStatement: ts = AstToplevelStatement{location = (*tokens)->location} @@ -672,7 +689,12 @@ def parse_toplevel_node(tokens: Token**, stdlib_path: byte*) -> AstToplevelState elif (*tokens)->is_keyword("class"): ++*tokens ts.kind = AstToplevelStatementKind::ClassDefinition - ts.classdef = parse_classdef(tokens) + ts.classdef = parse_class(tokens) + + elif (*tokens)->is_keyword("enum"): + ++*tokens + ts.kind = AstToplevelStatementKind::Enum + ts.enumdef = parse_enum(tokens) else: (*tokens)->fail_expected_got("a definition or declaration") diff --git a/self_hosted/parses_wrong.txt b/self_hosted/parses_wrong.txt index 28845651..d314238a 100644 --- a/self_hosted/parses_wrong.txt +++ b/self_hosted/parses_wrong.txt @@ -1,12 +1,10 @@ # This is a list of files that are not yet supported by the tokenizer or parser of the self-hosted compiler. -tests/404/enum_member.jou tests/404/method_on_int.jou tests/404/method_on_class.jou tests/404/method_on_class_ptr.jou tests/404/class_field.jou tests/already_exists_error/global_var_import.jou tests/already_exists_error/global_var.jou -tests/already_exists_error/class_and_enum.jou tests/other_errors/address_of_array_indexing.jou tests/other_errors/array0.jou tests/other_errors/duplicate_enum_member.jou @@ -16,7 +14,6 @@ tests/other_errors/var_shadow.jou tests/should_succeed/add_sub_mul_div_mod.jou tests/should_succeed/array.jou tests/should_succeed/as.jou -tests/should_succeed/enum.jou tests/should_succeed/expfloat.jou tests/should_succeed/file.jou tests/should_succeed/global.jou @@ -54,10 +51,8 @@ tests/wrong_type/arrow_operator_not_class.jou tests/wrong_type/brace_init_arg.jou tests/wrong_type/cannot_be_indexed.jou tests/wrong_type/dot_operator.jou -tests/wrong_type/enum_to_int.jou tests/wrong_type/float_and_double.jou tests/wrong_type/index.jou -tests/wrong_type/int_to_enum.jou tests/wrong_type/class_member_assign.jou tests/wrong_type/class_member_init.jou tests/should_succeed/linked_list.jou diff --git a/self_hosted/runs_wrong.txt b/self_hosted/runs_wrong.txt index 32b1e042..0a58b8b1 100644 --- a/self_hosted/runs_wrong.txt +++ b/self_hosted/runs_wrong.txt @@ -10,7 +10,6 @@ stdlib/str.jou stdlib/_windows_startup.jou tests/404/class_field.jou tests/404/enum.jou -tests/404/enum_member.jou tests/404/file.jou tests/404/method_on_class.jou tests/404/method_on_class_ptr.jou @@ -52,7 +51,6 @@ tests/should_succeed/as.jou tests/should_succeed/class.jou tests/should_succeed/compare.jou tests/should_succeed/compiler_cli.jou -tests/should_succeed/enum.jou tests/should_succeed/expfloat.jou tests/should_succeed/file.jou tests/should_succeed/global_bug.jou @@ -115,12 +113,10 @@ tests/wrong_type/enum_member_from_class.jou tests/wrong_type/enum_to_int.jou tests/wrong_type/float_and_double.jou tests/wrong_type/for.jou -tests/wrong_type/if.jou tests/wrong_type/index.jou tests/wrong_type/inplace_add_doesnt_go_back.jou tests/wrong_type/int_to_enum.jou tests/wrong_type/mod.jou -tests/wrong_type/neg.jou tests/wrong_type/not.jou tests/wrong_type/or.jou tests/wrong_type/plusplus.jou diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index 4586709a..b3070f3b 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -78,14 +78,15 @@ def fail_with_implicit_cast_error(location: Location, template: byte*, from: Typ enum ExportSymbolKind: Function - # TODO: exporting types, exporting global variables + Type class ExportSymbol: kind: ExportSymbolKind name: byte[100] - # TODO: union - signature: Signature # ExportSymbolKind::Function + union: + signature: Signature # ExportSymbolKind::Function + type: Type* # ExportSymbolKind::Type class ExpressionTypes: expression: AstExpression* @@ -137,8 +138,13 @@ class FileTypes: defined_functions: FunctionOrMethodTypes* n_defined_functions: int + types: Type** + ntypes: int + def add_imported_symbol(self, symbol: ExportSymbol*) -> void: - assert symbol->kind == ExportSymbolKind::Function + if symbol->kind != ExportSymbolKind::Function: + # TODO + return self->all_functions = realloc(self->all_functions, sizeof self->all_functions[0] * (self->n_all_functions + 1)) self->all_functions[self->n_all_functions++] = symbol->signature.copy() @@ -154,9 +160,33 @@ class FileTypes: return &self->defined_functions[i] return NULL -# TODO: implement -#def typecheck_stage1_create_types(ft: FileTypes*, file: AstFile*) -> ExportSymbol*: -# assert False + def find_type(self, name: byte*) -> Type*: + for i = 0; i < self->ntypes; i++: + if strcmp(&self->types[i]->name[0], name) == 0: + return self->types[i] + return NULL + +def typecheck_stage1_create_types(ft: FileTypes*, file: AstFile*) -> ExportSymbol*: + exports: ExportSymbol* = NULL + nexports = 0 + + for i = 0; i < file->body_len; i++: + if file->body[i].kind == AstToplevelStatementKind::Enum: + enumdef = &file->body[i].enumdef + t = create_enum(&enumdef->name[0], enumdef->member_count, enumdef->member_names) + ft->types = realloc(ft->types, (ft->ntypes + 1) * sizeof ft->types[0]) + ft->types[ft->ntypes++] = t + exports = realloc(exports, (nexports + 1) * sizeof exports[0]) + exports[nexports++] = ExportSymbol{ + kind = ExportSymbolKind::Type, + name = enumdef->name, + type = t, + } + + exports = realloc(exports, sizeof exports[0] * (nexports + 1)) + exports[nexports] = ExportSymbol{} + return exports + def type_from_ast(ft: FileTypes*, ast_type: AstType*) -> Type*: if ast_type->is_void(): @@ -175,6 +205,10 @@ def type_from_ast(ft: FileTypes*, ast_type: AstType*) -> Type*: return &bool_type # TODO: float, double + result = ft->find_type(&ast_type->name[0]) + if result != NULL: + return result + message: byte* = malloc(strlen(&ast_type->name[0]) + 100) sprintf(message, "there is no type named '%s'", &ast_type->name[0]) fail(ast_type->location, message) @@ -211,7 +245,7 @@ def handle_signature(ft: FileTypes*, astsig: AstSignature*) -> Signature: return sig # Returned array is terminated by ExportSymbol with empty name. -def typecheck_stage2_signatures_globals_structbodies(ft: FileTypes*, ast_file: AstFile*) -> ExportSymbol*: +def typecheck_stage2_populate_types(ft: FileTypes*, ast_file: AstFile*) -> ExportSymbol*: exports: ExportSymbol* = NULL nexports = 0 @@ -255,6 +289,128 @@ def nth(n: int) -> byte*: sprintf(result, "%dth", n) return result +def check_explicit_cast(from: Type*, to: Type*, location: Location) -> void: + if ( + from == to # TODO: should probably be error if it's the same type. + or (from->is_pointer_type() and to->is_pointer_type()) + or (from->is_number_type() and to->is_number_type()) + or (from->is_integer_type() and to->kind == TypeKind::Enum) + or (from->kind == TypeKind::Enum and to->is_integer_type()) # TODO: disallow for too small type + or (from == &bool_type and to->is_integer_type()) + ): + return + + message: byte[500] + snprintf( + &message[0], sizeof message, + "cannot cast from type %s to %s", &from->name[0], &to->name[0], + ) + fail(location, &message[0]) + +def very_short_type_description(t: Type*) -> byte*: + if t->kind == TypeKind::OpaqueClass: + assert False + if t->kind == TypeKind::Class: + return "a class" + if t->kind == TypeKind::Enum: + return "an enum" + if t->is_pointer_type(): + return "a pointer type" + if t->is_number_type(): + return "a number type" +# if t->kind == TypeKind::Array: +# return "an array type" + if t == &bool_type: + return "the built-in boolean type" + + assert False + +def max(a: int, b: int) -> int: + if a > b: + return a + return b + +def check_binop( + op: AstExpressionKind, + location: Location, + lhs_types: ExpressionTypes*, + rhs_types: ExpressionTypes*, +) -> Type*: + result_is_bool = False + if op == AstExpressionKind::Add: + do_what = "add" + elif op == AstExpressionKind::Subtract: + do_what = "subtract" + elif op == AstExpressionKind::Multiply: + do_what = "multiply" + elif op == AstExpressionKind::Divide: + do_what = "divide" + elif op == AstExpressionKind::Modulo: + do_what = "take remainder with" + else: + assert ( + op == AstExpressionKind::Eq + or op == AstExpressionKind::Ne + or op == AstExpressionKind::Gt + or op == AstExpressionKind::Ge + or op == AstExpressionKind::Lt + or op == AstExpressionKind::Le + ) + do_what = "compare" + result_is_bool = True + + got_integers = lhs_types->original_type->is_integer_type() and rhs_types->original_type->is_integer_type() + got_numbers = lhs_types->original_type->is_number_type() and rhs_types->original_type->is_number_type() + got_enums = lhs_types->original_type->kind == TypeKind::Enum and rhs_types->original_type->kind == TypeKind::Enum + got_pointers = ( + lhs_types->original_type->is_pointer_type() + and rhs_types->original_type->is_pointer_type() + and ( + # Ban comparisons like int* == byte*, unless one of the two types is void* + lhs_types->original_type == rhs_types->original_type + or lhs_types->original_type == &void_ptr_type + or rhs_types->original_type == &void_ptr_type + ) + ) + + if ( + (not got_numbers and not got_enums and not got_pointers) + or (op != AstExpressionKind::Eq and op != AstExpressionKind::Ne and not got_numbers) + ): + message: byte[500] + snprintf( + &message[0], sizeof message, + "wrong types: cannot %s %s and %s", + do_what, &lhs_types->original_type->name[0], &rhs_types->original_type->name[0], + ) + fail(location, &message[0]) + + if got_integers: + size = max(lhs_types->original_type->size_in_bits, rhs_types->original_type->size_in_bits) + if ( + lhs_types->original_type->kind == TypeKind::SignedInteger + or rhs_types->original_type->kind == TypeKind::SignedInteger + ): + cast_type = &signed_integers[size] + else: + cast_type = &unsigned_integers[size] + elif got_numbers: + assert False # TODO: use float/double + elif got_pointers: + cast_type = &void_ptr_type + elif got_enums: + cast_type = int_type + else: + assert False + + lhs_types->do_implicit_cast(cast_type, Location{}, NULL) + rhs_types->do_implicit_cast(cast_type, Location{}, NULL) + + if result_is_bool: + return &bool_type + else: + return cast_type + class Stage3TypeChecker: file_types: FileTypes* @@ -335,6 +491,7 @@ class Stage3TypeChecker: def do_expression_maybe_void(self, expression: AstExpression*) -> ExpressionTypes*: result: Type* + message: byte[200] if expression->kind == AstExpressionKind::String: result = byte_type->get_pointer_type() @@ -351,9 +508,46 @@ class Stage3TypeChecker: elif expression->kind == AstExpressionKind::GetVariable: result = self->find_var(&expression->varname[0]) if result == NULL: - message: byte[200] snprintf(&message[0], sizeof message, "no variable named '%s'", &expression->varname[0]) fail(expression->location, &message[0]) + elif expression->kind == AstExpressionKind::As: + old_type = self->do_expression(&expression->as_expression->value)->original_type + result = type_from_ast(self->file_types, &expression->as_expression->type) + check_explicit_cast(old_type, result, expression->location) + elif expression->kind == AstExpressionKind::GetEnumMember: + result = self->file_types->find_type(&expression->enum_member.enum_name[0]) + if result == NULL: + printf("find_type(%s) returned NULL\n", &expression->enum_member.enum_name[0]) + snprintf(&message[0], sizeof message, "there is no type named '%s'", &expression->enum_member.enum_name[0]) + fail(expression->location, &message[0]) + if result->kind != TypeKind::Enum: + snprintf( + &message[0], sizeof message, + "the '::' syntax is only for enums, but %s is %s", + &expression->enum_member.enum_name[0], very_short_type_description(result), + ) + fail(expression->location, &message[0]) + if result->enum_members.find_index(&expression->enum_member.member_name[0]) == -1: + snprintf( + &message[0], sizeof message, + "enum %s has no member named '%s'", + &expression->enum_member.enum_name[0], &expression->enum_member.member_name[0], + ) + fail(expression->location, &message[0]) + elif expression->kind == AstExpressionKind::Eq: + lhs_types = self->do_expression(&expression->operands[0]) + rhs_types = self->do_expression(&expression->operands[1]) + result = check_binop(expression->kind, expression->location, lhs_types, rhs_types) + elif expression->kind == AstExpressionKind::Negate: + result = self->do_expression(&expression->operands[0])->original_type + # TODO: check for floats/doubles too + if result->kind != TypeKind::SignedInteger: + snprintf( + &message[0], sizeof message, + "value after '-' must be a float or double or a signed integer, not %s", + &result->name[0], + ) + fail(expression->location, &message[0]) else: printf("*** %d\n", expression->kind as int) assert False @@ -465,10 +659,21 @@ class Stage3TypeChecker: "initial value for variable of type TO cannot be of type FROM", ) + elif statement->kind == AstStatementKind::If: + for i = 0; i < statement->if_statement.n_if_and_elifs; i++: + self->do_expression_and_implicit_cast( + &statement->if_statement.if_and_elifs[i].condition, + &bool_type, + "'if' condition must be a boolean, not ", + ) + self->do_body(&statement->if_statement.if_and_elifs[i].body) + self->do_body(&statement->if_statement.else_body) + else: + printf("*** typecheck: unknown statement kind %d\n", statement->kind) assert False - def typecheck_body(self, body: AstBody*) -> void: + def do_body(self, body: AstBody*) -> void: for i = 0; i < body->nstatements; i++: self->do_statement(&body->statements[i]) @@ -491,5 +696,5 @@ def typecheck_stage3_function_and_method_bodies(file_types: FileTypes*, ast_file for k = 0; k < sig->nargs; k++: checker.add_local_var(&sig->argnames[k][0], sig->argtypes[k]) - checker.typecheck_body(&ts->function.body) + checker.do_body(&ts->function.body) checker.current_function_or_method = NULL diff --git a/self_hosted/types.jou b/self_hosted/types.jou index 13b1547e..96712775 100644 --- a/self_hosted/types.jou +++ b/self_hosted/types.jou @@ -7,6 +7,20 @@ enum TypeKind: UnsignedInteger Pointer VoidPointer + Class + OpaqueClass + Enum + +class EnumMembers: + count: int + names: byte[100]* + + # Returns -1 for not found + def find_index(self, name: byte*) -> int: + for i = 0; i < self->count; i++: + if strcmp(&self->names[i][0], name) == 0: + return i + return -1 class Type: name: byte[100] @@ -15,6 +29,7 @@ class Type: union: size_in_bits: int # SignedInteger, UnsignedInteger value_type: Type* # Pointer + enum_members: EnumMembers # Enum # Pointers and arrays of a given type live as long as the type itself. # To make it possible, we just store them within the type. @@ -27,6 +42,13 @@ class Type: def is_integer_type(self) -> bool: return self->kind == TypeKind::SignedInteger or self->kind == TypeKind::UnsignedInteger + def is_number_type(self) -> bool: + # TODO: accept floats/doubles as numbers + return self->is_integer_type() + + def is_pointer_type(self) -> bool: + return self->kind == TypeKind::Pointer or self->kind == TypeKind::VoidPointer + def get_pointer_type(self) -> Type*: if self->cached_pointer_type == NULL: pointer_name: byte[100] @@ -77,6 +99,26 @@ def init_types() -> void: strcpy(&int_type->name[0], "int") strcpy(&long_type->name[0], "long") +def create_opaque_class(name: byte*) -> Type*: + result: Type* = malloc(sizeof *result) + *result = Type{kind = TypeKind::OpaqueClass} + assert strlen(name) < sizeof result->name + strcpy(&result->name[0], name) + return result + +def create_enum(name: byte*, member_count: int, member_names: byte[100]*) -> Type*: + copied_member_names: byte[100]* = malloc(member_count * sizeof copied_member_names[0]) + memcpy(copied_member_names, member_names, member_count * sizeof copied_member_names[0]) + + result: Type* = malloc(sizeof *result) + *result = Type{ + kind = TypeKind::Enum, + enum_members = EnumMembers{count = member_count, names = copied_member_names}, + } + assert strlen(name) < sizeof result->name + strcpy(&result->name[0], name) + return result + class Signature: name: byte[100] # name of function or method, after "def" keyword diff --git a/src/codegen.c b/src/codegen.c index c8e7446f..a8f35feb 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -367,20 +367,10 @@ static void codegen_instruction(const struct State *st, const CfInstruction *ins assert(is_number_type(from) && is_number_type(to)); if (is_integer_type(from) && is_integer_type(to)) { - if (from->data.width_in_bits < to->data.width_in_bits) { - if (from->kind == TYPE_SIGNED_INTEGER) { - // example: signed 8-bit 0xFF --> 16-bit 0xFFFF - setdest(LLVMBuildSExt(st->builder, getop(0), codegen_type(to), "int_cast")); - } else { - // example: unsigned 8-bit 0xFF --> 16-bit 0x00FF - setdest(LLVMBuildZExt(st->builder, getop(0), codegen_type(to), "int_cast")); - } - } else if (from->data.width_in_bits > to->data.width_in_bits) { - setdest(LLVMBuildTrunc(st->builder, getop(0), codegen_type(to), "int_cast")); - } else { - // same size, LLVM doesn't distinguish signed and unsigned integer types - setdest(getop(0)); - } + // Examples: + // signed 8-bit 0xFF (-1) --> 16-bit 0xFFFF (-1 or max value) + // unsigned 8-bit 0xFF (255) --> 16-bit 0x00FF (255) + setdest(LLVMBuildIntCast2(st->builder, getop(0), codegen_type(to), from->kind == TYPE_SIGNED_INTEGER, "int_cast")); } else if (is_integer_type(from) && to->kind == TYPE_FLOATING_POINT) { // integer --> double / float if (from->kind == TYPE_SIGNED_INTEGER) diff --git a/src/jou_compiler.h b/src/jou_compiler.h index 02398e78..bac0135c 100644 --- a/src/jou_compiler.h +++ b/src/jou_compiler.h @@ -458,7 +458,7 @@ struct ExpressionTypes { struct ExportSymbol { enum ExportSymbolKind { EXPSYM_FUNCTION, EXPSYM_TYPE, EXPSYM_GLOBAL_VAR } kind; - char name[200]; // For methods this is "StructName.method_name" + char name[200]; union { Signature funcsignature; const Type *type; // EXPSYM_TYPE and EXPSYM_GLOBAL_VAR @@ -503,7 +503,7 @@ The list is terminated with (ExportSymbol){0}, which you can detect by checking if the name of the ExportSymbol is empty. */ ExportSymbol *typecheck_stage1_create_types(FileTypes *ft, const AstToplevelNode *ast); -ExportSymbol *typecheck_stage2_signatures_globals_structbodies(FileTypes *ft, const AstToplevelNode *ast); +ExportSymbol *typecheck_stage2_populate_types(FileTypes *ft, const AstToplevelNode *ast); void typecheck_stage3_function_and_method_bodies(FileTypes *ft, const AstToplevelNode *ast); diff --git a/src/main.c b/src/main.c index 8c8b3149..955cd720 100644 --- a/src/main.c +++ b/src/main.c @@ -456,7 +456,7 @@ int main(int argc, char **argv) for (struct FileState *fs = compst.files.ptr; fs < End(compst.files); fs++) { if (command_line_args.verbosity >= 1) printf(" stage 2: %s\n", fs->path); - fs->pending_exports = typecheck_stage2_signatures_globals_structbodies(&fs->types, fs->ast); + fs->pending_exports = typecheck_stage2_populate_types(&fs->types, fs->ast); } add_imported_symbols(&compst); for (struct FileState *fs = compst.files.ptr; fs < End(compst.files); fs++) { diff --git a/src/parse.c b/src/parse.c index 76427a70..a70f1c0d 100644 --- a/src/parse.c +++ b/src/parse.c @@ -897,6 +897,9 @@ static AstEnumDef parse_enumdef(const Token **tokens) List(const char*) membernames = {0}; while ((*tokens)->type != TOKEN_DEDENT) { + if ((*tokens)->type != TOKEN_NAME) + fail_with_parse_error(*tokens, "a name for an enum member"); + for (const char **old = membernames.ptr; old < End(membernames); old++) if (!strcmp(*old, (*tokens)->data.name)) fail_with_error((*tokens)->location, "the enum has two members named '%s'", (*tokens)->data.name); diff --git a/src/typecheck.c b/src/typecheck.c index 50d7b4c8..829ee55f 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -279,7 +279,7 @@ static const Type *handle_class_members_stage2(FileTypes *ft, const AstClassDef return type; } -ExportSymbol *typecheck_stage2_signatures_globals_structbodies(FileTypes *ft, const AstToplevelNode *ast) +ExportSymbol *typecheck_stage2_populate_types(FileTypes *ft, const AstToplevelNode *ast) { List(ExportSymbol) exports = {0}; @@ -488,11 +488,10 @@ static const Type *check_binop( ) ); - if(!( - got_integers - || got_numbers - || ((got_enums || got_pointers) && (op == AST_EXPR_EQ || op == AST_EXPR_NE)) - )) + if ( + (!got_numbers && !got_enums && !got_pointers) + || (op != AST_EXPR_EQ && op != AST_EXPR_NE && !got_numbers) + ) fail_with_error(location, "wrong types: cannot %s %s and %s", do_what, lhstypes->type->name, rhstypes->type->name); const Type *cast_type = NULL; diff --git a/tests/syntax_error/bad_enum_member.jou b/tests/syntax_error/bad_enum_member.jou new file mode 100644 index 00000000..14520974 --- /dev/null +++ b/tests/syntax_error/bad_enum_member.jou @@ -0,0 +1,2 @@ +enum Foo: + 123 # Error: expected a name for an enum member, got an integer