Skip to content

Commit

Permalink
Expose sentencepiece parallel tokenize in tf text sentencepiece wrapper.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608351975
  • Loading branch information
tf-text-github-robot committed Feb 21, 2024
1 parent daf770b commit 7d29681
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
26 changes: 22 additions & 4 deletions tensorflow_text/core/kernels/sentencepiece_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
Expand Down Expand Up @@ -266,6 +267,10 @@ class SentencepieceTokenizeOp : public OpKernel {
public:
explicit SentencepieceTokenizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
ctx->GetAttr("return_nbest", &return_nbest_).IgnoreError();

// Parallel encode options.
ctx->GetAttr("num_threads", &num_threads_).IgnoreError();
ctx->GetAttr("chunk_size", &chunk_size_).IgnoreError();
}

void Compute(OpKernelContext* ctx) override {
Expand Down Expand Up @@ -309,6 +314,8 @@ class SentencepieceTokenizeOp : public OpKernel {
nbest_tokens(return_nbest_ ? num_of_input_values : 0);
if (num_of_input_values > 0) {
const bool return_nbest = return_nbest_;
const int32 num_threads = num_threads_;
const int32 chunk_size = chunk_size_;
const auto& worker_threads =
*(ctx->device()->tensorflow_cpu_worker_threads());
::tensorflow::Shard(
Expand All @@ -317,8 +324,8 @@ class SentencepieceTokenizeOp : public OpKernel {
num_of_input_values, // total number of data to process.
kCostPerUnit, // cost per unit
[ctx, sp, &input_values_flat, &tokens, &nbest_tokens,
&nbest_size_tensor, &alpha_tensor,
return_nbest](int64 start, int64 limit) {
&nbest_size_tensor, &alpha_tensor, &chunk_size,
&num_threads, return_nbest](int64 start, int64 limit) {
absl::ReaderMutexLock lock(&sp->mu);
for (int i = start; i < limit; ++i) {
const int32 nbest_size = nbest_size_tensor->dims() == 1
Expand All @@ -329,8 +336,17 @@ class SentencepieceTokenizeOp : public OpKernel {
input_values_flat(i), nbest_size,
&nbest_tokens[i])));
} else if (nbest_size == 0 || nbest_size == 1) {
OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode(
input_values_flat(i), &tokens[i])));
if (num_threads == 1) {
OP_REQUIRES_OK(
ctx,
ToTFStatus(sp->processor.Encode(
input_values_flat(i), &tokens[i])));
} else {
OP_REQUIRES_OK(
ctx, ToTFStatus(sp->processor.ParallelEncode(
input_values_flat(i), chunk_size, num_threads,
&tokens[i])));
}
} else {
const float alpha = alpha_tensor->dims() == 1
? alpha_tensor->vec<float>()(i)
Expand Down Expand Up @@ -379,6 +395,8 @@ class SentencepieceTokenizeOp : public OpKernel {
}

bool return_nbest_{false};
int32_t num_threads_{1};
int32_t chunk_size_{0};
};

REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeOp")
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_text/core/ops/sentencepiece_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ REGISTER_OP("SentencepieceTokenizeOp")
.Input("add_bos: bool")
.Input("add_eos: bool")
.Input("reverse: bool")
.Attr("num_threads: int = 1")
.Attr("chunk_size: int = 0")
.Attr("out_type: {int32, string} = DT_INT32")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.Attr("return_nbest: bool = false")
Expand Down
11 changes: 10 additions & 1 deletion tensorflow_text/python/ops/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(self,
reverse=False,
add_bos=False,
add_eos=False,
num_threads=1,
chunk_size=0,
return_nbest=False,
name=None):
"""Creates & initializes a Sentencepiece processor.
Expand All @@ -101,6 +103,10 @@ def __init__(self,
add_eos: Add end of sentence token to the result (Default = false). When
`reverse=True` beginning/end of sentence tokens are added after
reversing.
num_threads: If `> 1`, the input is split up into chunks of size
`chunk_size` and tokenized in parallel with this many threads.
chunk_size: Only used if `num_threads > 1`. The input is split into
chunks of this size and tokenized in parallel.
return_nbest: If True requires that `nbest_size` is a scalar and `> 1`.
Returns the `nbest_size` best tokenizations for each sentence instead
of a single one. The returned tensor has shape
Expand All @@ -118,6 +124,8 @@ def __init__(self,
self.reverse = reverse
self.add_bos = add_bos
self.add_eos = add_eos
self.num_threads = num_threads
self.chunk_size = chunk_size
self.return_nbest = return_nbest
self._model_resource = _SentencepieceModelResource(model, name)

Expand Down Expand Up @@ -154,7 +162,8 @@ def tokenize(self, input, name=None): # pylint: disable=redefined-builtin
gen_sentencepiece_tokenizer.sentencepiece_tokenize_op(
self._model_resource.resource_handle, input_tensor,
self.nbest_size, self.alpha, self.add_bos, self.add_eos,
self.reverse, self.out_type, return_nbest=self.return_nbest))
self.reverse, self.num_threads, self.chunk_size,
self.out_type, return_nbest=self.return_nbest))
tokens = RaggedTensor.from_nested_row_splits(
flat_values=output_values,
nested_row_splits=[row_splits],
Expand Down

0 comments on commit 7d29681

Please sign in to comment.