From b97aadb925ac787275abda7efbaf2a3d79b2753d Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 8 Nov 2024 17:31:28 +0900 Subject: [PATCH 1/2] Move vector search from IndexInput to RandomAccessInput (#13938) --- .../store/EndiannessReverserIndexInput.java | 25 ++++++++++++++-- .../lucene/codecs/lucene95/HasIndexSlice.java | 8 ++--- .../lucene95/OffHeapByteVectorValues.java | 25 +++++++++------- .../lucene95/OffHeapFloatVectorValues.java | 24 ++++++++------- .../Lucene99ScalarQuantizedVectorScorer.java | 4 +-- .../OffHeapQuantizedByteVectorValues.java | 29 +++++++++---------- .../lucene/store/BufferedIndexInput.java | 6 ++++ .../lucene/store/ByteBuffersDataInput.java | 6 ++++ .../lucene/store/ByteBuffersIndexInput.java | 6 ++++ .../org/apache/lucene/store/IndexInput.java | 7 +++++ .../lucene/store/RandomAccessInput.java | 12 +++++++- .../QuantizedByteVectorValues.java | 4 +-- ...stLucene99ScalarQuantizedVectorScorer.java | 9 ++++-- .../codecs/quantization/SampleReader.java | 4 +-- 14 files changed, 119 insertions(+), 50 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/store/EndiannessReverserIndexInput.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/store/EndiannessReverserIndexInput.java index 5ec1402efc7e..c4f5e4c3f78b 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/store/EndiannessReverserIndexInput.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/store/EndiannessReverserIndexInput.java @@ -63,8 +63,7 @@ public void readInts(int[] dst, int offset, int length) throws IOException { public void readFloats(float[] dst, int offset, int length) throws IOException { in.readFloats(dst, offset, length); for (int i = 0; i < length; ++i) { - dst[offset + i] = - Float.intBitsToFloat(Integer.reverseBytes(Float.floatToRawIntBits(dst[offset + i]))); + dst[offset + i] = revertFloat(dst[offset + i]); } } @@ -106,6 +105,14 @@ public byte readByte(long pos) throws IOException { return in.readByte(pos); } + @Override + public void readFloats(long pos, float[] floats, int offset, int length) throws IOException { + in.readFloats(pos, floats, offset, length); + for (int i = 0; i < length; ++i) { + floats[offset + i] = revertFloat(floats[offset + i]); + } + } + @Override public short readShort(long pos) throws IOException { return Short.reverseBytes(in.readShort(pos)); @@ -120,5 +127,19 @@ public int readInt(long pos) throws IOException { public long readLong(long pos) throws IOException { return Long.reverseBytes(in.readLong(pos)); } + + @Override + public Object clone() { + try { + return super.clone(); + } catch (CloneNotSupportedException e) { + throw new Error( + "This cannot happen: Failing to clone EndiannessReverserRandomAccessInput", e); + } + } + } + + private static float revertFloat(float value) { + return Float.intBitsToFloat(Integer.reverseBytes(Float.floatToRawIntBits(value))); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java index 2bfe72386a05..15efd9da40c9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java @@ -16,14 +16,14 @@ */ package org.apache.lucene.codecs.lucene95; -import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.RandomAccessInput; /** - * Implementors can return the IndexInput from which their values are read. For use by vector + * Implementors can return the RandomAccessInput from which their values are read. For use by vector * quantizers. */ public interface HasIndexSlice { - /** Returns an IndexInput from which to read this instance's values. */ - IndexInput getSlice(); + /** Returns a RandomAccessInput from which to read this instance's values. */ + RandomAccessInput getSlice(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 1e78c8ea7aa2..2037f314c24c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -37,7 +37,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement protected final int dimension; protected final int size; - protected final IndexInput slice; + protected final RandomAccessInput slice; protected int lastOrd = -1; protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; @@ -48,7 +48,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues implement OffHeapByteVectorValues( int dimension, int size, - IndexInput slice, + RandomAccessInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { @@ -82,13 +82,13 @@ public byte[] vectorValue(int targetOrd) throws IOException { } @Override - public IndexInput getSlice() { + public RandomAccessInput getSlice() { return slice; } private void readValue(int targetOrd) throws IOException { - slice.seek((long) targetOrd * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); + slice.readBytes( + (long) targetOrd * byteSize, byteBuffer.array(), byteBuffer.arrayOffset(), byteSize); } public static OffHeapByteVectorValues load( @@ -104,7 +104,7 @@ public static OffHeapByteVectorValues load( if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) { return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } - IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); + RandomAccessInput bytesSlice = vectorData.randomAccessSlice(vectorDataOffset, vectorDataLength); if (configuration.isDense()) { return new DenseOffHeapVectorValues( dimension, @@ -133,7 +133,7 @@ public static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { public DenseOffHeapVectorValues( int dimension, int size, - IndexInput slice, + RandomAccessInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction vectorSimilarityFunction) { @@ -143,7 +143,12 @@ public DenseOffHeapVectorValues( @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( - dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + dimension, + size, + (RandomAccessInput) slice.clone(), + byteSize, + flatVectorsScorer, + similarityFunction); } @Override @@ -186,7 +191,7 @@ private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues { public SparseOffHeapVectorValues( OrdToDocDISIReaderConfiguration configuration, IndexInput dataIn, - IndexInput slice, + RandomAccessInput slice, int dimension, int byteSize, FlatVectorsScorer flatVectorsScorer, @@ -220,7 +225,7 @@ public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( configuration, dataIn, - slice.clone(), + (RandomAccessInput) slice.clone(), dimension, byteSize, flatVectorsScorer, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 2384657e93e1..2a2b3b3b0892 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -36,7 +36,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme protected final int dimension; protected final int size; - protected final IndexInput slice; + protected final RandomAccessInput slice; protected final int byteSize; protected int lastOrd = -1; protected final float[] value; @@ -46,7 +46,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues impleme OffHeapFloatVectorValues( int dimension, int size, - IndexInput slice, + RandomAccessInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { @@ -70,7 +70,7 @@ public int size() { } @Override - public IndexInput getSlice() { + public RandomAccessInput getSlice() { return slice; } @@ -79,8 +79,7 @@ public float[] vectorValue(int targetOrd) throws IOException { if (lastOrd == targetOrd) { return value; } - slice.seek((long) targetOrd * byteSize); - slice.readFloats(value, 0, value.length); + slice.readFloats((long) targetOrd * byteSize, value, 0, value.length); lastOrd = targetOrd; return value; } @@ -98,7 +97,7 @@ public static OffHeapFloatVectorValues load( if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.FLOAT32) { return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); } - IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); + RandomAccessInput bytesSlice = vectorData.randomAccessSlice(vectorDataOffset, vectorDataLength); int byteSize = dimension * Float.BYTES; if (configuration.docsWithFieldOffset == -1) { return new DenseOffHeapVectorValues( @@ -129,7 +128,7 @@ public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { public DenseOffHeapVectorValues( int dimension, int size, - IndexInput slice, + RandomAccessInput slice, int byteSize, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { @@ -139,7 +138,12 @@ public DenseOffHeapVectorValues( @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( - dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + dimension, + size, + (RandomAccessInput) slice.clone(), + byteSize, + flatVectorsScorer, + similarityFunction); } @Override @@ -187,7 +191,7 @@ private static class SparseOffHeapVectorValues extends OffHeapFloatVectorValues public SparseOffHeapVectorValues( OrdToDocDISIReaderConfiguration configuration, IndexInput dataIn, - IndexInput slice, + RandomAccessInput slice, int dimension, int byteSize, FlatVectorsScorer flatVectorsScorer, @@ -215,7 +219,7 @@ public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( configuration, dataIn, - slice.clone(), + (RandomAccessInput) slice.clone(), dimension, byteSize, flatVectorsScorer, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java index a4770f01f46d..1962dccd15b8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -213,8 +213,8 @@ private CompressedInt4DotProduct( public float score(int vectorOrdinal) throws IOException { // get compressed vector, in Lucene99, vector values are stored and have a single value for // offset correction - values.getSlice().seek((long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES)); - values.getSlice().readBytes(compressedVector, 0, compressedVector.length); + long pos = (long) vectorOrdinal * (values.getVectorByteLength() + Float.BYTES); + values.getSlice().readBytes(pos, compressedVector, 0, compressedVector.length); float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal); int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector); // For the current implementation of scalar quantization, all dotproducts should be >= 0; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 051c926a679e..72a88fce92aa 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -26,6 +26,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; @@ -46,7 +47,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect protected final FlatVectorsScorer vectorsScorer; protected final boolean compress; - protected final IndexInput slice; + protected final RandomAccessInput slice; protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; protected final int byteSize; @@ -93,7 +94,7 @@ static void compressBytes(byte[] raw, byte[] compressed) { VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer, boolean compress, - IndexInput slice) { + RandomAccessInput slice) { this.dimension = dimension; this.size = size; this.slice = slice; @@ -131,9 +132,9 @@ public byte[] vectorValue(int targetOrd) throws IOException { if (lastOrd == targetOrd) { return binaryValue; } - slice.seek((long) targetOrd * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); - slice.readFloats(scoreCorrectionConstant, 0, 1); + long pos = (long) targetOrd * byteSize; + slice.readBytes(pos, byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); + slice.readFloats(pos + numBytes, scoreCorrectionConstant, 0, 1); decompressBytes(binaryValue, numBytes); lastOrd = targetOrd; return binaryValue; @@ -144,13 +145,12 @@ public float getScoreCorrectionConstant(int targetOrd) throws IOException { if (lastOrd == targetOrd) { return scoreCorrectionConstant[0]; } - slice.seek(((long) targetOrd * byteSize) + numBytes); - slice.readFloats(scoreCorrectionConstant, 0, 1); + slice.readFloats(((long) targetOrd * byteSize) + numBytes, scoreCorrectionConstant, 0, 1); return scoreCorrectionConstant[0]; } @Override - public IndexInput getSlice() { + public RandomAccessInput getSlice() { return slice; } @@ -174,9 +174,8 @@ public static OffHeapQuantizedByteVectorValues load( if (configuration.isEmpty()) { return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); } - IndexInput bytesSlice = - vectorData.slice( - "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + RandomAccessInput bytesSlice = + vectorData.randomAccessSlice(quantizedVectorDataOffset, quantizedVectorDataLength); if (configuration.isDense()) { return new DenseOffHeapVectorValues( dimension, @@ -213,7 +212,7 @@ public DenseOffHeapVectorValues( boolean compress, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer, - IndexInput slice) { + RandomAccessInput slice) { super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); } @@ -226,7 +225,7 @@ public DenseOffHeapVectorValues copy() throws IOException { compress, similarityFunction, vectorsScorer, - slice.clone()); + (RandomAccessInput) slice.clone()); } @Override @@ -275,7 +274,7 @@ public SparseOffHeapVectorValues( IndexInput dataIn, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer, - IndexInput slice) + RandomAccessInput slice) throws IOException { super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); this.configuration = configuration; @@ -300,7 +299,7 @@ public SparseOffHeapVectorValues copy() throws IOException { dataIn, similarityFunction, vectorsScorer, - slice.clone()); + (RandomAccessInput) slice.clone()); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java b/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java index 7f2aadf54a5b..62fc380bc172 100644 --- a/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/BufferedIndexInput.java @@ -258,6 +258,12 @@ public final byte readByte(long pos) throws IOException { return buffer.get((int) index); } + @Override + public void readFloats(long pos, float[] dst, int offset, int len) throws IOException { + seek(pos); + readFloats(dst, offset, len); + } + @Override public void readBytes(long pos, byte[] bytes, int offset, int len) throws IOException { if (len <= bufferSize) { diff --git a/lucene/core/src/java/org/apache/lucene/store/ByteBuffersDataInput.java b/lucene/core/src/java/org/apache/lucene/store/ByteBuffersDataInput.java index a09f78e5f3a6..c0a7d0d6f105 100644 --- a/lucene/core/src/java/org/apache/lucene/store/ByteBuffersDataInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/ByteBuffersDataInput.java @@ -261,6 +261,12 @@ public void readBytes(long pos, byte[] bytes, int offset, int len) throws IOExce } } + @Override + public void readFloats(long pos, float[] floats, int offset, int length) throws IOException { + seek(pos); + readFloats(floats, offset, length); + } + @Override public short readShort(long pos) { long absPos = offset + pos; diff --git a/lucene/core/src/java/org/apache/lucene/store/ByteBuffersIndexInput.java b/lucene/core/src/java/org/apache/lucene/store/ByteBuffersIndexInput.java index 6aebb771b686..48dcd9b58826 100644 --- a/lucene/core/src/java/org/apache/lucene/store/ByteBuffersIndexInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/ByteBuffersIndexInput.java @@ -175,6 +175,12 @@ public void readBytes(long pos, byte[] bytes, int offset, int length) throws IOE in.readBytes(pos, bytes, offset, length); } + @Override + public void readFloats(long pos, float[] floats, int offset, int length) throws IOException { + ensureOpen(); + in.readFloats(pos, floats, offset, length); + } + @Override public short readShort(long pos) throws IOException { ensureOpen(); diff --git a/lucene/core/src/java/org/apache/lucene/store/IndexInput.java b/lucene/core/src/java/org/apache/lucene/store/IndexInput.java index 38eb1dcbceeb..5eaeeac62130 100644 --- a/lucene/core/src/java/org/apache/lucene/store/IndexInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/IndexInput.java @@ -184,6 +184,13 @@ public void readBytes(long pos, byte[] bytes, int offset, int length) throws IOE slice.readBytes(bytes, offset, length); } + @Override + public void readFloats(long pos, float[] floats, int offset, int length) + throws IOException { + slice.seek(pos); + slice.readFloats(floats, offset, length); + } + @Override public short readShort(long pos) throws IOException { slice.seek(pos); diff --git a/lucene/core/src/java/org/apache/lucene/store/RandomAccessInput.java b/lucene/core/src/java/org/apache/lucene/store/RandomAccessInput.java index 08b2e83d36d4..f62a2aba13fb 100644 --- a/lucene/core/src/java/org/apache/lucene/store/RandomAccessInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/RandomAccessInput.java @@ -23,7 +23,7 @@ * Random Access Index API. Unlike {@link IndexInput}, this has no concept of file position, all * reads are absolute. However, like IndexInput, it is only intended for use by a single thread. */ -public interface RandomAccessInput { +public interface RandomAccessInput extends Cloneable { /** The number of bytes in the file. */ long length(); @@ -47,6 +47,14 @@ default void readBytes(long pos, byte[] bytes, int offset, int length) throws IO } } + /** + * Reads a specified number of floats starting at a given position into an array at the specified + * offset. + * + * @see DataInput#readFloats + */ + void readFloats(long pos, float[] floats, int offset, int length) throws IOException; + /** * Reads a short (LE byte order) at the given position in the file * @@ -77,4 +85,6 @@ default void readBytes(long pos, byte[] bytes, int offset, int length) throws IO * @see IndexInput#prefetch */ default void prefetch(long offset, long length) throws IOException {} + + Object clone(); } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index b90ab8276dd1..5adec747c4fd 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -20,7 +20,7 @@ import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; -import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.RandomAccessInput; /** * A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for @@ -52,7 +52,7 @@ public QuantizedByteVectorValues copy() throws IOException { } @Override - public IndexInput getSlice() { + public RandomAccessInput getSlice() { return null; } } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index 3b758de6ce67..df814d18e732 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -43,6 +43,7 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; @@ -133,8 +134,12 @@ public QuantizedByteVectorValues copy() throws IOException { } @Override - public IndexInput getSlice() { - return in; + public RandomAccessInput getSlice() { + try { + return in.randomAccessSlice(0, in.length()); + } catch (IOException e) { + throw new RuntimeException(e); + } } @Override diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java index 684c9fac838f..6ae042e640a6 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java @@ -22,7 +22,7 @@ import java.util.function.IntUnaryOperator; import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; /** A reader of vector values that samples a subset of the vectors. */ @@ -53,7 +53,7 @@ public FloatVectorValues copy() throws IOException { } @Override - public IndexInput getSlice() { + public RandomAccessInput getSlice() { return ((HasIndexSlice) origin).getSlice(); } From 972dbfb17440f8f6a0708b2353913b7c822e94e9 Mon Sep 17 00:00:00 2001 From: Anh Dung Bui Date: Fri, 8 Nov 2024 17:31:28 +0900 Subject: [PATCH 2/2] Fix build error --- .../lucene95/Lucene95HnswVectorsWriter.java | 4 ++-- .../benchmark/jmh/VectorScorerBenchmark.java | 2 +- .../lucene99/Lucene99FlatVectorsWriter.java | 4 ++-- .../Lucene99ScalarQuantizedVectorsWriter.java | 2 +- .../org/apache/lucene/store/IndexInput.java | 17 +++++++++++++++++ .../Lucene99MemorySegmentByteVectorScorer.java | 9 ++++++++- ...99MemorySegmentByteVectorScorerSupplier.java | 6 +++++- .../lucene/store/MemorySegmentIndexInput.java | 6 ++++++ .../codecs/hnsw/TestFlatVectorScorer.java | 9 ++------- .../vectorization/TestVectorScorer.java | 4 ++-- 10 files changed, 46 insertions(+), 17 deletions(-) diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index c855d8f5e073..67ebed95bcdf 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -450,7 +450,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE new OffHeapByteVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), - vectorDataInput, + vectorDataInput.toRandomAccessInput(), byteSize, defaultFlatVectorScorer, fieldInfo.getVectorSimilarityFunction())); @@ -462,7 +462,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE new OffHeapFloatVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), - vectorDataInput, + vectorDataInput.toRandomAccessInput(), byteSize, defaultFlatVectorScorer, fieldInfo.getVectorSimilarityFunction())); diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java index 0a4da1f48867..512908345c4e 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java @@ -98,7 +98,7 @@ public float binaryDotProductMemSeg() throws IOException { static KnnVectorValues vectorValues( int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( - dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim); + dims, size, in.toRandomAccessInput(), dims, new ThrowingFlatVectorScorer(), sim); } static final class ThrowingFlatVectorScorer implements FlatVectorsScorer { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index b731e758b7a8..9a323bc7ba6f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -303,7 +303,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( new OffHeapByteVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), - finalVectorDataInput, + finalVectorDataInput.toRandomAccessInput(), fieldInfo.getVectorDimension() * Byte.BYTES, vectorsScorer, fieldInfo.getVectorSimilarityFunction())); @@ -313,7 +313,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( new OffHeapFloatVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), - finalVectorDataInput, + finalVectorDataInput.toRandomAccessInput(), fieldInfo.getVectorDimension() * Float.BYTES, vectorsScorer, fieldInfo.getVectorSimilarityFunction())); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 1a30b5271cd7..583cafcd8b21 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -524,7 +524,7 @@ private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( compress, fieldInfo.getVectorSimilarityFunction(), vectorsScorer, - quantizationDataInput))); + quantizationDataInput.toRandomAccessInput()))); } finally { if (success == false) { IOUtils.closeWhileHandlingException(tempQuantizedVectorData, quantizationDataInput); diff --git a/lucene/core/src/java/org/apache/lucene/store/IndexInput.java b/lucene/core/src/java/org/apache/lucene/store/IndexInput.java index 5eaeeac62130..26947542b60d 100644 --- a/lucene/core/src/java/org/apache/lucene/store/IndexInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/IndexInput.java @@ -152,6 +152,14 @@ protected String getFullSliceDescription(String sliceDescription) { } } + /** Convert this IndexInput a RandomAccessInput. */ + public RandomAccessInput toRandomAccessInput() throws IOException { + if (this instanceof RandomAccessInput) { + return (RandomAccessInput) this; + } + return randomAccessSlice(0, length()); + } + /** * Creates a random-access slice of this index input, with the given offset and length. * @@ -214,6 +222,15 @@ public void prefetch(long offset, long length) throws IOException { slice.prefetch(offset, length); } + @Override + public Object clone() { + try { + return super.clone(); + } catch (CloneNotSupportedException e) { + throw new Error("This cannot happen: Failing to clone RandomAccessInput", e); + } + } + @Override public String toString() { return "RandomAccessInput(" + IndexInput.this.toString() + ")"; diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java index b65f1e570921..9a8008cd86b9 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java @@ -25,6 +25,7 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; abstract sealed class Lucene99MemorySegmentByteVectorScorer @@ -40,8 +41,14 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer * returned. */ public static Optional create( - VectorSimilarityFunction type, IndexInput input, KnnVectorValues values, byte[] queryVector) { + VectorSimilarityFunction type, + RandomAccessInput slice, + KnnVectorValues values, + byte[] queryVector) { assert values instanceof ByteVectorValues; + if (!(slice instanceof IndexInput input)) { + return Optional.empty(); + } input = FilterIndexInput.unwrapOnlyTest(input); if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java index 02c71561122d..22012d89ac14 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java @@ -25,6 +25,7 @@ import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -42,8 +43,11 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier * optional is returned. */ static Optional create( - VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) { + VectorSimilarityFunction type, RandomAccessInput slice, KnnVectorValues values) { assert values instanceof ByteVectorValues; + if (!(slice instanceof IndexInput input)) { + return Optional.empty(); + } input = FilterIndexInput.unwrapOnlyTest(input); if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); diff --git a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java index 832fa5f98e6b..a1e605a51df3 100644 --- a/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java +++ b/lucene/core/src/java21/org/apache/lucene/store/MemorySegmentIndexInput.java @@ -232,6 +232,12 @@ public void readLongs(long[] dst, int offset, int length) throws IOException { } } + @Override + public void readFloats(long pos, float[] dst, int offset, int len) throws IOException { + seek(pos); + readFloats(dst, offset, len); + } + @Override public void readFloats(float[] dst, int offset, int length) throws IOException { try { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java index 6fe9a685e1b4..4e951ad7539a 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java @@ -178,18 +178,13 @@ public void testCheckFloatDimensions() throws IOException { ByteVectorValues byteVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( - dims, size, in.slice("byteValues", 0, in.length()), dims, flatVectorsScorer, sim); + dims, size, in.toRandomAccessInput(), dims, flatVectorsScorer, sim); } FloatVectorValues floatVectorValues( int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapFloatVectorValues.DenseOffHeapVectorValues( - dims, - size, - in.slice("floatValues", 0, in.length()), - dims * Float.BYTES, - flatVectorsScorer, - sim); + dims, size, in.toRandomAccessInput(), dims * Float.BYTES, flatVectorsScorer, sim); } /** Concatenates float arrays as byte[]. */ diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java index bc3b6813a5be..122c71afbe54 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java @@ -381,13 +381,13 @@ public void testWithFloatValues() throws IOException { KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( - dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim); + dims, size, in.toRandomAccessInput(), dims, MEMSEG_SCORER, sim); } KnnVectorValues floatVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapFloatVectorValues.DenseOffHeapVectorValues( - dims, size, in.slice("floatValues", 0, in.length()), dims, MEMSEG_SCORER, sim); + dims, size, in.toRandomAccessInput(), dims, MEMSEG_SCORER, sim); } // creates the vector based on the given ordinal, which is reproducible given the ord and dims