diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java index 397c0c8bfc251..cec8bd403fee1 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java @@ -167,6 +167,7 @@ import io.trino.operator.scalar.SessionFunctions; import io.trino.operator.scalar.SplitToMapFunction; import io.trino.operator.scalar.SplitToMultimapFunction; +import io.trino.operator.scalar.StreamFunction; import io.trino.operator.scalar.StringFunctions; import io.trino.operator.scalar.TDigestFunctions; import io.trino.operator.scalar.TryFunction; @@ -518,6 +519,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .scalar(ConcatWsFunction.ConcatArrayWs.class) .scalar(DynamicFilters.Function.class) .scalar(DynamicFilters.NullableFunction.class) + .scalar(StreamFunction.class) .functions(ZIP_WITH_FUNCTION, MAP_ZIP_WITH_FUNCTION) .functions(ZIP_FUNCTIONS) .scalars(ArrayJoin.class) diff --git a/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java index 8972b8ce4aefd..ce5a33f5bd3ef 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TypeRegistry.java @@ -100,6 +100,7 @@ import static io.trino.type.LikePatternType.LIKE_PATTERN; import static io.trino.type.MapParametricType.MAP; import static io.trino.type.RowParametricType.ROW; +import static io.trino.type.StreamParametricType.STREAM; import static io.trino.type.TDigestType.TDIGEST; import static io.trino.type.UnknownType.UNKNOWN; import static io.trino.type.setdigest.SetDigestType.SET_DIGEST; @@ -162,6 +163,7 @@ public TypeRegistry(TypeOperators typeOperators, FeaturesConfig featuresConfig) addParametricType(MAP); addParametricType(FUNCTION); addParametricType(QDIGEST); + addParametricType(STREAM); addParametricType(TIMESTAMP); addParametricType(TIMESTAMP_WITH_TIME_ZONE); addParametricType(TIME); diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java index 1679e3ece3ea1..d7cc6a54dfadd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/ApproximateCountDistinctAggregation.java @@ -16,6 +16,7 @@ import com.google.common.annotations.VisibleForTesting; import io.airlift.stats.cardinality.HyperLogLog; import io.trino.operator.aggregation.state.HyperLogLogState; +import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.ValueBlock; import io.trino.spi.function.AggregationFunction; @@ -30,9 +31,12 @@ import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.StandardTypes; +import io.trino.spi.type.Type; +import io.trino.type.StreamType; import java.lang.invoke.MethodHandle; +import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -112,6 +116,7 @@ public static void input( @InputFunction @TypeParameter("T") public static void input( + @TypeParameter("T") Type type, @OperatorDependency( operator = XX_HASH_64, argumentTypes = "T", @@ -123,14 +128,22 @@ public static void input( { HyperLogLog hll = getOrCreateHyperLogLog(state, maxStandardError); state.addMemoryUsage(-hll.estimatedInMemorySize()); - long hash; try { - hash = (long) methodHandle.invoke(value); + if (type instanceof StreamType streamType) { + checkArgument(value instanceof Block, "value must be a Block"); + for (Block block : streamType.blockValueIterable((Block) value, 0)) { + long hash = (long) methodHandle.invokeExact(block); + hll.addHash(hash); + } + } + else { + long hash = (long) methodHandle.invoke(value); + hll.addHash(hash); + } } catch (Throwable t) { throw internalError(t); } - hll.addHash(hash); state.addMemoryUsage(hll.estimatedInMemorySize()); } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/BigintApproximateMostFrequent.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/BigintApproximateMostFrequent.java index ecd50b8ac4e0b..fffe273c01a4a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/BigintApproximateMostFrequent.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/BigintApproximateMostFrequent.java @@ -13,6 +13,8 @@ */ package io.trino.operator.aggregation; +import com.google.common.collect.Streams; +import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.function.AccumulatorState; @@ -27,6 +29,8 @@ import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.type.StandardTypes.BIGINT; +import static io.trino.type.StreamType.STREAM_BIGINT; +import static io.trino.type.StreamType.STREAM_INTEGER; import static io.trino.util.Failures.checkCondition; import static java.lang.Math.toIntExact; @@ -57,23 +61,44 @@ public interface State void set(ApproximateMostFrequentHistogram value); } - @InputFunction - public static void input(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType(BIGINT) long value, @SqlType(BIGINT) long capacity) + private static ApproximateMostFrequentHistogram getHistogram(State state, long buckets, long capacity) { - ApproximateMostFrequentHistogram histogram = state.get(); - if (histogram == null) { - checkCondition(buckets >= 2, INVALID_FUNCTION_ARGUMENT, "approx_most_frequent bucket count must be greater than one"); - histogram = new ApproximateMostFrequentHistogram( + checkCondition(buckets >= 2, INVALID_FUNCTION_ARGUMENT, "approx_most_frequent bucket count must be greater than one"); + if (state.get() == null) { + state.set(new ApproximateMostFrequentHistogram( toIntExact(buckets), toIntExact(capacity), LongApproximateMostFrequentStateSerializer::serializeBucket, - LongApproximateMostFrequentStateSerializer::deserializeBucket); - state.set(histogram); + LongApproximateMostFrequentStateSerializer::deserializeBucket)); } + return state.get(); + } + @InputFunction + public static void input(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType(BIGINT) long value, @SqlType(BIGINT) long capacity) + { + ApproximateMostFrequentHistogram histogram = getHistogram(state, buckets, capacity); histogram.add(value); } + @InputFunction + public static void inputStreamInteger(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType("stream(array(integer))") Block value, @SqlType(BIGINT) long capacity) + { + ApproximateMostFrequentHistogram histogram = getHistogram(state, buckets, capacity); + Streams.stream(STREAM_INTEGER.valueIterable(value)) + .map(Long.class::cast) + .forEach(histogram::add); + } + + @InputFunction + public static void inputStreamLong(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType("stream(array(bigint))") Block value, @SqlType(BIGINT) long capacity) + { + ApproximateMostFrequentHistogram histogram = getHistogram(state, buckets, capacity); + Streams.stream(STREAM_BIGINT.valueIterable(value)) + .map(Long.class::cast) + .forEach(histogram::add); + } + @CombineFunction public static void combine(@AggregationState State state, @AggregationState State otherState) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java index 517fa4df07b40..d926558a1f08f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/DefaultApproximateCountDistinctAggregation.java @@ -28,6 +28,7 @@ import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.StandardTypes; +import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -82,6 +83,7 @@ public static void input( @InputFunction @TypeParameter("T") public static void input( + @TypeParameter("T") Type type, @OperatorDependency( operator = XX_HASH_64, argumentTypes = "T", @@ -90,7 +92,7 @@ public static void input( @AggregationState HyperLogLogState state, @SqlType("T") Object value) { - ApproximateCountDistinctAggregation.input(methodHandle, state, value, DEFAULT_STANDARD_ERROR); + ApproximateCountDistinctAggregation.input(type, methodHandle, state, value, DEFAULT_STANDARD_ERROR); } @CombineFunction diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java index f2872e76ff5fd..d7cca8d328034 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/LongSumAggregation.java @@ -13,8 +13,12 @@ */ package io.trino.operator.aggregation; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import io.trino.annotation.UsedByGeneratedCode; import io.trino.operator.aggregation.state.NullableLongState; +import io.trino.operator.window.InternalWindowIndex; +import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.AggregationFunction; import io.trino.spi.function.AggregationState; @@ -22,10 +26,13 @@ import io.trino.spi.function.InputFunction; import io.trino.spi.function.OutputFunction; import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; import io.trino.spi.function.WindowAccumulator; import io.trino.spi.function.WindowIndex; import io.trino.spi.type.StandardTypes; +import io.trino.spi.type.Type; import io.trino.type.BigintOperators; +import io.trino.type.StreamType; import static io.trino.spi.type.BigintType.BIGINT; @@ -35,12 +42,28 @@ public final class LongSumAggregation private LongSumAggregation() {} @InputFunction - public static void sum(@AggregationState NullableLongState state, @SqlType(StandardTypes.BIGINT) long value) + @TypeParameter("T") + public static void sum(@AggregationState NullableLongState state, @SqlType("T") long value) { state.setNull(false); state.setValue(BigintOperators.add(state.getValue(), value)); } + @InputFunction + @TypeParameter("T") + public static void sum(@TypeParameter("T") Type type, @AggregationState NullableLongState state, @SqlType("T") Block value) + { + if (type instanceof StreamType streamType) { + Streams.stream(streamType.valueIterable(value)) + .map(Long.class::cast) + .reduce(BigintOperators::add) + .ifPresent(v -> sum(state, v)); + } + else { + throw new UnsupportedOperationException("Unsupported type: " + type); + } + } + @CombineFunction public static void combine(@AggregationState NullableLongState state, @AggregationState NullableLongState otherState) { @@ -85,13 +108,25 @@ public WindowAccumulator copy() return new LongSumWindowAccumulator(count, sum); } + private Iterable getValues(WindowIndex index, int position) + { + if (index instanceof InternalWindowIndex internalWindowIndex && internalWindowIndex.getType(0) instanceof StreamType streamType) { + Block block = index.getSingleValueBlock(0, position); + return streamType.valueIterable(block); + } + + return ImmutableList.of(index.getLong(0, position)); + } + @Override public void addInput(WindowIndex index, int startPosition, int endPosition) { for (int i = startPosition; i <= endPosition; i++) { if (!index.isNull(0, i)) { - sum += index.getLong(0, i); - count++; + for (Object value : getValues(index, i)) { + sum += (long) value; + count++; + } } } } @@ -101,8 +136,10 @@ public boolean removeInput(WindowIndex index, int startPosition, int endPosition { for (int i = startPosition; i <= endPosition; i++) { if (!index.isNull(0, i)) { - sum -= index.getLong(0, i); - count--; + for (Object value : getValues(index, i)) { + sum -= (long) value; + count--; + } } } return true; diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java index 6ec2f540c84f9..ac9afa05ef5bc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MaxAggregationFunction.java @@ -29,6 +29,8 @@ import io.trino.spi.function.OutputFunction; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; +import io.trino.type.StreamType; import java.lang.invoke.MethodHandle; @@ -45,6 +47,7 @@ private MaxAggregationFunction() {} @InputFunction @TypeParameter("T") public static void input( + @TypeParameter("T") Type type, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_FIRST, argumentTypes = {"T", "T"}, @@ -55,8 +58,17 @@ public static void input( @BlockIndex int position) throws Throwable { - if (state.isNull() || ((long) compare.invokeExact(block, position, state)) > 0) { - state.set(block, position); + if (type instanceof StreamType streamType) { + for (ValueBlock sub : streamType.arrayIterable(block, position)) { + if (state.isNull() || ((long) compare.invokeExact(sub, 0, state)) > 0) { + state.set(sub, 0); + } + } + } + else { + if (state.isNull() || ((long) compare.invokeExact(block, position, state)) > 0) { + state.set(block, position); + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java index 8616b7c2116cf..e8f300f92bfaf 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/MinAggregationFunction.java @@ -29,6 +29,8 @@ import io.trino.spi.function.OutputFunction; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.Type; +import io.trino.type.StreamType; import java.lang.invoke.MethodHandle; @@ -45,6 +47,7 @@ private MinAggregationFunction() {} @InputFunction @TypeParameter("T") public static void input( + @TypeParameter("T") Type type, @OperatorDependency( operator = OperatorType.COMPARISON_UNORDERED_LAST, argumentTypes = {"T", "T"}, @@ -55,8 +58,17 @@ public static void input( @BlockIndex int position) throws Throwable { - if (state.isNull() || ((long) compare.invokeExact(block, position, state)) < 0) { - state.set(block, position); + if (type instanceof StreamType streamType) { + for (ValueBlock sub : streamType.arrayIterable(block, position)) { + if (state.isNull() || ((long) compare.invokeExact(sub, 0, state)) < 0) { + state.set(sub, 0); + } + } + } + else { + if (state.isNull() || ((long) compare.invokeExact(block, position, state)) < 0) { + state.set(block, position); + } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/VarcharApproximateMostFrequent.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/VarcharApproximateMostFrequent.java index b63a79d5ebffc..378b84beb090f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/VarcharApproximateMostFrequent.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/VarcharApproximateMostFrequent.java @@ -13,7 +13,9 @@ */ package io.trino.operator.aggregation; +import com.google.common.collect.Streams; import io.airlift.slice.Slice; +import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.function.AccumulatorState; @@ -30,6 +32,7 @@ import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.type.StandardTypes.BIGINT; import static io.trino.spi.type.StandardTypes.VARCHAR; +import static io.trino.type.StreamType.STREAM_VARCHAR; import static io.trino.util.Failures.checkCondition; import static java.lang.Math.toIntExact; @@ -60,8 +63,7 @@ public interface State void set(ApproximateMostFrequentHistogram value); } - @InputFunction - public static void input(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType(VARCHAR) Slice value, @SqlType(BIGINT) long capacity) + private static ApproximateMostFrequentHistogram getHistogram(State state, long buckets, long capacity) { ApproximateMostFrequentHistogram histogram = state.get(); if (histogram == null) { @@ -74,9 +76,26 @@ public static void input(@AggregationState State state, @SqlType(BIGINT) long bu state.set(histogram); } + return histogram; + } + + @InputFunction + public static void input(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType(VARCHAR) Slice value, @SqlType(BIGINT) long capacity) + { + ApproximateMostFrequentHistogram histogram = getHistogram(state, buckets, capacity); histogram.add(value); } + @InputFunction + public static void inputStream(@AggregationState State state, @SqlType(BIGINT) long buckets, @SqlType("stream(array(varchar))") Block value, @SqlType(BIGINT) long capacity) + { + ApproximateMostFrequentHistogram histogram = getHistogram(state, buckets, capacity); + + Streams.stream(STREAM_VARCHAR.valueIterable(value)) + .map(Slice.class::cast) + .forEach(histogram::add); + } + @CombineFunction public static void combine(@AggregationState State state, @AggregationState State otherState) { diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/StreamFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/StreamFunction.java new file mode 100644 index 0000000000000..7b16272807d0c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/StreamFunction.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.trino.spi.block.Block; +import io.trino.spi.function.Description; +import io.trino.spi.function.ScalarFunction; +import io.trino.spi.function.SqlType; +import io.trino.spi.function.TypeParameter; + +@ScalarFunction("stream") +@Description("Return stream of array") +public final class StreamFunction +{ + @TypeParameter("E") + public StreamFunction() + { + } + + @TypeParameter("E") + @SqlType("stream(E)") + public Block stream(@SqlType("E") Block block) + { + return block; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/window/InternalWindowIndex.java b/core/trino-main/src/main/java/io/trino/operator/window/InternalWindowIndex.java index 53fe8103e8674..2f0f562b579ac 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/InternalWindowIndex.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/InternalWindowIndex.java @@ -16,6 +16,7 @@ import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.function.WindowIndex; +import io.trino.spi.type.Type; public interface InternalWindowIndex extends WindowIndex @@ -25,4 +26,6 @@ public interface InternalWindowIndex @UsedByGeneratedCode int getRawBlockPosition(int position); + + Type getType(int channel); } diff --git a/core/trino-main/src/main/java/io/trino/operator/window/MappedWindowIndex.java b/core/trino-main/src/main/java/io/trino/operator/window/MappedWindowIndex.java index afbdd9039a747..e36753d539336 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/MappedWindowIndex.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/MappedWindowIndex.java @@ -18,6 +18,7 @@ import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.Type; import java.util.List; @@ -107,6 +108,12 @@ public int getRawBlockPosition(int position) return delegate.getRawBlockPosition(position); } + @Override + public Type getType(int channel) + { + return delegate.getType(toDelegateChannel(channel)); + } + private int toDelegateChannel(int channel) { return channelMap[channel]; diff --git a/core/trino-main/src/main/java/io/trino/operator/window/PagesWindowIndex.java b/core/trino-main/src/main/java/io/trino/operator/window/PagesWindowIndex.java index 96805bbf29f4e..69f757ebad6c9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/PagesWindowIndex.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/PagesWindowIndex.java @@ -17,6 +17,7 @@ import io.trino.operator.PagesIndex; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.Type; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -122,4 +123,10 @@ public String toString() .add("size", size) .toString(); } + + @Override + public Type getType(int channel) + { + return pagesIndex.getType(channel); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/window/pattern/ProjectingPagesWindowIndex.java b/core/trino-main/src/main/java/io/trino/operator/window/pattern/ProjectingPagesWindowIndex.java index e71f70a2415dc..82645a923a46e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/pattern/ProjectingPagesWindowIndex.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/pattern/ProjectingPagesWindowIndex.java @@ -277,4 +277,10 @@ public String toString() .add("projected channels", projectedTypes.size()) .toString(); } + + @Override + public Type getType(int channel) + { + return pagesIndex.getType(channel); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index c50341811c981..ff2901f1892ca 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -152,6 +152,7 @@ import io.trino.sql.tree.WindowOperation; import io.trino.type.FunctionType; import io.trino.type.JsonPath2016Type; +import io.trino.type.StreamType; import io.trino.type.TypeCoercion; import io.trino.type.UnknownType; import jakarta.annotation.Nullable; @@ -1422,6 +1423,9 @@ else if (isAggregation) { } Type type = signature.getReturnType(); + if (isAggregation && type instanceof StreamType streamType) { + type = streamType.getArrayType(); + } return setExpressionType(node, type); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 969e8d1a8487d..0d7fcea3d4364 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -271,6 +271,7 @@ import io.trino.sql.tree.With; import io.trino.sql.tree.WithQuery; import io.trino.transaction.TransactionManager; +import io.trino.type.StreamType; import io.trino.type.TypeCoercion; import java.math.RoundingMode; @@ -4866,6 +4867,9 @@ private void analyzeSelectSingleColumn( type, expression); } + if (type instanceof StreamType) { + throw semanticException(TYPE_MISMATCH, node.getSelect(), "Stream type is not allowed in output"); + } } private void analyzeWhere(Node node, Scope scope, Expression predicate) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index f5a84ed009edd..b8d6096416cf2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -275,6 +275,7 @@ import io.trino.sql.relational.SqlToRowExpressionTranslator; import io.trino.type.BlockTypeOperators; import io.trino.type.FunctionType; +import io.trino.type.StreamType; import org.objectweb.asm.MethodTooLargeException; import java.util.AbstractMap.SimpleEntry; @@ -3916,6 +3917,9 @@ private AggregatorFactory buildAggregatorFactory( .collect(toImmutableList()); Type intermediateType = (intermediateTypes.size() == 1) ? getOnlyElement(intermediateTypes) : RowType.anonymous(intermediateTypes); Type finalType = resolvedFunction.signature().getReturnType(); + if (finalType instanceof StreamType streamType) { + finalType = streamType.getArrayType(); + } OptionalInt maskChannel = aggregation.getMask().stream() .mapToInt(value -> source.getLayout().get(value)) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java index 83fbcb2c64721..1b016d9e4acd7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TypeValidator.java @@ -17,6 +17,7 @@ import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.spi.function.BoundSignature; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; @@ -31,6 +32,7 @@ import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.WindowNode; import io.trino.type.FunctionType; +import io.trino.type.StreamType; import io.trino.type.UnknownType; import java.util.List; @@ -177,6 +179,9 @@ private void verifyTypeSignature(Symbol symbol, Type expected, Type actual) checkArgument(expectedFieldType.equals(actualFieldTypes), "type of symbol '%s' is expected to be %s, but the actual type is %s", symbol.name(), expected, actual); } + else if (actual instanceof StreamType streamType && expected instanceof ArrayType arrayType) { + checkArgument(arrayType.equals(streamType.getArrayType()), "type of symbol '%s' is expected to be %s, but the actual type is %s", symbol.name(), expected, actual); + } else if (!(actual instanceof UnknownType)) { // UNKNOWN should be considered as a wildcard type, which matches all the other types checkArgument(expected.equals(actual), "type of symbol '%s' is expected to be %s, but the actual type is %s", symbol.name(), expected, actual); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/VerifyOnlyOneOutputNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/VerifyOnlyOneOutputNode.java index 988ba9a20fb37..26caab7b7bc64 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/VerifyOnlyOneOutputNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/VerifyOnlyOneOutputNode.java @@ -16,8 +16,12 @@ import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.sql.PlannerContext; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; +import io.trino.type.StreamType; + +import java.util.List; import static com.google.common.base.Preconditions.checkState; import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; @@ -31,10 +35,14 @@ public void validate(PlanNode plan, PlannerContext plannerContext, WarningCollector warningCollector) { - int outputPlanNodesCount = searchFrom(plan) + List outputNodes = searchFrom(plan) .where(OutputNode.class::isInstance) - .findAll() - .size(); - checkState(outputPlanNodesCount == 1, "Expected plan to have single instance of OutputNode"); + .findAll(); + checkState(outputNodes.size() == 1, "Expected plan to have single instance of OutputNode"); + + boolean containsStream = outputNodes.getFirst().getOutputSymbols().stream() + .map(Symbol::type) + .anyMatch(StreamType.class::isInstance); + checkState(!containsStream, "OutputNode should not contain StreamType"); } } diff --git a/core/trino-main/src/main/java/io/trino/type/StreamParametricType.java b/core/trino-main/src/main/java/io/trino/type/StreamParametricType.java new file mode 100644 index 0000000000000..24f46169583f7 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/type/StreamParametricType.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import io.trino.spi.type.ParameterKind; +import io.trino.spi.type.ParametricType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeManager; +import io.trino.spi.type.TypeParameter; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; + +public final class StreamParametricType + implements ParametricType +{ + public static final StreamParametricType STREAM = new StreamParametricType(); + + private StreamParametricType() + { + } + + @Override + public String getName() + { + return "STREAM"; + } + + @Override + public Type createType(TypeManager typeManager, List parameters) + { + checkArgument(parameters.size() == 1, "Array type expects exactly one type as a parameter, got %s", parameters); + checkArgument( + parameters.getFirst().getKind() == ParameterKind.TYPE, + "Array expects type as a parameter, got %s", + parameters); + return new StreamType(parameters.getFirst().getType()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/type/StreamType.java b/core/trino-main/src/main/java/io/trino/type/StreamType.java new file mode 100644 index 0000000000000..b76fba967024d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/type/StreamType.java @@ -0,0 +1,240 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.type; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Streams; +import io.trino.spi.block.ArrayBlock; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BlockBuilderStatus; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.AbstractType; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperatorDeclaration; +import io.trino.spi.type.TypeOperators; +import io.trino.spi.type.TypeSignature; +import io.trino.spi.type.TypeSignatureParameter; + +import java.util.List; +import java.util.Optional; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.util.Collections.singletonList; + +public class StreamType + extends AbstractType +{ + public static final StreamType STREAM_INTEGER = new StreamType(new ArrayType(INTEGER)); + public static final StreamType STREAM_BIGINT = new StreamType(new ArrayType(BIGINT)); + public static final StreamType STREAM_VARCHAR = new StreamType(new ArrayType(VARCHAR)); + + private static final String STREAM = "stream"; + private volatile TypeOperatorDeclaration operatorDeclaration; + + private final Type elementType; + private final ArrayType arrayType; + + public StreamType(Type elementType) + { + super(new TypeSignature(STREAM, TypeSignatureParameter.typeParameter(elementType.getTypeSignature())), Block.class, ArrayBlock.class); + checkArgument(elementType instanceof ArrayType || elementType instanceof StreamType, "elementType must be an array type or a stream type"); + this.elementType = elementType; + this.arrayType = underlyingArrayType(elementType); + } + + private static ArrayType underlyingArrayType(Type type) + { + Type arrayType = type; + while (arrayType instanceof StreamType streamType) { + arrayType = streamType.elementType; + } + return (ArrayType) arrayType; + } + + public ArrayType getArrayType() + { + return arrayType; + } + + @Override + public List getTypeParameters() + { + return singletonList(elementType); + } + + @Override + public String getDisplayName() + { + return STREAM + "(" + elementType.getDisplayName() + ")"; + } + + @Override + public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOperators) + { + if (operatorDeclaration == null) { + operatorDeclaration = arrayType.getTypeOperatorDeclaration(typeOperators); + } + return operatorDeclaration; + } + + @Override + public Object getObjectValue(ConnectorSession session, Block block, int position) + { + return arrayType.getObjectValue(session, block, position); + } + + @Override + public Object getObject(Block block, int position) + { + return arrayType.getObject(block, position); + } + + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + arrayType.writeObject(blockBuilder, value); + } + + @Override + public void appendTo(Block block, int position, BlockBuilder blockBuilder) + { + arrayType.appendTo(block, position, blockBuilder); + } + + @Override + public int getFlatFixedSize() + { + return 8; + } + + @Override + public boolean isFlatVariableWidth() + { + return true; + } + + @Override + public int getFlatVariableWidthSize(Block block, int position) + { + return arrayType.getFlatVariableWidthSize(block, position); + } + + @Override + public int relocateFlatVariableWidthOffsets(byte[] fixedSizeSlice, int fixedSizeOffset, byte[] variableSizeSlice, int variableSizeOffset) + { + return arrayType.relocateFlatVariableWidthOffsets(fixedSizeSlice, fixedSizeOffset, variableSizeSlice, variableSizeOffset); + } + + @Override + public boolean isComparable() + { + return elementType.isComparable(); + } + + @Override + public boolean isOrderable() + { + return elementType.isOrderable(); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries, int expectedBytesPerEntry) + { + return arrayType.createBlockBuilder(blockBuilderStatus, expectedEntries, expectedBytesPerEntry); + } + + @Override + public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, int expectedEntries) + { + return createBlockBuilder(blockBuilderStatus, expectedEntries, 100); + } + + public Iterable valueIterable(Block block) + { + ValueBlock valueBlock = getValueBlock(block, 0); + return () -> IntStream.range(0, valueBlock.getPositionCount()).boxed() + .map(i -> getValueObject(valueBlock, i)) + .iterator(); + } + + private static Object getValueObject(ValueBlock valueBlock, int position) + { + if (valueBlock instanceof IntArrayBlock intArrayBlock) { + return (long) intArrayBlock.getInt(position); + } + if (valueBlock instanceof LongArrayBlock longArrayBlock) { + return longArrayBlock.getLong(position); + } + if (valueBlock instanceof VariableWidthBlock variableWidthBlock) { + return variableWidthBlock.getSlice(position); + } + throw new UnsupportedOperationException("unsupported block type: " + valueBlock.getClass().getName()); + } + + private static ValueBlock getValueBlock(Block block, int position) + { + if (block instanceof ArrayBlock arrayBlock) { + return arrayBlock.getUnderlyingValueBlock().getArray(arrayBlock.getUnderlyingValuePosition(position)).getUnderlyingValueBlock(); + } + if (block instanceof ValueBlock valueBlock) { + return valueBlock.getUnderlyingValueBlock(); + } + if (block instanceof RunLengthEncodedBlock rleBlock) { + return rleBlock.getValue(); + } + if (block instanceof DictionaryBlock dictionaryBlock) { + return dictionaryBlock.getDictionary(); + } + + throw new UnsupportedOperationException("unsupported block type: " + block.getClass().getName()); + } + + private Iterable streamValueBlocks(ValueBlock streamBlock, int position) + { + ValueBlock block = streamBlock.getSingleValueBlock(position); + if (elementType instanceof StreamType subStreamType) { + return subStreamType.arrayIterable(getValueBlock(block, position), 0); + } + return ImmutableList.of(block); + } + + public Iterable blockValueIterable(Block block, int position) + { + ValueBlock streamBlock = getValueBlock(block, position); + + return () -> IntStream.range(0, streamBlock.getPositionCount()) + .boxed() + .flatMap(i -> Streams.stream(streamValueBlocks(streamBlock, i))) + .iterator(); + } + + public Iterable arrayIterable(Block block, int position) + { + return Iterables.transform(blockValueIterable(block, position), + b -> ArrayBlock.fromElementBlock(1, Optional.of(new boolean[] {false}), new int[] {0, 1}, b)); + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java index 202d33f9bf74e..fc85c2fcaea41 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestAccumulatorCompiler.java @@ -35,6 +35,7 @@ import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.RealType; import io.trino.spi.type.TimestampType; +import io.trino.spi.type.Type; import io.trino.sql.gen.IsolatedClass; import org.junit.jupiter.api.Test; @@ -221,5 +222,11 @@ public int getRawBlockPosition(int position) { return 0; } + + @Override + public Type getType(int channel) + { + return BIGINT; + } } } diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java b/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java index 416cf18909239..7d0e301803f87 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestAggregations.java @@ -163,6 +163,94 @@ public void testPreAggregateWithFilter() plan -> assertAggregationNodeCount(plan, 4)); } + @Test + public void testStreamMinMax() + { + assertQuery("SELECT max(a) FROM (VALUES ARRAY[1], ARRAY[2], ARRAY[3]) t(a)", "VALUES ARRAY[3]"); + assertQuery("SELECT max(stream(a)) FROM (VALUES ARRAY[1], ARRAY[2, 4], ARRAY[3, 5, 0]) t(a)", "VALUES ARRAY[5]"); + assertQuery("SELECT max(stream(a)) FROM (VALUES ARRAY['a'], ARRAY['b'], ARRAY['c']) t(a)", "VALUES ARRAY['c']"); + + assertQuery("SELECT min(stream(a)) FROM (VALUES ARRAY[1], ARRAY[2], ARRAY[3]) t(a)", "VALUES ARRAY[1]"); + assertQuery("SELECT min(stream(a)) FROM (VALUES ARRAY[1], ARRAY[2, 4], ARRAY[3, 5, 0]) t(a)", "VALUES ARRAY[0]"); + + assertQuery("SELECT max(stream(repeat(nationkey, 3))) FROM nation", "SELECT max(nationkey) FROM nation"); + } + + @Test + public void testStreamApproxDistinct() + { + assertQuery("SELECT approx_distinct(a) FROM (VALUES ARRAY[1, 2, 3], ARRAY[4, 5], ARRAY[6]) t(a)", "VALUES 3"); + assertQuery("SELECT approx_distinct(stream(a)) FROM (VALUES ARRAY[1, 2, 3], ARRAY[4, 5], ARRAY[6]) t(a)", "VALUES 6"); + } + + @Test + public void testSumStream() + { + assertQuery("SELECT SUM(a) OVER (PARTITION BY b) FROM (VALUES (1, 1), (2, 1), (3, 2), (4, 2)) t(a, b)", "VALUES (3), (3), (7), (7)"); + assertQuery("SELECT SUM(stream(a)) FROM (VALUES ARRAY[1, 2, 3], ARRAY[4, 5], ARRAY[6]) t(a)", "VALUES 21"); + } + + @Test + public void testStreamSumWindow() + { + assertQuery(""" + SELECT SUM(stream(a)) OVER (PARTITION BY b) + FROM (VALUES (ARRAY[CAST(1 AS BIGINT), CAST(2 AS BIGINT), CAST(3 AS BIGINT)], 1), (ARRAY[CAST(4 AS BIGINT), CAST(5 AS BIGINT)], 1), (ARRAY[CAST(6 AS BIGINT)], 2)) t(a, b)""", + "VALUES (15), (15), (6)"); + + assertQuery(""" + SELECT SUM(stream(transform(a, x -> x + 1))) OVER (PARTITION BY b) + FROM (VALUES (ARRAY[1, 2, 3], 1), (ARRAY[4, 5], 1), (ARRAY[6], 2)) t(a, b)""", + "VALUES (20), (20), (7)"); + + assertQuery(""" + SELECT SUM(stream(filter(a, x -> x >= 5))) OVER (PARTITION BY b) + FROM (VALUES (ARRAY[1, 2, 3], 1), (ARRAY[4, 5], 1), (ARRAY[6], 2)) t(a, b)""", + "VALUES (5), (5), (6)"); + } + + @Test + public void testStreamMapKey() + { + assertQuery(""" + SELECT + approx_distinct(stream(map_keys(a))) + , SUM(stream(map_values(a))) + FROM (VALUES + MAP_FROM_ENTRIES(ARRAY[ROW('a', 1)]), + MAP_FROM_ENTRIES(ARRAY[ROW('b', 2), ROW('x', 21)]), + MAP_FROM_ENTRIES(ARRAY[ROW('c', 3), ROW('d', 4), ROW('0', -9)]) + ) t(a)""", + "VALUES (6, 22)"); + + assertQuery(""" + SELECT + CAST(approx_most_frequent(3, stream(map_keys(a)), 10) AS JSON) + FROM (VALUES + MAP_FROM_ENTRIES(ARRAY[ROW(CAST('a' AS VARCHAR), 1)]), + MAP_FROM_ENTRIES(ARRAY[ROW(CAST('b' AS VARCHAR), 2), ROW(CAST('x' AS VARCHAR), 21)]), + MAP_FROM_ENTRIES(ARRAY[ROW(CAST('c' AS VARCHAR), 3), ROW(CAST('d' AS VARCHAR), 4), ROW(CAST('0' AS VARCHAR), -9)]) + ) t(a)""", + "VALUES '{\"a\":1,\"b\":1,\"x\":1}'"); + } + + @Test + public void testNestedArrayStream() + { + assertQuery("SELECT CAST(max(stream(a)) AS JSON) FROM (VALUES ARRAY[ARRAY[1, 2]], ARRAY[ARRAY[2, 3]], ARRAY[ARRAY[3, 4]]) t(a)", "VALUES '[[3, 4]]'"); + assertQuery("SELECT CAST(max(stream(ARRAY[ARRAY[nationkey, nationkey + 1]])) AS JSON) FROM nation", "VALUES '[[24, 25]]'"); + + assertQuery("SELECT CAST(max(stream(stream(a))) AS JSON) FROM (VALUES ARRAY[ARRAY[1]], ARRAY[ARRAY[2]], ARRAY[ARRAY[3]]) t(a)", "VALUES '[[3]]'"); + } + + @Test + public void testApproxMostFrequent() + { + assertQuery("SELECT CAST(approx_most_frequent(2, a, 2) AS JSON) FROM (VALUES (1), (2), (3), (4), (5), (6), (6)) t(a)", "VALUES '{\"5\":3,\"6\":4}'"); + assertQuery("SELECT CAST(approx_most_frequent(2, stream(a), 2) AS JSON) FROM (VALUES ARRAY[1, 2, 3], ARRAY[4, 5, 6], ARRAY[6]) t(a)", "VALUES '{\"5\":3,\"6\":4}'"); + assertQuery("SELECT CAST(approx_most_frequent(2, stream(a), 2) AS JSON) FROM (VALUES ARRAY[CAST(1 AS BIGINT), CAST(2 AS BIGINT), CAST(3 AS BIGINT)], ARRAY[CAST(4 AS BIGINT), CAST(5 AS BIGINT), CAST(6 AS BIGINT)], ARRAY[CAST(6 AS BIGINT)]) t(a)", "VALUES '{\"5\":3,\"6\":4}'"); + } + private void assertAggregationNodeCount(Plan plan, int count) { assertThat(countOfMatchingNodes(plan, AggregationNode.class::isInstance)).isEqualTo(count);