Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add instruction annotations to the basic block representations. #92

Merged
merged 4 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions gematria/basic_block/basic_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,27 +229,63 @@ std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand) {
return os;
}

Annotation::Annotation(std::string name, double value)
: name(std::move(name)), value(value) {}

bool Annotation::operator==(const Annotation& other) const {
const auto as_tuple = [](const Annotation& annotation) {
return std::tie(annotation.name, annotation.value);
};
return as_tuple(*this) == as_tuple(other);
}

std::string Annotation::ToString() const {
std::stringstream buffer;
buffer << "Annotation(";
if (!name.empty()) {
buffer << "name='" << name << "', ";
}
if (value != -1) {
buffer << "value=" << value << ", ";
}
// If we added any keyword args to the buffer, drop the last two characters
// (a comma and a space). This is not strictly necessary, but it looks better.
auto msg = buffer.str();
assert(msg.size() >= 2);
if (msg.back() == ' ') msg.resize(msg.size() - 2);
msg.push_back(')');
return msg;
}

std::ostream& operator<<(std::ostream& os, const Annotation& annotation) {
os << annotation.ToString();
return os;
}

Instruction::Instruction(
std::string mnemonic, std::string llvm_mnemonic,
std::vector<std::string> prefixes,
std::vector<InstructionOperand> input_operands,
std::vector<InstructionOperand> implicit_input_operands,
std::vector<InstructionOperand> output_operands,
std::vector<InstructionOperand> implicit_output_operands)
std::vector<InstructionOperand> implicit_output_operands,
std::vector<Annotation> instruction_annotations)
: mnemonic(std::move(mnemonic)),
llvm_mnemonic(std::move(llvm_mnemonic)),
prefixes(std::move(prefixes)),
input_operands(std::move(input_operands)),
implicit_input_operands(std::move(implicit_input_operands)),
output_operands(std::move(output_operands)),
implicit_output_operands(std::move(implicit_output_operands)) {}
implicit_output_operands(std::move(implicit_output_operands)),
instruction_annotations(std::move(instruction_annotations)) {}

