From b45b8cd7f42166544a938cabc37b5aa7e9ce33f7 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Tue, 19 Nov 2024 20:12:07 +0000 Subject: [PATCH] [RTG] Add set type and operations --- include/circt/Dialect/RTG/IR/RTGOps.td | 45 ++++++++++++++++++++++++ include/circt/Dialect/RTG/IR/RTGTypes.td | 18 ++++++++++ lib/Dialect/RTG/IR/RTGOps.cpp | 40 +++++++++++++++++++++ test/Dialect/RTG/IR/basic.mlir | 14 ++++++++ 4 files changed, 117 insertions(+) diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index 83dd5220a3c3..4c79e50786e3 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -95,3 +95,48 @@ def InvokeSequenceOp : RTGOp<"invoke_sequence", []> { let assemblyFormat = "$sequence attr-dict"; } + +//===- Set Operations ------------------------------------------------------===// + +def SetCreateOp : RTGOp<"set_create", [Pure, SameTypeOperands]> { + let summary = "constructs a set of the given values"; + + let arguments = (ins Variadic:$elements); + let results = (outs SetType:$set); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def SetSelectRandomOp : RTGOp<"set_select_random", [ + Pure, + TypesMatchWith<"output must be of the element type of input set", + "set", "output", + "llvm::cast($_self).getElementType()"> +]> { + let summary = "selects an element uniformly at random from a set"; + let description = [{ + This operation returns an element from the given set uniformly at random. + Applying this operation to an empty set is undefined behavior. + }]; + + let arguments = (ins SetType:$set); + let results = (outs AnyType:$output); + + let assemblyFormat = "$set `:` qualified(type($set)) attr-dict"; +} + +def SetDifferenceOp : RTGOp<"set_difference", [ + Pure, + AllTypesMatch<["original", "diff", "output"]> +]> { + let summary = "computes the difference of two sets"; + + let arguments = (ins SetType:$original, + SetType:$diff); + let results = (outs SetType:$output); + + let assemblyFormat = [{ + $original `,` $diff `:` qualified(type($output)) attr-dict + }]; +} diff --git a/include/circt/Dialect/RTG/IR/RTGTypes.td b/include/circt/Dialect/RTG/IR/RTGTypes.td index b93340b28679..a76c1dd1fcf8 100644 --- a/include/circt/Dialect/RTG/IR/RTGTypes.td +++ b/include/circt/Dialect/RTG/IR/RTGTypes.td @@ -29,4 +29,22 @@ def SequenceType : RTGTypeDef<"Sequence"> { let assemblyFormat = ""; } +def SetType : RTGTypeDef<"Set"> { + let summary = "a set of values"; + let description = [{ + This type represents a standard set datastructure. It does not make any + assumptions about the underlying implementation. Thus a hash set, tree set, + etc. can be used in a backend. + }]; + + let parameters = (ins "::mlir::Type":$elementType); + + let mnemonic = "set"; + let assemblyFormat = "`<` $elementType `>`"; +} + +class SetTypeOf : ContainerType< + elementType, SetType.predicate, + "llvm::cast($_self).getElementType()", "set">; + #endif // CIRCT_DIALECT_RTG_IR_RTGTYPES_TD diff --git a/lib/Dialect/RTG/IR/RTGOps.cpp b/lib/Dialect/RTG/IR/RTGOps.cpp index f7e105dbe9a3..6aebf52844f9 100644 --- a/lib/Dialect/RTG/IR/RTGOps.cpp +++ b/lib/Dialect/RTG/IR/RTGOps.cpp @@ -38,6 +38,46 @@ SequenceClosureOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// SetCreateOp +//===----------------------------------------------------------------------===// + +ParseResult SetCreateOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SmallVector operands; + Type elemType; + + if (parser.parseOperandList(operands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseType(elemType)) + return failure(); + + result.addTypes({SetType::get(result.getContext(), elemType)}); + + for (auto operand : operands) + if (parser.resolveOperand(operand, elemType, result.operands)) + return failure(); + + return success(); +} + +void SetCreateOp::print(OpAsmPrinter &p) { + p << " "; + p.printOperands(getElements()); + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : " << getSet().getType().getElementType(); +} + +LogicalResult SetCreateOp::verify() { + if (getElements().size() > 0) { + // We only need to check the first element because of the `SameTypeOperands` + // trait. + if (getElements()[0].getType() != getSet().getType().getElementType()) + return emitOpError() << "operand types must match set element type"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen generated logic. //===----------------------------------------------------------------------===// diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index fe59271d0c06..eae6e394a66f 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -24,3 +24,17 @@ rtg.sequence @invocations { rtg.invoke_sequence %0 rtg.invoke_sequence %1 } + +// CHECK-LABEL: @sets +func.func @sets(%arg0: i32, %arg1: i32) { + // CHECK: [[SET:%.+]] = rtg.set_create %arg0, %arg1 : i32 + // CHECK: [[R:%.+]] = rtg.set_select_random [[SET]] : !rtg.set + // CHECK: [[EMPTY:%.+]] = rtg.set_create : i32 + // CHECK: rtg.set_difference [[SET]], [[EMPTY]] : !rtg.set + %set = rtg.set_create %arg0, %arg1 : i32 + %r = rtg.set_select_random %set : !rtg.set + %empty = rtg.set_create : i32 + %diff = rtg.set_difference %set, %empty : !rtg.set + + return +}