Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for consuming instruction annotations to TokenGraphBuilderModels. #94

Merged
merged 14 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gematria/granite/graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

#include <cstddef>
#include <ostream>
#include <set>
#include <string>
#include <string_view>
#include <unordered_map>
Expand Down
14 changes: 14 additions & 0 deletions gematria/granite/python/gnn_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(

# @Override
def _create_tf_graph(self) -> None:
self._create_graph_network_resources()
self._graph_network = self._create_graph_network_modules()
assert self._graph_network is not None
self._graphs_tuple_placeholders = self._create_graphs_placeholders()
Expand Down Expand Up @@ -319,6 +320,19 @@ def _create_graphs_placeholders(self) -> graph_nets.graphs.GraphsTuple:
),
)

def _create_graph_network_resources(self) -> None:
"""Creates resources (like TensorFlow ops) needed by the readout network.

Child classes can override this method to create resources (e.g. TensorFlow
ops) that will be needed during the creation of the graph network, but
that are impractical to create in self._create_graph_network_modules(), e.g.
to make it easy to override in child classes.

This method is called before self._create_graph_network_modules().

By default, this method is a no-op.
"""

def _create_graph_network(self) -> graph_nets.graphs.GraphsTuple:
"""Creates TensorFlow ops for the graph network.

Expand Down
1 change: 1 addition & 0 deletions gematria/granite/python/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "gematria/granite/graph_builder.h"

#include <set>
#include <string>
#include <vector>

Expand Down
56 changes: 53 additions & 3 deletions gematria/granite/python/graph_builder_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ class GraphBuilderModelBase(
'GraphBuilderModelBase.instruction_node_mask'
)

# The name of the input tensor that holds the instruction annotations.
INSTRUCTION_ANNOTATIONS_TENSOR_NAME = (
'GraphBuilderModelBase.instruction_annotations'
)

# The name of the tensor holding ordered annotation names.
ANNOTATION_NAMES_TENSOR_NAME = 'GraphBuilderModelBase.annotation_names'

# A Boolean tensor placeholder that receives a mask for instruction nodes. The
# mask has shape (None,), and it must have the same length as
# self._graphs_tuple_placeholders.nodes along the first dimension. It contains
Expand All @@ -81,6 +89,16 @@ class GraphBuilderModelBase(
# the format of the data.
_special_tokens_tensor: tf.Tensor

# A 1D byte tensor that contains the list of annotation names in the order of
# their indices in the graph builder.
_annotation_name_tensor: tf.Tensor

# The list of annotation names, in the order of their indices in the model.
_annotation_name_list: Sequence[str]

# A 2D float tensor holding instruction annotations.
_instruction_annotations: tf.Tensor

def __init__(
self,
*,
Expand All @@ -89,6 +107,7 @@ def __init__(
fp_immediate_token: str,
address_token: str,
memory_token: str,
annotation_names: Sequence[str] = [],
**kwargs: Any,
) -> None:
"""Initializes the model with the given feature factory.
Expand All @@ -108,6 +127,7 @@ def __init__(
in the basic block graph.
memory_token: The token that is associated with memory value nodes in the
basic block graph.
annotation_names: The list of names of annotations to be used.
**kwargs: Additional keyword arguments are passed to the constructor of
the base class.
"""
Expand All @@ -127,19 +147,24 @@ def __init__(
tokens=tokens,
**kwargs,
)
self._instruction_node_mask = None
self._instruction_features = None
self._batch_graph_builder = graph_builder.BasicBlockGraphBuilder(
node_tokens=self._token_list,
immediate_token=immediate_token,
fp_immediate_token=fp_immediate_token,
address_token=address_token,
memory_token=memory_token,
annotation_names=annotation_names,
out_of_vocabulary_behavior=self._oov_behavior,
)

self._special_tokens_tensor = None

self._annotation_name_list = tuple(
self._batch_graph_builder.annotation_names
)
self._num_annotations = len(self._annotation_name_list)

@property
def special_tokens_tensor(self) -> tf.Tensor:
"""Returns the indices of special node tokens.
Expand All @@ -156,12 +181,17 @@ def special_tokens_tensor(self) -> tf.Tensor:
"""
return self._special_tokens_tensor

@property
def annotation_names_tensor(self) -> tf.Tensor:
return self._annotation_names_tensor

# @Override
@property
def output_tensor_names(self) -> Sequence[str]:
return (
*super().output_tensor_names,
GraphBuilderModelBase.SPECIAL_TOKENS_TENSOR_NAME,
GraphBuilderModelBase.ANNOTATION_NAMES_TENSOR_NAME,
)

# @Override
Expand All @@ -183,15 +213,32 @@ def _create_tf_graph(self) -> None:
dtype=tf.dtypes.int32,
name=GraphBuilderModelBase.SPECIAL_TOKENS_TENSOR_NAME,
)
annotation_names_array = np.frombuffer(
b'\0'.join(name.encode('utf-8') for name in self._annotation_name_list),
dtype=np.uint8,
)
self._annotation_name_tensor = tf.constant(
annotation_names_array,
name=GraphBuilderModelBase.ANNOTATION_NAMES_TENSOR_NAME,
)

# @Override
def _create_readout_network_resources(self) -> None:
super()._create_readout_network_resources()
def _create_graph_network_resources(self) -> None:
super()._create_graph_network_resources()
self._instruction_annotations = tf.placeholder(
dtype=self.dtype,
shape=(None, len(self._annotation_name_list)),
name=GraphBuilderModelBase.INSTRUCTION_ANNOTATIONS_TENSOR_NAME,
)
self._instruction_node_mask = tf.placeholder(
dtype=tf.dtypes.bool,
shape=(None,),
name=GraphBuilderModelBase.INSTRUCTION_NODE_MASK_TENSOR_NAME,
)

# @Override
def _create_readout_network_resources(self) -> None:
super()._create_readout_network_resources()
self._instruction_features = tf.boolean_mask(
self._graphs_tuple_outputs.nodes, self._instruction_node_mask
)
Expand All @@ -207,6 +254,9 @@ def _make_batch_feed_dict(self) -> model_base.FeedDict:
feed_dict[self._instruction_node_mask] = np.array(
self._batch_graph_builder.instruction_node_mask, dtype=bool
)
feed_dict[self._instruction_annotations] = (
self._batch_graph_builder.instruction_annotations
)
return feed_dict

# @Override
Expand Down
6 changes: 3 additions & 3 deletions gematria/granite/python/graph_builder_model_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def test_validate_basic_block(self):
# This basic block is invalid - there is no x86-64 instruction `FOOBAR`.
invalid_block = throughput.BasicBlockWithThroughput(
block=basic_block.BasicBlock(
basic_block.InstructionList((
basic_block.Instruction(mnemonic='FOOBAR'),
))
basic_block.InstructionList(
(basic_block.Instruction(mnemonic='FOOBAR'),)
)
)
)
self.assertFalse(model.validate_basic_block_with_throughput(invalid_block))
Expand Down
4 changes: 4 additions & 0 deletions gematria/granite/python/run_granite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def main(argv):
model_tokens = token_model_flags.get_tokens_from_command_line_flags(
model_tokens=tokens.STRUCTURAL_TOKENS
)
model_annotation_names = (
token_model_flags.get_annotation_names_from_command_line_flags()
)

main_function.run_gematria_model_from_command_line_flags(
model_class,
Expand All @@ -52,6 +55,7 @@ def main(argv):
fp_immediate_token=tokens.IMMEDIATE,
address_token=tokens.ADDRESS,
memory_token=tokens.MEMORY,
annotation_names=model_annotation_names,
dtype=tf.dtypes.float32,
node_embedding_size=granite_flags.NODE_EMBEDDING_SIZE.value,
edge_embedding_size=granite_flags.EDGE_EMBEDDING_SIZE.value,
Expand Down
92 changes: 90 additions & 2 deletions gematria/granite/python/token_graph_builder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ def __init__(
self._readout_activation = readout_activation or leaky_relu
self._update_activation = update_activation or leaky_relu

if self._num_annotations > self._node_embedding_size:
raise ValueError(
'_num_annotations cannot be greater than _node_embedding_size.'
)

# The length of the learnt part of the instruction node embedding vectors.
# The remaining elements are filled in with instruction annotations.
self._common_node_embedding_size = (
self._node_embedding_size - self._num_annotations
)

# @Override
def _make_model_name(self) -> str:
# TODO(ondrasej): Use a string provided by the token feature factory as the
Expand Down Expand Up @@ -288,9 +299,12 @@ def _create_graph_network_modules(
initializers=embedding_initializers,
),
node_model_fn=functools.partial(
snt.Embed,
TokenGraphBuilderModelNodeEmbed,
vocab_size=len(self._token_list),
embed_dim=self._node_embedding_size,
common_embed_dim=self._common_node_embedding_size,
num_annotations=self._num_annotations,
instruction_annotations=self._instruction_annotations,
instruction_node_mask=self._instruction_node_mask,
initializers=embedding_initializers,
),
global_model_fn=functools.partial(
Expand Down Expand Up @@ -340,3 +354,77 @@ def _create_graph_network_modules(
residual_connection=options.EnableFeature.BY_FLAG,
),
)


class TokenGraphBuilderModelNodeEmbed:
"""Class representing node embeddings with instruction annotations included.

`snt.Embed`-like class. Generates node embeddings normally, then replaces the
last `num_annotation` values of the embeddings corresponding to instructions
with the annotation values. The embeddings for other node types remain
unchanged.
"""

def __init__(
self,
common_embed_dim,
num_annotations,
instruction_annotations,
instruction_node_mask,
**kwargs,
) -> None:
"""Initializes node embeddings.

Args:
common_embed_dim: The length of the learnt part of the instruction node
embedding vectors. The remainder of the vector is filled with
instruction annotation.
num_annotations: The number of annotations per instruction.
instruction_annotations: Tensor holding instruction level runtime
annotations as in `BasicBlockGraphBuilder`.
instruction_node_mask: As in `BasicBlockGraphBuilder`.
kwargs: Additional arguments to be passed to the internal `snt.Embed`s.
"""
self._instruction_annotations = instruction_annotations
self._instruction_node_mask = instruction_node_mask

# The first `embed_dim - num_annotations` embedding values for all nodes.
self._common_embed = snt.Embed(
embed_dim=common_embed_dim,
**kwargs,
)

# `num_annotations` extra learnt embedding values for non-instruction nodes.
# Instruction nodes will use instruction annotations instead of these learnt
# embeddings. This is not required when there are no annotations - in that
# case, we simply return the common embeddings.
self._extra_embed = None
if num_annotations:
self._extra_embed = snt.Embed(
embed_dim=num_annotations,
**kwargs,
)

def __call__(
self,
inputs,
):
if not self._extra_embed:
return self._common_embed(inputs)

common_embeddings = self._common_embed(inputs)
extra_embeddings = self._extra_embed(inputs)

return tf.concat(
[
common_embeddings,
tf.tensor_scatter_nd_update(
extra_embeddings,
indices=tf.where(
self._instruction_node_mask,
),
updates=self._instruction_annotations,
),
],
axis=1,
)
Loading
Loading