bool Instruction::operator==(const Instruction& other) const {
const auto as_tuple = [](const Instruction& instruction) {
return std::tie(
instruction.mnemonic, instruction.llvm_mnemonic, instruction.prefixes,
instruction.input_operands, instruction.implicit_input_operands,
instruction.output_operands, instruction.implicit_output_operands);
instruction.output_operands, instruction.implicit_output_operands,
instruction.instruction_annotations);
};
return as_tuple(*this) == as_tuple(other);
}
Expand Down Expand Up @@ -292,6 +328,17 @@ std::string Instruction::ToString() const {
add_operand_list("output_operands", output_operands);
add_operand_list("implicit_output_operands", implicit_output_operands);

if (!instruction_annotations.empty()) {
buffer << "instruction_annotations=(";
for (const Annotation& annotation : instruction_annotations) {
buffer << annotation.ToString() << ", ";
}
// Pop only the trailing space. For simplicity, we leave the trailing comma
// which is required in case there is only one element.
buffer.seekp(-1, std::ios_base::end);
buffer << "), ";
}

auto msg = buffer.str();
assert(msg.size() >= 2);
if (msg.back() == ' ') msg.resize(msg.size() - 2);
Expand Down
34 changes: 33 additions & 1 deletion gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,31 @@ class InstructionOperand {

std::ostream& operator<<(std::ostream& os, const InstructionOperand& operand);

// Represents an annotation holding a value such as some measure/statistic
// paired with the instruction.
struct Annotation {
Annotation() : value(-1){};

// Initializes all fields of the annotation.
Annotation(std::string name, double value);

Annotation(const Annotation&) = default;
Annotation(Annotation&&) = default;

Annotation& operator=(const Annotation&) = default;
Annotation& operator=(Annotation&&) = default;

bool operator==(const Annotation& other) const;
bool operator!=(const Annotation& other) const { return !(*this == other); }

std::string ToString() const;

std::string name;
double value;
};

std::ostream& operator<<(std::ostream& os, const Annotation& annotation);

// Represents a single instruction.
struct Instruction {
Instruction() {}
Expand All @@ -229,7 +254,9 @@ struct Instruction {
std::vector<InstructionOperand> input_operands,
std::vector<InstructionOperand> implicit_input_operands,
std::vector<InstructionOperand> output_operands,
std::vector<InstructionOperand> implicit_output_operands);
std::vector<InstructionOperand> implicit_output_operands,
std::vector<Annotation> instruction_annotations =
std::vector<Annotation>{});

Instruction(const Instruction&) = default;
Instruction(Instruction&&) = default;
Expand Down Expand Up @@ -280,6 +307,11 @@ struct Instruction {
// to the ML models explicitly.
std::vector<InstructionOperand> implicit_output_operands;

// The list of instruction level annotations used to supply additional
// information to the model. Currently includes the cache miss frequency of
// the instruction. Used to better model the overhead coming from LLC misses.
std::vector<Annotation> instruction_annotations;

// The address of the instruction.
uint64_t address = 0;
// The size of the instruction.
Expand Down
104 changes: 62 additions & 42 deletions gematria/basic_block/basic_block_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,37 @@
#include <vector>

#include "gematria/basic_block/basic_block.h"
#include "gematria/proto/annotation.pb.h"
#include "gematria/proto/canonicalized_instruction.pb.h"
#include "google/protobuf/repeated_ptr_field.h"

namespace gematria {

namespace {

template <typename Object, typename Proto, typename Convertor>
std::vector<Object> ToVector(
const google::protobuf::RepeatedPtrField<Proto>& protos,
Convertor object_from_proto) {
std::vector<Object> result(std::size(protos));
std::transform(std::begin(protos), std::end(protos), std::begin(result),
object_from_proto);
return result;
}

template <typename Object, typename Proto, typename Convertor>
void ToRepeatedPtrField(
const std::vector<Object>& objects,
google::protobuf::RepeatedPtrField<Proto>* repeated_field,
Convertor proto_from_object) {
repeated_field->Reserve(std::size(objects));
std::transform(std::begin(objects), std::end(objects),
google::protobuf::RepeatedFieldBackInserter(repeated_field),
proto_from_object);
}

} // namespace

AddressTuple AddressTupleFromProto(
const CanonicalizedOperandProto::AddressTuple& proto) {
return AddressTuple(
Expand Down Expand Up @@ -91,41 +117,41 @@ CanonicalizedOperandProto ProtoFromInstructionOperand(
return proto;
}

namespace {

std::vector<InstructionOperand> ToVector(
const google::protobuf::RepeatedPtrField<CanonicalizedOperandProto>&
protos) {
std::vector<InstructionOperand> result(protos.size());
std::transform(protos.begin(), protos.end(), result.begin(),
InstructionOperandFromProto);
return result;
Annotation AnnotationFromProto(const AnnotationProto& proto) {
return Annotation(
/* name = */ proto.name(),
/* value = */ proto.value());
}

void ToRepeatedPtrField(
const std::vector<InstructionOperand>& operands,
google::protobuf::RepeatedPtrField<CanonicalizedOperandProto>*
repeated_field) {
repeated_field->Reserve(operands.size());
std::transform(operands.begin(), operands.end(),
google::protobuf::RepeatedFieldBackInserter(repeated_field),
ProtoFromInstructionOperand);
AnnotationProto ProtoFromAnnotation(const Annotation& annotation) {
AnnotationProto proto;
proto.set_name(annotation.name);
proto.set_value(annotation.value);
return proto;
}

} // namespace

Instruction InstructionFromProto(const CanonicalizedInstructionProto& proto) {
return Instruction(
/* mnemonic = */ proto.mnemonic(),
/* llvm_mnemonic = */ proto.llvm_mnemonic(),
/* prefixes = */
std::vector<std::string>(proto.prefixes().begin(),
proto.prefixes().end()),
/* input_operands = */ ToVector(proto.input_operands()),
/* implicit_input_operands = */ ToVector(proto.implicit_input_operands()),
/* output_operands = */ ToVector(proto.output_operands()),
/* input_operands = */
ToVector<InstructionOperand>(proto.input_operands(),
InstructionOperandFromProto),
/* implicit_input_operands = */
ToVector<InstructionOperand>(proto.implicit_input_operands(),
InstructionOperandFromProto),
/* output_operands = */
ToVector<InstructionOperand>(proto.output_operands(),
InstructionOperandFromProto),
/* implicit_output_operands = */
ToVector(proto.implicit_output_operands()));
ToVector<InstructionOperand>(proto.implicit_output_operands(),
InstructionOperandFromProto),
/* instruction_annotations = */
ToVector<Annotation>(proto.instruction_annotations(),
AnnotationFromProto));
}

CanonicalizedInstructionProto ProtoFromInstruction(
Expand All @@ -135,33 +161,27 @@ CanonicalizedInstructionProto ProtoFromInstruction(
proto.set_llvm_mnemonic(instruction.llvm_mnemonic);
proto.mutable_prefixes()->Assign(instruction.prefixes.begin(),
instruction.prefixes.end());
ToRepeatedPtrField(instruction.input_operands,
proto.mutable_input_operands());
ToRepeatedPtrField(instruction.input_operands, proto.mutable_input_operands(),
ProtoFromInstructionOperand);
ToRepeatedPtrField(instruction.implicit_input_operands,
proto.mutable_implicit_input_operands());
proto.mutable_implicit_input_operands(),
ProtoFromInstructionOperand);
ToRepeatedPtrField(instruction.output_operands,
proto.mutable_output_operands());
proto.mutable_output_operands(),
ProtoFromInstructionOperand);
ToRepeatedPtrField(instruction.implicit_output_operands,
proto.mutable_implicit_output_operands());
proto.mutable_implicit_output_operands(),
ProtoFromInstructionOperand);
ToRepeatedPtrField(instruction.instruction_annotations,
proto.mutable_instruction_annotations(),
ProtoFromAnnotation);
return proto;
}

namespace {

std::vector<Instruction> ToVector(
const google::protobuf::RepeatedPtrField<CanonicalizedInstructionProto>&
protos) {
std::vector<Instruction> result(protos.size());
std::transform(protos.begin(), protos.end(), result.begin(),
InstructionFromProto);
return result;
}

} // namespace

BasicBlock BasicBlockFromProto(const BasicBlockProto& proto) {
return BasicBlock(
/* instructions = */ ToVector(proto.canonicalized_instructions()));
/* instructions = */ ToVector<Instruction>(
proto.canonicalized_instructions(), InstructionFromProto));
}

} // namespace gematria
6 changes: 6 additions & 0 deletions gematria/basic_block/basic_block_protos.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ InstructionOperand InstructionOperandFromProto(
CanonicalizedOperandProto ProtoFromInstructionOperand(
const InstructionOperand& operand);

// Creates an annotation data structure from a proto.
Annotation AnnotationFromProto(const AnnotationProto& proto);

// Creates a proto representing the given annotation.
AnnotationProto ProtoFromAnnotation(const Annotation& annotation);

// Creates an instruction data structure from a proto.
Instruction InstructionFromProto(const CanonicalizedInstructionProto& proto);

Expand Down
33 changes: 30 additions & 3 deletions gematria/basic_block/basic_block_protos_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,26 @@ TEST(ProtoFromInstructionOperandTest, Memory) {
EqualsProto(R"pb(memory { alias_group_id: 123 })pb"));
}

TEST(AnnotationFromProtoTest, AllFields) {
const AnnotationProto proto = ParseTextProto(R"pb(
name: 'cache_miss_freq'
value: 0.875
)pb");
Annotation annotation = AnnotationFromProto(proto);
EXPECT_EQ(annotation.name, "cache_miss_freq");
EXPECT_EQ(annotation.value, 0.875);
}

TEST(ProtoFromAnnotationTest, AllFields) {
const Annotation annotation(
/* name = */ "cache_miss_freq",
/* value = */ 0.875);
EXPECT_THAT(ProtoFromAnnotation(annotation), EqualsProto(R"pb(
name: 'cache_miss_freq'
value: 0.875
)pb"));
}

TEST(InstructionFromProtoTest, AllFields) {
const CanonicalizedInstructionProto proto = ParseTextProto(R"pb(
mnemonic: "ADC"
Expand All @@ -156,6 +176,7 @@ TEST(InstructionFromProtoTest, AllFields) {
implicit_output_operands { register_name: "EFLAGS" }
implicit_input_operands { register_name: "EFLAGS" }
implicit_input_operands { immediate_value: 1 }
instruction_annotations { name: "cache_miss_freq" value: 0.875 }
)pb");
Instruction instruction = InstructionFromProto(proto);
EXPECT_EQ(
Expand All @@ -169,7 +190,9 @@ TEST(InstructionFromProtoTest, AllFields) {
InstructionOperand::ImmediateValue(1)},
/* output_operands = */ {InstructionOperand::Register("RAX")},
/* implicit_output_operands = */
{InstructionOperand::Register("EFLAGS")}));
{InstructionOperand::Register("EFLAGS")},
/* instruction_annotations = */
{Annotation("cache_miss_freq", 0.875)}));
}

TEST(ProtoFromInstructionTest, AllFields) {
Expand Down Expand Up @@ -205,6 +228,7 @@ TEST(BasicBlockFromProtoTest, SomeInstructions) {
llvm_mnemonic: "MOV64rr"
output_operands: { register_name: "RCX" }
input_operands: { register_name: "RAX" }
instruction_annotations: { name: "cache_miss_freq" value: 0.875 }
}
canonicalized_instructions: {
mnemonic: "NOT"
Expand All @@ -223,15 +247,18 @@ TEST(BasicBlockFromProtoTest, SomeInstructions) {
/* input_operands = */ {InstructionOperand::Register("RAX")},
/* implicit_input_operands = */ {},
/* output_operands = */ {InstructionOperand::Register("RCX")},
/* implicit_output_operands = */ {}),
/* implicit_output_operands = */ {},
/* instruction_annotations = */
{Annotation("cache_miss_freq", 0.875)}),
Instruction(
/* mnemonic = */ "NOT",
/* llvm_mnemonic = */ "NOT64r",
/* prefixes = */ {},
/* input_operands = */ {InstructionOperand::Register("RCX")},
/* implicit_input_operands = */ {},
/* output_operands = */ {InstructionOperand::Register("RCX")},
/* implicit_output_operands = */ {})}));
/* implicit_output_operands = */ {},
/* instruction_annotations = */ {})}));
}

} // namespace
Expand Down
Loading
Loading