Skip to content

Commit

Permalink
add global constants
Browse files Browse the repository at this point in the history
  • Loading branch information
drblallo committed Nov 20, 2024
1 parent 8201d2f commit 5279ba0
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 28 deletions.
6 changes: 3 additions & 3 deletions lib/conversions/src/RLCToPython.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,14 @@ class ClassDeclarationToNothing
};

class ConstantGlobalArrayOpToNothing
: public mlir::OpConversionPattern<mlir::rlc::ConstantGlobalArrayOp>
: public mlir::OpConversionPattern<mlir::rlc::ConstantGlobalOp>
{
public:
using mlir::OpConversionPattern<
mlir::rlc::ConstantGlobalArrayOp>::OpConversionPattern;
mlir::rlc::ConstantGlobalOp>::OpConversionPattern;

mlir::LogicalResult matchAndRewrite(
mlir::rlc::ConstantGlobalArrayOp op,
mlir::rlc::ConstantGlobalOp op,
OpAdaptor adaptor,
mlir::ConversionPatternRewriter& rewriter) const final
{
Expand Down
30 changes: 16 additions & 14 deletions lib/dialect/src/Conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1450,18 +1450,18 @@ static void emitGlobalVectorInitialization(
}

class GlobalArrayRewriter
: public mlir::OpConversionPattern<mlir::rlc::ConstantGlobalArrayOp>
: public mlir::OpConversionPattern<mlir::rlc::FlatConstantGlobalOp>
{
using mlir::OpConversionPattern<
mlir::rlc::ConstantGlobalArrayOp>::OpConversionPattern;
mlir::rlc::FlatConstantGlobalOp>::OpConversionPattern;

mlir::LogicalResult matchAndRewrite(
mlir::rlc::ConstantGlobalArrayOp op,
mlir::rlc::FlatConstantGlobalOp op,
OpAdaptor adaptor,
mlir::ConversionPatternRewriter& rewriter) const final
{
auto type = getTypeConverter()->convertType(
mlir::rlc::ProxyType::get(op.getResult()));
mlir::rlc::ProxyType::get(op.getType()));
auto global = rewriter.create<mlir::LLVM::GlobalOp>(
op->getLoc(),
type,
Expand All @@ -1472,18 +1472,20 @@ class GlobalArrayRewriter

auto* block = rewriter.createBlock(&global.getInitializer());
rewriter.setInsertionPoint(block, block->begin());
mlir::Value toReturn;

mlir::Value toReturn =
rewriter.create<mlir::LLVM::UndefOp>(op.getLoc(), type);
if (auto casted = op.getValues().dyn_cast<mlir::ArrayAttr>())
{
toReturn = rewriter.create<mlir::LLVM::UndefOp>(op.getLoc(), type);

llvm::SmallVector<int64_t, 4> indicies;
emitGlobalVectorInitialization(
op.getValues(),
toReturn,
indicies,
rewriter,
typeConverter,
op.getLoc());
llvm::SmallVector<int64_t, 4> indicies;
emitGlobalVectorInitialization(
casted, toReturn, indicies, rewriter, typeConverter, op.getLoc());
}
else
{
toReturn = lowerConstant(rewriter, op.getValues(), op.getLoc());
}

rewriter.create<mlir::LLVM::ReturnOp>(
op->getLoc(), mlir::ValueRange({ toReturn }));
Expand Down
6 changes: 3 additions & 3 deletions lib/dialect/src/LowerInitializerListsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ namespace mlir::rlc
return variable;
}

static mlir::rlc::ConstantGlobalArrayOp rewriteAsGlobal(
static mlir::rlc::FlatConstantGlobalOp rewriteAsGlobal(
llvm::StringRef name, mlir::rlc::Constant op, mlir::IRRewriter& rewriter)
{
rewriter.setInsertionPoint(op);
auto global = rewriter.create<mlir::rlc::ConstantGlobalArrayOp>(
auto global = rewriter.create<mlir::rlc::FlatConstantGlobalOp>(
op.getLoc(),
op.getResult().getType(),
op.getValue().cast<mlir::ArrayAttr>(),
Expand All @@ -66,7 +66,7 @@ namespace mlir::rlc
{
rewriter.setInsertionPoint(use.getOwner());
use.assign(rewriter.create<mlir::rlc::Reference>(
op.getLoc(), global.getResult(), global.getName()));
op.getLoc(), global.getType(), global.getName()));
}

op.erase();
Expand Down
29 changes: 29 additions & 0 deletions lib/dialect/src/LowerToCf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,35 @@ static mlir::LogicalResult flattenModule(mlir::ModuleOp op)
rewriter.eraseOp(f);
}

// rewrite globals to use references
llvm::SmallVector<mlir::rlc::ConstantGlobalOp, 2> globals(
op.getOps<mlir::rlc::ConstantGlobalOp>());
for (auto global : globals)
{
rewriter.setInsertionPoint(global);
rewriter.create<mlir::rlc::FlatConstantGlobalOp>(
global.getLoc(),
global.getType(),
global.getValues(),
global.getName());
llvm::SmallVector<mlir::OpOperand*> operands;
for (auto& use : global.getResult().getUses())
operands.push_back(&use);

for (auto& use : operands)
{
rewriter.setInsertionPoint(use->getOwner());

auto ref = rewriter.create<mlir::rlc::Reference>(
use->getOwner()->getLoc(),
global.getResult().getType(),
global.getName());
use->set(ref);
}

global.erase();
}

return mlir::LogicalResult::success();
}

Expand Down
26 changes: 26 additions & 0 deletions lib/dialect/src/Operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,22 @@ mlir::LogicalResult mlir::rlc::WhileStatement::typeCheck(
return mlir::success();
}

mlir::LogicalResult mlir::rlc::ConstantGlobalOp::typeCheck(
mlir::rlc::ModuleBuilder &builder)
{
builder.getConverter().setErrorLocation(getLoc());
auto deducedType = builder.getConverter().convertType(getType());
auto shugarizedType = builder.getConverter().shugarizedConvertType(getType());
if (deducedType == nullptr or shugarizedType == nullptr)
{
return mlir::failure();
}

getResult().setType(deducedType);
setShugarizedTypeAttr(getShugarizedType()->replaceType(shugarizedType));
return mlir::success();
}

mlir::LogicalResult mlir::rlc::ConstructOp::typeCheck(
mlir::rlc::ModuleBuilder &builder)
{
Expand Down Expand Up @@ -1835,6 +1851,16 @@ mlir::LogicalResult mlir::rlc::FromByteArrayOp::typeCheck(
"are supported");
}

llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2>
mlir::rlc::ConstantGlobalOp::getShugarizedTypes()
{
llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2> toReturn;

if (getShugarizedType() != nullptr)
toReturn.push_back(*getShugarizedType());
return toReturn;
}

llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2>
mlir::rlc::ActionFunction::getShugarizedTypes()
{
Expand Down
23 changes: 20 additions & 3 deletions lib/dialect/src/Operations.td
Original file line number Diff line number Diff line change
Expand Up @@ -326,18 +326,35 @@ def RLC_UsingTypeOp : RLC_Dialect<"using_type", [DeclareOpInterfaceMethods<TypeC
}];
}

def RLC_ConstantGlobalArrayOp : RLC_Dialect<"global_array"> {
def RLC_ConstantGlobalOp : RLC_Dialect<"global", [DeclareOpInterfaceMethods<TypeCheckable>, DeclareOpInterfaceMethods<TypeUser>]> {
let summary = "constant global array";

let description = [{
expression.
}];

let arguments = (ins TypeAttrOf<AnyType>:$result, ArrayAttr:$values, StrAttr:$name);
let arguments = (ins AnyAttr:$values, StrAttr:$name, OptionalAttr<RLC_ShugarizedTypeAttr>:$shugarized_type);
let results = (outs AnyType:$result);

let assemblyFormat = [{
$name $values `:` $result attr-dict
$name $values `->` type($result) (`shugarized_type` `=` $shugarized_type^ )? attr-dict
}];

let builders = [
OpBuilder<(ins "mlir::Type":$type, "mlir::Attribute":$value, "mlir::StringRef":$name), [{
build($_builder, $_state, type, value, name, nullptr);
}]>];
}

def RLC_FlatConstantGlobalOp : RLC_Dialect<"flat_global"> {
let summary = "constant global array";

let description = [{
expression.
}];

let arguments = (ins TypeAttrOf<AnyType>:$type, AnyAttr:$values, StrAttr:$name);

}

def RLC_StringLiteralOp : RLC_Dialect<"string_literal", [DeclareOpInterfaceMethods<TypeCheckable>]> {
Expand Down
16 changes: 16 additions & 0 deletions lib/dialect/src/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ mlir::rlc::TypeTable mlir::rlc::makeTypeTable(mlir::ModuleOp mod)
table.add(
"StringLiteral", mlir::rlc::StringLiteralType::get(mod.getContext()));

for (auto constant : mod.getOps<mlir::rlc::ConstantGlobalOp>())
if (constant.getResult().getType().isa<mlir::rlc::IntegerType>())
table.add(
constant.getName(),
mlir::rlc::IntegerLiteralType::get(
constant.getContext(),
constant.getValues()
.cast<mlir::IntegerAttr>()
.getValue()
.getSExtValue()));

for (auto classDecl : mod.getOps<mlir::rlc::ClassDeclaration>())
table.add(classDecl.getName(), classDecl.getType());

Expand Down Expand Up @@ -491,6 +502,11 @@ mlir::rlc::ModuleBuilder::ModuleBuilder(
assert(result.second);
}

for (auto global : op.getOps<mlir::rlc::ConstantGlobalOp>())
{
getSymbolTable().add(global.getName(), global);
}

for (auto fun : getSymbolTable().get(
mlir::rlc::builtinOperatorName<mlir::rlc::InitOp>()))
{
Expand Down
26 changes: 22 additions & 4 deletions lib/lsp/src/LSP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ class mlir::rlc::lsp::LSPModuleInfoImpl
if (not op.getIsMemberFunction())
registerArgument(op.getUnmangledName(), op.getType(), op.getInfo());

for (auto op : module.getOps<mlir::rlc::ConstantGlobalOp>())
registerArgument(op.getName(), op.getType(), nullptr);

return mlir::success();
}

Expand All @@ -318,7 +321,9 @@ class mlir::rlc::lsp::LSPModuleInfoImpl
list.items.push_back(toReturn);
}

module.walk([&](mlir::rlc::ClassDeclaration op) {
for (mlir::rlc::ClassDeclaration op :
module.getOps<mlir::rlc::ClassDeclaration>())
{
mlir::lsp::CompletionItem toReturn;
toReturn.kind = mlir::lsp::CompletionItemKind::TypeParameter;
toReturn.label = op.getName().str();
Expand All @@ -333,14 +338,27 @@ class mlir::rlc::lsp::LSPModuleInfoImpl
params[params.size() - 1].cast<mlir::TypeAttr>().getValue());
toReturn.detail += ">";
list.items.push_back(toReturn);
});
}

module.walk([&](mlir::rlc::TypeAliasOp op) {
for (mlir::rlc::TypeAliasOp op : module.getOps<mlir::rlc::TypeAliasOp>())
{
mlir::lsp::CompletionItem toReturn;
toReturn.kind = mlir::lsp::CompletionItemKind::TypeParameter;
toReturn.label = op.getName();
list.items.push_back(toReturn);
});
}
for (mlir::rlc::ConstantGlobalOp op :
module.getOps<mlir::rlc::ConstantGlobalOp>())
{
if (op.getResult().getType() !=
mlir::rlc::IntegerType::getInt64(op.getContext()))
continue;
mlir::lsp::CompletionItem toReturn;
toReturn.kind = mlir::lsp::CompletionItemKind::TypeParameter;
toReturn.label = op.getName();
toReturn.detail = prettyType(op.getType());
list.items.push_back(toReturn);
}
}

bool sameLineAsOp(const mlir::lsp::Position &completePos, mlir::Operation *op)
Expand Down
1 change: 1 addition & 0 deletions lib/parser/include/rlc/parser/Lexer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace rlc
KeywordAssert,
KeywordDestroy,
KeywordContstruct,
KeywordConst,
KeywordCan,
KeywordBreak,
KeywordContinue,
Expand Down
1 change: 1 addition & 0 deletions lib/parser/include/rlc/parser/Parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ namespace rlc
bool templateFunction = true, bool isMemberFunction = false);
llvm::Expected<mlir::rlc::FunctionOp> externFunctionDeclaration();
llvm::Expected<mlir::Operation*> actionDeclaration(bool actionFunction);
llvm::Expected<mlir::rlc::ConstantGlobalOp> globalConstant();
llvm::Expected<mlir::rlc::ActionFunction> actionDefinition();
llvm::Expected<mlir::rlc::UncheckedTraitDefinition> traitDefinition();

