From 1824a173a19eca3b0f09f23fd6de5c49ac0680f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Sat, 10 Jan 2026 23:22:15 +0000 Subject: [PATCH 01/12] Remove grpc-java reflection for ownership transfers --- .../org/apache/arrow/flight/ArrowMessage.java | 130 +++++++++++- .../arrow/flight/grpc/GetReadableBuffer.java | 99 --------- .../flight/TestArrowMessageZeroCopy.java | 196 ++++++++++++++++++ 3 files changed, 321 insertions(+), 104 deletions(-) delete mode 100644 flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java create mode 100644 flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index ab4eab3048..2292a256ca 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -23,7 +23,9 @@ import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.WireFormat; +import io.grpc.Detachable; import io.grpc.Drainable; +import io.grpc.HasByteBuffer; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.protobuf.ProtoUtils; import io.netty.buffer.ByteBuf; @@ -41,11 +43,12 @@ import java.util.Collections; import java.util.List; import org.apache.arrow.flight.grpc.AddWritableBuffer; -import org.apache.arrow.flight.grpc.GetReadableBuffer; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ForeignAllocation; +import org.apache.arrow.memory.util.MemoryUtil; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -55,10 +58,14 @@ import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** The in-memory representation of FlightData used to manage a stream of Arrow messages. */ class ArrowMessage implements AutoCloseable { + private static final Logger LOG = LoggerFactory.getLogger(ArrowMessage.class); + // If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer // instead of copying the data. Defaults to true. public static final boolean ENABLE_ZERO_COPY_READ; @@ -312,8 +319,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s case APP_METADATA_TAG: { int size = readRawVarint32(stream); - appMetadata = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ); + appMetadata = readBuffer(allocator, stream, size); break; } case BODY_TAG: @@ -323,8 +329,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s body = null; } int size = readRawVarint32(stream); - body = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ); + body = readBuffer(allocator, stream, size); break; default: @@ -377,6 +382,121 @@ private static int readRawVarint32(int firstByte, InputStream is) throws IOExcep return CodedInputStream.readRawVarint32(firstByte, is); } + /** + * Reads data from the stream into an ArrowBuf, without copying data when possible. + * + *

First attempts to transfer ownership of the gRPC buffer to Arrow via {@link + * #wrapGrpcBuffer}. This avoids any memory copy when the gRPC transport provides a direct + * ByteBuffer (e.g., Netty). + * + *

If not possible (e.g., heap buffer, fragmented data, or unsupported transport), falls back + * to allocating a new buffer and copying data into it. + * + * @param allocator The allocator to use for buffer allocation + * @param stream The input stream to read from + * @param size The number of bytes to read + * @return An ArrowBuf containing the data + * @throws IOException if there is an error reading from the stream + */ + private static ArrowBuf readBuffer(BufferAllocator allocator, InputStream stream, int size) + throws IOException { + if (ENABLE_ZERO_COPY_READ) { + ArrowBuf zeroCopyBuf = wrapGrpcBuffer(stream, allocator, size); + if (zeroCopyBuf != null) { + return zeroCopyBuf; + } + } + + // Fall back to allocating and copying + ArrowBuf buf = allocator.buffer(size); + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + return buf; + } + + /** + * Attempts to wrap gRPC's buffer as an ArrowBuf without copying. + * + *

This method takes ownership of gRPC's underlying buffer via {@link Detachable#detach()} and + * wraps it as an ArrowBuf using {@link BufferAllocator#wrapForeignAllocation}. The gRPC buffer + * will be released when the ArrowBuf is closed. + * + * @param stream The gRPC-provided InputStream + * @param allocator The allocator to use for wrapping the foreign allocation + * @param size The number of bytes to wrap + * @return An ArrowBuf wrapping gRPC's buffer, or {@code null} if zero-copy is not possible + */ + static ArrowBuf wrapGrpcBuffer( + final InputStream stream, final BufferAllocator allocator, final int size) { + + if (!(stream instanceof Detachable) || !(stream instanceof HasByteBuffer)) { + return null; + } + + HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; + if (!hasByteBuffer.byteBufferSupported()) { + return null; + } + + ByteBuffer peekBuffer = hasByteBuffer.getByteBuffer(); + if (peekBuffer == null) { + return null; + } + if (!peekBuffer.isDirect()) { + return null; + } + if (peekBuffer.remaining() < size) { + // Data is fragmented across multiple buffers; zero-copy not possible + return null; + } + + // Take ownership + Detachable detachable = (Detachable) stream; + InputStream detachedStream = detachable.detach(); + + // Get buffer from detached stream + HasByteBuffer detachedHasByteBuffer = (HasByteBuffer) detachedStream; + ByteBuffer detachedByteBuffer = detachedHasByteBuffer.getByteBuffer(); + + if (detachedByteBuffer == null || !detachedByteBuffer.isDirect()) { + closeQuietly(detachedStream); + return null; + } + + // Calculate memory address accounting for buffer position + long baseAddress = MemoryUtil.getByteBufferAddress(detachedByteBuffer); + long dataAddress = baseAddress + detachedByteBuffer.position(); + + // Create ForeignAllocation with proper cleanup + ForeignAllocation foreignAllocation = + new ForeignAllocation(size, dataAddress) { + @Override + protected void release0() { + closeQuietly(detachedStream); + } + }; + + try { + return allocator.wrapForeignAllocation(foreignAllocation); + } catch (Throwable t) { + // If it fails, clean up the detached stream and propagate + closeQuietly(detachedStream); + throw t; + } + } + + private static void closeQuietly(InputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + LOG.debug("Error closing detached gRPC stream", e); + } + } + } + /** * Convert the ArrowMessage to an InputStream. * diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java deleted file mode 100644 index fcba88d212..0000000000 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.grpc; - -import com.google.common.base.Throwables; -import com.google.common.io.ByteStreams; -import io.grpc.internal.ReadableBuffer; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Field; -import org.apache.arrow.memory.ArrowBuf; - -/** - * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target - * ByteBuffer/ByteBuf. - * - *

This could be solved by BufferInputStream exposing Drainable. - */ -public class GetReadableBuffer { - - private static final Field READABLE_BUFFER; - private static final Class BUFFER_INPUT_STREAM; - - static { - Field tmpField = null; - Class tmpClazz = null; - try { - Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); - - Field f = clazz.getDeclaredField("buffer"); - f.setAccessible(true); - // don't set until we've gotten past all exception cases. - tmpField = f; - tmpClazz = clazz; - } catch (Exception e) { - new RuntimeException("Failed to initialize GetReadableBuffer, falling back to slow path", e) - .printStackTrace(); - } - READABLE_BUFFER = tmpField; - BUFFER_INPUT_STREAM = tmpClazz; - } - - /** - * Extracts the ReadableBuffer for the given input stream. - * - * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or null - * will be returned. - */ - public static ReadableBuffer getReadableBuffer(InputStream is) { - - if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) { - return null; - } - - try { - return (ReadableBuffer) READABLE_BUFFER.get(is); - } catch (Exception ex) { - throw Throwables.propagate(ex); - } - } - - /** - * Helper method to read a gRPC-provided InputStream into an ArrowBuf. - * - * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. - * @param buf The buffer to read into. - * @param size The number of bytes to read. - * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link - * #BUFFER_INPUT_STREAM}). - * @throws IOException if there is an error reading form the stream - */ - public static void readIntoBuffer( - final InputStream stream, final ArrowBuf buf, final int size, final boolean fastPath) - throws IOException { - ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; - byte[] heapBytes = new byte[size]; - if (readableBuffer != null) { - readableBuffer.readBytes(heapBytes, 0, size); - } else { - ByteStreams.readFully(stream, heapBytes); - } - buf.writeBytes(heapBytes); - buf.writerIndex(size); - } -} diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java new file mode 100644 index 0000000000..099b1cd3e5 --- /dev/null +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Random; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TestArrowMessageZeroCopy { + + private BufferAllocator allocator; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + private static InputStream createGrpcStreamWithDirectBuffer(byte[] data) { + ByteBuffer directBuffer = ByteBuffer.allocateDirect(data.length); + directBuffer.put(data); + directBuffer.flip(); + ReadableBuffer readableBuffer = ReadableBuffers.wrap(directBuffer); + return ReadableBuffers.openStream(readableBuffer, true); + } + + @Test + public void testWrapGrpcBufferReturnsNullForRegularInputStream() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + InputStream stream = new ByteArrayInputStream(testData); + + // ByteArrayInputStream doesn't implement Detachable or HasByteBuffer + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for streams not implementing required interfaces"); + } + + @Test + public void testWrapGrpcBufferSucceedsForRealGrpcDirectBuffer() throws IOException { + byte[] testData = new byte[] {11, 22, 33, 44, 55}; + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).byteBufferSupported(), + "Direct buffer stream should support ByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).getByteBuffer().isDirect(), + "Should have direct ByteBuffer backing"); + + try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { + assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); + assertEquals(testData.length, result.capacity()); + + // Check received data is the same + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + } + } + + @Test + public void testWrapGrpcBufferReturnsNullForRealGrpcHeapByteBuffer() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + ByteBuffer heapBuffer = ByteBuffer.wrap(testData); + ReadableBuffer readableBuffer = ReadableBuffers.wrap(heapBuffer); + + InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).byteBufferSupported(), + "Heap ByteBuffer stream should support ByteBuffer"); + assertFalse( + ((HasByteBuffer) stream).getByteBuffer().isDirect(), "Should have heap ByteBuffer backing"); + + // Zero-copy should return null for heap buffer (not direct) + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for real gRPC stream with heap buffer"); + } + + @Test + public void testWrapGrpcBufferReturnsNullForRealGrpcByteArrayStream() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + ReadableBuffer readableBuffer = ReadableBuffers.wrap(testData); + InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + + // Verify the stream has the expected gRPC interfaces + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + // Byte array backed streams don't support ByteBuffer access + assertFalse( + ((HasByteBuffer) stream).byteBufferSupported(), + "Byte array stream should not support ByteBuffer"); + + // Zero-copy should return null when byteBufferSupported() is false + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for real gRPC stream backed by byte array"); + } + + @Test + public void testWrapGrpcBufferMemoryAccountingWithRealGrpcStream() throws IOException { + byte[] testData = new byte[1024]; + new Random(42).nextBytes(testData); + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + long memoryBefore = allocator.getAllocatedMemory(); + assertEquals(0, memoryBefore); + + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); + + long memoryDuring = allocator.getAllocatedMemory(); + assertEquals(testData.length, memoryDuring); + + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + + result.close(); + + long memoryAfter = allocator.getAllocatedMemory(); + assertEquals(0, memoryAfter); + } + + @Test + public void testWrapGrpcBufferReturnsNullForInsufficientDataWithRealGrpcStream() + throws IOException { + byte[] testData = new byte[] {1, 2, 3}; + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + // Request more data than available + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, 10); + assertNull(result, "Should return null when buffer has insufficient data"); + } + + @Test + public void testWrapGrpcBufferLargeDataWithRealGrpcStream() throws IOException { + // Test with larger data (64KB) + byte[] testData = new byte[64 * 1024]; + new Random(42).nextBytes(testData); + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { + assertNotNull(result, "Should succeed for large data with real gRPC stream"); + assertEquals(testData.length, result.capacity()); + + // Verify data integrity + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + } + } +} From efd58b7436e650c63288b95d4c909df528a25fc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Mon, 12 Jan 2026 14:19:32 +0000 Subject: [PATCH 02/12] remove redundant check --- .../org/apache/arrow/flight/ArrowMessage.java | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 2292a256ca..366277f711 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -408,12 +408,12 @@ private static ArrowBuf readBuffer(BufferAllocator allocator, InputStream stream } // Fall back to allocating and copying - ArrowBuf buf = allocator.buffer(size); - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); - buf.writerIndex(size); - return buf; + ArrowBuf buf = allocator.buffer(size); + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + return buf; } /** @@ -453,17 +453,10 @@ static ArrowBuf wrapGrpcBuffer( } // Take ownership - Detachable detachable = (Detachable) stream; - InputStream detachedStream = detachable.detach(); + InputStream detachedStream = ((Detachable) stream).detach(); // Get buffer from detached stream - HasByteBuffer detachedHasByteBuffer = (HasByteBuffer) detachedStream; - ByteBuffer detachedByteBuffer = detachedHasByteBuffer.getByteBuffer(); - - if (detachedByteBuffer == null || !detachedByteBuffer.isDirect()) { - closeQuietly(detachedStream); - return null; - } + ByteBuffer detachedByteBuffer = ((HasByteBuffer) detachedStream).getByteBuffer(); // Calculate memory address accounting for buffer position long baseAddress = MemoryUtil.getByteBufferAddress(detachedByteBuffer); From 19a9c51eadd6528a95a56c377e0e331c5cffa642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Wed, 14 Jan 2026 22:24:22 +0000 Subject: [PATCH 03/12] set grpc-core as runtime dependency --- flight/flight-core/pom.xml | 1 + flight/flight-core/src/main/java/module-info.java | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/flight/flight-core/pom.xml b/flight/flight-core/pom.xml index 15b870905d..de06e0518e 100644 --- a/flight/flight-core/pom.xml +++ b/flight/flight-core/pom.xml @@ -62,6 +62,7 @@ under the License. io.grpc grpc-core + runtime io.grpc diff --git a/flight/flight-core/src/main/java/module-info.java b/flight/flight-core/src/main/java/module-info.java index 669797ac93..9bafc5fddf 100644 --- a/flight/flight-core/src/main/java/module-info.java +++ b/flight/flight-core/src/main/java/module-info.java @@ -30,7 +30,6 @@ requires com.google.protobuf; requires com.google.protobuf.util; requires io.grpc; - requires io.grpc.internal; requires io.grpc.netty; requires io.grpc.protobuf; requires io.grpc.stub; From f5e43d85ad7364c56cc9f0cd04a559b8dd1dccc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Thu, 15 Jan 2026 12:28:02 +0000 Subject: [PATCH 04/12] mock ReadableBuffer to not depend on grpc-core --- .../flight/TestArrowMessageZeroCopy.java | 110 +++++++++++------- 1 file changed, 71 insertions(+), 39 deletions(-) diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java index 099b1cd3e5..7f79868ad7 100644 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java @@ -26,8 +26,6 @@ import io.grpc.Detachable; import io.grpc.HasByteBuffer; -import io.grpc.internal.ReadableBuffer; -import io.grpc.internal.ReadableBuffers; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; @@ -54,14 +52,6 @@ public void tearDown() { allocator.close(); } - private static InputStream createGrpcStreamWithDirectBuffer(byte[] data) { - ByteBuffer directBuffer = ByteBuffer.allocateDirect(data.length); - directBuffer.put(data); - directBuffer.flip(); - ReadableBuffer readableBuffer = ReadableBuffers.wrap(directBuffer); - return ReadableBuffers.openStream(readableBuffer, true); - } - @Test public void testWrapGrpcBufferReturnsNullForRegularInputStream() throws IOException { byte[] testData = new byte[] {1, 2, 3, 4, 5}; @@ -73,9 +63,9 @@ public void testWrapGrpcBufferReturnsNullForRegularInputStream() throws IOExcept } @Test - public void testWrapGrpcBufferSucceedsForRealGrpcDirectBuffer() throws IOException { + public void testWrapGrpcBufferSucceedsForDirectBuffer() throws IOException { byte[] testData = new byte[] {11, 22, 33, 44, 55}; - InputStream stream = createGrpcStreamWithDirectBuffer(testData); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); assertInstanceOf( @@ -88,7 +78,7 @@ public void testWrapGrpcBufferSucceedsForRealGrpcDirectBuffer() throws IOExcepti "Should have direct ByteBuffer backing"); try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { - assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); + assertNotNull(result, "Should succeed for gRPC stream with direct buffer"); assertEquals(testData.length, result.capacity()); // Check received data is the same @@ -101,10 +91,7 @@ public void testWrapGrpcBufferSucceedsForRealGrpcDirectBuffer() throws IOExcepti @Test public void testWrapGrpcBufferReturnsNullForRealGrpcHeapByteBuffer() throws IOException { byte[] testData = new byte[] {1, 2, 3, 4, 5}; - ByteBuffer heapBuffer = ByteBuffer.wrap(testData); - ReadableBuffer readableBuffer = ReadableBuffers.wrap(heapBuffer); - - InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + InputStream stream = MockGrpcInputStream.ofHeapBuffer(testData); assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); assertInstanceOf( @@ -117,14 +104,13 @@ public void testWrapGrpcBufferReturnsNullForRealGrpcHeapByteBuffer() throws IOEx // Zero-copy should return null for heap buffer (not direct) ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); - assertNull(result, "Should return null for real gRPC stream with heap buffer"); + assertNull(result, "Should return null for gRPC stream with heap buffer"); } @Test - public void testWrapGrpcBufferReturnsNullForRealGrpcByteArrayStream() throws IOException { + public void testWrapGrpcBufferReturnsNullWhenByteBufferNotSupported() throws IOException { byte[] testData = new byte[] {1, 2, 3, 4, 5}; - ReadableBuffer readableBuffer = ReadableBuffers.wrap(testData); - InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + InputStream stream = MockGrpcInputStream.withoutByteBufferSupport(testData); // Verify the stream has the expected gRPC interfaces assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); @@ -137,39 +123,33 @@ public void testWrapGrpcBufferReturnsNullForRealGrpcByteArrayStream() throws IOE // Zero-copy should return null when byteBufferSupported() is false ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); - assertNull(result, "Should return null for real gRPC stream backed by byte array"); + assertNull(result, "Should return null for gRPC stream without ByteBuffer support"); } @Test - public void testWrapGrpcBufferMemoryAccountingWithRealGrpcStream() throws IOException { + public void testWrapGrpcBufferMemoryAccounting() throws IOException { byte[] testData = new byte[1024]; new Random(42).nextBytes(testData); - InputStream stream = createGrpcStreamWithDirectBuffer(testData); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); - long memoryBefore = allocator.getAllocatedMemory(); - assertEquals(0, memoryBefore); + assertEquals(0, allocator.getAllocatedMemory()); ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); - assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); - - long memoryDuring = allocator.getAllocatedMemory(); - assertEquals(testData.length, memoryDuring); + assertNotNull(result, "Should succeed for gRPC stream with direct buffer"); + assertEquals(testData.length, allocator.getAllocatedMemory()); byte[] readData = new byte[testData.length]; result.getBytes(0, readData); assertArrayEquals(testData, readData); result.close(); - - long memoryAfter = allocator.getAllocatedMemory(); - assertEquals(0, memoryAfter); + assertEquals(0, allocator.getAllocatedMemory()); } @Test - public void testWrapGrpcBufferReturnsNullForInsufficientDataWithRealGrpcStream() - throws IOException { + public void testWrapGrpcBufferReturnsNullForInsufficientData() throws IOException { byte[] testData = new byte[] {1, 2, 3}; - InputStream stream = createGrpcStreamWithDirectBuffer(testData); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); // Request more data than available ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, 10); @@ -177,11 +157,10 @@ public void testWrapGrpcBufferReturnsNullForInsufficientDataWithRealGrpcStream() } @Test - public void testWrapGrpcBufferLargeDataWithRealGrpcStream() throws IOException { - // Test with larger data (64KB) + public void testWrapGrpcBufferLargeData() throws IOException { byte[] testData = new byte[64 * 1024]; new Random(42).nextBytes(testData); - InputStream stream = createGrpcStreamWithDirectBuffer(testData); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { assertNotNull(result, "Should succeed for large data with real gRPC stream"); @@ -193,4 +172,57 @@ public void testWrapGrpcBufferLargeDataWithRealGrpcStream() throws IOException { assertArrayEquals(testData, readData); } } + + /** Mock InputStream implementing gRPC's Detachable and HasByteBuffer for testing zero-copy. */ + private static class MockGrpcInputStream extends InputStream + implements Detachable, HasByteBuffer { + private ByteBuffer buffer; + private final boolean byteBufferSupported; + + private MockGrpcInputStream(ByteBuffer buffer, boolean byteBufferSupported) { + this.buffer = buffer; + this.byteBufferSupported = byteBufferSupported; + } + + static MockGrpcInputStream ofDirectBuffer(byte[] data) { + ByteBuffer buf = ByteBuffer.allocateDirect(data.length); + buf.put(data).flip(); + return new MockGrpcInputStream(buf, true); + } + + static MockGrpcInputStream ofHeapBuffer(byte[] data) { + return new MockGrpcInputStream(ByteBuffer.wrap(data), true); + } + + static MockGrpcInputStream withoutByteBufferSupport(byte[] data) { + return new MockGrpcInputStream(ByteBuffer.wrap(data), false); + } + + @Override + public boolean byteBufferSupported() { + return byteBufferSupported; + } + + @Override + public ByteBuffer getByteBuffer() { + return byteBufferSupported ? buffer : null; + } + + @Override + public InputStream detach() { + ByteBuffer detached = this.buffer; + this.buffer = null; + return new MockGrpcInputStream(detached, byteBufferSupported); + } + + @Override + public int read() { + return (buffer != null && buffer.hasRemaining()) ? (buffer.get() & 0xFF) : -1; + } + + @Override + public void close() { + buffer = null; + } + } } From 1962322ec982720741f93c7691e11d492b500305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Thu, 15 Jan 2026 23:52:01 +0000 Subject: [PATCH 05/12] rework frame to wrap inputstream with ArrowBuf before reading --- .../org/apache/arrow/flight/ArrowMessage.java | 225 +--------- .../apache/arrow/flight/FlightDataParser.java | 410 ++++++++++++++++++ .../flight/TestArrowMessageZeroCopy.java | 192 ++++---- 3 files changed, 495 insertions(+), 332 deletions(-) create mode 100644 flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index 366277f711..b768559ca7 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -18,14 +18,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import com.google.common.io.ByteStreams; import com.google.protobuf.ByteString; -import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.WireFormat; -import io.grpc.Detachable; import io.grpc.Drainable; -import io.grpc.HasByteBuffer; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.protobuf.ProtoUtils; import io.netty.buffer.ByteBuf; @@ -42,13 +38,14 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.arrow.flight.FlightDataParser.ArrowBufReader; +import org.apache.arrow.flight.FlightDataParser.FlightDataReader; +import org.apache.arrow.flight.FlightDataParser.InputStreamReader; import org.apache.arrow.flight.grpc.AddWritableBuffer; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.ForeignAllocation; -import org.apache.arrow.memory.util.MemoryUtil; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -82,19 +79,10 @@ class ArrowMessage implements AutoCloseable { if (zeroCopyWriteFlag == null) { zeroCopyWriteFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE"); } - ENABLE_ZERO_COPY_READ = !"false".equalsIgnoreCase(zeroCopyReadFlag); + ENABLE_ZERO_COPY_READ = true; // !"false".equalsIgnoreCase(zeroCopyReadFlag); ENABLE_ZERO_COPY_WRITE = "true".equalsIgnoreCase(zeroCopyWriteFlag); } - private static final int DESCRIPTOR_TAG = - (FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final int BODY_TAG = - (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final int HEADER_TAG = - (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final int APP_METADATA_TAG = - (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; - private static final Marshaller NO_BODY_MARSHALLER = ProtoUtils.marshaller(FlightData.getDefaultInstance()); @@ -219,7 +207,7 @@ public ArrowMessage(FlightDescriptor descriptor) { this.tryZeroCopyWrite = false; } - private ArrowMessage( + ArrowMessage( FlightDescriptor descriptor, MessageMetadataResult message, ArrowBuf appMetadata, @@ -287,207 +275,16 @@ public Iterable getBufs() { } private static ArrowMessage frame(BufferAllocator allocator, final InputStream stream) { - - try { - FlightDescriptor descriptor = null; - MessageMetadataResult header = null; - ArrowBuf body = null; - ArrowBuf appMetadata = null; - while (stream.available() > 0) { - final int tagFirstByte = stream.read(); - if (tagFirstByte == -1) { - break; - } - int tag = readRawVarint32(tagFirstByte, stream); - switch (tag) { - case DESCRIPTOR_TAG: - { - int size = readRawVarint32(stream); - byte[] bytes = new byte[size]; - ByteStreams.readFully(stream, bytes); - descriptor = FlightDescriptor.parseFrom(bytes); - break; - } - case HEADER_TAG: - { - int size = readRawVarint32(stream); - byte[] bytes = new byte[size]; - ByteStreams.readFully(stream, bytes); - header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size); - break; - } - case APP_METADATA_TAG: - { - int size = readRawVarint32(stream); - appMetadata = readBuffer(allocator, stream, size); - break; - } - case BODY_TAG: - if (body != null) { - // only read last body. - body.getReferenceManager().release(); - body = null; - } - int size = readRawVarint32(stream); - body = readBuffer(allocator, stream, size); - break; - - default: - // ignore unknown fields. - } - } - // Protobuf implementations can omit empty fields, such as body; for some message types, like - // RecordBatch, - // this will fail later as we still expect an empty buffer. In those cases only, fill in an - // empty buffer here - - // in other cases, like Schema, having an unexpected empty buffer will also cause failures. - // We don't fill in defaults for fields like header, for which there is no reasonable default, - // or for appMetadata - // or descriptor, which are intended to be empty in some cases. - if (header != null) { - switch (HeaderType.getHeader(header.headerType())) { - case SCHEMA: - // Ignore 0-length buffers in case a Protobuf implementation wrote it out - if (body != null && body.capacity() == 0) { - body.close(); - body = null; - } - break; - case DICTIONARY_BATCH: - case RECORD_BATCH: - // A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here - if (body == null) { - body = allocator.getEmpty(); - } - break; - case NONE: - case TENSOR: - default: - // Do nothing - break; - } - } - return new ArrowMessage(descriptor, header, appMetadata, body); - } catch (Exception ioe) { - throw new RuntimeException(ioe); - } - } - - private static int readRawVarint32(InputStream is) throws IOException { - int firstByte = is.read(); - return readRawVarint32(firstByte, is); - } - - private static int readRawVarint32(int firstByte, InputStream is) throws IOException { - return CodedInputStream.readRawVarint32(firstByte, is); - } - - /** - * Reads data from the stream into an ArrowBuf, without copying data when possible. - * - *

First attempts to transfer ownership of the gRPC buffer to Arrow via {@link - * #wrapGrpcBuffer}. This avoids any memory copy when the gRPC transport provides a direct - * ByteBuffer (e.g., Netty). - * - *

If not possible (e.g., heap buffer, fragmented data, or unsupported transport), falls back - * to allocating a new buffer and copying data into it. - * - * @param allocator The allocator to use for buffer allocation - * @param stream The input stream to read from - * @param size The number of bytes to read - * @return An ArrowBuf containing the data - * @throws IOException if there is an error reading from the stream - */ - private static ArrowBuf readBuffer(BufferAllocator allocator, InputStream stream, int size) - throws IOException { + FlightDataReader reader; if (ENABLE_ZERO_COPY_READ) { - ArrowBuf zeroCopyBuf = wrapGrpcBuffer(stream, allocator, size); - if (zeroCopyBuf != null) { - return zeroCopyBuf; + reader = ArrowBufReader.tryArrowBufReader(allocator, stream); + if (reader != null) { + return reader.toMessage(); } } - // Fall back to allocating and copying - ArrowBuf buf = allocator.buffer(size); - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); - buf.writerIndex(size); - return buf; - } - - /** - * Attempts to wrap gRPC's buffer as an ArrowBuf without copying. - * - *

This method takes ownership of gRPC's underlying buffer via {@link Detachable#detach()} and - * wraps it as an ArrowBuf using {@link BufferAllocator#wrapForeignAllocation}. The gRPC buffer - * will be released when the ArrowBuf is closed. - * - * @param stream The gRPC-provided InputStream - * @param allocator The allocator to use for wrapping the foreign allocation - * @param size The number of bytes to wrap - * @return An ArrowBuf wrapping gRPC's buffer, or {@code null} if zero-copy is not possible - */ - static ArrowBuf wrapGrpcBuffer( - final InputStream stream, final BufferAllocator allocator, final int size) { - - if (!(stream instanceof Detachable) || !(stream instanceof HasByteBuffer)) { - return null; - } - - HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; - if (!hasByteBuffer.byteBufferSupported()) { - return null; - } - - ByteBuffer peekBuffer = hasByteBuffer.getByteBuffer(); - if (peekBuffer == null) { - return null; - } - if (!peekBuffer.isDirect()) { - return null; - } - if (peekBuffer.remaining() < size) { - // Data is fragmented across multiple buffers; zero-copy not possible - return null; - } - - // Take ownership - InputStream detachedStream = ((Detachable) stream).detach(); - - // Get buffer from detached stream - ByteBuffer detachedByteBuffer = ((HasByteBuffer) detachedStream).getByteBuffer(); - - // Calculate memory address accounting for buffer position - long baseAddress = MemoryUtil.getByteBufferAddress(detachedByteBuffer); - long dataAddress = baseAddress + detachedByteBuffer.position(); - - // Create ForeignAllocation with proper cleanup - ForeignAllocation foreignAllocation = - new ForeignAllocation(size, dataAddress) { - @Override - protected void release0() { - closeQuietly(detachedStream); - } - }; - - try { - return allocator.wrapForeignAllocation(foreignAllocation); - } catch (Throwable t) { - // If it fails, clean up the detached stream and propagate - closeQuietly(detachedStream); - throw t; - } - } - - private static void closeQuietly(InputStream stream) { - if (stream != null) { - try { - stream.close(); - } catch (IOException e) { - LOG.debug("Error closing detached gRPC stream", e); - } - } + reader = new InputStreamReader(allocator, stream); + return reader.toMessage(); } /** diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java new file mode 100644 index 0000000000..7143f8b08f --- /dev/null +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight; + +import com.google.common.io.ByteStreams; +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.WireFormat; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import org.apache.arrow.flight.impl.Flight.FlightData; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ForeignAllocation; +import org.apache.arrow.memory.util.MemoryUtil; +import org.apache.arrow.vector.ipc.message.MessageMetadataResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Parses FlightData protobuf messages into ArrowMessage objects. + * + *

This class handles parsing from both regular InputStreams (with data copying) and ArrowBuf + * (with zero-copy slicing for large fields like app_metadata and body). + * + *

Small fields (descriptor, header) are always copied. Large fields (app_metadata, body) use + * zero-copy slicing when parsing from ArrowBuf. + */ +final class FlightDataParser { + + // Protobuf wire format tags for FlightData fields + private static final int DESCRIPTOR_TAG = + (FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int HEADER_TAG = + (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int BODY_TAG = + (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int APP_METADATA_TAG = + (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + + /** Base class for FlightData readers with common parsing logic. */ + abstract static class FlightDataReader { + protected final BufferAllocator allocator; + + protected FlightDescriptor descriptor; + protected MessageMetadataResult header; + protected ArrowBuf appMetadata; + protected ArrowBuf body; + + FlightDataReader(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** Parses the FlightData and returns an ArrowMessage. */ + final ArrowMessage toMessage() { + try { + parseFields(); + ArrowBuf adjustedBody = adjustBodyForHeaderType(); + ArrowMessage message = new ArrowMessage(descriptor, header, appMetadata, adjustedBody); + // Ownership transferred to ArrowMessage + appMetadata = null; + body = null; + return message; + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + cleanup(); + } + } + + private ArrowBuf adjustBodyForHeaderType() { + if (header == null) { + return body; + } + switch (ArrowMessage.HeaderType.getHeader(header.headerType())) { + case SCHEMA: + if (body != null && body.capacity() == 0) { + body.close(); + return null; + } + break; + case DICTIONARY_BATCH: + case RECORD_BATCH: + if (body == null) { + return allocator.getEmpty(); + } + break; + case NONE: + case TENSOR: + default: + break; + } + return body; + } + + private void parseFields() throws IOException { + while (hasRemaining()) { + int tag = readTag(); + if (tag == -1) { + break; + } + switch (tag) { + case DESCRIPTOR_TAG: + { + int size = readLength(); + byte[] bytes = readBytes(size); + descriptor = FlightDescriptor.parseFrom(bytes); + break; + } + case HEADER_TAG: + { + int size = readLength(); + byte[] bytes = readBytes(size); + header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size); + break; + } + case APP_METADATA_TAG: + { + int size = readLength(); + closeAppMetadata(); + appMetadata = readBuffer(size); + break; + } + case BODY_TAG: + { + int size = readLength(); + closeBody(); + body = readBuffer(size); + break; + } + default: + // ignore unknown fields + } + } + } + + /** Returns true if there is more data to read. */ + protected abstract boolean hasRemaining() throws IOException; + + /** Reads the next protobuf tag, or -1 if no more data. */ + protected abstract int readTag() throws IOException; + + /** Reads a varint-encoded length. */ + protected abstract int readLength() throws IOException; + + /** Reads the specified number of bytes into a new byte array. */ + protected abstract byte[] readBytes(int size) throws IOException; + + /** Reads the specified number of bytes into an ArrowBuf. */ + protected abstract ArrowBuf readBuffer(int size) throws IOException; + + /** Called in finally block to clean up resources. Subclasses can override to add cleanup. */ + protected void cleanup() { + closeAppMetadata(); + closeBody(); + } + + private void closeAppMetadata() { + if (appMetadata != null) { + appMetadata.close(); + appMetadata = null; + } + } + + private void closeBody() { + if (body != null) { + body.close(); + body = null; + } + } + } + + /** Parses FlightData from an InputStream, copying data into Arrow-managed buffers. */ + static final class InputStreamReader extends FlightDataReader { + private final InputStream stream; + + InputStreamReader(BufferAllocator allocator, InputStream stream) { + super(allocator); + this.stream = stream; + } + + @Override + protected boolean hasRemaining() throws IOException { + return stream.available() > 0; + } + + @Override + protected int readTag() throws IOException { + int tagFirstByte = stream.read(); + if (tagFirstByte == -1) { + return -1; + } + return CodedInputStream.readRawVarint32(tagFirstByte, stream); + } + + @Override + protected int readLength() throws IOException { + int firstByte = stream.read(); + return CodedInputStream.readRawVarint32(firstByte, stream); + } + + @Override + protected byte[] readBytes(int size) throws IOException { + byte[] bytes = new byte[size]; + ByteStreams.readFully(stream, bytes); + return bytes; + } + + @Override + protected ArrowBuf readBuffer(int size) throws IOException { + ArrowBuf buf = allocator.buffer(size); + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + return buf; + } + } + + /** Parses FlightData from an ArrowBuf, using zero-copy slicing for large fields. */ + static final class ArrowBufReader extends FlightDataReader { + private static final Logger LOG = LoggerFactory.getLogger(ArrowBufReader.class); + + private final ArrowBuf backingBuffer; + private final ByteBuffer buffer; + + ArrowBufReader(BufferAllocator allocator, ArrowBuf backingBuffer) { + super(allocator); + this.backingBuffer = backingBuffer; + this.buffer = backingBuffer.nioBuffer(0, (int) backingBuffer.capacity()); + } + + static ArrowBufReader tryArrowBufReader(BufferAllocator allocator, InputStream stream) { + if (!(stream instanceof Detachable) || !(stream instanceof HasByteBuffer)) { + return null; + } + + HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; + if (!hasByteBuffer.byteBufferSupported()) { + return null; + } + + ByteBuffer peekBuffer = hasByteBuffer.getByteBuffer(); + if (peekBuffer == null || !peekBuffer.isDirect()) { + return null; + } + + try { + int available = stream.available(); + if (available > 0 && peekBuffer.remaining() < available) { + return null; + } + } catch (IOException ioe) { + return null; + } + + InputStream detachedStream = ((Detachable) stream).detach(); + if (!(detachedStream instanceof HasByteBuffer)) { + closeQuietly(detachedStream); + return null; + } + + ByteBuffer detachedBuffer = ((HasByteBuffer) detachedStream).getByteBuffer(); + if (detachedBuffer == null || !detachedBuffer.isDirect()) { + closeQuietly(detachedStream); + return null; + } + + long bufferAddress = MemoryUtil.getByteBufferAddress(detachedBuffer); + int bufferSize = detachedBuffer.remaining(); + + ForeignAllocation foreignAllocation = + new ForeignAllocation(bufferSize, bufferAddress + detachedBuffer.position()) { + @Override + protected void release0() { + closeQuietly(detachedStream); + } + }; + + try { + ArrowBuf backingBuffer = allocator.wrapForeignAllocation(foreignAllocation); + return new ArrowBufReader(allocator, backingBuffer); + } catch (Throwable t) { + closeQuietly(detachedStream); + throw t; + } + } + + private static void closeQuietly(InputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + LOG.debug("Error closing detached gRPC stream", e); + } + } + } + + @Override + protected void cleanup() { + super.cleanup(); + backingBuffer.close(); + } + + @Override + protected boolean hasRemaining() { + return buffer.hasRemaining(); + } + + @Override + protected int readTag() throws IOException { + if (!buffer.hasRemaining()) { + return -1; + } + int tagFirstByte = buffer.get() & 0xFF; + return readRawVarint32(tagFirstByte); + } + + @Override + protected int readLength() throws IOException { + if (!buffer.hasRemaining()) { + throw new IOException("Unexpected end of buffer"); + } + int firstByte = buffer.get() & 0xFF; + return readRawVarint32(firstByte); + } + + /** + * Decodes a Base 128 Varint from the ByteBuffer. + * + *

This is a manual implementation because CodedInputStream only provides a static helper for + * InputStream, not ByteBuffer. We need direct ByteBuffer access to track positions for + * zero-copy slicing in {@link #readBuffer(int)}. + * + *

Varints are a variable-length encoding for integers used by Protocol Buffers. Each byte + * uses 7 bits for data and 1 bit (MSB) as a continuation flag: + * + *

+ * + *

Bytes are stored in little-endian order (least significant group first). + * + * @see Protocol Buffers + * Encoding: Varints + */ + private int readRawVarint32(int firstByte) throws IOException { + // Check MSB: if 0, this single byte contains the entire value (0-127) + if ((firstByte & 0x80) == 0) { + return firstByte; + } + // Extract lower 7 bits of first byte as the starting result + int result = firstByte & 0x7F; + // Process continuation bytes, shifting each 7-bit group into position + for (int shift = 7; shift < 32; shift += 7) { + if (!buffer.hasRemaining()) { + throw new IOException("Unexpected end of buffer"); + } + int b = buffer.get() & 0xFF; + // OR the 7 data bits into the result at the current shift position + result |= (b & 0x7F) << shift; + // If MSB is 0, we've reached the last byte + if ((b & 0x80) == 0) { + return result; + } + } + // A valid 32-bit varint uses at most 5 bytes (5 * 7 = 35 bits > 32 bits) + throw new IOException("Malformed varint"); + } + + @Override + protected byte[] readBytes(int size) throws IOException { + if (buffer.remaining() < size) { + throw new IOException("Unexpected end of buffer"); + } + byte[] bytes = new byte[size]; + buffer.get(bytes); + return bytes; + } + + @Override + protected ArrowBuf readBuffer(int size) throws IOException { + if (buffer.remaining() < size) { + throw new IOException("Unexpected end of buffer"); + } + int offset = buffer.position(); + buffer.position(offset + size); + backingBuffer.getReferenceManager().retain(); + return backingBuffer.slice(offset, size); + } + } +} diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java index 7f79868ad7..f96532f744 100644 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java @@ -18,19 +18,18 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; +import com.google.common.collect.Iterables; +import com.google.common.io.ByteStreams; +import com.google.protobuf.ByteString; import io.grpc.Detachable; import io.grpc.HasByteBuffer; -import java.io.ByteArrayInputStream; -import java.io.IOException; +import io.grpc.protobuf.ProtoUtils; import java.io.InputStream; import java.nio.ByteBuffer; -import java.util.Random; +import org.apache.arrow.flight.impl.Flight.FlightData; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -38,6 +37,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +/** Tests for zero-copy buffer handling in ArrowMessage parsing. */ public class TestArrowMessageZeroCopy { private BufferAllocator allocator; @@ -53,123 +53,72 @@ public void tearDown() { } @Test - public void testWrapGrpcBufferReturnsNullForRegularInputStream() throws IOException { - byte[] testData = new byte[] {1, 2, 3, 4, 5}; - InputStream stream = new ByteArrayInputStream(testData); - - // ByteArrayInputStream doesn't implement Detachable or HasByteBuffer - ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); - assertNull(result, "Should return null for streams not implementing required interfaces"); - } - - @Test - public void testWrapGrpcBufferSucceedsForDirectBuffer() throws IOException { - byte[] testData = new byte[] {11, 22, 33, 44, 55}; - InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); - - assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); - assertInstanceOf( - HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); - assertTrue( - ((HasByteBuffer) stream).byteBufferSupported(), - "Direct buffer stream should support ByteBuffer"); - assertTrue( - ((HasByteBuffer) stream).getByteBuffer().isDirect(), - "Should have direct ByteBuffer backing"); - - try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { - assertNotNull(result, "Should succeed for gRPC stream with direct buffer"); - assertEquals(testData.length, result.capacity()); - - // Check received data is the same - byte[] readData = new byte[testData.length]; - result.getBytes(0, readData); - assertArrayEquals(testData, readData); + public void testParseUsesDetachedBuffer() throws Exception { + byte[] appMetadataBytes = new byte[] {1, 2, 3}; + byte[] bodyBytes = new byte[] {4, 5, 6, 7}; + FlightDescriptor descriptor = + FlightDescriptor.newBuilder() + .setType(FlightDescriptor.DescriptorType.PATH) + .addPath("path") + .build(); + FlightData flightData = + FlightData.newBuilder() + .setFlightDescriptor(descriptor) + .setAppMetadata(ByteString.copyFrom(appMetadataBytes)) + .setDataBody(ByteString.copyFrom(bodyBytes)) + .build(); + + byte[] serialized; + try (InputStream grpcStream = + ProtoUtils.marshaller(FlightData.getDefaultInstance()).stream(flightData)) { + serialized = ByteStreams.toByteArray(grpcStream); } - } - - @Test - public void testWrapGrpcBufferReturnsNullForRealGrpcHeapByteBuffer() throws IOException { - byte[] testData = new byte[] {1, 2, 3, 4, 5}; - InputStream stream = MockGrpcInputStream.ofHeapBuffer(testData); - - assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); - assertInstanceOf( - HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); - assertTrue( - ((HasByteBuffer) stream).byteBufferSupported(), - "Heap ByteBuffer stream should support ByteBuffer"); - assertFalse( - ((HasByteBuffer) stream).getByteBuffer().isDirect(), "Should have heap ByteBuffer backing"); - - // Zero-copy should return null for heap buffer (not direct) - ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); - assertNull(result, "Should return null for gRPC stream with heap buffer"); - } - - @Test - public void testWrapGrpcBufferReturnsNullWhenByteBufferNotSupported() throws IOException { - byte[] testData = new byte[] {1, 2, 3, 4, 5}; - InputStream stream = MockGrpcInputStream.withoutByteBufferSupport(testData); - - // Verify the stream has the expected gRPC interfaces - assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); - assertInstanceOf( - HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); - // Byte array backed streams don't support ByteBuffer access - assertFalse( - ((HasByteBuffer) stream).byteBufferSupported(), - "Byte array stream should not support ByteBuffer"); - - // Zero-copy should return null when byteBufferSupported() is false - ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); - assertNull(result, "Should return null for gRPC stream without ByteBuffer support"); - } - - @Test - public void testWrapGrpcBufferMemoryAccounting() throws IOException { - byte[] testData = new byte[1024]; - new Random(42).nextBytes(testData); - InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); - assertEquals(0, allocator.getAllocatedMemory()); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); - ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); - assertNotNull(result, "Should succeed for gRPC stream with direct buffer"); - assertEquals(testData.length, allocator.getAllocatedMemory()); + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + assertEquals(descriptor, message.getDescriptor()); - byte[] readData = new byte[testData.length]; - result.getBytes(0, readData); - assertArrayEquals(testData, readData); + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + byte[] appMetadataRead = new byte[appMetadataBytes.length]; + appMetadata.getBytes(0, appMetadataRead); + assertArrayEquals(appMetadataBytes, appMetadataRead); - result.close(); - assertEquals(0, allocator.getAllocatedMemory()); + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + byte[] bodyRead = new byte[bodyBytes.length]; + body.getBytes(0, bodyRead); + assertArrayEquals(bodyBytes, bodyRead); + } } @Test - public void testWrapGrpcBufferReturnsNullForInsufficientData() throws IOException { - byte[] testData = new byte[] {1, 2, 3}; - InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); + public void testFallbackDoesNotDetachStream() throws Exception { + byte[] appMetadataBytes = new byte[] {8, 9}; + byte[] bodyBytes = new byte[] {10, 11, 12}; + FlightDescriptor descriptor = + FlightDescriptor.newBuilder() + .setType(FlightDescriptor.DescriptorType.PATH) + .addPath("fallback") + .build(); + FlightData flightData = + FlightData.newBuilder() + .setFlightDescriptor(descriptor) + .setAppMetadata(ByteString.copyFrom(appMetadataBytes)) + .setDataBody(ByteString.copyFrom(bodyBytes)) + .build(); + + byte[] serialized; + try (InputStream grpcStream = + ProtoUtils.marshaller(FlightData.getDefaultInstance()).stream(flightData)) { + serialized = ByteStreams.toByteArray(grpcStream); + } - // Request more data than available - ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, 10); - assertNull(result, "Should return null when buffer has insufficient data"); - } + MockGrpcInputStream stream = MockGrpcInputStream.ofHeapBuffer(serialized); - @Test - public void testWrapGrpcBufferLargeData() throws IOException { - byte[] testData = new byte[64 * 1024]; - new Random(42).nextBytes(testData); - InputStream stream = MockGrpcInputStream.ofDirectBuffer(testData); - - try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { - assertNotNull(result, "Should succeed for large data with real gRPC stream"); - assertEquals(testData.length, result.capacity()); - - // Verify data integrity - byte[] readData = new byte[testData.length]; - result.getBytes(0, readData); - assertArrayEquals(testData, readData); + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + assertEquals(descriptor, message.getDescriptor()); + assertEquals(0, stream.getDetachCount()); } } @@ -178,6 +127,7 @@ private static class MockGrpcInputStream extends InputStream implements Detachable, HasByteBuffer { private ByteBuffer buffer; private final boolean byteBufferSupported; + private int detachCount; private MockGrpcInputStream(ByteBuffer buffer, boolean byteBufferSupported) { this.buffer = buffer; @@ -194,10 +144,6 @@ static MockGrpcInputStream ofHeapBuffer(byte[] data) { return new MockGrpcInputStream(ByteBuffer.wrap(data), true); } - static MockGrpcInputStream withoutByteBufferSupport(byte[] data) { - return new MockGrpcInputStream(ByteBuffer.wrap(data), false); - } - @Override public boolean byteBufferSupported() { return byteBufferSupported; @@ -210,16 +156,26 @@ public ByteBuffer getByteBuffer() { @Override public InputStream detach() { + detachCount++; ByteBuffer detached = this.buffer; this.buffer = null; return new MockGrpcInputStream(detached, byteBufferSupported); } + int getDetachCount() { + return detachCount; + } + @Override public int read() { return (buffer != null && buffer.hasRemaining()) ? (buffer.get() & 0xFF) : -1; } + @Override + public int available() { + return buffer == null ? 0 : buffer.remaining(); + } + @Override public void close() { buffer = null; From abc85bcd263f413a410182e7e2f3e61ff12343dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Fri, 16 Jan 2026 11:12:20 +0000 Subject: [PATCH 06/12] simplify ArrowBufReader decoding logic --- .../apache/arrow/flight/FlightDataParser.java | 91 ++++--------------- 1 file changed, 18 insertions(+), 73 deletions(-) diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java index 7143f8b08f..cb870da7c7 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java @@ -116,31 +116,32 @@ private void parseFields() throws IOException { if (tag == -1) { break; } + int size = readLength(); switch (tag) { case DESCRIPTOR_TAG: { - int size = readLength(); byte[] bytes = readBytes(size); descriptor = FlightDescriptor.parseFrom(bytes); break; } case HEADER_TAG: { - int size = readLength(); byte[] bytes = readBytes(size); header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size); break; } case APP_METADATA_TAG: { - int size = readLength(); + // Called before reading a new value to handle duplicate protobuf fields + // (last occurrence wins per spec) and prevent memory leaks. closeAppMetadata(); appMetadata = readBuffer(size); break; } case BODY_TAG: { - int size = readLength(); + // Called before reading a new value to handle duplicate protobuf fields + // (last occurrence wins per spec) and prevent memory leaks. closeBody(); body = readBuffer(size); break; @@ -239,12 +240,13 @@ static final class ArrowBufReader extends FlightDataReader { private static final Logger LOG = LoggerFactory.getLogger(ArrowBufReader.class); private final ArrowBuf backingBuffer; - private final ByteBuffer buffer; + private final CodedInputStream codedInput; ArrowBufReader(BufferAllocator allocator, ArrowBuf backingBuffer) { super(allocator); this.backingBuffer = backingBuffer; - this.buffer = backingBuffer.nioBuffer(0, (int) backingBuffer.capacity()); + ByteBuffer buffer = backingBuffer.nioBuffer(0, (int) backingBuffer.capacity()); + this.codedInput = CodedInputStream.newInstance(buffer); } static ArrowBufReader tryArrowBufReader(BufferAllocator allocator, InputStream stream) { @@ -320,89 +322,32 @@ protected void cleanup() { } @Override - protected boolean hasRemaining() { - return buffer.hasRemaining(); + protected boolean hasRemaining() throws IOException { + return !codedInput.isAtEnd(); } @Override protected int readTag() throws IOException { - if (!buffer.hasRemaining()) { - return -1; - } - int tagFirstByte = buffer.get() & 0xFF; - return readRawVarint32(tagFirstByte); + int tag = codedInput.readTag(); + return tag == 0 ? -1 : tag; } @Override protected int readLength() throws IOException { - if (!buffer.hasRemaining()) { - throw new IOException("Unexpected end of buffer"); - } - int firstByte = buffer.get() & 0xFF; - return readRawVarint32(firstByte); - } - - /** - * Decodes a Base 128 Varint from the ByteBuffer. - * - *

This is a manual implementation because CodedInputStream only provides a static helper for - * InputStream, not ByteBuffer. We need direct ByteBuffer access to track positions for - * zero-copy slicing in {@link #readBuffer(int)}. - * - *

Varints are a variable-length encoding for integers used by Protocol Buffers. Each byte - * uses 7 bits for data and 1 bit (MSB) as a continuation flag: - * - *

- * - *

Bytes are stored in little-endian order (least significant group first). - * - * @see Protocol Buffers - * Encoding: Varints - */ - private int readRawVarint32(int firstByte) throws IOException { - // Check MSB: if 0, this single byte contains the entire value (0-127) - if ((firstByte & 0x80) == 0) { - return firstByte; - } - // Extract lower 7 bits of first byte as the starting result - int result = firstByte & 0x7F; - // Process continuation bytes, shifting each 7-bit group into position - for (int shift = 7; shift < 32; shift += 7) { - if (!buffer.hasRemaining()) { - throw new IOException("Unexpected end of buffer"); - } - int b = buffer.get() & 0xFF; - // OR the 7 data bits into the result at the current shift position - result |= (b & 0x7F) << shift; - // If MSB is 0, we've reached the last byte - if ((b & 0x80) == 0) { - return result; - } - } - // A valid 32-bit varint uses at most 5 bytes (5 * 7 = 35 bits > 32 bits) - throw new IOException("Malformed varint"); + return codedInput.readRawVarint32(); } @Override protected byte[] readBytes(int size) throws IOException { - if (buffer.remaining() < size) { - throw new IOException("Unexpected end of buffer"); - } - byte[] bytes = new byte[size]; - buffer.get(bytes); - return bytes; + // Reads size bytes and creates a copy + return codedInput.readRawBytes(size); } @Override protected ArrowBuf readBuffer(int size) throws IOException { - if (buffer.remaining() < size) { - throw new IOException("Unexpected end of buffer"); - } - int offset = buffer.position(); - buffer.position(offset + size); + // CodedInputStream advances the shared ByteBuffer; use its read count for zero-copy slicing. + int offset = codedInput.getTotalBytesRead(); + codedInput.skipRawBytes(size); backingBuffer.getReferenceManager().retain(); return backingBuffer.slice(offset, size); } From f5cf5916d2430d9ab3a58d1a23812e103fb5b97a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Fri, 16 Jan 2026 12:43:47 +0000 Subject: [PATCH 07/12] add ArrowMessage.frame tests --- .../TestFlightDataParserDuplicateFields.java | 359 ++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightDataParserDuplicateFields.java diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightDataParserDuplicateFields.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightDataParserDuplicateFields.java new file mode 100644 index 0000000000..6677b64116 --- /dev/null +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightDataParserDuplicateFields.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.google.common.collect.Iterables; +import com.google.common.io.ByteStreams; +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedOutputStream; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.protobuf.ProtoUtils; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import org.apache.arrow.flight.impl.Flight.FlightData; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests FlightData parsing for duplicate field handling and well-formed messages. Covers both + * InputStream (with copying) and ArrowBuf (zero-copy) parsing paths. Verifies that duplicate + * protobuf fields use last-occurrence-wins semantics without memory leaks. + */ +public class TestFlightDataParserDuplicateFields { + + private BufferAllocator allocator; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + /** Verifies duplicate app_metadata fields via InputStream path use last-occurrence-wins. */ + @Test + public void testDuplicateAppMetadataInputStream() throws Exception { + byte[] firstAppMetadata = new byte[] {1, 2, 3}; + byte[] secondAppMetadata = new byte[] {4, 5, 6, 7, 8}; + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, firstAppMetadata), + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, secondAppMetadata))); + InputStream stream = new ByteArrayInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + // Use readableBytes() instead of capacity() since allocator may round up + assertEquals(secondAppMetadata.length, appMetadata.readableBytes()); + + byte[] actual = new byte[secondAppMetadata.length]; + appMetadata.getBytes(0, actual); + assertArrayEquals(secondAppMetadata, actual); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** + * Verifies duplicate app_metadata fields via zero-copy ArrowBuf path use last-occurrence-wins. + */ + @Test + public void testDuplicateAppMetadataArrowBuf() throws Exception { + byte[] firstAppMetadata = new byte[] {1, 2, 3}; + byte[] secondAppMetadata = new byte[] {4, 5, 6, 7, 8}; + + // Verify clean start + assertEquals(0, allocator.getAllocatedMemory()); + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, firstAppMetadata), + Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, secondAppMetadata))); + InputStream stream = new DetachableDirectBufferInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + assertEquals(secondAppMetadata.length, appMetadata.readableBytes()); + + byte[] actual = new byte[secondAppMetadata.length]; + appMetadata.getBytes(0, actual); + assertArrayEquals(secondAppMetadata, actual); + + // Zero-copy: only the backing buffer (serialized message) should be allocated + assertEquals(serialized.length, allocator.getAllocatedMemory()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies duplicate body fields via InputStream path use last-occurrence-wins. */ + @Test + public void testDuplicateBodyInputStream() throws Exception { + byte[] firstBody = new byte[] {10, 20, 30}; + byte[] secondBody = new byte[] {40, 50, 60, 70}; + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, firstBody), + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, secondBody))); + InputStream stream = new ByteArrayInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(secondBody.length, body.readableBytes()); + + byte[] actual = new byte[secondBody.length]; + body.getBytes(0, actual); + assertArrayEquals(secondBody, actual); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies duplicate body fields via zero-copy ArrowBuf path use last-occurrence-wins. */ + @Test + public void testDuplicateBodyArrowBuf() throws Exception { + byte[] firstBody = new byte[] {10, 20, 30}; + byte[] secondBody = new byte[] {40, 50, 60, 70}; + + // Verify clean start + assertEquals(0, allocator.getAllocatedMemory()); + + byte[] serialized = + buildFlightDataDescriptors( + List.of( + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, firstBody), + Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, secondBody))); + InputStream stream = new DetachableDirectBufferInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(secondBody.length, body.readableBytes()); + + byte[] actual = new byte[secondBody.length]; + body.getBytes(0, actual); + assertArrayEquals(secondBody, actual); + + // Zero-copy: only the backing buffer (serialized message) should be allocated + assertEquals(serialized.length, allocator.getAllocatedMemory()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies well-formed FlightData message parsing via InputStream path. */ + @Test + public void testFieldsInputStream() throws Exception { + byte[] appMetadataBytes = new byte[] {100, 101, 102}; + byte[] bodyBytes = new byte[] {50, 51, 52, 53, 54}; + FlightDescriptor expectedDescriptor = createTestDescriptor(); + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + InputStream stream = new ByteArrayInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + // Verify descriptor + assertEquals(expectedDescriptor, message.getDescriptor()); + + // Verify header is present (Schema message type) + assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); + + // Verify app metadata + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + assertEquals(appMetadataBytes.length, appMetadata.readableBytes()); + byte[] actualAppMetadata = new byte[appMetadataBytes.length]; + appMetadata.getBytes(0, actualAppMetadata); + assertArrayEquals(appMetadataBytes, actualAppMetadata); + + // Verify body + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(bodyBytes.length, body.readableBytes()); + byte[] actualBody = new byte[bodyBytes.length]; + body.getBytes(0, actualBody); + assertArrayEquals(bodyBytes, actualBody); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + /** Verifies well-formed FlightData message parsing via zero-copy ArrowBuf path. */ + @Test + public void testFieldsArrowBuf() throws Exception { + byte[] appMetadataBytes = new byte[] {100, 101, 102}; + byte[] bodyBytes = new byte[] {50, 51, 52, 53, 54}; + FlightDescriptor expectedDescriptor = createTestDescriptor(); + + assertEquals(0, allocator.getAllocatedMemory()); + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + InputStream stream = new DetachableDirectBufferInputStream(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + // Verify descriptor + assertEquals(expectedDescriptor, message.getDescriptor()); + + // Verify header is present (Schema message type) + assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); + + // Verify app metadata + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + assertEquals(appMetadataBytes.length, appMetadata.readableBytes()); + byte[] actualAppMetadata = new byte[appMetadataBytes.length]; + appMetadata.getBytes(0, actualAppMetadata); + assertArrayEquals(appMetadataBytes, actualAppMetadata); + + // Verify body + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + assertEquals(bodyBytes.length, body.readableBytes()); + byte[] actualBody = new byte[bodyBytes.length]; + body.getBytes(0, actualBody); + assertArrayEquals(bodyBytes, actualBody); + + // Zero-copy: only the backing buffer (serialized message) should be allocated + assertEquals(serialized.length, allocator.getAllocatedMemory()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + + // Helper methods to build complete FlightData messages + + private FlightDescriptor createTestDescriptor() { + return FlightDescriptor.newBuilder() + .setType(FlightDescriptor.DescriptorType.PATH) + .addPath("test") + .addPath("path") + .build(); + } + + private byte[] createSchemaHeader() { + Schema schema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8()))); + ByteBuffer headerBuffer = MessageSerializer.serializeMetadata(schema, IpcOption.DEFAULT); + byte[] headerBytes = new byte[headerBuffer.remaining()]; + headerBuffer.get(headerBytes); + return headerBytes; + } + + private byte[] buildFlightDataWithBothFields(byte[] appMetadata, byte[] body) throws IOException { + FlightData flightData = + FlightData.newBuilder() + .setFlightDescriptor(createTestDescriptor()) + .setDataHeader(ByteString.copyFrom(createSchemaHeader())) + .setAppMetadata(ByteString.copyFrom(appMetadata)) + .setDataBody(ByteString.copyFrom(body)) + .build(); + try (InputStream grpcStream = + ProtoUtils.marshaller(FlightData.getDefaultInstance()).stream(flightData)) { + return ByteStreams.toByteArray(grpcStream); + } + } + + // Helper methods to build FlightData messages with duplicate fields + + private byte[] buildFlightDataDescriptors(List> descriptors) throws IOException { + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + CodedOutputStream cos = CodedOutputStream.newInstance(baos); + + for (Pair descriptor : descriptors) { + cos.writeBytes(descriptor.getKey(), ByteString.copyFrom(descriptor.getValue())); + } + cos.flush(); + return baos.toByteArray(); + } + + /** Mock InputStream implementing gRPC's Detachable and HasByteBuffer for testing zero-copy. */ + private static class DetachableDirectBufferInputStream extends InputStream + implements Detachable, HasByteBuffer { + private ByteBuffer buffer; + + DetachableDirectBufferInputStream(byte[] data) { + this.buffer = ByteBuffer.allocateDirect(data.length); + this.buffer.put(data).flip(); + } + + private DetachableDirectBufferInputStream(ByteBuffer buffer) { + this.buffer = buffer; + } + + @Override + public boolean byteBufferSupported() { + return true; + } + + @Override + public ByteBuffer getByteBuffer() { + return buffer; + } + + @Override + public InputStream detach() { + ByteBuffer detached = this.buffer; + this.buffer = null; + return new DetachableDirectBufferInputStream(detached); + } + + @Override + public int read() { + return (buffer != null && buffer.hasRemaining()) ? (buffer.get() & 0xFF) : -1; + } + + @Override + public int available() { + return buffer == null ? 0 : buffer.remaining(); + } + + @Override + public void close() { + buffer = null; + } + } +} From cb04e7490f84c9635196525a967a68909f098275 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Fri, 16 Jan 2026 12:56:53 +0000 Subject: [PATCH 08/12] merge tests --- flight/flight-core/pom.xml | 6 + ...Fields.java => TestArrowMessageParse.java} | 63 ++++-- .../flight/TestArrowMessageZeroCopy.java | 184 ------------------ 3 files changed, 52 insertions(+), 201 deletions(-) rename flight/flight-core/src/test/java/org/apache/arrow/flight/{TestFlightDataParserDuplicateFields.java => TestArrowMessageParse.java} (86%) delete mode 100644 flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java diff --git a/flight/flight-core/pom.xml b/flight/flight-core/pom.xml index de06e0518e..aa82e55b85 100644 --- a/flight/flight-core/pom.xml +++ b/flight/flight-core/pom.xml @@ -146,6 +146,12 @@ under the License. test-jar test + + org.apache.commons + commons-lang3 + 3.20.0 + test + diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightDataParserDuplicateFields.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java similarity index 86% rename from flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightDataParserDuplicateFields.java rename to flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java index 6677b64116..b4aa6a756e 100644 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightDataParserDuplicateFields.java +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java @@ -50,11 +50,11 @@ import org.junit.jupiter.api.Test; /** - * Tests FlightData parsing for duplicate field handling and well-formed messages. Covers both - * InputStream (with copying) and ArrowBuf (zero-copy) parsing paths. Verifies that duplicate - * protobuf fields use last-occurrence-wins semantics without memory leaks. + * Tests FlightData parsing including duplicate field handling, well-formed messages, and zero-copy + * behavior. Covers both InputStream (with copying) and ArrowBuf (zero-copy) parsing paths. Verifies + * that duplicate protobuf fields use last-occurrence-wins semantics without memory leaks. */ -public class TestFlightDataParserDuplicateFields { +public class TestArrowMessageParse { private BufferAllocator allocator; @@ -110,7 +110,7 @@ public void testDuplicateAppMetadataArrowBuf() throws Exception { List.of( Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, firstAppMetadata), Pair.of(FlightData.APP_METADATA_FIELD_NUMBER, secondAppMetadata))); - InputStream stream = new DetachableDirectBufferInputStream(serialized); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { ArrowBuf appMetadata = message.getApplicationMetadata(); @@ -166,7 +166,7 @@ public void testDuplicateBodyArrowBuf() throws Exception { List.of( Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, firstBody), Pair.of(FlightData.DATA_BODY_FIELD_NUMBER, secondBody))); - InputStream stream = new DetachableDirectBufferInputStream(serialized); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); @@ -229,7 +229,7 @@ public void testFieldsArrowBuf() throws Exception { assertEquals(0, allocator.getAllocatedMemory()); byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); - InputStream stream = new DetachableDirectBufferInputStream(serialized); + InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { // Verify descriptor @@ -260,6 +260,21 @@ public void testFieldsArrowBuf() throws Exception { assertEquals(0, allocator.getAllocatedMemory()); } + /** Verifies that heap buffers fall back to InputStream path without calling detach(). */ + @Test + public void testHeapBufferFallbackDoesNotDetach() throws Exception { + byte[] appMetadataBytes = new byte[] {8, 9}; + byte[] bodyBytes = new byte[] {10, 11, 12}; + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + MockGrpcInputStream stream = MockGrpcInputStream.ofHeapBuffer(serialized); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + assertNotNull(message.getDescriptor()); + assertEquals(0, stream.getDetachCount()); + } + } + // Helper methods to build complete FlightData messages private FlightDescriptor createTestDescriptor() { @@ -298,7 +313,8 @@ private byte[] buildFlightDataWithBothFields(byte[] appMetadata, byte[] body) th // Helper methods to build FlightData messages with duplicate fields - private byte[] buildFlightDataDescriptors(List> descriptors) throws IOException { + private byte[] buildFlightDataDescriptors(List> descriptors) + throws IOException { ByteArrayOutputStream baos = new ByteArrayOutputStream(); CodedOutputStream cos = CodedOutputStream.newInstance(baos); @@ -311,34 +327,47 @@ private byte[] buildFlightDataDescriptors(List> descriptor } /** Mock InputStream implementing gRPC's Detachable and HasByteBuffer for testing zero-copy. */ - private static class DetachableDirectBufferInputStream extends InputStream + private static class MockGrpcInputStream extends InputStream implements Detachable, HasByteBuffer { private ByteBuffer buffer; + private final boolean byteBufferSupported; + private int detachCount; - DetachableDirectBufferInputStream(byte[] data) { - this.buffer = ByteBuffer.allocateDirect(data.length); - this.buffer.put(data).flip(); + private MockGrpcInputStream(ByteBuffer buffer, boolean byteBufferSupported) { + this.buffer = buffer; + this.byteBufferSupported = byteBufferSupported; } - private DetachableDirectBufferInputStream(ByteBuffer buffer) { - this.buffer = buffer; + static MockGrpcInputStream ofDirectBuffer(byte[] data) { + ByteBuffer buf = ByteBuffer.allocateDirect(data.length); + buf.put(data).flip(); + return new MockGrpcInputStream(buf, true); + } + + static MockGrpcInputStream ofHeapBuffer(byte[] data) { + return new MockGrpcInputStream(ByteBuffer.wrap(data), true); } @Override public boolean byteBufferSupported() { - return true; + return byteBufferSupported; } @Override public ByteBuffer getByteBuffer() { - return buffer; + return byteBufferSupported ? buffer : null; } @Override public InputStream detach() { + detachCount++; ByteBuffer detached = this.buffer; this.buffer = null; - return new DetachableDirectBufferInputStream(detached); + return new MockGrpcInputStream(detached, byteBufferSupported); + } + + int getDetachCount() { + return detachCount; } @Override diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java deleted file mode 100644 index f96532f744..0000000000 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; - -import com.google.common.collect.Iterables; -import com.google.common.io.ByteStreams; -import com.google.protobuf.ByteString; -import io.grpc.Detachable; -import io.grpc.HasByteBuffer; -import io.grpc.protobuf.ProtoUtils; -import java.io.InputStream; -import java.nio.ByteBuffer; -import org.apache.arrow.flight.impl.Flight.FlightData; -import org.apache.arrow.flight.impl.Flight.FlightDescriptor; -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -/** Tests for zero-copy buffer handling in ArrowMessage parsing. */ -public class TestArrowMessageZeroCopy { - - private BufferAllocator allocator; - - @BeforeEach - public void setUp() { - allocator = new RootAllocator(Long.MAX_VALUE); - } - - @AfterEach - public void tearDown() { - allocator.close(); - } - - @Test - public void testParseUsesDetachedBuffer() throws Exception { - byte[] appMetadataBytes = new byte[] {1, 2, 3}; - byte[] bodyBytes = new byte[] {4, 5, 6, 7}; - FlightDescriptor descriptor = - FlightDescriptor.newBuilder() - .setType(FlightDescriptor.DescriptorType.PATH) - .addPath("path") - .build(); - FlightData flightData = - FlightData.newBuilder() - .setFlightDescriptor(descriptor) - .setAppMetadata(ByteString.copyFrom(appMetadataBytes)) - .setDataBody(ByteString.copyFrom(bodyBytes)) - .build(); - - byte[] serialized; - try (InputStream grpcStream = - ProtoUtils.marshaller(FlightData.getDefaultInstance()).stream(flightData)) { - serialized = ByteStreams.toByteArray(grpcStream); - } - - InputStream stream = MockGrpcInputStream.ofDirectBuffer(serialized); - - try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { - assertEquals(descriptor, message.getDescriptor()); - - ArrowBuf appMetadata = message.getApplicationMetadata(); - assertNotNull(appMetadata); - byte[] appMetadataRead = new byte[appMetadataBytes.length]; - appMetadata.getBytes(0, appMetadataRead); - assertArrayEquals(appMetadataBytes, appMetadataRead); - - ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); - byte[] bodyRead = new byte[bodyBytes.length]; - body.getBytes(0, bodyRead); - assertArrayEquals(bodyBytes, bodyRead); - } - } - - @Test - public void testFallbackDoesNotDetachStream() throws Exception { - byte[] appMetadataBytes = new byte[] {8, 9}; - byte[] bodyBytes = new byte[] {10, 11, 12}; - FlightDescriptor descriptor = - FlightDescriptor.newBuilder() - .setType(FlightDescriptor.DescriptorType.PATH) - .addPath("fallback") - .build(); - FlightData flightData = - FlightData.newBuilder() - .setFlightDescriptor(descriptor) - .setAppMetadata(ByteString.copyFrom(appMetadataBytes)) - .setDataBody(ByteString.copyFrom(bodyBytes)) - .build(); - - byte[] serialized; - try (InputStream grpcStream = - ProtoUtils.marshaller(FlightData.getDefaultInstance()).stream(flightData)) { - serialized = ByteStreams.toByteArray(grpcStream); - } - - MockGrpcInputStream stream = MockGrpcInputStream.ofHeapBuffer(serialized); - - try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { - assertEquals(descriptor, message.getDescriptor()); - assertEquals(0, stream.getDetachCount()); - } - } - - /** Mock InputStream implementing gRPC's Detachable and HasByteBuffer for testing zero-copy. */ - private static class MockGrpcInputStream extends InputStream - implements Detachable, HasByteBuffer { - private ByteBuffer buffer; - private final boolean byteBufferSupported; - private int detachCount; - - private MockGrpcInputStream(ByteBuffer buffer, boolean byteBufferSupported) { - this.buffer = buffer; - this.byteBufferSupported = byteBufferSupported; - } - - static MockGrpcInputStream ofDirectBuffer(byte[] data) { - ByteBuffer buf = ByteBuffer.allocateDirect(data.length); - buf.put(data).flip(); - return new MockGrpcInputStream(buf, true); - } - - static MockGrpcInputStream ofHeapBuffer(byte[] data) { - return new MockGrpcInputStream(ByteBuffer.wrap(data), true); - } - - @Override - public boolean byteBufferSupported() { - return byteBufferSupported; - } - - @Override - public ByteBuffer getByteBuffer() { - return byteBufferSupported ? buffer : null; - } - - @Override - public InputStream detach() { - detachCount++; - ByteBuffer detached = this.buffer; - this.buffer = null; - return new MockGrpcInputStream(detached, byteBufferSupported); - } - - int getDetachCount() { - return detachCount; - } - - @Override - public int read() { - return (buffer != null && buffer.hasRemaining()) ? (buffer.get() & 0xFF) : -1; - } - - @Override - public int available() { - return buffer == null ? 0 : buffer.remaining(); - } - - @Override - public void close() { - buffer = null; - } - } -} From 69d0cc21eb7bd6009b4491332c065a5388dfa282 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Fri, 16 Jan 2026 13:13:54 +0000 Subject: [PATCH 09/12] remove redundant checks. add test to check fallback works --- .../apache/arrow/flight/FlightDataParser.java | 12 ++----- .../arrow/flight/TestArrowMessageParse.java | 33 +++++++++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java index cb870da7c7..2de7aafbc0 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; +import java.util.Objects; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; @@ -274,19 +275,10 @@ static ArrowBufReader tryArrowBufReader(BufferAllocator allocator, InputStream s } InputStream detachedStream = ((Detachable) stream).detach(); - if (!(detachedStream instanceof HasByteBuffer)) { - closeQuietly(detachedStream); - return null; - } - ByteBuffer detachedBuffer = ((HasByteBuffer) detachedStream).getByteBuffer(); - if (detachedBuffer == null || !detachedBuffer.isDirect()) { - closeQuietly(detachedStream); - return null; - } long bufferAddress = MemoryUtil.getByteBufferAddress(detachedBuffer); - int bufferSize = detachedBuffer.remaining(); + int bufferSize = Objects.requireNonNull(detachedBuffer).remaining(); ForeignAllocation foreignAllocation = new ForeignAllocation(bufferSize, bufferAddress + detachedBuffer.position()) { diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java index b4aa6a756e..89a56da56e 100644 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java @@ -275,6 +275,39 @@ public void testHeapBufferFallbackDoesNotDetach() throws Exception { } } + /** Verifies fallback to InputStream path when getByteBuffer() returns null. */ + @Test + public void testNullByteBufferFallbackToInputStream() throws Exception { + byte[] appMetadataBytes = new byte[] {20, 21, 22}; + byte[] bodyBytes = new byte[] {30, 31, 32, 33}; + FlightDescriptor expectedDescriptor = createTestDescriptor(); + + byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); + MockGrpcInputStream stream = + new MockGrpcInputStream( + ByteBuffer.wrap(serialized), + false); + + try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { + assertEquals(expectedDescriptor, message.getDescriptor()); + + ArrowBuf appMetadata = message.getApplicationMetadata(); + assertNotNull(appMetadata); + byte[] actualAppMetadata = new byte[appMetadataBytes.length]; + appMetadata.getBytes(0, actualAppMetadata); + assertArrayEquals(appMetadataBytes, actualAppMetadata); + + ArrowBuf body = Iterables.getOnlyElement(message.getBufs()); + assertNotNull(body); + byte[] actualBody = new byte[bodyBytes.length]; + body.getBytes(0, actualBody); + assertArrayEquals(bodyBytes, actualBody); + + assertEquals(0, stream.getDetachCount()); + } + assertEquals(0, allocator.getAllocatedMemory()); + } + // Helper methods to build complete FlightData messages private FlightDescriptor createTestDescriptor() { From 68c290e49d64fd58e8220377cec7774cfc9f9824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Fri, 16 Jan 2026 14:07:14 +0000 Subject: [PATCH 10/12] spotless --- .../java/org/apache/arrow/flight/TestArrowMessageParse.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java index 89a56da56e..df9852d02e 100644 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java @@ -283,10 +283,7 @@ public void testNullByteBufferFallbackToInputStream() throws Exception { FlightDescriptor expectedDescriptor = createTestDescriptor(); byte[] serialized = buildFlightDataWithBothFields(appMetadataBytes, bodyBytes); - MockGrpcInputStream stream = - new MockGrpcInputStream( - ByteBuffer.wrap(serialized), - false); + MockGrpcInputStream stream = new MockGrpcInputStream(ByteBuffer.wrap(serialized), false); try (ArrowMessage message = ArrowMessage.createMarshaller(allocator).parse(stream)) { assertEquals(expectedDescriptor, message.getDescriptor()); From 4f23ac115e2fd00653b76a9ae1a0e144ef79f7a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Wed, 25 Mar 2026 10:25:38 +0000 Subject: [PATCH 11/12] Share detached buffers with downstream consumers Scope detached FlightData buffers to a message-local allocator and retain them into downstream allocators as messages are consumed --- .../org/apache/arrow/flight/ArrowMessage.java | 21 +++- .../apache/arrow/flight/FlightDataParser.java | 45 +++++++- .../org/apache/arrow/flight/FlightStream.java | 83 +++++++------- .../arrow/flight/TestArrowMessageParse.java | 108 ++++++++++++++++++ 4 files changed, 212 insertions(+), 45 deletions(-) diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index b768559ca7..cd1fe3f316 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -36,7 +36,6 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import org.apache.arrow.flight.FlightDataParser.ArrowBufReader; import org.apache.arrow.flight.FlightDataParser.FlightDataReader; @@ -138,6 +137,8 @@ public static HeaderType getHeader(byte b) { private final ArrowBuf appMetadata; private final List bufs; private final boolean tryZeroCopyWrite; + // For zero-copy reads, this releases the message-scoped allocator after local buffers close. + private final BufferAllocator messageAllocator; public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option) { this.writeOption = option; @@ -148,6 +149,7 @@ public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option this.descriptor = descriptor; this.appMetadata = null; this.tryZeroCopyWrite = false; + this.messageAllocator = null; } /** @@ -167,6 +169,7 @@ public ArrowMessage( this.descriptor = null; this.appMetadata = appMetadata; this.tryZeroCopyWrite = tryZeroCopy; + this.messageAllocator = null; } public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) { @@ -180,6 +183,7 @@ public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) { this.descriptor = null; this.appMetadata = null; this.tryZeroCopyWrite = false; + this.messageAllocator = null; } /** @@ -195,6 +199,7 @@ public ArrowMessage(ArrowBuf appMetadata) { this.descriptor = null; this.appMetadata = appMetadata; this.tryZeroCopyWrite = false; + this.messageAllocator = null; } public ArrowMessage(FlightDescriptor descriptor) { @@ -205,6 +210,7 @@ public ArrowMessage(FlightDescriptor descriptor) { this.descriptor = descriptor; this.appMetadata = null; this.tryZeroCopyWrite = false; + this.messageAllocator = null; } ArrowMessage( @@ -212,6 +218,15 @@ public ArrowMessage(FlightDescriptor descriptor) { MessageMetadataResult message, ArrowBuf appMetadata, ArrowBuf buf) { + this(descriptor, message, appMetadata, buf, null); + } + + ArrowMessage( + FlightDescriptor descriptor, + MessageMetadataResult message, + ArrowBuf appMetadata, + ArrowBuf buf, + BufferAllocator messageAllocator) { // No need to take IpcOption as this is used for deserialized ArrowMessage coming from the wire. this.writeOption = message != null @@ -224,6 +239,7 @@ public ArrowMessage(FlightDescriptor descriptor) { this.appMetadata = appMetadata; this.bufs = buf == null ? ImmutableList.of() : ImmutableList.of(buf); this.tryZeroCopyWrite = false; + this.messageAllocator = messageAllocator; } public MessageMetadataResult asSchemaMessage() { @@ -264,6 +280,7 @@ public ArrowDictionaryBatch asDictionaryBatch() throws IOException { bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf."); Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH); ArrowBuf underlying = bufs.get(0); + // Retain a reference to keep the batch alive when the message is closed underlying.getReferenceManager().retain(); // Do not set drained - we still want to release our reference @@ -496,6 +513,6 @@ public ArrowMessage parse(InputStream stream) { @Override public void close() throws Exception { - AutoCloseables.close(Iterables.concat(bufs, Collections.singletonList(appMetadata))); + AutoCloseables.close(Iterables.concat(bufs, AutoCloseables.iter(appMetadata, messageAllocator))); } } diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java index 2de7aafbc0..843b9f0ff8 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDataParser.java @@ -25,6 +25,7 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; @@ -45,6 +46,7 @@ * zero-copy slicing when parsing from ArrowBuf. */ final class FlightDataParser { + private static final AtomicLong ALLOCATOR_ID = new AtomicLong(); // Protobuf wire format tags for FlightData fields private static final int DESCRIPTOR_TAG = @@ -74,7 +76,8 @@ final ArrowMessage toMessage() { try { parseFields(); ArrowBuf adjustedBody = adjustBodyForHeaderType(); - ArrowMessage message = new ArrowMessage(descriptor, header, appMetadata, adjustedBody); + ArrowMessage message = + new ArrowMessage(descriptor, header, appMetadata, adjustedBody, getMessageAllocator()); // Ownership transferred to ArrowMessage appMetadata = null; body = null; @@ -168,6 +171,11 @@ private void parseFields() throws IOException { /** Reads the specified number of bytes into an ArrowBuf. */ protected abstract ArrowBuf readBuffer(int size) throws IOException; + /** Additional resources that should be transferred to the parsed ArrowMessage. */ + protected BufferAllocator getMessageAllocator() { + return null; + } + /** Called in finally block to clean up resources. Subclasses can override to add cleanup. */ protected void cleanup() { closeAppMetadata(); @@ -240,11 +248,15 @@ protected ArrowBuf readBuffer(int size) throws IOException { static final class ArrowBufReader extends FlightDataReader { private static final Logger LOG = LoggerFactory.getLogger(ArrowBufReader.class); + private final BufferAllocator messageAllocator; private final ArrowBuf backingBuffer; private final CodedInputStream codedInput; + private boolean transferred; - ArrowBufReader(BufferAllocator allocator, ArrowBuf backingBuffer) { + ArrowBufReader( + BufferAllocator allocator, BufferAllocator messageAllocator, ArrowBuf backingBuffer) { super(allocator); + this.messageAllocator = messageAllocator; this.backingBuffer = backingBuffer; ByteBuffer buffer = backingBuffer.nioBuffer(0, (int) backingBuffer.capacity()); this.codedInput = CodedInputStream.newInstance(buffer); @@ -288,10 +300,16 @@ protected void release0() { } }; + BufferAllocator messageAllocator = + allocator.newChildAllocator( + // Keep detached transport memory scoped to this message until a downstream retain. + "arrow-msg-" + ALLOCATOR_ID.incrementAndGet(), 0, bufferSize); + try { - ArrowBuf backingBuffer = allocator.wrapForeignAllocation(foreignAllocation); - return new ArrowBufReader(allocator, backingBuffer); + ArrowBuf backingBuffer = messageAllocator.wrapForeignAllocation(foreignAllocation); + return new ArrowBufReader(allocator, messageAllocator, backingBuffer); } catch (Throwable t) { + closeQuietly(messageAllocator); closeQuietly(detachedStream); throw t; } @@ -307,10 +325,23 @@ private static void closeQuietly(InputStream stream) { } } + private static void closeQuietly(BufferAllocator allocator) { + if (allocator != null) { + try { + allocator.close(); + } catch (Exception e) { + LOG.debug("Error closing message allocator", e); + } + } + } + @Override protected void cleanup() { super.cleanup(); backingBuffer.close(); + if (!transferred) { + closeQuietly(messageAllocator); + } } @Override @@ -343,5 +374,11 @@ protected ArrowBuf readBuffer(int size) throws IOException { backingBuffer.getReferenceManager().retain(); return backingBuffer.slice(offset, size); } + + @Override + protected BufferAllocator getMessageAllocator() { + transferred = true; + return messageAllocator; + } } } diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index 15cfd6ba85..478d49766c 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -318,13 +318,18 @@ public boolean next() { /** Update our metadata reference with a new one from this message. */ private void updateMetadata(ArrowMessage msg) { - if (this.applicationMetadata != null) { - this.applicationMetadata.close(); + ArrowBuf retainedMetadata = null; + if (msg.getApplicationMetadata() != null) { + // Re-associate metadata with the stream allocator so it can outlive this message. + retainedMetadata = + msg.getApplicationMetadata() + .getReferenceManager() + .retain(msg.getApplicationMetadata(), allocator); } - this.applicationMetadata = msg.getApplicationMetadata(); if (this.applicationMetadata != null) { - this.applicationMetadata.getReferenceManager().retain(); + this.applicationMetadata.close(); } + this.applicationMetadata = retainedMetadata; } /** Ensure the Arrow metadata version doesn't change mid-stream. */ @@ -424,50 +429,49 @@ public void onNext(ArrowMessage msg) { } if (msg.getApplicationMetadata() != null) { enqueue(msg); + } else { + AutoCloseables.closeNoChecked(msg); } break; } case SCHEMA: { - Schema schema = msg.asSchema(); - - // if there is app metadata in the schema message, make sure - // that we don't leak it. - ArrowBuf meta = msg.getApplicationMetadata(); - if (meta != null) { - meta.close(); - } - - final List fields = new ArrayList<>(); - final Map dictionaryMap = new HashMap<>(); - for (final Field originalField : schema.getFields()) { - final Field updatedField = - DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap); - fields.add(updatedField); - } - for (final Map.Entry entry : dictionaryMap.entrySet()) { - dictionaries.put(entry.getValue()); - } - schema = new Schema(fields, schema.getCustomMetadata()); - metadataVersion = - MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version()); try { - MetadataV4UnionChecker.checkRead(schema, metadataVersion); - } catch (IOException e) { - ex = e; - enqueue(DONE_EX); - break; - } + Schema schema = msg.asSchema(); + + final List fields = new ArrayList<>(); + final Map dictionaryMap = new HashMap<>(); + for (final Field originalField : schema.getFields()) { + final Field updatedField = + DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap); + fields.add(updatedField); + } + for (final Map.Entry entry : dictionaryMap.entrySet()) { + dictionaries.put(entry.getValue()); + } + schema = new Schema(fields, schema.getCustomMetadata()); + metadataVersion = + MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version()); + try { + MetadataV4UnionChecker.checkRead(schema, metadataVersion); + } catch (IOException e) { + ex = e; + enqueue(DONE_EX); + break; + } - synchronized (completed) { - if (!completed.isDone()) { - fulfilledRoot = VectorSchemaRoot.create(schema, allocator); - loader = new VectorLoader(fulfilledRoot); - if (msg.getDescriptor() != null) { - descriptor.set(new FlightDescriptor(msg.getDescriptor())); + synchronized (completed) { + if (!completed.isDone()) { + fulfilledRoot = VectorSchemaRoot.create(schema, allocator); + loader = new VectorLoader(fulfilledRoot); + if (msg.getDescriptor() != null) { + descriptor.set(new FlightDescriptor(msg.getDescriptor())); + } + root.set(fulfilledRoot); } - root.set(fulfilledRoot); } + } finally { + AutoCloseables.closeNoChecked(msg); } break; } @@ -480,6 +484,7 @@ public void onNext(ArrowMessage msg) { ex = new UnsupportedOperationException( "Unable to handle message of type: " + msg.getMessageType()); + AutoCloseables.closeNoChecked(msg); enqueue(DONE_EX); } } diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java index df9852d02e..b4b0669a2a 100644 --- a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageParse.java @@ -27,18 +27,28 @@ import io.grpc.Detachable; import io.grpc.HasByteBuffer; import io.grpc.protobuf.ProtoUtils; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Constructor; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; +import org.apache.arrow.flight.FlightProducer.CallContext; +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -305,8 +315,106 @@ public void testNullByteBufferFallbackToInputStream() throws Exception { assertEquals(0, allocator.getAllocatedMemory()); } + @Test + public void testRealFlightSmallBatchLifecycle() throws Exception { + try (BufferAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + FlightServer server = + FlightServer.builder( + rootAllocator, + Location.forGrpcInsecure("localhost", 0), + new NoOpFlightProducer() { + @Override + public void getStream( + CallContext context, Ticket ticket, ServerStreamListener listener) { + try (VectorSchemaRoot root = + VectorSchemaRoot.of(new BigIntVector("a", rootAllocator))) { + BigIntVector vector = (BigIntVector) root.getVector(0); + vector.allocateNew(8); + for (int i = 0; i < 8; i++) { + vector.set(i, i); + } + root.setRowCount(8); + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + }) + .build() + .start(); + FlightClient client = FlightClient.builder(rootAllocator, server.getLocation()).build(); + FlightStream stream = client.getStream(new Ticket(new byte[] {1}))) { + while (stream.next()) { + assertEquals(8, stream.getRoot().getRowCount()); + } + } + } + + @Test + public void testBufferInputStreamLargeRecordBatchLifecycle() throws Exception { + byte[] batchBytes; + try (BufferAllocator writerAllocator = + allocator.newChildAllocator("writer", 0, Long.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.of(new BigIntVector("a", writerAllocator))) { + BigIntVector vector = (BigIntVector) root.getVector(0); + vector.allocateNew(4095); + for (int i = 0; i < 4095; i++) { + vector.set(i, i); + } + root.setRowCount(4095); + + try (ArrowRecordBatch batch = new VectorUnloader(root).getRecordBatch(); + InputStream grpcStream = + ArrowMessage.createMarshaller(writerAllocator).stream( + new ArrowMessage(batch, null, false, IpcOption.DEFAULT))) { + batchBytes = ByteStreams.toByteArray(grpcStream); + } + } + + try (BufferAllocator parseAllocator = allocator.newChildAllocator("parse", 0, Long.MAX_VALUE)) { + try (VectorSchemaRoot loadedRoot = + VectorSchemaRoot.of(new BigIntVector("a", parseAllocator))) { + ArrowMessage message = + ArrowMessage.createMarshaller(parseAllocator) + .parse(createGrpcBufferInputStream(batchBytes)); + assertEquals(batchBytes.length, parseAllocator.getAllocatedMemory()); + assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType()); + + try (ArrowRecordBatch batch = message.asRecordBatch()) { + new VectorLoader(loadedRoot).load(batch); + } finally { + message.close(); + } + + assertEquals(batchBytes.length, parseAllocator.getAllocatedMemory()); + assertEquals(4095, loadedRoot.getRowCount()); + assertEquals(4094L, ((BigIntVector) loadedRoot.getVector(0)).get(4095 - 1)); + } + } + + assertEquals(0, allocator.getAllocatedMemory()); + } + // Helper methods to build complete FlightData messages + private InputStream createGrpcBufferInputStream(byte[] data) throws Exception { + ByteBuf byteBuf = Unpooled.directBuffer(data.length); + byteBuf.writeBytes(data); + + Class readableBufferClass = Class.forName("io.grpc.internal.ReadableBuffer"); + Class nettyReadableBufferClass = Class.forName("io.grpc.netty.NettyReadableBuffer"); + Constructor readableBufferCtor = + nettyReadableBufferClass.getDeclaredConstructor(ByteBuf.class); + readableBufferCtor.setAccessible(true); + Object readableBuffer = readableBufferCtor.newInstance(byteBuf); + + Class bufferInputStreamClass = + Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); + Constructor streamCtor = bufferInputStreamClass.getDeclaredConstructor(readableBufferClass); + streamCtor.setAccessible(true); + return (InputStream) streamCtor.newInstance(readableBuffer); + } + private FlightDescriptor createTestDescriptor() { return FlightDescriptor.newBuilder() .setType(FlightDescriptor.DescriptorType.PATH) From bc07143ffe87669d3b192a17f8d6dded276eae55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9lder=20Greg=C3=B3rio?= Date: Mon, 13 Apr 2026 12:54:39 +0100 Subject: [PATCH 12/12] Apply spotless formatting to ArrowMessage --- .../src/main/java/org/apache/arrow/flight/ArrowMessage.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index cd1fe3f316..c9deae68e1 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -513,6 +513,7 @@ public ArrowMessage parse(InputStream stream) { @Override public void close() throws Exception { - AutoCloseables.close(Iterables.concat(bufs, AutoCloseables.iter(appMetadata, messageAllocator))); + AutoCloseables.close( + Iterables.concat(bufs, AutoCloseables.iter(appMetadata, messageAllocator))); } }