Skip to content

Commit

Permalink
Merge branch 'prajna-lang:dev' into dev-workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ConvolutedDog authored Sep 4, 2024
2 parents cb54ddc + d729e6b commit 1916d02
Show file tree
Hide file tree
Showing 39 changed files with 742 additions and 347 deletions.
22 changes: 22 additions & 0 deletions builtin_packages/_array.prajna
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,25 @@ implement Array<Type, Length_> {
}
}

template <Length_>
func __ArrayOne()->Array<i64, Length_> {
var re: Array<i64, Length_>;
for i in 0 to re.Length() {
re[i] = 1;
}

return re;
}

template <Row_, Column_>
func __CacluateArrayIndex(A: Array<Array<i64, Column_>, Row_>, B: Array<i64, Row_>, Index: Array<i64, Column_>) -> Array<i64, Row_> {
var re: Array<i64, Row_>;
for i in 0 to Row_ {
re[i] = B[i];
for j in 0 to Column_ {
re[i] = re[i] + A[i][j] * Index[j];
}
}

return re;
}
6 changes: 4 additions & 2 deletions builtin_packages/optional.prajna
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ implement Optional<ValueType> {
}

func __get__Value()->ValueType{
"hit\n".Print();
if (!this.HasValue()) {
debug::AssertWithMessage(false, "has no value");
// TODO(张志敏): 属性赋值存在bug, 后续需要修复
// debug::AssertWithMessage(false, "has no value");
}
return this._value;
}
Expand All @@ -33,7 +35,7 @@ implement Optional<ValueType> {
@static
func Create(value: ValueType) -> Optional<ValueType> {
var optional: Optional<ValueType>;
optional.Value = value;
optional._value = value;
return optional;
}
}
Expand Down
11 changes: 8 additions & 3 deletions examples_in_cpp/add.prajna
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@