Expand Down
5 changes: 5 additions & 0 deletions lib/parser/src/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ llvm::StringRef rlc::tokenToString(Token t)
return "KeywordAnd";
case Token::KeywordOr:
return "KeywordOr";
case Token::KeywordConst:
return "KeywordConst";
case Token::KeywordRule:
return "KeywordRule";
case Token::KeywordContinue:
Expand Down Expand Up @@ -430,6 +432,9 @@ Token Lexer::eatIdent()
if (name == "evn")
return Token::KeywordEvent;

if (name == "const")
return Token::KeywordConst;

if (name == "rul")
return Token::KeywordRule;

Expand Down
44 changes: 43 additions & 1 deletion lib/parser/src/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2085,6 +2085,41 @@ void Parser::setComment(mlir::Operation* op, llvm::StringRef comment)
op->setAttr("comment", builder.getStringAttr(comment));
}

llvm::Expected<mlir::rlc::ConstantGlobalOp> Parser::globalConstant()
{
auto location = getCurrentSourcePos();
EXPECT(Token::KeywordConst);
EXPECT(Token::Identifier);
std::string name = lIdent;
EXPECT(Token::Equal);
if (accept<Token::Double>())
return builder.create<mlir::rlc::ConstantGlobalOp>(
location,
mlir::rlc::FloatType::get(builder.getContext()),
builder.getF64FloatAttr(lDouble),
name);

if (accept<Token::Int64>())
return builder.create<mlir::rlc::ConstantGlobalOp>(
location,
mlir::rlc::IntegerType::getInt64(builder.getContext()),
builder.getIntegerAttr(builder.getIntegerType(64), lInt64),
name);

if (accept<Token::Character>())
return builder.create<mlir::rlc::ConstantGlobalOp>(
location,
mlir::rlc::IntegerType::getInt8(builder.getContext()),
builder.getIntegerAttr(builder.getIntegerType(8), lInt64),
name);

auto location2 = getCurrentSourcePos();
return make_error<RlcError>(
"Expected int, float or char",
RlcErrorCategory::errorCode(RlcErrorCode::unexpectedToken),
location2);
}

Expected<mlir::ModuleOp> Parser::system(mlir::ModuleOp destination)
{
auto location = getCurrentSourcePos();
Expand Down Expand Up @@ -2129,6 +2164,13 @@ Expected<mlir::ModuleOp> Parser::system(mlir::ModuleOp destination)
continue;
}

if (current == Token::KeywordConst)
{
TRY(f, globalConstant());
setComment(*f, comment);
continue;
}

if (current == Token::KeywordFun)
{
TRY(f, functionDefinition());
Expand Down Expand Up @@ -2180,7 +2222,7 @@ Expected<mlir::ModuleOp> Parser::system(mlir::ModuleOp destination)
}
auto location = getCurrentSourcePos();
return make_error<RlcError>(
"Expected function, action or class declaration",
"Expected declaration of function, action, class, constant or alias ",
RlcErrorCategory::errorCode(RlcErrorCode::unexpectedToken),
location);
}
Expand Down
Loading

0 comments on commit 5279ba0

Please sign in to comment.