Skip to content

Commit

Permalink
various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
drblallo committed Dec 15, 2024
1 parent 89a8fec commit da3c2c8
Show file tree
Hide file tree
Showing 33 changed files with 137 additions and 67 deletions.
2 changes: 1 addition & 1 deletion lib/dialect/src/EmitImplicitAssignPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ namespace mlir::rlc
fun.getLoc(), toAssignType, block->getArgument(0));
auto contructed = rewriter.create<mlir::rlc::ConstructOp>(
fun.getLoc(), casted.getResult().getType());
assert(builder.isTemplateType(contructed.getType()));
assert(!builder.isTemplateType(contructed.getType()));
auto* call = builder.emitCall(
fun,
true,
Expand Down
67 changes: 39 additions & 28 deletions lib/dialect/src/Operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ static mlir::LogicalResult declareActionTypes(
mlir::rlc::ActionFunction mlir::rlc::detail::typeCheckAction(
mlir::rlc::ActionFunction fun, mlir::rlc::ValueTable *parentSymbolTable)
{
bool emitClasses = fun->hasAttr("emit_classes");
if (fun.getUnmangledName() == "init")
{
auto _ = logError(
Expand Down Expand Up @@ -544,10 +545,13 @@ mlir::rlc::ActionFunction mlir::rlc::detail::typeCheckAction(
if (not newF)
return nullptr;

mlir::rlc::ModuleBuilder builder2(
newF->getParentOfType<mlir::ModuleOp>(), parentSymbolTable);
if (declareActionTypes(newF, builder2).failed())
return nullptr;
if (emitClasses)
{
mlir::rlc::ModuleBuilder builder2(
newF->getParentOfType<mlir::ModuleOp>(), parentSymbolTable);
if (declareActionTypes(newF, builder2).failed())
return nullptr;
}

if (newF.getActions().empty())
{
Expand Down Expand Up @@ -755,30 +759,18 @@ mlir::LogicalResult mlir::rlc::SubActionStatement::typeCheck(
.getDefiningOp<mlir::rlc::ActionFunction>();

mlir::IRRewriter &rewiter = builder.getRewriter();
rewiter.setInsertionPoint(*this);

mlir::Value decl = nullptr;
if (getName().empty())
{
auto terminator =
mlir::cast<mlir::rlc::Yield>(getBody().front().getTerminator());
while (not getBody().front().empty())
getBody().front().front().moveBefore(getOperation());

decl = terminator.getArguments()[0];
terminator.erase();
}
else
if (not getName().empty())
{
rewiter.setInsertionPoint(*this);
auto varDecl = rewiter.create<mlir::rlc::DeclarationStatement>(
getLoc(), mlir::rlc::FrameType::get(underlyingType), getName());
varDecl.getBody().takeBody(getBody());
builder.getSymbolTable().add(getName(), varDecl);
decl = varDecl;
rewiter.createBlock(&getBody());
rewiter.create<mlir::rlc::Yield>(getLoc(), mlir::ValueRange({ varDecl }));
}

frameVar = builder.getRewriter().create<mlir::rlc::StorageCast>(
getLoc(), underlyingType, decl);
rewiter.setInsertionPoint(*this);

llvm::SmallVector<mlir::Value, 2> actionValues = underlying.getActions();
if (actionValues.empty())
Expand All @@ -787,13 +779,29 @@ mlir::LogicalResult mlir::rlc::SubActionStatement::typeCheck(
return mlir::success();
}

const auto &make_frame_access = [this, &rewiter, underlyingType]() {
mlir::IRMapping mapping;
for (auto &op : llvm::drop_end(getBody().front().getOperations()))
rewiter.clone(op, mapping);

auto yielded =
mlir::cast<mlir::rlc::Yield>(getBody().front().getTerminator())
.getArguments()[0];
mlir::Value frame = mapping.getValueMap().contains(yielded)
? mapping.getValueMap().at(yielded)
: yielded;

return rewiter.create<mlir::rlc::StorageCast>(
getLoc(), underlyingType, frame);
};

if (not getRunOnce())
{
auto loop = rewiter.create<mlir::rlc::WhileStatement>(getLoc());
rewiter.createBlock(&loop.getCondition());

auto *call = builder.emitCall(
*this, true, "is_done", mlir::ValueRange({ frameVar }));
*this, true, "is_done", mlir::ValueRange({ make_frame_access() }));
assert(call);
auto isDone = call->getResult(0);

Expand Down Expand Up @@ -900,7 +908,8 @@ mlir::LogicalResult mlir::rlc::SubActionStatement::typeCheck(

for (auto arg : llvm::drop_begin(newBody->getArguments(), contextArgsCount))
canArgs.push_back(arg);
canArgs.insert(canArgs.begin(), frameVar);

canArgs.insert(canArgs.begin(), make_frame_access());

auto casted = rewiter.create<mlir::rlc::CanOp>(actions.getLoc(), toCall);
auto result = rewiter.create<mlir::rlc::CallOp>(
Expand Down Expand Up @@ -928,7 +937,7 @@ mlir::LogicalResult mlir::rlc::SubActionStatement::typeCheck(
for (auto result : llvm::drop_begin(fixed.getResults(), contextArgsCount))
args.push_back(result);

args.insert(args.begin(), frameVar);
args.insert(args.begin(), make_frame_access());

rewiter.create<mlir::rlc::CallOp>(actions.getLoc(), toCall, false, args);

Expand Down Expand Up @@ -1283,8 +1292,8 @@ mlir::LogicalResult mlir::rlc::UncheckedTraitDefinition::typeCheck(
rewriter.setInsertionPointAfter(*this);
auto op = rewriter.create<mlir::rlc::TraitDefinition>(getLoc(), type);

// replace the template parameters provided by the user with one prefixed with
// TraitType so that it does not clashes with regular names
// replace the template parameters provided by the user with one prefixed
// with TraitType so that it does not clashes with regular names
for (auto templateParameter : getTemplateParameterTypes())
{
mlir::AttrTypeReplacer replacer;
Expand Down Expand Up @@ -1557,8 +1566,10 @@ mlir::LogicalResult mlir::rlc::ForFieldStatement::typeCheck(
return logError(
*this,
"Missmatched count between for induction variables and for "
"arguments. The number of induction varaibles must be exactly the same "
"as the number of expressions, or one more to capture the name of the "
"arguments. The number of induction varaibles must be exactly the "
"same "
"as the number of expressions, or one more to capture the name of "
"the "
"field");
}

Expand Down
2 changes: 2 additions & 0 deletions lib/parser/include/rlc/parser/Lexer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace rlc
KeywordFree,
KeywordRef,
KeywordMangledName,
KeywordActionClass,
KeywordAsPtr,
KeywordToArray,
KeywordFromArray,
Expand Down Expand Up @@ -90,6 +91,7 @@ namespace rlc
Module,
LPar,
RPar,
AnnotationIntroducer,
RSquare,
LSquare,
LBracket,
Expand Down
9 changes: 9 additions & 0 deletions lib/parser/src/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ llvm::StringRef rlc::tokenToString(Token t)
return "KeywordIs";
case Token::KeywordBitXor:
return "KeywordBitXor";
case Token::KeywordActionClass:
return "KeywordActionClass";
case Token::KeywordBitAnd:
return "KeywordBitAnd";
case Token::KeywordAnd:
Expand Down Expand Up @@ -211,6 +213,8 @@ llvm::StringRef rlc::tokenToString(Token t)
return "Tilde";
case Token::Equal:
return "Equal";
case Token::AnnotationIntroducer:
return "AnnotationIntroducer";
case Token::Identifier:
return "Identifier";
case Token::String:
Expand Down Expand Up @@ -376,6 +380,8 @@ optional<Token> Lexer::eatSymbol()
return Token::Newline;
case '|':
return Token::VerticalPipe;
case '@':
return Token::AnnotationIntroducer;
}
return nullopt;
}
Expand Down Expand Up @@ -468,6 +474,9 @@ Token Lexer::eatIdent()
if (name == "frm")
return Token::KeywordFrame;

if (name == "classes")
return Token::KeywordActionClass;

if (name == "ctx")
return Token::KeywordCtx;

Expand Down
77 changes: 42 additions & 35 deletions lib/parser/src/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,54 +1012,52 @@ llvm::Expected<mlir::rlc::SubActionStatement> Parser::subActionStatement()

llvm::SmallVector<mlir::Value, 3> expressions;
auto insertionPoint = builder.saveInsertionPoint();
mlir::Region region;
builder.createBlock(&region);

mlir::Region forwardedArgsRegion;
mlir::Region bodyRegion;
std::string name;

auto onExit = [&, this](bool success) -> mlir::rlc::SubActionStatement {
builder.restoreInsertionPoint(insertionPoint);
if (not success)
return nullptr;
auto operation =
builder.create<mlir::rlc::SubActionStatement>(location, name, runOnce);
operation.getBody().takeBody(bodyRegion);
operation.getForwardedArgs().takeBody(forwardedArgsRegion);
return operation;
};

builder.createBlock(&forwardedArgsRegion);
if (accept<Token::LPar>())
{
TRY(list, argumentExpressionList());
expressions = *list;
EXPECT(Token::RPar);
}
builder.create<mlir::rlc::Yield>(getCurrentSourcePos(), expressions);
builder.restoreInsertionPoint(insertionPoint);

std::string name;
accept(Token::Identifier);
name = lIdent;
// is we don't see a equal it means that the
// user has written `subaction name `
if (not accept(Token::Equal))
builder.createBlock(&bodyRegion);
TRY(exp, expression(), onExit(false));
auto maybeName =
mlir::dyn_cast<mlir::rlc::UnresolvedReference>(*exp->getDefiningOp());
if (!maybeName or not accept(Token::Equal))
{
auto operation =
builder.create<mlir::rlc::SubActionStatement>(location, "", runOnce);
operation.getForwardedArgs().takeBody(region);
builder.createBlock(&operation.getBody());
auto exp = builder.create<mlir::rlc::UnresolvedReference>(location, name);
builder.create<mlir::rlc::Yield>(
getCurrentSourcePos(), mlir::ValueRange(exp));
builder.setInsertionPointAfter(operation);
EXPECT(Token::Newline);
return operation;
builder.create<mlir::rlc::Yield>(
getCurrentSourcePos(), mlir::ValueRange(*exp));
return onExit(true);
}

auto operation =
builder.create<mlir::rlc::SubActionStatement>(location, name, runOnce);
operation.getForwardedArgs().takeBody(region);
builder.createBlock(&operation.getBody());
auto onExit = [&, this](mlir::Value exp) {
if (exp)
builder.create<mlir::rlc::Yield>(
getCurrentSourcePos(), mlir::ValueRange(exp));
builder.setInsertionPointAfter(operation);
if (not exp)
operation.erase();
};
name = maybeName.getName();
(*exp).getDefiningOp()->erase();

TRY(exp, expression(), onExit(nullptr));
EXPECT(Token::Newline, onExit(*exp));
onExit(*exp);
builder.setInsertionPointToStart(&bodyRegion.front());
TRY(body, expression(), onExit(false));
EXPECT(Token::Newline);
builder.create<mlir::rlc::Yield>(
getCurrentSourcePos(), mlir::ValueRange(*body));

return operation;
return onExit(true);
}

llvm::Expected<mlir::rlc::ActionsStatement> Parser::actionsStatement()
Expand Down Expand Up @@ -2205,9 +2203,18 @@ Expected<mlir::ModuleOp> Parser::system(mlir::ModuleOp destination)
continue;
}

bool emitClasses = false;
if (accept(Token::AnnotationIntroducer))
{
EXPECT(Token::KeywordActionClass);
EXPECT(Token::Newline);
emitClasses = true;
}
if (current == Token::KeywordAction)
{
TRY(f, actionDefinition());
if (emitClasses)
f->getOperation()->setAttr("emit_classes", builder.getBoolAttr(true));
setComment(*f, comment);
continue;
}
Expand Down
1 change: 1 addition & 0 deletions stdlib/learn.rl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import action


fun gen_printer_parser():
let state : Game
let any_action : AnyGameAction
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/src/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ static int run(
const mlir::rlc::TargetInfo &info)
{
mlir::PassManager manager(&context);
manager.enableVerifier(isDebug);
driver.configurePassManager(manager);

if (timing)
Expand Down
2 changes: 2 additions & 0 deletions tool/rlc/test/action_trait.rl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# RUN: rlc %s -o %t -i %stdlib
# RUN: %t%exeext

@classes
act action() -> Action:
frm to_return = 0
act to_call(Int arg)
to_return = arg


fun main() -> Int:
let x : ActionToCall
if x is ActionAction:
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/all_actions_alternative_type.rl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import string

@classes
act play() -> Name:
act first()
act second(Int asd)
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/alternative_actions_print.rl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import serialization.print
import string

@classes
act to_run() -> ToRun:
act first(Bool asd)
act second(Int tasd)
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/bug_multiple_can_apply_definitions.rl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import action
fun crash_on_five(Int input) -> Int {input != 5}:
return 0

@classes
act play() -> Play:
frm current = 0
while current != 7:
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/enumeration_errors.rl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cls TopLevel:
Bool[3] c
Vector<BInt<0, 3>> f

@classes
act play() -> Game:
# CHECK: GameDoNothing.a is of type Int, which is not enumerable. Replace it instead with a BInt with appropriate bounds or specify yourself how to enumerate it.
act do_nothing(Int a, Bool b)
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/examples/2players_texas_holdem.rl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ fun resolve_winner(Player[2] players):
players[0].chips = players[0].chips - players[1].current_bet
players[1].chips = players[1].chips + players[1].current_bet

@classes
act play() -> Game:
frm players : Player[2]
players[0].chips = 100
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/examples/battleship.rl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ fun make_player_ships() -> BoundedVector<Ship, 5>:

return ships_to_place

@classes
act play() -> Game:
frm current_player : Bool
frm players : HiddenInformation<Player>[2]
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/examples/black_jack.rl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ fun calculate_points(BoundedVector<Card, 20> hand) -> Int:
# must return a Game, otherwise the machine
# learning components will not know what
# to look for.
@classes
act play() -> Game:

# allocates a deck and initializes it
Expand Down
1 change: 1 addition & 0 deletions tool/rlc/test/examples/checkers.rl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ fun can_move_anything(Int player_id, Board b) -> Bool:
x = x + 1
return false

@classes
act play() -> Game:
frm current_player : Bool
frm board = make_board()
Expand Down
Loading

0 comments on commit da3c2c8

Please sign in to comment.