From 98150a8a55bd19d5d53052b5aed2ed094b3b0940 Mon Sep 17 00:00:00 2001 From: Surya Sumanth N Date: Thu, 9 Feb 2023 17:30:00 +0530 Subject: [PATCH 1/4] Hybrid Exchange Buffer --- .../filesystem/HetuHdfsFileSystemClient.java | 10 + .../filesystem/HetuLocalFileSystemClient.java | 7 + .../io/prestosql/exchange/ExchangeSink.java | 15 ++ .../exchange/FileSystemExchangeManager.java | 6 +- .../exchange/FileSystemExchangeSink.java | 37 ++- .../FileSystemExchangeSinkInstanceHandle.java | 6 + .../io/prestosql/exchange/RetryPolicy.java | 1 + .../storage/HetuFileSystemExchangeWriter.java | 25 +- .../execution/SqlStageExecution.java | 6 +- .../buffer/HybridSpoolingBuffer.java | 210 +++++++++++++++ .../execution/buffer/LazyOutputBuffer.java | 82 +++++- .../execution/buffer/OutputBuffer.java | 27 ++ .../scheduler/SqlQueryScheduler.java | 3 +- .../operator/ExchangeClientFactory.java | 1 + .../operator/TaskOutputOperator.java | 23 ++ .../operator/output/PagePartitioner.java | 19 ++ .../testing/TestingRecoveryUtils.java | 6 + .../buffer/TestHybridSpoolingBuffer.java | 249 ++++++++++++++++++ .../TestSpoolingExchangeOutputBuffer.java | 36 ++- .../spi/filesystem/HetuFileSystemClient.java | 3 + 20 files changed, 753 insertions(+), 19 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java create mode 100644 presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java diff --git a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java index 61bacc659..16c634a87 100644 --- a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java +++ b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java @@ -18,6 +18,7 @@ import com.google.common.base.Throwables; import io.airlift.log.Logger; import io.prestosql.spi.filesystem.SupportedFileAttributes; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.PathIsNotEmptyDirectoryException; @@ -267,6 +268,15 @@ public class HetuHdfsFileSystemClient } } + @Override + public void flush(OutputStream outputStream) + throws IOException + { + if (outputStream instanceof FSDataOutputStream) { + ((FSDataOutputStream) outputStream).hflush(); + } + } + @Override public void close() throws IOException diff --git a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java index f9e0c9eb8..c4a52c158 100644 --- a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java +++ b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java @@ -249,4 +249,11 @@ public class HetuLocalFileSystemClient String glob = prefix + "*" + suffix; return StreamSupport.stream(newDirectoryStream(path, glob).spliterator(), false); } + + @Override + public void flush(OutputStream outputStream) + throws IOException + { + outputStream.flush(); + } } diff --git a/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java b/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java index eb174df1a..0386dcd41 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java +++ b/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java @@ -16,8 +16,13 @@ package io.prestosql.exchange; import io.airlift.slice.Slice; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; import io.prestosql.spi.Page; +import javax.crypto.SecretKey; + +import java.net.URI; +import java.util.Optional; import java.util.concurrent.CompletableFuture; public interface ExchangeSink @@ -40,4 +45,14 @@ public interface ExchangeSink { return DirectSerialisationType.OFF; } + + FileSystemExchangeStorage getExchangeStorage(); + + URI getOutputDirectory(); + + Optional getSecretKey(); + + boolean isExchangeCompressionEnabled(); + + int getPartitionId(); } diff --git a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java index c7fa8b459..ec4424b4b 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java +++ b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java @@ -137,7 +137,8 @@ public class FileSystemExchangeManager exchangeSinkBuffersPerPartition, exchangeSinkMaxFileSizeInBytes, directSerialisationType, - directSerialisationBufferSize); + directSerialisationBufferSize, + instanceHandle.getPartiitonId()); } @Override @@ -157,7 +158,8 @@ public class FileSystemExchangeManager exchangeSinkBuffersPerPartition, exchangeSinkMaxFileSizeInBytes, serType, - directSerialisationBufferSize); + directSerialisationBufferSize, + instanceHandle.getPartiitonId()); } @Override diff --git a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java index d985a4d46..b2f8e77f8 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java +++ b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java @@ -91,6 +91,7 @@ public class FileSystemExchangeSink private boolean closed; private final DirectSerialisationType directSerialisationType; private final int directSerialisationBufferSize; + private final int partitionId; public FileSystemExchangeSink( FileSystemExchangeStorage exchangeStorage, @@ -105,7 +106,8 @@ public class FileSystemExchangeSink int exchangeSinkBuffersPerPartition, long maxFileSizeInBytes, DirectSerialisationType directSerialisationType, - int directSerialisationBufferSize) + int directSerialisationBufferSize, + int partitionId) { checkArgument(maxPageStorageSizeInBytes <= maxFileSizeInBytes, format("maxPageStorageSizeInBytes %s exceeded maxFileSizeInBytes %s", succinctBytes(maxPageStorageSizeInBytes), succinctBytes(maxFileSizeInBytes))); @@ -129,6 +131,7 @@ public class FileSystemExchangeSink else { this.bufferPool = null; } + this.partitionId = partitionId; } @Override @@ -142,6 +145,36 @@ public class FileSystemExchangeSink return directSerialisationType; } + @Override + public URI getOutputDirectory() + { + return outputDirectory; + } + + @Override + public Optional getSecretKey() + { + return secretKey; + } + + @Override + public boolean isExchangeCompressionEnabled() + { + return exchangeCompressionEnabled; + } + + @Override + public FileSystemExchangeStorage getExchangeStorage() + { + return exchangeStorage; + } + + @Override + public int getPartitionId() + { + return partitionId; + } + @Override public void add(int partitionId, Slice data) { @@ -428,7 +461,7 @@ public class FileSystemExchangeSink currentBuffer.writeBytes(slice.getBytes(position, writableBytes)); position += writableBytes; - flushIfNeeded(false); + flushIfNeeded(true); } } diff --git a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java index 9e86443ac..74c11ae6b 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java +++ b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java @@ -55,4 +55,10 @@ public class FileSystemExchangeSinkInstanceHandle { return outputPartitionCount; } + + @JsonProperty + public int getPartiitonId() + { + return sinkHandle.getPartitionId(); + } } diff --git a/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java b/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java index d03e98e94..459598ee7 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java +++ b/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java @@ -19,6 +19,7 @@ public enum RetryPolicy { TASK(RetryMode.RETRIES_ENABLED), NONE(RetryMode.NO_RETRIES), + TASK_ASYNC(RetryMode.RETRIES_ENABLED), /**/; private final RetryMode retryMode; diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java index 79e8b0902..49bbfc769 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java @@ -55,35 +55,39 @@ public class HetuFileSystemExchangeWriter private final OutputStream outputStream; private final DirectSerialisationType directSerialisationType; private final int directSerialisationBufferSize; + private final HetuFileSystemClient fileSystemClient; + private final OutputStream delegateOutputStream; public HetuFileSystemExchangeWriter(URI file, HetuFileSystemClient fileSystemClient, Optional secretKey, boolean exchangeCompressionEnabled, AlgorithmParameterSpec algorithmParameterSpec, FileSystemExchangeConfig.DirectSerialisationType directSerialisationType, int directSerialisationBufferSize) { this.directSerialisationBufferSize = directSerialisationBufferSize; this.directSerialisationType = directSerialisationType; + this.fileSystemClient = fileSystemClient; try { Path path = Paths.get(file.toString()); + this.delegateOutputStream = fileSystemClient.newOutputStream(path); if (secretKey.isPresent() && exchangeCompressionEnabled) { Cipher cipher = Cipher.getInstance(CIPHER_TRANSFORMATION); cipher.init(Cipher.ENCRYPT_MODE, secretKey.get(), algorithmParameterSpec); - this.outputStream = new SnappyFramedOutputStream(new CipherOutputStream(fileSystemClient.newOutputStream(path), cipher)); + this.outputStream = new SnappyFramedOutputStream(new CipherOutputStream(delegateOutputStream, cipher)); } else if (secretKey.isPresent()) { Cipher cipher = Cipher.getInstance(CIPHER_TRANSFORMATION); cipher.init(Cipher.ENCRYPT_MODE, secretKey.get(), algorithmParameterSpec); - this.outputStream = new CipherOutputStream(fileSystemClient.newOutputStream(path), cipher); + this.outputStream = new CipherOutputStream(delegateOutputStream, cipher); } else if (exchangeCompressionEnabled) { - this.outputStream = new SnappyFramedOutputStream(new OutputStreamSliceOutput(fileSystemClient.newOutputStream(path), directSerialisationBufferSize)); + this.outputStream = new SnappyFramedOutputStream(new OutputStreamSliceOutput(delegateOutputStream, directSerialisationBufferSize)); } else { if (directSerialisationType == DirectSerialisationType.KRYO) { - this.outputStream = new Output(fileSystemClient.newOutputStream(path), directSerialisationBufferSize); + this.outputStream = new Output(delegateOutputStream, directSerialisationBufferSize); } else if (directSerialisationType == DirectSerialisationType.JAVA) { - this.outputStream = new OutputStreamSliceOutput(fileSystemClient.newOutputStream(path), directSerialisationBufferSize); + this.outputStream = new OutputStreamSliceOutput(delegateOutputStream, directSerialisationBufferSize); } else { - this.outputStream = new OutputStreamSliceOutput(fileSystemClient.newOutputStream(path), directSerialisationBufferSize); + this.outputStream = new OutputStreamSliceOutput(delegateOutputStream, directSerialisationBufferSize); } } } @@ -98,6 +102,8 @@ public class HetuFileSystemExchangeWriter { try { outputStream.write(slice.getBytes()); + outputStream.flush(); + fileSystemClient.flush(delegateOutputStream); } catch (IOException | RuntimeException e) { return immediateFailedFuture(e); @@ -110,6 +116,13 @@ public class HetuFileSystemExchangeWriter { checkState(directSerialisationType != DirectSerialisationType.OFF, "Should be used with direct serialization is enabled!"); serde.serialize(outputStream, page); + try { + outputStream.flush(); + fileSystemClient.flush(delegateOutputStream); + } + catch (IOException | RuntimeException e) { + return immediateFailedFuture(e); + } return immediateFuture(null); } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java index ad847703e..b3a873dd7 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java @@ -550,7 +550,9 @@ public final class SqlStageExecution if (this.outputBuffers.compareAndSet(currentOutputBuffers, outputBuffers)) { for (RemoteTask task : getAllTasks()) { - task.setOutputBuffers(outputBuffers); + OutputBuffers localOutputBuffers = new OutputBuffers(outputBuffers.getType(), outputBuffers.getVersion(), + outputBuffers.isNoMoreBufferIds(), outputBuffers.getBuffers(), outputBuffers.getExchangeSinkInstanceHandle()); + task.setOutputBuffers(localOutputBuffers); } return; } @@ -658,6 +660,8 @@ public final class SqlStageExecution }); OutputBuffers localOutputBuffers = this.outputBuffers.get(); + localOutputBuffers = new OutputBuffers(localOutputBuffers.getType(), localOutputBuffers.getVersion(), + localOutputBuffers.isNoMoreBufferIds(), localOutputBuffers.getBuffers(), localOutputBuffers.getExchangeSinkInstanceHandle()); checkState(localOutputBuffers != null, "Initial output buffers must be set before a task can be scheduled"); if (sinkExchange.isPresent()) { diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java new file mode 100644 index 000000000..98beb1bc2 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java @@ -0,0 +1,210 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.buffer; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.log.Logger; +import io.airlift.slice.Slice; +import io.airlift.units.DataSize; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.PagesSerdeUtil; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.exchange.ExchangeManager; +import io.prestosql.exchange.ExchangeSink; +import io.prestosql.exchange.ExchangeSource; +import io.prestosql.exchange.ExchangeSourceHandle; +import io.prestosql.exchange.FileStatus; +import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.exchange.FileSystemExchangeSourceHandle; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.spi.Page; + +import javax.crypto.SecretKey; + +import java.net.URI; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.function.Supplier; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.exchange.FileSystemExchangeSink.DATA_FILE_SUFFIX; +import static java.util.concurrent.Executors.newCachedThreadPool; + +public class HybridSpoolingBuffer + extends SpoolingExchangeOutputBuffer +{ + private static final String PARENT_URI = ".."; + private static final Logger LOG = Logger.get(HybridSpoolingBuffer.class); + + private final OutputBufferStateMachine stateMachine; + private final OutputBuffers outputBuffers; + private final ExchangeSink exchangeSink; + private ExchangeSource exchangeSource; + private final Supplier memoryContextSupplier; + private final ExecutorService executor; + private final ExchangeManager exchangeManager; + private int token; + private PagesSerde serde; + private PagesSerde javaSerde; + private PagesSerde kryoSerde; + + private URI outputDirectory; + + public HybridSpoolingBuffer(OutputBufferStateMachine stateMachine, OutputBuffers outputBuffers, ExchangeSink exchangeSink, Supplier memoryContextSupplier, ExchangeManager exchangeManager) + { + super(stateMachine, outputBuffers, exchangeSink, memoryContextSupplier); + this.stateMachine = stateMachine; + this.outputBuffers = outputBuffers; + this.exchangeSink = exchangeSink; + this.memoryContextSupplier = memoryContextSupplier; + this.outputDirectory = exchangeSink.getOutputDirectory().resolve(PARENT_URI); + this.executor = newCachedThreadPool(daemonThreadsNamed("exchange-source-handles-creation-%s")); + this.exchangeManager = exchangeManager; + } + + @Override + public ListenableFuture get(OutputBuffers.OutputBufferId bufferId, long token, DataSize maxSize) + { + if (exchangeSource == null) { + exchangeSource = createExchangeSource(exchangeSink.getPartitionId()); + if (exchangeSource == null) { + return immediateFuture(BufferResult.emptyResults(token, false)); + } + } + return immediateFuture(readPages(token)); + } + + private BufferResult readPages(long tokenId) + { + List result = new ArrayList<>(); + FileSystemExchangeConfig.DirectSerialisationType directSerialisationType = exchangeSource.getDirectSerialisationType(); + if (directSerialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { + PagesSerde pagesSerde = (directSerialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? javaSerde : kryoSerde; + while (token < tokenId) { + exchangeSource.readPage(pagesSerde); + token++; + } + Page page = exchangeSource.readPage(pagesSerde); + result.add(serde.serialize(page)); + } + else { + while (token < tokenId) { + exchangeSource.read(); + token++; + } + Slice slice = exchangeSource.read(); + SerializedPage serializedPage = slice != null ? PagesSerdeUtil.readSerializedPage(slice) : null; + result.add(serializedPage); + } + return new BufferResult(tokenId, tokenId + result.size(), false, result); + } + + public ExchangeSource createExchangeSource(int partitionId) + { + ExchangeHandleInfo exchangeHandleInfo = getExchangeHandleInfo(exchangeSink); + ListenableFuture> fileStatus = getFileStatus(outputDirectory); + List completedFileStatus = new ArrayList<>(); + if (!fileStatus.isDone()) { + return null; + } + try { + completedFileStatus = fileStatus.get(); + } + catch (Exception e) { + LOG.debug("Failed in creating exchange source with outputDirectory" + outputDirectory); + } + List handles = ImmutableList.of(new FileSystemExchangeSourceHandle(partitionId, + completedFileStatus, exchangeHandleInfo.getSecretKey(), exchangeHandleInfo.isExchangeCompressionEnabled())); + + return exchangeManager.createSource(handles); + } + + private ListenableFuture> getFileStatus(URI path) + { + FileSystemExchangeStorage exchangeStorage = exchangeSink.getExchangeStorage(); + return Futures.transform(exchangeStorage.listFilesRecursively(path), + sinkOutputFiles -> sinkOutputFiles.stream().filter(file -> file.getFilePath().endsWith(DATA_FILE_SUFFIX)).collect(toImmutableList()), + executor); + } + + private ExchangeHandleInfo getExchangeHandleInfo(ExchangeSink exchangeSink) + { + return new ExchangeHandleInfo(exchangeSink.getOutputDirectory(), + exchangeSink.getExchangeStorage(), + exchangeSink.getSecretKey().map(SecretKey::getEncoded), + exchangeSink.isExchangeCompressionEnabled()); + } + + @Override + public void setJavaSerde(PagesSerde javaSerde) + { + this.javaSerde = javaSerde; + } + + @Override + public void setKryoSerde(PagesSerde kryoSerde) + { + this.kryoSerde = kryoSerde; + } + + @Override + public void setSerde(PagesSerde serde) + { + this.serde = serde; + } + + private static class ExchangeHandleInfo + { + URI outputDirectory; + FileSystemExchangeStorage exchangeStorage; + Optional secretKey; + boolean exchangeCompressionEnabled; + + ExchangeHandleInfo(URI outputDirectory, FileSystemExchangeStorage exchangeStorage, Optional secretKey, boolean exchangeCompressionEnabled) + { + this.outputDirectory = outputDirectory; + this.exchangeStorage = exchangeStorage; + this.secretKey = secretKey; + this.exchangeCompressionEnabled = exchangeCompressionEnabled; + } + + public URI getOutputDirectory() + { + return outputDirectory; + } + + public Optional getSecretKey() + { + return secretKey; + } + + public boolean isExchangeCompressionEnabled() + { + return exchangeCompressionEnabled; + } + + public FileSystemExchangeStorage getExchangeStorage() + { + return exchangeStorage; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java index 1e3e05712..86e7caad0 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java @@ -64,6 +64,9 @@ public class LazyOutputBuffer @GuardedBy("this") private OutputBuffer delegate; + @GuardedBy("this") + private OutputBuffer hybridSpoolingDelegate; + @GuardedBy("this") private final Set abortedBuffers = new HashSet<>(); @@ -72,6 +75,10 @@ public class LazyOutputBuffer private final ExchangeManagerRegistry exchangeManagerRegistry; private Optional exchangeSink; + private PagesSerde serde; + private PagesSerde javaSerde; + private PagesSerde kryoSerde; + public LazyOutputBuffer( TaskId taskId, Executor executor, @@ -161,6 +168,7 @@ public class LazyOutputBuffer Set abortedBuffersIds = ImmutableSet.of(); List bufferPendingReads = ImmutableList.of(); OutputBuffer outputBuffer; + ExchangeManager exchangeManager = null; synchronized (this) { if (delegate == null) { // ignore set output if buffer was already destroyed or failed @@ -185,7 +193,7 @@ public class LazyOutputBuffer if (newOutputBuffers.getExchangeSinkInstanceHandle().isPresent()) { ExchangeSinkInstanceHandle exchangeSinkInstanceHandle = newOutputBuffers.getExchangeSinkInstanceHandle() .orElseThrow(() -> new IllegalArgumentException("exchange sink handle is expected to be present for buffer type EXTERNAL")); - ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); + exchangeManager = exchangeManagerRegistry.getExchangeManager(); ExchangeSink exchangeSinkInstance = exchangeManager.createSink(exchangeSinkInstanceHandle, false); //TODO: create directories this.exchangeSink = Optional.ofNullable(exchangeSinkInstance); } @@ -199,11 +207,30 @@ public class LazyOutputBuffer bufferPendingReads = ImmutableList.copyOf(this.pendingReads); this.pendingReads.clear(); } - this.exchangeSink.ifPresent(sink -> delegate = new SpoolingExchangeOutputBuffer( - stateMachine, - newOutputBuffers, - sink, - systemMemoryContextSupplier)); + ExchangeManager finalExchangeManager = exchangeManager; + this.exchangeSink.ifPresent(sink -> { + if (delegate == null) { + delegate = new SpoolingExchangeOutputBuffer( + stateMachine, + newOutputBuffers, + sink, + systemMemoryContextSupplier); + } + else { + if (hybridSpoolingDelegate == null) { + hybridSpoolingDelegate = new HybridSpoolingBuffer(stateMachine, + newOutputBuffers, + sink, + systemMemoryContextSupplier, + finalExchangeManager); + if (hybridSpoolingDelegate != null) { + hybridSpoolingDelegate.setSerde(serde); + hybridSpoolingDelegate.setJavaSerde(javaSerde); + hybridSpoolingDelegate.setKryoSerde(kryoSerde); + } + } + } + }); outputBuffer = delegate; } @@ -325,6 +352,9 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } + if (hybridSpoolingDelegate != null) { + hybridSpoolingDelegate.setNoMorePages(); + } outputBuffer.setNoMorePages(); } @@ -488,4 +518,44 @@ public class LazyOutputBuffer } outputBuffer.enqueuePages(partition, pages, id, directSerde); } + + @Override + public boolean isSpoolingDelegateAvailable() + { + return delegate != null && hybridSpoolingDelegate != null && hybridSpoolingDelegate.isSpoolingOutputBuffer(); + } + + @Override + public OutputBuffer getSpoolingDelegate() + { + return hybridSpoolingDelegate; + } + + @Override + public DirectSerialisationType getDelegateSpoolingExchangeDirectSerializationType() + { + DirectSerialisationType type = DirectSerialisationType.JAVA; + if (hybridSpoolingDelegate != null) { + type = hybridSpoolingDelegate.getExchangeDirectSerialisationType(); + } + return type; + } + + @Override + public void setSerde(PagesSerde serde) + { + this.serde = serde; + } + + @Override + public void setJavaSerde(PagesSerde javaSerde) + { + this.javaSerde = javaSerde; + } + + @Override + public void setKryoSerde(PagesSerde kryoSerde) + { + this.kryoSerde = kryoSerde; + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java index 3e6cfbeb1..773b27056 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java @@ -176,4 +176,31 @@ public interface OutputBuffer { return; } + + default boolean isSpoolingDelegateAvailable() + { + return false; + } + + default OutputBuffer getSpoolingDelegate() + { + return null; + } + + default DirectSerialisationType getDelegateSpoolingExchangeDirectSerializationType() + { + return DirectSerialisationType.JAVA; + } + + default void setSerde(PagesSerde pagesSerde) + { + } + + default void setJavaSerde(PagesSerde pagesSerde) + { + } + + default void setKryoSerde(PagesSerde pagesSerde) + { + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java index 621438ed9..9fa71a151 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java @@ -149,6 +149,7 @@ import static io.prestosql.SystemSessionProperties.getWriterMinSize; import static io.prestosql.SystemSessionProperties.isQueryResourceTrackingEnabled; import static io.prestosql.SystemSessionProperties.isReuseTableScanEnabled; import static io.prestosql.exchange.RetryPolicy.TASK; +import static io.prestosql.exchange.RetryPolicy.TASK_ASYNC; import static io.prestosql.execution.BasicStageStats.aggregateBasicStageStats; import static io.prestosql.execution.QueryState.FINISHING; import static io.prestosql.execution.QueryState.RECOVERING; @@ -633,7 +634,7 @@ public class SqlQueryScheduler StageId stageId = new StageId(queryStateMachine.getQueryId(), nextStageId.getAndIncrement()); Optional exchange = Optional.empty(); - if (retryPolicy.equals(TASK)) { + if (retryPolicy.equals(TASK) || retryPolicy.equals(TASK_ASYNC)) { ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); exchange = createSqlStageExchange(exchangeManager, stageId); } diff --git a/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java b/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java index a41957238..ee89af2b0 100644 --- a/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java +++ b/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java @@ -138,6 +138,7 @@ public class ExchangeClientFactory buffer = new DeduplicatingDirectExchangeBuffer(scheduler, deduplicationBufferSize, retryPolicy, exchangeManagerRegistry, queryId, exchangeId); break; case NONE: + case TASK_ASYNC: buffer = null; break; default: diff --git a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java index b4410f694..737ee270f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java @@ -123,6 +123,8 @@ public class TaskOutputOperator private final SingleInputSnapshotState snapshotState; private final boolean isStage0; private final PagesSerde serde; + private final PagesSerde javaSerde; + private final PagesSerde kryoSerde; private boolean finished; public TaskOutputOperator(String id, OperatorContext operatorContext, OutputBuffer outputBuffer, Function pagePreprocessor) @@ -132,6 +134,11 @@ public class TaskOutputOperator this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.pagePreprocessor = requireNonNull(pagePreprocessor, "pagePreprocessor is null"); this.serde = requireNonNull(operatorContext.getDriverContext().getSerde(), "serde is null"); + this.javaSerde = requireNonNull(operatorContext.getDriverContext().getJavaSerde(), "javaSerde is null"); + this.kryoSerde = requireNonNull(operatorContext.getDriverContext().getKryoSerde(), "kryoSerde is null"); + this.outputBuffer.setSerde(serde); + this.outputBuffer.setJavaSerde(javaSerde); + this.outputBuffer.setKryoSerde(kryoSerde); this.snapshotState = operatorContext.isSnapshotEnabled() ? SingleInputSnapshotState.forOperator(this, operatorContext) : null; this.isStage0 = operatorContext.getDriverContext().getPipelineContext().getTaskContext().getTaskId().getStageId().getId() == 0; } @@ -192,6 +199,7 @@ public class TaskOutputOperator } DirectSerialisationType serialisationType = outputBuffer.getExchangeDirectSerialisationType(); + DirectSerialisationType spoolingSerialisationType = outputBuffer.getDelegateSpoolingExchangeDirectSerializationType(); if (outputBuffer.isSpoolingOutputBuffer() && serialisationType != DirectSerialisationType.OFF) { PagesSerde directSerde = (serialisationType == DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); List pages = splitPage(inputPage, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); @@ -215,6 +223,21 @@ public class TaskOutputOperator } } else { + if (outputBuffer.isSpoolingDelegateAvailable()) { + OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); + if (spoolingSerialisationType != DirectSerialisationType.OFF) { + PagesSerde directSerde = (spoolingSerialisationType == DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); + List pages = splitPage(inputPage, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + if (spoolingBuffer != null) { + spoolingBuffer.enqueuePages(0, pages, id, directSerde); + } + } + else { + if (spoolingBuffer != null) { + spoolingBuffer.enqueue(serializedPages, id); + } + } + } outputBuffer.enqueue(serializedPages, id); } } diff --git a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java index bbc68e1c8..2debb2841 100644 --- a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java +++ b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java @@ -99,6 +99,9 @@ public class PagePartitioner this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); this.operatorContext = requireNonNull(operatorContext, "serde is null"); + this.outputBuffer.setSerde(requireNonNull(operatorContext.getDriverContext().getSerde(), "serde is null")); + this.outputBuffer.setJavaSerde(requireNonNull(operatorContext.getDriverContext().getJavaSerde(), "java serde is null")); + this.outputBuffer.setKryoSerde(requireNonNull(operatorContext.getDriverContext().getKryoSerde(), "kryo serde is null")); int partitionCount = partitionFunction.getPartitionCount(); int pageSize = min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, ((int) maxMemory.toBytes()) / partitionCount); @@ -456,6 +459,7 @@ public class PagePartitioner partitionPageBuilder.reset(); FileSystemExchangeConfig.DirectSerialisationType serialisationType = outputBuffer.getExchangeDirectSerialisationType(); + FileSystemExchangeConfig.DirectSerialisationType spoolingSerialisationType = outputBuffer.getDelegateSpoolingExchangeDirectSerializationType(); if (outputBuffer.isSpoolingOutputBuffer() && serialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { PagesSerde directSerde = (serialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); List pages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); @@ -466,6 +470,21 @@ public class PagePartitioner .map(page -> operatorContext.getDriverContext().getSerde().serialize(page)) .collect(toImmutableList()); + if (outputBuffer.isSpoolingDelegateAvailable()) { + OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); + if (spoolingSerialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { + PagesSerde directSerde = (spoolingSerialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); + List pages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + if (spoolingBuffer != null) { + spoolingBuffer.enqueuePages(partition, pages, id, directSerde); + } + } + else { + if (spoolingBuffer != null) { + spoolingBuffer.enqueue(partition, serializedPages, id); + } + } + } outputBuffer.enqueue(partition, serializedPages, id); } pagesAdded.incrementAndGet(); diff --git a/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java b/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java index c97655402..8203dcd53 100644 --- a/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java +++ b/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java @@ -95,6 +95,12 @@ public class TestingRecoveryUtils { } + @Override + public void flush(OutputStream outputStream) + throws IOException + { + } + @Override public InputStream newInputStream(Path path) { diff --git a/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java b/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java new file mode 100644 index 000000000..51451966b --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java @@ -0,0 +1,249 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.buffer; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.units.DataSize; +import io.hetu.core.filesystem.HetuLocalFileSystemClient; +import io.hetu.core.filesystem.LocalConfig; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.exchange.ExchangeManager; +import io.prestosql.exchange.ExchangeSink; +import io.prestosql.exchange.ExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.exchange.FileSystemExchangeManager; +import io.prestosql.exchange.FileSystemExchangeSinkHandle; +import io.prestosql.exchange.FileSystemExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeStats; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.exchange.storage.HetuFileSystemExchangeStorage; +import io.prestosql.execution.StageId; +import io.prestosql.execution.TaskId; +import io.prestosql.operator.PageAssertions; +import io.prestosql.spi.Page; +import io.prestosql.spi.QueryId; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.testing.TestingPagesSerdeFactory; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestHybridSpoolingBuffer +{ + private final PagesSerde serde = new TestingPagesSerdeFactory().createPagesSerde(); + private final PagesSerde javaSerde = new PagesSerdeFactory(createTestMetadataManager().getFunctionAndTypeManager().getBlockEncodingSerde(), false) + .createDirectPagesSerde(Optional.empty(), true, false); + private final PagesSerde kryoSerde = new PagesSerdeFactory(createTestMetadataManager().getFunctionAndTypeManager().getBlockKryoEncodingSerde(), false) + .createDirectPagesSerde(Optional.empty(), true, true); + private ExchangeSink exchangeSink; + private ExchangeManager exchangeManager; + private ExchangeSinkInstanceHandle exchangeSinkInstanceHandle; + private final String baseDir = "/tmp/hetu/spooling"; + private final String accessDir = "/tmp/hetu"; + private final Path accessPath = Paths.get(accessDir); + + @BeforeMethod + public void setUp() + throws IOException, InterruptedException + { + Path basePath = Paths.get(baseDir); + File base = new File(baseDir); + if (!base.exists()) { + Files.createDirectories(basePath); + } + else { + deleteDirectory(base); + Files.createDirectories(basePath); + } + } + + @AfterMethod + public void cleanUp() + { + File base = new File(baseDir); + if (base.exists()) { + deleteDirectory(base); + } + } + + private void setConfig(FileSystemExchangeConfig.DirectSerialisationType type) + { + FileSystemExchangeConfig config = new FileSystemExchangeConfig() + .setExchangeEncryptionEnabled(false) + .setDirectSerializationType(type) + .setBaseDirectories(baseDir); + + FileSystemExchangeStorage exchangeStorage = new HetuFileSystemExchangeStorage(); + exchangeStorage.setFileSystemClient(new HetuLocalFileSystemClient(new LocalConfig(new Properties()), accessPath)); + exchangeManager = new FileSystemExchangeManager(exchangeStorage, new FileSystemExchangeStats(), config); + exchangeSinkInstanceHandle = new FileSystemExchangeSinkInstanceHandle( + new FileSystemExchangeSinkHandle(0, Optional.empty(), false), + config.getBaseDirectories().get(0), + 10); + exchangeSink = exchangeManager.createSink(exchangeSinkInstanceHandle, false); + } + + @Test + public void testHybridSpoolingBufferWithSerializationOff() + throws ExecutionException, InterruptedException + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + List pages = new ArrayList<>(); + pages.add(generateSerializedPage()); + pages.add(generateSerializedPage()); + hybridSpoolingBuffer.enqueue(0, pages, null); + ListenableFuture result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + while (result.get().equals(BufferResult.emptyResults(0, false))) { + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + } + List actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + hybridSpoolingBuffer.setNoMorePages(); + } + + @Test + public void testHybridSpoolingBufferWithSerializationJava() + throws ExecutionException, InterruptedException + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.JAVA); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.setSerde(serde); + hybridSpoolingBuffer.setJavaSerde(javaSerde); + hybridSpoolingBuffer.setKryoSerde(kryoSerde); + List pages = new ArrayList<>(); + pages.add(generatePage()); + pages.add(generatePage()); + hybridSpoolingBuffer.enqueuePages(0, pages, null, javaSerde); + ListenableFuture result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + while (result.get().equals(BufferResult.emptyResults(0, false))) { + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + } + List actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + hybridSpoolingBuffer.setNoMorePages(); + } + + @Test + public void testHybridSpoolingBufferWithSerializationKryo() + throws ExecutionException, InterruptedException + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.KRYO); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.setSerde(serde); + hybridSpoolingBuffer.setJavaSerde(javaSerde); + hybridSpoolingBuffer.setKryoSerde(kryoSerde); + List pages = new ArrayList<>(); + pages.add(generatePage()); + pages.add(generatePage()); + hybridSpoolingBuffer.enqueuePages(0, pages, null, kryoSerde); + ListenableFuture result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + while (result.get().equals(BufferResult.emptyResults(0, false))) { + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + } + List actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + hybridSpoolingBuffer.setNoMorePages(); + } + + private HybridSpoolingBuffer createHybridSpoolingBuffer() + { + OutputBuffers outputBuffers = OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.PARTITIONED); + outputBuffers.setExchangeSinkInstanceHandle(exchangeSinkInstanceHandle); + + return new HybridSpoolingBuffer(new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), directExecutor()), + outputBuffers, + exchangeSink, + TestSpoolingExchangeOutputBuffer.TestingLocalMemoryContext::new, + exchangeManager); + } + + private SerializedPage generateSerializedPage() + { + Page expectedPage = generatePage(); + SerializedPage page = serde.serialize(expectedPage); + return page; + } + + private Page generatePage() + { + BlockBuilder expectedBlockBuilder = INTEGER.createBlockBuilder(null, 2); + INTEGER.writeLong(expectedBlockBuilder, 10); + INTEGER.writeLong(expectedBlockBuilder, 20); + Block expectedBlock = expectedBlockBuilder.build(); + + return new Page(expectedBlock, expectedBlock); + } + + private boolean deleteDirectory(File dir) + { + File[] allContents = dir.listFiles(); + if (allContents != null) { + for (File file : allContents) { + deleteDirectory(file); + } + } + return dir.delete(); + } +} diff --git a/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java b/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java index b5ad674d3..e0706a856 100644 --- a/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java +++ b/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java @@ -24,6 +24,7 @@ import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.exchange.ExchangeSink; import io.prestosql.exchange.ExchangeSinkInstanceHandle; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; import io.prestosql.execution.StageId; import io.prestosql.execution.TaskId; import io.prestosql.memory.context.LocalMemoryContext; @@ -31,6 +32,9 @@ import io.prestosql.spi.Page; import io.prestosql.spi.QueryId; import org.testng.annotations.Test; +import javax.crypto.SecretKey; + +import java.net.URI; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -358,6 +362,36 @@ public class TestSpoolingExchangeOutputBuffer return abort; } + @Override + public FileSystemExchangeStorage getExchangeStorage() + { + return null; + } + + @Override + public URI getOutputDirectory() + { + return null; + } + + @Override + public Optional getSecretKey() + { + return Optional.empty(); + } + + @Override + public boolean isExchangeCompressionEnabled() + { + return false; + } + + @Override + public int getPartitionId() + { + return 0; + } + public void setAbort(CompletableFuture abort) { this.abort = requireNonNull(abort, "abort is null"); @@ -370,7 +404,7 @@ public class TestSpoolingExchangeOutputBuffer INSTANCE } - private static class TestingLocalMemoryContext + protected static class TestingLocalMemoryContext implements LocalMemoryContext { @Override diff --git a/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java b/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java index b7eb581f1..9d0bb8497 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java +++ b/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java @@ -210,4 +210,7 @@ public interface HetuFileSystemClient Stream getDirectoryStream(Path path, String prefix, String suffix) throws IOException; + + void flush(OutputStream outputStream) + throws IOException; } -- Gitee From ff092fa66006ce90b5155142fb5d8d21d8d01c89 Mon Sep 17 00:00:00 2001 From: Surya Sumanth N Date: Tue, 28 Feb 2023 17:56:10 +0530 Subject: [PATCH 2/4] Marker Index File Interface Implementation --- .../filesystem/HetuHdfsFileSystemClient.java | 5 +- .../io/prestosql/exchange/ExchangeSink.java | 6 + .../exchange/FileSystemExchangeSink.java | 24 ++ .../storage/ExchangeStorageWriter.java | 7 + .../storage/FileSystemExchangeStorage.java | 2 + .../HetuFileSystemExchangeStorage.java | 6 + .../storage/HetuFileSystemExchangeWriter.java | 23 ++ .../execution/MarkerDataFileFactory.java | 266 ++++++++++++++++++ .../execution/MarkerIndexFileFactory.java | 191 +++++++++++++ .../buffer/HybridSpoolingBuffer.java | 246 ++++++++++++++++ .../buffer/TestHybridSpoolingBuffer.java | 95 ++++++- .../TestSpoolingExchangeOutputBuffer.java | 14 + 12 files changed, 876 insertions(+), 9 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/execution/MarkerDataFileFactory.java create mode 100644 presto-main/src/main/java/io/prestosql/execution/MarkerIndexFileFactory.java diff --git a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java index 16c634a87..69328b206 100644 --- a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java +++ b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.PathIsNotEmptyDirectoryException; +import org.apache.hadoop.hdfs.client.HdfsDataOutputStream; import org.apache.hadoop.hdfs.protocol.AlreadyBeingCreatedException; import org.apache.hadoop.ipc.RemoteException; @@ -36,6 +37,7 @@ import java.nio.file.NoSuchFileException; import java.nio.file.OpenOption; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.EnumSet; import java.util.UUID; import java.util.stream.Stream; @@ -272,7 +274,8 @@ public class HetuHdfsFileSystemClient public void flush(OutputStream outputStream) throws IOException { - if (outputStream instanceof FSDataOutputStream) { + if (outputStream instanceof HdfsDataOutputStream) { + ((HdfsDataOutputStream) outputStream).hsync(EnumSet.of(HdfsDataOutputStream.SyncFlag.UPDATE_LENGTH)); ((FSDataOutputStream) outputStream).hflush(); } } diff --git a/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java b/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java index 0386dcd41..f8310479f 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java +++ b/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java @@ -17,11 +17,13 @@ import io.airlift.slice.Slice; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.spi.Page; import javax.crypto.SecretKey; import java.net.URI; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -55,4 +57,8 @@ public interface ExchangeSink boolean isExchangeCompressionEnabled(); int getPartitionId(); + + List getSinkFiles(); + + void enqueueMarkerInfo(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo); } diff --git a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java index b2f8e77f8..71e9f2648 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java +++ b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java @@ -25,6 +25,7 @@ import io.hetu.core.transport.execution.buffer.PagesSerde; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; import io.prestosql.exchange.storage.ExchangeStorageWriter; import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.util.SizeOf; @@ -46,6 +47,7 @@ import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -175,6 +177,18 @@ public class FileSystemExchangeSink return partitionId; } + @Override + public List getSinkFiles() + { + return writerMap.values().stream().flatMap(bufferedStorageWriter -> bufferedStorageWriter.getSpoolFiles().stream()).collect(Collectors.toList()); + } + + @Override + public void enqueueMarkerInfo(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + writerMap.values().stream().forEach(bufferedStorageWriter -> bufferedStorageWriter.enqueueMarkerMetadata(markerDataFileFooterInfo)); + } + @Override public void add(int partitionId, Slice data) { @@ -509,6 +523,16 @@ public class FileSystemExchangeSink addExceptionCallback(writeFuture, throwable -> failure.compareAndSet(null, throwable)); } } + + public List getSpoolFiles() + { + return writers.stream().map(writer -> writer.getFile()).collect(Collectors.toList()); + } + + public void enqueueMarkerMetadata(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + writers.stream().forEach(exchangeStorageWriter -> exchangeStorageWriter.writeMarkerMetadata(markerDataFileFooterInfo)); + } } @ThreadSafe diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java b/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java index 1c3efd091..eb8f5a63a 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java @@ -16,8 +16,11 @@ package io.prestosql.exchange.storage; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.spi.Page; +import java.net.URI; + public interface ExchangeStorageWriter { ListenableFuture write(Slice slice); @@ -29,4 +32,8 @@ public interface ExchangeStorageWriter ListenableFuture abort(); long getRetainedSize(); + + URI getFile(); + + void writeMarkerMetadata(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo); } diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java b/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java index 69c3ff968..e8812671c 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java @@ -32,6 +32,8 @@ public interface FileSystemExchangeStorage { public void setFileSystemClient(HetuFileSystemClient fsClient); + HetuFileSystemClient getFileSystemClient(); + void createDirectories(URI dir) throws IOException; ExchangeStorageReader createExchangeReader(Queue sourceFiles, int maxPageSize, DirectSerialisationType directSerialisationType, int directSerialisationBufferSize); diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java index 5c1bce39a..e8d56dcc9 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java @@ -59,6 +59,12 @@ public class HetuFileSystemExchangeStorage fileSystemClient = fsClient; } + @Override + public HetuFileSystemClient getFileSystemClient() + { + return fileSystemClient; + } + @Override public void createDirectories(URI dir) throws IOException { diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java index 49bbfc769..83e5f40bf 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java @@ -21,6 +21,8 @@ import io.airlift.slice.Slice; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.prestosql.exchange.FileSystemExchangeConfig; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; +import io.prestosql.execution.MarkerDataFileFactory; +import io.prestosql.snapshot.RecoveryUtils; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.filesystem.HetuFileSystemClient; @@ -57,12 +59,14 @@ public class HetuFileSystemExchangeWriter private final int directSerialisationBufferSize; private final HetuFileSystemClient fileSystemClient; private final OutputStream delegateOutputStream; + private final URI file; public HetuFileSystemExchangeWriter(URI file, HetuFileSystemClient fileSystemClient, Optional secretKey, boolean exchangeCompressionEnabled, AlgorithmParameterSpec algorithmParameterSpec, FileSystemExchangeConfig.DirectSerialisationType directSerialisationType, int directSerialisationBufferSize) { this.directSerialisationBufferSize = directSerialisationBufferSize; this.directSerialisationType = directSerialisationType; this.fileSystemClient = fileSystemClient; + this.file = file; try { Path path = Paths.get(file.toString()); this.delegateOutputStream = fileSystemClient.newOutputStream(path); @@ -155,4 +159,23 @@ public class HetuFileSystemExchangeWriter { return INSTANCE_SIZE; } + + @Override + public URI getFile() + { + return file; + } + + @Override + public void writeMarkerMetadata(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + try { + RecoveryUtils.serializeState(markerDataFileFooterInfo, outputStream, false); + outputStream.flush(); + fileSystemClient.flush(delegateOutputStream); + } + catch (Exception e) { + e.printStackTrace(); + } + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/MarkerDataFileFactory.java b/presto-main/src/main/java/io/prestosql/execution/MarkerDataFileFactory.java new file mode 100644 index 000000000..1be5c428f --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/MarkerDataFileFactory.java @@ -0,0 +1,266 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution; + +import io.prestosql.snapshot.RecoveryUtils; +import io.prestosql.snapshot.SnapshotStateId; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.filesystem.HetuFileSystemClient; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.prestosql.spi.filesystem.SupportedFileAttributes.SIZE; +import static java.util.Objects.requireNonNull; + +public class MarkerDataFileFactory +{ + private final URI targetOutputDirectory; + private final HetuFileSystemClient hetuFileSystemClient; + + public MarkerDataFileFactory(URI taskOutputDirectory, HetuFileSystemClient hetuFileSystemClient) + { + this.targetOutputDirectory = taskOutputDirectory; + this.hetuFileSystemClient = requireNonNull(hetuFileSystemClient, "hetuFileSystemClient is null"); + } + + public MarkerDataFileWriter createWriter(String file, boolean useKryo) + { + URI fileURI = targetOutputDirectory.resolve(file); + return new MarkerDataFileWriter(hetuFileSystemClient, fileURI, useKryo); + } + + public MarkerDataFileReader createReader(URI file, boolean useKryo) + { + return new MarkerDataFileReader(hetuFileSystemClient, file, useKryo); + } + + public static class MarkerDataFileWriter + { + private final OutputStream outputStream; + private final boolean useKryo; + private HetuFileSystemClient hetuFileSystemClient; + private Path path; + private URI fileURI; + + public MarkerDataFileWriter(HetuFileSystemClient hetuFileSystemClient, URI fileURI, boolean useKryo) + { + this.hetuFileSystemClient = hetuFileSystemClient; + this.path = Paths.get(fileURI.getPath()); + this.fileURI = fileURI; + OutputStream localOutputStream = null; + try { + localOutputStream = hetuFileSystemClient.newOutputStream(path); + } + catch (IOException e) { + // add logs + } + this.outputStream = localOutputStream; + this.useKryo = useKryo; + } + + public MarkerDataFileFooterInfo writeDataFile(int markerId, Map state, long previousFooterOffset, long previousFooterSize) + { + Map operatorStateInfoMap = new HashMap<>(); + try { + for (Map.Entry entry : state.entrySet()) { + long offset = (Long) hetuFileSystemClient.getAttribute(path, SIZE); + RecoveryUtils.serializeState(entry.getValue(), outputStream, useKryo); + outputStream.flush(); + hetuFileSystemClient.flush(outputStream); + long size = (Long) hetuFileSystemClient.getAttribute(path, SIZE) - offset; + operatorStateInfoMap.put(entry.getKey().getId(), new OperatorStateInfo(offset, size)); + } + // write footer + MarkerDataFileFooter footer = new MarkerDataFileFooter(markerId, previousFooterOffset, previousFooterSize, state.size(), operatorStateInfoMap); + long footerOffset = (Long) hetuFileSystemClient.getAttribute(path, SIZE); + RecoveryUtils.serializeState(footer, outputStream, useKryo); + outputStream.flush(); + hetuFileSystemClient.flush(outputStream); + long footerSize = (Long) hetuFileSystemClient.getAttribute(path, SIZE) - footerOffset; + return new MarkerDataFileFooterInfo(fileURI, footer, footerOffset, footerSize); + } + catch (IOException e) { + // add logs + System.out.println(1); + } + return null; + } + } + + public static class MarkerDataFileReader + { + private final boolean useKryo; + private Path path; + private HetuFileSystemClient hetuFileSystemClient; + + public MarkerDataFileReader(HetuFileSystemClient hetuFileSystemClient, URI fileURI, boolean useKryo) + { + this.path = Paths.get(fileURI.getPath()); + this.useKryo = useKryo; + this.hetuFileSystemClient = hetuFileSystemClient; + } + + public Map readDataFile(long markerDataOffset, long markerDataLength) + { + try { + InputStream inputStream = hetuFileSystemClient.newInputStream(path); + inputStream.skip(markerDataOffset); + MarkerDataFileFooter markerDataFileFooter = (MarkerDataFileFooter) RecoveryUtils.deserializeState(inputStream, useKryo); + int operatorCount = markerDataFileFooter.getOperatorCount(); + Map operatorStateInfoMap = markerDataFileFooter.getOperatorStateInfo(); + checkArgument(operatorCount == operatorStateInfoMap.size(), "operator data mismatch"); + inputStream.close(); + + Map states = new HashMap<>(); + for (Map.Entry operatorStateInfoEntry : operatorStateInfoMap.entrySet()) { + inputStream = hetuFileSystemClient.newInputStream(path); + inputStream.skip(operatorStateInfoEntry.getValue().getStateOffset()); + states.put(SnapshotStateId.fromString(operatorStateInfoEntry.getKey()), RecoveryUtils.deserializeState(inputStream, useKryo)); + inputStream.close(); + } + return states; + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed reading data file"); + } + } + } + + public static class MarkerDataFileFooterInfo + implements Serializable + { + private URI path; + private MarkerDataFileFooter footer; + private long footerOffset; + private long footerSize; + + public MarkerDataFileFooterInfo(URI path, MarkerDataFileFooter footer, long footerOffset, long footerSize) + { + this.path = path; + this.footer = footer; + this.footerOffset = footerOffset; + this.footerSize = footerSize; + } + + public long getFooterOffset() + { + return footerOffset; + } + + public long getFooterSize() + { + return footerSize; + } + + public MarkerDataFileFooter getFooter() + { + return footer; + } + + public URI getPath() + { + return path; + } + } + + public static class MarkerDataFileFooter + implements Serializable + { + private int markerId; + private int version; + private long previousTailOffset; + private long previousTailSize; + private int operatorCount; + private Map operatorStateInfo; + + public MarkerDataFileFooter(int markerId, long previousTailOffset, long previousTailSize, int operatorCount, Map operatorStateInfo) + { + this(markerId, 1, previousTailOffset, previousTailSize, operatorCount, operatorStateInfo); + } + + public MarkerDataFileFooter(int markerId, int version, long previousTailOffset, long previousTailSize, int operatorCount, Map operatorStateInfo) + { + this.markerId = markerId; + this.version = version; + this.previousTailOffset = previousTailOffset; + this.operatorCount = operatorCount; + this.operatorStateInfo = operatorStateInfo; + this.previousTailSize = previousTailSize; + } + + public int getMarkerId() + { + return markerId; + } + + public int getVersion() + { + return version; + } + + public long getPreviousTailOffset() + { + return previousTailOffset; + } + + public long getPreviousTailSize() + { + return previousTailSize; + } + + public int getOperatorCount() + { + return operatorCount; + } + + public Map getOperatorStateInfo() + { + return operatorStateInfo; + } + } + + public static class OperatorStateInfo + implements Serializable + { + private long stateOffset; + private long stateSize; + + public OperatorStateInfo(long stateOffset, long stateSize) + { + this.stateOffset = stateOffset; + this.stateSize = stateSize; + } + + public long getStateOffset() + { + return stateOffset; + } + + public long getStateSize() + { + return stateSize; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/MarkerIndexFileFactory.java b/presto-main/src/main/java/io/prestosql/execution/MarkerIndexFileFactory.java new file mode 100644 index 000000000..be77bd030 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/MarkerIndexFileFactory.java @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.prestosql.execution.buffer.HybridSpoolingBuffer; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.filesystem.HetuFileSystemClient; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Map; + +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.util.Objects.requireNonNull; + +public class MarkerIndexFileFactory +{ + private final URI targetOutputDirectory; + private final HetuFileSystemClient hetuFileSystemClient; + + public MarkerIndexFileFactory(URI taskOutputDirectory, HetuFileSystemClient hetuFileSystemClient) + { + this.targetOutputDirectory = taskOutputDirectory; + this.hetuFileSystemClient = requireNonNull(hetuFileSystemClient, "hetuFileSystemClient is null"); + } + + public MarkerIndexFileWriter createWriter(String file) + { + Path path = Paths.get(targetOutputDirectory.resolve(file).getPath()); + return new MarkerIndexFileWriter(hetuFileSystemClient, path); + } + + public static class MarkerIndexFileWriter + { + private final OutputStream outputStream; + private final HetuFileSystemClient hetuFileSystemClient; + + public MarkerIndexFileWriter(HetuFileSystemClient hetuFileSystemClient, Path path) + { + OutputStream localOutputStream = null; + try { + localOutputStream = hetuFileSystemClient.newOutputStream(path); + } + catch (IOException e) { + // add logs + } + outputStream = localOutputStream; + this.hetuFileSystemClient = hetuFileSystemClient; + } + + public void writeIndexFile(MarkerIndexFile markerIndexFile) + { + JsonFactory jsonFactory = new JsonFactory(); + jsonFactory.configure(JsonGenerator.Feature.AUTO_CLOSE_TARGET, false); + ObjectMapper objectMapper = new ObjectMapper(jsonFactory); + try { + objectMapper.writeValue(outputStream, markerIndexFile); + outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + hetuFileSystemClient.flush(outputStream); + } + catch (IOException e) { + // add logs + } + } + } + + public static class MarkerIndexFileReader + { + private final InputStream inputStream; + + public MarkerIndexFileReader(HetuFileSystemClient hetuFileSystemClient, Path path) + { + InputStream localInputStream = null; + try { + localInputStream = hetuFileSystemClient.newInputStream(path); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to create marker index file reader"); + } + inputStream = localInputStream; + } + + public MarkerIndexFile readIndexFile(int markerId) + { + JsonFactory jsonFactory = new JsonFactory(); + jsonFactory.configure(JsonParser.Feature.AUTO_CLOSE_SOURCE, false); + MarkerIndexFile markerIndexFile = null; + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + String line = reader.readLine(); + while (line != null) { + ObjectMapper objectMapper = new ObjectMapper(jsonFactory); + markerIndexFile = objectMapper.readValue(line, MarkerIndexFile.class); + if (markerId == markerIndexFile.getMarkerId()) { + return markerIndexFile; + } + line = reader.readLine(); + } + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Marker ID: " + markerId + " is not present in index file"); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to read marker index file"); + } + } + } + + public static class MarkerIndexFile + { + private final int markerId; + private final URI markerDataFile; + private final long markerStartOffset; + private final long markerLength; + private final Map spoolingInfoMap; + + @JsonCreator + public MarkerIndexFile( + @JsonProperty("markerID") int markerId, + @JsonProperty("markerDataFile") URI markerDataFile, + @JsonProperty("markerStartOffset") long markerStartOffset, + @JsonProperty("markerLength") long markerLength, + @JsonProperty("spoolingInfoMap") Map spoolingInfoMap) + { + this.markerId = markerId; + this.markerDataFile = markerDataFile; + this.markerStartOffset = markerStartOffset; + this.markerLength = markerLength; + this.spoolingInfoMap = spoolingInfoMap; + } + + public static MarkerIndexFile createMarkerIndexFile(int markerId, URI markerDataFile, long markerStartOffset, long markerLength, Map spoolingInfoMap) + { + return new MarkerIndexFile(markerId, markerDataFile, markerStartOffset, markerLength, spoolingInfoMap); + } + + @JsonProperty + public int getMarkerId() + { + return markerId; + } + + @JsonProperty + public URI getMarkerDataFile() + { + return markerDataFile; + } + + @JsonProperty + public long getMarkerStartOffset() + { + return markerStartOffset; + } + + @JsonProperty + public long getMarkerLength() + { + return markerLength; + } + + @JsonProperty + public Map getSpoolingInfoMap() + { + return spoolingInfoMap; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java index 98beb1bc2..e412c55ca 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java @@ -14,6 +14,8 @@ */ package io.prestosql.execution.buffer; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -31,14 +33,28 @@ import io.prestosql.exchange.FileStatus; import io.prestosql.exchange.FileSystemExchangeConfig; import io.prestosql.exchange.FileSystemExchangeSourceHandle; import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; +import io.prestosql.execution.MarkerIndexFileFactory; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.snapshot.RecoveryUtils; +import io.prestosql.snapshot.SnapshotStateId; import io.prestosql.spi.Page; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.filesystem.HetuFileSystemClient; import javax.crypto.SecretKey; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.function.Supplier; @@ -47,12 +63,17 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.prestosql.exchange.FileSystemExchangeSink.DATA_FILE_SUFFIX; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.prestosql.spi.filesystem.SupportedFileAttributes.SIZE; import static java.util.concurrent.Executors.newCachedThreadPool; public class HybridSpoolingBuffer extends SpoolingExchangeOutputBuffer { private static final String PARENT_URI = ".."; + private static final String MARKER_INDEX_FILE = "marker_index.json"; + private static final String MARKER_DATA_FILE = "marker_data_file.data"; + private static final String TEMP_FILE = "marker_temp_txn.data"; private static final Logger LOG = Logger.get(HybridSpoolingBuffer.class); private final OutputBufferStateMachine stateMachine; @@ -62,6 +83,13 @@ public class HybridSpoolingBuffer private final Supplier memoryContextSupplier; private final ExecutorService executor; private final ExchangeManager exchangeManager; + private MarkerIndexFileFactory.MarkerIndexFileWriter markerIndexFileWriter; + private MarkerDataFileFactory.MarkerDataFileWriter markerDataFileWriter; + private long previousFooterOffset; + private long previousFooterSize; + private int previousSuccessfulMarkerId; + private Map spoolingInfoMap = new HashMap<>(); + private final HetuFileSystemClient fsClient; private int token; private PagesSerde serde; private PagesSerde javaSerde; @@ -79,6 +107,154 @@ public class HybridSpoolingBuffer this.outputDirectory = exchangeSink.getOutputDirectory().resolve(PARENT_URI); this.executor = newCachedThreadPool(daemonThreadsNamed("exchange-source-handles-creation-%s")); this.exchangeManager = exchangeManager; + this.fsClient = exchangeSink.getExchangeStorage().getFileSystemClient(); + } + + public void enqueueMarkerInfo(int markerId, Map states) + { + createTempTransactionFile(); // this is to ensure transaction is successful. Deleted at end. + List exchangeSinkFiles = exchangeSink.getSinkFiles(); + MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo = enqueueMarkerData(markerId, states, exchangeSinkFiles); + Map spoolingFileInfo = new HashMap<>(); + try { + spoolingFileInfo.putAll(createSpoolingInfo(exchangeSinkFiles)); + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed creating spooling Info"); + } + + // write marker metadata to spooling file + enqueueMarkerDataToSpoolingFile(markerDataFileFooterInfo); + try { + updateSpoolingInfo(exchangeSinkFiles); + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed updating spooling Info"); + } + enqueueMarkerIndex(markerId, outputDirectory.resolve(MARKER_DATA_FILE), previousFooterOffset, previousFooterSize, spoolingFileInfo); + deleteTempTransactionFile(); + previousSuccessfulMarkerId = markerId; + } + + public Map dequeueMarkerInfo(int markerId) + { + if (checkIfTransactionFileExists()) { + //todo(SURYA): add logic to get recent successful marker and offsets + throw new PrestoException(GENERIC_INTERNAL_ERROR, "transaction file exists for marker: " + markerId); + } + MarkerIndexFileFactory.MarkerIndexFile markerIndexFile = dequeueMarkerIndex(markerId, outputDirectory.resolve(MARKER_INDEX_FILE)); + Map markerData = dequeueMarkerData(markerIndexFile); + Map spoolingInfoMap = dequeueSpoolingInfo(markerIndexFile); + return markerData; + } + + private boolean checkIfTransactionFileExists() + { + Path tempTransactionFile = Paths.get(outputDirectory.resolve(TEMP_FILE).getPath()); + return fsClient.exists(tempTransactionFile); + } + + private void updateSpoolingInfo(List exchangeSinkFiles) + throws IOException + { + for (URI sinkFile : exchangeSinkFiles) { + if (!spoolingInfoMap.containsKey(sinkFile)) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "sink file suppose to be present in spoolingInfoMap"); + } + else { + long spoolingOffset = spoolingInfoMap.get(sinkFile).getSpoolingFileOffset() + spoolingInfoMap.get(sinkFile).getSpoolingFileSize(); + long spoolingSize = ((Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE)) - spoolingOffset; + spoolingInfoMap.put(sinkFile, new SpoolingInfo(spoolingOffset, spoolingSize)); + } + } + } + + private void enqueueMarkerDataToSpoolingFile(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + exchangeSink.enqueueMarkerInfo(markerDataFileFooterInfo); + } + + private Map createSpoolingInfo(List sinkFiles) + throws IOException + { + Map localSpoolingInfoMap = new HashMap<>(); + for (URI sinkFile : sinkFiles) { + if (!spoolingInfoMap.containsKey(sinkFile)) { + localSpoolingInfoMap.put(sinkFile, new SpoolingInfo(0, (Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE))); + spoolingInfoMap.put(sinkFile, new SpoolingInfo(0, (Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE))); + } + else { + long spoolingOffset = spoolingInfoMap.get(sinkFile).getSpoolingFileOffset() + spoolingInfoMap.get(sinkFile).getSpoolingFileSize(); + long spoolingSize = ((Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE)) - spoolingOffset; + if (spoolingSize > 0) { + localSpoolingInfoMap.put(sinkFile, new SpoolingInfo(spoolingOffset, spoolingSize)); + spoolingInfoMap.put(sinkFile, new SpoolingInfo(spoolingOffset, spoolingSize)); + } + } + } + return localSpoolingInfoMap; + } + + public void createTempTransactionFile() + { + Path tempTransactionFile = Paths.get(outputDirectory.resolve(TEMP_FILE).getPath()); + try (OutputStream outputStream = fsClient.newOutputStream(tempTransactionFile)) { + RecoveryUtils.serializeState(previousSuccessfulMarkerId, outputStream, false); + RecoveryUtils.serializeState(previousFooterOffset, outputStream, false); + } + catch (Exception e) { + e.printStackTrace(); + } + } + + private void deleteTempTransactionFile() + { + Path tempTransactionFile = Paths.get(outputDirectory.resolve(TEMP_FILE).getPath()); + try { + fsClient.delete(tempTransactionFile); + } + catch (Exception e) { + e.printStackTrace(); + } + } + + public MarkerDataFileFactory.MarkerDataFileFooterInfo enqueueMarkerData(int markerId, Map statesMap, List sinkFiles) + { + if (markerDataFileWriter == null) { + MarkerDataFileFactory markerDataFileFactory = new MarkerDataFileFactory(outputDirectory, fsClient); + this.markerDataFileWriter = markerDataFileFactory.createWriter(MARKER_DATA_FILE, false); //todo(SURYA): add JAVA/KRYO config based decision. + } + MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo = markerDataFileWriter.writeDataFile(markerId, statesMap, previousFooterOffset, previousFooterSize); + previousFooterOffset = markerDataFileFooterInfo.getFooterOffset(); + previousFooterSize = markerDataFileFooterInfo.getFooterSize(); + return markerDataFileFooterInfo; + } + + public Map dequeueMarkerData(MarkerIndexFileFactory.MarkerIndexFile markerIndexFile) + { + MarkerDataFileFactory.MarkerDataFileReader markerDataFileReader = new MarkerDataFileFactory.MarkerDataFileReader(fsClient, markerIndexFile.getMarkerDataFile(), false); + return markerDataFileReader.readDataFile(markerIndexFile.getMarkerStartOffset(), markerIndexFile.getMarkerLength()); + } + + public Map dequeueSpoolingInfo(MarkerIndexFileFactory.MarkerIndexFile markerIndexFile) + { + return markerIndexFile.getSpoolingInfoMap(); + } + + public void enqueueMarkerIndex(int markerId, URI markerDataFile, long markerStartOffset, long markerLength, Map spoolingInfoMap) + { + if (markerIndexFileWriter == null) { + MarkerIndexFileFactory markerIndexFileWriterFactory = new MarkerIndexFileFactory(outputDirectory, fsClient); + this.markerIndexFileWriter = markerIndexFileWriterFactory.createWriter(MARKER_INDEX_FILE); + } + MarkerIndexFileFactory.MarkerIndexFile markerIndexFile = new MarkerIndexFileFactory.MarkerIndexFile(markerId, markerDataFile, markerStartOffset, markerLength, spoolingInfoMap); + markerIndexFileWriter.writeIndexFile(markerIndexFile); + } + + private MarkerIndexFileFactory.MarkerIndexFile dequeueMarkerIndex(int markerId, URI indexFile) + { + MarkerIndexFileFactory.MarkerIndexFileReader markerIndexFileReader = new MarkerIndexFileFactory.MarkerIndexFileReader(fsClient, Paths.get(indexFile.getPath())); + return markerIndexFileReader.readIndexFile(markerId); } @Override @@ -207,4 +383,74 @@ public class HybridSpoolingBuffer return exchangeStorage; } } + + public static class SpoolingInfo + implements Serializable + { + private long spoolingFileOffset; + private long spoolingFileSize; + + @JsonCreator + public SpoolingInfo( + @JsonProperty("spoolingFileOffset") long spoolingFileOffset, + @JsonProperty("spoolingFileSize") long spoolingFileSize) + { + this.spoolingFileOffset = spoolingFileOffset; + this.spoolingFileSize = spoolingFileSize; + } + + @JsonProperty + public long getSpoolingFileOffset() + { + return spoolingFileOffset; + } + + @JsonProperty + public long getSpoolingFileSize() + { + return spoolingFileSize; + } + } + + public Object getMarkerDataFileFooter() + { + try { + InputStream inputStream = fsClient.newInputStream(Paths.get(outputDirectory.resolve(MARKER_DATA_FILE).getPath())); + inputStream.skip(previousFooterOffset); + Object o = RecoveryUtils.deserializeState(inputStream, false); + return o; + } + catch (Exception e) { + e.printStackTrace(); + } + return null; + } + + public Object getMarkerDataFileFooter(long previousFooterOffset) + { + try { + InputStream inputStream = fsClient.newInputStream(Paths.get(outputDirectory.resolve(MARKER_DATA_FILE).getPath())); + inputStream.skip(previousFooterOffset); + Object o = RecoveryUtils.deserializeState(inputStream, false); + return o; + } + catch (Exception e) { + e.printStackTrace(); + } + return null; + } + + public Object getMarkerData(long stateOffset) + { + try { + InputStream inputStream = fsClient.newInputStream(Paths.get(outputDirectory.resolve(MARKER_DATA_FILE).getPath())); + inputStream.skip(stateOffset); + Object o = RecoveryUtils.deserializeState(inputStream, false); + return o; + } + catch (Exception e) { + e.printStackTrace(); + } + return null; + } } diff --git a/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java b/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java index 51451966b..02220aa47 100644 --- a/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java +++ b/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java @@ -15,6 +15,7 @@ package io.prestosql.execution.buffer; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.hetu.core.filesystem.HetuLocalFileSystemClient; @@ -32,13 +33,16 @@ import io.prestosql.exchange.FileSystemExchangeSinkInstanceHandle; import io.prestosql.exchange.FileSystemExchangeStats; import io.prestosql.exchange.storage.FileSystemExchangeStorage; import io.prestosql.exchange.storage.HetuFileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.execution.StageId; import io.prestosql.execution.TaskId; import io.prestosql.operator.PageAssertions; +import io.prestosql.snapshot.SnapshotStateId; import io.prestosql.spi.Page; import io.prestosql.spi.QueryId; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.filesystem.HetuFileSystemClient; import io.prestosql.testing.TestingPagesSerdeFactory; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; @@ -46,11 +50,13 @@ import org.testng.annotations.Test; import java.io.File; import java.io.IOException; +import java.net.URI; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Properties; import java.util.concurrent.ExecutionException; @@ -58,6 +64,7 @@ import java.util.stream.Collectors; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.IntegerType.INTEGER; import static org.testng.Assert.assertEquals; @@ -72,29 +79,28 @@ public class TestHybridSpoolingBuffer private ExchangeSink exchangeSink; private ExchangeManager exchangeManager; private ExchangeSinkInstanceHandle exchangeSinkInstanceHandle; + private final String baseURI = "file:///tmp/hetu/spooling"; private final String baseDir = "/tmp/hetu/spooling"; private final String accessDir = "/tmp/hetu"; private final Path accessPath = Paths.get(accessDir); + private final HetuFileSystemClient hetuFileSystemClient = new HetuLocalFileSystemClient(new LocalConfig(new Properties()), accessPath); @BeforeMethod public void setUp() throws IOException, InterruptedException { Path basePath = Paths.get(baseDir); - File base = new File(baseDir); - if (!base.exists()) { - Files.createDirectories(basePath); - } - else { + File base = new File(accessDir); + if (base.exists()) { deleteDirectory(base); - Files.createDirectories(basePath); } + Files.createDirectories(basePath); } @AfterMethod public void cleanUp() { - File base = new File(baseDir); + File base = new File(accessDir); if (base.exists()) { deleteDirectory(base); } @@ -108,7 +114,7 @@ public class TestHybridSpoolingBuffer .setBaseDirectories(baseDir); FileSystemExchangeStorage exchangeStorage = new HetuFileSystemExchangeStorage(); - exchangeStorage.setFileSystemClient(new HetuLocalFileSystemClient(new LocalConfig(new Properties()), accessPath)); + exchangeStorage.setFileSystemClient(hetuFileSystemClient); exchangeManager = new FileSystemExchangeManager(exchangeStorage, new FileSystemExchangeStats(), config); exchangeSinkInstanceHandle = new FileSystemExchangeSinkInstanceHandle( new FileSystemExchangeSinkHandle(0, Optional.empty(), false), @@ -207,6 +213,79 @@ public class TestHybridSpoolingBuffer hybridSpoolingBuffer.setNoMorePages(); } + @Test + public void testHybridSpoolingMarkerIndexFile() + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + URI markerDataFile = URI.create(baseURI).resolve("marker_data_file.data"); + URI spoolingDataFile = URI.create(baseURI).resolve("spooling_data_file.data"); + HybridSpoolingBuffer.SpoolingInfo spoolingInfo = new HybridSpoolingBuffer.SpoolingInfo(1000, 2000); + hybridSpoolingBuffer.enqueueMarkerIndex(1, markerDataFile, 1000, 2000, ImmutableMap.of(spoolingDataFile, spoolingInfo)); + } + + @Test + public void testHybridSpoolingMarkerDataFile() + { + SnapshotStateId snapshotStateId = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 1, 1)); + SnapshotStateId snapshotStateId1 = new SnapshotStateId(2, new TaskId(new StageId("query", 1), 1, 1)); + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.enqueueMarkerData(1, ImmutableMap.of(snapshotStateId, "marker1"), exchangeSink.getSinkFiles()); + hybridSpoolingBuffer.enqueueMarkerData(2, ImmutableMap.of(snapshotStateId1, "marker2"), exchangeSink.getSinkFiles()); + MarkerDataFileFactory.MarkerDataFileFooter footer = (MarkerDataFileFactory.MarkerDataFileFooter) hybridSpoolingBuffer.getMarkerDataFileFooter(); + Map operatorStateInfoMap = footer.getOperatorStateInfo(); + MarkerDataFileFactory.OperatorStateInfo operatorStateInfo = operatorStateInfoMap.get(snapshotStateId1.getId()); + assertEquals(hybridSpoolingBuffer.getMarkerData(operatorStateInfo.getStateOffset()), "marker2"); + MarkerDataFileFactory.MarkerDataFileFooter previousFooter = (MarkerDataFileFactory.MarkerDataFileFooter) hybridSpoolingBuffer.getMarkerDataFileFooter(footer.getPreviousTailOffset()); + Map previousOperatorStateInfoMap = previousFooter.getOperatorStateInfo(); + MarkerDataFileFactory.OperatorStateInfo previousOperatorStateInfo = previousOperatorStateInfoMap.get(snapshotStateId.getId()); + assertEquals(hybridSpoolingBuffer.getMarkerData(previousOperatorStateInfo.getStateOffset()), "marker1"); + } + + @Test + public void testMarkerSpoolingInfo() + { + int entries = 10; + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, entries); + for (int i = 0; i < entries; i++) { + BIGINT.writeLong(blockBuilder, i); + } + Block block = blockBuilder.build(); + Page page = new Page(block); + SerializedPage serializedPage = javaSerde.serialize(page); + SnapshotStateId snapshotStateId = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 1, 1)); + SnapshotStateId snapshotStateId1 = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 2, 1)); + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.enqueue(0, ImmutableList.of(serializedPage), null); + hybridSpoolingBuffer.enqueueMarkerInfo(1, ImmutableMap.of(snapshotStateId, "marker1", snapshotStateId1, "marker2")); + assertEquals(hybridSpoolingBuffer.dequeueMarkerInfo(1), ImmutableMap.of(snapshotStateId, "marker1", snapshotStateId1, "marker2")); + } + + @Test + public void testMarkerSpoolingInfoMultiPartition() + { + int entries = 10; + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, entries); + for (int i = 0; i < entries; i++) { + BIGINT.writeLong(blockBuilder, i); + } + Block block = blockBuilder.build(); + Page page = new Page(block); + SerializedPage serializedPage = javaSerde.serialize(page); + SnapshotStateId snapshotStateId = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 1, 1)); + SnapshotStateId snapshotStateId1 = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 2, 1)); + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.enqueue(0, ImmutableList.of(serializedPage), null); + hybridSpoolingBuffer.enqueue(1, ImmutableList.of(serializedPage), null); + hybridSpoolingBuffer.enqueueMarkerInfo(1, ImmutableMap.of(snapshotStateId, "marker1")); + hybridSpoolingBuffer.enqueueMarkerInfo(2, ImmutableMap.of(snapshotStateId1, "marker2")); + assertEquals(hybridSpoolingBuffer.dequeueMarkerInfo(1), ImmutableMap.of(snapshotStateId, "marker1")); + assertEquals(hybridSpoolingBuffer.dequeueMarkerInfo(2), ImmutableMap.of(snapshotStateId1, "marker2")); + } + private HybridSpoolingBuffer createHybridSpoolingBuffer() { OutputBuffers outputBuffers = OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.PARTITIONED); diff --git a/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java b/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java index e0706a856..69661c57e 100644 --- a/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java +++ b/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java @@ -25,6 +25,7 @@ import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.exchange.ExchangeSink; import io.prestosql.exchange.ExchangeSinkInstanceHandle; import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.execution.StageId; import io.prestosql.execution.TaskId; import io.prestosql.memory.context.LocalMemoryContext; @@ -35,6 +36,7 @@ import org.testng.annotations.Test; import javax.crypto.SecretKey; import java.net.URI; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -392,6 +394,18 @@ public class TestSpoolingExchangeOutputBuffer return 0; } + @Override + public List getSinkFiles() + { + return ImmutableList.of(); + } + + @Override + public void enqueueMarkerInfo(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + return; + } + public void setAbort(CompletableFuture abort) { this.abort = requireNonNull(abort, "abort is null"); -- Gitee From d33a3d977a7f67ab290a21361f0199957946368a Mon Sep 17 00:00:00 2001 From: Surya Sumanth N Date: Wed, 5 Apr 2023 12:36:57 +0530 Subject: [PATCH 3/4] Hybrid Spooling Parallel Write Support --- .../io/prestosql/SystemSessionProperties.java | 10 ++ .../execution/QueryManagerConfig.java | 13 ++ .../java/io/prestosql/execution/SqlTask.java | 1 + .../buffer/ArbitraryOutputBuffer.java | 153 ++++++++++++++++-- .../buffer/HybridSpoolingBuffer.java | 77 +++++++++ .../execution/buffer/LazyOutputBuffer.java | 149 ++++++++++++++++- .../execution/buffer/OutputBuffer.java | 38 +++++ .../buffer/PartitionedOutputBuffer.java | 45 ++++++ .../operator/TaskOutputOperator.java | 6 +- .../operator/output/PagePartitioner.java | 6 +- 10 files changed, 478 insertions(+), 20 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java index 41d990d2f..1e02f2dab 100644 --- a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java +++ b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java @@ -204,6 +204,7 @@ public final class SystemSessionProperties public static final String FAULT_TOLERANT_EXECUTION_PARTITION_COUNT = "fault_tolerant_execution_partition_count"; public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_GROWTH_FACTOR = "fault_tolerant_execution_task_memory_growth_factor"; public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE = "fault_tolerant_execution_task_memory_estimation_quantile"; + public static final String TASK_ASYNC_PARALLEL_WRITE_ENABLED = "task_async_parallel_write_enabled"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows"; @@ -1087,6 +1088,10 @@ public final class SystemSessionProperties CTE_MATERIALIZATION_SCHEMA_NAME, "Name of the table schema to store cached result data", hetuConfig.getCachingSchemaName(), + false), + booleanProperty(TASK_ASYNC_PARALLEL_WRITE_ENABLED, + "Parallel Writes to client buffer and spool when retry policy is TASK_ASYNC", + queryManagerConfig.isTaskAsyncParallelWriteEnabled(), false)); } @@ -1927,4 +1932,9 @@ public final class SystemSessionProperties { return session.getSystemProperty(CTE_MATERIALIZATION_THRESHOLD_SIZE, DataSize.class); } + + public static boolean isTaskAsyncParallelWriteEnabled(Session session) + { + return session.getSystemProperty(TASK_ASYNC_PARALLEL_WRITE_ENABLED, Boolean.class); + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java b/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java index 5c816a73f..846ac327e 100644 --- a/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java +++ b/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java @@ -96,6 +96,7 @@ public class QueryManagerConfig private String exchangeFilesystemBaseDirectory = "/tmp/hetu-exchange-manager"; private boolean queryResourceTracking; + private boolean taskAsyncParallelWriteEnabled; @Min(1) public int getScheduleSplitBatchSize() @@ -612,4 +613,16 @@ public class QueryManagerConfig this.queryResourceTracking = queryResourceTracking; return this; } + + public boolean isTaskAsyncParallelWriteEnabled() + { + return taskAsyncParallelWriteEnabled; + } + + @Config("task-async-parallel-write-enabled") + public QueryManagerConfig setTaskAsyncParallelWriteEnabled(Boolean taskAsyncParallelWriteEnabled) + { + this.taskAsyncParallelWriteEnabled = taskAsyncParallelWriteEnabled; + return this; + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlTask.java b/presto-main/src/main/java/io/prestosql/execution/SqlTask.java index 0d3e1dee1..736105938 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlTask.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlTask.java @@ -421,6 +421,7 @@ public class SqlTask // The LazyOutput buffer does not support write methods, so the actual // output buffer must be established before drivers are created (e.g. // a VALUES query). + outputBuffer.setSession(session); outputBuffer.setOutputBuffers(outputBuffers); // assure the task execution is only created once diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java index 77f77e6f6..189bc57f4 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java @@ -39,10 +39,12 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -85,6 +87,7 @@ public class ArbitraryOutputBuffer private OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(ARBITRARY); private final MasterBuffer masterBuffer; + private Map asyncMasterBuffers = new HashMap<>(); @GuardedBy("this") private final ConcurrentMap buffers = new ConcurrentHashMap<>(); @@ -97,12 +100,25 @@ public class ArbitraryOutputBuffer private final AtomicLong totalPagesAdded = new AtomicLong(); private final AtomicLong totalRowsAdded = new AtomicLong(); + private final boolean isTaskAsyncParallelWrite; + private Map masterBuffersInfo = new HashMap<>(); + private long pageCounter; public ArbitraryOutputBuffer( OutputBufferStateMachine state, DataSize maxBufferSize, Supplier systemMemoryContextSupplier, Executor notificationExecutor) + { + this (state, maxBufferSize, systemMemoryContextSupplier, notificationExecutor, false); + } + + public ArbitraryOutputBuffer( + OutputBufferStateMachine state, + DataSize maxBufferSize, + Supplier systemMemoryContextSupplier, + Executor notificationExecutor, + boolean isTaskAsyncParallelWrite) { this.stateMachine = requireNonNull(state, "state is null"); requireNonNull(maxBufferSize, "maxBufferSize is null"); @@ -112,6 +128,7 @@ public class ArbitraryOutputBuffer requireNonNull(systemMemoryContextSupplier, "systemMemoryContextSupplier is null"), requireNonNull(notificationExecutor, "notificationExecutor is null")); this.masterBuffer = new MasterBuffer(); + this.isTaskAsyncParallelWrite = isTaskAsyncParallelWrite; } @Override @@ -153,7 +170,15 @@ public class ArbitraryOutputBuffer @SuppressWarnings("FieldAccessNotGuarded") Collection clientBuffers = this.buffers.values(); - int totalBufferedPages = masterBuffer.getBufferedPages(); + int totalBufferedPages = 0; + if (isTaskAsyncParallelWrite) { + for (AsyncMasterBuffer localAsyncMasterBuffers : asyncMasterBuffers.values()) { + totalBufferedPages = localAsyncMasterBuffers.getBufferedPages(); + } + } + else { + totalBufferedPages = masterBuffer.getBufferedPages(); + } ImmutableList.Builder infos = ImmutableList.builder(); for (ClientBuffer buffer : clientBuffers) { BufferInfo bufferInfo = buffer.getInfo(); @@ -266,15 +291,31 @@ public class ArbitraryOutputBuffer .map(pageSplit -> new SerializedPageReference(pageSplit, 1, () -> memoryManager.updateMemoryUsage(-pageSplit.getRetainedSizeInBytes()))) .collect(toImmutableList()); - // add pages to the buffer - masterBuffer.addPages(serializedPageReferences, targetClients); + if (isTaskAsyncParallelWrite) { + enqueuePagesRoundRobin(serializedPageReferences, targetClients); + } + else { + // add pages to the buffer + masterBuffer.addPages(serializedPageReferences, targetClients); + } // process any pending reads from the client buffers - for (ClientBuffer clientBuffer : safeGetBuffersSnapshot()) { - if (masterBuffer.isEmpty()) { - break; + if (isTaskAsyncParallelWrite) { + for (ClientBuffer clientBuffer : safeGetBuffersSnapshot()) { + if (masterBuffer.isEmpty()) { + break; + } + AsyncMasterBuffer localAsyncMasterBuffer = asyncMasterBuffers.get(clientBuffer.getInfo().getBufferId().getId()); + clientBuffer.loadPagesIfNecessary(localAsyncMasterBuffer); + } + } + else { + for (ClientBuffer clientBuffer : safeGetBuffersSnapshot()) { + if (masterBuffer.isEmpty()) { + break; + } + clientBuffer.loadPagesIfNecessary(masterBuffer); } - clientBuffer.loadPagesIfNecessary(masterBuffer); } if (targetClients != null && stateMachine.getState().canAddBuffers()) { @@ -285,6 +326,21 @@ public class ArbitraryOutputBuffer } } + @Override + public long getTokenId(int bufferId, long token) + { + return masterBuffersInfo.get(bufferId + "-" + token); + } + + @Override + public boolean checkIfAcknowledged(int bufferId, long token) + { + if (masterBuffersInfo.containsKey(bufferId + "-" + token)) { + return false; + } + return true; + } + @Override public void enqueue(int partition, List pages, String origin) { @@ -299,6 +355,11 @@ public class ArbitraryOutputBuffer requireNonNull(bufferId, "bufferId is null"); checkArgument(maxSize.toBytes() > 0, "maxSize must be at least 1 byte"); + if (isTaskAsyncParallelWrite) { + AsyncMasterBuffer localAsyncMasterBuffer = asyncMasterBuffers.get(bufferId); + return getBuffer(bufferId).getPages(startingSequenceId, maxSize, Optional.of(localAsyncMasterBuffer)); + } + return getBuffer(bufferId).getPages(startingSequenceId, maxSize, Optional.of(masterBuffer)); } @@ -309,6 +370,10 @@ public class ArbitraryOutputBuffer requireNonNull(bufferId, "bufferId is null"); getBuffer(bufferId).acknowledgePages(sequenceId); + + if (isTaskAsyncParallelWrite) { + masterBuffersInfo.remove(bufferId + "-" + sequenceId); + } } @Override @@ -345,11 +410,22 @@ public class ArbitraryOutputBuffer stateMachine.compareAndSet(NO_MORE_BUFFERS, FLUSHING); memoryManager.setNoBlockOnFull(); - masterBuffer.setNoMorePages(); + if (isTaskAsyncParallelWrite) { + asyncMasterBuffers.values().forEach(buffer -> buffer.setNoMorePages()); + } + else { + masterBuffer.setNoMorePages(); + } // process any pending reads from the client buffers for (ClientBuffer clientBuffer : safeGetBuffersSnapshot()) { - clientBuffer.loadPagesIfNecessary(masterBuffer); + if (isTaskAsyncParallelWrite) { + AsyncMasterBuffer localAsyncMasterBuffer = asyncMasterBuffers.get(clientBuffer.getInfo().getBufferId().getId()); + clientBuffer.loadPagesIfNecessary(localAsyncMasterBuffer); + } + else { + clientBuffer.loadPagesIfNecessary(masterBuffer); + } } checkFlushComplete(); @@ -370,7 +446,12 @@ public class ArbitraryOutputBuffer if (stateMachine.setIf(FINISHED, oldState -> !oldState.isTerminal())) { noMoreBuffers(); - masterBuffer.destroy(); + if (isTaskAsyncParallelWrite) { + asyncMasterBuffers.values().forEach(buffer -> buffer.destroy()); + } + else { + masterBuffer.destroy(); + } safeGetBuffersSnapshot().forEach(ClientBuffer::destroy); @@ -402,6 +483,16 @@ public class ArbitraryOutputBuffer memoryManager.close(); } + private synchronized void enqueuePagesRoundRobin(List serializedPageReferences, Collection targetClients) + { + for (SerializedPageReference serializedPageReference : serializedPageReferences) { + long pageBuffer = pageCounter % buffers.size(); + AsyncMasterBuffer localAsyncMasterBuffer = asyncMasterBuffers.get(pageBuffer); + localAsyncMasterBuffer.addPages(ImmutableList.of(serializedPageReference), targetClients); + masterBuffersInfo.put(pageBuffer + "-" + localAsyncMasterBuffer.getToken(), pageCounter++); + } + } + private synchronized ClientBuffer getBuffer(OutputBufferId id) { ClientBuffer buffer = buffers.get(id); @@ -424,7 +515,17 @@ public class ArbitraryOutputBuffer // add pending markers if (!markersForNewBuffers.isEmpty()) { markersForNewBuffers.forEach(SerializedPageReference::addReference); - masterBuffer.insertMarkers(markersForNewBuffers, buffer); + if (isTaskAsyncParallelWrite) { + AsyncMasterBuffer localAsyncMasterBuffer = asyncMasterBuffers.get(id.getId()); + if (localAsyncMasterBuffer == null) { + localAsyncMasterBuffer = new AsyncMasterBuffer(id.getId()); + asyncMasterBuffers.put(id.getId(), localAsyncMasterBuffer); + } + localAsyncMasterBuffer.insertMarkers(markersForNewBuffers, buffer); + } + else { + masterBuffer.insertMarkers(markersForNewBuffers, buffer); + } } // buffer may have finished immediately before calling this method @@ -479,6 +580,31 @@ public class ArbitraryOutputBuffer } } + @ThreadSafe + private class AsyncMasterBuffer + extends MasterBuffer + { + @GuardedBy("this") + private final int id; + private final AtomicLong token = new AtomicLong(); + + public AsyncMasterBuffer(int id) + { + super(); + this.id = id; + } + + public int getId() + { + return id; + } + + public long getToken() + { + return token.getAndAdd(1); + } + } + @ThreadSafe private class MasterBuffer implements PagesSupplier @@ -654,6 +780,11 @@ public class ArbitraryOutputBuffer this.snapshotState = SystemSessionProperties.isSnapshotEnabled(taskContext.getSession()) ? MultiInputSnapshotState.forTaskComponent(this, taskContext, snapshotId -> SnapshotStateId.forTaskComponent(snapshotId, taskContext, "OutputBuffer")) : null; + if (isTaskAsyncParallelWrite && asyncMasterBuffers.size() == 0) { + for (OutputBufferId outputBufferId : buffers.keySet()) { + asyncMasterBuffers.put(outputBufferId.getId(), new AsyncMasterBuffer(outputBufferId.getId())); + } + } } @Override diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java index e412c55ca..8e0bde245 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java @@ -17,8 +17,10 @@ package io.prestosql.execution.buffer; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.units.DataSize; @@ -41,6 +43,7 @@ import io.prestosql.snapshot.SnapshotStateId; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.filesystem.HetuFileSystemClient; +import org.checkerframework.checker.nullness.qual.Nullable; import javax.crypto.SecretKey; @@ -53,19 +56,24 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.prestosql.exchange.FileSystemExchangeSink.DATA_FILE_SUFFIX; import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.prestosql.spi.filesystem.SupportedFileAttributes.SIZE; import static java.util.concurrent.Executors.newCachedThreadPool; +import static java.util.concurrent.Executors.newFixedThreadPool; public class HybridSpoolingBuffer extends SpoolingExchangeOutputBuffer @@ -96,6 +104,8 @@ public class HybridSpoolingBuffer private PagesSerde kryoSerde; private URI outputDirectory; + private Map writeTokenPerPartition = new HashMap<>(); + private final ListeningExecutorService spoolingExecutorService; public HybridSpoolingBuffer(OutputBufferStateMachine stateMachine, OutputBuffers outputBuffers, ExchangeSink exchangeSink, Supplier memoryContextSupplier, ExchangeManager exchangeManager) { @@ -108,6 +118,73 @@ public class HybridSpoolingBuffer this.executor = newCachedThreadPool(daemonThreadsNamed("exchange-source-handles-creation-%s")); this.exchangeManager = exchangeManager; this.fsClient = exchangeSink.getExchangeStorage().getFileSystemClient(); + this.spoolingExecutorService = listeningDecorator(newFixedThreadPool(1, daemonThreadsNamed("spooling-thread-%s"))); + } + + @Override + public void enqueue(List pages, String origin, boolean isTaskAsyncParallelWrite) + { + if (isTaskAsyncParallelWrite) { + enqueue(0, pages, origin); + } + else { + super.enqueue(0, pages, origin); + } + } + + @Override + public void enqueue(int partition, List pages, String origin, boolean isTaskAsyncParallelWrite) + { + if (isTaskAsyncParallelWrite) { + enqueue(partition, pages, origin); + } + else { + super.enqueue(partition, pages, origin); + } + } + + @Override + public void enqueue(List pages, String origin) + { + enqueue(0, pages, origin); + } + + @Override + public void enqueue(int partition, List pages, String origin) + { + enqueueImpl(partition, pages, origin); + } + + public long getWriteToken(int partition) + { + if (writeTokenPerPartition.containsKey(partition)) { + return writeTokenPerPartition.get(partition).get(); + } + else { + return -1; + } + } + + private void enqueueImpl(int partition, List localPages, String origin) + { + List spoolPages = new LinkedList<>(localPages); + ListenableFuture future = spoolingExecutorService.submit(() -> super.enqueue(partition, spoolPages, origin)); + Futures.addCallback(future, new FutureCallback() { + @Override + public void onSuccess(@Nullable Object result) + { + if (!writeTokenPerPartition.containsKey(partition)) { + writeTokenPerPartition.put(partition, new AtomicLong()); + } + writeTokenPerPartition.get(partition).getAndAdd(spoolPages.size()); + } + + @Override + public void onFailure(Throwable t) + { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "spooling failed"); + } + }, directExecutor()); } public void enqueueMarkerInfo(int markerId, Map states) diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java index 86e7caad0..98d435826 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java @@ -21,11 +21,13 @@ import io.airlift.log.Logger; import io.airlift.units.DataSize; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.Session; import io.prestosql.exchange.ExchangeManager; import io.prestosql.exchange.ExchangeManagerRegistry; import io.prestosql.exchange.ExchangeSink; import io.prestosql.exchange.ExchangeSinkInstanceHandle; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.TaskId; import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId; @@ -38,14 +40,20 @@ import javax.annotation.concurrent.GuardedBy; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.prestosql.SystemSessionProperties.getRetryPolicy; +import static io.prestosql.SystemSessionProperties.isTaskAsyncParallelWriteEnabled; import static io.prestosql.execution.buffer.BufferResult.emptyResults; import static io.prestosql.execution.buffer.BufferState.FAILED; import static io.prestosql.execution.buffer.BufferState.FINISHED; @@ -72,9 +80,23 @@ public class LazyOutputBuffer @GuardedBy("this") private final List pendingReads = new ArrayList<>(); + + @GuardedBy("this") + private Map partitionedTokenIds = new ConcurrentHashMap<>(); + + @GuardedBy("this") + private AtomicLong bufferSize = new AtomicLong(); + + @GuardedBy("this") + private Map> serializedPartitionedPages = new ConcurrentHashMap<>(); + private final ExchangeManagerRegistry exchangeManagerRegistry; private Optional exchangeSink; + private long tokenRemoved; + private boolean noMorePages; + private boolean isTaskAsyncParallelWriteEnabled; + private PagesSerde serde; private PagesSerde javaSerde; private PagesSerde kryoSerde; @@ -177,13 +199,14 @@ public class LazyOutputBuffer } switch (newOutputBuffers.getType()) { case PARTITIONED: - delegate = new PartitionedOutputBuffer(stateMachine, newOutputBuffers, maxBufferSize, systemMemoryContextSupplier, executor); + delegate = new PartitionedOutputBuffer(stateMachine, newOutputBuffers, maxBufferSize, systemMemoryContextSupplier, executor, isTaskAsyncParallelWriteEnabled); break; case BROADCAST: + isTaskAsyncParallelWriteEnabled = false; delegate = new BroadcastOutputBuffer(stateMachine, maxBufferSize, systemMemoryContextSupplier, executor); break; case ARBITRARY: - delegate = new ArbitraryOutputBuffer(stateMachine, maxBufferSize, systemMemoryContextSupplier, executor); + delegate = new ArbitraryOutputBuffer(stateMachine, maxBufferSize, systemMemoryContextSupplier, executor, isTaskAsyncParallelWriteEnabled); break; default: //TODO(Alex): decide the spool output buffer @@ -259,6 +282,26 @@ public class LazyOutputBuffer } outputBuffer = delegate; } + if (noMorePages && isTaskAsyncParallelWriteEnabled && hybridSpoolingDelegate != null && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + if (serializedPartitionedPages.containsKey(bufferId.getId()) && serializedPartitionedPages.get(bufferId.getId()).size() == 0) { + serializedPartitionedPages.remove(bufferId.getId()); + } + if (serializedPartitionedPages.size() != 0) { + long tokenAcknowledged = hybridSpoolingDelegate.getWriteToken(bufferId.getId()) >= 0 ? hybridSpoolingDelegate.getWriteToken(bufferId.getId()) : tokenRemoved; + for (int count = 0; serializedPartitionedPages.containsKey(bufferId.getId()) && partitionedTokenIds.containsKey(bufferId.getId()) && count < (tokenAcknowledged - partitionedTokenIds.get(bufferId.getId())); count++) { + SerializedPage serializedPage = serializedPartitionedPages.get(bufferId.getId()).poll(); + bufferSize.getAndAdd(-serializedPage.getSizeInBytes()); + if (serializedPartitionedPages.get(bufferId.getId()).size() == 0) { + serializedPartitionedPages.remove(bufferId.getId()); + } + } + partitionedTokenIds.put(bufferId.getId(), tokenAcknowledged); + } + else { + hybridSpoolingDelegate.setNoMorePages(); + outputBuffer.setNoMorePages(); + } + } return outputBuffer.get(bufferId, token, maxSize); } @@ -270,6 +313,69 @@ public class LazyOutputBuffer checkState(delegate != null, "delegate is null"); outputBuffer = delegate; } + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + boolean clientAcknowledged = outputBuffer.checkIfAcknowledged(bufferId.getId(), token); + if (outputBuffer instanceof ArbitraryOutputBuffer) { + if (!clientAcknowledged) { + long tokenAcknowledged = Math.min(outputBuffer.getTokenId(bufferId.getId(), token), hybridSpoolingDelegate.getWriteToken(0)); + for (int count = 0; count < (tokenAcknowledged - tokenRemoved); count++) { + SerializedPage serializedPage = serializedPartitionedPages.get(0).poll(); + bufferSize.getAndAdd(-serializedPage.getSizeInBytes()); + } + tokenRemoved = tokenAcknowledged; + } + if (noMorePages) { + if (serializedPartitionedPages.size() != 0) { + long tokenAcknowledged = hybridSpoolingDelegate.getWriteToken(0); + if (serializedPartitionedPages.containsKey(0) && serializedPartitionedPages.get(0).size() == 0) { + serializedPartitionedPages.remove(0); + } + for (int count = 0; count < (tokenAcknowledged - tokenRemoved) && serializedPartitionedPages.containsKey(bufferId.getId()); count++) { + SerializedPage serializedPage = serializedPartitionedPages.get(0).poll(); + bufferSize.getAndAdd(-serializedPage.getSizeInBytes()); + if (serializedPartitionedPages.get(0).size() == 0) { + serializedPartitionedPages.remove(0); + } + } + tokenRemoved = tokenAcknowledged; + } + else { + hybridSpoolingDelegate.setNoMorePages(); + outputBuffer.setNoMorePages(); + } + } + } + if (outputBuffer instanceof PartitionedOutputBuffer) { + if (!clientAcknowledged) { + long tokenAcknowledged = Math.min(outputBuffer.getTokenId(bufferId.getId(), token), hybridSpoolingDelegate.getWriteToken(bufferId.getId())); + for (int count = 0; partitionedTokenIds.containsKey(bufferId.getId()) && count < (tokenAcknowledged - partitionedTokenIds.get(bufferId.getId())); count++) { + SerializedPage serializedPage = serializedPartitionedPages.get(bufferId.getId()).poll(); + bufferSize.getAndAdd(-serializedPage.getSizeInBytes()); + } + partitionedTokenIds.put(bufferId.getId(), tokenAcknowledged); + } + if (noMorePages) { + if (serializedPartitionedPages .size() != 0) { + long tokenAcknowledged = hybridSpoolingDelegate.getWriteToken(bufferId.getId()); + if (serializedPartitionedPages.get(bufferId.getId()).size() == 0) { + serializedPartitionedPages.remove(bufferId.getId()); + } + for (int count = 0; partitionedTokenIds.containsKey(bufferId.getId()) && count < (tokenAcknowledged - partitionedTokenIds.get(bufferId.getId())); count++) { + SerializedPage serializedPage = serializedPartitionedPages.get(bufferId.getId()).poll(); + bufferSize.getAndAdd(-serializedPage.getSizeInBytes()); + if (serializedPartitionedPages.get(bufferId.getId()).size() == 0) { + serializedPartitionedPages.remove(bufferId.getId()); + } + } + partitionedTokenIds.put(bufferId.getId(), tokenAcknowledged); + } + } + else { + hybridSpoolingDelegate.setNoMorePages(); + outputBuffer.setNoMorePages(); + } + } + } outputBuffer.acknowledge(bufferId, token); } @@ -330,6 +436,14 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + if (!serializedPartitionedPages.containsKey(0)) { + serializedPartitionedPages.put(0, new ConcurrentLinkedQueue<>()); + } + serializedPartitionedPages.get(0).addAll(pages); + pages.stream().forEach(serializedPage -> bufferSize.addAndGet(serializedPage.getSizeInBytes())); + hybridSpoolingDelegate.enqueue(pages, origin); + } outputBuffer.enqueue(pages, origin); } @@ -341,6 +455,14 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + if (!serializedPartitionedPages.containsKey(partition)) { + serializedPartitionedPages.put(partition, new ConcurrentLinkedQueue<>()); + } + serializedPartitionedPages.get(partition).addAll(pages); + pages.stream().forEach(serializedPage -> bufferSize.addAndGet(serializedPage.getSizeInBytes())); + hybridSpoolingDelegate.enqueue(partition, pages, origin); + } outputBuffer.enqueue(partition, pages, origin); } @@ -352,10 +474,21 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } - if (hybridSpoolingDelegate != null) { - hybridSpoolingDelegate.setNoMorePages(); + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + noMorePages = true; + } + else { + if (hybridSpoolingDelegate != null) { + hybridSpoolingDelegate.setNoMorePages(); + } + outputBuffer.setNoMorePages(); } - outputBuffer.setNoMorePages(); + } + + @Override + public void setSession(Session session) + { + isTaskAsyncParallelWriteEnabled = (getRetryPolicy(session) == RetryPolicy.TASK_ASYNC && isTaskAsyncParallelWriteEnabled(session)); } @Override @@ -531,6 +664,12 @@ public class LazyOutputBuffer return hybridSpoolingDelegate; } + @Override + public OutputBuffer getDelegate() + { + return delegate; + } + @Override public DirectSerialisationType getDelegateSpoolingExchangeDirectSerializationType() { diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java index 773b27056..dcf4a6b16 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java @@ -17,6 +17,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.Session; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId; @@ -132,6 +133,20 @@ public interface OutputBuffer */ void enqueue(int partition, List pages, String origin); + /** + * Adds a split-up page to an unpartitioned buffer. If no-more-pages has been set, the enqueue + * page call is ignored. This can happen with limit queries. + */ + default void enqueue(List pages, String origin, boolean isTaskAsyncParallelWrite) + {} + + /** + * Adds a split-up page to a specific partition. If no-more-pages has been set, the enqueue + * page call is ignored. This can happen with limit queries. + */ + default void enqueue(int partition, List pages, String origin, boolean isTaskAsyncParallelWrite) + {} + /** * Notify buffer that no more pages will be added. Any future calls to enqueue a * page are ignored. @@ -187,6 +202,11 @@ public interface OutputBuffer return null; } + default OutputBuffer getDelegate() + { + return null; + } + default DirectSerialisationType getDelegateSpoolingExchangeDirectSerializationType() { return DirectSerialisationType.JAVA; @@ -203,4 +223,22 @@ public interface OutputBuffer default void setKryoSerde(PagesSerde pagesSerde) { } + + default boolean checkIfAcknowledged(int bufferId, long token) + { + return true; + } + + default long getTokenId(int bufferId, long token) + { + return -1; + } + + default long getWriteToken(int partition) + { + return -1; + } + + default void setSession(Session session) + {} } diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java index f671b3ff9..899a198c5 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java @@ -33,7 +33,9 @@ import io.prestosql.spi.snapshot.RestorableConfig; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.Executor; @@ -72,12 +74,27 @@ public class PartitionedOutputBuffer private final AtomicLong totalPagesAdded = new AtomicLong(); private final AtomicLong totalRowsAdded = new AtomicLong(); + private final boolean isTaskAsyncParallelWrite; + private Map partitionTokenMap = new HashMap<>(); + private Map tokenPerPartition = new HashMap<>(); + public PartitionedOutputBuffer( OutputBufferStateMachine state, OutputBuffers outputBuffers, DataSize maxBufferSize, Supplier systemMemoryContextSupplier, Executor notificationExecutor) + { + this(state, outputBuffers, maxBufferSize, systemMemoryContextSupplier, notificationExecutor, false); + } + + public PartitionedOutputBuffer( + OutputBufferStateMachine state, + OutputBuffers outputBuffers, + DataSize maxBufferSize, + Supplier systemMemoryContextSupplier, + Executor notificationExecutor, + boolean isTaskAsyncParallelWrite) { this.stateMachine = requireNonNull(state, "state is null"); requireNonNull(outputBuffers, "outputBuffers is null"); @@ -94,8 +111,12 @@ public class PartitionedOutputBuffer for (OutputBufferId bufferId : outputBuffers.getBuffers().keySet()) { ClientBuffer partition = new ClientBuffer(bufferId); partitionsBuffer.add(partition); + if (isTaskAsyncParallelWrite) { + tokenPerPartition.put(bufferId.getId(), new AtomicLong()); + } } this.partitions = partitionsBuffer.build(); + this.isTaskAsyncParallelWrite = isTaskAsyncParallelWrite; state.compareAndSet(OPEN, NO_MORE_BUFFERS); state.compareAndSet(NO_MORE_PAGES, FLUSHING); @@ -179,6 +200,21 @@ public class PartitionedOutputBuffer return memoryManager.getBufferBlockedFuture(); } + @Override + public boolean checkIfAcknowledged(int bufferId, long token) + { + if (partitionTokenMap.containsKey(bufferId + "-" + token)) { + return false; + } + return true; + } + + @Override + public long getTokenId(int bufferId, long token) + { + return partitionTokenMap.get(bufferId + "-" + token); + } + @Override public void enqueue(List pages, String origin) { @@ -245,6 +281,12 @@ public class PartitionedOutputBuffer // add pages to the buffer (this will increase the reference count by one) partitions.get(partitionNumber).enqueuePages(serializedPageReferences); + if (isTaskAsyncParallelWrite) { + if (!partitionTokenMap.containsKey(partitionNumber + "-" + tokenPerPartition.get(partitionNumber).get())) { + partitionTokenMap.put(partitionNumber + "-" + tokenPerPartition.get(partitionNumber).get(), tokenPerPartition.get(partitionNumber).get()); + } + partitionTokenMap.put(partitionNumber + "-" + tokenPerPartition.get(partitionNumber), tokenPerPartition.get(partitionNumber).addAndGet(1)); + } // drop the initial reference serializedPageReferences.forEach(SerializedPageReference::dereferencePage); @@ -265,6 +307,9 @@ public class PartitionedOutputBuffer requireNonNull(outputBufferId, "bufferId is null"); partitions.get(outputBufferId.getId()).acknowledgePages(sequenceId); + if (isTaskAsyncParallelWrite) { + partitionTokenMap.remove(outputBufferId.getId() + "-" + sequenceId); + } } @Override diff --git a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java index 737ee270f..60e59e209 100644 --- a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java @@ -17,6 +17,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; +import io.prestosql.execution.buffer.BroadcastOutputBuffer; import io.prestosql.execution.buffer.OutputBuffer; import io.prestosql.snapshot.SingleInputSnapshotState; import io.prestosql.spi.Page; @@ -33,6 +34,7 @@ import java.util.function.Function; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.SystemSessionProperties.isTaskAsyncParallelWriteEnabled; import static io.prestosql.execution.buffer.PageSplitterUtil.splitPage; import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static java.util.Objects.requireNonNull; @@ -223,7 +225,7 @@ public class TaskOutputOperator } } else { - if (outputBuffer.isSpoolingDelegateAvailable()) { + if (outputBuffer.isSpoolingDelegateAvailable() && (!isTaskAsyncParallelWriteEnabled(operatorContext.getDriverContext().getPipelineContext().getSession()) || outputBuffer.getDelegate() instanceof BroadcastOutputBuffer)) { OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); if (spoolingSerialisationType != DirectSerialisationType.OFF) { PagesSerde directSerde = (spoolingSerialisationType == DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); @@ -234,7 +236,7 @@ public class TaskOutputOperator } else { if (spoolingBuffer != null) { - spoolingBuffer.enqueue(serializedPages, id); + spoolingBuffer.enqueue(serializedPages, id, false); } } } diff --git a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java index 2debb2841..295640f9d 100644 --- a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java +++ b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java @@ -18,6 +18,7 @@ import io.airlift.units.DataSize; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.execution.buffer.BroadcastOutputBuffer; import io.prestosql.execution.buffer.OutputBuffer; import io.prestosql.operator.OperatorContext; import io.prestosql.operator.PartitionFunction; @@ -45,6 +46,7 @@ import java.util.function.IntUnaryOperator; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.SystemSessionProperties.isTaskAsyncParallelWriteEnabled; import static io.prestosql.execution.buffer.PageSplitterUtil.splitPage; import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static java.lang.Math.max; @@ -470,7 +472,7 @@ public class PagePartitioner .map(page -> operatorContext.getDriverContext().getSerde().serialize(page)) .collect(toImmutableList()); - if (outputBuffer.isSpoolingDelegateAvailable()) { + if (outputBuffer.isSpoolingDelegateAvailable() && (!isTaskAsyncParallelWriteEnabled(operatorContext.getDriverContext().getPipelineContext().getSession()) || outputBuffer.getDelegate() instanceof BroadcastOutputBuffer)) { OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); if (spoolingSerialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { PagesSerde directSerde = (spoolingSerialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); @@ -481,7 +483,7 @@ public class PagePartitioner } else { if (spoolingBuffer != null) { - spoolingBuffer.enqueue(partition, serializedPages, id); + spoolingBuffer.enqueue(partition, serializedPages, id, false); } } } -- Gitee From e55abcc2c0a7bb5261db87d7919d2bc4115c16c8 Mon Sep 17 00:00:00 2001 From: Surya Sumanth N Date: Thu, 13 Apr 2023 15:14:17 +0530 Subject: [PATCH 4/4] Fix Parallel Write Hang Issue and Buffer Pool Issue --- .../execution/buffer/LazyOutputBuffer.java | 22 +++++++++++-------- .../operator/TaskOutputOperator.java | 4 +++- .../operator/output/PagePartitioner.java | 4 +++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java index 98d435826..2f54076e7 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java @@ -313,7 +313,8 @@ public class LazyOutputBuffer checkState(delegate != null, "delegate is null"); outputBuffer = delegate; } - if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer) + && hybridSpoolingDelegate.getExchangeDirectSerialisationType() == DirectSerialisationType.OFF) { boolean clientAcknowledged = outputBuffer.checkIfAcknowledged(bufferId.getId(), token); if (outputBuffer instanceof ArbitraryOutputBuffer) { if (!clientAcknowledged) { @@ -355,7 +356,7 @@ public class LazyOutputBuffer partitionedTokenIds.put(bufferId.getId(), tokenAcknowledged); } if (noMorePages) { - if (serializedPartitionedPages .size() != 0) { + if (serializedPartitionedPages.size() != 0) { long tokenAcknowledged = hybridSpoolingDelegate.getWriteToken(bufferId.getId()); if (serializedPartitionedPages.get(bufferId.getId()).size() == 0) { serializedPartitionedPages.remove(bufferId.getId()); @@ -369,10 +370,10 @@ public class LazyOutputBuffer } partitionedTokenIds.put(bufferId.getId(), tokenAcknowledged); } - } - else { - hybridSpoolingDelegate.setNoMorePages(); - outputBuffer.setNoMorePages(); + else { + hybridSpoolingDelegate.setNoMorePages(); + outputBuffer.setNoMorePages(); + } } } } @@ -436,7 +437,8 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } - if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer) + && hybridSpoolingDelegate.getExchangeDirectSerialisationType() == DirectSerialisationType.OFF) { if (!serializedPartitionedPages.containsKey(0)) { serializedPartitionedPages.put(0, new ConcurrentLinkedQueue<>()); } @@ -455,7 +457,8 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } - if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer) + && hybridSpoolingDelegate.getExchangeDirectSerialisationType() == DirectSerialisationType.OFF) { if (!serializedPartitionedPages.containsKey(partition)) { serializedPartitionedPages.put(partition, new ConcurrentLinkedQueue<>()); } @@ -474,7 +477,8 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } - if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer)) { + if (hybridSpoolingDelegate != null && isTaskAsyncParallelWriteEnabled && delegate != null && !(delegate instanceof BroadcastOutputBuffer) + && hybridSpoolingDelegate.getExchangeDirectSerialisationType() == DirectSerialisationType.OFF) { noMorePages = true; } else { diff --git a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java index 60e59e209..2f2561328 100644 --- a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java @@ -225,7 +225,9 @@ public class TaskOutputOperator } } else { - if (outputBuffer.isSpoolingDelegateAvailable() && (!isTaskAsyncParallelWriteEnabled(operatorContext.getDriverContext().getPipelineContext().getSession()) || outputBuffer.getDelegate() instanceof BroadcastOutputBuffer)) { + if (outputBuffer.isSpoolingDelegateAvailable() && ((!isTaskAsyncParallelWriteEnabled(operatorContext.getDriverContext().getPipelineContext().getSession()) + || outputBuffer.getSpoolingDelegate().getExchangeDirectSerialisationType() != DirectSerialisationType.OFF) + || outputBuffer.getDelegate() instanceof BroadcastOutputBuffer)) { OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); if (spoolingSerialisationType != DirectSerialisationType.OFF) { PagesSerde directSerde = (spoolingSerialisationType == DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); diff --git a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java index 295640f9d..dc56dd549 100644 --- a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java +++ b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java @@ -472,7 +472,9 @@ public class PagePartitioner .map(page -> operatorContext.getDriverContext().getSerde().serialize(page)) .collect(toImmutableList()); - if (outputBuffer.isSpoolingDelegateAvailable() && (!isTaskAsyncParallelWriteEnabled(operatorContext.getDriverContext().getPipelineContext().getSession()) || outputBuffer.getDelegate() instanceof BroadcastOutputBuffer)) { + if (outputBuffer.isSpoolingDelegateAvailable() && ((!isTaskAsyncParallelWriteEnabled(operatorContext.getDriverContext().getPipelineContext().getSession()) + || outputBuffer.getSpoolingDelegate().getExchangeDirectSerialisationType() != FileSystemExchangeConfig.DirectSerialisationType.OFF) + || outputBuffer.getDelegate() instanceof BroadcastOutputBuffer)) { OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); if (spoolingSerialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { PagesSerde directSerde = (spoolingSerialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); -- Gitee