diff --git a/builtin_packages/_array.prajna b/builtin_packages/_array.prajna index bf7bccdb..fa71a91d 100644 --- a/builtin_packages/_array.prajna +++ b/builtin_packages/_array.prajna @@ -50,3 +50,25 @@ implement Array { } } +template +func __ArrayOne()->Array { + var re: Array; + for i in 0 to re.Length() { + re[i] = 1; + } + + return re; +} + +template +func __CacluateArrayIndex(A: Array, Row_>, B: Array, Index: Array) -> Array { + var re: Array; + for i in 0 to Row_ { + re[i] = B[i]; + for j in 0 to Column_ { + re[i] = re[i] + A[i][j] * Index[j]; + } + } + + return re; +} diff --git a/builtin_packages/optional.prajna b/builtin_packages/optional.prajna index 1dfe1254..399e1d97 100644 --- a/builtin_packages/optional.prajna +++ b/builtin_packages/optional.prajna @@ -15,8 +15,10 @@ implement Optional { } func __get__Value()->ValueType{ + "hit\n".Print(); if (!this.HasValue()) { - debug::AssertWithMessage(false, "has no value"); + // TODO(张志敏): 属性赋值存在bug, 后续需要修复 + // debug::AssertWithMessage(false, "has no value"); } return this._value; } @@ -33,7 +35,7 @@ implement Optional { @static func Create(value: ValueType) -> Optional { var optional: Optional; - optional.Value = value; + optional._value = value; return optional; } } diff --git a/examples_in_cpp/add.prajna b/examples_in_cpp/add.prajna index 4911e8d9..a2174b80 100644 --- a/examples_in_cpp/add.prajna +++ b/examples_in_cpp/add.prajna @@ -1,12 +1,17 @@ -func helloWorld(ts: Array*){ - ts->ToString().Print(); +func HelloWorld(){ + "Hello World!\n".Print(); } -func matrixAddF32(p_ts0: Tensor *, p_ts1: Tensor*, p_ts_re: Tensor*){ +func MatrixAddF32(p_ts0: ptr>, p_ts1: ptr>, p_ts_re: ptr>){ var ts0 = *p_ts0; var ts1 = *p_ts1; var ts_re = *p_ts_re; + for i in 0 to ts0.layout.shape[0] { + for j in 0 to ts0.layout.shape[1] { + ts_re[i, j] = ts0[i, j] + ts1[i, j]; + } + } for idx in [0, 0] to ts0.layout.shape{ ts_re.At(idx) = ts0.At(idx) + ts1.At(idx); } diff --git a/examples_in_cpp/main.cpp b/examples_in_cpp/main.cpp index 46cc7186..33dde205 100644 --- a/examples_in_cpp/main.cpp +++ b/examples_in_cpp/main.cpp @@ -5,8 +5,13 @@ int main() { auto compiler = prajna::Compiler::Create(); - compiler->CompileBuiltinSourceFiles("prajna/builtin_packages"); - compiler->CompileProgram("examples/prajna_in_cpp/add.prajna", false); + compiler->CompileBuiltinSourceFiles("builtin_packages"); + compiler->AddPackageDirectoryPath("."); + compiler->CompileProgram("examples_in_cpp/add.prajna", false); + + auto hello_world = reinterpret_cast( + compiler->GetSymbolValue("::examples_in_cpp::add::HelloWorld")); + hello_world(); using MatrixF32 = prajna::Tensor; @@ -18,10 +23,9 @@ int main() { ts(1, 0) = 4; ts(1, 1) = 5; ts(1, 2) = 6; - // 获取Prajna里的函数地址 auto matrix_add_f32 = reinterpret_cast( - compiler->GetSymbolValue("::examples::prajna_in_cpp::add::matrixAddF32")); + compiler->GetSymbolValue("::examples_in_cpp::add::MatrixAddF32")); auto ts_re = MatrixF32::Create(shape); diff --git a/prajna/CMakeLists.txt b/prajna/CMakeLists.txt index 0fe2a35c..e2dd3ab9 100644 --- a/prajna/CMakeLists.txt +++ b/prajna/CMakeLists.txt @@ -13,9 +13,18 @@ target_link_libraries(prajna_config_target INTERFACE Boost::dll INTERFACE Boost::process INTERFACE Boost::asio + INTERFACE Boost::scope INTERFACE fmt::fmt ) +target_compile_definitions(prajna_config_target + # 参阅third_party/boost/libs/mpl/include/boost/mpl/list/list50.hpp + # 并非任意数字都行, 需要是10,20,30,40,50 + # boost::variant的模板参数受此影响 + INTERFACE BOOST_MPL_LIMIT_LIST_SIZE=50 + INTERFACE BOOST_MPL_CFG_NO_PREPROCESSED_HEADERS +) + if (MSVC) target_compile_options(prajna_config_target INTERFACE "/bigobj") else () diff --git a/prajna/assert.hpp b/prajna/assert.hpp index b42e6d0a..10ac2ff4 100644 --- a/prajna/assert.hpp +++ b/prajna/assert.hpp @@ -10,7 +10,7 @@ namespace prajna { class assert_failed : public std::exception { public: assert_failed(const std::string& expr, const std::string& function, const std::string& file, - size_t line, bool verify, const std::string& msg) + int64_t line, bool verify, const std::string& msg) : _expr(expr), _function(function), _file(file), _line(line), _msg(msg) { std::string with_msg; if (!msg.empty()) { @@ -28,7 +28,7 @@ class assert_failed : public std::exception { std::string _expr; std::string _function; std::string _file; - size_t _line; + int64_t _line; std::string _msg; std::string _what_str; }; diff --git a/prajna/bindings/core.hpp b/prajna/bindings/core.hpp index 3f7f414a..affe73fe 100644 --- a/prajna/bindings/core.hpp +++ b/prajna/bindings/core.hpp @@ -20,6 +20,27 @@ class Array { Type_& operator[](i64 offset) { return data[offset]; } }; +template +struct Ptr { + Type* raw_ptr; + int64_t size; + int64_t* _reference_counter; + + static Ptr Allocate(int64_t size) { + Ptr self; + self.raw_ptr = new Type[size]; + self.size = size; + + self._reference_counter = new int64_t; + *self._reference_counter = 1; + return self; + } + + Type& operator[](int64_t offset) { return this->raw_ptr[offset]; } + + // ~Ptr() { delete[] this->raw_ptr; } +}; + template class Layout { public: @@ -68,7 +89,7 @@ class Layout { template class Tensor { public: - Type_* data = nullptr; + Ptr data; Layout layout; protected: @@ -79,9 +100,9 @@ class Tensor { Tensor self; self.layout = Layout::Create(shape); - auto bytes = self.layout.Length() * sizeof(Type_); - self.data = reinterpret_cast(malloc(bytes)); + self.data = Ptr::Allocate(self.layout.Length()); + // TODO(zhangzhimin): 初始化和释放还没有处理 // RegisterReferenceCount(self.data); // __copy__(self.data); @@ -101,7 +122,7 @@ class Tensor { // } } - const Type_& at(Array idx) const { + Type_& at(Array idx) const { i64 offset = this->layout.ArrayIndexToLinearIndex(idx); return this->data[offset]; } @@ -112,7 +133,7 @@ class Tensor { } template - const Type_& operator()(Idx_... indices) const { + Type_& operator()(Idx_... indices) const { Array idx(indices...); return this->at(idx); } diff --git a/prajna/codegen/llvm_codegen.cpp b/prajna/codegen/llvm_codegen.cpp index aea3c203..4f2f23be 100644 --- a/prajna/codegen/llvm_codegen.cpp +++ b/prajna/codegen/llvm_codegen.cpp @@ -122,6 +122,13 @@ class LlvmCodegen { llvm::ArrayType::get(ir_array_type->value_type->llvm_type, ir_array_type->size); return; } + if (auto ir_vector_type = Cast(ir_type)) { + this->EmitType(ir_vector_type->value_type); + ir_vector_type->llvm_type = + llvm::VectorType::get(ir_vector_type->value_type->llvm_type, + llvm::ElementCount::getFixed(ir_vector_type->size)); + return; + } if (auto ir_struct_type = Cast(ir_type)) { auto llvm_struct_type = llvm::StructType::create(static_llvm_context, ir_struct_type->fullname); @@ -133,7 +140,13 @@ class LlvmCodegen { return field->type->llvm_type; }); llvm_struct_type->setBody(llvm_types, true); - + return; + } + if (auto ir_simd_type = Cast(ir_type)) { + this->EmitType(ir_simd_type->value_type); + ir_simd_type->llvm_type = + llvm::VectorType::get(ir_simd_type->value_type->llvm_type, + llvm::ElementCount::getFixed(ir_simd_type->size)); return; } @@ -190,11 +203,24 @@ class LlvmCodegen { llvm::Function *llvm_fun = static_cast(ir_function->llvm_value); PRAJNA_ASSERT(llvm_fun); PRAJNA_ASSERT(ir_function->parameters.size() == llvm_fun->arg_size()); - size_t i = 0; + int64_t i = 0; auto iter_parameter = ir_function->parameters.begin(); for (auto llvm_arg = llvm_fun->arg_begin(); llvm_arg != llvm_fun->arg_end(); ++llvm_arg, ++iter_parameter) { - (*iter_parameter)->llvm_value = llvm_arg; + auto ir_parameter = *iter_parameter; + if (ir_parameter->no_alias) { + llvm_arg->addAttr(llvm::Attribute::NoAlias); + } + if (ir_parameter->no_capture) { + llvm_arg->addAttr(llvm::Attribute::NoCapture); + } + if (ir_parameter->no_undef) { + llvm_arg->addAttr(llvm::Attribute::NoUndef); + } + if (ir_parameter->readonly) { + llvm_arg->addAttr(llvm::Attribute::ReadOnly); + } + ir_parameter->llvm_value = llvm_arg; } for (auto block : ir_function->blocks) { @@ -337,6 +363,21 @@ class LlvmCodegen { return; } + if (auto ir_constant_vector = Cast(ir_constant)) { + std::vector llvm_contants( + ir_constant_vector->initialize_constants.size()); + std::transform(RANGE(ir_constant_vector->initialize_constants), llvm_contants.begin(), + [=](auto ir_init) { + auto llvm_constant = + static_cast(ir_init->llvm_value); + PRAJNA_ASSERT(llvm_constant); + return llvm_constant; + }); + PRAJNA_ASSERT(ir_constant_vector->type->llvm_type); + ir_constant_vector->llvm_value = llvm::ConstantVector::get(llvm_contants); + return; + } + PRAJNA_TODO; } @@ -364,7 +405,7 @@ class LlvmCodegen { if (auto ir_call = Cast(ir_instruction)) { auto ir_function_type = ir_call->Function()->GetFunctionType(); std::vector llvm_arguments(ir_call->ArgumentSize()); - for (size_t i = 0; i < llvm_arguments.size(); ++i) { + for (int64_t i = 0; i < llvm_arguments.size(); ++i) { llvm_arguments[i] = ir_call->Argument(i)->llvm_value; PRAJNA_ASSERT(llvm_arguments[i]); } @@ -383,12 +424,11 @@ class LlvmCodegen { return; } if (auto ir_alloca = Cast(ir_instruction)) { - // Align需要设置一个合理的值, 目前先设置为8 - // PRAJNA_ASSERT() auto ir_alloca_type = Cast(ir_alloca->type); PRAJNA_ASSERT(ir_alloca_type && ir_alloca_type->value_type->llvm_type); - ir_alloca->llvm_value = new llvm::AllocaInst(ir_alloca_type->value_type->llvm_type, 0, - ir_alloca->name, llvm_basic_block); + ir_alloca->llvm_value = new llvm::AllocaInst( + ir_alloca_type->value_type->llvm_type, 0, ir_alloca->Length()->llvm_value, + llvm::Align(16), ir_alloca->name, llvm_basic_block); return; } @@ -524,7 +564,7 @@ class LlvmCodegen { PRAJNA_ASSERT(cast_operator_dict.count(ir_cast_instruction->operation)); auto cast_op = cast_operator_dict[ir_cast_instruction->operation]; ir_cast_instruction->llvm_value = - llvm::CastInst::Create(cast_op, ir_cast_instruction->operand(0)->llvm_value, + llvm::CastInst::Create(cast_op, ir_cast_instruction->GetOperand(0)->llvm_value, ir_cast_instruction->type->llvm_type, "", llvm_basic_block); return; } @@ -602,8 +642,8 @@ class LlvmCodegen { llvm_compare_predicator_dict[ir_compare_instruction->operation]; ir_compare_instruction->llvm_value = llvm::CmpInst::Create( llvm_compare_other_ops, llvm_compare_predicator, - ir_compare_instruction->operand(0)->llvm_value, - ir_compare_instruction->operand(1)->llvm_value, "", llvm_basic_block); + ir_compare_instruction->GetOperand(0)->llvm_value, + ir_compare_instruction->GetOperand(1)->llvm_value, "", llvm_basic_block); return; } @@ -636,12 +676,19 @@ class LlvmCodegen { auto llvm_binary_operator_operation = binary_operator_dict[ir_binary_operator->operation]; ir_binary_operator->llvm_value = llvm::BinaryOperator::Create( - llvm_binary_operator_operation, ir_binary_operator->operand(0)->llvm_value, - ir_binary_operator->operand(1)->llvm_value, "", llvm_basic_block); + llvm_binary_operator_operation, ir_binary_operator->GetOperand(0)->llvm_value, + ir_binary_operator->GetOperand(1)->llvm_value, "", llvm_basic_block); return; } + if (auto ir_shuffle_vector = Cast(ir_instruction)) { + ir_shuffle_vector->llvm_value = new llvm::ShuffleVectorInst( + ir_shuffle_vector->Value()->llvm_value, ir_shuffle_vector->Mask()->llvm_value, "", + llvm_basic_block); + return; + } + PRAJNA_ASSERT(false, ir_instruction->tag); } }; @@ -701,7 +748,7 @@ std::shared_ptr LlvmPass(std::shared_ptr ir_module) { PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); llvm::ModulePassManager MPM = - PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3, true); + PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O2, true); MPM.run(*ir_module->llvm_module, MAM); diff --git a/prajna/compiler/compiler.cpp b/prajna/compiler/compiler.cpp index 219ccf02..59b92b84 100644 --- a/prajna/compiler/compiler.cpp +++ b/prajna/compiler/compiler.cpp @@ -98,6 +98,16 @@ std::shared_ptr Compiler::CompileCode( return ir_lowering_module; } +void Compiler::GenLlvm(std::shared_ptr ir_module) { + auto ir_ssa_module = prajna::transform::transform(ir_module); + auto ir_codegen_module = prajna::codegen::LlvmCodegen(ir_ssa_module, ir::Target::host); + auto ir_llvm_optimize_module = prajna::codegen::LlvmPass(ir_codegen_module); +#ifdef PRAJNA_ENABLE_LLVM_DUMP + ir_module->llvm_module->dump(); +#endif + jit_engine->AddIRModule(ir_llvm_optimize_module); +} + void Compiler::ExecuteCodeInRelp(std::string script_code) { try { static int command_id = 0; @@ -170,7 +180,7 @@ void Compiler::ExecutateMainFunction() { } } -size_t Compiler::GetSymbolValue(std::string symbol_name) { +int64_t Compiler::GetSymbolValue(std::string symbol_name) { return this->jit_engine->GetValue(symbol_name); } @@ -216,7 +226,8 @@ std::shared_ptr Compiler::CompileProgram( return nullptr; } - auto current_symbol_table = CreateSymbolTableTree(_symbol_table, prajna_source_package_path.string()); + auto current_symbol_table = + CreateSymbolTableTree(_symbol_table, prajna_source_package_path.string()); current_symbol_table->directory_path = prajna_directory_path / current_symbol_table->directory_path; diff --git a/prajna/compiler/compiler.h b/prajna/compiler/compiler.h index d1e7a3cc..6e93cc41 100644 --- a/prajna/compiler/compiler.h +++ b/prajna/compiler/compiler.h @@ -48,11 +48,13 @@ class Compiler : public std::enable_shared_from_this { std::shared_ptr symbol_table, std::string file_name, bool is_interpreter); + void GenLlvm(std::shared_ptr ir_module); + void BindBuiltinFunctions(); void ExecuteCodeInRelp(std::string command_line_code); - size_t GetSymbolValue(std::string symbol_name); + int64_t GetSymbolValue(std::string symbol_name); void RunTests(std::filesystem::path prajna_source_package_path); diff --git a/prajna/helper.hpp b/prajna/helper.hpp index 71280479..c4294e44 100644 --- a/prajna/helper.hpp +++ b/prajna/helper.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "boost/dll/runtime_symbol_info.hpp" @@ -22,15 +23,15 @@ struct overloaded : Ts... { template overloaded(Ts...) -> overloaded; -class function_guard { +class ScopeGuard { public: - static std::unique_ptr Create(std::function func) { - auto self = std::unique_ptr(new function_guard); + static std::unique_ptr Create(std::function func) { + auto self = std::unique_ptr(new ScopeGuard); self->_todo = func; return self; } - ~function_guard() { _todo(); } + ~ScopeGuard() { _todo(); } private: std::function _todo; @@ -42,6 +43,15 @@ auto Cast(std::shared_ptr ir_src) -> std::shared_ptr { return ir_dst; } +template +auto ListCast(std::list> src_list) + -> std::list> { + std::list> dst_list; + std::transform(RANGE(src_list), std::back_inserter(dst_list), + [](std::shared_ptr x) { return Cast(x); }); + return dst_list; +} + template bool Is(std::shared_ptr ir_src) { return Cast(ir_src) != nullptr; diff --git a/prajna/ir/global_context.cpp b/prajna/ir/global_context.cpp index b41a21c0..5c30bb3b 100644 --- a/prajna/ir/global_context.cpp +++ b/prajna/ir/global_context.cpp @@ -4,7 +4,7 @@ namespace prajna::ir { -GlobalContext::GlobalContext(size_t target_bits) { +GlobalContext::GlobalContext(int64_t target_bits) { this->created_types.clear(); } // namespace prajna::ir diff --git a/prajna/ir/global_context.h b/prajna/ir/global_context.h index 6bc04303..7c32d351 100644 --- a/prajna/ir/global_context.h +++ b/prajna/ir/global_context.h @@ -10,7 +10,7 @@ class Type; class GlobalContext { public: - GlobalContext(size_t target_bits); + GlobalContext(int64_t target_bits); /// @brief 用于存储已经构造了的类型 /// @note 需要使用vector来确保构造的顺序, 因为后面的codegen需要顺序正确 diff --git a/prajna/ir/operation_instruction.hpp b/prajna/ir/operation_instruction.hpp index 93d3343f..e9b61c74 100644 --- a/prajna/ir/operation_instruction.hpp +++ b/prajna/ir/operation_instruction.hpp @@ -46,8 +46,8 @@ class CompareInstruction : public Instruction { std::shared_ptr self(new Self); self->operation = operation; self->OperandResize(2); - self->operand(0, ir_operand0); - self->operand(1, ir_operand1); + self->SetOperand(0, ir_operand0); + self->SetOperand(1, ir_operand1); self->type = ir::BoolType::Create(); self->tag = "CompareInstruction"; return self; @@ -99,8 +99,8 @@ class BinaryOperator : public Instruction { PRAJNA_ASSERT(ir_operand0->type == ir_operand1->type); self->operation = operation; self->OperandResize(2); - self->operand(0, ir_operand0); - self->operand(1, ir_operand1); + self->SetOperand(0, ir_operand0); + self->SetOperand(1, ir_operand1); self->type = ir_operand0->type; self->tag = "BinaryOperator"; return self; @@ -146,7 +146,7 @@ class CastInstruction : public Instruction { std::shared_ptr self(new CastInstruction); self->operation = operation; self->OperandResize(1); - self->operand(0, ir_operand); + self->SetOperand(0, ir_operand); self->type = ir_type; self->tag = "CastInstruction"; return self; diff --git a/prajna/ir/type.hpp b/prajna/ir/type.hpp index beaa43fa..d9006176 100644 --- a/prajna/ir/type.hpp +++ b/prajna/ir/type.hpp @@ -29,7 +29,7 @@ class TemplateStruct; namespace prajna::ir { -const size_t ADDRESS_BITS = 64; +const int64_t ADDRESS_BITS = 64; class Function; struct Field; @@ -58,7 +58,7 @@ class Type : public Named { public: // @ref https://llvm.org/docs/LangRef.html#langref-datalayout bytes是多少可参阅datalyout的描述 - size_t bytes = 0; + int64_t bytes = 0; std::unordered_map> member_function_dict; std::unordered_map> static_function_dict; @@ -318,7 +318,7 @@ class ArrayType : public Type { ArrayType() = default; public: - static std::shared_ptr Create(std::shared_ptr value_type, size_t size) { + static std::shared_ptr Create(std::shared_ptr value_type, int64_t size) { for (auto ir_type : global_context.created_types) { if (auto ir_array_type = Cast(ir_type)) { if (ir_array_type->value_type == value_type && ir_array_type->size == size) { @@ -339,7 +339,36 @@ class ArrayType : public Type { public: std::shared_ptr value_type = nullptr; - size_t size = 0; + int64_t size = 0; +}; + +class VectorType : public Type { + protected: + VectorType() = default; + + public: + static std::shared_ptr Create(std::shared_ptr value_type, int64_t size) { + for (auto ir_type : global_context.created_types) { + if (auto ir_vectory_type = Cast(ir_type)) { + if (ir_vectory_type->value_type == value_type && ir_vectory_type->size == size) { + return ir_vectory_type; + } + } + } + + std::shared_ptr self(new VectorType); + self->value_type = value_type; + self->size = size; + self->bytes = value_type->bytes * size; + self->name = value_type->name + "[" + std::to_string(size) + "]"; + self->fullname = self->name; + global_context.created_types.push_back(self); + return self; + } + + public: + std::shared_ptr value_type = nullptr; + int64_t size = 0; }; class Field { @@ -357,7 +386,7 @@ class Field { std::string name; std::shared_ptr type = nullptr; - size_t index = 0; + int64_t index = 0; }; class StructType : public Type { @@ -441,4 +470,33 @@ class InterfaceImplement : public Named { std::shared_ptr dynamic_type_creator = nullptr; }; +class SimdType : public Type { + protected: + SimdType() = default; + + public: + static std::shared_ptr Create(std::shared_ptr value_type, int64_t size) { + for (auto ir_type : global_context.created_types) { + if (auto ir_array_type = Cast(ir_type)) { + if (ir_array_type->value_type == value_type && ir_array_type->size == size) { + return ir_array_type; + } + } + } + + std::shared_ptr self(new SimdType); + self->value_type = value_type; + self->size = size; + self->bytes = value_type->bytes * size; + self->name = value_type->name + "[" + std::to_string(size) + "]"; + self->fullname = self->name; + global_context.created_types.push_back(self); + return self; + } + + public: + std::shared_ptr value_type = nullptr; + int64_t size = 0; +}; + } // namespace prajna::ir diff --git a/prajna/ir/value.hpp b/prajna/ir/value.hpp index fa9d5802..6294151c 100644 --- a/prajna/ir/value.hpp +++ b/prajna/ir/value.hpp @@ -47,7 +47,7 @@ class Instruction; struct InstructionAndOperandIndex { std::shared_ptr instruction; - size_t operand_index; + int64_t operand_index; }; inline bool operator==(prajna::ir::InstructionAndOperandIndex lhs, @@ -59,10 +59,10 @@ inline bool operator==(prajna::ir::InstructionAndOperandIndex lhs, template <> struct std::hash { - std::size_t operator()(prajna::ir::InstructionAndOperandIndex inst_with_idx) const noexcept { - std::size_t h1 = + std::int64_t operator()(prajna::ir::InstructionAndOperandIndex inst_with_idx) const noexcept { + std::int64_t h1 = std::hash>{}(inst_with_idx.instruction); - std::size_t h2 = std::hash{}(inst_with_idx.operand_index); + std::int64_t h2 = std::hash{}(inst_with_idx.operand_index); // 这里哈希函数应该不重要, 应该不会导致性能问题 return h1 ^ (h2 << 1); } @@ -197,6 +197,12 @@ class Parameter : public Value { function_cloner->value_dict[shared_from_this()] = ir_new; return ir_new; } + + public: + bool no_alias = false; + bool no_capture = false; + bool no_undef = false; + bool readonly = false; }; class Constant : public Value { @@ -358,6 +364,42 @@ class ConstantArray : public Constant { std::list> initialize_constants; }; +class ConstantVector : public Constant { + protected: + ConstantVector() = default; + + public: + static std::shared_ptr Create( + std::shared_ptr ir_vector_type, + std::list> ir_init_constants) { + PRAJNA_ASSERT(ir_vector_type); + std::shared_ptr self(new ConstantVector); + self->tag = "ConstantVector"; + self->type = ir_vector_type; + self->initialize_constants = ir_init_constants; + return self; + } + + std::shared_ptr Clone(std::shared_ptr function_cloner) override { + std::list> new_initialize_constants( + this->initialize_constants.size()); + std::transform( + RANGE(this->initialize_constants), new_initialize_constants.begin(), + [=](auto ir_constant) { + PRAJNA_ASSERT( + function_cloner->value_dict[ir_constant]); // constant应该在前面就处理过; + return Cast(function_cloner->value_dict[ir_constant]); + }); + + std::shared_ptr ir_new = + ConstantVector::Create(Cast(this->type), new_initialize_constants); + function_cloner->value_dict[shared_from_this()] = ir_new; + return ir_new; + } + + std::list> initialize_constants; +}; + /// @brief 和高级语言里的块是对应的 class Block : public Value { protected: @@ -465,7 +507,7 @@ class Function : public Value { public: std::shared_ptr function_type = nullptr; - std::list> parameters; + std::list> parameters; std::list> blocks; std::shared_ptr parent_module = nullptr; @@ -537,19 +579,19 @@ class Instruction : virtual public Value { protected: Instruction() : Instruction(0) {} - Instruction(size_t operand_size) { this->operands.resize(operand_size); } + Instruction(int64_t operand_size) { this->operands.resize(operand_size); } public: - virtual void OperandResize(size_t size) { return this->operands.resize(size); } + virtual void OperandResize(int64_t size) { return this->operands.resize(size); } - virtual size_t OperandSize() { return this->operands.size(); } + virtual int64_t OperandSize() { return this->operands.size(); } - virtual std::shared_ptr operand(size_t i) { + virtual std::shared_ptr GetOperand(int64_t i) { PRAJNA_ASSERT(this->OperandSize() > i); return this->operands[i]; }; - virtual void operand(size_t i, std::shared_ptr ir_value) { + virtual void SetOperand(int64_t i, std::shared_ptr ir_value) { PRAJNA_ASSERT(ir_value); PRAJNA_ASSERT(this->OperandSize() > i); @@ -568,7 +610,7 @@ class Instruction : virtual public Value { void Finalize() override { Value::Finalize(); - for (size_t i = 0; i < OperandSize(); ++i) { + for (int64_t i = 0; i < OperandSize(); ++i) { auto ir_old_value = this->operands[i]; if (ir_old_value) { ir_old_value->instruction_with_index_list.remove( @@ -580,7 +622,7 @@ class Instruction : virtual public Value { } void CloneOperands(std::shared_ptr function_cloner) { - for (size_t i = 0; i < operands.size(); ++i) { + for (int64_t i = 0; i < operands.size(); ++i) { auto ir_old = operands[i]; if (!function_cloner->value_dict[ir_old]) { @@ -589,7 +631,7 @@ class Instruction : virtual public Value { operands[i] = nullptr; // 置零以避免干扰原来的操作数 auto ir_new = function_cloner->value_dict[ir_old]; - operand(i, ir_new); + this->SetOperand(i, ir_new); } } @@ -617,8 +659,8 @@ class AccessField : virtual public VariableLiked, virtual public Instruction { return self; } - std::shared_ptr object() { return this->operand(0); } - void object(std::shared_ptr ir_object) { this->operand(0, ir_object); } + std::shared_ptr object() { return this->GetOperand(0); } + void object(std::shared_ptr ir_object) { this->SetOperand(0, ir_object); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new AccessField(*this)); @@ -645,18 +687,20 @@ class IndexArray : virtual public VariableLiked, virtual public Instruction { self->OperandResize(2); self->object(ir_object); self->IndexVariable(ir_index); - auto ir_array_type = Cast(ir_object->type); - PRAJNA_ASSERT(ir_array_type); - self->type = ir_array_type->value_type; + if (Is(ir_object->type)) { + self->type = Cast(ir_object->type)->value_type; + } else { + self->type = Cast(ir_object->type)->value_type; + } self->tag = "IndexArray"; return self; } - std::shared_ptr object() { return this->operand(0); } - void object(std::shared_ptr ir_object) { this->operand(0, ir_object); } + std::shared_ptr object() { return this->GetOperand(0); } + void object(std::shared_ptr ir_object) { this->SetOperand(0, ir_object); } - std::shared_ptr IndexVariable() { return this->operand(1); } - void IndexVariable(std::shared_ptr ir_index) { this->operand(1, ir_index); } + std::shared_ptr IndexVariable() { return this->GetOperand(1); } + void IndexVariable(std::shared_ptr ir_index) { this->SetOperand(1, ir_index); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new IndexArray(*this)); @@ -689,15 +733,15 @@ class IndexPointer : virtual public VariableLiked, virtual public Instruction { std::shared_ptr object() { PRAJNA_ASSERT(this->OperandSize() == 2); - return this->operand(0); + return this->GetOperand(0); } void object(std::shared_ptr ir_object) { PRAJNA_ASSERT(this->OperandSize() == 2); - this->operand(0, ir_object); + this->SetOperand(0, ir_object); } - std::shared_ptr IndexVariable() { return this->operand(1); } - void IndexVariable(std::shared_ptr ir_index) { this->operand(1, ir_index); } + std::shared_ptr IndexVariable() { return this->GetOperand(1); } + void IndexVariable(std::shared_ptr ir_index) { this->SetOperand(1, ir_index); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new IndexPointer(*this)); @@ -727,8 +771,8 @@ class GetStructElementPointer : public Instruction { return self; } - std::shared_ptr Pointer() { return this->operand(0); } - void Pointer(std::shared_ptr ir_pointer) { this->operand(0, ir_pointer); } + std::shared_ptr Pointer() { return this->GetOperand(0); } + void Pointer(std::shared_ptr ir_pointer) { this->SetOperand(0, ir_pointer); } /// @brief 指针偏移下标, 对于结构体来说相当于字段的号数 std::shared_ptr field; @@ -761,18 +805,22 @@ class GetArrayElementPointer : public Instruction { self->IndexVariable(ir_index); auto ir_pointer_type = Cast(ir_pointer->type); PRAJNA_ASSERT(ir_pointer_type); - auto ir_array_type = Cast(ir_pointer_type->value_type); - PRAJNA_ASSERT(ir_array_type); - self->type = PointerType::Create(ir_array_type->value_type); + if (Is(ir_pointer_type->value_type)) { + self->type = + PointerType::Create(Cast(ir_pointer_type->value_type)->value_type); + } else { + self->type = + PointerType::Create(Cast(ir_pointer_type->value_type)->value_type); + } self->tag = "GetArrayElementPointer"; return self; } - std::shared_ptr Pointer() { return this->operand(0); } - void Pointer(std::shared_ptr ir_pointer) { this->operand(0, ir_pointer); } + std::shared_ptr Pointer() { return this->GetOperand(0); } + void Pointer(std::shared_ptr ir_pointer) { this->SetOperand(0, ir_pointer); } - std::shared_ptr IndexVariable() { return this->operand(1); } - void IndexVariable(std::shared_ptr ir_index) { this->operand(1, ir_index); } + std::shared_ptr IndexVariable() { return this->GetOperand(1); } + void IndexVariable(std::shared_ptr ir_index) { this->SetOperand(1, ir_index); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new GetArrayElementPointer(*this)); @@ -805,11 +853,11 @@ class GetPointerElementPointer : public Instruction { return self; } - std::shared_ptr Pointer() { return this->operand(0); } - void Pointer(std::shared_ptr ir_pointer) { this->operand(0, ir_pointer); } + std::shared_ptr Pointer() { return this->GetOperand(0); } + void Pointer(std::shared_ptr ir_pointer) { this->SetOperand(0, ir_pointer); } - std::shared_ptr IndexVariable() { return this->operand(1); } - void IndexVariable(std::shared_ptr ir_index) { this->operand(1, ir_index); } + std::shared_ptr IndexVariable() { return this->GetOperand(1); } + void IndexVariable(std::shared_ptr ir_index) { this->SetOperand(1, ir_index); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new GetPointerElementPointer(*this)); @@ -837,8 +885,8 @@ class DeferencePointer : virtual public VariableLiked, virtual public Instructio return self; } - std::shared_ptr Pointer() { return this->operand(0); } - void Pointer(std::shared_ptr ir_pointer) { this->operand(0, ir_pointer); } + std::shared_ptr Pointer() { return this->GetOperand(0); } + void Pointer(std::shared_ptr ir_pointer) { this->SetOperand(0, ir_pointer); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new DeferencePointer(*this)); @@ -868,16 +916,16 @@ class WriteVariableLiked : public Instruction { return self; } - std::shared_ptr Value() { return this->operand(0); } + std::shared_ptr Value() { return this->GetOperand(0); } void Value(std::shared_ptr ir_value) { PRAJNA_ASSERT(ir_value); - return this->operand(0, ir_value); + return this->SetOperand(0, ir_value); } - std::shared_ptr variable() { return Cast(this->operand(1)); } + std::shared_ptr variable() { return Cast(this->GetOperand(1)); } void variable(std::shared_ptr ir_variable) { PRAJNA_ASSERT(ir_variable); - return this->operand(1, ir_variable); + return this->SetOperand(1, ir_variable); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { @@ -904,9 +952,9 @@ class GetAddressOfVariableLiked : public Instruction { return self; } - std::shared_ptr variable() { return Cast(this->operand(0)); } + std::shared_ptr variable() { return Cast(this->GetOperand(0)); } void variable(std::shared_ptr ir_variable_liked) { - this->operand(0, ir_variable_liked); + this->SetOperand(0, ir_variable_liked); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { @@ -922,11 +970,13 @@ class Alloca : public Instruction { Alloca() = default; public: - static std::shared_ptr Create(std::shared_ptr type) { + static std::shared_ptr Create(std::shared_ptr type, + std::shared_ptr length) { PRAJNA_ASSERT(type); auto self = Cast(std::shared_ptr(static_cast(new Alloca))); - self->OperandResize(0); + self->OperandResize(1); + self->Length(length); self->type = PointerType::Create(type); self->tag = "Alloca"; return self; @@ -938,6 +988,9 @@ class Alloca : public Instruction { ir_new->CloneOperands(function_cloner); return ir_new; } + + std::shared_ptr Length() { return this->GetOperand(0); } + void Length(std::shared_ptr ir_length) { this->SetOperand(0, ir_length); } }; class GlobalVariable; @@ -992,8 +1045,8 @@ class LoadPointer : public Instruction { return self; } - std::shared_ptr Pointer() { return this->operand(0); } - void Pointer(std::shared_ptr ir_pointer) { return this->operand(0, ir_pointer); } + std::shared_ptr Pointer() { return this->GetOperand(0); } + void Pointer(std::shared_ptr ir_pointer) { return this->SetOperand(0, ir_pointer); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new LoadPointer(*this)); @@ -1023,11 +1076,11 @@ class StorePointer : public Instruction { return self; } - std::shared_ptr Value() { return this->operand(0); } - void Value(std::shared_ptr ir_value) { return this->operand(0, ir_value); } + std::shared_ptr Value() { return this->GetOperand(0); } + void Value(std::shared_ptr ir_value) { return this->SetOperand(0, ir_value); } - std::shared_ptr Pointer() { return this->operand(1); } - void Pointer(std::shared_ptr ir_pointer) { return this->operand(1, ir_pointer); } + std::shared_ptr Pointer() { return this->GetOperand(1); } + void Pointer(std::shared_ptr ir_pointer) { return this->SetOperand(1, ir_pointer); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new StorePointer(*this)); @@ -1053,8 +1106,8 @@ class Return : public Instruction { return self; } - std::shared_ptr Value() { return this->operand(0); } - void Value(std::shared_ptr ir_value) { this->operand(0, ir_value); } + std::shared_ptr Value() { return this->GetOperand(0); } + void Value(std::shared_ptr ir_value) { this->SetOperand(0, ir_value); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new Return(*this)); @@ -1082,8 +1135,8 @@ class BitCast : public Instruction { return self; }; - std::shared_ptr Value() { return this->operand(0); } - void Value(std::shared_ptr ir_value) { this->operand(0, ir_value); } + std::shared_ptr Value() { return this->GetOperand(0); } + void Value(std::shared_ptr ir_value) { this->SetOperand(0, ir_value); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new BitCast(*this)); @@ -1107,7 +1160,7 @@ class Call : public Instruction { auto ir_function_type = ir_value->GetFunctionType(); PRAJNA_ASSERT(ir_function_type); PRAJNA_ASSERT(ir_function_type->parameter_types.size() == arguments.size()); - size_t i = 0; + int64_t i = 0; for (auto [ir_argument, ir_parameter_type] : boost::combine(arguments, ir_function_type->parameter_types)) { PRAJNA_ASSERT(ir_argument->type == ir_parameter_type); @@ -1129,15 +1182,15 @@ class Call : public Instruction { return Create(ir_value, std::list>{}); } - std::shared_ptr Function() { return this->operand(0); } - void Function(std::shared_ptr ir_value) { this->operand(0, ir_value); } + std::shared_ptr Function() { return this->GetOperand(0); } + void Function(std::shared_ptr ir_value) { this->SetOperand(0, ir_value); } - std::shared_ptr Argument(size_t i) { return this->operand(1 + i); } - void Argument(size_t i, std::shared_ptr ir_argument) { - this->operand(i + 1, ir_argument); + std::shared_ptr Argument(int64_t i) { return this->GetOperand(1 + i); } + void Argument(int64_t i, std::shared_ptr ir_argument) { + this->SetOperand(i + 1, ir_argument); } - size_t ArgumentSize() { return this->OperandSize() - 1; } + int64_t ArgumentSize() { return this->OperandSize() - 1; } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new Call(*this)); @@ -1167,14 +1220,14 @@ class ConditionBranch : public Instruction { return self; } - std::shared_ptr Condition() { return this->operand(0); } - void condition(std::shared_ptr ir_condition) { this->operand(0, ir_condition); } + std::shared_ptr Condition() { return this->GetOperand(0); } + void condition(std::shared_ptr ir_condition) { this->SetOperand(0, ir_condition); } - std::shared_ptr TrueBlock() { return Cast(this->operand(1)); } - void TrueBlock(std::shared_ptr ir_true_block) { this->operand(1, ir_true_block); } + std::shared_ptr TrueBlock() { return Cast(this->GetOperand(1)); } + void TrueBlock(std::shared_ptr ir_true_block) { this->SetOperand(1, ir_true_block); } - std::shared_ptr FalseBlock() { return Cast(this->operand(2)); } - void FalseBlock(std::shared_ptr ir_false_block) { this->operand(2, ir_false_block); } + std::shared_ptr FalseBlock() { return Cast(this->GetOperand(2)); } + void FalseBlock(std::shared_ptr ir_false_block) { this->SetOperand(2, ir_false_block); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new ConditionBranch(*this)); @@ -1198,8 +1251,8 @@ class JumpBranch : public Instruction { return self; } - std::shared_ptr NextBlock() { return Cast(this->operand(0)); } - void NextBlock(std::shared_ptr ir_next) { this->operand(0, ir_next); } + std::shared_ptr NextBlock() { return Cast(this->GetOperand(0)); } + void NextBlock(std::shared_ptr ir_next) { this->SetOperand(0, ir_next); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new JumpBranch(*this)); @@ -1251,14 +1304,14 @@ class If : public Instruction { return self; } - std::shared_ptr Condition() { return this->operand(0); } - void condition(std::shared_ptr ir_condition) { this->operand(0, ir_condition); } + std::shared_ptr Condition() { return this->GetOperand(0); } + void condition(std::shared_ptr ir_condition) { this->SetOperand(0, ir_condition); } - std::shared_ptr TrueBlock() { return Cast(this->operand(1)); } - void TrueBlock(std::shared_ptr ir_true_block) { this->operand(1, ir_true_block); } + std::shared_ptr TrueBlock() { return Cast(this->GetOperand(1)); } + void TrueBlock(std::shared_ptr ir_true_block) { this->SetOperand(1, ir_true_block); } - std::shared_ptr FalseBlock() { return Cast(this->operand(2)); } - void FalseBlock(std::shared_ptr ir_false_block) { this->operand(2, ir_false_block); } + std::shared_ptr FalseBlock() { return Cast(this->GetOperand(2)); } + void FalseBlock(std::shared_ptr ir_false_block) { this->SetOperand(2, ir_false_block); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new If(*this)); @@ -1285,17 +1338,17 @@ class While : public Instruction { return self; } - std::shared_ptr Condition() { return this->operand(0); } - void condition(std::shared_ptr ir_condition) { this->operand(0, ir_condition); } + std::shared_ptr Condition() { return this->GetOperand(0); } + void condition(std::shared_ptr ir_condition) { this->SetOperand(0, ir_condition); } /// @brief 用于存放条件表达式的块 - std::shared_ptr ConditionBlock() { return Cast(this->operand(1)); } + std::shared_ptr ConditionBlock() { return Cast(this->GetOperand(1)); } void ConditionBlock(std::shared_ptr ir_condition_block) { - this->operand(1, ir_condition_block); + this->SetOperand(1, ir_condition_block); } - std::shared_ptr LoopBlock() { return Cast(this->operand(2)); } - void LoopBlock(std::shared_ptr ir_true_block) { this->operand(2, ir_true_block); } + std::shared_ptr LoopBlock() { return Cast(this->GetOperand(2)); } + void LoopBlock(std::shared_ptr ir_true_block) { this->SetOperand(2, ir_true_block); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new While(*this)); @@ -1324,17 +1377,19 @@ class For : public Instruction { return self; } - std::shared_ptr IndexVariable() { return Cast(this->operand(0)); } - void IndexVariable(std::shared_ptr ir_index) { this->operand(0, ir_index); } + std::shared_ptr IndexVariable() { + return Cast(this->GetOperand(0)); + } + void IndexVariable(std::shared_ptr ir_index) { this->SetOperand(0, ir_index); } - std::shared_ptr First() { return this->operand(1); } - void First(std::shared_ptr ir_first) { this->operand(1, ir_first); } + std::shared_ptr First() { return this->GetOperand(1); } + void First(std::shared_ptr ir_first) { this->SetOperand(1, ir_first); } - std::shared_ptr Last() { return this->operand(2); } - void Last(std::shared_ptr ir_last) { this->operand(2, ir_last); } + std::shared_ptr Last() { return this->GetOperand(2); } + void Last(std::shared_ptr ir_last) { this->SetOperand(2, ir_last); } - std::shared_ptr LoopBlock() { return Cast(this->operand(3)); } - void LoopBlock(std::shared_ptr ir_loop_block) { this->operand(3, ir_loop_block); } + std::shared_ptr LoopBlock() { return Cast(this->GetOperand(3)); } + void LoopBlock(std::shared_ptr ir_loop_block) { this->SetOperand(3, ir_loop_block); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new For(*this)); @@ -1358,8 +1413,8 @@ class Break : public Instruction { return self; } - std::shared_ptr Loop() { return this->operand(0); } - void loop(std::shared_ptr ir_loop) { this->operand(0, ir_loop); } + std::shared_ptr Loop() { return this->GetOperand(0); } + void loop(std::shared_ptr ir_loop) { this->SetOperand(0, ir_loop); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new Break(*this)); @@ -1383,8 +1438,8 @@ class Continue : public Instruction { return self; } - std::shared_ptr Loop() { return this->operand(0); } - void loop(std::shared_ptr ir_loop) { this->operand(0, ir_loop); } + std::shared_ptr Loop() { return this->GetOperand(0); } + void loop(std::shared_ptr ir_loop) { this->SetOperand(0, ir_loop); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { std::shared_ptr ir_new(new Continue(*this)); @@ -1441,24 +1496,24 @@ class AccessProperty : public WriteReadAble, virtual public Instruction { return self; } - std::shared_ptr ThisPointer() { return this->operand(0); } + std::shared_ptr ThisPointer() { return this->GetOperand(0); } void ThisPointer(std::shared_ptr ir_this_pointer) { - this->operand(0, ir_this_pointer); + this->SetOperand(0, ir_this_pointer); } void Arguments(std::list> ir_arguments) { this->OperandResize(1 + ir_arguments.size()); - size_t i = 1; + int64_t i = 1; for (auto ir_argument : ir_arguments) { - this->operand(i, ir_argument); + this->SetOperand(i, ir_argument); ++i; } } std::list> Arguments() { std::list> ir_arguments; - for (size_t i = 1; i < this->OperandSize(); ++i) { - ir_arguments.push_back(this->operand(i)); + for (int64_t i = 1; i < this->OperandSize(); ++i) { + ir_arguments.push_back(this->GetOperand(i)); } return ir_arguments; @@ -1495,12 +1550,12 @@ class WriteProperty : public Instruction { return self; } - std::shared_ptr Value() { return this->operand(0); } - void Value(std::shared_ptr ir_value) { return this->operand(0, ir_value); } + std::shared_ptr Value() { return this->GetOperand(0); } + void Value(std::shared_ptr ir_value) { return this->SetOperand(0, ir_value); } - std::shared_ptr property() { return Cast(this->operand(1)); } + std::shared_ptr property() { return Cast(this->GetOperand(1)); } void property(std::shared_ptr ir_access_property) { - return this->operand(1, ir_access_property); + return this->SetOperand(1, ir_access_property); } std::shared_ptr Clone(std::shared_ptr function_cloner) override { @@ -1510,6 +1565,33 @@ class WriteProperty : public Instruction { } }; +class ShuffleVector : public Instruction { + protected: + ShuffleVector() = default; + + public: + static std::shared_ptr Create(std::shared_ptr ir_value, + std::shared_ptr ir_mask) { + PRAJNA_ASSERT(Is(ir_value->type)); + PRAJNA_ASSERT(Is(ir_mask->type)); + std::shared_ptr self(new ShuffleVector); + PRAJNA_ASSERT(Cast(ir_value->type)->size == + Cast(ir_mask->type)->size); + self->type = ir_value->type; + self->OperandResize(2); + self->Value(ir_value); + self->Mask(ir_mask); + self->tag = "ShuffleVector"; + return self; + } + + std::shared_ptr Value() { return this->GetOperand(0); } + void Value(std::shared_ptr ir_value) { this->SetOperand(0, ir_value); } + + std::shared_ptr Mask() { return this->GetOperand(1); } + void Mask(std::shared_ptr ir_mask) { this->SetOperand(1, ir_mask); } +}; + class Module : public Named, public std::enable_shared_from_this { protected: Module() = default; @@ -1565,7 +1647,7 @@ class KernelFunctionCall : public Instruction { self->BlockShape(ir_block_shape); auto iter_parameter_type = ir_function_type->parameter_types.begin(); - size_t i = 0; + int64_t i = 0; for (auto iter_argument = arguments.begin(); iter_argument != arguments.end(); ++iter_argument, ++iter_parameter_type, ++i) { PRAJNA_ASSERT(*iter_parameter_type == (*iter_argument)->type); @@ -1577,25 +1659,27 @@ class KernelFunctionCall : public Instruction { return self; } - std::shared_ptr Function() { return this->operand(0); } - void Function(std::shared_ptr ir_value) { this->operand(0, ir_value); } + std::shared_ptr Function() { return this->GetOperand(0); } + void Function(std::shared_ptr ir_value) { this->SetOperand(0, ir_value); } - std::shared_ptr GridShape() { return this->operand(1); } - void GridShape(std::shared_ptr ir_grid_shape) { this->operand(1, ir_grid_shape); } + std::shared_ptr GridShape() { return this->GetOperand(1); } + void GridShape(std::shared_ptr ir_grid_shape) { this->SetOperand(1, ir_grid_shape); } - std::shared_ptr BlockShape() { return this->operand(2); } - void BlockShape(std::shared_ptr ir_block_shape) { this->operand(2, ir_block_shape); } + std::shared_ptr BlockShape() { return this->GetOperand(2); } + void BlockShape(std::shared_ptr ir_block_shape) { + this->SetOperand(2, ir_block_shape); + } - size_t ArgumentSize() { return this->OperandSize() - 3; } + int64_t ArgumentSize() { return this->OperandSize() - 3; } - std::shared_ptr Argument(size_t i) { return this->operand(3 + i); } - void Argument(size_t i, std::shared_ptr ir_argument) { - this->operand(i + 3, ir_argument); + std::shared_ptr Argument(int64_t i) { return this->GetOperand(3 + i); } + void Argument(int64_t i, std::shared_ptr ir_argument) { + this->SetOperand(i + 3, ir_argument); } std::list> Arguments() { std::list> arguments_re; - for (size_t i = 0; i < this->OperandSize(); ++i) { + for (int64_t i = 0; i < this->OperandSize(); ++i) { arguments_re.push_back(this->Argument(i)); } @@ -1732,7 +1816,7 @@ inline std::shared_ptr Function::Clone(std::shared_ptrClone(function_cloner); function_cloner->value_dict[ir_parameter] = ir_new_parameter; - return ir_new_parameter; + return Cast(ir_new_parameter); }); // 需要再开头, 因为函数有可能存在递归 diff --git a/prajna/jit/execution_engine.cpp b/prajna/jit/execution_engine.cpp index 621df117..a37625cb 100644 --- a/prajna/jit/execution_engine.cpp +++ b/prajna/jit/execution_engine.cpp @@ -79,23 +79,29 @@ ExecutionEngine::ExecutionEngine() { LLVMInitializeNativeAsmPrinter(); // LLVMInitializeNativeAsmParser(); + auto lljit_builder = llvm::orc::LLJITBuilder(); + auto JTMB = llvm::orc::JITTargetMachineBuilder::detectHost(); + PRAJNA_VERIFY(JTMB); + JTMB->getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast; + JTMB->getOptions().UnsafeFPMath = true; + lljit_builder.setJITTargetMachineBuilder(*JTMB); + + lljit_builder.setObjectLinkingLayerCreator( + [=](llvm::orc::ExecutionSession &ES, const llvm::Triple &TT) { + // @note 需要确认机制是做什么用的 + auto ll = std::make_unique( + ES, std::make_unique(64 * 1024)); + ll->setAutoClaimResponsibilityForObjectSymbols(true); + return std::move(ll); + }); + // TODO: 需要确定setObjectLinkingLayerCreator的作用, 现在去除后, 在mac上会报错. #ifdef __APPLE__ - auto expect_up_lljit = - llvm::orc::LLJITBuilder() - .setObjectLinkingLayerCreator( - [=](llvm::orc::ExecutionSession &ES, const llvm::Triple &TT) { - // @note 需要确认机制是做什么用的 - auto ll = std::make_unique( - ES, std::make_unique(64 * 1024)); - ll->setAutoClaimResponsibilityForObjectSymbols(true); - return std::move(ll); - }) - .create(); + auto expect_up_lljit = lljit_builder.create(); #else auto expect_up_lljit = llvm::orc::LLJITBuilder().create(); #endif - PRAJNA_ASSERT(expect_up_lljit); + PRAJNA_VERIFY(expect_up_lljit); _up_lljit = std::move(*expect_up_lljit); _up_lljit->getMainJITDylib().addGenerator( @@ -107,7 +113,7 @@ bool ExecutionEngine::LoadDynamicLib(std::string lib_name) { return llvm::sys::DynamicLibrary::getPermanentLibrary(lib_name.c_str()).isValid(); } -size_t ExecutionEngine::GetValue(std::string name) { +int64_t ExecutionEngine::GetValue(std::string name) { auto expect_symbol = _up_lljit->lookup(name); PRAJNA_VERIFY(expect_symbol); return expect_symbol->getValue(); diff --git a/prajna/jit/execution_engine.h b/prajna/jit/execution_engine.h index ff72bf1a..04494f37 100644 --- a/prajna/jit/execution_engine.h +++ b/prajna/jit/execution_engine.h @@ -16,7 +16,7 @@ class ExecutionEngine { public: ExecutionEngine(); - size_t GetValue(std::string name); + int64_t GetValue(std::string name); void AddIRModule(std::shared_ptr ir_module); diff --git a/prajna/logger.hpp b/prajna/logger.hpp index 2ffb9bfd..7693cc62 100644 --- a/prajna/logger.hpp +++ b/prajna/logger.hpp @@ -88,7 +88,7 @@ class Logger { } std::vector code_lines; - for (size_t i = first_position.line; i <= last_position.line; ++i) { + for (int64_t i = first_position.line; i <= last_position.line; ++i) { code_lines.push_back(_code_lines[i - 1]); } diff --git a/prajna/lowering/expression_lowering_visitor.cpp b/prajna/lowering/expression_lowering_visitor.cpp index 1fab80d3..ade74833 100644 --- a/prajna/lowering/expression_lowering_visitor.cpp +++ b/prajna/lowering/expression_lowering_visitor.cpp @@ -11,7 +11,7 @@ std::shared_ptr ExpressionLoweringVisitor::operator()(ast::Closure as ir_builder->PushSymbolTable(); ir_builder->symbol_table->name = "closure." + std::to_string(ir_builder->closure_id); ++ir_builder->closure_id; - auto guard = function_guard::Create([this]() { + auto guard = ScopeGuard::Create([this]() { this->ir_builder->PopSymbolTable(); this->ir_builder->current_implement_type = nullptr; }); diff --git a/prajna/lowering/expression_lowering_visitor.hpp b/prajna/lowering/expression_lowering_visitor.hpp index 1779f195..3e030aed 100644 --- a/prajna/lowering/expression_lowering_visitor.hpp +++ b/prajna/lowering/expression_lowering_visitor.hpp @@ -50,7 +50,7 @@ class ExpressionLoweringVisitor { } std::shared_ptr operator()(ast::IntLiteral ast_int_literal) { - return ir_builder->GetIndexConstant(ast_int_literal.value); + return ir_builder->GetInt64Constant(ast_int_literal.value); }; std::shared_ptr operator()(ast::IntLiteralPostfix ast_int_literal_postfix) { @@ -187,7 +187,7 @@ class ExpressionLoweringVisitor { *identifier_path.identifiers.front().template_arguments_optional); auto ir_member_function = Cast( - SymbolGet(lowering_member_function_template->instantiate( + SymbolGet(lowering_member_function_template->Instantiate( symbol_template_arguments, ir_builder->module))); PRAJNA_ASSERT(ir_member_function); @@ -237,7 +237,7 @@ class ExpressionLoweringVisitor { logger->Error("index should have one argument at least", ast_binary_operation.operand); } auto ir_index = ir_arguments.front(); - if (ir_index->type != ir_builder->GetI64Type()) { + if (ir_index->type != ir_builder->GetInt64Type()) { logger->Error( fmt::format("the index type must be i64, but it's {}", ir_index->type->fullname), ast_binary_operation.operand); @@ -384,7 +384,7 @@ class ExpressionLoweringVisitor { // 移除, 在末尾插入, 应为参数应该在属性访问的前面 ir_access_property->parent_block->values.remove(ir_access_property); - ir_builder->currentBlock()->values.push_back(ir_access_property); + ir_builder->CurrentBlock()->values.push_back(ir_access_property); ir_access_property->Arguments(ir_arguments); return ir_access_property; } @@ -514,7 +514,7 @@ class ExpressionLoweringVisitor { return this->ApplyIdentifierPath(ast_identifier_path); }, [=](ast::IntLiteral ast_int_literal) -> Symbol { - return ir::ConstantInt::Create(ir_builder->GetI64Type(), + return ir::ConstantInt::Create(ir_builder->GetInt64Type(), ast_int_literal.value); }}, ast_template_argument); @@ -557,7 +557,7 @@ class ExpressionLoweringVisitor { auto symbol_template_argumen_list = this->ApplyTemplateArguments( *iter_ast_identifier->template_arguments_optional); auto ir_function = Cast( - SymbolGet(lowering_member_function_template->instantiate( + SymbolGet(lowering_member_function_template->Instantiate( symbol_template_argumen_list, ir_builder->module))); PRAJNA_ASSERT(ir_function); return ir_function; @@ -616,7 +616,7 @@ class ExpressionLoweringVisitor { } if (auto tempate_ = SymbolGet