Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose sentencepiece parallel tokenize in tf text sentencepiece wrapper. #1249

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions tensorflow_text/core/kernels/sentencepiece_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
Expand Down Expand Up @@ -266,6 +267,10 @@ class SentencepieceTokenizeOp : public OpKernel {
public:
explicit SentencepieceTokenizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
ctx->GetAttr("return_nbest", &return_nbest_).IgnoreError();

// Parallel encode options.
ctx->GetAttr("num_threads", &num_threads_).IgnoreError();
ctx->GetAttr("chunk_size", &chunk_size_).IgnoreError();
}

void Compute(OpKernelContext* ctx) override {
Expand Down Expand Up @@ -309,6 +314,8 @@ class SentencepieceTokenizeOp : public OpKernel {
nbest_tokens(return_nbest_ ? num_of_input_values : 0);
if (num_of_input_values > 0) {
const bool return_nbest = return_nbest_;
const int32 num_threads = num_threads_;
const int32 chunk_size = chunk_size_;
const auto& worker_threads =
*(ctx->device()->tensorflow_cpu_worker_threads());
::tensorflow::Shard(
Expand All @@ -317,8 +324,8 @@ class SentencepieceTokenizeOp : public OpKernel {
num_of_input_values, // total number of data to process.
kCostPerUnit, // cost per unit
[ctx, sp, &input_values_flat, &tokens, &nbest_tokens,
&nbest_size_tensor, &alpha_tensor,
return_nbest](int64 start, int64 limit) {
&nbest_size_tensor, &alpha_tensor, &chunk_size,
&num_threads, return_nbest](int64 start, int64 limit) {
absl::ReaderMutexLock lock(&sp->mu);
for (int i = start; i < limit; ++i) {
const int32 nbest_size = nbest_size_tensor->dims() == 1
Expand All @@ -329,8 +336,17 @@ class SentencepieceTokenizeOp : public OpKernel {
input_values_flat(i), nbest_size,
&nbest_tokens[i])));
} else if (nbest_size == 0 || nbest_size == 1) {
OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode(
input_values_flat(i), &tokens[i])));
if (num_threads == 1) {
OP_REQUIRES_OK(
ctx,
ToTFStatus(sp->processor.Encode(
input_values_flat(i), &tokens[i])));
} else {
OP_REQUIRES_OK(
ctx, ToTFStatus(sp->processor.ParallelEncode(
input_values_flat(i), chunk_size, num_threads,
&tokens[i])));
}
} else {
const float alpha = alpha_tensor->dims() == 1
? alpha_tensor->vec<float>()(i)
Expand Down Expand Up @@ -379,6 +395,8 @@ class SentencepieceTokenizeOp : public OpKernel {
}

bool return_nbest_{false};
int32_t num_threads_{1};
int32_t chunk_size_{0};
};

REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeOp")
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_text/core/ops/sentencepiece_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ REGISTER_OP("SentencepieceTokenizeOp")
.Input("add_bos: bool")
.Input("add_eos: bool")
.Input("reverse: bool")
.Attr("num_threads: int = 1")
.Attr("chunk_size: int = 0")
.Attr("out_type: {int32, string} = DT_INT32")
.Attr("Tsplits: {int32, int64} = DT_INT64")
.Attr("return_nbest: bool = false")
Expand Down
11 changes: 10 additions & 1 deletion tensorflow_text/python/ops/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(self,
reverse=False,
add_bos=False,
add_eos=False,
num_threads=1,
chunk_size=0,
return_nbest=False,
name=None):
"""Creates & initializes a Sentencepiece processor.
Expand All @@ -101,6 +103,10 @@ def __init__(self,
add_eos: Add end of sentence token to the result (Default = false). When
`reverse=True` beginning/end of sentence tokens are added after
reversing.
num_threads: If `> 1`, the input is split up into chunks of size
`chunk_size` and tokenized in parallel with this many threads.
chunk_size: Only used if `num_threads > 1`. The input is split into
chunks of this size and tokenized in parallel.
return_nbest: If True requires that `nbest_size` is a scalar and `> 1`.
Returns the `nbest_size` best tokenizations for each sentence instead
of a single one. The returned tensor has shape
Expand All @@ -118,6 +124,8 @@ def __init__(self,
self.reverse = reverse
self.add_bos = add_bos
self.add_eos = add_eos
self.num_threads = num_threads
self.chunk_size = chunk_size
self.return_nbest = return_nbest
self._model_resource = _SentencepieceModelResource(model, name)

Expand Down Expand Up @@ -154,7 +162,8 @@ def tokenize(self, input, name=None): # pylint: disable=redefined-builtin
gen_sentencepiece_tokenizer.sentencepiece_tokenize_op(
self._model_resource.resource_handle, input_tensor,
self.nbest_size, self.alpha, self.add_bos, self.add_eos,
self.reverse, self.out_type, return_nbest=self.return_nbest))
self.reverse, self.num_threads, self.chunk_size,
self.out_type, return_nbest=self.return_nbest))
tokens = RaggedTensor.from_nested_row_splits(
flat_values=output_values,
nested_row_splits=[row_splits],
Expand Down