Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream function/type to process array elements one by one #24148

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
import io.trino.operator.scalar.SessionFunctions;
import io.trino.operator.scalar.SplitToMapFunction;
import io.trino.operator.scalar.SplitToMultimapFunction;
import io.trino.operator.scalar.StreamFunction;
import io.trino.operator.scalar.StringFunctions;
import io.trino.operator.scalar.TDigestFunctions;
import io.trino.operator.scalar.TryFunction;
Expand Down Expand Up @@ -518,6 +519,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.scalar(ConcatWsFunction.ConcatArrayWs.class)
.scalar(DynamicFilters.Function.class)
.scalar(DynamicFilters.NullableFunction.class)
.scalar(StreamFunction.class)
.functions(ZIP_WITH_FUNCTION, MAP_ZIP_WITH_FUNCTION)
.functions(ZIP_FUNCTIONS)
.scalars(ArrayJoin.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
import static io.trino.type.LikePatternType.LIKE_PATTERN;
import static io.trino.type.MapParametricType.MAP;
import static io.trino.type.RowParametricType.ROW;
import static io.trino.type.StreamParametricType.STREAM;
import static io.trino.type.TDigestType.TDIGEST;
import static io.trino.type.UnknownType.UNKNOWN;
import static io.trino.type.setdigest.SetDigestType.SET_DIGEST;
Expand Down Expand Up @@ -162,6 +163,7 @@ public TypeRegistry(TypeOperators typeOperators, FeaturesConfig featuresConfig)
addParametricType(MAP);
addParametricType(FUNCTION);
addParametricType(QDIGEST);
addParametricType(STREAM);
addParametricType(TIMESTAMP);
addParametricType(TIMESTAMP_WITH_TIME_ZONE);
addParametricType(TIME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.annotations.VisibleForTesting;
import io.airlift.stats.cardinality.HyperLogLog;
import io.trino.operator.aggregation.state.HyperLogLogState;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.ValueBlock;
import io.trino.spi.function.AggregationFunction;
Expand All @@ -30,9 +31,12 @@
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.Type;
import io.trino.type.StreamType;

import java.lang.invoke.MethodHandle;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
Expand Down Expand Up @@ -112,6 +116,7 @@ public static void input(
@InputFunction
@TypeParameter("T")
public static void input(
@TypeParameter("T") Type type,
@OperatorDependency(
operator = XX_HASH_64,
argumentTypes = "T",
Expand All @@ -123,14 +128,22 @@ public static void input(
{
HyperLogLog hll = getOrCreateHyperLogLog(state, maxStandardError);
state.addMemoryUsage(-hll.estimatedInMemorySize());
long hash;
try {
hash = (long) methodHandle.invoke(value);
if (type instanceof StreamType streamType) {
checkArgument(value instanceof Block, "value must be a Block");
for (Block block : streamType.blockValueIterable((Block) value, 0)) {
long hash = (long) methodHandle.invokeExact(block);
hll.addHash(hash);
}
}
else {
long hash = (long) methodHandle.invoke(value);
hll.addHash(hash);
}
}
catch (Throwable t) {
throw internalError(t);
}
hll.addHash(hash);
state.addMemoryUsage(hll.estimatedInMemorySize());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.Streams;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.MapBlockBuilder;
import io.trino.spi.function.AccumulatorState;
Expand All @@ -27,6 +29,8 @@

import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static io.trino.spi.type.StandardTypes.BIGINT;
import static io.trino.type.StreamType.STREAM_BIGINT;
import static io.trino.type.StreamType.STREAM_INTEGER;
import static io.trino.util.Failures.checkCondition;
import static java.lang.Math.toIntExact;

Expand Down Expand Up @@ -57,23 +61,44 @@ public interface State
void set(ApproximateMostFrequentHistogram<Long> value);
}

@InputFunction
public static void input(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType(BIGINT) long value, @SqlType(BIGINT) long capacity)
private static ApproximateMostFrequentHistogram<Long> getHistogram(State state, long buckets, long capacity)
{
ApproximateMostFrequentHistogram<Long> histogram = state.get();
if (histogram == null) {
checkCondition(buckets >= 2, INVALID_FUNCTION_ARGUMENT, "approx_most_frequent bucket count must be greater than one");
histogram = new ApproximateMostFrequentHistogram<Long>(
checkCondition(buckets >= 2, INVALID_FUNCTION_ARGUMENT, "approx_most_frequent bucket count must be greater than one");
if (state.get() == null) {
state.set(new ApproximateMostFrequentHistogram<Long>(
toIntExact(buckets),
toIntExact(capacity),
LongApproximateMostFrequentStateSerializer::serializeBucket,
LongApproximateMostFrequentStateSerializer::deserializeBucket);
state.set(histogram);
LongApproximateMostFrequentStateSerializer::deserializeBucket));
}
return state.get();
}

@InputFunction
public static void input(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType(BIGINT) long value, @SqlType(BIGINT) long capacity)
{
ApproximateMostFrequentHistogram<Long> histogram = getHistogram(state, buckets, capacity);
histogram.add(value);
}

@InputFunction
public static void inputStreamInteger(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType("stream(array(integer))") Block value, @SqlType(BIGINT) long capacity)
{
ApproximateMostFrequentHistogram<Long> histogram = getHistogram(state, buckets, capacity);
Streams.stream(STREAM_INTEGER.valueIterable(value))
.map(Long.class::cast)
.forEach(histogram::add);
}

@InputFunction
public static void inputStreamLong(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType("stream(array(bigint))") Block value, @SqlType(BIGINT) long capacity)
{
ApproximateMostFrequentHistogram<Long> histogram = getHistogram(state, buckets, capacity);
Streams.stream(STREAM_BIGINT.valueIterable(value))
.map(Long.class::cast)
.forEach(histogram::add);
}

@CombineFunction
public static void combine(@AggregationState State state, @AggregationState State otherState)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.Type;

import java.lang.invoke.MethodHandle;

Expand Down Expand Up @@ -82,6 +83,7 @@ public static void input(
@InputFunction
@TypeParameter("T")
public static void input(
@TypeParameter("T") Type type,
@OperatorDependency(
operator = XX_HASH_64,
argumentTypes = "T",
Expand All @@ -90,7 +92,7 @@ public static void input(
@AggregationState HyperLogLogState state,
@SqlType("T") Object value)
{
ApproximateCountDistinctAggregation.input(methodHandle, state, value, DEFAULT_STANDARD_ERROR);
ApproximateCountDistinctAggregation.input(type, methodHandle, state, value, DEFAULT_STANDARD_ERROR);
}

@CombineFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,26 @@
*/
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.operator.aggregation.state.NullableLongState;
import io.trino.operator.window.InternalWindowIndex;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.function.WindowAccumulator;
import io.trino.spi.function.WindowIndex;
import io.trino.spi.type.StandardTypes;
import io.trino.spi.type.Type;
import io.trino.type.BigintOperators;
import io.trino.type.StreamType;

import static io.trino.spi.type.BigintType.BIGINT;

Expand All @@ -35,12 +42,28 @@ public final class LongSumAggregation
private LongSumAggregation() {}

@InputFunction
public static void sum(@AggregationState NullableLongState state, @SqlType(StandardTypes.BIGINT) long value)
@TypeParameter("T")
public static void sum(@AggregationState NullableLongState state, @SqlType("T") long value)
{
state.setNull(false);
state.setValue(BigintOperators.add(state.getValue(), value));
}

@InputFunction
@TypeParameter("T")
public static void sum(@TypeParameter("T") Type type, @AggregationState NullableLongState state, @SqlType("T") Block value)
{
if (type instanceof StreamType streamType) {
Streams.stream(streamType.valueIterable(value))
.map(Long.class::cast)
.reduce(BigintOperators::add)
.ifPresent(v -> sum(state, v));
}
else {
throw new UnsupportedOperationException("Unsupported type: " + type);
}
}

@CombineFunction
public static void combine(@AggregationState NullableLongState state, @AggregationState NullableLongState otherState)
{
Expand Down Expand Up @@ -85,13 +108,25 @@ public WindowAccumulator copy()
return new LongSumWindowAccumulator(count, sum);
}

private Iterable<Object> getValues(WindowIndex index, int position)
{
if (index instanceof InternalWindowIndex internalWindowIndex && internalWindowIndex.getType(0) instanceof StreamType streamType) {
Block block = index.getSingleValueBlock(0, position);
return streamType.valueIterable(block);
}

return ImmutableList.of(index.getLong(0, position));
}

@Override
public void addInput(WindowIndex index, int startPosition, int endPosition)
{
for (int i = startPosition; i <= endPosition; i++) {
if (!index.isNull(0, i)) {
sum += index.getLong(0, i);
count++;
for (Object value : getValues(index, i)) {
sum += (long) value;
count++;
}
}
}
}
Expand All @@ -101,8 +136,10 @@ public boolean removeInput(WindowIndex index, int startPosition, int endPosition
{
for (int i = startPosition; i <= endPosition; i++) {
if (!index.isNull(0, i)) {
sum -= index.getLong(0, i);
count--;
for (Object value : getValues(index, i)) {
sum -= (long) value;
count--;
}
}
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.type.StreamType;

import java.lang.invoke.MethodHandle;

Expand All @@ -45,6 +47,7 @@ private MaxAggregationFunction() {}
@InputFunction
@TypeParameter("T")
public static void input(
@TypeParameter("T") Type type,
@OperatorDependency(
operator = OperatorType.COMPARISON_UNORDERED_FIRST,
argumentTypes = {"T", "T"},
Expand All @@ -55,8 +58,17 @@ public static void input(
@BlockIndex int position)
throws Throwable
{
if (state.isNull() || ((long) compare.invokeExact(block, position, state)) > 0) {
state.set(block, position);
if (type instanceof StreamType streamType) {
for (ValueBlock sub : streamType.arrayIterable(block, position)) {
if (state.isNull() || ((long) compare.invokeExact(sub, 0, state)) > 0) {
state.set(sub, 0);
}
}
}
else {
if (state.isNull() || ((long) compare.invokeExact(block, position, state)) > 0) {
state.set(block, position);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.type.StreamType;

import java.lang.invoke.MethodHandle;

Expand All @@ -45,6 +47,7 @@ private MinAggregationFunction() {}
@InputFunction
@TypeParameter("T")
public static void input(
@TypeParameter("T") Type type,
@OperatorDependency(
operator = OperatorType.COMPARISON_UNORDERED_LAST,
argumentTypes = {"T", "T"},
Expand All @@ -55,8 +58,17 @@ public static void input(
@BlockIndex int position)
throws Throwable
{
if (state.isNull() || ((long) compare.invokeExact(block, position, state)) < 0) {
state.set(block, position);
if (type instanceof StreamType streamType) {
for (ValueBlock sub : streamType.arrayIterable(block, position)) {
if (state.isNull() || ((long) compare.invokeExact(sub, 0, state)) < 0) {
state.set(sub, 0);
}
}
}
else {
if (state.isNull() || ((long) compare.invokeExact(block, position, state)) < 0) {
state.set(block, position);
}
}
}

Expand Down
Loading
Loading