diff --git a/gematria/basic_block/basic_block.cc b/gematria/basic_block/basic_block.cc index c97a29dd..466cab16 100644 --- a/gematria/basic_block/basic_block.cc +++ b/gematria/basic_block/basic_block.cc @@ -229,27 +229,63 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand) { return os; } +Annotation::Annotation(std::string name, double value) + : name(std::move(name)), value(value) {} + +bool Annotation::operator==(const Annotation& other) const { + const auto as_tuple = [](const Annotation& annotation) { + return std::tie(annotation.name, annotation.value); + }; + return as_tuple(*this) == as_tuple(other); +} + +std::string Annotation::ToString() const { + std::stringstream buffer; + buffer << "Annotation("; + if (!name.empty()) { + buffer << "name='" << name << "', "; + } + if (value != -1) { + buffer << "value=" << value << ", "; + } + // If we added any keyword args to the buffer, drop the last two characters + // (a comma and a space). This is not strictly necessary, but it looks better. + auto msg = buffer.str(); + assert(msg.size() >= 2); + if (msg.back() == ' ') msg.resize(msg.size() - 2); + msg.push_back(')'); + return msg; +} + +std::ostream& operator<<(std::ostream& os, const Annotation& annotation) { + os << annotation.ToString(); + return os; +} + Instruction::Instruction( std::string mnemonic, std::string llvm_mnemonic, std::vector prefixes, std::vector input_operands, std::vector implicit_input_operands, std::vector output_operands, - std::vector implicit_output_operands) + std::vector implicit_output_operands, + std::vector instruction_annotations) : mnemonic(std::move(mnemonic)), llvm_mnemonic(std::move(llvm_mnemonic)), prefixes(std::move(prefixes)), input_operands(std::move(input_operands)), implicit_input_operands(std::move(implicit_input_operands)), output_operands(std::move(output_operands)), - implicit_output_operands(std::move(implicit_output_operands)) {} + implicit_output_operands(std::move(implicit_output_operands)), + instruction_annotations(std::move(instruction_annotations)) {} bool Instruction::operator==(const Instruction& other) const { const auto as_tuple = [](const Instruction& instruction) { return std::tie( instruction.mnemonic, instruction.llvm_mnemonic, instruction.prefixes, instruction.input_operands, instruction.implicit_input_operands, - instruction.output_operands, instruction.implicit_output_operands); + instruction.output_operands, instruction.implicit_output_operands, + instruction.instruction_annotations); }; return as_tuple(*this) == as_tuple(other); } @@ -292,6 +328,17 @@ std::string Instruction::ToString() const { add_operand_list("output_operands", output_operands); add_operand_list("implicit_output_operands", implicit_output_operands); + if (!instruction_annotations.empty()) { + buffer << "instruction_annotations=("; + for (const Annotation& annotation : instruction_annotations) { + buffer << annotation.ToString() << ", "; + } + // Pop only the trailing space. For simplicity, we leave the trailing comma + // which is required in case there is only one element. + buffer.seekp(-1, std::ios_base::end); + buffer << "), "; + } + auto msg = buffer.str(); assert(msg.size() >= 2); if (msg.back() == ' ') msg.resize(msg.size() - 2); diff --git a/gematria/basic_block/basic_block.h b/gematria/basic_block/basic_block.h index 46a8b0b7..43579fb7 100644 --- a/gematria/basic_block/basic_block.h +++ b/gematria/basic_block/basic_block.h @@ -218,6 +218,31 @@ class InstructionOperand { std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand); +// Represents an annotation holding a value such as some measure/statistic +// paired with the instruction. +struct Annotation { + Annotation() : value(-1){}; + + // Initializes all fields of the annotation. + Annotation(std::string name, double value); + + Annotation(const Annotation&) = default; + Annotation(Annotation&&) = default; + + Annotation& operator=(const Annotation&) = default; + Annotation& operator=(Annotation&&) = default; + + bool operator==(const Annotation& other) const; + bool operator!=(const Annotation& other) const { return !(*this == other); } + + std::string ToString() const; + + std::string name; + double value; +}; + +std::ostream& operator<<(std::ostream& os, const Annotation& annotation); + // Represents a single instruction. struct Instruction { Instruction() {} @@ -229,7 +254,9 @@ struct Instruction { std::vector input_operands, std::vector implicit_input_operands, std::vector output_operands, - std::vector implicit_output_operands); + std::vector implicit_output_operands, + std::vector instruction_annotations = + std::vector{}); Instruction(const Instruction&) = default; Instruction(Instruction&&) = default; @@ -280,6 +307,11 @@ struct Instruction { // to the ML models explicitly. std::vector implicit_output_operands; + // The list of instruction level annotations used to supply additional + // information to the model. Currently includes the cache miss frequency of + // the instruction. Used to better model the overhead coming from LLC misses. + std::vector instruction_annotations; + // The address of the instruction. uint64_t address = 0; // The size of the instruction. diff --git a/gematria/basic_block/basic_block_protos.cc b/gematria/basic_block/basic_block_protos.cc index 16abd847..a65f8e3e 100644 --- a/gematria/basic_block/basic_block_protos.cc +++ b/gematria/basic_block/basic_block_protos.cc @@ -20,11 +20,37 @@ #include #include "gematria/basic_block/basic_block.h" +#include "gematria/proto/annotation.pb.h" #include "gematria/proto/canonicalized_instruction.pb.h" #include "google/protobuf/repeated_ptr_field.h" namespace gematria { +namespace { + +template +std::vector ToVector( + const google::protobuf::RepeatedPtrField& protos, + Convertor object_from_proto) { + std::vector result(protos.size()); + std::transform(protos.begin(), protos.end(), result.begin(), + object_from_proto); + return result; +} + +template +void ToRepeatedPtrField( + const std::vector& objects, + google::protobuf::RepeatedPtrField* repeated_field, + Convertor proto_from_object) { + repeated_field->Reserve(objects.size()); + std::transform(objects.begin(), objects.end(), + google::protobuf::RepeatedFieldBackInserter(repeated_field), + proto_from_object); +} + +} // namespace + AddressTuple AddressTupleFromProto( const CanonicalizedOperandProto::AddressTuple& proto) { return AddressTuple( @@ -91,29 +117,19 @@ CanonicalizedOperandProto ProtoFromInstructionOperand( return proto; } -namespace { - -std::vector ToVector( - const google::protobuf::RepeatedPtrField& - protos) { - std::vector result(protos.size()); - std::transform(protos.begin(), protos.end(), result.begin(), - InstructionOperandFromProto); - return result; +Annotation AnnotationFromProto(const AnnotationProto& proto) { + return Annotation( + /* name = */ proto.name(), + /* value = */ proto.value()); } -void ToRepeatedPtrField( - const std::vector& operands, - google::protobuf::RepeatedPtrField* - repeated_field) { - repeated_field->Reserve(operands.size()); - std::transform(operands.begin(), operands.end(), - google::protobuf::RepeatedFieldBackInserter(repeated_field), - ProtoFromInstructionOperand); +AnnotationProto ProtoFromAnnotation(const Annotation& annotation) { + AnnotationProto proto; + proto.set_name(annotation.name); + proto.set_value(annotation.value); + return proto; } -} // namespace - Instruction InstructionFromProto(const CanonicalizedInstructionProto& proto) { return Instruction( /* mnemonic = */ proto.mnemonic(), @@ -121,11 +137,21 @@ Instruction InstructionFromProto(const CanonicalizedInstructionProto& proto) { /* prefixes = */ std::vector(proto.prefixes().begin(), proto.prefixes().end()), - /* input_operands = */ ToVector(proto.input_operands()), - /* implicit_input_operands = */ ToVector(proto.implicit_input_operands()), - /* output_operands = */ ToVector(proto.output_operands()), + /* input_operands = */ + ToVector(proto.input_operands(), + InstructionOperandFromProto), + /* implicit_input_operands = */ + ToVector(proto.implicit_input_operands(), + InstructionOperandFromProto), + /* output_operands = */ + ToVector(proto.output_operands(), + InstructionOperandFromProto), /* implicit_output_operands = */ - ToVector(proto.implicit_output_operands())); + ToVector(proto.implicit_output_operands(), + InstructionOperandFromProto), + /* instruction_annotations = */ + ToVector(proto.instruction_annotations(), + AnnotationFromProto)); } CanonicalizedInstructionProto ProtoFromInstruction( @@ -135,33 +161,27 @@ CanonicalizedInstructionProto ProtoFromInstruction( proto.set_llvm_mnemonic(instruction.llvm_mnemonic); proto.mutable_prefixes()->Assign(instruction.prefixes.begin(), instruction.prefixes.end()); - ToRepeatedPtrField(instruction.input_operands, - proto.mutable_input_operands()); + ToRepeatedPtrField(instruction.input_operands, proto.mutable_input_operands(), + ProtoFromInstructionOperand); ToRepeatedPtrField(instruction.implicit_input_operands, - proto.mutable_implicit_input_operands()); + proto.mutable_implicit_input_operands(), + ProtoFromInstructionOperand); ToRepeatedPtrField(instruction.output_operands, - proto.mutable_output_operands()); + proto.mutable_output_operands(), + ProtoFromInstructionOperand); ToRepeatedPtrField(instruction.implicit_output_operands, - proto.mutable_implicit_output_operands()); + proto.mutable_implicit_output_operands(), + ProtoFromInstructionOperand); + ToRepeatedPtrField(instruction.instruction_annotations, + proto.mutable_instruction_annotations(), + ProtoFromAnnotation); return proto; } -namespace { - -std::vector ToVector( - const google::protobuf::RepeatedPtrField& - protos) { - std::vector result(protos.size()); - std::transform(protos.begin(), protos.end(), result.begin(), - InstructionFromProto); - return result; -} - -} // namespace - BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) { return BasicBlock( - /* instructions = */ ToVector(proto.canonicalized_instructions())); + /* instructions = */ ToVector( + proto.canonicalized_instructions(), InstructionFromProto)); } } // namespace gematria diff --git a/gematria/basic_block/basic_block_protos.h b/gematria/basic_block/basic_block_protos.h index 2d2d9b04..4dc6d4b4 100644 --- a/gematria/basic_block/basic_block_protos.h +++ b/gematria/basic_block/basic_block_protos.h @@ -41,6 +41,12 @@ InstructionOperand InstructionOperandFromProto( CanonicalizedOperandProto ProtoFromInstructionOperand( const InstructionOperand& operand); +// Creates an annotation data structure from a proto. +Annotation AnnotationFromProto(const AnnotationProto& proto); + +// Creates a proto representing the given annotation. +AnnotationProto ProtoFromAnnotation(const Annotation& annotation); + // Creates an instruction data structure from a proto. Instruction InstructionFromProto(const CanonicalizedInstructionProto& proto); diff --git a/gematria/basic_block/basic_block_protos_test.cc b/gematria/basic_block/basic_block_protos_test.cc index b667de88..c23f8881 100644 --- a/gematria/basic_block/basic_block_protos_test.cc +++ b/gematria/basic_block/basic_block_protos_test.cc @@ -144,6 +144,26 @@ TEST(ProtoFromInstructionOperandTest, Memory) { EqualsProto(R"pb(memory { alias_group_id: 123 })pb")); } +TEST(AnnotationFromProtoTest, AllFields) { + const AnnotationProto proto = ParseTextProto(R"pb( + name: 'cache_miss_freq' + value: 0.875 + )pb"); + Annotation annotation = AnnotationFromProto(proto); + EXPECT_EQ(annotation.name, "cache_miss_freq"); + EXPECT_EQ(annotation.value, 0.875); +} + +TEST(ProtoFromAnnotationTest, AllFields) { + const Annotation annotation( + /* name = */ "cache_miss_freq", + /* value = */ 0.875); + EXPECT_THAT(ProtoFromAnnotation(annotation), EqualsProto(R"pb( + name: 'cache_miss_freq' + value: 0.875 + )pb")); +} + TEST(InstructionFromProtoTest, AllFields) { const CanonicalizedInstructionProto proto = ParseTextProto(R"pb( mnemonic: "ADC" @@ -156,6 +176,7 @@ TEST(InstructionFromProtoTest, AllFields) { implicit_output_operands { register_name: "EFLAGS" } implicit_input_operands { register_name: "EFLAGS" } implicit_input_operands { immediate_value: 1 } + instruction_annotations { name: "cache_miss_freq" value: 0.875 } )pb"); Instruction instruction = InstructionFromProto(proto); EXPECT_EQ( @@ -169,7 +190,9 @@ TEST(InstructionFromProtoTest, AllFields) { InstructionOperand::ImmediateValue(1)}, /* output_operands = */ {InstructionOperand::Register("RAX")}, /* implicit_output_operands = */ - {InstructionOperand::Register("EFLAGS")})); + {InstructionOperand::Register("EFLAGS")}, + /* instruction_annotations = */ + {Annotation("cache_miss_freq", 0.875)})); } TEST(ProtoFromInstructionTest, AllFields) { @@ -205,6 +228,7 @@ TEST(BasicBlockFromProtoTest, SomeInstructions) { llvm_mnemonic: "MOV64rr" output_operands: { register_name: "RCX" } input_operands: { register_name: "RAX" } + instruction_annotations: { name: "cache_miss_freq" value: 0.875 } } canonicalized_instructions: { mnemonic: "NOT" @@ -223,7 +247,9 @@ TEST(BasicBlockFromProtoTest, SomeInstructions) { /* input_operands = */ {InstructionOperand::Register("RAX")}, /* implicit_input_operands = */ {}, /* output_operands = */ {InstructionOperand::Register("RCX")}, - /* implicit_output_operands = */ {}), + /* implicit_output_operands = */ {}, + /* instruction_annotations = */ + {Annotation("cache_miss_freq", 0.875)}), Instruction( /* mnemonic = */ "NOT", /* llvm_mnemonic = */ "NOT64r", @@ -231,7 +257,8 @@ TEST(BasicBlockFromProtoTest, SomeInstructions) { /* input_operands = */ {InstructionOperand::Register("RCX")}, /* implicit_input_operands = */ {}, /* output_operands = */ {InstructionOperand::Register("RCX")}, - /* implicit_output_operands = */ {})})); + /* implicit_output_operands = */ {}, + /* instruction_annotations = */ {})})); } } // namespace diff --git a/gematria/basic_block/basic_block_test.cc b/gematria/basic_block/basic_block_test.cc index c77f3aea..3977c560 100644 --- a/gematria/basic_block/basic_block_test.cc +++ b/gematria/basic_block/basic_block_test.cc @@ -318,6 +318,28 @@ TEST(InstructionOperandTest, AsTokenList) { } } +// TODO(virajbshah): Add tests for Annotation. +TEST(AnnotationTest, Constructor) { + constexpr char kName[] = "cache_miss_freq"; + constexpr double kValue = 0.875; + + const Annotation annotation( + /* name = */ kName, + /* value = */ kValue); + EXPECT_EQ(annotation.name, kName); + EXPECT_EQ(annotation.value, kValue); +} + +TEST(AnnotationTest, ToString) { + const Annotation annotation( + /* name = */ "cache_miss_freq", + /* value = */ 0.875); + + constexpr char kExpectedString[] = + "Annotation(name='cache_miss_freq', value=0.875)"; + EXPECT_EQ(annotation.ToString(), kExpectedString); +} + TEST(InstructionTest, Constructor) { constexpr char kMnemonic[] = "MOV"; constexpr char kLlvmMnemonic[] = "MOV32rr"; @@ -330,6 +352,8 @@ TEST(InstructionTest, Constructor) { InstructionOperand::MemoryLocation(3)}; const std::vector kImplicitOutputOperands = { InstructionOperand::Register("EFLAGS")}; + const std::vector kInstructionAnnotations = { + Annotation("cache_miss_freq", 0.875)}; const Instruction instruction( /* mnemonic = */ kMnemonic, @@ -338,7 +362,8 @@ TEST(InstructionTest, Constructor) { /* input_operands = */ kInputOperands, /* implicit_input_operands = */ kImplicitInputOperands, /* output_operands = */ kOutputOperands, - /* implicit_output_operands = */ kImplicitOutputOperands); + /* implicit_output_operands = */ kImplicitOutputOperands, + /* instruction_annotations = */ kInstructionAnnotations); EXPECT_EQ(instruction.mnemonic, kMnemonic); EXPECT_EQ(instruction.llvm_mnemonic, kLlvmMnemonic); EXPECT_EQ(instruction.prefixes, kPrefixes); @@ -346,6 +371,7 @@ TEST(InstructionTest, Constructor) { EXPECT_EQ(instruction.implicit_input_operands, kImplicitInputOperands); EXPECT_EQ(instruction.output_operands, kOutputOperands); EXPECT_EQ(instruction.implicit_output_operands, kImplicitOutputOperands); + EXPECT_EQ(instruction.instruction_annotations, kInstructionAnnotations); } TEST(InstructionTest, AsTokenList) { @@ -360,6 +386,8 @@ TEST(InstructionTest, AsTokenList) { InstructionOperand::MemoryLocation(3)}; const std::vector kImplicitOutputOperands = { InstructionOperand::Register("EFLAGS")}; + const std::vector kInstructionAnnotations = { + Annotation("cache_miss_freq", 0.875)}; const Instruction instruction( /* mnemonic = */ kMnemonic, @@ -368,7 +396,8 @@ TEST(InstructionTest, AsTokenList) { /* input_operands = */ kInputOperands, /* implicit_input_operands = */ kImplicitInputOperands, /* output_operands = */ kOutputOperands, - /* implicit_output_operands = */ kImplicitOutputOperands); + /* implicit_output_operands = */ kImplicitOutputOperands, + /* instruction_annotations = */ kInstructionAnnotations); EXPECT_THAT(instruction.AsTokenList(), ElementsAre(kPrefixes[0], kPrefixes[1], kMnemonic, @@ -387,7 +416,9 @@ TEST(InstructionTest, ToString) { /* implicit_input_operands = */ {InstructionOperand::Register("EFLAGS")}, /* output_operands = */ {InstructionOperand::Register("RAX")}, /* implicit_output_operands = */ - {InstructionOperand::Register("EFLAGS")}); + {InstructionOperand::Register("EFLAGS")}, + /* instruction_annotations = */ + {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}); constexpr char kExpectedString[] = "Instruction(mnemonic='ADC', llvm_mnemonic='ADC32rr', " "prefixes=('LOCK',), " @@ -395,7 +426,9 @@ TEST(InstructionTest, ToString) { "InstructionOperand.from_register('RBX'),), " "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " "output_operands=(InstructionOperand.from_register('RAX'),), " - "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),))"; + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),))"; EXPECT_EQ(instruction.ToString(), kExpectedString); } @@ -492,7 +525,10 @@ TEST(BasicBlockTest, Constructor) { /* implicit_input_operands = */ {InstructionOperand::Register("EFLAGS")}, /* output_operands = */ {InstructionOperand::Register("RAX")}, /* implicit_output_operands = */ - {InstructionOperand::Register("EFLAGS")}); + {InstructionOperand::Register("EFLAGS")}, + /* instruction_annotations = */ + {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}); + const BasicBlock block({instruction}); EXPECT_THAT(block.instructions, ElementsAre(instruction)); } @@ -515,7 +551,10 @@ TEST(BasicBlockTest, Equality) { /* implicit_input_operands = */ {InstructionOperand::Register("EFLAGS")}, /* output_operands = */ {InstructionOperand::Register("RAX")}, /* implicit_output_operands = */ - {InstructionOperand::Register("EFLAGS")})); + {InstructionOperand::Register("EFLAGS")}, + /* instruction_annotations = */ + {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)})); + EXPECT_NE(block_1, block_2); EXPECT_FALSE(block_1 == block_2); @@ -530,7 +569,10 @@ TEST(BasicBlockTest, Equality) { /* implicit_input_operands = */ {InstructionOperand::Register("EFLAGS")}, /* output_operands = */ {InstructionOperand::Register("RAX")}, /* implicit_output_operands = */ - {InstructionOperand::Register("EFLAGS")})); + {InstructionOperand::Register("EFLAGS")}, + /* instruction_annotations = */ + {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)})); + EXPECT_EQ(block_1, block_2); EXPECT_FALSE(block_1 != block_2); @@ -551,7 +593,10 @@ TEST(BasicBlockTest, ToString) { /* implicit_input_operands = */ {InstructionOperand::Register("EFLAGS")}, /* output_operands = */ {InstructionOperand::Register("RAX")}, /* implicit_output_operands = */ - {InstructionOperand::Register("EFLAGS")}); + {InstructionOperand::Register("EFLAGS")}, + /* instruction_annotations = */ + {Annotation("MEM_LOAD_RETIRED:L3_MISS", 0.875)}); + BasicBlock block({instruction}); constexpr char kExpectedString[] = "BasicBlock(instructions=InstructionList((Instruction(mnemonic='ADC', " @@ -560,7 +605,9 @@ TEST(BasicBlockTest, ToString) { "InstructionOperand.from_register('RBX'),), " "implicit_input_operands=(InstructionOperand.from_register('EFLAGS'),), " "output_operands=(InstructionOperand.from_register('RAX'),), " - "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),))," + "implicit_output_operands=(InstructionOperand.from_register('EFLAGS'),), " + "instruction_annotations=(Annotation(name='MEM_LOAD_RETIRED:L3_MISS', " + "value=0.875),))," ")))"; EXPECT_EQ(block.ToString(), kExpectedString); } diff --git a/gematria/basic_block/python/BUILD b/gematria/basic_block/python/BUILD index df6e9c90..832c7288 100644 --- a/gematria/basic_block/python/BUILD +++ b/gematria/basic_block/python/BUILD @@ -46,6 +46,7 @@ gematria_py_test( deps = [ ":basic_block", ":basic_block_protos", + "//gematria/proto:annotation_py_pb2", "//gematria/proto:basic_block_py_pb2", "//gematria/proto:canonicalized_instruction_py_pb2", ], diff --git a/gematria/basic_block/python/basic_block.cc b/gematria/basic_block/python/basic_block.cc index e446b350..1596213b 100644 --- a/gematria/basic_block/python/basic_block.cc +++ b/gematria/basic_block/python/basic_block.cc @@ -30,6 +30,7 @@ PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE(std::vector); namespace gematria { @@ -185,24 +186,44 @@ PYBIND11_MODULE(basic_block, m) { py::bind_vector>(m, "InstructionOperandList"); - py::class_(m, "Instruction") + py::class_ annotation(m, "Annotation"); + annotation + .def(py::init(), + py::arg("name") = std::string(), py::arg("value")) + .def("__repr__", &Annotation::ToString) + .def("__eq__", &Annotation::operator==) + .def("__copy__", + [](const Annotation& annotation) { return Annotation(annotation); }) .def( - py::init< - std::string /* mnemonic */, std::string /* llvm_mnemonic */, - std::vector /* prefixes */, - std::vector /* input_operands */, - std::vector /* implicit_input_operands */, - std::vector /* output_oeprands */, - std::vector /* implicit_output_operands */>(), - py::arg("mnemonic") = std::string(), - py::arg("llvm_mnemonic") = std::string(), - py::arg("prefixes") = std::vector(), - py::arg("input_operands") = std::vector(), - py::arg("implicit_input_operands") = - std::vector(), - py::arg("output_operands") = std::vector(), - py::arg("implicit_output_operands") = - std::vector()) + "__deepcopy__", + [](const Annotation& annotation, py::dict) { + return Annotation(annotation); + }, + py::arg("memo")) + .def_readonly("name", &Annotation::name) + .def_readonly("value", &Annotation::value); + + py::bind_vector>(m, "AnnotationList"); + + py::class_(m, "Instruction") + .def(py::init< + std::string /* mnemonic */, std::string /* llvm_mnemonic */, + std::vector /* prefixes */, + std::vector /* input_operands */, + std::vector /* implicit_input_operands */, + std::vector /* output_operands */, + std::vector /* implicit_output_operands */, + std::vector /* instruction_annotations */>(), + py::arg("mnemonic") = std::string(), + py::arg("llvm_mnemonic") = std::string(), + py::arg("prefixes") = std::vector(), + py::arg("input_operands") = std::vector(), + py::arg("implicit_input_operands") = + std::vector(), + py::arg("output_operands") = std::vector(), + py::arg("implicit_output_operands") = + std::vector(), + py::arg("instruction_annotations") = std::vector()) .def("__str__", &Instruction::ToString) .def("__repr__", &Instruction::ToString) .def("__eq__", &Instruction::operator==) @@ -225,7 +246,9 @@ PYBIND11_MODULE(basic_block, m) { &Instruction::implicit_input_operands) .def_readwrite("output_operands", &Instruction::output_operands) .def_readwrite("implicit_output_operands", - &Instruction::implicit_output_operands); + &Instruction::implicit_output_operands) + .def_readwrite("instruction_annotations", + &Instruction::instruction_annotations); py::bind_vector>(m, "InstructionList"); diff --git a/gematria/basic_block/python/basic_block_protos.cc b/gematria/basic_block/python/basic_block_protos.cc index 0e2922c0..bc9cbc2b 100644 --- a/gematria/basic_block/python/basic_block_protos.cc +++ b/gematria/basic_block/python/basic_block_protos.cc @@ -34,6 +34,7 @@ PYBIND11_MODULE(basic_block_protos, m) { m.def("instruction_operand_from_proto", InstructionOperandFromProto, py::arg("proto")); m.def("address_tuple_from_proto", AddressTupleFromProto, py::arg("proto")); + m.def("annotation_from_proto", AnnotationFromProto, py::arg("proto")); } } // namespace gematria diff --git a/gematria/basic_block/python/basic_block_protos_test.py b/gematria/basic_block/python/basic_block_protos_test.py index f6115942..dc2b28c4 100644 --- a/gematria/basic_block/python/basic_block_protos_test.py +++ b/gematria/basic_block/python/basic_block_protos_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from gematria.basic_block.python import basic_block from gematria.basic_block.python import basic_block_protos +from gematria.proto import annotation_pb2 from gematria.proto import basic_block_pb2 from gematria.proto import canonicalized_instruction_pb2 @@ -26,6 +27,7 @@ _CanonicalizedInstructionProto = ( canonicalized_instruction_pb2.CanonicalizedInstructionProto ) +_AnnotationProto = annotation_pb2.AnnotationProto class AddressTupleTest(absltest.TestCase): @@ -116,6 +118,18 @@ def test_memory(self): self.assertIsNone(operand.address) +class AnnotationFromProtoTest(absltest.TestCase): + + def test_annotation_from_proto(self): + proto = _AnnotationProto( + name='cache_miss_freq', + value=0.875, + ) + annotation = basic_block_protos.annotation_from_proto(proto) + self.assertEqual(annotation.name, 'cache_miss_freq') + self.assertEqual(annotation.value, 0.875) + + class InstructionFromProtoTest(absltest.TestCase): def test_instruction_from_proto(self): @@ -134,6 +148,9 @@ def test_instruction_from_proto(self): implicit_input_operands=( _CanonicalizedOperandProto(register_name='EFLAGS'), ), + instruction_annotations=( + _AnnotationProto(name='cache_miss_freq', value=0.875), + ), ) instruction = basic_block_protos.instruction_from_proto(proto) self.assertEqual(instruction.mnemonic, 'ADC') @@ -158,6 +175,10 @@ def test_instruction_from_proto(self): instruction.implicit_output_operands, (basic_block.InstructionOperand.from_register('EFLAGS'),), ) + self.assertSequenceEqual( + instruction.instruction_annotations, + (basic_block.Annotation('cache_miss_freq', 0.875),), + ) class BasicBlockFromProtoTest(absltest.TestCase): @@ -174,6 +195,12 @@ def test_initialize_from_proto(self): output_operands=( _CanonicalizedOperandProto(register_name='RCX'), ), + instruction_annotations=( + _AnnotationProto( + name='cache_miss_freq', + value=0.875, + ), + ), ), _CanonicalizedInstructionProto( mnemonic='MOVSB', @@ -212,6 +239,9 @@ def test_initialize_from_proto(self): output_operands=basic_block.InstructionOperandList(( basic_block.InstructionOperand.from_register('RCX'), )), + instruction_annotations=basic_block.AnnotationList(( + basic_block.Annotation('cache_miss_freq', 0.875), + )), ), basic_block.Instruction( mnemonic='MOVSB', diff --git a/gematria/basic_block/python/basic_block_test.py b/gematria/basic_block/python/basic_block_test.py index 2f11a527..1628d834 100644 --- a/gematria/basic_block/python/basic_block_test.py +++ b/gematria/basic_block/python/basic_block_test.py @@ -201,6 +201,9 @@ def test_initialize_with_keyword_args(self): implicit_output_operands=basic_block.InstructionOperandList(( basic_block.InstructionOperand.from_register('EFLAGS'), )), + instruction_annotations=basic_block.AnnotationList(( + basic_block.Annotation('cache_miss_freq', 0.875), + )), ) self.assertEqual(instruction.mnemonic, 'ADC') self.assertEqual(instruction.llvm_mnemonic, 'ADC32rr') @@ -224,6 +227,10 @@ def test_initialize_with_keyword_args(self): instruction.implicit_output_operands, (basic_block.InstructionOperand.from_register('EFLAGS'),), ) + self.assertSequenceEqual( + instruction.instruction_annotations, + (basic_block.Annotation('cache_miss_freq', 0.875),), + ) instruction = basic_block.Instruction( mnemonic='NOP', prefixes=basic_block.StringList(('LOCK',)) diff --git a/gematria/datasets/BUILD b/gematria/datasets/BUILD index 7ccddfeb..aa3e1a2a 100644 --- a/gematria/datasets/BUILD +++ b/gematria/datasets/BUILD @@ -1,3 +1,8 @@ +load( + "//:python.bzl", + "gematria_py_test", +) + package( default_visibility = ["//visibility:private"], ) @@ -69,6 +74,20 @@ cc_binary( ], ) +gematria_py_test( + name = "convert_bhive_to_llvm_exegesis_input_test", + size = "small", + srcs = ["convert_bhive_to_llvm_exegesis_input_test.py"], + data = [ + "tests/lit.cfg.py", + "tests/lit.site.cfg.py", + ":convert_bhive_to_llvm_exegesis_input", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + "@llvm-project//llvm:split-file", + ] + glob(["tests/*.test"]), +) + cc_library( name = "find_accessed_addrs", srcs = ["find_accessed_addrs.cc"], diff --git a/gematria/datasets/convert_bhive_to_llvm_exegesis_input.cc b/gematria/datasets/convert_bhive_to_llvm_exegesis_input.cc index 1c90584c..19c84f56 100644 --- a/gematria/datasets/convert_bhive_to_llvm_exegesis_input.cc +++ b/gematria/datasets/convert_bhive_to_llvm_exegesis_input.cc @@ -49,10 +49,12 @@ constexpr std::string_view kMemDefPrefix = "# LLVM-EXEGESIS-MEM-DEF "; constexpr std::string_view kMemMapPrefix = "# LLVM-EXEGESIS-MEM-MAP "; constexpr std::string_view kMemNamePrefix = "MEM"; -enum class AnnotatorType { kExegesis, kFast }; +enum class AnnotatorType { kExegesis, kFast, kNone }; constexpr std::pair kAnnotatorTypeNames[] = { - {AnnotatorType::kExegesis, "exegesis"}, {AnnotatorType::kFast, "fast"}}; + {AnnotatorType::kExegesis, "exegesis"}, + {AnnotatorType::kFast, "fast"}, + {AnnotatorType::kNone, "none"}}; bool AbslParseFlag(absl::string_view text, AnnotatorType* type, std::string* error) { @@ -104,6 +106,8 @@ absl::StatusOr GetAccessedAddrs( return gematria::LlvmExpectedToStatusOr( exegesis_annotator->findAccessedAddrs( llvm::ArrayRef(basic_block.begin(), basic_block.end()))); + case AnnotatorType::kNone: + return gematria::AccessedAddrs(); } return absl::InvalidArgumentError("unknown annotator type"); } @@ -244,12 +248,11 @@ int main(int argc, char* argv[]) { // Check for errors. if (!proto.ok()) { - std::cerr << "Failed to disassemble block '" << hex << ": " - << proto.status() << "\n"; + std::cerr << "Failed to disassemble block '" << hex + << "': " << proto.status() << "\n"; continue; } - // This will only get the first segfault address. auto addrs = GetAccessedAddrs(*bytes, exegesis_annotator.get()); if (!addrs.ok()) { @@ -337,13 +340,13 @@ int main(int argc, char* argv[]) { } } - if (file_counter % report_progress_every == 0) + if (file_counter != 0 && file_counter % report_progress_every == 0) std::cerr << "Finished annotating block #" << file_counter << ".\n"; file_counter++; } - if (!json_output_dir.empty()) { + if (!json_output_dir.empty() && processed_snippets.size() != 0) { size_t json_file_number = file_counter / blocks_per_json_file; bool write_successfully = WriteJsonFile(std::move(processed_snippets), json_file_number, json_output_dir); diff --git a/gematria/datasets/convert_bhive_to_llvm_exegesis_input_test.py b/gematria/datasets/convert_bhive_to_llvm_exegesis_input_test.py new file mode 100755 index 00000000..466c9677 --- /dev/null +++ b/gematria/datasets/convert_bhive_to_llvm_exegesis_input_test.py @@ -0,0 +1,11 @@ +from lit.main import main +import sys + +# Lit expects the test folder path to be specifided on the command-line, which +# is usually passed in through CMake. Bazel doesn't support this configuration, +# so we manually add the path here. +sys.argv.append("./gematria/datasets/tests") +sys.argv.append("-vv") + +if __name__ == "__main__": + main() diff --git a/gematria/datasets/tests/blocks_per_json_file.test b/gematria/datasets/tests/blocks_per_json_file.test new file mode 100644 index 00000000..d6b56844 --- /dev/null +++ b/gematria/datasets/tests/blocks_per_json_file.test @@ -0,0 +1,56 @@ +; Test that splitting a dataset among multiple JSON files works as expected. + +; RUN: split-file %s %t +; RUN: mkdir %t.jsondir +; RUN: %convert_bhive_to_llvm_exegesis_input --json_output_dir=%t.jsondir --bhive_csv=%t/test.csv --blocks_per_json_file=1 +; RUN: cat %t.jsondir/0.json | FileCheck --check-prefix FILE1 %s +; RUN: cat %t.jsondir/1.json | FileCheck --check-prefix FILE2 %s + +; Ensure that we don't have any "leftover" files. +; RUN: ls %t.jsondir | FileCheck --check-prefix DIR %s + +; FILE1: [ +; FILE1: { +; FILE1: "Hex": "85c044897c2460", +; FILE1: "MemoryDefinitions": [ +; FILE1: { +; FILE1: "Name": "MEM", +; FILE1: "Size": 4096, +; FILE1: "Value": 305419776 +; FILE1: } +; FILE1: ], +; FILE1: "MemoryMappings": [ +; FILE1: { +; FILE1: "Address": 65536, +; FILE1: "Value": "MEM" +; FILE1: } +; FILE1: ] +; FILE1: } +; FILE1: ] + +; FILE2: [ +; FILE2: { +; FILE2: "Hex": "3b31", +; FILE2: "MemoryDefinitions": [ +; FILE2: { +; FILE2: "Name": "MEM", +; FILE2: "Size": 4096, +; FILE2: "Value": 305419776 +; FILE2: } +; FILE2: ], +; FILE2: "MemoryMappings": [ +; FILE2: { +; FILE2: "Address": 65536, +; FILE2: "Value": "MEM" +; FILE2: } +; FILE2: ] +; FILE2: } +; FILE2: ] + +; DIR: 0.json +; DIR: 1.json +; DIR-NOT: 2.json + +;--- test.csv +85c044897c2460,98.000000 +3b31,45.000000 diff --git a/gematria/datasets/tests/conversion.test b/gematria/datasets/tests/conversion.test new file mode 100644 index 00000000..e2ae831e --- /dev/null +++ b/gematria/datasets/tests/conversion.test @@ -0,0 +1,46 @@ +; Test that converting a single basic block in a CSV with the default settings +; produces a proper llvm-exegesis snippet file. + +; RUN: split-file %s %t +; RUN: mkdir %t.asmdir +; RUN: %convert_bhive_to_llvm_exegesis_input --asm_output_dir=%t.asmdir --bhive_csv=%t/test.csv +; RUN: cat %t.asmdir/0.test | FileCheck %s + +; CHECK: # LLVM-EXEGESIS-DEFREG RAX 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG RCX 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG RDX 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG RSI 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG RDI 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R8 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R9 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R10 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R11 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG RBX 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R14 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R15 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R12 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG R13 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG RBP 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG RSP 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM0 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM1 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM2 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM3 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM4 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM5 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM6 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM7 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM8 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM9 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM10 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM11 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM12 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM13 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM14 12345600 +; CHECK: # LLVM-EXEGESIS-DEFREG XMM15 12345600 +; CHECK: # LLVM-EXEGESIS-MEM-DEF MEM 4096 0000000012345600 +; CHECK: # LLVM-EXEGESIS-MEM-MAP MEM 65536 +; CHECK: cmpl (%rcx), %esi + +;--- test.csv +3b31,45.000000 diff --git a/gematria/datasets/tests/flag_errors.test b/gematria/datasets/tests/flag_errors.test new file mode 100644 index 00000000..60b336fe --- /dev/null +++ b/gematria/datasets/tests/flag_errors.test @@ -0,0 +1,22 @@ +; Test various flag combinations that should result in an error. + +; Test that not passing in any flags results in an error. +; RUN: %not %convert_bhive_to_llvm_exegesis_input 2>&1 | FileCheck %s --check-prefix=NO-ARGS + +; NO-ARGS: Error: --bhive_csv is required + +; Test that setting a number of blocks per JSON file less than 1 results in +; an error. +; RUN: split-file %s %t +; RUN: mkdir %t.asmdir +; RUN: %not %convert_bhive_to_llvm_exegesis_input --bhive_csv=%t/test.csv --asm_output_dir=%t.asmdir --blocks_per_json_file=0 2>&1 | FileCheck %s --check-prefix=BAD-BLOCK-COUNT + +; BAD-BLOCK-COUNT: Error: --blocks_per_json_file must be greater than 1. + +; Test that specifying an unknown annotator type results in an error. +; RUN: %not %convert_bhive_to_llvm_exegesis_input --bhive_csv=%t/test.csv --asm_output_dir=%t.asmdir --annotator_implementation=doesntexist 2>&1 | FileCheck %s --check-prefix=BAD-ANNOTATOR-TYPE + +; BAD-ANNOTATOR-TYPE: ERROR: Illegal value 'doesntexist' specified for flag 'annotator_implementation'; unknown annotator type + +;--- test.csv +3b31,45.000000 diff --git a/gematria/datasets/tests/lit.cfg.py b/gematria/datasets/tests/lit.cfg.py new file mode 100644 index 00000000..1106898f --- /dev/null +++ b/gematria/datasets/tests/lit.cfg.py @@ -0,0 +1,24 @@ +import lit.formats + +config.name = 'gematria' +config.test_format = lit.formats.ShTest(True) + +config.suffixes = ['.test'] + +config.test_source_root = os.path.dirname(__file__) +config.test_exec_root = os.path.join(config.obj_root, 'test') + +config.substitutions.append( + ('FileCheck', os.path.join(config.llvm_tools_root, 'FileCheck')) +) +config.substitutions.append( + ('split-file', os.path.join(config.llvm_tools_root, 'split-file')) +) +config.substitutions.append( + ('%not', os.path.join(config.llvm_tools_root, 'not')) +) + +config.substitutions.append(( + '%convert_bhive_to_llvm_exegesis_input', + os.path.join(config.tools_root, 'convert_bhive_to_llvm_exegesis_input'), +)) diff --git a/gematria/datasets/tests/lit.site.cfg.py b/gematria/datasets/tests/lit.site.cfg.py new file mode 100644 index 00000000..bc238a12 --- /dev/null +++ b/gematria/datasets/tests/lit.site.cfg.py @@ -0,0 +1,7 @@ +import os + +config.obj_root = os.path.join(os.getcwd(), 'gematria/datasets/tests') +config.tools_root = os.path.join(os.getcwd(), 'gematria/datasets') +config.llvm_tools_root = os.path.join(os.getcwd(), 'external/llvm-project/llvm') + +lit_config.load_config(config, os.path.join(config.obj_root, 'lit.cfg.py')) diff --git a/gematria/datasets/tests/max_bb_count.test b/gematria/datasets/tests/max_bb_count.test new file mode 100644 index 00000000..0a3eb220 --- /dev/null +++ b/gematria/datasets/tests/max_bb_count.test @@ -0,0 +1,13 @@ +; Test that we only annotate up to --max_bb_count BBs. + +; RUN: split-file %s %t +; RUN: mkdir %t.asmdir +; RUN: %convert_bhive_to_llvm_exegesis_input --asm_output_dir=%t.asmdir --bhive_csv=%t/test.csv --max_bb_count=1 +; RUN: ls %t.asmdir | FileCheck %s + +; CHECK: 0.test +; CHECK-NOT: 1.test + +;--- test.csv +3b31,45.000000 +85c044897c2460,98.000000 diff --git a/gematria/datasets/tests/report_progress.test b/gematria/datasets/tests/report_progress.test new file mode 100644 index 00000000..05d5147f --- /dev/null +++ b/gematria/datasets/tests/report_progress.test @@ -0,0 +1,22 @@ +; Test that the --report_progress_every reports progress at the expected +; intervals. + +; RUN: split-file %s %t +; RUN: mkdir %t.asmdir +; RUN: %convert_bhive_to_llvm_exegesis_input --asm_output_dir=%t.asmdir --bhive_csv=%t/test.csv --report_progress_every=2 2>&1 | FileCheck %s + +; CHECK: Finished annotating block #2. +; CHECK: Finished annotating block #4. + +; Test that --report_progress_every doesn't output anything with the default +; value. +; RUN: %convert_bhive_to_llvm_exegesis_input --asm_output_dir=%t.asmdir --bhive_csv=%t/test.csv 2>&1 | FileCheck %s --check-prefix=DEFAULT-VALUE + +; DEFAULT-VALUE-NOT: Finished annotating block + +;--- test.csv +4183ff0119c083e00885c98945c4b8010000000f4fc139c2,298.000000 +4889de4889c24c89ff,93.000000 +48895d1844886520488945004889e84883c4085b5d415c415d,335.000000 +418b4424084d8b3424498d2cc64939ee,98.000000 +85c044897c2460,98.000000 diff --git a/gematria/granite/convert_gb_token_model_to_tflite.sh b/gematria/granite/convert_gb_token_model_to_tflite.sh index 87097469..90460e54 100755 --- a/gematria/granite/convert_gb_token_model_to_tflite.sh +++ b/gematria/granite/convert_gb_token_model_to_tflite.sh @@ -37,6 +37,8 @@ function print_error_and_exit() { # Parse command-line flags. # TODO(ondrasej): Consider using getopt instead of parsing the flags manually. gematria_export_as_seq2seq=0 +gematria_export_with_deltas=0 +gematria_export_with_annotations=0 gematria_input_graphdef="" gematria_output_tflite="" while [[ "$#" -gt 0 ]]; do @@ -58,6 +60,12 @@ while [[ "$#" -gt 0 ]]; do --gematria_export_as_seq2seq) gematria_export_as_seq2seq=1 ;; + --gematria_export_with_deltas) + gematria_export_with_deltas=1 + ;; + --gematria_export_with_annotations) + gematria_export_with_annotations=1 + ;; *) print_error_and_exit "Unexpected command-line argument: $1" esac @@ -78,11 +86,11 @@ function str_join() { } # The list of inputs of the model. This must contain an entry for each -# tf.placeholder tensor used in the Python code. The order of the tensors used -# here must correspond to the input tensor indices defined in -# gematria/granite/graph_builder_model_inference.cc. +# tf.placeholder tensor used in the Python code. readonly INPUT_TENSORS_LIST=( - ModelBase.delta_block_index_tensor + $([[ "${gematria_export_as_seq2seq}" -eq 1 || \ + "${gematria_export_with_deltas}" -eq 1 ]] && \ + echo 'ModelBase.delta_block_index_tensor') GnnModelBase.node_features GnnModelBase.edge_features GnnModelBase.global_features @@ -90,7 +98,11 @@ readonly INPUT_TENSORS_LIST=( GnnModelBase.senders GnnModelBase.num_edges GnnModelBase.num_nodes - GraphBuilderModelBase.instruction_node_mask + $([[ "${gematria_export_with_annotations}" -eq 1 ]] && \ + echo 'GraphBuilderModelBase.instruction_node_mask') + $([[ "${gematria_export_with_deltas}" -eq 1 || \ + "${gematria_export_with_annotations}" -eq 1 ]] && \ + echo 'TokenGraphBuilderModel.instruction_annotations') ) INPUT_TENSORS=$(str_join "${INPUT_TENSORS_LIST[@]}") readonly INPUT_TENSORS @@ -103,20 +115,16 @@ readonly TARGET_OPS_LIST=( TARGET_OPS=$(str_join "${TARGET_OPS_LIST[@]}") readonly TARGET_OPS -if (( FLAGS_gematria_export_as_seq2seq )); then - readonly OUTPUT_TENSOR_LIST=( - ModelBase.output_tensor - ModelBase.output_tensor_deltas - TokenModel.token_list - GraphBuilderModelBase.special_tokens - ) -else - readonly OUTPUT_TENSOR_LIST=( - ModelBase.output_tensor - TokenModel.token_list - GraphBuilderModelBase.special_tokens - ) -fi +readonly OUTPUT_TENSOR_LIST=( + ModelBase.output_tensor + $([[ "${gematria_export_as_seq2seq}" -eq 1 || \ + "${gematria_export_with_deltas}" -eq 1 ]] && \ + echo 'ModelBase.output_tensor_deltas') + TokenModel.token_list + GraphBuilderModelBase.special_tokens + $([[ "${gematria_export_with_annotations}" -eq 1 ]] && \ + echo 'TokenGraphBuilderModel.annotation_names') +) OUTPUT_TENSORS=$(str_join "${OUTPUT_TENSOR_LIST[@]}") readonly OUTPUT_TENSORS diff --git a/gematria/granite/graph_builder.cc b/gematria/granite/graph_builder.cc index 47a24e7e..781a4df2 100644 --- a/gematria/granite/graph_builder.cc +++ b/gematria/granite/graph_builder.cc @@ -149,6 +149,7 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( std::vector node_tokens, std::string_view immediate_token, std::string_view fp_immediate_token, std::string_view address_token, std::string_view memory_token, + std::set annotation_names /* = std::set() */, OutOfVocabularyTokenBehavior out_of_vocabulary_behavior /* = ReturnError() */ ) @@ -161,6 +162,7 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( FindTokenOrDie(node_tokens_, std::string(fp_immediate_token))), address_token_(FindTokenOrDie(node_tokens_, std::string(address_token))), memory_token_(FindTokenOrDie(node_tokens_, std::string(memory_token))), + annotation_names_(std::move(annotation_names)), out_of_vocabulary_behavior_(out_of_vocabulary_behavior), replacement_token_( out_of_vocabulary_behavior.behavior_type() == @@ -168,7 +170,16 @@ BasicBlockGraphBuilder::BasicBlockGraphBuilder( ? kInvalidTokenIndex : FindTokenOrDie( node_tokens_, - out_of_vocabulary_behavior.replacement_token())) {} + out_of_vocabulary_behavior.replacement_token())) { + instruction_annotations_ = std::vector>(); + + // Store row indices corresponding to specific annotation names. + int annotation_idx = 0; + for (auto& annotation_name : annotation_names_) { + annotation_name_to_idx_[annotation_name] = annotation_idx; + ++annotation_idx; + } +} bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( const std::vector& instructions) { @@ -191,6 +202,16 @@ bool BasicBlockGraphBuilder::AddBasicBlockFromInstructions( return false; } + // Store the annotations for later use (inclusion in embeddings), using -1 + // as a default value wherever annotations are missing. + std::vector row = std::vector(annotation_names_.size(), -1); + for (const auto& [name, value] : instruction.instruction_annotations) { + const auto annotation_index = annotation_name_to_idx_.find(name); + if (annotation_index == annotation_name_to_idx_.end()) continue; + row[annotation_index->second] = value; + } + instruction_annotations_.push_back(row); + // Add nodes for prefixes of the instruction. for (const std::string& prefix : instruction.prefixes) { const NodeIndex prefix_node = AddNode(NodeType::kPrefix, prefix); @@ -253,6 +274,8 @@ void BasicBlockGraphBuilder::Reset() { edge_types_.clear(); global_features_.clear(); + + instruction_annotations_.clear(); } bool BasicBlockGraphBuilder::AddInputOperand( diff --git a/gematria/granite/graph_builder.h b/gematria/granite/graph_builder.h index cfc65d9d..07e78c6d 100644 --- a/gematria/granite/graph_builder.h +++ b/gematria/granite/graph_builder.h @@ -91,6 +91,7 @@ #include #include +#include #include #include #include @@ -159,6 +160,11 @@ class BasicBlockGraphBuilder { // address computation. // - memory_token: the token associated with nodes that represent memory // accesses. + // - annotation_names: the set of names of annotations to be used. + // Annotations with names belonging to this list will be stored and + // available for use, the rest will be discarded. All instructions need + // not have all annotations corresponding to the elements of this list, + // i.e. missing annotations will be handled. // - unknown_token_behavior and unknown_token: controls for the behavior of // the basic block graph builder when it encounters an unknown token when // adding new basic blocks to the builder. When unknown_node_behavior is @@ -171,6 +177,7 @@ class BasicBlockGraphBuilder { std::vector node_tokens, std::string_view immediate_token, std::string_view fp_immediate_token, std::string_view address_token, std::string_view memory_token, + std::set annotation_names = std::set(), OutOfVocabularyTokenBehavior out_of_vocabulary_behavior = OutOfVocabularyTokenBehavior::ReturnError()); @@ -236,6 +243,18 @@ class BasicBlockGraphBuilder { // corresponding to the nodes). Corresponds to `GraphsTuple.nodes`. const std::vector& node_features() const { return node_features_; } + // Names of types of instruction annotations stored. + const std::set& annotation_names() const { + return annotation_names_; + } + // Values of instruction level runtime annotations. Represents a + // `num_instructions` x `annotation_names.size()` matrix, each entry of which + // represents the value of the annotation of the type corresponding to the + // column for the instruction corresponding to the row. + const std::vector>& instruction_annotations() const { + return instruction_annotations_; + } + // The sender (start) nodes of the edges in the graphs. `edge_senders()[i]` is // the index of the start node of the i-th edge in the graph. Corresponds to // `GraphsTuple.senders`. @@ -374,6 +393,11 @@ class BasicBlockGraphBuilder { const TokenIndex address_token_; const TokenIndex memory_token_; + // Holds valid annotation names in sorted order. Instruction annotations with + // names belonging to this list are stored in `instruction_annotations_` and + // the rest are discarded. + const std::set annotation_names_; + const OutOfVocabularyTokenBehavior out_of_vocabulary_behavior_; const TokenIndex replacement_token_; @@ -383,6 +407,11 @@ class BasicBlockGraphBuilder { std::vector node_types_; std::vector node_features_; + // Mapping from annotation type names to corresponding row index in the + // `instruction_annotations_` matrix. + std::unordered_map annotation_name_to_idx_; + std::vector> instruction_annotations_; + std::vector edge_senders_; std::vector edge_receivers_; std::vector edge_types_; diff --git a/gematria/granite/graph_builder_model_inference.cc b/gematria/granite/graph_builder_model_inference.cc index a1d21ecc..af690d5e 100644 --- a/gematria/granite/graph_builder_model_inference.cc +++ b/gematria/granite/graph_builder_model_inference.cc @@ -15,12 +15,16 @@ #include "gematria/granite/graph_builder_model_inference.h" #include +#include #include #include #include #include +#include +#include #include #include +#include #include #include @@ -46,20 +50,37 @@ namespace { using ::tflite::FlatBufferModel; -// The indices of the input tensors in the compiled TensorFlow Lite model. This -// order of the input tensors must be preserved during the conversion of the -// model to the .tflite format. -constexpr int kDeltaBlockIndexTensor = 0; -constexpr int kGraphNodesTensor = 1; -constexpr int kGraphEdgesTensor = 2; -constexpr int kGraphGlobalsTensor = 3; -constexpr int kGraphReceiversTensor = 4; -constexpr int kGraphSendersTensor = 5; -constexpr int kGraphNEdgeTensor = 6; -constexpr int kGraphNNodeTensor = 7; -constexpr int kInstructionNodeMaskTensor = 8; - -constexpr int kNumInputTensors = 9; +// The names of the valid input tensors in the compiled TensorFlow Lite model. +// Not all tensors are present in all models. +constexpr std::string_view kDeltaBlockIndexTensorName = + "ModelBase.delta_block_index_tensor"; +constexpr std::string_view kGraphNodesTensorName = "GnnModelBase.node_features"; +constexpr std::string_view kGraphEdgesTensorName = "GnnModelBase.edge_features"; +constexpr std::string_view kGraphGlobalsTensorName = + "GnnModelBase.global_features"; +constexpr std::string_view kGraphReceiversTensorName = "GnnModelBase.receivers"; +constexpr std::string_view kGraphSendersTensorName = "GnnModelBase.senders"; +constexpr std::string_view kGraphNEdgeTensorName = "GnnModelBase.num_edges"; +constexpr std::string_view kGraphNNodeTensorName = "GnnModelBase.num_nodes"; +constexpr std::string_view kInstructionNodeMaskTensorName = + "GraphBuilderModelBase.instruction_node_mask"; +constexpr std::string_view kInstructionAnnotationsTensorName = + "TokenGraphBuilderModel.instruction_annotations"; + +// The model must have at least 7 input tensors - it may not include some +// tensors such as the delta block index tensor and the instruction annotations +// tensor depending on the model configuration. +constexpr int kNumRequiredInputTensors = 7; + +// The list of names of all input tensors that every graph builder model must +// have to be valid. Other tensors may be present based on the configuration. +constexpr std::array + kRequiredInputTensorNames{ + kGraphNodesTensorName, kGraphEdgesTensorName, + kGraphGlobalsTensorName, kGraphReceiversTensorName, + kGraphSendersTensorName, kGraphNEdgeTensorName, + kGraphNNodeTensorName, + }; // The indices of special node token indices in the tensor // `GraphBuilderModelBase.special_tokens`. For example the token used for @@ -83,16 +104,19 @@ constexpr int kNumSpecialNodeTokens = 5; constexpr std::string_view kNodeTokensTensorName = "TokenModel.token_list"; constexpr std::string_view kSpecialTokensTensorName = "GraphBuilderModelBase.special_tokens"; +constexpr std::string_view kAnnotationNamesTensorName = + "TokenGraphBuilderModel.annotation_names"; // Checks that: // 1. `tensor` != nullptr, // 2. `tensor` has type `tensor_type`. // 3. `tensor` has the number of dimensions corresponding to the number of -// elements of `sizes`, and the sizes in those dimensions are equal to `sizes`. -// Returns `llvm::Error::success()` when all checks pass, an error otherwise. +// elements of `sizes`, and the sizes in those dimensions are equal to +// `sizes`. Returns `llvm::Error::success()` when all checks pass, an error +// otherwise. // -// TODO(ondrasej): See if we can replace this function and the one below with -// TFModelEvaluatorImpl::checkReportAndInvalidate. +// TODO(ondrasej): See if we can replace this function and the one below +// with TFModelEvaluatorImpl::checkReportAndInvalidate. template llvm::Error CheckTensorTypeAndDimensions(int tensor_index, const TfLiteTensor* tensor, @@ -335,6 +359,63 @@ llvm::Expected GetNodeTokenAtIndex( return node_token_list[token_index]; } +// Extracts the set of annotation names from the model. This should be a Const +// tensor, and as such, it should be readable without providing any inputs. +// Returns an error when the annotation names tensor is not found or it is not +// readable. +llvm::Expected> GetAnnotationNames( + const tflite::Interpreter& interpreter) { + llvm::Expected annotation_names_tensor_index = TensorIndexByName( + interpreter, interpreter.outputs(), kAnnotationNamesTensorName); + if (llvm::Error error = annotation_names_tensor_index.takeError()) { + return error; + } + const TfLiteTensor* const annotation_names_tensor = + interpreter.tensor(*annotation_names_tensor_index); + assert(annotation_names_tensor != nullptr); + + const size_t annotation_names_size_bytes = annotation_names_tensor->bytes; + // The token list tensor is a Const operation, so it should be readable before + // running the inference or providing any inputs. + const char* const annotation_names_raw_data = reinterpret_cast( + interpreter.typed_tensor(*annotation_names_tensor_index)); + if (annotation_names_raw_data == nullptr) { + return llvm::createStringError(llvm::errc::invalid_argument, + "The annotation names could not be read"); + } + const std::string_view annotation_names_data(annotation_names_raw_data, + annotation_names_size_bytes); + std::vector annotation_names = + StrSplitAsCopy(annotation_names_data, '\0'); + return std::set( + std::make_move_iterator(annotation_names.begin()), + std::make_move_iterator(annotation_names.end())); +} + +// Checks whether the model has all input tensors required to have for all +// graph builder models. +llvm::Error CheckHasRequiredInputTensors( + const std::unordered_map input_name_to_idx) { + // Check if all required input tensors are present. + std::vector missing_inputs; + for (const std::string_view input_name : kRequiredInputTensorNames) { + if (!input_name_to_idx.count(input_name)) + missing_inputs.push_back(input_name); + } + if (!missing_inputs.empty()) { + std::stringstream buffer; + buffer << "Model is missing input tensors. "; + for (const std::string_view& missing_input : missing_inputs) { + buffer << missing_input << ", "; + } + buffer.seekp(-2, std::ios_base::end); + buffer << " were expected but not found"; + return llvm::createStringError(llvm::errc::invalid_argument, buffer.str()); + } + + return llvm::Error::success(); +} + } // namespace llvm::Expected> @@ -346,7 +427,51 @@ GraphBuilderModelInference::FromTfLiteModel( } llvm::Expected> interpreter = CreateInterpreter(*tflite_model); - if (auto error = interpreter.takeError()) return error; + if (llvm::Error error = interpreter.takeError()) return error; + + // Get a mapping between the names of the input tensors used in the model and + // their corresponding indices. Does not take ownership of the name strings, + // so the interpreter must stay alive for as long as `input_name_to_idx`. + std::unique_ptr> input_name_to_idx = + std::make_unique>(); + for (int idx : (*interpreter)->inputs()) { + input_name_to_idx->emplace((*interpreter)->GetInputName(idx), idx); + } + if (llvm::Error error = CheckHasRequiredInputTensors(*input_name_to_idx)) { + return error; + } + + // Ensures no unexpected input tensors are present. + const int num_input_tensors = (*interpreter)->inputs().size(); + const bool uses_deltas = input_name_to_idx->count(kDeltaBlockIndexTensorName); + const bool uses_annotations = + input_name_to_idx->count(kInstructionAnnotationsTensorName); + if (uses_deltas && + !input_name_to_idx->count(kInstructionNodeMaskTensorName)) { + return llvm::createStringError( + llvm::errc::invalid_argument, + "Missing input tensor. Models having " + + llvm::Twine(kDeltaBlockIndexTensorName) + " must also have " + + llvm::Twine(kInstructionNodeMaskTensorName) + "."); + } + if (uses_annotations && + !input_name_to_idx->count(kInstructionNodeMaskTensorName)) { + return llvm::createStringError( + llvm::errc::invalid_argument, + "Missing input tensor. Models having " + + llvm::Twine(kInstructionAnnotationsTensorName) + + " must also have " + llvm::Twine(kInstructionNodeMaskTensorName) + + "."); + } + const int num_expected_input_tensors = + kNumRequiredInputTensors + int(uses_deltas) + int(uses_annotations) + + int(uses_deltas || uses_annotations); + if (num_input_tensors != num_expected_input_tensors) { + return llvm::createStringError( + llvm::errc::invalid_argument, + "Unexpected number of input tensors. Expected %d, found %d.", + num_expected_input_tensors, num_input_tensors); + } // Get the list of node tokens used in the model. llvm::Expected> node_token_list = @@ -373,6 +498,16 @@ GraphBuilderModelInference::FromTfLiteModel( llvm::errc::invalid_argument, "The special token index tensor could not be read"); } + + // Get the set of annotation names used by the model, if any. + std::set annotation_names; + if (uses_annotations) { + llvm::Expected> expected_annotation_names = + GetAnnotationNames(**interpreter); + if (llvm::Error error = expected_annotation_names.takeError()) return error; + annotation_names = std::move(*expected_annotation_names); + } + // We'll be std::move()-ing the node list vector in the same function call // where we use the token names. To be safe from any move effects, we make a // copy of all the tokens instead of taking a const reference. @@ -412,19 +547,33 @@ GraphBuilderModelInference::FromTfLiteModel( *std::move(node_token_list), /* immediate_token = */ *immediate_token, /* fp_immediate_token = */ *fp_immediate_token, /* address_token = */ *address_token, /* memory_token = */ *memory_token, + /* annotation_names = */ std::move(annotation_names), /* out_of_vocabulary_behavior = */ out_of_vocabulary_behavior); // We can't use std::make_unique(), because // std::make_unique<>() requires a public constructor. return std::unique_ptr( - new GraphBuilderModelInference(std::move(graph_builder), tflite_model)); + new GraphBuilderModelInference( + std::move(graph_builder), tflite_model, std::move(*interpreter), + std::move(input_name_to_idx), uses_deltas, uses_annotations)); } GraphBuilderModelInference::GraphBuilderModelInference( std::unique_ptr graph_builder, - const FlatBufferModel* tflite_model) - : graph_builder_(std::move(graph_builder)), tflite_model_(*tflite_model) { + const FlatBufferModel* tflite_model, + std::unique_ptr interpreter, + std::unique_ptr> + input_name_to_idx, + bool uses_deltas, bool uses_annotations) + : graph_builder_(std::move(graph_builder)), + tflite_model_(*tflite_model), + interpreter_(std::move(interpreter)), + input_name_to_idx_(std::move(input_name_to_idx)), + uses_deltas_(uses_deltas), + uses_annotations_(uses_annotations) { assert(tflite_model != nullptr); + assert(interpreter_ != nullptr); + assert(input_name_to_idx_ != nullptr); assert(graph_builder_ != nullptr); } @@ -443,99 +592,108 @@ GraphBuilderModelInference::RunInference() { return std::vector(); } - // TODO(ondrasej): Reuse the interpreter across RunInference() calls. The - // graph builder class is already stateful, so this should not be an issue - // and it could save us some loading time. - llvm::Expected> interpreter = - CreateInterpreter(tflite_model_); - if (llvm::Error error = interpreter.takeError()) return error; - - // TODO(ondrasej): Move all the checks of the model format to the - // initialization of the class. - if ((*interpreter)->inputs().size() != kNumInputTensors) { - return llvm::createStringError( - llvm::errc::invalid_argument, - "Unexpected number of input tensors. Expected %d, found %d.", - kNumInputTensors, (*interpreter)->inputs().size()); - } - const std::vector instruction_node_mask = graph_builder_->InstructionNodeMask(); + const std::vector> instruction_annotations = + graph_builder_->instruction_annotations(); const std::vector delta_block_index = graph_builder_->DeltaBlockIndex(); // Resize the input tensors according to the size of the input data. - // TODO(ondrasej): Replace the index-based lookups with name-based lookups. - - GEMATRIA_RETURN_IF_ERROR( - Resize1DTensor(interpreter->get(), kDeltaBlockIndexTensor, - static_cast(delta_block_index.size()))); - GEMATRIA_RETURN_IF_ERROR(Resize1DTensor(interpreter->get(), kGraphNodesTensor, - graph_builder_->num_nodes())); - GEMATRIA_RETURN_IF_ERROR(Resize1DTensor(interpreter->get(), kGraphEdgesTensor, - graph_builder_->num_edges())); GEMATRIA_RETURN_IF_ERROR(Resize1DTensor( - interpreter->get(), kGraphReceiversTensor, + interpreter_.get(), input_name_to_idx_->at(kGraphNodesTensorName), + graph_builder_->num_nodes())); + GEMATRIA_RETURN_IF_ERROR(Resize1DTensor( + interpreter_.get(), input_name_to_idx_->at(kGraphEdgesTensorName), + graph_builder_->num_edges())); + GEMATRIA_RETURN_IF_ERROR(Resize1DTensor( + interpreter_.get(), input_name_to_idx_->at(kGraphReceiversTensorName), static_cast(graph_builder_->edge_receivers().size()))); - GEMATRIA_RETURN_IF_ERROR( - Resize1DTensor(interpreter->get(), kGraphSendersTensor, - static_cast(graph_builder_->edge_senders().size()))); GEMATRIA_RETURN_IF_ERROR(Resize1DTensor( - interpreter->get(), kGraphNEdgeTensor, + interpreter_.get(), input_name_to_idx_->at(kGraphSendersTensorName), + static_cast(graph_builder_->edge_senders().size()))); + GEMATRIA_RETURN_IF_ERROR(Resize1DTensor( + interpreter_.get(), input_name_to_idx_->at(kGraphNEdgeTensorName), static_cast(graph_builder_->num_nodes_per_block().size()))); GEMATRIA_RETURN_IF_ERROR(Resize1DTensor( - interpreter->get(), kGraphNNodeTensor, + interpreter_.get(), input_name_to_idx_->at(kGraphNNodeTensorName), static_cast(graph_builder_->num_edges_per_block().size()))); - GEMATRIA_RETURN_IF_ERROR( - Resize1DTensor(interpreter->get(), kInstructionNodeMaskTensor, - static_cast(instruction_node_mask.size()))); GEMATRIA_RETURN_IF_ERROR(Resize2DTensor( - interpreter->get(), kGraphGlobalsTensor, + interpreter_.get(), input_name_to_idx_->at(kGraphGlobalsTensorName), /* desired_first_dimension_size = */ graph_builder_->num_graphs(), /* expected_second_dimension_size = */ graph_builder_->num_node_tokens())); - - if (const TfLiteStatus status = (*interpreter)->AllocateTensors(); + if (uses_deltas_) { + GEMATRIA_RETURN_IF_ERROR(Resize1DTensor( + interpreter_.get(), input_name_to_idx_->at(kDeltaBlockIndexTensorName), + static_cast(delta_block_index.size()))); + } + if (uses_annotations_) { + GEMATRIA_RETURN_IF_ERROR(Resize2DTensor( + interpreter_.get(), + input_name_to_idx_->at(kInstructionAnnotationsTensorName), + /* desired_first_dimension_size = */ + static_cast(instruction_annotations.size()), + /* expected_second_dimension_size = */ + static_cast(graph_builder_->annotation_names().size()))); + } + if (uses_deltas_ || uses_annotations_) { + GEMATRIA_RETURN_IF_ERROR( + Resize1DTensor(interpreter_.get(), + input_name_to_idx_->at(kInstructionNodeMaskTensorName), + static_cast(instruction_node_mask.size()))); + } + + if (const TfLiteStatus status = interpreter_->AllocateTensors(); status != kTfLiteOk) { return llvm::make_error( "Could not allocate memory for tensors", llvm::errc::not_enough_memory); } // Fill in the input tensors. - if (llvm::Error error = FillTensorFromStdVector( - interpreter->get(), delta_block_index, kDeltaBlockIndexTensor)) { - return error; - } GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( - interpreter->get(), graph_builder_->node_features(), kGraphNodesTensor)); + interpreter_.get(), graph_builder_->node_features(), + input_name_to_idx_->at(kGraphNodesTensorName))); GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( - interpreter->get(), graph_builder_->EdgeFeatures(), kGraphEdgesTensor)); + interpreter_.get(), graph_builder_->EdgeFeatures(), + input_name_to_idx_->at(kGraphEdgesTensorName))); GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( - interpreter->get(), graph_builder_->edge_receivers(), - kGraphReceiversTensor)); + interpreter_.get(), graph_builder_->edge_receivers(), + input_name_to_idx_->at(kGraphReceiversTensorName))); GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( - interpreter->get(), graph_builder_->edge_senders(), kGraphSendersTensor)); + interpreter_.get(), graph_builder_->edge_senders(), + input_name_to_idx_->at(kGraphSendersTensorName))); GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( - interpreter->get(), graph_builder_->num_nodes_per_block(), - kGraphNNodeTensor)); + interpreter_.get(), graph_builder_->num_nodes_per_block(), + input_name_to_idx_->at(kGraphNNodeTensorName))); GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( - interpreter->get(), graph_builder_->num_edges_per_block(), - kGraphNEdgeTensor)); - GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( - interpreter->get(), instruction_node_mask, kInstructionNodeMaskTensor)); - if (auto error = FillTensorFromStdVectorMatrix( - interpreter->get(), graph_builder_->global_features(), - kGraphGlobalsTensor)) { - return error; - } - - if (const TfLiteStatus status = (*interpreter)->Invoke(); - status != kTfLiteOk) { + interpreter_.get(), graph_builder_->num_edges_per_block(), + input_name_to_idx_->at(kGraphNEdgeTensorName))); + GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVectorMatrix( + interpreter_.get(), graph_builder_->global_features(), + input_name_to_idx_->at(kGraphGlobalsTensorName))); + if (uses_deltas_) { + GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( + interpreter_.get(), delta_block_index, + input_name_to_idx_->at(kDeltaBlockIndexTensorName))); + } + if (uses_annotations_) { + GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVectorMatrix( + interpreter_.get(), instruction_annotations, + input_name_to_idx_->at(kInstructionAnnotationsTensorName))); + } + if (uses_deltas_ || uses_annotations_) { + GEMATRIA_RETURN_IF_ERROR(FillTensorFromStdVector( + interpreter_.get(), instruction_node_mask, + input_name_to_idx_->at(kInstructionNodeMaskTensorName))); + } + + if (const TfLiteStatus status = interpreter_->Invoke(); status != kTfLiteOk) { return llvm::make_error( "Invoking the TensorFlow Lite interpreter failed", llvm::errc::io_error); } - const TfLiteTensor* const output_tensor = (*interpreter)->output_tensor(0); + const TfLiteTensor* const output_tensor = interpreter_->output_tensor(0); if (output_tensor == nullptr) { return llvm::createStringError(llvm::errc::invalid_argument, "No output tensor at index 0."); @@ -554,8 +712,7 @@ GraphBuilderModelInference::RunInference() { output_tensor->dims->data[0]); } const int num_tasks = output_tensor->dims->data[1]; - auto* const output_tensor_data = - (*interpreter)->typed_output_tensor(0); + auto* const output_tensor_data = interpreter_->typed_output_tensor(0); assert(output_tensor_data != nullptr); std::vector output; diff --git a/gematria/granite/graph_builder_model_inference.h b/gematria/granite/graph_builder_model_inference.h index d5bfe879..a5a41b14 100644 --- a/gematria/granite/graph_builder_model_inference.h +++ b/gematria/granite/graph_builder_model_inference.h @@ -22,6 +22,7 @@ #include "gematria/granite/graph_builder.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Error.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model_builder.h" namespace gematria { @@ -87,10 +88,23 @@ class GraphBuilderModelInference { // structure of a model based on the BasicBlockGraphBuilder class. GraphBuilderModelInference( std::unique_ptr graph_builder, - const tflite::FlatBufferModel* tflite_model); + const tflite::FlatBufferModel* tflite_model, + std::unique_ptr interpreter, + std::unique_ptr> + input_name_to_idx, + bool uses_deltas = true, bool uses_annotations = false); std::unique_ptr graph_builder_; const tflite::FlatBufferModel& tflite_model_; + std::unique_ptr interpreter_; + + // The mapping between input tensor names and their indices in the tflite + // model. Allows name-based lookups of input tensors. + std::unique_ptr> input_name_to_idx_; + + // Encodes the configuration of input tensors present in the tflite model. + const bool uses_deltas_; + const bool uses_annotations_; }; } // namespace gematria diff --git a/gematria/granite/graph_builder_test.cc b/gematria/granite/graph_builder_test.cc index bb838cd2..5e8283e0 100644 --- a/gematria/granite/graph_builder_test.cc +++ b/gematria/granite/graph_builder_test.cc @@ -35,6 +35,7 @@ namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::Pair; // Tokens used in the basic blocks in tests. For simplicity, we do not use the // full set of x86-64 tokens. @@ -51,6 +52,9 @@ constexpr absl::string_view kTokens[] = { // 10 "RBX", "RCX", "RDI", kUnknownToken, "NOP", "LOCK"}; +// Names of Instruction annotations used in tests. +const std::set kAnnotationNames{"cache_miss_freq"}; + int TokenIndex(absl::string_view token) { const auto it = std::find(std::begin(kTokens), std::end(kTokens), token); EXPECT_NE(it, std::end(kTokens)) << "Invalid token: " << token; @@ -66,7 +70,8 @@ class BasicBlockGraphBuilderTest : public testing::Test { /*immediate_token =*/kImmediateToken, /*fp_immediate_token =*/kFpImmediateToken, /*address_token =*/kAddressToken, - /*memory_token =*/kMemoryToken, out_of_vocabulary_behavior); + /*memory_token =*/kMemoryToken, + /*annotation_names=*/kAnnotationNames, out_of_vocabulary_behavior); } std::unique_ptr builder_; }; @@ -160,6 +165,45 @@ TEST_F(BasicBlockGraphBuilderTest, SingleInstructionWithPrefix) { ElementsAre(ElementsAre(0, 0, 1, 2, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1))); } +TEST_F(BasicBlockGraphBuilderTest, SingleInstructionWithAnnotation) { + CreateBuilder(OutOfVocabularyTokenBehavior::ReturnError()); + ASSERT_TRUE(builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( + canonicalized_instructions: { + mnemonic: "MOV" + llvm_mnemonic: "MOV64rr" + output_operands: { register_name: "RCX" } + input_operands: { register_name: "RAX" } + instruction_annotations: { name: "cache_miss_freq" value: 0.875 } + })pb")))); + EXPECT_EQ(builder_->num_graphs(), 1); + EXPECT_EQ(builder_->num_nodes(), 3); + EXPECT_EQ(builder_->num_edges(), 2); + EXPECT_EQ(builder_->num_node_tokens(), std::size(kTokens)); + EXPECT_THAT(builder_->num_nodes_per_block(), ElementsAre(3)); + EXPECT_THAT(builder_->num_edges_per_block(), ElementsAre(2)); + + EXPECT_THAT(builder_->node_types(), + ElementsAre(NodeType::kInstruction, NodeType::kRegister, + NodeType::kRegister)); + EXPECT_THAT( + builder_->node_features(), + ElementsAre(TokenIndex("MOV"), TokenIndex("RAX"), TokenIndex("RCX"))); + EXPECT_THAT(builder_->InstructionNodeMask(), ElementsAre(true, false, false)); + + EXPECT_THAT(builder_->edge_senders(), ElementsAre(1, 0)); + EXPECT_THAT(builder_->edge_receivers(), ElementsAre(0, 2)); + EXPECT_THAT(builder_->edge_types(), + ElementsAre(EdgeType::kInputOperands, EdgeType::kOutputOperands)); + + EXPECT_THAT( + builder_->global_features(), + ElementsAre(ElementsAre(0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0))); + + EXPECT_THAT(builder_->annotation_names(), ElementsAre("cache_miss_freq")); + EXPECT_THAT(builder_->instruction_annotations(), + ElementsAre(ElementsAre(0.875))); +} + TEST_F(BasicBlockGraphBuilderTest, InvalidMnemonic_ReturnError) { CreateBuilder(OutOfVocabularyTokenBehavior::ReturnError()); EXPECT_FALSE(builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( @@ -357,7 +401,8 @@ TEST_F(BasicBlockGraphBuilderTest, InvalidAddress_ReplaceToken) { } // Tests that the instruction nodes within the basic block are connected through -// their operands when they refer to the same value. +// their operands when they refer to the same value. Also ensures annotated and +// non-annotated instructions are handled correctly when mixed. TEST_F(BasicBlockGraphBuilderTest, MultipleInstructions) { CreateBuilder(OutOfVocabularyTokenBehavior::ReturnError()); ASSERT_TRUE(builder_->AddBasicBlock(BasicBlockFromProto(ParseTextProto(R"pb( @@ -367,6 +412,7 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleInstructions) { output_operands: { register_name: "R14" } input_operands: { memory: { alias_group_id: 1 } } input_operands: { address: { base_register: "R15" scaling: 1 } } + instruction_annotations: { name: "cache_miss_freq" value: 0.9 } } canonicalized_instructions: { mnemonic: "MOV" @@ -382,6 +428,7 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleInstructions) { llvm_mnemonic: "MOV64rr" output_operands: { register_name: "RCX" } input_operands: { register_name: "RAX" } + instruction_annotations: { name: "cache_miss_freq" value: 0.01 } } canonicalized_instructions: { mnemonic: "NOT" @@ -455,6 +502,10 @@ TEST_F(BasicBlockGraphBuilderTest, MultipleInstructions) { ElementsAre(1, 3, 2, 0, 0, 1, 4, 7, 6, 5, 5, 8, 9, 9, 10, 11)); EXPECT_THAT(builder_->edge_receivers(), ElementsAre(0, 2, 0, 4, 5, 5, 6, 6, 5, 8, 9, 9, 10, 11, 11, 12)); + + EXPECT_THAT(builder_->instruction_annotations(), + ElementsAre(ElementsAre(0.9), ElementsAre(-1), ElementsAre(0.01), + ElementsAre(-1))); } // Tests that nodes in basic blocks added through different AddBasicBlock() diff --git a/gematria/granite/python/graph_builder.cc b/gematria/granite/python/graph_builder.cc index 53ab9565..0b2d1f58 100644 --- a/gematria/granite/python/graph_builder.cc +++ b/gematria/granite/python/graph_builder.cc @@ -14,6 +14,7 @@ #include "gematria/granite/graph_builder.h" +#include #include #include @@ -70,11 +71,13 @@ PYBIND11_MODULE(graph_builder, m) { absl::string_view /* fp_immediate_token */, absl::string_view /* address_token */, absl::string_view /* memory_token */, + std::set /* annotation_names */, OutOfVocabularyTokenBehavior /* out_of_vocabulary_behavior */ >(), py::arg("node_tokens"), py::arg("immediate_token"), py::arg("fp_immediate_token"), py::arg("address_token"), - py::arg("memory_token"), py::arg("out_of_vocabulary_behavior")) + py::arg("memory_token"), py::arg("annotation_names"), + py::arg("out_of_vocabulary_behavior")) .def("add_basic_block", &BasicBlockGraphBuilder::AddBasicBlock, py::arg("block")) .def("add_basic_block_from_instructions", @@ -94,6 +97,10 @@ PYBIND11_MODULE(graph_builder, m) { &BasicBlockGraphBuilder::node_features) .def_property_readonly("instruction_node_mask", &BasicBlockGraphBuilder::InstructionNodeMask) + .def_property_readonly("annotation_names", + &BasicBlockGraphBuilder::annotation_names) + .def_property_readonly("instruction_annotations", + &BasicBlockGraphBuilder::instruction_annotations) .def_property_readonly("edge_senders", &BasicBlockGraphBuilder::edge_senders) .def_property_readonly("edge_receivers", diff --git a/gematria/granite/python/graph_builder_model_base.py b/gematria/granite/python/graph_builder_model_base.py index 21fc46ad..9a50d13b 100644 --- a/gematria/granite/python/graph_builder_model_base.py +++ b/gematria/granite/python/graph_builder_model_base.py @@ -89,6 +89,7 @@ def __init__( fp_immediate_token: str, address_token: str, memory_token: str, + annotation_names: Sequence[str] = [], **kwargs: Any, ) -> None: """Initializes the model with the given feature factory. @@ -108,6 +109,7 @@ def __init__( in the basic block graph. memory_token: The token that is associated with memory value nodes in the basic block graph. + annotation_names: The list of names of annotations to be used. **kwargs: Additional keyword arguments are passed to the constructor of the base class. """ @@ -127,7 +129,13 @@ def __init__( tokens=tokens, **kwargs, ) - self._instruction_node_mask = None + # Definition moved up from _create_readout_network_resources since + # _instruction_node_mask is needed earlier for embedding calculations. + self._instruction_node_mask = tf.placeholder( + dtype=tf.dtypes.bool, + shape=(None,), + name=GraphBuilderModelBase.INSTRUCTION_NODE_MASK_TENSOR_NAME, + ) self._instruction_features = None self._batch_graph_builder = graph_builder.BasicBlockGraphBuilder( node_tokens=self._token_list, @@ -135,6 +143,7 @@ def __init__( fp_immediate_token=fp_immediate_token, address_token=address_token, memory_token=memory_token, + annotation_names=set(annotation_names), out_of_vocabulary_behavior=self._oov_behavior, ) @@ -187,11 +196,6 @@ def _create_tf_graph(self) -> None: # @Override def _create_readout_network_resources(self) -> None: super()._create_readout_network_resources() - self._instruction_node_mask = tf.placeholder( - dtype=tf.dtypes.bool, - shape=(None,), - name=GraphBuilderModelBase.INSTRUCTION_NODE_MASK_TENSOR_NAME, - ) self._instruction_features = tf.boolean_mask( self._graphs_tuple_outputs.nodes, self._instruction_node_mask ) diff --git a/gematria/granite/python/graph_builder_test.py b/gematria/granite/python/graph_builder_test.py index 40147198..9a10dfa9 100644 --- a/gematria/granite/python/graph_builder_test.py +++ b/gematria/granite/python/graph_builder_test.py @@ -85,6 +85,7 @@ def test_single_instruction_basic_block(self): fp_immediate_token=tokens.IMMEDIATE, address_token=tokens.ADDRESS, memory_token=tokens.MEMORY, + annotation_names=set(), out_of_vocabulary_behavior=_OutOfVocabularyTokenBehavior.return_error(), ) @@ -113,6 +114,7 @@ def test_multiple_basic_blocks(self): fp_immediate_token=tokens.IMMEDIATE, address_token=tokens.ADDRESS, memory_token=tokens.MEMORY, + annotation_names=set(), out_of_vocabulary_behavior=_OutOfVocabularyTokenBehavior.return_error(), ) @@ -128,6 +130,7 @@ def test_many_blocks(self): fp_immediate_token=tokens.IMMEDIATE, address_token=tokens.ADDRESS, memory_token=tokens.MEMORY, + annotation_names=set(), out_of_vocabulary_behavior=_OutOfVocabularyTokenBehavior.return_error(), ) @@ -144,6 +147,7 @@ def test_out_of_vocabulary_tokens_return_error(self): fp_immediate_token=tokens.IMMEDIATE, address_token=tokens.ADDRESS, memory_token=tokens.MEMORY, + annotation_names=set(), out_of_vocabulary_behavior=_OutOfVocabularyTokenBehavior.return_error(), ) @@ -162,6 +166,7 @@ def test_out_of_vocabulary_tokens_replace_token(self): fp_immediate_token=tokens.IMMEDIATE, address_token=tokens.ADDRESS, memory_token=tokens.MEMORY, + annotation_names=set(), out_of_vocabulary_behavior=( _OutOfVocabularyTokenBehavior.replace_with_token(tokens.UNKNOWN) ), diff --git a/gematria/granite/python/run_granite_model.py b/gematria/granite/python/run_granite_model.py index 1baad11a..138d07eb 100644 --- a/gematria/granite/python/run_granite_model.py +++ b/gematria/granite/python/run_granite_model.py @@ -44,6 +44,9 @@ def main(argv): model_tokens = token_model_flags.get_tokens_from_command_line_flags( model_tokens=tokens.STRUCTURAL_TOKENS ) + model_annotation_names = ( + token_model_flags.get_annotation_names_from_command_line_flags() + ) main_function.run_gematria_model_from_command_line_flags( model_class, @@ -52,6 +55,7 @@ def main(argv): fp_immediate_token=tokens.IMMEDIATE, address_token=tokens.ADDRESS, memory_token=tokens.MEMORY, + annotation_names=model_annotation_names, dtype=tf.dtypes.float32, node_embedding_size=granite_flags.NODE_EMBEDDING_SIZE.value, edge_embedding_size=granite_flags.EDGE_EMBEDDING_SIZE.value, diff --git a/gematria/granite/python/token_graph_builder_model.py b/gematria/granite/python/token_graph_builder_model.py index 7a4b0b29..cc1e11da 100644 --- a/gematria/granite/python/token_graph_builder_model.py +++ b/gematria/granite/python/token_graph_builder_model.py @@ -22,9 +22,11 @@ from gematria.granite.python import graph_builder from gematria.granite.python import graph_builder_model_base from gematria.model.python import model_blocks +from gematria.model.python import model_base from gematria.model.python import options import graph_nets import sonnet as snt +import numpy as np import tensorflow.compat.v1 as tf @@ -50,6 +52,20 @@ class TokenGraphBuilderModel(graph_builder_model_base.GraphBuilderModelBase): READOUT_VARIABLES = 'TokenGraphBuilderModel.readout' TASK_READOUT_VARIABLES = 'TokenGraphBuilderModel.task_readout' + INSTRUCTION_ANNOTATIONS_TENSOR_NAME = ( + 'TokenGraphBuilderModel.instruction_annotations' + ) + ANNOTATION_NAMES_TENSOR_NAME = 'TokenGraphBuilderModel.annotation_names' + + # A 1D byte tensor that contains the list of annotation names in the order of + # their indices in the graph builder. + _annotation_name_tensor: tf.Tensor + + # The list of annotation names, in the order of their indices in the model. + _annotation_name_list: Sequence[str] + + _instruction_annotations: tf.Tensor + def __init__( self, node_embedding_size: int, @@ -155,6 +171,15 @@ def __init__( self._readout_activation = readout_activation or leaky_relu self._update_activation = update_activation or leaky_relu + self._annotation_name_list = tuple( + self._batch_graph_builder.annotation_names + ) + self._instruction_annotations = tf.placeholder( + dtype=self.dtype, + shape=(None, len(self._annotation_name_list)), + name=TokenGraphBuilderModel.INSTRUCTION_ANNOTATIONS_TENSOR_NAME, + ) + # @Override def _make_model_name(self) -> str: # TODO(ondrasej): Use a string provided by the token feature factory as the @@ -177,6 +202,30 @@ def _make_model_name(self) -> str: f'{self._task_readout_input_layer_normalization}' ) + @property + def annotation_name_tensor(self) -> tf.Tensor: + return self._annotation_name_tensor + + @property + def output_tensor_names(self) -> Sequence[str]: + return ( + *super().output_tensor_names, + TokenGraphBuilderModel.ANNOTATION_NAMES_TENSOR_NAME, + ) + + # @Override + def _create_tf_graph(self) -> None: + super()._create_tf_graph() + + annotation_names_array = np.frombuffer( + b'\0'.join(name.encode('utf-8') for name in self._annotation_name_list), + dtype=np.uint8, + ) + self._annotation_name_tensor = tf.constant( + annotation_names_array, + name=TokenGraphBuilderModel.ANNOTATION_NAMES_TENSOR_NAME, + ) + def _create_dense_readout_network(self, data: tf.Tensor) -> tf.Tensor: """Creates the dense part of the readout network from `data`. @@ -288,10 +337,12 @@ def _create_graph_network_modules( initializers=embedding_initializers, ), node_model_fn=functools.partial( - snt.Embed, + TokenGraphBuilderModelNodeEmbed, vocab_size=len(self._token_list), embed_dim=self._node_embedding_size, initializers=embedding_initializers, + instruction_annotations=self._instruction_annotations, + instruction_node_mask=self._instruction_node_mask, ), global_model_fn=functools.partial( snt.Sequential, @@ -340,3 +391,75 @@ def _create_graph_network_modules( residual_connection=options.EnableFeature.BY_FLAG, ), ) + + def _make_batch_feed_dict(self) -> model_base.FeedDict: + feed_dict = super()._make_batch_feed_dict() + + feed_dict[self._instruction_annotations] = ( + self._batch_graph_builder.instruction_annotations + ) + return feed_dict + + +class TokenGraphBuilderModelNodeEmbed(snt.Embed): + """Extends `snt.Embed` to include instruction annotations in node embeddings. + + Generates node embeddings normally, then replaces the last `num_annotation` + values of the embeddings corresponding to instructions with the annotation + values. The embeddings for other node types remain unchanged. + """ + + def __init__( + self, + instruction_annotations, + instruction_node_mask, + **kwargs, + ) -> None: + """Initializes node embeddings. + + Args: + instruction_annotations: Tensor holding instruction level runtime + annotations as in `BasicBlockGraphBuilder`. + instruction_node_mask: As in `BasicBlockGraphBuilder`. + **kwargs: Additional arguments to be passed to the internal `snt.Embed`. + """ + super().__init__(**kwargs) + + # The number of annotations per instruction. + self.num_annotations = int(instruction_annotations.shape[1]) + + if self.num_annotations > self.embed_dim: + raise ValueError('num_annotations cannot be greater than embed_dim.') + + self.instruction_annotations = instruction_annotations + self.instruction_node_mask = instruction_node_mask + + def __call__(self, inputs): + embeddings = super().__call__(inputs) + # print(embeddings.op.name) + # print(self.num_annotations) + + if self.num_annotations == 0: + return embeddings + + out = tf.concat( + [ + tf.slice( + embeddings, + begin=[0, 0], + size=[-1, self.embed_dim - self.num_annotations], + ), + tf.tensor_scatter_nd_update( + tf.slice( + embeddings, + begin=[0, self.embed_dim - self.num_annotations], + size=[-1, self.num_annotations], + ), + indices=tf.where(self.instruction_node_mask), + updates=self.instruction_annotations, + ), + ], + axis=1, + ) + + return out diff --git a/gematria/model/python/token_model_flags.py b/gematria/model/python/token_model_flags.py index d18efc48..23c9d206 100644 --- a/gematria/model/python/token_model_flags.py +++ b/gematria/model/python/token_model_flags.py @@ -18,7 +18,7 @@ """ from collections.abc import Sequence -from typing import Optional +from typing import Optional, Text from absl import flags from gematria.model.python import oov_token_behavior @@ -38,6 +38,18 @@ ), ) +_ANNOTATION_NAME_FILE = flags.DEFINE_string( + 'gematria_annotation_names_file', + None, + ( + 'The text file that contains the list of annotation names used in the' + ' input basic blocks. Used to incorporate instruction annotations in' + ' the model features. Assumes that the argument is the path of a text' + ' file that contains one annotation name per line. Lines that start' + ' with a hash symbol (#) are considered as comments and ignored.' + ), +) + _OOV_REPLACEMENT_TOKEN = flags.DEFINE_string( 'gematria_out_of_vocabulary_replacement_token', None, @@ -93,6 +105,38 @@ def get_oov_token_behavior_from_command_line_flags() -> ( ) +def get_lines_from_command_line_flag( + flag: flags.FlagHolder[Text | None], +) -> Optional[Sequence[str]]: + """Returns a list of lines from a file passed as a command-line arguments. + + The input file is expected to be a text file. When loading the tokens, the + function: + 1. removes leading and trailing whitespace from each line. + 2. ignores lines starting with a hash character (#). + + Args: + flag: The command-line flag holding the path of the input file. + + Returns: + The list of lines loaded from the file specified by the command-line flags + or None when no such file is specified. The returned list is sorted and + contains each line at most once. + """ + if not flag.value: + return None + + lines = set() + with tf.io.gfile.GFile(flag.value, 'r') as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + continue + lines.add(line) + + return sorted(lines) + + def get_tokens_from_command_line_flags( # model_tokens: Sequence[str] = (), ) -> Optional[Sequence[str]]: @@ -114,18 +158,35 @@ def get_tokens_from_command_line_flags( # or None when no such file is specified. The returned list is sorted and contains each token at most once. """ - if not _TOKEN_FILE.value: + tokens_from_file = get_lines_from_command_line_flag(_TOKEN_FILE) + if tokens_from_file is None: return None - tokens = set(model_tokens) - with tf.io.gfile.GFile(_TOKEN_FILE.value, 'r') as f: - for line in f: - line = line.strip() - if not line or line.startswith('#'): - continue - tokens.add(line) + return sorted(set(model_tokens) | set(tokens_from_file)) + + +def get_annotation_names_from_command_line_flags() -> Sequence[str]: + """Returns the list of annotation names used in the model. + + When the command-line flag --gematria_annotation_names_file is used, returns a + sorted list of annotation_names from this file. The input file is expected to + be a text file that contains one annotation name per line. When loading the + annotation names, the function: + 1. removes leading and trailing whitespace from each line. + 2. ignores lines starting with a hash character (#). + + Returns: + The list of annotation_names loaded from the file specified by the + command-line flags or an empty list when no such file is specified. The + returned list is sorted and contains each annotation_name at most once. + """ + annotation_names_from_file = get_lines_from_command_line_flag( + _ANNOTATION_NAME_FILE + ) + if annotation_names_from_file is None: + return [] - return sorted(tokens) + return annotation_names_from_file def mark_token_flags_as_required() -> None: diff --git a/gematria/proto/BUILD b/gematria/proto/BUILD index 88aa97b0..0a366b81 100644 --- a/gematria/proto/BUILD +++ b/gematria/proto/BUILD @@ -15,6 +15,9 @@ gematria_proto_library( gematria_proto_library( name = "canonicalized_instruction_proto", srcs = ["canonicalized_instruction.proto"], + deps = [ + ":annotation_proto", + ], ) gematria_proto_library( @@ -24,3 +27,8 @@ gematria_proto_library( ":basic_block_proto", ], ) + +gematria_proto_library( + name = "annotation_proto", + srcs = ["annotation.proto"], +) diff --git a/gematria/proto/annotation.proto b/gematria/proto/annotation.proto new file mode 100644 index 00000000..93113e34 --- /dev/null +++ b/gematria/proto/annotation.proto @@ -0,0 +1,27 @@ +// Copyright 2023 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package gematria; + +// Contains annotations to supply additional information to the model, such as +// cache-miss frequencies, or branching related statistics. +message AnnotationProto { + // A name or label for the annotation. + string name = 1; + // The annotation value, holding information such as measurements or + // statistics like event frequency or rate. + double value = 2; +} diff --git a/gematria/proto/canonicalized_instruction.proto b/gematria/proto/canonicalized_instruction.proto index b7a9687a..d2ca5c35 100644 --- a/gematria/proto/canonicalized_instruction.proto +++ b/gematria/proto/canonicalized_instruction.proto @@ -18,6 +18,8 @@ syntax = "proto3"; package gematria; +import "gematria/proto/annotation.proto"; + // Contains information about an instruction and all its inputs and outputs. // This proto can be used to create the embedding of the instruction as // described in the Ithemal [1] and Granite [2] papers. @@ -51,6 +53,9 @@ message CanonicalizedInstructionProto { // The list of implicit input operands of the instruction. repeated CanonicalizedOperandProto implicit_input_operands = 7; + + // Runtime related instruction level annotations. + repeated AnnotationProto instruction_annotations = 8; } // Contains information about a single operand in the canonicalized instruction. diff --git a/requirements.in b/requirements.in index 821fc82e..eced7567 100644 --- a/requirements.in +++ b/requirements.in @@ -15,3 +15,4 @@ tensorflow-probability>=0.19.0 tensorflow-ranking; sys_platform!='darwin' tensorflow>=2.11.0; sys_platform=='linux' tensorflow-macos>=2.11.0; sys_platform=='darwin' +lit>=17.0.6