func helloWorld(ts: Array<i64, 8>*){
ts->ToString().Print();
func HelloWorld(){
"Hello World!\n".Print();
}

func matrixAddF32(p_ts0: Tensor<f32, 2> *, p_ts1: Tensor<f32, 2>*, p_ts_re: Tensor<f32, 2>*){
func MatrixAddF32(p_ts0: ptr<Tensor<f32, 2>>, p_ts1: ptr<Tensor<f32, 2>>, p_ts_re: ptr<Tensor<f32, 2>>){
var ts0 = *p_ts0;
var ts1 = *p_ts1;
var ts_re = *p_ts_re;
for i in 0 to ts0.layout.shape[0] {
for j in 0 to ts0.layout.shape[1] {
ts_re[i, j] = ts0[i, j] + ts1[i, j];
}
}
for idx in [0, 0] to ts0.layout.shape{
ts_re.At(idx) = ts0.At(idx) + ts1.At(idx);
}
Expand Down
12 changes: 8 additions & 4 deletions examples_in_cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

int main() {
auto compiler = prajna::Compiler::Create();
compiler->CompileBuiltinSourceFiles("prajna/builtin_packages");
compiler->CompileProgram("examples/prajna_in_cpp/add.prajna", false);
compiler->CompileBuiltinSourceFiles("builtin_packages");
compiler->AddPackageDirectoryPath(".");
compiler->CompileProgram("examples_in_cpp/add.prajna", false);

auto hello_world = reinterpret_cast<void (*)()>(
compiler->GetSymbolValue("::examples_in_cpp::add::HelloWorld"));
hello_world();

using MatrixF32 = prajna::Tensor<float, 2>;

Expand All @@ -18,10 +23,9 @@ int main() {
ts(1, 0) = 4;
ts(1, 1) = 5;
ts(1, 2) = 6;

// 获取Prajna里的函数地址
auto matrix_add_f32 = reinterpret_cast<MatrixF32 (*)(MatrixF32 *, MatrixF32 *, MatrixF32 *)>(
compiler->GetSymbolValue("::examples::prajna_in_cpp::add::matrixAddF32"));
compiler->GetSymbolValue("::examples_in_cpp::add::MatrixAddF32"));

auto ts_re = MatrixF32::Create(shape);

Expand Down
9 changes: 9 additions & 0 deletions prajna/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ target_link_libraries(prajna_config_target
INTERFACE Boost::dll
INTERFACE Boost::process
INTERFACE Boost::asio
INTERFACE Boost::scope
INTERFACE fmt::fmt
)

target_compile_definitions(prajna_config_target
# 参阅third_party/boost/libs/mpl/include/boost/mpl/list/list50.hpp
# 并非任意数字都行, 需要是10,20,30,40,50
# boost::variant的模板参数受此影响
INTERFACE BOOST_MPL_LIMIT_LIST_SIZE=50
INTERFACE BOOST_MPL_CFG_NO_PREPROCESSED_HEADERS
)

if (MSVC)
target_compile_options(prajna_config_target INTERFACE "/bigobj")
else ()
Expand Down
4 changes: 2 additions & 2 deletions prajna/assert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace prajna {
class assert_failed : public std::exception {
public:
assert_failed(const std::string& expr, const std::string& function, const std::string& file,
size_t line, bool verify, const std::string& msg)
int64_t line, bool verify, const std::string& msg)
: _expr(expr), _function(function), _file(file), _line(line), _msg(msg) {
std::string with_msg;
if (!msg.empty()) {
Expand All @@ -28,7 +28,7 @@ class assert_failed : public std::exception {
std::string _expr;
std::string _function;
std::string _file;
size_t _line;
int64_t _line;
std::string _msg;
std::string _what_str;
};
Expand Down
31 changes: 26 additions & 5 deletions prajna/bindings/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,27 @@ class Array {
Type_& operator[](i64 offset) { return data[offset]; }
};

template <typename Type>
struct Ptr {
Type* raw_ptr;
int64_t size;
int64_t* _reference_counter;

static Ptr<Type> Allocate(int64_t size) {
Ptr<Type> self;
self.raw_ptr = new Type[size];
self.size = size;

self._reference_counter = new int64_t;
*self._reference_counter = 1;
return self;
}

Type& operator[](int64_t offset) { return this->raw_ptr[offset]; }

// ~Ptr() { delete[] this->raw_ptr; }
};

template <i64 Dim_>
class Layout {
public:
Expand Down Expand Up @@ -68,7 +89,7 @@ class Layout {
template <typename Type_, i64 Dim_>
class Tensor {
public:
Type_* data = nullptr;
Ptr<Type_> data;
Layout<Dim_> layout;

protected:
Expand All @@ -79,9 +100,9 @@ class Tensor {
Tensor<Type_, Dim_> self;
self.layout = Layout<Dim_>::Create(shape);

auto bytes = self.layout.Length() * sizeof(Type_);
self.data = reinterpret_cast<Type_*>(malloc(bytes));
self.data = Ptr<Type_>::Allocate(self.layout.Length());

// TODO(zhangzhimin): 初始化和释放还没有处理
// RegisterReferenceCount(self.data);
// __copy__(self.data);

Expand All @@ -101,7 +122,7 @@ class Tensor {
// }
}

const Type_& at(Array<i64, Dim_> idx) const {
Type_& at(Array<i64, Dim_> idx) const {
i64 offset = this->layout.ArrayIndexToLinearIndex(idx);
return this->data[offset];
}
Expand All @@ -112,7 +133,7 @@ class Tensor {
}

template <typename... Idx_>
const Type_& operator()(Idx_... indices) const {
Type_& operator()(Idx_... indices) const {
Array<i64, Dim_> idx(indices...);
return this->at(idx);
}
Expand Down
75 changes: 61 additions & 14 deletions prajna/codegen/llvm_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ class LlvmCodegen {
llvm::ArrayType::get(ir_array_type->value_type->llvm_type, ir_array_type->size);
return;
}
if (auto ir_vector_type = Cast<ir::VectorType>(ir_type)) {
this->EmitType(ir_vector_type->value_type);
ir_vector_type->llvm_type =
llvm::VectorType::get(ir_vector_type->value_type->llvm_type,
llvm::ElementCount::getFixed(ir_vector_type->size));
return;
}
if (auto ir_struct_type = Cast<ir::StructType>(ir_type)) {
auto llvm_struct_type =
llvm::StructType::create(static_llvm_context, ir_struct_type->fullname);
Expand All @@ -133,7 +140,13 @@ class LlvmCodegen {
return field->type->llvm_type;
});
llvm_struct_type->setBody(llvm_types, true);

return;
}
if (auto ir_simd_type = Cast<ir::SimdType>(ir_type)) {
this->EmitType(ir_simd_type->value_type);
ir_simd_type->llvm_type =
llvm::VectorType::get(ir_simd_type->value_type->llvm_type,
llvm::ElementCount::getFixed(ir_simd_type->size));
return;
}

Expand Down Expand Up @@ -190,11 +203,24 @@ class LlvmCodegen {
llvm::Function *llvm_fun = static_cast<llvm::Function *>(ir_function->llvm_value);
PRAJNA_ASSERT(llvm_fun);
PRAJNA_ASSERT(ir_function->parameters.size() == llvm_fun->arg_size());
size_t i = 0;
int64_t i = 0;
auto iter_parameter = ir_function->parameters.begin();
for (auto llvm_arg = llvm_fun->arg_begin(); llvm_arg != llvm_fun->arg_end();
++llvm_arg, ++iter_parameter) {
(*iter_parameter)->llvm_value = llvm_arg;
auto ir_parameter = *iter_parameter;
if (ir_parameter->no_alias) {
llvm_arg->addAttr(llvm::Attribute::NoAlias);
}
if (ir_parameter->no_capture) {
llvm_arg->addAttr(llvm::Attribute::NoCapture);
}
if (ir_parameter->no_undef) {
llvm_arg->addAttr(llvm::Attribute::NoUndef);
}
if (ir_parameter->readonly) {
llvm_arg->addAttr(llvm::Attribute::ReadOnly);
}
ir_parameter->llvm_value = llvm_arg;
}

for (auto block : ir_function->blocks) {
Expand Down Expand Up @@ -337,6 +363,21 @@ class LlvmCodegen {
return;
}

if (auto ir_constant_vector = Cast<ir::ConstantVector>(ir_constant)) {
std::vector<llvm::Constant *> llvm_contants(
ir_constant_vector->initialize_constants.size());
std::transform(RANGE(ir_constant_vector->initialize_constants), llvm_contants.begin(),
[=](auto ir_init) {
auto llvm_constant =
static_cast<llvm::Constant *>(ir_init->llvm_value);
PRAJNA_ASSERT(llvm_constant);
return llvm_constant;
});
PRAJNA_ASSERT(ir_constant_vector->type->llvm_type);
ir_constant_vector->llvm_value = llvm::ConstantVector::get(llvm_contants);
return;
}

PRAJNA_TODO;
}

Expand Down Expand Up @@ -364,7 +405,7 @@ class LlvmCodegen {
if (auto ir_call = Cast<ir::Call>(ir_instruction)) {
auto ir_function_type = ir_call->Function()->GetFunctionType();
std::vector<llvm::Value *> llvm_arguments(ir_call->ArgumentSize());
for (size_t i = 0; i < llvm_arguments.size(); ++i) {
for (int64_t i = 0; i < llvm_arguments.size(); ++i) {
llvm_arguments[i] = ir_call->Argument(i)->llvm_value;
PRAJNA_ASSERT(llvm_arguments[i]);
}
Expand All @@ -383,12 +424,11 @@ class LlvmCodegen {
return;
}
if (auto ir_alloca = Cast<ir::Alloca>(ir_instruction)) {
// Align需要设置一个合理的值, 目前先设置为8
// PRAJNA_ASSERT()
auto ir_alloca_type = Cast<ir::PointerType>(ir_alloca->type);
PRAJNA_ASSERT(ir_alloca_type && ir_alloca_type->value_type->llvm_type);
ir_alloca->llvm_value = new llvm::AllocaInst(ir_alloca_type->value_type->llvm_type, 0,
ir_alloca->name, llvm_basic_block);
ir_alloca->llvm_value = new llvm::AllocaInst(
ir_alloca_type->value_type->llvm_type, 0, ir_alloca->Length()->llvm_value,
llvm::Align(16), ir_alloca->name, llvm_basic_block);
return;
}

Expand Down Expand Up @@ -524,7 +564,7 @@ class LlvmCodegen {
PRAJNA_ASSERT(cast_operator_dict.count(ir_cast_instruction->operation));
auto cast_op = cast_operator_dict[ir_cast_instruction->operation];
ir_cast_instruction->llvm_value =
llvm::CastInst::Create(cast_op, ir_cast_instruction->operand(0)->llvm_value,
llvm::CastInst::Create(cast_op, ir_cast_instruction->GetOperand(0)->llvm_value,
ir_cast_instruction->type->llvm_type, "", llvm_basic_block);
return;
}
Expand Down Expand Up @@ -602,8 +642,8 @@ class LlvmCodegen {
llvm_compare_predicator_dict[ir_compare_instruction->operation];
ir_compare_instruction->llvm_value = llvm::CmpInst::Create(
llvm_compare_other_ops, llvm_compare_predicator,
ir_compare_instruction->operand(0)->llvm_value,
ir_compare_instruction->operand(1)->llvm_value, "", llvm_basic_block);
ir_compare_instruction->GetOperand(0)->llvm_value,
ir_compare_instruction->GetOperand(1)->llvm_value, "", llvm_basic_block);
return;
}

Expand Down Expand Up @@ -636,12 +676,19 @@ class LlvmCodegen {
auto llvm_binary_operator_operation =
binary_operator_dict[ir_binary_operator->operation];
ir_binary_operator->llvm_value = llvm::BinaryOperator::Create(
llvm_binary_operator_operation, ir_binary_operator->operand(0)->llvm_value,
ir_binary_operator->operand(1)->llvm_value, "", llvm_basic_block);
llvm_binary_operator_operation, ir_binary_operator->GetOperand(0)->llvm_value,
ir_binary_operator->GetOperand(1)->llvm_value, "", llvm_basic_block);

return;
}

if (auto ir_shuffle_vector = Cast<ir::ShuffleVector>(ir_instruction)) {
ir_shuffle_vector->llvm_value = new llvm::ShuffleVectorInst(
ir_shuffle_vector->Value()->llvm_value, ir_shuffle_vector->Mask()->llvm_value, "",
llvm_basic_block);
return;
}

PRAJNA_ASSERT(false, ir_instruction->tag);
}
};
Expand Down Expand Up @@ -701,7 +748,7 @@ std::shared_ptr<ir::Module> LlvmPass(std::shared_ptr<ir::Module> ir_module) {
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

llvm::ModulePassManager MPM =
PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3, true);
PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O2, true);

MPM.run(*ir_module->llvm_module, MAM);

Expand Down
15 changes: 13 additions & 2 deletions prajna/compiler/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ std::shared_ptr<ir::Module> Compiler::CompileCode(
return ir_lowering_module;
}

void Compiler::GenLlvm(std::shared_ptr<ir::Module> ir_module) {
auto ir_ssa_module = prajna::transform::transform(ir_module);
auto ir_codegen_module = prajna::codegen::LlvmCodegen(ir_ssa_module, ir::Target::host);
auto ir_llvm_optimize_module = prajna::codegen::LlvmPass(ir_codegen_module);
#ifdef PRAJNA_ENABLE_LLVM_DUMP
ir_module->llvm_module->dump();
#endif
jit_engine->AddIRModule(ir_llvm_optimize_module);
}

void Compiler::ExecuteCodeInRelp(std::string script_code) {
try {
static int command_id = 0;
Expand Down Expand Up @@ -170,7 +180,7 @@ void Compiler::ExecutateMainFunction() {
}
}

size_t Compiler::GetSymbolValue(std::string symbol_name) {
int64_t Compiler::GetSymbolValue(std::string symbol_name) {
return this->jit_engine->GetValue(symbol_name);
}

Expand Down Expand Up @@ -216,7 +226,8 @@ std::shared_ptr<ir::Module> Compiler::CompileProgram(
return nullptr;
}

auto current_symbol_table = CreateSymbolTableTree(_symbol_table, prajna_source_package_path.string());
auto current_symbol_table =
CreateSymbolTableTree(_symbol_table, prajna_source_package_path.string());
current_symbol_table->directory_path =
prajna_directory_path / current_symbol_table->directory_path;

Expand Down
Loading

0 comments on commit 1916d02

Please sign in to comment.