diff --git a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java index 61bacc659902413cd552e40d720558021fdb3e1c..69328b20614fd3176fd11f4dfc9adb3265adb1f0 100644 --- a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java +++ b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuHdfsFileSystemClient.java @@ -18,9 +18,11 @@ import com.google.common.base.Throwables; import io.airlift.log.Logger; import io.prestosql.spi.filesystem.SupportedFileAttributes; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.PathIsNotEmptyDirectoryException; +import org.apache.hadoop.hdfs.client.HdfsDataOutputStream; import org.apache.hadoop.hdfs.protocol.AlreadyBeingCreatedException; import org.apache.hadoop.ipc.RemoteException; @@ -35,6 +37,7 @@ import java.nio.file.NoSuchFileException; import java.nio.file.OpenOption; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.EnumSet; import java.util.UUID; import java.util.stream.Stream; @@ -267,6 +270,16 @@ public class HetuHdfsFileSystemClient } } + @Override + public void flush(OutputStream outputStream) + throws IOException + { + if (outputStream instanceof HdfsDataOutputStream) { + ((HdfsDataOutputStream) outputStream).hsync(EnumSet.of(HdfsDataOutputStream.SyncFlag.UPDATE_LENGTH)); + ((FSDataOutputStream) outputStream).hflush(); + } + } + @Override public void close() throws IOException diff --git a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java index f9e0c9eb85a7ba868c9ceeb3954cd7d97d2cf5b4..c4a52c1583e6cef353b09de8280ad1016d24f6f5 100644 --- a/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java +++ b/hetu-filesystem-client/src/main/java/io/hetu/core/filesystem/HetuLocalFileSystemClient.java @@ -249,4 +249,11 @@ public class HetuLocalFileSystemClient String glob = prefix + "*" + suffix; return StreamSupport.stream(newDirectoryStream(path, glob).spliterator(), false); } + + @Override + public void flush(OutputStream outputStream) + throws IOException + { + outputStream.flush(); + } } diff --git a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java index 41d990d2f80dbb95ff7098eb7c9aee3fec60f98f..3733bfab496b250501a86b3a1f7ff17551ea3fc4 100644 --- a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java +++ b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java @@ -210,7 +210,7 @@ public final class SystemSessionProperties public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold"; public static final String RETRY_POLICY = "retry_policy"; - + public static final String TASK_SNAPSHOT_TIME_INTERVAL = "task_snapshot_time_interval"; public static final String EXCHANGE_FILESYSTEM_BASE_DIRECTORY = "exchange_filesystem_base_directory"; public static final String QUERY_RESOURCE_TRACKING = "query_resource_tracking_enabled"; @@ -940,6 +940,11 @@ public final class SystemSessionProperties RetryPolicy.class, queryManagerConfig.getRetryPolicy(), true), + durationProperty( + TASK_SNAPSHOT_TIME_INTERVAL, + "Task snapshot time interval", + queryManagerConfig.getTaskSnapshotTimeInterval(), + false), integerProperty( TASK_RETRY_ATTEMPTS_OVERALL, "Maximum number of task retry attempts overall", @@ -1773,6 +1778,11 @@ public final class SystemSessionProperties return retryPolicy; } + public static Duration getTaskSnapshotTimeInterval(Session session) + { + return session.getSystemProperty(TASK_SNAPSHOT_TIME_INTERVAL, Duration.class); + } + public static int getTaskRetryAttemptsOverall(Session session) { return session.getSystemProperty(TASK_RETRY_ATTEMPTS_OVERALL, Integer.class); diff --git a/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java b/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java index eb174df1adab07ffc5da2142576057e4264f872e..f8310479f7f0aa3a1f7fac885ecdf80d5986212d 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java +++ b/presto-main/src/main/java/io/prestosql/exchange/ExchangeSink.java @@ -16,8 +16,15 @@ package io.prestosql.exchange; import io.airlift.slice.Slice; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.spi.Page; +import javax.crypto.SecretKey; + +import java.net.URI; +import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; public interface ExchangeSink @@ -40,4 +47,18 @@ public interface ExchangeSink { return DirectSerialisationType.OFF; } + + FileSystemExchangeStorage getExchangeStorage(); + + URI getOutputDirectory(); + + Optional getSecretKey(); + + boolean isExchangeCompressionEnabled(); + + int getPartitionId(); + + List getSinkFiles(); + + void enqueueMarkerInfo(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo); } diff --git a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java index c7fa8b459dd08137bf08abbeb87d57712110e0be..ec4424b4bd91ab2a3806d852a25fc9d926dfb492 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java +++ b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeManager.java @@ -137,7 +137,8 @@ public class FileSystemExchangeManager exchangeSinkBuffersPerPartition, exchangeSinkMaxFileSizeInBytes, directSerialisationType, - directSerialisationBufferSize); + directSerialisationBufferSize, + instanceHandle.getPartiitonId()); } @Override @@ -157,7 +158,8 @@ public class FileSystemExchangeManager exchangeSinkBuffersPerPartition, exchangeSinkMaxFileSizeInBytes, serType, - directSerialisationBufferSize); + directSerialisationBufferSize, + instanceHandle.getPartiitonId()); } @Override diff --git a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java index d985a4d4624f6b9082800177a5660e1c7c09284f..71e9f26480efb0e08a20afb5a30c8b10588e1a12 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java +++ b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSink.java @@ -25,6 +25,7 @@ import io.hetu.core.transport.execution.buffer.PagesSerde; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; import io.prestosql.exchange.storage.ExchangeStorageWriter; import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.util.SizeOf; @@ -46,6 +47,7 @@ import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -91,6 +93,7 @@ public class FileSystemExchangeSink private boolean closed; private final DirectSerialisationType directSerialisationType; private final int directSerialisationBufferSize; + private final int partitionId; public FileSystemExchangeSink( FileSystemExchangeStorage exchangeStorage, @@ -105,7 +108,8 @@ public class FileSystemExchangeSink int exchangeSinkBuffersPerPartition, long maxFileSizeInBytes, DirectSerialisationType directSerialisationType, - int directSerialisationBufferSize) + int directSerialisationBufferSize, + int partitionId) { checkArgument(maxPageStorageSizeInBytes <= maxFileSizeInBytes, format("maxPageStorageSizeInBytes %s exceeded maxFileSizeInBytes %s", succinctBytes(maxPageStorageSizeInBytes), succinctBytes(maxFileSizeInBytes))); @@ -129,6 +133,7 @@ public class FileSystemExchangeSink else { this.bufferPool = null; } + this.partitionId = partitionId; } @Override @@ -142,6 +147,48 @@ public class FileSystemExchangeSink return directSerialisationType; } + @Override + public URI getOutputDirectory() + { + return outputDirectory; + } + + @Override + public Optional getSecretKey() + { + return secretKey; + } + + @Override + public boolean isExchangeCompressionEnabled() + { + return exchangeCompressionEnabled; + } + + @Override + public FileSystemExchangeStorage getExchangeStorage() + { + return exchangeStorage; + } + + @Override + public int getPartitionId() + { + return partitionId; + } + + @Override + public List getSinkFiles() + { + return writerMap.values().stream().flatMap(bufferedStorageWriter -> bufferedStorageWriter.getSpoolFiles().stream()).collect(Collectors.toList()); + } + + @Override + public void enqueueMarkerInfo(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + writerMap.values().stream().forEach(bufferedStorageWriter -> bufferedStorageWriter.enqueueMarkerMetadata(markerDataFileFooterInfo)); + } + @Override public void add(int partitionId, Slice data) { @@ -428,7 +475,7 @@ public class FileSystemExchangeSink currentBuffer.writeBytes(slice.getBytes(position, writableBytes)); position += writableBytes; - flushIfNeeded(false); + flushIfNeeded(true); } } @@ -476,6 +523,16 @@ public class FileSystemExchangeSink addExceptionCallback(writeFuture, throwable -> failure.compareAndSet(null, throwable)); } } + + public List getSpoolFiles() + { + return writers.stream().map(writer -> writer.getFile()).collect(Collectors.toList()); + } + + public void enqueueMarkerMetadata(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + writers.stream().forEach(exchangeStorageWriter -> exchangeStorageWriter.writeMarkerMetadata(markerDataFileFooterInfo)); + } } @ThreadSafe diff --git a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java index 9e86443ac88ff5721cb98aaba637f6794a1eada8..74c11ae6b8385e081d1a12b12295ee0accdb2113 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java +++ b/presto-main/src/main/java/io/prestosql/exchange/FileSystemExchangeSinkInstanceHandle.java @@ -55,4 +55,10 @@ public class FileSystemExchangeSinkInstanceHandle { return outputPartitionCount; } + + @JsonProperty + public int getPartiitonId() + { + return sinkHandle.getPartitionId(); + } } diff --git a/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java b/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java index d03e98e941010e25b52b6723d88d94f7b1f6faa1..803ee70c5e7f450dc09099250ec6b01eb308b718 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java +++ b/presto-main/src/main/java/io/prestosql/exchange/RetryPolicy.java @@ -18,7 +18,9 @@ import io.prestosql.spi.connector.RetryMode; public enum RetryPolicy { TASK(RetryMode.RETRIES_ENABLED), + TASK_WITH_SNAPSHOT(RetryMode.RETRIES_ENABLED), NONE(RetryMode.NO_RETRIES), + TASK_ASYNC(RetryMode.RETRIES_ENABLED), /**/; private final RetryMode retryMode; diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java b/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java index 1c3efd0910a0471f422fcf1782c0fbb9c554951f..eb8f5a63a396cdc46b86efe732aafe09511a8bf6 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/ExchangeStorageWriter.java @@ -16,8 +16,11 @@ package io.prestosql.exchange.storage; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.spi.Page; +import java.net.URI; + public interface ExchangeStorageWriter { ListenableFuture write(Slice slice); @@ -29,4 +32,8 @@ public interface ExchangeStorageWriter ListenableFuture abort(); long getRetainedSize(); + + URI getFile(); + + void writeMarkerMetadata(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo); } diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java b/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java index 69c3ff968c7ccd62f88a942792f73bed56e37519..e8812671c92a7c7b204d7944e11df9453b71e4d3 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/FileSystemExchangeStorage.java @@ -32,6 +32,8 @@ public interface FileSystemExchangeStorage { public void setFileSystemClient(HetuFileSystemClient fsClient); + HetuFileSystemClient getFileSystemClient(); + void createDirectories(URI dir) throws IOException; ExchangeStorageReader createExchangeReader(Queue sourceFiles, int maxPageSize, DirectSerialisationType directSerialisationType, int directSerialisationBufferSize); diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java index 5c1bce39a1f49d49f3a8018765e3ad6be959a900..e8d56dcc9ca2b3358ab7f39253b966e4378eacbc 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeStorage.java @@ -59,6 +59,12 @@ public class HetuFileSystemExchangeStorage fileSystemClient = fsClient; } + @Override + public HetuFileSystemClient getFileSystemClient() + { + return fileSystemClient; + } + @Override public void createDirectories(URI dir) throws IOException { diff --git a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java index 79e8b0902adf7543e0febd14b63603a513ba9b78..b68700d21c280985a9d5a4d18a23d762e5e0f268 100644 --- a/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java +++ b/presto-main/src/main/java/io/prestosql/exchange/storage/HetuFileSystemExchangeWriter.java @@ -21,6 +21,8 @@ import io.airlift.slice.Slice; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.prestosql.exchange.FileSystemExchangeConfig; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; +import io.prestosql.execution.MarkerDataFileFactory; +import io.prestosql.snapshot.RecoveryUtils; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.filesystem.HetuFileSystemClient; @@ -55,35 +57,41 @@ public class HetuFileSystemExchangeWriter private final OutputStream outputStream; private final DirectSerialisationType directSerialisationType; private final int directSerialisationBufferSize; + private final HetuFileSystemClient fileSystemClient; + private final OutputStream delegateOutputStream; + private final URI file; public HetuFileSystemExchangeWriter(URI file, HetuFileSystemClient fileSystemClient, Optional secretKey, boolean exchangeCompressionEnabled, AlgorithmParameterSpec algorithmParameterSpec, FileSystemExchangeConfig.DirectSerialisationType directSerialisationType, int directSerialisationBufferSize) { this.directSerialisationBufferSize = directSerialisationBufferSize; this.directSerialisationType = directSerialisationType; + this.fileSystemClient = fileSystemClient; + this.file = file; try { Path path = Paths.get(file.toString()); + this.delegateOutputStream = fileSystemClient.newOutputStream(path); if (secretKey.isPresent() && exchangeCompressionEnabled) { Cipher cipher = Cipher.getInstance(CIPHER_TRANSFORMATION); cipher.init(Cipher.ENCRYPT_MODE, secretKey.get(), algorithmParameterSpec); - this.outputStream = new SnappyFramedOutputStream(new CipherOutputStream(fileSystemClient.newOutputStream(path), cipher)); + this.outputStream = new SnappyFramedOutputStream(new CipherOutputStream(delegateOutputStream, cipher)); } else if (secretKey.isPresent()) { Cipher cipher = Cipher.getInstance(CIPHER_TRANSFORMATION); cipher.init(Cipher.ENCRYPT_MODE, secretKey.get(), algorithmParameterSpec); - this.outputStream = new CipherOutputStream(fileSystemClient.newOutputStream(path), cipher); + this.outputStream = new CipherOutputStream(delegateOutputStream, cipher); } else if (exchangeCompressionEnabled) { - this.outputStream = new SnappyFramedOutputStream(new OutputStreamSliceOutput(fileSystemClient.newOutputStream(path), directSerialisationBufferSize)); + this.outputStream = new SnappyFramedOutputStream(new OutputStreamSliceOutput(delegateOutputStream, directSerialisationBufferSize)); } else { if (directSerialisationType == DirectSerialisationType.KRYO) { - this.outputStream = new Output(fileSystemClient.newOutputStream(path), directSerialisationBufferSize); + this.outputStream = new Output(delegateOutputStream, directSerialisationBufferSize); } else if (directSerialisationType == DirectSerialisationType.JAVA) { - this.outputStream = new OutputStreamSliceOutput(fileSystemClient.newOutputStream(path), directSerialisationBufferSize); + this.outputStream = new OutputStreamSliceOutput(delegateOutputStream, directSerialisationBufferSize); } else { - this.outputStream = new OutputStreamSliceOutput(fileSystemClient.newOutputStream(path), directSerialisationBufferSize); + this.outputStream = new OutputStreamSliceOutput(delegateOutputStream, directSerialisationBufferSize); } } } @@ -98,6 +106,8 @@ public class HetuFileSystemExchangeWriter { try { outputStream.write(slice.getBytes()); + outputStream.flush(); + fileSystemClient.flush(delegateOutputStream); } catch (IOException | RuntimeException e) { return immediateFailedFuture(e); @@ -110,6 +120,13 @@ public class HetuFileSystemExchangeWriter { checkState(directSerialisationType != DirectSerialisationType.OFF, "Should be used with direct serialization is enabled!"); serde.serialize(outputStream, page); + try { + outputStream.flush(); + fileSystemClient.flush(delegateOutputStream); + } + catch (IOException | RuntimeException e) { + return immediateFailedFuture(e); + } return immediateFuture(null); } @@ -142,4 +159,23 @@ public class HetuFileSystemExchangeWriter { return INSTANCE_SIZE; } + + @Override + public URI getFile() + { + return file; + } + + @Override + public void writeMarkerMetadata(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + try { + RecoveryUtils.serializeState(markerDataFileFooterInfo, outputStream, false); + outputStream.flush(); + fileSystemClient.flush(delegateOutputStream); + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to write marker data"); + } + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/MarkerDataFileFactory.java b/presto-main/src/main/java/io/prestosql/execution/MarkerDataFileFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..351f758e7865fb26e8c229621d01e0da99fe6f36 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/MarkerDataFileFactory.java @@ -0,0 +1,277 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution; + +import io.prestosql.snapshot.RecoveryUtils; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.filesystem.HetuFileSystemClient; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.prestosql.spi.filesystem.SupportedFileAttributes.SIZE; +import static java.util.Objects.requireNonNull; + +public class MarkerDataFileFactory +{ + private final URI targetOutputDirectory; + private final HetuFileSystemClient hetuFileSystemClient; + + public MarkerDataFileFactory(URI taskOutputDirectory, HetuFileSystemClient hetuFileSystemClient) + { + this.targetOutputDirectory = taskOutputDirectory; + this.hetuFileSystemClient = requireNonNull(hetuFileSystemClient, "hetuFileSystemClient is null"); + } + + public MarkerDataFileWriter createWriter(String file, boolean useKryo) + { + URI fileURI = targetOutputDirectory.resolve(file); + return new MarkerDataFileWriter(hetuFileSystemClient, fileURI, useKryo); + } + + public MarkerDataFileReader createReader(URI file, boolean useKryo) + { + return new MarkerDataFileReader(hetuFileSystemClient, file, useKryo); + } + + public static class MarkerDataFileWriter + { + private final OutputStream outputStream; + private final boolean useKryo; + private HetuFileSystemClient hetuFileSystemClient; + private Path path; + private URI fileURI; + + public MarkerDataFileWriter(HetuFileSystemClient hetuFileSystemClient, URI fileURI, boolean useKryo) + { + this.hetuFileSystemClient = hetuFileSystemClient; + this.path = Paths.get(fileURI.getPath()); + this.fileURI = fileURI; + try { + outputStream = hetuFileSystemClient.newOutputStream(path); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to create MarkerDataFileWriter"); + } + this.useKryo = useKryo; + } + + public MarkerDataFileFooterInfo writeDataFile(long markerId, Map state, long previousFooterOffset, long previousFooterSize) + { + Map operatorStateInfoMap = new HashMap<>(); + try { + for (Map.Entry entry : state.entrySet()) { + long offset = (Long) hetuFileSystemClient.getAttribute(path, SIZE); + RecoveryUtils.serializeState(entry.getValue(), outputStream, useKryo); + outputStream.flush(); + hetuFileSystemClient.flush(outputStream); + long size = (Long) hetuFileSystemClient.getAttribute(path, SIZE) - offset; + operatorStateInfoMap.put(entry.getKey(), new OperatorStateInfo(offset, size)); + } + // write footer + MarkerDataFileFooter footer = new MarkerDataFileFooter(markerId, previousFooterOffset, previousFooterSize, state.size(), operatorStateInfoMap); + long footerOffset = (Long) hetuFileSystemClient.getAttribute(path, SIZE); + RecoveryUtils.serializeState(footer, outputStream, useKryo); + outputStream.flush(); + hetuFileSystemClient.flush(outputStream); + long footerSize = (Long) hetuFileSystemClient.getAttribute(path, SIZE) - footerOffset; + return new MarkerDataFileFooterInfo(fileURI, footer, footerOffset, footerSize); + } + catch (IOException e) { + // add logs + System.out.println(1); + } + return null; + } + + public void close() + { + try { + outputStream.close(); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "failed to close output stream"); + } + } + } + + public static class MarkerDataFileReader + { + private final boolean useKryo; + private Path path; + private HetuFileSystemClient hetuFileSystemClient; + + public MarkerDataFileReader(HetuFileSystemClient hetuFileSystemClient, URI fileURI, boolean useKryo) + { + this.path = Paths.get(fileURI.getPath()); + this.useKryo = useKryo; + this.hetuFileSystemClient = hetuFileSystemClient; + } + + public Map readDataFile(long markerDataOffset, long markerDataLength) + { + try { + InputStream inputStream = hetuFileSystemClient.newInputStream(path); + inputStream.skip(markerDataOffset); + MarkerDataFileFooter markerDataFileFooter = (MarkerDataFileFooter) RecoveryUtils.deserializeState(inputStream, useKryo); + int operatorCount = markerDataFileFooter.getOperatorCount(); + Map operatorStateInfoMap = markerDataFileFooter.getOperatorStateInfo(); + checkArgument(operatorCount == operatorStateInfoMap.size(), "operator data mismatch"); + inputStream.close(); + + Map states = new HashMap<>(); + try { + for (Map.Entry operatorStateInfoEntry : operatorStateInfoMap.entrySet()) { + inputStream = hetuFileSystemClient.newInputStream(path); + inputStream.skip(operatorStateInfoEntry.getValue().getStateOffset()); + states.put(operatorStateInfoEntry.getKey(), RecoveryUtils.deserializeState(inputStream, useKryo)); + } + return states; + } + finally { + inputStream.close(); + } + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed reading data file"); + } + } + } + + public static class MarkerDataFileFooterInfo + implements Serializable + { + private URI path; + private MarkerDataFileFooter footer; + private long footerOffset; + private long footerSize; + + public MarkerDataFileFooterInfo(URI path, MarkerDataFileFooter footer, long footerOffset, long footerSize) + { + this.path = path; + this.footer = footer; + this.footerOffset = footerOffset; + this.footerSize = footerSize; + } + + public long getFooterOffset() + { + return footerOffset; + } + + public long getFooterSize() + { + return footerSize; + } + + public MarkerDataFileFooter getFooter() + { + return footer; + } + + public URI getPath() + { + return path; + } + } + + public static class MarkerDataFileFooter + implements Serializable + { + private long markerId; + private int version; + private long previousTailOffset; + private long previousTailSize; + private int operatorCount; + private Map operatorStateInfo; + + public MarkerDataFileFooter(long markerId, long previousTailOffset, long previousTailSize, int operatorCount, Map operatorStateInfo) + { + this(markerId, 1, previousTailOffset, previousTailSize, operatorCount, operatorStateInfo); + } + + public MarkerDataFileFooter(long markerId, int version, long previousTailOffset, long previousTailSize, int operatorCount, Map operatorStateInfo) + { + this.markerId = markerId; + this.version = version; + this.previousTailOffset = previousTailOffset; + this.operatorCount = operatorCount; + this.operatorStateInfo = operatorStateInfo; + this.previousTailSize = previousTailSize; + } + + public long getMarkerId() + { + return markerId; + } + + public int getVersion() + { + return version; + } + + public long getPreviousTailOffset() + { + return previousTailOffset; + } + + public long getPreviousTailSize() + { + return previousTailSize; + } + + public int getOperatorCount() + { + return operatorCount; + } + + public Map getOperatorStateInfo() + { + return operatorStateInfo; + } + } + + public static class OperatorStateInfo + implements Serializable + { + private long stateOffset; + private long stateSize; + + public OperatorStateInfo(long stateOffset, long stateSize) + { + this.stateOffset = stateOffset; + this.stateSize = stateSize; + } + + public long getStateOffset() + { + return stateOffset; + } + + public long getStateSize() + { + return stateSize; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/MarkerIndexFileFactory.java b/presto-main/src/main/java/io/prestosql/execution/MarkerIndexFileFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..41a0fbedb3868f6d054337e32bf1250009998952 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/MarkerIndexFileFactory.java @@ -0,0 +1,206 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.prestosql.execution.buffer.HybridSpoolingBuffer; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.filesystem.HetuFileSystemClient; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.util.Objects.requireNonNull; + +public class MarkerIndexFileFactory +{ + private final URI targetOutputDirectory; + private final HetuFileSystemClient hetuFileSystemClient; + + public MarkerIndexFileFactory(URI taskOutputDirectory, HetuFileSystemClient hetuFileSystemClient) + { + this.targetOutputDirectory = taskOutputDirectory; + this.hetuFileSystemClient = requireNonNull(hetuFileSystemClient, "hetuFileSystemClient is null"); + } + + public MarkerIndexFileWriter createWriter(String file) + { + Path path = Paths.get(targetOutputDirectory.resolve(file).getPath()); + return new MarkerIndexFileWriter(hetuFileSystemClient, path); + } + + public static class MarkerIndexFileWriter + { + private final OutputStream outputStream; + private final HetuFileSystemClient hetuFileSystemClient; + + public MarkerIndexFileWriter(HetuFileSystemClient hetuFileSystemClient, Path path) + { + try { + outputStream = hetuFileSystemClient.newOutputStream(path); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "failed in MarkerIndexFileWriter"); + } + this.hetuFileSystemClient = hetuFileSystemClient; + } + + public void writeIndexFile(MarkerIndexFile markerIndexFile) + { + JsonFactory jsonFactory = new JsonFactory(); + jsonFactory.configure(JsonGenerator.Feature.AUTO_CLOSE_TARGET, false); + ObjectMapper objectMapper = new ObjectMapper(jsonFactory); + try { + objectMapper.writeValue(outputStream, markerIndexFile); + outputStream.write("\n".getBytes(StandardCharsets.UTF_8)); + outputStream.flush(); + hetuFileSystemClient.flush(outputStream); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed in writeIndexFile"); + } + } + + public void close() + { + try { + outputStream.close(); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "failed to close output stream"); + } + } + } + + public static class MarkerIndexFileReader + { + private final InputStream inputStream; + + public MarkerIndexFileReader(HetuFileSystemClient hetuFileSystemClient, Path path) + { + try { + inputStream = hetuFileSystemClient.newInputStream(path); + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to create marker index file reader"); + } + } + + public List readIndexFile(long markerId) + { + List result = new ArrayList<>(); + JsonFactory jsonFactory = new JsonFactory(); + jsonFactory.configure(JsonParser.Feature.AUTO_CLOSE_SOURCE, false); + MarkerIndexFile markerIndexFile = null; + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8"))); + String line = reader.readLine(); + while (line != null) { + ObjectMapper objectMapper = new ObjectMapper(jsonFactory); + markerIndexFile = objectMapper.readValue(line, MarkerIndexFile.class); + if (markerId == markerIndexFile.getMarkerId()) { + result.add(markerIndexFile); + } + line = reader.readLine(); + } + if (result.isEmpty()) { + reader.close(); + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Marker ID: " + markerId + " is not present in index file"); + } + reader.close(); + return result; + } + catch (IOException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to read marker index file"); + } + } + } + + public static class MarkerIndexFile + { + private final long markerId; + private final URI markerDataFile; + private final long markerStartOffset; + private final long markerLength; + private final Map spoolingInfoMap; + + @JsonCreator + public MarkerIndexFile( + @JsonProperty("markerID") long markerId, + @JsonProperty("markerDataFile") URI markerDataFile, + @JsonProperty("markerStartOffset") long markerStartOffset, + @JsonProperty("markerLength") long markerLength, + @JsonProperty("spoolingInfoMap") Map spoolingInfoMap) + { + this.markerId = markerId; + this.markerDataFile = markerDataFile; + this.markerStartOffset = markerStartOffset; + this.markerLength = markerLength; + this.spoolingInfoMap = spoolingInfoMap; + } + + public static MarkerIndexFile createMarkerIndexFile(int markerId, URI markerDataFile, long markerStartOffset, long markerLength, Map spoolingInfoMap) + { + return new MarkerIndexFile(markerId, markerDataFile, markerStartOffset, markerLength, spoolingInfoMap); + } + + @JsonProperty + public long getMarkerId() + { + return markerId; + } + + @JsonProperty + public URI getMarkerDataFile() + { + return markerDataFile; + } + + @JsonProperty + public long getMarkerStartOffset() + { + return markerStartOffset; + } + + @JsonProperty + public long getMarkerLength() + { + return markerLength; + } + + @JsonProperty + public Map getSpoolingInfoMap() + { + return spoolingInfoMap; + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java b/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java index 5c816a73fa98ec877e558498b36a44bc076e4c29..c9f6b6b45579f29ff29175d4bf3368f8e3b86c00 100644 --- a/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java +++ b/presto-main/src/main/java/io/prestosql/execution/QueryManagerConfig.java @@ -74,6 +74,7 @@ public class QueryManagerConfig private Duration requiredWorkersMaxWait = new Duration(5, TimeUnit.MINUTES); private RetryPolicy retryPolicy = RetryPolicy.NONE; + private Duration taskSnapshotTimeInterval = new Duration(10, SECONDS); private int taskRetryAttemptsPerTask = 4; private int taskRetryAttemptsOverall = Integer.MAX_VALUE; @@ -382,6 +383,19 @@ public class QueryManagerConfig return this; } + @Config("task-snapshot-time-interval") + public QueryManagerConfig setTaskSnapshotTimeInterval(Duration taskSnapshotTimeInterval) + { + this.taskSnapshotTimeInterval = taskSnapshotTimeInterval; + return this; + } + + @NotNull + public Duration getTaskSnapshotTimeInterval() + { + return taskSnapshotTimeInterval; + } + @NotNull public RetryPolicy getRetryPolicy() { diff --git a/presto-main/src/main/java/io/prestosql/execution/ScheduledSplit.java b/presto-main/src/main/java/io/prestosql/execution/ScheduledSplit.java index 1a6e05a9cd03978dac05a16af27a171071904188..736b10a6aa2bc9793fb3f124cfaf0d60e1d5fba7 100644 --- a/presto-main/src/main/java/io/prestosql/execution/ScheduledSplit.java +++ b/presto-main/src/main/java/io/prestosql/execution/ScheduledSplit.java @@ -73,6 +73,10 @@ public class ScheduledSplit return false; } final ScheduledSplit other = (ScheduledSplit) obj; + //MarkerSplit + if (this.sequenceId == -1 || other.sequenceId == -1) { + return false; + } return this.sequenceId == other.sequenceId; } diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java index 776da2dcd7e56608c6665cbe9dfc7bf6af8bdea8..3d80f9252355c855daaa9a4c749c824fc959c288 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java @@ -902,7 +902,8 @@ public class SqlQueryExecution DistributedExecutionPlanner distributedPlanner = new DistributedExecutionPlanner(splitManager, metadata); StageExecutionPlan outputStageExecutionPlan; Session session = stateMachine.getSession(); - if (SystemSessionProperties.isRecoveryEnabled(session)) { + if (SystemSessionProperties.isRecoveryEnabled(session) + || SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC) { // Recovery: need to plan different when recovery is enabled. // See the "plan" method for difference between the different modes. MarkerAnnouncer announcer = splitManager.getMarkerAnnouncer(session); diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java index ad847703ee6b4e4e6123d942fc50dc63b7f1060e..b3a873dd7c27f87f95706e050e5893a741bd879c 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java @@ -550,7 +550,9 @@ public final class SqlStageExecution if (this.outputBuffers.compareAndSet(currentOutputBuffers, outputBuffers)) { for (RemoteTask task : getAllTasks()) { - task.setOutputBuffers(outputBuffers); + OutputBuffers localOutputBuffers = new OutputBuffers(outputBuffers.getType(), outputBuffers.getVersion(), + outputBuffers.isNoMoreBufferIds(), outputBuffers.getBuffers(), outputBuffers.getExchangeSinkInstanceHandle()); + task.setOutputBuffers(localOutputBuffers); } return; } @@ -658,6 +660,8 @@ public final class SqlStageExecution }); OutputBuffers localOutputBuffers = this.outputBuffers.get(); + localOutputBuffers = new OutputBuffers(localOutputBuffers.getType(), localOutputBuffers.getVersion(), + localOutputBuffers.isNoMoreBufferIds(), localOutputBuffers.getBuffers(), localOutputBuffers.getExchangeSinkInstanceHandle()); checkState(localOutputBuffers != null, "Initial output buffers must be set before a task can be scheduled"); if (sinkExchange.isPresent()) { diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java index d11ede481d9181c6789df4e24b67b2d9faff7e04..5648bbd535f49c2ec1648b597ccb1f9000e75a85 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java @@ -28,6 +28,7 @@ import io.airlift.units.Duration; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.SystemSessionProperties; import io.prestosql.event.SplitMonitor; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.buffer.BufferState; import io.prestosql.execution.buffer.OutputBuffer; @@ -211,7 +212,7 @@ public class SqlTaskExecution this.taskId = taskStateMachine.getTaskId(); this.taskContext = requireNonNull(taskContext, "taskContext is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); - recoveryEnabled = SystemSessionProperties.isRecoveryEnabled(taskContext.getSession()); + recoveryEnabled = SystemSessionProperties.isRecoveryEnabled(taskContext.getSession()) || (SystemSessionProperties.getRetryPolicy(taskContext.getSession()) == RetryPolicy.TASK_ASYNC); this.taskExecutor = requireNonNull(taskExecutor, "driverExecutor is null"); this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null"); @@ -358,13 +359,13 @@ public class SqlTaskExecution Map updatedUnpartitionedSources = new HashMap<>(); List sources = inputSources; - // first remove any split that was already acknowledged + // first remove any split that was already acknowledged, skip checking sequence id for markerTaskSource long currentMaxAcknowledgedSplit = this.maxAcknowledgedSplit; sources = sources.stream() .map(source -> new TaskSource( source.getPlanNodeId(), source.getSplits().stream() - .filter(scheduledSplit -> scheduledSplit.getSequenceId() > currentMaxAcknowledgedSplit) + .filter(scheduledSplit -> (scheduledSplit.getSequenceId() > currentMaxAcknowledgedSplit)) .collect(Collectors.toSet()), // Like splits, noMoreSplitsForLifespan could be pruned so that only new items will be processed. // This is not happening here because correctness won't be compromised due to duplicate events for noMoreSplitsForLifespan. @@ -372,7 +373,7 @@ public class SqlTaskExecution source.isNoMoreSplits())) .collect(toList()); - // update maxAcknowledgedSplit + // update maxAcknowledgedSplit, skip for MarkerTaskSource maxAcknowledgedSplit = sources.stream() .flatMap(source -> source.getSplits().stream()) .mapToLong(ScheduledSplit::getSequenceId) @@ -1203,6 +1204,11 @@ public class SqlTaskExecution return driverFactory.getPipelineExecutionStrategy(); } + public Optional getSourceId() + { + return driverFactory.getSourceId(); + } + public OptionalInt getDriverInstances() { return driverFactory.getDriverInstances(); diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java index 77f77e6f64f4264f4c08ebbcfed8feed5df9c007..79ff62c65fb4c312d07ce55aba4f48c6923f4fa7 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/ArbitraryOutputBuffer.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.buffer.ClientBuffer.PagesSupplier; import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId; @@ -651,7 +652,7 @@ public class ArbitraryOutputBuffer checkArgument(taskContext != null, "taskContext is null"); checkState(this.taskContext == null, "setTaskContext is called multiple times"); this.taskContext = taskContext; - this.snapshotState = SystemSessionProperties.isSnapshotEnabled(taskContext.getSession()) + this.snapshotState = SystemSessionProperties.isSnapshotEnabled(taskContext.getSession()) || (SystemSessionProperties.getRetryPolicy(taskContext.getSession()) == RetryPolicy.TASK_ASYNC) ? MultiInputSnapshotState.forTaskComponent(this, taskContext, snapshotId -> SnapshotStateId.forTaskComponent(snapshotId, taskContext, "OutputBuffer")) : null; } diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/BroadcastOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/BroadcastOutputBuffer.java index cfb212a40dcb5eb7e11b15cde679e255cb010275..6d286baaf97bd0d307429e808c5cb832759916c4 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/BroadcastOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/BroadcastOutputBuffer.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId; import io.prestosql.memory.context.LocalMemoryContext; @@ -461,7 +462,7 @@ public class BroadcastOutputBuffer checkArgument(taskContext != null, "taskContext is null"); checkState(this.taskContext == null, "setTaskContext is called multiple times"); this.taskContext = taskContext; - this.snapshotState = SystemSessionProperties.isSnapshotEnabled(taskContext.getSession()) + this.snapshotState = SystemSessionProperties.isSnapshotEnabled(taskContext.getSession()) || SystemSessionProperties.getRetryPolicy(taskContext.getSession()) == RetryPolicy.TASK_ASYNC ? MultiInputSnapshotState.forTaskComponent(this, taskContext, snapshotId -> SnapshotStateId.forTaskComponent(snapshotId, taskContext, "OutputBuffer")) : null; } diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java new file mode 100644 index 0000000000000000000000000000000000000000..f448cb7a968e37650c0550cf20430408a6792469 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/HybridSpoolingBuffer.java @@ -0,0 +1,462 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.buffer; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.log.Logger; +import io.airlift.slice.Slice; +import io.airlift.units.DataSize; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.PagesSerdeUtil; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.exchange.ExchangeManager; +import io.prestosql.exchange.ExchangeSink; +import io.prestosql.exchange.ExchangeSource; +import io.prestosql.exchange.ExchangeSourceHandle; +import io.prestosql.exchange.FileStatus; +import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.exchange.FileSystemExchangeSourceHandle; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; +import io.prestosql.execution.MarkerIndexFileFactory; +import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.snapshot.RecoveryUtils; +import io.prestosql.spi.Page; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.filesystem.HetuFileSystemClient; + +import javax.crypto.SecretKey; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.function.Supplier; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.prestosql.exchange.FileSystemExchangeSink.DATA_FILE_SUFFIX; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.prestosql.spi.filesystem.SupportedFileAttributes.SIZE; +import static java.util.concurrent.Executors.newCachedThreadPool; + +public class HybridSpoolingBuffer + extends SpoolingExchangeOutputBuffer +{ + private static final String PARENT_URI = ".."; + private static final String MARKER_INDEX_FILE = "marker_index.json"; + private static final String MARKER_DATA_FILE = "marker_data_file.data"; + private static final String TEMP_FILE = "marker_temp_txn.data"; + private static final Logger LOG = Logger.get(HybridSpoolingBuffer.class); + + private final OutputBufferStateMachine stateMachine; + private final OutputBuffers outputBuffers; + private final ExchangeSink exchangeSink; + private ExchangeSource exchangeSource; + private final Supplier memoryContextSupplier; + private final ExecutorService executor; + private final ExchangeManager exchangeManager; + private MarkerIndexFileFactory.MarkerIndexFileWriter markerIndexFileWriter; + private MarkerDataFileFactory.MarkerDataFileWriter markerDataFileWriter; + private long previousFooterOffset; + private long previousFooterSize; + private long previousSuccessfulMarkerId; + private Map spoolingInfoMap = new HashMap<>(); + private final HetuFileSystemClient fsClient; + private int token; + private PagesSerde serde; + private PagesSerde javaSerde; + private PagesSerde kryoSerde; + + private URI outputDirectory; + + public HybridSpoolingBuffer(OutputBufferStateMachine stateMachine, OutputBuffers outputBuffers, ExchangeSink exchangeSink, Supplier memoryContextSupplier, ExchangeManager exchangeManager) + { + super(stateMachine, outputBuffers, exchangeSink, memoryContextSupplier); + this.stateMachine = stateMachine; + this.outputBuffers = outputBuffers; + this.exchangeSink = exchangeSink; + this.memoryContextSupplier = memoryContextSupplier; + this.outputDirectory = exchangeSink.getOutputDirectory().resolve(PARENT_URI); + this.executor = newCachedThreadPool(daemonThreadsNamed("exchange-source-handles-creation-%s")); + this.exchangeManager = exchangeManager; + this.fsClient = exchangeSink.getExchangeStorage().getFileSystemClient(); + } + + public void enqueueMarkerInfo(long markerId, Map states) + { + createTempTransactionFile(); // this is to ensure transaction is successful. Deleted at end. + List exchangeSinkFiles = exchangeSink.getSinkFiles(); + MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo = enqueueMarkerData(markerId, states, exchangeSinkFiles); + Map spoolingFileInfo = new HashMap<>(); + try { + spoolingFileInfo.putAll(createSpoolingInfo(exchangeSinkFiles)); + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed creating spooling Info"); + } + + // write marker metadata to spooling file + enqueueMarkerDataToSpoolingFile(markerDataFileFooterInfo); + try { + updateSpoolingInfo(exchangeSinkFiles); + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed updating spooling Info"); + } + enqueueMarkerIndex(markerId, outputDirectory.resolve(MARKER_DATA_FILE), previousFooterOffset, previousFooterSize, spoolingFileInfo); + deleteTempTransactionFile(); + previousSuccessfulMarkerId = markerId; + } + + public Map dequeueMarkerInfo(long markerId) + { + if (checkIfTransactionFileExists()) { + //todo(SURYA): add logic to get recent successful marker and offsets + throw new PrestoException(GENERIC_INTERNAL_ERROR, "transaction file exists for marker: " + markerId); + } + List markerIndexFileList = dequeueMarkerIndex(markerId, outputDirectory.resolve(MARKER_INDEX_FILE)); + Map resultMarkerData = new HashMap<>(); + for (MarkerIndexFileFactory.MarkerIndexFile markerIndexFile : markerIndexFileList) { + resultMarkerData.putAll(dequeueMarkerData(markerIndexFile)); + } + return resultMarkerData; + } + + private boolean checkIfTransactionFileExists() + { + Path tempTransactionFile = Paths.get(outputDirectory.resolve(TEMP_FILE).getPath()); + return fsClient.exists(tempTransactionFile); + } + + private void updateSpoolingInfo(List exchangeSinkFiles) + throws IOException + { + for (URI sinkFile : exchangeSinkFiles) { + if (!spoolingInfoMap.containsKey(sinkFile)) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "sink file suppose to be present in spoolingInfoMap"); + } + else { + long spoolingOffset = spoolingInfoMap.get(sinkFile).getSpoolingFileOffset() + spoolingInfoMap.get(sinkFile).getSpoolingFileSize(); + long spoolingSize = ((Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE)) - spoolingOffset; + spoolingInfoMap.put(sinkFile, new SpoolingInfo(spoolingOffset, spoolingSize)); + } + } + } + + private void enqueueMarkerDataToSpoolingFile(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + exchangeSink.enqueueMarkerInfo(markerDataFileFooterInfo); + } + + private Map createSpoolingInfo(List sinkFiles) + throws IOException + { + Map localSpoolingInfoMap = new HashMap<>(); + for (URI sinkFile : sinkFiles) { + if (!spoolingInfoMap.containsKey(sinkFile)) { + localSpoolingInfoMap.put(sinkFile, new SpoolingInfo(0, (Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE))); + spoolingInfoMap.put(sinkFile, new SpoolingInfo(0, (Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE))); + } + else { + long spoolingOffset = spoolingInfoMap.get(sinkFile).getSpoolingFileOffset() + spoolingInfoMap.get(sinkFile).getSpoolingFileSize(); + long spoolingSize = ((Long) fsClient.getAttribute(Paths.get(sinkFile.getPath()), SIZE)) - spoolingOffset; + if (spoolingSize > 0) { + localSpoolingInfoMap.put(sinkFile, new SpoolingInfo(spoolingOffset, spoolingSize)); + spoolingInfoMap.put(sinkFile, new SpoolingInfo(spoolingOffset, spoolingSize)); + } + } + } + return localSpoolingInfoMap; + } + + public void createTempTransactionFile() + { + Path tempTransactionFile = Paths.get(outputDirectory.resolve(TEMP_FILE).getPath()); + try (OutputStream outputStream = fsClient.newOutputStream(tempTransactionFile)) { + RecoveryUtils.serializeState(previousSuccessfulMarkerId, outputStream, false); + RecoveryUtils.serializeState(previousFooterOffset, outputStream, false); + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to create temp transaction file"); + } + } + + private void deleteTempTransactionFile() + { + Path tempTransactionFile = Paths.get(outputDirectory.resolve(TEMP_FILE).getPath()); + try { + fsClient.delete(tempTransactionFile); + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to delete temp transaction file"); + } + } + + public MarkerDataFileFactory.MarkerDataFileFooterInfo enqueueMarkerData(long markerId, Map statesMap, List sinkFiles) + { + if (markerDataFileWriter == null) { + MarkerDataFileFactory markerDataFileFactory = new MarkerDataFileFactory(outputDirectory, fsClient); + this.markerDataFileWriter = markerDataFileFactory.createWriter(MARKER_DATA_FILE, false); //todo(SURYA): add JAVA/KRYO config based decision. + } + MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo = markerDataFileWriter.writeDataFile(markerId, statesMap, previousFooterOffset, previousFooterSize); + previousFooterOffset = markerDataFileFooterInfo.getFooterOffset(); + previousFooterSize = markerDataFileFooterInfo.getFooterSize(); + return markerDataFileFooterInfo; + } + + public Map dequeueMarkerData(MarkerIndexFileFactory.MarkerIndexFile markerIndexFile) + { + MarkerDataFileFactory.MarkerDataFileReader markerDataFileReader = new MarkerDataFileFactory.MarkerDataFileReader(fsClient, markerIndexFile.getMarkerDataFile(), false); + return markerDataFileReader.readDataFile(markerIndexFile.getMarkerStartOffset(), markerIndexFile.getMarkerLength()); + } + + public Map dequeueSpoolingInfo(MarkerIndexFileFactory.MarkerIndexFile markerIndexFile) + { + return markerIndexFile.getSpoolingInfoMap(); + } + + public void enqueueMarkerIndex(long markerId, URI markerDataFile, long markerStartOffset, long markerLength, Map spoolingInfoMap) + { + if (markerIndexFileWriter == null) { + MarkerIndexFileFactory markerIndexFileWriterFactory = new MarkerIndexFileFactory(outputDirectory, fsClient); + this.markerIndexFileWriter = markerIndexFileWriterFactory.createWriter(MARKER_INDEX_FILE); + } + MarkerIndexFileFactory.MarkerIndexFile markerIndexFile = new MarkerIndexFileFactory.MarkerIndexFile(markerId, markerDataFile, markerStartOffset, markerLength, spoolingInfoMap); + markerIndexFileWriter.writeIndexFile(markerIndexFile); + } + + private List dequeueMarkerIndex(long markerId, URI indexFile) + { + MarkerIndexFileFactory.MarkerIndexFileReader markerIndexFileReader = new MarkerIndexFileFactory.MarkerIndexFileReader(fsClient, Paths.get(indexFile.getPath())); + return markerIndexFileReader.readIndexFile(markerId); + } + + @Override + public ListenableFuture get(OutputBuffers.OutputBufferId bufferId, long token, DataSize maxSize) + { + if (exchangeSource == null) { + exchangeSource = createExchangeSource(exchangeSink.getPartitionId()); + if (exchangeSource == null) { + return immediateFuture(BufferResult.emptyResults(token, false)); + } + } + return immediateFuture(readPages(token)); + } + + private BufferResult readPages(long tokenId) + { + List result = new ArrayList<>(); + FileSystemExchangeConfig.DirectSerialisationType directSerialisationType = exchangeSource.getDirectSerialisationType(); + if (directSerialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { + PagesSerde pagesSerde = (directSerialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? javaSerde : kryoSerde; + while (token < tokenId) { + exchangeSource.readPage(pagesSerde); + token++; + } + Page page = exchangeSource.readPage(pagesSerde); + result.add(serde.serialize(page)); + } + else { + while (token < tokenId) { + exchangeSource.read(); + token++; + } + Slice slice = exchangeSource.read(); + SerializedPage serializedPage = slice != null ? PagesSerdeUtil.readSerializedPage(slice) : null; + result.add(serializedPage); + } + return new BufferResult(tokenId, tokenId + result.size(), false, result); + } + + public ExchangeSource createExchangeSource(int partitionId) + { + ExchangeHandleInfo exchangeHandleInfo = getExchangeHandleInfo(exchangeSink); + ListenableFuture> fileStatus = getFileStatus(outputDirectory); + List completedFileStatus = new ArrayList<>(); + if (!fileStatus.isDone()) { + return null; + } + try { + completedFileStatus = fileStatus.get(); + } + catch (Exception e) { + LOG.debug("Failed in creating exchange source with outputDirectory" + outputDirectory); + } + List handles = ImmutableList.of(new FileSystemExchangeSourceHandle(partitionId, + completedFileStatus, exchangeHandleInfo.getSecretKey(), exchangeHandleInfo.isExchangeCompressionEnabled())); + + return exchangeManager.createSource(handles); + } + + private ListenableFuture> getFileStatus(URI path) + { + FileSystemExchangeStorage exchangeStorage = exchangeSink.getExchangeStorage(); + return Futures.transform(exchangeStorage.listFilesRecursively(path), + sinkOutputFiles -> sinkOutputFiles.stream().filter(file -> file.getFilePath().endsWith(DATA_FILE_SUFFIX)).collect(toImmutableList()), + executor); + } + + private ExchangeHandleInfo getExchangeHandleInfo(ExchangeSink exchangeSink) + { + return new ExchangeHandleInfo(exchangeSink.getOutputDirectory(), + exchangeSink.getExchangeStorage(), + exchangeSink.getSecretKey().map(SecretKey::getEncoded), + exchangeSink.isExchangeCompressionEnabled()); + } + + @Override + public void setJavaSerde(PagesSerde javaSerde) + { + this.javaSerde = javaSerde; + } + + @Override + public void setKryoSerde(PagesSerde kryoSerde) + { + this.kryoSerde = kryoSerde; + } + + @Override + public void setSerde(PagesSerde serde) + { + this.serde = serde; + } + + private static class ExchangeHandleInfo + { + URI outputDirectory; + FileSystemExchangeStorage exchangeStorage; + Optional secretKey; + boolean exchangeCompressionEnabled; + + ExchangeHandleInfo(URI outputDirectory, FileSystemExchangeStorage exchangeStorage, Optional secretKey, boolean exchangeCompressionEnabled) + { + this.outputDirectory = outputDirectory; + this.exchangeStorage = exchangeStorage; + this.secretKey = secretKey; + this.exchangeCompressionEnabled = exchangeCompressionEnabled; + } + + public URI getOutputDirectory() + { + return outputDirectory; + } + + public Optional getSecretKey() + { + return secretKey; + } + + public boolean isExchangeCompressionEnabled() + { + return exchangeCompressionEnabled; + } + + public FileSystemExchangeStorage getExchangeStorage() + { + return exchangeStorage; + } + } + + public static class SpoolingInfo + implements Serializable + { + private long spoolingFileOffset; + private long spoolingFileSize; + + @JsonCreator + public SpoolingInfo( + @JsonProperty("spoolingFileOffset") long spoolingFileOffset, + @JsonProperty("spoolingFileSize") long spoolingFileSize) + { + this.spoolingFileOffset = spoolingFileOffset; + this.spoolingFileSize = spoolingFileSize; + } + + @JsonProperty + public long getSpoolingFileOffset() + { + return spoolingFileOffset; + } + + @JsonProperty + public long getSpoolingFileSize() + { + return spoolingFileSize; + } + } + + public Object getMarkerDataFileFooter() + { + try { + InputStream inputStream = fsClient.newInputStream(Paths.get(outputDirectory.resolve(MARKER_DATA_FILE).getPath())); + inputStream.skip(previousFooterOffset); + Object o = RecoveryUtils.deserializeState(inputStream, false); + return o; + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to get marker data file footer"); + } + } + + public Object getMarkerDataFileFooter(long previousFooterOffset) + { + try { + InputStream inputStream = fsClient.newInputStream(Paths.get(outputDirectory.resolve(MARKER_DATA_FILE).getPath())); + inputStream.skip(previousFooterOffset); + Object o = RecoveryUtils.deserializeState(inputStream, false); + return o; + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to get marker data file footer"); + } + } + + public Object getMarkerData(long stateOffset) + { + try { + InputStream inputStream = fsClient.newInputStream(Paths.get(outputDirectory.resolve(MARKER_DATA_FILE).getPath())); + inputStream.skip(stateOffset); + Object o = RecoveryUtils.deserializeState(inputStream, false); + return o; + } + catch (Exception e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Failed to get marker data"); + } + } + + @Override + public void destroy() + { + markerIndexFileWriter.close(); + markerDataFileWriter.close(); + super.destroy(); + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java index 1e3e05712691d330f339e76f04da059b1b04bf80..4a92cee6772a3aaa00488660a2d7ed024aeb539c 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/LazyOutputBuffer.java @@ -64,6 +64,9 @@ public class LazyOutputBuffer @GuardedBy("this") private OutputBuffer delegate; + @GuardedBy("this") + private OutputBuffer hybridSpoolingDelegate; + @GuardedBy("this") private final Set abortedBuffers = new HashSet<>(); @@ -72,6 +75,10 @@ public class LazyOutputBuffer private final ExchangeManagerRegistry exchangeManagerRegistry; private Optional exchangeSink; + private PagesSerde serde; + private PagesSerde javaSerde; + private PagesSerde kryoSerde; + public LazyOutputBuffer( TaskId taskId, Executor executor, @@ -161,6 +168,7 @@ public class LazyOutputBuffer Set abortedBuffersIds = ImmutableSet.of(); List bufferPendingReads = ImmutableList.of(); OutputBuffer outputBuffer; + ExchangeManager exchangeManager = null; synchronized (this) { if (delegate == null) { // ignore set output if buffer was already destroyed or failed @@ -185,7 +193,7 @@ public class LazyOutputBuffer if (newOutputBuffers.getExchangeSinkInstanceHandle().isPresent()) { ExchangeSinkInstanceHandle exchangeSinkInstanceHandle = newOutputBuffers.getExchangeSinkInstanceHandle() .orElseThrow(() -> new IllegalArgumentException("exchange sink handle is expected to be present for buffer type EXTERNAL")); - ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); + exchangeManager = exchangeManagerRegistry.getExchangeManager(); ExchangeSink exchangeSinkInstance = exchangeManager.createSink(exchangeSinkInstanceHandle, false); //TODO: create directories this.exchangeSink = Optional.ofNullable(exchangeSinkInstance); } @@ -199,11 +207,30 @@ public class LazyOutputBuffer bufferPendingReads = ImmutableList.copyOf(this.pendingReads); this.pendingReads.clear(); } - this.exchangeSink.ifPresent(sink -> delegate = new SpoolingExchangeOutputBuffer( - stateMachine, - newOutputBuffers, - sink, - systemMemoryContextSupplier)); + ExchangeManager finalExchangeManager = exchangeManager; + this.exchangeSink.ifPresent(sink -> { + if (delegate == null) { + delegate = new SpoolingExchangeOutputBuffer( + stateMachine, + newOutputBuffers, + sink, + systemMemoryContextSupplier); + } + else { + if (hybridSpoolingDelegate == null) { + hybridSpoolingDelegate = new HybridSpoolingBuffer(stateMachine, + newOutputBuffers, + sink, + systemMemoryContextSupplier, + finalExchangeManager); + if (hybridSpoolingDelegate != null) { + hybridSpoolingDelegate.setSerde(serde); + hybridSpoolingDelegate.setJavaSerde(javaSerde); + hybridSpoolingDelegate.setKryoSerde(kryoSerde); + } + } + } + }); outputBuffer = delegate; } @@ -325,6 +352,9 @@ public class LazyOutputBuffer checkState(delegate != null, "Buffer has not been initialized"); outputBuffer = delegate; } + if (hybridSpoolingDelegate != null) { + hybridSpoolingDelegate.setNoMorePages(); + } outputBuffer.setNoMorePages(); } @@ -361,6 +391,9 @@ public class LazyOutputBuffer } outputBuffer.destroy(); + if (hybridSpoolingDelegate != null) { + hybridSpoolingDelegate.destroy(); + } } @Override @@ -488,4 +521,44 @@ public class LazyOutputBuffer } outputBuffer.enqueuePages(partition, pages, id, directSerde); } + + @Override + public boolean isSpoolingDelegateAvailable() + { + return delegate != null && hybridSpoolingDelegate != null && hybridSpoolingDelegate.isSpoolingOutputBuffer(); + } + + @Override + public OutputBuffer getSpoolingDelegate() + { + return hybridSpoolingDelegate; + } + + @Override + public DirectSerialisationType getDelegateSpoolingExchangeDirectSerializationType() + { + DirectSerialisationType type = DirectSerialisationType.JAVA; + if (hybridSpoolingDelegate != null) { + type = hybridSpoolingDelegate.getExchangeDirectSerialisationType(); + } + return type; + } + + @Override + public void setSerde(PagesSerde serde) + { + this.serde = serde; + } + + @Override + public void setJavaSerde(PagesSerde javaSerde) + { + this.javaSerde = javaSerde; + } + + @Override + public void setKryoSerde(PagesSerde kryoSerde) + { + this.kryoSerde = kryoSerde; + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java index 3e6cfbeb10018bf474319c90d123a24aeaa24367..773b270568c2fd5e96c0146120644bc6c9291edc 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/OutputBuffer.java @@ -176,4 +176,31 @@ public interface OutputBuffer { return; } + + default boolean isSpoolingDelegateAvailable() + { + return false; + } + + default OutputBuffer getSpoolingDelegate() + { + return null; + } + + default DirectSerialisationType getDelegateSpoolingExchangeDirectSerializationType() + { + return DirectSerialisationType.JAVA; + } + + default void setSerde(PagesSerde pagesSerde) + { + } + + default void setJavaSerde(PagesSerde pagesSerde) + { + } + + default void setKryoSerde(PagesSerde pagesSerde) + { + } } diff --git a/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java b/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java index f671b3ff9748afa3771cab7743b82e2f3bbbf52d..28348964c4c66ebd9583b6b328a742a0eda788d7 100644 --- a/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java +++ b/presto-main/src/main/java/io/prestosql/execution/buffer/PartitionedOutputBuffer.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId; import io.prestosql.memory.context.LocalMemoryContext; @@ -367,7 +368,8 @@ public class PartitionedOutputBuffer checkArgument(taskContext != null, "taskContext is null"); checkState(this.taskContext == null, "setTaskContext is called multiple times"); this.taskContext = taskContext; - if (SystemSessionProperties.isSnapshotEnabled(taskContext.getSession())) { + if (SystemSessionProperties.isSnapshotEnabled(taskContext.getSession()) + || (SystemSessionProperties.getRetryPolicy(taskContext.getSession()) == RetryPolicy.TASK_ASYNC)) { isSnapshotEnabled = true; for (int i = 0; i < partitions.size(); i++) { // Create one snapshot state for each partition, because restored pages need to be associated with the same partition diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java index 1e7e041cf31951f6ecdfc20b505566290fd1df1a..bd221662477915f4a50fb6dee5a69fb1284e48f5 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java @@ -25,6 +25,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.prestosql.Session; import io.prestosql.SystemSessionProperties; import io.prestosql.dynamicfilter.DynamicFilterService; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.Lifespan; import io.prestosql.execution.RemoteTask; import io.prestosql.execution.SqlStageExecution; @@ -338,7 +339,8 @@ public class SourcePartitionedScheduler // add split filter to filter out split has no valid rows Pair, Map> pair = SplitFiltering.getExpression(stage); - if (SystemSessionProperties.isRecoveryEnabled(session)) { + if (SystemSessionProperties.isRecoveryEnabled(session) + || SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC) { List batchSplits = nextSplits.getSplits(); // Don't apply filter to MarkerSplit if (batchSplits.size() == 1 && batchSplits.get(0).getConnectorSplit() instanceof MarkerSplit) { @@ -410,7 +412,8 @@ public class SourcePartitionedScheduler splitAssignment = splitPlacementResult.getAssignments(); - if (SystemSessionProperties.isRecoveryEnabled(session)) { + if (SystemSessionProperties.isRecoveryEnabled(session) + || SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC) { Split firstSplit = pendingSplits.iterator().next(); if (pendingSplits.size() == 1 && firstSplit.getConnectorSplit() instanceof MarkerSplit) { // We'll create a new assignment, but still need to call computeAssignments above, and cannot modify the returned assignment map directly diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java index 621438ed92677583cc4c4c1ecb6e9483cdcd57c5..9fa71a151fbc4ea8f6a3a79b87c97e20da950556 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java @@ -149,6 +149,7 @@ import static io.prestosql.SystemSessionProperties.getWriterMinSize; import static io.prestosql.SystemSessionProperties.isQueryResourceTrackingEnabled; import static io.prestosql.SystemSessionProperties.isReuseTableScanEnabled; import static io.prestosql.exchange.RetryPolicy.TASK; +import static io.prestosql.exchange.RetryPolicy.TASK_ASYNC; import static io.prestosql.execution.BasicStageStats.aggregateBasicStageStats; import static io.prestosql.execution.QueryState.FINISHING; import static io.prestosql.execution.QueryState.RECOVERING; @@ -633,7 +634,7 @@ public class SqlQueryScheduler StageId stageId = new StageId(queryStateMachine.getQueryId(), nextStageId.getAndIncrement()); Optional exchange = Optional.empty(); - if (retryPolicy.equals(TASK)) { + if (retryPolicy.equals(TASK) || retryPolicy.equals(TASK_ASYNC)) { ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); exchange = createSqlStageExchange(exchangeManager, stageId); } diff --git a/presto-main/src/main/java/io/prestosql/memory/QueryContext.java b/presto-main/src/main/java/io/prestosql/memory/QueryContext.java index ab54be696494d8fde8188ada5d4072c81dedf294..e72daf7556cb496e2979206b7db4df7fea7014f5 100644 --- a/presto-main/src/main/java/io/prestosql/memory/QueryContext.java +++ b/presto-main/src/main/java/io/prestosql/memory/QueryContext.java @@ -280,7 +280,7 @@ public class QueryContext parent.orElse(null), serdeFactory, kryoSerdeFactory, - new TaskSnapshotManager(taskStateMachine.getTaskId(), resumeCount, recoveryUtils), + new TaskSnapshotManager(taskStateMachine.getTaskId(), resumeCount, recoveryUtils, session), recoveryUtils.getRecoveryManager(queryId)); taskContexts.put(taskInstanceId, taskContext); return taskContext; diff --git a/presto-main/src/main/java/io/prestosql/operator/Driver.java b/presto-main/src/main/java/io/prestosql/operator/Driver.java index cb80093f5dbd109ae76074b846a24edc7b520300..1abceb32f378364fe35b0ed816d36ca8270e03ef 100644 --- a/presto-main/src/main/java/io/prestosql/operator/Driver.java +++ b/presto-main/src/main/java/io/prestosql/operator/Driver.java @@ -23,6 +23,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.airlift.log.Logger; import io.airlift.units.Duration; import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.ScheduledSplit; import io.prestosql.execution.TaskSource; import io.prestosql.execution.TaskState; @@ -128,7 +129,8 @@ public class Driver private Driver(DriverContext driverContext, List operators) { this.driverContext = requireNonNull(driverContext, "driverContext is null"); - this.isSnapshotEnabled = SystemSessionProperties.isSnapshotEnabled(driverContext.getSession()); + this.isSnapshotEnabled = SystemSessionProperties.isSnapshotEnabled(driverContext.getSession()) + || SystemSessionProperties.getRetryPolicy(driverContext.getSession()) == RetryPolicy.TASK_ASYNC; this.allOperators = ImmutableList.copyOf(requireNonNull(operators, "operators is null")); checkArgument(allOperators.size() > 1, "At least two operators are required"); this.activeOperators = new ArrayList<>(operators); diff --git a/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java b/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java index a41957238c705f947c839f0a57149b03eee25bff..ee89af2b0f33061a25c4b03b86318d5d711501e6 100644 --- a/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java +++ b/presto-main/src/main/java/io/prestosql/operator/ExchangeClientFactory.java @@ -138,6 +138,7 @@ public class ExchangeClientFactory buffer = new DeduplicatingDirectExchangeBuffer(scheduler, deduplicationBufferSize, retryPolicy, exchangeManagerRegistry, queryId, exchangeId); break; case NONE: + case TASK_ASYNC: buffer = null; break; default: diff --git a/presto-main/src/main/java/io/prestosql/operator/ExchangeOperator.java b/presto-main/src/main/java/io/prestosql/operator/ExchangeOperator.java index 27f1be3b196c13e0fe87eacb6ad07f3d77e23d13..46e2ea77b562857222fa417d176cab579ef7d71d 100644 --- a/presto-main/src/main/java/io/prestosql/operator/ExchangeOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/ExchangeOperator.java @@ -33,6 +33,7 @@ import io.prestosql.execution.TaskFailureListener; import io.prestosql.execution.TaskId; import io.prestosql.memory.context.LocalMemoryContext; import io.prestosql.metadata.Split; +import io.prestosql.snapshot.MarkerSplit; import io.prestosql.snapshot.MultiInputRestorable; import io.prestosql.snapshot.MultiInputSnapshotState; import io.prestosql.snapshot.QueryRecoveryManager; @@ -123,7 +124,8 @@ public class ExchangeOperator exchangeManagerRegistry, uniqueId, addOperatorContext.isRecoveryEnabled(), - driverContext.getPipelineContext().getTaskContext().getRecoveryManager()); + driverContext.getPipelineContext().getTaskContext().getRecoveryManager(), + addOperatorContext.isTaskRetryAsyncExecutionEnabled()); // if recovery is enabled exchange client is required at the time of exchange operator creation to add Targets. // So if recovery is enabled DirectExchangeDataSource is set as dataSource. if (addOperatorContext.isRecoveryEnabled()) { @@ -163,6 +165,7 @@ public class ExchangeOperator private Optional> inputChannels = Optional.empty(); private final SettableFuture blockedOnSplits = SettableFuture.create(); + private MarkerSplit markerSplit; public ExchangeOperator( String id, @@ -171,10 +174,11 @@ public class ExchangeOperator ExchangeDataSource exchangeDataSource) { this.id = requireNonNull(id, "id is null"); + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.sourceId = requireNonNull(sourceId, "sourceId is null"); this.exchangeDataSource = requireNonNull(exchangeDataSource, "exchangeDataSource is null"); - this.snapshotState = operatorContext.isSnapshotEnabled() + this.snapshotState = operatorContext.isSnapshotEnabled() || operatorContext.isTaskRetryAsyncExecutionEnabled() ? MultiInputSnapshotState.forOperator(this, operatorContext) : null; operatorContext.setInfoSupplier(exchangeDataSource::getInfo); @@ -384,6 +388,7 @@ public class ExchangeOperator private final ExchangeManagerRegistry exchangeManagerRegistry; private final boolean recoveryEnabled; private final QueryRecoveryManager queryRecoveryManager; + private final boolean asyncExecutionEnabled; private final SettableFuture initializationFuture = SettableFuture.create(); private final AtomicReference delegate = new AtomicReference<>(); @@ -400,7 +405,8 @@ public class ExchangeOperator ExchangeManagerRegistry exchangeManagerRegistry, String uniqueId, boolean recoveryEnabled, - QueryRecoveryManager queryRecoveryManager) + QueryRecoveryManager queryRecoveryManager, + boolean asyncExecutionEnabled) { this.taskId = requireNonNull(taskId, "taskId is null"); this.sourceId = requireNonNull(sourceId, "sourceId is null"); @@ -412,6 +418,7 @@ public class ExchangeOperator this.queryRecoveryManager = queryRecoveryManager; this.uniqueId = uniqueId; this.recoveryEnabled = recoveryEnabled; + this.asyncExecutionEnabled = asyncExecutionEnabled; } @Override diff --git a/presto-main/src/main/java/io/prestosql/operator/OperatorContext.java b/presto-main/src/main/java/io/prestosql/operator/OperatorContext.java index 98572b816eb481ae3a722ca885ae3cc5225dd5fd..f6be9bab89299eb6ae819d8bb9f056c13f38c0b6 100644 --- a/presto-main/src/main/java/io/prestosql/operator/OperatorContext.java +++ b/presto-main/src/main/java/io/prestosql/operator/OperatorContext.java @@ -125,6 +125,7 @@ public class OperatorContext private final boolean snapshotEnabled; private final boolean recoveryEnabled; private final RetryPolicy retryPolicy; + private final boolean taskRetryAsyncExecutionEnabled; public OperatorContext( int operatorId, @@ -148,9 +149,10 @@ public class OperatorContext this.operatorMemoryContext = requireNonNull(operatorMemoryContext, "operatorMemoryContext is null"); operatorMemoryContext.initializeLocalMemoryContexts(operatorType); - this.snapshotEnabled = SystemSessionProperties.isSnapshotEnabled(driverContext.getSession()); - this.recoveryEnabled = SystemSessionProperties.isRecoveryEnabled(driverContext.getSession()); + this.snapshotEnabled = SystemSessionProperties.isSnapshotEnabled(driverContext.getSession()) || (SystemSessionProperties.getRetryPolicy(driverContext.getSession()) == RetryPolicy.TASK_ASYNC); + this.recoveryEnabled = SystemSessionProperties.isRecoveryEnabled(driverContext.getSession()) || (SystemSessionProperties.getRetryPolicy(driverContext.getSession()) == RetryPolicy.TASK_ASYNC); this.retryPolicy = SystemSessionProperties.getRetryPolicy(driverContext.getSession()); + this.taskRetryAsyncExecutionEnabled = (SystemSessionProperties.getRetryPolicy(driverContext.getSession()) == RetryPolicy.TASK_ASYNC); } public int getOperatorId() @@ -826,6 +828,11 @@ public class OperatorContext return recoveryEnabled; } + public boolean isTaskRetryAsyncExecutionEnabled() + { + return taskRetryAsyncExecutionEnabled; + } + public RetryPolicy getRetryPolicy() { return retryPolicy; diff --git a/presto-main/src/main/java/io/prestosql/operator/TableScanOperator.java b/presto-main/src/main/java/io/prestosql/operator/TableScanOperator.java index 76e516cb75613906c44a3668faab6d286f0c3dcf..33688132741c60b7bcbf3859509302f1844f97a7 100644 --- a/presto-main/src/main/java/io/prestosql/operator/TableScanOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/TableScanOperator.java @@ -25,6 +25,7 @@ import io.prestosql.memory.context.LocalMemoryContext; import io.prestosql.memory.context.MemoryTrackingContext; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Split; +import io.prestosql.snapshot.MarkerSplit; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.QueryId; @@ -293,6 +294,7 @@ public class TableScanOperator private static final ConcurrentMap reuseExchangeTableScanMappingIdUtilsMap = new ConcurrentHashMap<>(); private ReuseExchangeTableScanMappingIdState reuseExchangeTableScanMappingIdState; private ListenableFuture spillInProgress = immediateFuture(null); + private MarkerSplit markerSplit; public TableScanOperator( OperatorContext operatorContext, @@ -598,12 +600,17 @@ public class TableScanOperator public Supplier> addSplit(Split split) { requireNonNull(split, "split is null"); - checkState(this.split == null, "Table scan split already set"); + checkState(this.split == null || split.getConnectorSplit() instanceof MarkerSplit, "Table scan split already set"); if (finished) { return Optional::empty; } + if (split.getConnectorSplit() instanceof MarkerSplit) { + this.markerSplit = (MarkerSplit) split.getConnectorSplit(); + return Optional::empty; + } + this.split = split; Object splitInfo = split.getInfo(); @@ -694,9 +701,16 @@ public class TableScanOperator if (strategy.equals(REUSE_STRATEGY_CONSUMER)) { return getPage(); } - if (split == null) { + if (split == null && markerSplit == null) { return null; } + + Page page; + if (markerSplit != null) { + page = markerSplit.toMarkerPage(); + markerSplit = null; + return page; + } if (source == null) { if (isDcTable) { source = pageSourceProvider.createPageSource(operatorContext.getSession(), @@ -710,7 +724,7 @@ public class TableScanOperator } } - Page page = source.getNextPage(); + page = source.getNextPage(); // if pageSource.getCompletedPositionCount is present, get operator statistics from pageSource if (source.getCompletedPositionCount().isPresent()) { diff --git a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java index b4410f6943fd84d976d2c24eb4b8d40a161fbe94..fdf36c5125e71487b1bacc64ad916dcebf811095 100644 --- a/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/TaskOutputOperator.java @@ -17,6 +17,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; +import io.prestosql.execution.buffer.HybridSpoolingBuffer; import io.prestosql.execution.buffer.OutputBuffer; import io.prestosql.snapshot.SingleInputSnapshotState; import io.prestosql.spi.Page; @@ -123,6 +124,8 @@ public class TaskOutputOperator private final SingleInputSnapshotState snapshotState; private final boolean isStage0; private final PagesSerde serde; + private final PagesSerde javaSerde; + private final PagesSerde kryoSerde; private boolean finished; public TaskOutputOperator(String id, OperatorContext operatorContext, OutputBuffer outputBuffer, Function pagePreprocessor) @@ -132,8 +135,19 @@ public class TaskOutputOperator this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.pagePreprocessor = requireNonNull(pagePreprocessor, "pagePreprocessor is null"); this.serde = requireNonNull(operatorContext.getDriverContext().getSerde(), "serde is null"); + this.javaSerde = requireNonNull(operatorContext.getDriverContext().getJavaSerde(), "javaSerde is null"); + this.kryoSerde = requireNonNull(operatorContext.getDriverContext().getKryoSerde(), "kryoSerde is null"); + this.outputBuffer.setSerde(serde); + this.outputBuffer.setJavaSerde(javaSerde); + this.outputBuffer.setKryoSerde(kryoSerde); this.snapshotState = operatorContext.isSnapshotEnabled() ? SingleInputSnapshotState.forOperator(this, operatorContext) : null; this.isStage0 = operatorContext.getDriverContext().getPipelineContext().getTaskContext().getTaskId().getStageId().getId() == 0; + + // when async exec is enabled + if (this.snapshotState != null && this.outputBuffer.isSpoolingDelegateAvailable()) { + HybridSpoolingBuffer hybridSpoolingBuffer = (HybridSpoolingBuffer) this.outputBuffer.getSpoolingDelegate(); + this.snapshotState.setHybridSpoolingBuffer(hybridSpoolingBuffer); + } } @Override @@ -192,6 +206,7 @@ public class TaskOutputOperator } DirectSerialisationType serialisationType = outputBuffer.getExchangeDirectSerialisationType(); + DirectSerialisationType spoolingSerialisationType = outputBuffer.getDelegateSpoolingExchangeDirectSerializationType(); if (outputBuffer.isSpoolingOutputBuffer() && serialisationType != DirectSerialisationType.OFF) { PagesSerde directSerde = (serialisationType == DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); List pages = splitPage(inputPage, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); @@ -215,6 +230,21 @@ public class TaskOutputOperator } } else { + if (outputBuffer.isSpoolingDelegateAvailable()) { + OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); + if (spoolingSerialisationType != DirectSerialisationType.OFF) { + PagesSerde directSerde = (spoolingSerialisationType == DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); + List pages = splitPage(inputPage, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + if (spoolingBuffer != null) { + spoolingBuffer.enqueuePages(0, pages, id, directSerde); + } + } + else { + if (spoolingBuffer != null) { + spoolingBuffer.enqueue(serializedPages, id); + } + } + } outputBuffer.enqueue(serializedPages, id); } } diff --git a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java index bbc68e1c8604c20a298fcca1a8aaed6acfe81e13..2debb28414a5f78228b47cfd32eba17e9957ea2e 100644 --- a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java +++ b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java @@ -99,6 +99,9 @@ public class PagePartitioner this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); this.operatorContext = requireNonNull(operatorContext, "serde is null"); + this.outputBuffer.setSerde(requireNonNull(operatorContext.getDriverContext().getSerde(), "serde is null")); + this.outputBuffer.setJavaSerde(requireNonNull(operatorContext.getDriverContext().getJavaSerde(), "java serde is null")); + this.outputBuffer.setKryoSerde(requireNonNull(operatorContext.getDriverContext().getKryoSerde(), "kryo serde is null")); int partitionCount = partitionFunction.getPartitionCount(); int pageSize = min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, ((int) maxMemory.toBytes()) / partitionCount); @@ -456,6 +459,7 @@ public class PagePartitioner partitionPageBuilder.reset(); FileSystemExchangeConfig.DirectSerialisationType serialisationType = outputBuffer.getExchangeDirectSerialisationType(); + FileSystemExchangeConfig.DirectSerialisationType spoolingSerialisationType = outputBuffer.getDelegateSpoolingExchangeDirectSerializationType(); if (outputBuffer.isSpoolingOutputBuffer() && serialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { PagesSerde directSerde = (serialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); List pages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); @@ -466,6 +470,21 @@ public class PagePartitioner .map(page -> operatorContext.getDriverContext().getSerde().serialize(page)) .collect(toImmutableList()); + if (outputBuffer.isSpoolingDelegateAvailable()) { + OutputBuffer spoolingBuffer = outputBuffer.getSpoolingDelegate(); + if (spoolingSerialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { + PagesSerde directSerde = (spoolingSerialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); + List pages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + if (spoolingBuffer != null) { + spoolingBuffer.enqueuePages(partition, pages, id, directSerde); + } + } + else { + if (spoolingBuffer != null) { + spoolingBuffer.enqueue(partition, serializedPages, id); + } + } + } outputBuffer.enqueue(partition, serializedPages, id); } pagesAdded.incrementAndGet(); diff --git a/presto-main/src/main/java/io/prestosql/snapshot/MultiInputSnapshotState.java b/presto-main/src/main/java/io/prestosql/snapshot/MultiInputSnapshotState.java index 8e57e1353fb757f6b4be491a545326ce83849038..b6d787ba17b37f57173e2f022e455e3efcb70a23 100644 --- a/presto-main/src/main/java/io/prestosql/snapshot/MultiInputSnapshotState.java +++ b/presto-main/src/main/java/io/prestosql/snapshot/MultiInputSnapshotState.java @@ -19,6 +19,9 @@ import com.google.common.collect.Iterators; import io.airlift.log.Logger; import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.Session; +import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.operator.OperatorContext; import io.prestosql.operator.TaskContext; import io.prestosql.spi.Page; @@ -69,6 +72,7 @@ public class MultiInputSnapshotState // For an operator, this is typically /query-id/snapshot-id/stage-id/task-id/pipeline-id/driver-id/operator-id // For a task component, this is typically /query-id/snapshot-id/stage-id/task-id/component-id private final Function snapshotStateIdGenerator; + private final Session session; // All input channels, if known private Optional> inputChannels; @@ -90,7 +94,8 @@ public class MultiInputSnapshotState return new MultiInputSnapshotState(restorable, operatorContext.getDriverContext().getPipelineContext().getTaskContext().getSnapshotManager(), operatorContext.getDriverContext().getSerde(), - snapshotId -> SnapshotStateId.forOperator(snapshotId, operatorContext)); + snapshotId -> SnapshotStateId.forOperator(snapshotId, operatorContext), + operatorContext.getSession()); } /** @@ -105,7 +110,8 @@ public class MultiInputSnapshotState return new MultiInputSnapshotState(restorable, taskContext.getSnapshotManager(), taskContext.getSerdeFactory().createPagesSerde(), - snapshotStateIdGenerator); + snapshotStateIdGenerator, + taskContext.getSession()); } /** @@ -120,13 +126,15 @@ public class MultiInputSnapshotState MultiInputRestorable restorable, TaskSnapshotManager snapshotManager, PagesSerde serde, - Function snapshotStateIdGenerator) + Function snapshotStateIdGenerator, + Session session) { this.restorable = restorable; this.restorableId = String.format("%s (%s)", restorable.getClass().getSimpleName(), snapshotStateIdGenerator.apply(0L).getId()); this.snapshotManager = snapshotManager; this.pagesSerde = serde; this.snapshotStateIdGenerator = snapshotStateIdGenerator; + this.session = session; } /** @@ -389,18 +397,20 @@ public class MultiInputSnapshotState long snapshotId = marker.getSnapshotId(); SnapshotState snapshot = snapshotStateById(marker); - if (snapshot == null) { + if (snapshot == null && (SystemSessionProperties.getRetryPolicy(session) != RetryPolicy.TASK_ASYNC || !pendingMarkers.contains(marker))) { // First time a marker is received for a snapshot. Store operator state and make marker available. - snapshot = new SnapshotState(marker); pendingMarkers.add(marker); - try { - snapshot.addState(restorable, pagesSerde); - } - catch (Exception e) { - LOG.warn(e, "Failed to capture and store snapshot state"); - snapshotManager.failedToCapture(snapshotStateIdGenerator.apply(snapshotId)); + if (SystemSessionProperties.getRetryPolicy(session) != RetryPolicy.TASK_ASYNC || !snapshotManager.isCoordinatorTask()) { + snapshot = new SnapshotState(marker); + try { + snapshot.addState(restorable, pagesSerde); + } + catch (Exception e) { + LOG.warn(e, "Failed to capture and store snapshot state"); + snapshotManager.failedToCapture(snapshotStateIdGenerator.apply(snapshotId)); + } + states.add(snapshot); } - states.add(snapshot); } else { // Seen marker with same snapshot id. Don't need to pass on this marker. @@ -408,44 +418,46 @@ public class MultiInputSnapshotState } // No longer need to capture inputs on this channel - if (!snapshot.markedChannels.add(channel)) { + if (snapshot != null && !snapshot.markedChannels.add(channel)) { String message = String.format("Received duplicate marker '%s' from source '%s' to target '%s'", marker.toString(), channel, restorableId); LOG.error(message); return true; } - inputChannels = restorable.getInputChannels(); - if (inputChannels.isPresent()) { - checkState(inputChannels.get().containsAll(snapshot.markedChannels)); - if (inputChannels.get().size() == snapshot.markedChannels.size()) { - // Received marker from all channels. Operator state is complete. - SnapshotStateId componentId = snapshotStateIdGenerator.apply(snapshotId); - try { - if (restorable.supportsConsolidatedWrites()) { - snapshotManager.storeConsolidatedState(componentId, snapshot.states, snapshot.serTime); - } - else { - snapshotManager.storeState(componentId, snapshot.states, snapshot.serTime); + if (snapshot != null && (SystemSessionProperties.getRetryPolicy(session) != RetryPolicy.TASK_ASYNC || !snapshotManager.isCoordinatorTask())) { + inputChannels = restorable.getInputChannels(); + if (inputChannels.isPresent()) { + checkState(inputChannels.get().containsAll(snapshot.markedChannels)); + if (inputChannels.get().size() == snapshot.markedChannels.size()) { + // Received marker from all channels. Operator state is complete. + SnapshotStateId componentId = snapshotStateIdGenerator.apply(snapshotId); + try { + if (restorable.supportsConsolidatedWrites()) { + snapshotManager.storeConsolidatedState(componentId, snapshot.states, snapshot.serTime); + } + else { + snapshotManager.storeState(componentId, snapshot.states, snapshot.serTime); + } + snapshotManager.succeededToCapture(componentId); + LOG.debug("Successfully saved state to snapshot %d for %s", snapshotId, restorableId); } - snapshotManager.succeededToCapture(componentId); - LOG.debug("Successfully saved state to snapshot %d for %s", snapshotId, restorableId); - } - catch (Exception e) { - LOG.warn(e, "Failed to capture and store snapshot state"); - snapshotManager.failedToCapture(componentId); - } - int index = states.indexOf(snapshot); - states.remove(index); - for (int i = 0; i < index; i++) { - // All previous pending snapshots can't be complete - SnapshotState failedState = states.remove(0); - componentId = snapshotStateIdGenerator.apply(failedState.snapshotId); - if (failedState.resuming) { - snapshotManager.failedToRestore(componentId, false); - } - else { + catch (Exception e) { + LOG.warn(e, "Failed to capture and store snapshot state"); snapshotManager.failedToCapture(componentId); } + int index = states.indexOf(snapshot); + states.remove(index); + for (int i = 0; i < index; i++) { + // All previous pending snapshots can't be complete + SnapshotState failedState = states.remove(0); + componentId = snapshotStateIdGenerator.apply(failedState.snapshotId); + if (failedState.resuming) { + snapshotManager.failedToRestore(componentId, false); + } + else { + snapshotManager.failedToCapture(componentId); + } + } } } } diff --git a/presto-main/src/main/java/io/prestosql/snapshot/SingleInputSnapshotState.java b/presto-main/src/main/java/io/prestosql/snapshot/SingleInputSnapshotState.java index 44f5ebef1fdc6c8f1fcee06190537932c84c22e5..9bc7240c1f9214665c363d9e69c3a787821d0581 100644 --- a/presto-main/src/main/java/io/prestosql/snapshot/SingleInputSnapshotState.java +++ b/presto-main/src/main/java/io/prestosql/snapshot/SingleInputSnapshotState.java @@ -17,6 +17,10 @@ package io.prestosql.snapshot; import com.google.common.base.Stopwatch; import io.airlift.log.Logger; import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.prestosql.Session; +import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; +import io.prestosql.execution.buffer.HybridSpoolingBuffer; import io.prestosql.memory.context.LocalMemoryContext; import io.prestosql.operator.Operator; import io.prestosql.operator.OperatorContext; @@ -68,6 +72,7 @@ public class SingleInputSnapshotState Map> snapshotSpillPaths = new LinkedHashMap<>(); private final boolean isEliminateDuplicateSpillFilesEnabled; long lastSnapshotId = -1; + private final Session session; public static SingleInputSnapshotState forOperator(Operator operator, OperatorContext operatorContext) { @@ -78,7 +83,8 @@ public class SingleInputSnapshotState snapshotId -> SnapshotStateId.forOperator(snapshotId, operatorContext), snapshotId -> SnapshotStateId.forDriverComponent(snapshotId, operatorContext, operatorContext.getOperatorId() + "-spill"), operatorContext.newLocalUserMemoryContext(SingleInputSnapshotState.class.getSimpleName()), - isEliminateDuplicateSpillFilesEnabled(operatorContext.getDriverContext().getSession())); + isEliminateDuplicateSpillFilesEnabled(operatorContext.getDriverContext().getSession()), + operatorContext.getDriverContext().getSession()); } SingleInputSnapshotState(Restorable restorable, @@ -87,7 +93,8 @@ public class SingleInputSnapshotState Function snapshotStateIdGenerator, Function spillStateIdGenerator, LocalMemoryContext snapshotMemoryContext, - boolean isEliminateDuplicateSpillFilesEnabled) + boolean isEliminateDuplicateSpillFilesEnabled, + Session session) { this.restorable = requireNonNull(restorable, "restorable is null"); this.restorableId = String.format("%s (%s)", restorable.getClass().getSimpleName(), snapshotStateIdGenerator.apply(0L).getId()); @@ -97,6 +104,7 @@ public class SingleInputSnapshotState this.pagesSerde = pagesSerde; this.snapshotMemoryContext = snapshotMemoryContext; this.isEliminateDuplicateSpillFilesEnabled = isEliminateDuplicateSpillFilesEnabled; + this.session = session; } public void close() @@ -104,6 +112,15 @@ public class SingleInputSnapshotState snapshotMemoryContext.close(); } + private boolean needToCaptureSnapshot() + { + boolean result = true; + if ((SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC) && snapshotManager.isCoordinatorTask()) { + result = false; + } + return result; + } + /** * Perform marker and snapshot related processing on an incoming input * @@ -117,6 +134,12 @@ public class SingleInputSnapshotState } MarkerPage marker = (MarkerPage) input; + if (!needToCaptureSnapshot()) { + // ensure marker is propagated and ignore capturing snapshot + markers.add(marker); + return true; + } + long snapshotId = marker.getSnapshotId(); SnapshotStateId componentId = snapshotStateIdGenerator.apply(snapshotId); if (marker.isResuming()) { @@ -352,4 +375,9 @@ public class SingleInputSnapshotState } return true; } + + public void setHybridSpoolingBuffer(HybridSpoolingBuffer hybridSpoolingBuffer) + { + this.snapshotManager.setHybridSpoolingBuffer(hybridSpoolingBuffer); + } } diff --git a/presto-main/src/main/java/io/prestosql/snapshot/TaskSnapshotManager.java b/presto-main/src/main/java/io/prestosql/snapshot/TaskSnapshotManager.java index 94bf0884a8f69600e9eecf70ae793a8ce03c02b1..beb1fcdfd2ad1a68f4ee36d794ee5c225b480c84 100644 --- a/presto-main/src/main/java/io/prestosql/snapshot/TaskSnapshotManager.java +++ b/presto-main/src/main/java/io/prestosql/snapshot/TaskSnapshotManager.java @@ -16,7 +16,11 @@ package io.prestosql.snapshot; import com.google.common.collect.ImmutableMap; import io.airlift.log.Logger; +import io.prestosql.Session; +import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.TaskId; +import io.prestosql.execution.buffer.HybridSpoolingBuffer; import io.prestosql.operator.Operator; import io.prestosql.operator.exchange.LocalMergeSourceOperator; @@ -64,12 +68,15 @@ public class TaskSnapshotManager private final Map> loadCache = Collections.synchronizedMap(new HashMap<>()); private Set createdConsolidatedFiles; + private HybridSpoolingBuffer hybridSpoolingBuffer; + private final Session session; - public TaskSnapshotManager(TaskId taskId, long resumeCount, RecoveryUtils recoveryUtils) + public TaskSnapshotManager(TaskId taskId, long resumeCount, RecoveryUtils recoveryUtils, Session session) { this.taskId = taskId; this.resumeCount = resumeCount; this.recoveryUtils = recoveryUtils; + this.session = session; } public long getResumeCount() @@ -95,7 +102,12 @@ public class TaskSnapshotManager public void storeConsolidatedState(SnapshotStateId snapshotStateId, Object state, long serCpuTime) { Map map = storeCache.computeIfAbsent(snapshotStateId.getSnapshotId(), (x) -> Collections.synchronizedMap(new HashMap<>())); - map.put(snapshotStateId.toString(), state); + if (hybridSpoolingBuffer == null) { + map.put(snapshotStateId.toString(), state); + } + else { + map.put(createConsolidatedId(snapshotStateId.getSnapshotId(), snapshotStateId.getTaskId()).toString(), state); + } updateSnapshotCaptureCpuTime(snapshotStateId.getSnapshotId(), serCpuTime); } @@ -105,11 +117,23 @@ public class TaskSnapshotManager public void storeState(SnapshotStateId snapshotStateId, Object state, long serCpuTime) throws Exception { - recoveryUtils.storeState(snapshotStateId, state, this); + if (hybridSpoolingBuffer == null) { + recoveryUtils.storeState(snapshotStateId, state, this); + } + else { + Map markerInfoMap = new HashMap<>(); + markerInfoMap.put(snapshotStateId.toString(), state); + hybridSpoolingBuffer.enqueueMarkerInfo(snapshotStateId.getSnapshotId(), markerInfoMap); + } // store dummy value Map map = storeCache.computeIfAbsent(snapshotStateId.getSnapshotId(), (x) -> Collections.synchronizedMap(new HashMap<>())); - map.put(snapshotStateId.toString(), snapshotStateId.toString()); + if (hybridSpoolingBuffer == null) { + map.put(snapshotStateId.toString(), snapshotStateId.toString()); + } + else { + map.put(createConsolidatedId(snapshotStateId.getSnapshotId(), snapshotStateId.getTaskId()).toString(), snapshotStateId.toString()); + } updateSnapshotCaptureCpuTime(snapshotStateId.getSnapshotId(), serCpuTime); } @@ -122,17 +146,27 @@ public class TaskSnapshotManager if (!loadCache.containsKey(snapshotId)) { String queryId = taskId.getQueryId().getId(); SnapshotStateId stateId = createConsolidatedId(snapshotId, taskId); - Optional loadedState = recoveryUtils.loadState(stateId, this); - if (createdConsolidatedFiles == null) { - createdConsolidatedFiles = recoveryUtils.loadConsolidatedFiles(queryId); + Optional loadedState; + Object map; + if (hybridSpoolingBuffer == null) { + loadedState = recoveryUtils.loadState(stateId, this); + if (createdConsolidatedFiles == null) { + createdConsolidatedFiles = recoveryUtils.loadConsolidatedFiles(queryId); + } + // if it is still null after loading, that means it is deleted, and we need to fail + if (createdConsolidatedFiles == null || (createdConsolidatedFiles.contains(stateId.toString()) && !loadedState.isPresent())) { + // we created the consolidated file, but it has been deleted. non-recoverable failure + failedToRestore(stateId, true); + // continue so that the failure can be detected + } + map = loadedState.orElse(Collections.emptyMap()); } - // if it is still null after loading, that means it is deleted, and we need to fail - if (createdConsolidatedFiles == null || (createdConsolidatedFiles.contains(stateId.toString()) && !loadedState.isPresent())) { - // we created the consolidated file, but it has been deleted. non-recoverable failure - failedToRestore(stateId, true); - // continue so that the failure can be detected + else { + Map snapshots = hybridSpoolingBuffer.dequeueMarkerInfo(snapshotId); + Map resultMap = new HashMap(); + resultMap.put(stateId.toString(), snapshots.get(stateId.toString())); + map = (Object) resultMap; } - Object map = loadedState.orElse(Collections.emptyMap()); loadCache.put(snapshotId, (Map) map); } } @@ -151,7 +185,12 @@ public class TaskSnapshotManager // Need to check previous snapshots for their stored states. Optional state; loadMapIfNecessary(snapshotId, snapshotStateIdTaskId); - state = Optional.ofNullable(loadCache.get(snapshotId).get(newSnapshotStateId.toString())); + if (hybridSpoolingBuffer == null) { + state = Optional.ofNullable(loadCache.get(snapshotId).get(newSnapshotStateId.toString())); + } + else { + state = Optional.ofNullable(loadCache.get(snapshotId).get(createConsolidatedId(snapshotId, snapshotStateIdTaskId).toString())); + } Map snapshotToSnapshotResultMap = null; while (!state.isPresent()) { // Snapshot is complete but no entry for this id, then the component must have finished @@ -195,7 +234,13 @@ public class TaskSnapshotManager { Optional loadedValue = loadWithBacktrack(snapshotStateId); if (loadedValue.isPresent() && loadedValue.get() != NO_STATE) { - return recoveryUtils.loadState(SnapshotStateId.fromString((String) loadedValue.get()), this); + if (hybridSpoolingBuffer == null) { + return recoveryUtils.loadState(SnapshotStateId.fromString((String) loadedValue.get()), this); + } + else { + Map states = hybridSpoolingBuffer.dequeueMarkerInfo(snapshotStateId.getSnapshotId()); + return Optional.of(states.get(snapshotStateId.toString())); + } } return loadedValue; } @@ -349,7 +394,12 @@ public class TaskSnapshotManager else { map = Collections.emptyMap(); } - recoveryUtils.storeState(newId, map, this); + if (hybridSpoolingBuffer == null) { + recoveryUtils.storeState(newId, map, this); + } + else { + hybridSpoolingBuffer.enqueueMarkerInfo(newId.getSnapshotId(), map); + } } catch (Exception e) { LOG.error(e, "Failed to store state for " + newId); @@ -357,7 +407,7 @@ public class TaskSnapshotManager updateSnapshotStatus(snapshotId, snapshotResult); } } - if (recoveryUtils.isCoordinator()) { + if (recoveryUtils.isCoordinator() && (SystemSessionProperties.getRetryPolicy(session) != RetryPolicy.TASK_ASYNC)) { // Results on coordinator won't be reported through remote task. Send to the query side. QuerySnapshotManager querySnapshotManager = recoveryUtils.getQuerySnapshotManager(componentIdTaskId.getQueryId()); if (querySnapshotManager != null) { @@ -501,4 +551,14 @@ public class TaskSnapshotManager return recoveryUtils.loadSpilledPathInfo(SnapshotStateId.fromString((String) loadedValue.get())); } } + + public void setHybridSpoolingBuffer(HybridSpoolingBuffer hybridSpoolingBuffer) + { + this.hybridSpoolingBuffer = hybridSpoolingBuffer; + } + + public boolean isCoordinatorTask() + { + return recoveryUtils.isCoordinator(); + } } diff --git a/presto-main/src/main/java/io/prestosql/split/SplitManager.java b/presto-main/src/main/java/io/prestosql/split/SplitManager.java index 4270c23c57bcc954c83a8b2b0d4bb3e578f742fb..eb99665bf75372dc61e841fc53cd9f9ef6de350a 100644 --- a/presto-main/src/main/java/io/prestosql/split/SplitManager.java +++ b/presto-main/src/main/java/io/prestosql/split/SplitManager.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.prestosql.Session; import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.QueryManagerConfig; import io.prestosql.metadata.Metadata; import io.prestosql.snapshot.MarkerAnnouncer; @@ -94,7 +95,8 @@ public class SplitManager boolean partOfReuse, PlanNodeId nodeId) { MarkerAnnouncer announcer = null; - if (SystemSessionProperties.isRecoveryEnabled(session)) { + if (SystemSessionProperties.isRecoveryEnabled(session) + || SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC) { announcer = getMarkerAnnouncer(session); SplitSource splitSource = announcer.getSplitSource(nodeId); if (splitSource != null) { @@ -125,7 +127,8 @@ public class SplitManager if (minScheduleSplitBatchSize > 1) { splitSource = new BufferingSplitSource(splitSource, minScheduleSplitBatchSize); } - if (SystemSessionProperties.isRecoveryEnabled(session)) { + if (SystemSessionProperties.isRecoveryEnabled(session) + || SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC) { splitSource = announcer.createMarkerSplitSource(splitSource, nodeId); } return splitSource; @@ -146,7 +149,8 @@ public class SplitManager public MarkerAnnouncer getMarkerAnnouncer(Session session) { return announcers.computeIfAbsent(session.getQueryId(), queryId -> { - if (!SystemSessionProperties.isSnapshotEnabled(session)) { + if (!SystemSessionProperties.isSnapshotEnabled(session) + && SystemSessionProperties.getRetryPolicy(session) != RetryPolicy.TASK_ASYNC) { return new MarkerAnnouncer(); } else if (SystemSessionProperties.getSnapshotIntervalType(session) == RecoveryConfig.IntervalType.TIME) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 598a1b9ab9402e0d4b61f7fcd19bf5ab8fabbc2c..85dfba4073b31c53483f9fee9985f2cf7a94f31e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -36,6 +36,7 @@ import io.prestosql.cache.elements.CachedDataStorage; import io.prestosql.cube.CubeManager; import io.prestosql.dynamicfilter.DynamicFilterCacheManager; import io.prestosql.exchange.ExchangeManagerRegistry; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.ExplainAnalyzeContext; import io.prestosql.execution.StageId; import io.prestosql.execution.TableExecuteContextManager; @@ -762,7 +763,7 @@ public class LocalExecutionPlanner .forEach(LocalPlannerAware::localPlannerComplete); // calculate total number of components to be captured and add to snapshotManager - if (SystemSessionProperties.isSnapshotEnabled(session)) { + if (SystemSessionProperties.isSnapshotEnabled(session) || (SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC)) { taskContext.getSnapshotManager().setTotalComponents(calculateTotalCountOfTaskComponentToBeCaptured(taskContext, context, outputBuffer)); } @@ -1922,7 +1923,8 @@ public class LocalExecutionPlanner List pages; int from = 0; - if (SystemSessionProperties.isSnapshotEnabled(context.taskContext.getSession())) { + if (SystemSessionProperties.isSnapshotEnabled(context.taskContext.getSession()) + || SystemSessionProperties.getRetryPolicy(context.taskContext.getSession()) == RetryPolicy.TASK_ASYNC) { ImmutableList.Builder builder = ImmutableList.builder(); // Always include the data page, for debugging purposes Page page = valuesPage(node, context); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangeAboveCTENode.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangeAboveCTENode.java index 366b60415afa6b30e6600ec3f5f4a0b1c112de33..f321ea20e80efd13b4beccd53464bce0cc76689c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangeAboveCTENode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/AddExchangeAboveCTENode.java @@ -16,6 +16,8 @@ package io.prestosql.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; import io.prestosql.Session; +import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; import io.prestosql.spi.plan.FilterNode; @@ -41,7 +43,7 @@ public class AddExchangeAboveCTENode @Override public boolean isEnabled(Session session) { - return isCTEReuseEnabled(session) && !isSnapshotEnabled(session); + return isCTEReuseEnabled(session) && !isSnapshotEnabled(session) && (SystemSessionProperties.getRetryPolicy(session) != RetryPolicy.TASK_ASYNC); } @Override diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddReuseExchange.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddReuseExchange.java index 97caae6c04fbd678dcd31a280616b157cbac132d..55c2a60f71360fcd6f4f6acdbc0081bfa9921c4e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddReuseExchange.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddReuseExchange.java @@ -15,6 +15,8 @@ package io.prestosql.sql.planner.optimizations; import io.prestosql.Session; +import io.prestosql.SystemSessionProperties; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.connector.ColumnHandle; @@ -79,7 +81,7 @@ public class AddReuseExchange requireNonNull(planSymbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - if (!isReuseTableScanEnabled(session) || isColocatedJoinEnabled(session) || isSnapshotEnabled(session)) { + if (!isReuseTableScanEnabled(session) || isColocatedJoinEnabled(session) || isSnapshotEnabled(session) || (SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC)) { return plan; } else { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddSortBasedAggregation.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddSortBasedAggregation.java index f8536a4e4196f66174f3579c4c758594e0049a0b..1e1930db29f78b38c882962260f6a208813adc8c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddSortBasedAggregation.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/AddSortBasedAggregation.java @@ -18,6 +18,7 @@ package io.prestosql.sql.planner.optimizations; import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import io.prestosql.Session; +import io.prestosql.SystemSessionProperties; import io.prestosql.cost.CachingCostProvider; import io.prestosql.cost.CachingStatsProvider; import io.prestosql.cost.CostCalculator; @@ -26,6 +27,7 @@ import io.prestosql.cost.CostProvider; import io.prestosql.cost.PlanCostEstimate; import io.prestosql.cost.StatsCalculator; import io.prestosql.cost.StatsProvider; +import io.prestosql.exchange.RetryPolicy; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.metadata.Metadata; import io.prestosql.spi.PartialAndFinalAggregationType; @@ -81,7 +83,7 @@ public class AddSortBasedAggregation @Override public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { - if (!isSortBasedAggregationEnabled(session) || isSnapshotEnabled(session)) { + if (!isSortBasedAggregationEnabled(session) || isSnapshotEnabled(session) || (SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK_ASYNC)) { return plan; } diff --git a/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java b/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java index c976554020ae46c04574f434765df9240ad6ef06..8203dcd537005f57b20d31ffea0cd2c6f217da56 100644 --- a/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java +++ b/presto-main/src/main/java/io/prestosql/testing/TestingRecoveryUtils.java @@ -95,6 +95,12 @@ public class TestingRecoveryUtils { } + @Override + public void flush(OutputStream outputStream) + throws IOException + { + } + @Override public InputStream newInputStream(Path path) { diff --git a/presto-main/src/test/java/io/prestosql/SessionTestUtils.java b/presto-main/src/test/java/io/prestosql/SessionTestUtils.java index 62c1e6626437bc450064d61cd8d600232ac0eea5..3ebaeddf30f28d1026a5f3ad09819faedc045ee6 100644 --- a/presto-main/src/test/java/io/prestosql/SessionTestUtils.java +++ b/presto-main/src/test/java/io/prestosql/SessionTestUtils.java @@ -24,6 +24,7 @@ import io.prestosql.utils.HetuConfig; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.SystemSessionProperties.RECOVERY_ENABLED; +import static io.prestosql.SystemSessionProperties.RETRY_POLICY; import static io.prestosql.SystemSessionProperties.REUSE_TABLE_SCAN; import static io.prestosql.SystemSessionProperties.SNAPSHOT_ENABLED; import static io.prestosql.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; @@ -50,6 +51,7 @@ public final class SessionTestUtils .build(); public static final Session TEST_SNAPSHOT_SESSION; + public static final Session TEST_TASK_SNAPSHOT_SESSION; static { @@ -71,6 +73,17 @@ public final class SessionTestUtils .setSystemProperty(RECOVERY_ENABLED, "true") .setSystemProperty(SNAPSHOT_ENABLED, "true") .build(); + + TEST_TASK_SNAPSHOT_SESSION = testSessionBuilder(new SessionPropertyManager(properties)) + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setClientCapabilities(stream(ClientCapabilities.values()) + .map(ClientCapabilities::toString) + .collect(toImmutableSet())) + .setSystemProperty(RECOVERY_ENABLED, "false") + .setSystemProperty(SNAPSHOT_ENABLED, "false") + .setSystemProperty(RETRY_POLICY, "TASK_ASYNC") + .build(); } private SessionTestUtils() diff --git a/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java b/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java new file mode 100644 index 0000000000000000000000000000000000000000..3a1c7a161605ed7c896859432c7e36f33204b33d --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/execution/buffer/TestHybridSpoolingBuffer.java @@ -0,0 +1,328 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.buffer; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.units.DataSize; +import io.hetu.core.filesystem.HetuLocalFileSystemClient; +import io.hetu.core.filesystem.LocalConfig; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.exchange.ExchangeManager; +import io.prestosql.exchange.ExchangeSink; +import io.prestosql.exchange.ExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.exchange.FileSystemExchangeManager; +import io.prestosql.exchange.FileSystemExchangeSinkHandle; +import io.prestosql.exchange.FileSystemExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeStats; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.exchange.storage.HetuFileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; +import io.prestosql.execution.StageId; +import io.prestosql.execution.TaskId; +import io.prestosql.operator.PageAssertions; +import io.prestosql.snapshot.SnapshotStateId; +import io.prestosql.spi.Page; +import io.prestosql.spi.QueryId; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.filesystem.HetuFileSystemClient; +import io.prestosql.testing.TestingPagesSerdeFactory; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestHybridSpoolingBuffer +{ + private final PagesSerde serde = new TestingPagesSerdeFactory().createPagesSerde(); + private final PagesSerde javaSerde = new PagesSerdeFactory(createTestMetadataManager().getFunctionAndTypeManager().getBlockEncodingSerde(), false) + .createDirectPagesSerde(Optional.empty(), true, false); + private final PagesSerde kryoSerde = new PagesSerdeFactory(createTestMetadataManager().getFunctionAndTypeManager().getBlockKryoEncodingSerde(), false) + .createDirectPagesSerde(Optional.empty(), true, true); + private ExchangeSink exchangeSink; + private ExchangeManager exchangeManager; + private ExchangeSinkInstanceHandle exchangeSinkInstanceHandle; + private final String baseURI = "file:///tmp/hetu/spooling"; + private final String baseDir = "/tmp/hetu/spooling"; + private final String accessDir = "/tmp/hetu"; + private final Path accessPath = Paths.get(accessDir); + private final HetuFileSystemClient hetuFileSystemClient = new HetuLocalFileSystemClient(new LocalConfig(new Properties()), accessPath); + + @BeforeMethod + public void setUp() + throws IOException, InterruptedException + { + Path basePath = Paths.get(baseDir); + File base = new File(accessDir); + if (base.exists()) { + deleteDirectory(base); + } + Files.createDirectories(basePath); + } + + @AfterMethod + public void cleanUp() + { + File base = new File(accessDir); + if (base.exists()) { + deleteDirectory(base); + } + } + + private void setConfig(FileSystemExchangeConfig.DirectSerialisationType type) + { + FileSystemExchangeConfig config = new FileSystemExchangeConfig() + .setExchangeEncryptionEnabled(false) + .setDirectSerializationType(type) + .setBaseDirectories(baseDir); + + FileSystemExchangeStorage exchangeStorage = new HetuFileSystemExchangeStorage(); + exchangeStorage.setFileSystemClient(hetuFileSystemClient); + exchangeManager = new FileSystemExchangeManager(exchangeStorage, new FileSystemExchangeStats(), config); + exchangeSinkInstanceHandle = new FileSystemExchangeSinkInstanceHandle( + new FileSystemExchangeSinkHandle(0, Optional.empty(), false), + config.getBaseDirectories().get(0), + 10); + exchangeSink = exchangeManager.createSink(exchangeSinkInstanceHandle, false); + } + + @Test + public void testHybridSpoolingBufferWithSerializationOff() + throws ExecutionException, InterruptedException + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + List pages = new ArrayList<>(); + pages.add(generateSerializedPage()); + pages.add(generateSerializedPage()); + hybridSpoolingBuffer.enqueue(0, pages, null); + ListenableFuture result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + while (result.get().equals(BufferResult.emptyResults(0, false))) { + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + } + List actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + hybridSpoolingBuffer.setNoMorePages(); + } + + @Test + public void testHybridSpoolingBufferWithSerializationJava() + throws ExecutionException, InterruptedException + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.JAVA); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.setSerde(serde); + hybridSpoolingBuffer.setJavaSerde(javaSerde); + hybridSpoolingBuffer.setKryoSerde(kryoSerde); + List pages = new ArrayList<>(); + pages.add(generatePage()); + pages.add(generatePage()); + hybridSpoolingBuffer.enqueuePages(0, pages, null, javaSerde); + ListenableFuture result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + while (result.get().equals(BufferResult.emptyResults(0, false))) { + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + } + List actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + hybridSpoolingBuffer.setNoMorePages(); + } + + @Test + public void testHybridSpoolingBufferWithSerializationKryo() + throws ExecutionException, InterruptedException + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.KRYO); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.setSerde(serde); + hybridSpoolingBuffer.setJavaSerde(javaSerde); + hybridSpoolingBuffer.setKryoSerde(kryoSerde); + List pages = new ArrayList<>(); + pages.add(generatePage()); + pages.add(generatePage()); + hybridSpoolingBuffer.enqueuePages(0, pages, null, kryoSerde); + ListenableFuture result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + while (result.get().equals(BufferResult.emptyResults(0, false))) { + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + } + List actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + result = hybridSpoolingBuffer.get(new OutputBuffers.OutputBufferId(0), 0, new DataSize(100, DataSize.Unit.MEGABYTE)); + actualPages = result.get().getSerializedPages().stream().map(page -> serde.deserialize(page)).collect(Collectors.toList()); + assertEquals(actualPages.size(), 1); + for (int pageCount = 0; pageCount < actualPages.size(); pageCount++) { + PageAssertions.assertPageEquals(ImmutableList.of(INTEGER, INTEGER), actualPages.get(pageCount), generatePage()); + } + hybridSpoolingBuffer.setNoMorePages(); + } + + @Test + public void testHybridSpoolingMarkerIndexFile() + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + URI markerDataFile = URI.create(baseURI).resolve("marker_data_file.data"); + URI spoolingDataFile = URI.create(baseURI).resolve("spooling_data_file.data"); + HybridSpoolingBuffer.SpoolingInfo spoolingInfo = new HybridSpoolingBuffer.SpoolingInfo(1000, 2000); + hybridSpoolingBuffer.enqueueMarkerIndex(1, markerDataFile, 1000, 2000, ImmutableMap.of(spoolingDataFile, spoolingInfo)); + } + + @Test + public void testHybridSpoolingMarkerDataFile() + { + SnapshotStateId snapshotStateId = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 1, 1)); + SnapshotStateId snapshotStateId1 = new SnapshotStateId(2, new TaskId(new StageId("query", 1), 1, 1)); + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.enqueueMarkerData(1, ImmutableMap.of(snapshotStateId.toString(), "marker1"), exchangeSink.getSinkFiles()); + hybridSpoolingBuffer.enqueueMarkerData(2, ImmutableMap.of(snapshotStateId1.toString(), "marker2"), exchangeSink.getSinkFiles()); + MarkerDataFileFactory.MarkerDataFileFooter footer = (MarkerDataFileFactory.MarkerDataFileFooter) hybridSpoolingBuffer.getMarkerDataFileFooter(); + Map operatorStateInfoMap = footer.getOperatorStateInfo(); + MarkerDataFileFactory.OperatorStateInfo operatorStateInfo = operatorStateInfoMap.get(snapshotStateId1.getId()); + assertEquals(hybridSpoolingBuffer.getMarkerData(operatorStateInfo.getStateOffset()), "marker2"); + MarkerDataFileFactory.MarkerDataFileFooter previousFooter = (MarkerDataFileFactory.MarkerDataFileFooter) hybridSpoolingBuffer.getMarkerDataFileFooter(footer.getPreviousTailOffset()); + Map previousOperatorStateInfoMap = previousFooter.getOperatorStateInfo(); + MarkerDataFileFactory.OperatorStateInfo previousOperatorStateInfo = previousOperatorStateInfoMap.get(snapshotStateId.getId()); + assertEquals(hybridSpoolingBuffer.getMarkerData(previousOperatorStateInfo.getStateOffset()), "marker1"); + } + + @Test + public void testMarkerSpoolingInfo() + { + int entries = 10; + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, entries); + for (int i = 0; i < entries; i++) { + BIGINT.writeLong(blockBuilder, i); + } + Block block = blockBuilder.build(); + Page page = new Page(block); + SerializedPage serializedPage = javaSerde.serialize(page); + SnapshotStateId snapshotStateId = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 1, 1)); + SnapshotStateId snapshotStateId1 = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 2, 1)); + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.enqueue(0, ImmutableList.of(serializedPage), null); + hybridSpoolingBuffer.enqueueMarkerInfo(1, ImmutableMap.of(snapshotStateId.toString(), "marker1", snapshotStateId1.toString(), "marker2")); + assertEquals(hybridSpoolingBuffer.dequeueMarkerInfo(1), ImmutableMap.of(snapshotStateId.toString(), "marker1", snapshotStateId1.toString(), "marker2")); + } + + @Test + public void testMarkerSpoolingInfoMultiPartition() + { + int entries = 10; + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, entries); + for (int i = 0; i < entries; i++) { + BIGINT.writeLong(blockBuilder, i); + } + Block block = blockBuilder.build(); + Page page = new Page(block); + SerializedPage serializedPage = javaSerde.serialize(page); + SnapshotStateId snapshotStateId = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 1, 1)); + SnapshotStateId snapshotStateId1 = new SnapshotStateId(1, new TaskId(new StageId("query", 1), 2, 1)); + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + hybridSpoolingBuffer.enqueue(0, ImmutableList.of(serializedPage), null); + hybridSpoolingBuffer.enqueue(1, ImmutableList.of(serializedPage), null); + hybridSpoolingBuffer.enqueueMarkerInfo(1, ImmutableMap.of(snapshotStateId.toString(), "marker1")); + hybridSpoolingBuffer.enqueueMarkerInfo(2, ImmutableMap.of(snapshotStateId1.toString(), "marker2")); + assertEquals(hybridSpoolingBuffer.dequeueMarkerInfo(1), ImmutableMap.of(snapshotStateId, "marker1")); + assertEquals(hybridSpoolingBuffer.dequeueMarkerInfo(2), ImmutableMap.of(snapshotStateId1, "marker2")); + } + + private HybridSpoolingBuffer createHybridSpoolingBuffer() + { + OutputBuffers outputBuffers = OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.PARTITIONED); + outputBuffers.setExchangeSinkInstanceHandle(exchangeSinkInstanceHandle); + + return new HybridSpoolingBuffer(new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), directExecutor()), + outputBuffers, + exchangeSink, + TestSpoolingExchangeOutputBuffer.TestingLocalMemoryContext::new, + exchangeManager); + } + + private SerializedPage generateSerializedPage() + { + Page expectedPage = generatePage(); + SerializedPage page = serde.serialize(expectedPage); + return page; + } + + private Page generatePage() + { + BlockBuilder expectedBlockBuilder = INTEGER.createBlockBuilder(null, 2); + INTEGER.writeLong(expectedBlockBuilder, 10); + INTEGER.writeLong(expectedBlockBuilder, 20); + Block expectedBlock = expectedBlockBuilder.build(); + + return new Page(expectedBlock, expectedBlock); + } + + private boolean deleteDirectory(File dir) + { + File[] allContents = dir.listFiles(); + if (allContents != null) { + for (File file : allContents) { + deleteDirectory(file); + } + } + return dir.delete(); + } +} diff --git a/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java b/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java index b5ad674d3e284d2240e974b4eba9a32c6e6922c1..37776fc3fe6e74d10e430d7d809c07046347fa55 100644 --- a/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java +++ b/presto-main/src/test/java/io/prestosql/execution/buffer/TestSpoolingExchangeOutputBuffer.java @@ -24,6 +24,8 @@ import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; import io.prestosql.exchange.ExchangeSink; import io.prestosql.exchange.ExchangeSinkInstanceHandle; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.execution.MarkerDataFileFactory; import io.prestosql.execution.StageId; import io.prestosql.execution.TaskId; import io.prestosql.memory.context.LocalMemoryContext; @@ -31,6 +33,10 @@ import io.prestosql.spi.Page; import io.prestosql.spi.QueryId; import org.testng.annotations.Test; +import javax.crypto.SecretKey; + +import java.net.URI; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -358,6 +364,48 @@ public class TestSpoolingExchangeOutputBuffer return abort; } + @Override + public FileSystemExchangeStorage getExchangeStorage() + { + return null; + } + + @Override + public URI getOutputDirectory() + { + return null; + } + + @Override + public Optional getSecretKey() + { + return Optional.empty(); + } + + @Override + public boolean isExchangeCompressionEnabled() + { + return false; + } + + @Override + public int getPartitionId() + { + return 0; + } + + @Override + public List getSinkFiles() + { + return ImmutableList.of(); + } + + @Override + public void enqueueMarkerInfo(MarkerDataFileFactory.MarkerDataFileFooterInfo markerDataFileFooterInfo) + { + return; + } + public void setAbort(CompletableFuture abort) { this.abort = requireNonNull(abort, "abort is null"); @@ -370,7 +418,7 @@ public class TestSpoolingExchangeOutputBuffer INSTANCE } - private static class TestingLocalMemoryContext + public static class TestingLocalMemoryContext implements LocalMemoryContext { @Override diff --git a/presto-main/src/test/java/io/prestosql/operator/TestOrderByOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestOrderByOperator.java index d1d8419b7c413c56d60265b71c0b2f7451c83c8c..ca41d415d52b64b8f692d9364e798177ebe59149 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestOrderByOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestOrderByOperator.java @@ -22,12 +22,30 @@ import io.hetu.core.filesystem.HetuLocalFileSystemClient; import io.hetu.core.filesystem.LocalConfig; import io.prestosql.ExceededMemoryLimitException; import io.prestosql.Session; +import io.prestosql.exchange.ExchangeManager; +import io.prestosql.exchange.ExchangeSink; +import io.prestosql.exchange.ExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.exchange.FileSystemExchangeManager; +import io.prestosql.exchange.FileSystemExchangeSinkHandle; +import io.prestosql.exchange.FileSystemExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeStats; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.exchange.storage.HetuFileSystemExchangeStorage; +import io.prestosql.execution.StageId; +import io.prestosql.execution.TaskId; +import io.prestosql.execution.buffer.HybridSpoolingBuffer; +import io.prestosql.execution.buffer.OutputBufferStateMachine; +import io.prestosql.execution.buffer.OutputBuffers; +import io.prestosql.execution.buffer.TestSpoolingExchangeOutputBuffer; import io.prestosql.filesystem.FileSystemClientManager; import io.prestosql.metadata.InMemoryNodeManager; import io.prestosql.operator.OrderByOperator.OrderByOperatorFactory; import io.prestosql.snapshot.RecoveryConfig; import io.prestosql.snapshot.RecoveryUtils; import io.prestosql.spi.Page; +import io.prestosql.spi.QueryId; +import io.prestosql.spi.filesystem.HetuFileSystemClient; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.snapshot.MarkerPage; import io.prestosql.spiller.FileSingleStreamSpillerFactory; @@ -53,6 +71,7 @@ import java.util.Properties; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; @@ -61,6 +80,7 @@ import static io.airlift.units.DataSize.succinctBytes; import static io.prestosql.RowPagesBuilder.rowPagesBuilder; import static io.prestosql.SessionTestUtils.TEST_SESSION; import static io.prestosql.SessionTestUtils.TEST_SNAPSHOT_SESSION; +import static io.prestosql.SessionTestUtils.TEST_TASK_SNAPSHOT_SESSION; import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; import static io.prestosql.operator.OperatorAssertion.assertOperatorEqualsWithSimpleSelfStateComparison; @@ -93,6 +113,14 @@ public class TestOrderByOperator private DummySpillerFactory spillerFactory; private RecoveryUtils recoveryUtils = NOOP_RECOVERY_UTILS; private FileSystemClientManager fileSystemClientManager = mock(FileSystemClientManager.class); + private ExchangeSink exchangeSink; + private ExchangeManager exchangeManager; + private ExchangeSinkInstanceHandle exchangeSinkInstanceHandle; + private final String baseURI = "file:///tmp/hetu/spooling"; + private final String baseDir = "/tmp/hetu/spooling"; + private final String accessDir = "/tmp/hetu"; + private final Path accessPath = Paths.get(accessDir); + private final HetuFileSystemClient hetuFileSystemClient = new HetuLocalFileSystemClient(new LocalConfig(new Properties()), accessPath); @DataProvider public static Object[][] spillEnabled() @@ -415,6 +443,149 @@ public class TestOrderByOperator return new GenericSpillerFactory(streamSpillerFactory); } + private void setConfig(FileSystemExchangeConfig.DirectSerialisationType type) + { + if (!hetuFileSystemClient.exists(accessPath)) { + try { + hetuFileSystemClient.createDirectory(accessPath); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + FileSystemExchangeConfig config = new FileSystemExchangeConfig() + .setExchangeEncryptionEnabled(false) + .setDirectSerializationType(type) + .setBaseDirectories(baseDir); + + FileSystemExchangeStorage exchangeStorage = new HetuFileSystemExchangeStorage(); + exchangeStorage.setFileSystemClient(hetuFileSystemClient); + exchangeManager = new FileSystemExchangeManager(exchangeStorage, new FileSystemExchangeStats(), config); + exchangeSinkInstanceHandle = new FileSystemExchangeSinkInstanceHandle( + new FileSystemExchangeSinkHandle(0, Optional.empty(), false), + config.getBaseDirectories().get(0), + 10); + exchangeSink = exchangeManager.createSink(exchangeSinkInstanceHandle, false); + } + + private HybridSpoolingBuffer createHybridSpoolingBuffer() + { + OutputBuffers outputBuffers = OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.PARTITIONED); + outputBuffers.setExchangeSinkInstanceHandle(exchangeSinkInstanceHandle); + + return new HybridSpoolingBuffer(new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), directExecutor()), + outputBuffers, + exchangeSink, + TestSpoolingExchangeOutputBuffer.TestingLocalMemoryContext::new, + exchangeManager); + } + + @Test + public void TestTaskSnapshotConsolidatedStoreAndLoad() + throws Exception + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + + List input1 = rowPagesBuilder(VARCHAR, BIGINT) + .row("a", 1L) + .row("b", 2L) + .pageBreak() + .row("b", 3L) + .row("a", 4L) + .build(); + + List input2 = rowPagesBuilder(VARCHAR, BIGINT) + .row("c", 4L) + .row("d", 6L) + .pageBreak() + .row("c", 2L) + .row("d", 3L) + .build(); + + OrderByOperatorFactory operatorFactory = new OrderByOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(VARCHAR, BIGINT), + ImmutableList.of(0, 1), + 10, + ImmutableList.of(0, 1), + ImmutableList.of(ASC_NULLS_LAST, DESC_NULLS_LAST), + new PagesIndex.TestingFactory(false), + false, + Optional.empty(), + new OrderingCompiler(), + false, false); + + DriverContext driverContext = createDriverContext(defaultMemoryLimit, TEST_TASK_SNAPSHOT_SESSION); + driverContext.getPipelineContext().getTaskContext().getSnapshotManager().setTotalComponents(1); + driverContext.getPipelineContext().getTaskContext().getSnapshotManager().setHybridSpoolingBuffer(hybridSpoolingBuffer); + OrderByOperator orderByOperator = (OrderByOperator) operatorFactory.createOperator(driverContext); + + // Step1: add the first 2 pages + for (Page page : input1) { + orderByOperator.addInput(page); + } + + MarkerPage marker = MarkerPage.snapshotPage(1); + orderByOperator.addInput(marker); + + // Step4: add another 2 pages + for (Page page : input2) { + orderByOperator.addInput(page); + } + + // Step5: assume the task is rescheduled due to failure and everything is re-constructed + driverContext = createDriverContext(defaultMemoryLimit, TEST_TASK_SNAPSHOT_SESSION); + driverContext.getPipelineContext().getTaskContext().getSnapshotManager().setTotalComponents(1); + driverContext.getPipelineContext().getTaskContext().getSnapshotManager().setHybridSpoolingBuffer(hybridSpoolingBuffer); + operatorFactory = new OrderByOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(VARCHAR, BIGINT), + ImmutableList.of(0, 1), + 10, + ImmutableList.of(0, 1), + ImmutableList.of(ASC_NULLS_LAST, DESC_NULLS_LAST), + new PagesIndex.TestingFactory(false), + false, + Optional.empty(), + new OrderingCompiler(), + false, false); + orderByOperator = (OrderByOperator) operatorFactory.createOperator(driverContext); + + // Step6: restore to 'capture1', the spiller should contains the reference of the first 2 pages for now. + MarkerPage resumeMarker = MarkerPage.resumePage(1); + orderByOperator.addInput(resumeMarker); + + // Step7: continue to add another 2 pages + for (Page page : input2) { + orderByOperator.addInput(page); + } + orderByOperator.finish(); + + // Compare the results + MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT) + .row("a", 4L) + .row("a", 1L) + .row("b", 3L) + .row("b", 2L) + .row("c", 4L) + .row("c", 2L) + .row("d", 6L) + .row("d", 3L) + .build(); + + ImmutableList.Builder outputPages = ImmutableList.builder(); + Page p = orderByOperator.getOutput(); + while (p instanceof MarkerPage) { + p = orderByOperator.getOutput(); + } + outputPages.add(p); + MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), outputPages.build()); + Assert.assertEquals(actual, expected); + } + /** * This test is supposed to consume 4 pages and produce the output page with sorted ordering. * The spilling and capturing('capture1') happened after the first 2 pages added into the operator. @@ -458,8 +629,8 @@ public class TestOrderByOperator ImmutableList.of(0, 1), ImmutableList.of(ASC_NULLS_LAST, DESC_NULLS_LAST), new PagesIndex.TestingFactory(false), - true, - Optional.of(genericSpillerFactory), + false, + Optional.empty(), new OrderingCompiler(), false, false); diff --git a/presto-main/src/test/java/io/prestosql/snapshot/TestMultiInputSnapshotState.java b/presto-main/src/test/java/io/prestosql/snapshot/TestMultiInputSnapshotState.java index e812ba40d11f1524fcaff3e6baf5b913aec98779..5b111de9b6eb0c1af9d31da9c4e034a7f85ca45a 100644 --- a/presto-main/src/test/java/io/prestosql/snapshot/TestMultiInputSnapshotState.java +++ b/presto-main/src/test/java/io/prestosql/snapshot/TestMultiInputSnapshotState.java @@ -89,7 +89,7 @@ public class TestMultiInputSnapshotState snapshotManager = mock(TaskSnapshotManager.class); restorable = new TestingRestorable(); restorable.state = 100; - state = new MultiInputSnapshotState(restorable, snapshotManager, serde, TestMultiInputSnapshotState::createSnapshotStateId); + state = new MultiInputSnapshotState(restorable, snapshotManager, serde, TestMultiInputSnapshotState::createSnapshotStateId, TEST_SNAPSHOT_SESSION); } private Optional processPage(String source, Page page) @@ -265,7 +265,7 @@ public class TestMultiInputSnapshotState throws Exception { TestingRestorableUndeterminedInputs restorableUndeterminedInputs = new TestingRestorableUndeterminedInputs(); - MultiInputSnapshotState inputSnapshotState = new MultiInputSnapshotState(restorableUndeterminedInputs, snapshotManager, serde, TestMultiInputSnapshotState::createSnapshotStateId); + MultiInputSnapshotState inputSnapshotState = new MultiInputSnapshotState(restorableUndeterminedInputs, snapshotManager, serde, TestMultiInputSnapshotState::createSnapshotStateId, TEST_SNAPSHOT_SESSION); processPage(inputSnapshotState, source1, marker1); processPage(inputSnapshotState, source2, marker1); diff --git a/presto-main/src/test/java/io/prestosql/snapshot/TestSingleInputSnapshotState.java b/presto-main/src/test/java/io/prestosql/snapshot/TestSingleInputSnapshotState.java index 53e9514cdfab3d1726f17aaf6f4bb04b748b295d..d67d10bf42a84a9c4e4431e6814342e3c9d050b6 100644 --- a/presto-main/src/test/java/io/prestosql/snapshot/TestSingleInputSnapshotState.java +++ b/presto-main/src/test/java/io/prestosql/snapshot/TestSingleInputSnapshotState.java @@ -80,7 +80,7 @@ public class TestSingleInputSnapshotState when(snapshotMemoryContext.trySetBytes(anyLong())).thenReturn(true); restorable = new TestingRestorable(); restorable.state = 100; - state = new SingleInputSnapshotState(restorable, snapshotManager, null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, snapshotMemoryContext, false); + state = new SingleInputSnapshotState(restorable, snapshotManager, null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, snapshotMemoryContext, false, TEST_SNAPSHOT_SESSION); } private boolean processPage(Page page) @@ -207,7 +207,7 @@ public class TestSingleInputSnapshotState public void testResumeBacktrack() throws Exception { - SingleInputSnapshotState singleInputSnapshotState = new SingleInputSnapshotState(restorable, snapshotManager, null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, snapshotMemoryContext, false); + SingleInputSnapshotState singleInputSnapshotState = new SingleInputSnapshotState(restorable, snapshotManager, null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, snapshotMemoryContext, false, TEST_SNAPSHOT_SESSION); singleInputSnapshotState.processPage(regularPage); restorable.state++; int saved1 = restorable.state; @@ -233,7 +233,7 @@ public class TestSingleInputSnapshotState null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, - snapshotMemoryContext, true); + snapshotMemoryContext, true, TEST_SNAPSHOT_SESSION); singleInputSnapshotState.processPage(marker1); when(snapshotManager.loadState(anyObject())).thenReturn(Optional.of(1)); when(snapshotManager.loadSpilledPathInfo(anyObject())) @@ -260,7 +260,7 @@ public class TestSingleInputSnapshotState null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, - snapshotMemoryContext, true); + snapshotMemoryContext, true, TEST_SNAPSHOT_SESSION); singleInputSnapshotState.processPage(marker1); singleInputSnapshotState.processPage(marker2); when(snapshotManager.loadState(anyObject())).thenReturn(Optional.of(1)); @@ -289,7 +289,7 @@ public class TestSingleInputSnapshotState null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, - snapshotMemoryContext, false); + snapshotMemoryContext, false, TEST_SNAPSHOT_SESSION); singleInputSnapshotState.processPage(marker1); when(snapshotManager.loadConsolidatedState(anyObject())).thenReturn(Optional.of(0)); singleInputSnapshotState.processPage(resume1); @@ -311,7 +311,7 @@ public class TestSingleInputSnapshotState null, TestSingleInputSnapshotState::createSnapshotStateId, TestSingleInputSnapshotState::createSnapshotStateId, - snapshotMemoryContext, false); + snapshotMemoryContext, false, TEST_SNAPSHOT_SESSION); singleInputSnapshotState.processPage(marker1); when(snapshotManager.loadState(anyObject())).thenReturn(Optional.of(0)); singleInputSnapshotState.processPage(resume1); diff --git a/presto-main/src/test/java/io/prestosql/snapshot/TestTaskSnapshotManager.java b/presto-main/src/test/java/io/prestosql/snapshot/TestTaskSnapshotManager.java index 0b4864fa1c2a69d970fb50ab479bcab9e7c937b9..9fb096ae4d9dc88faba92ec4e6f6edcd710043a9 100644 --- a/presto-main/src/test/java/io/prestosql/snapshot/TestTaskSnapshotManager.java +++ b/presto-main/src/test/java/io/prestosql/snapshot/TestTaskSnapshotManager.java @@ -17,12 +17,30 @@ package io.prestosql.snapshot; import com.google.common.collect.ImmutableList; import io.hetu.core.filesystem.HetuLocalFileSystemClient; import io.hetu.core.filesystem.LocalConfig; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; +import io.prestosql.exchange.ExchangeManager; +import io.prestosql.exchange.ExchangeSink; +import io.prestosql.exchange.ExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.exchange.FileSystemExchangeManager; +import io.prestosql.exchange.FileSystemExchangeSinkHandle; +import io.prestosql.exchange.FileSystemExchangeSinkInstanceHandle; +import io.prestosql.exchange.FileSystemExchangeStats; +import io.prestosql.exchange.storage.FileSystemExchangeStorage; +import io.prestosql.exchange.storage.HetuFileSystemExchangeStorage; import io.prestosql.execution.StageId; import io.prestosql.execution.TaskId; +import io.prestosql.execution.buffer.HybridSpoolingBuffer; +import io.prestosql.execution.buffer.OutputBufferStateMachine; +import io.prestosql.execution.buffer.OutputBuffers; +import io.prestosql.execution.buffer.TestSpoolingExchangeOutputBuffer; import io.prestosql.filesystem.FileSystemClientManager; import io.prestosql.metadata.InMemoryNodeManager; import io.prestosql.operator.Operator; import io.prestosql.spi.QueryId; +import io.prestosql.spi.filesystem.HetuFileSystemClient; +import io.prestosql.testing.TestingPagesSerdeFactory; import io.prestosql.testing.assertions.Assert; import org.apache.commons.io.FileUtils; import org.testng.annotations.AfterMethod; @@ -40,7 +58,10 @@ import java.util.Map; import java.util.Optional; import java.util.Properties; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.prestosql.SessionTestUtils.TEST_SNAPSHOT_SESSION; +import static io.prestosql.SessionTestUtils.TEST_TASK_SNAPSHOT_SESSION; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -60,6 +81,20 @@ public class TestTaskSnapshotManager RecoveryUtils recoveryUtils; QueryId queryId; + private final PagesSerde serde = new TestingPagesSerdeFactory().createPagesSerde(); + private final PagesSerde javaSerde = new PagesSerdeFactory(createTestMetadataManager().getFunctionAndTypeManager().getBlockEncodingSerde(), false) + .createDirectPagesSerde(Optional.empty(), true, false); + private final PagesSerde kryoSerde = new PagesSerdeFactory(createTestMetadataManager().getFunctionAndTypeManager().getBlockKryoEncodingSerde(), false) + .createDirectPagesSerde(Optional.empty(), true, true); + private ExchangeSink exchangeSink; + private ExchangeManager exchangeManager; + private ExchangeSinkInstanceHandle exchangeSinkInstanceHandle; + private final String baseURI = "file:///tmp/hetu/spooling"; + private final String baseDir = "/tmp/hetu/spooling"; + private final String accessDir = "/tmp/hetu"; + private final Path accessPath = Paths.get(accessDir); + private final HetuFileSystemClient hetuFileSystemClient = new HetuLocalFileSystemClient(new LocalConfig(new Properties()), accessPath); + @BeforeMethod public void setup() throws Exception @@ -74,6 +109,13 @@ public class TestTaskSnapshotManager recoveryUtils = new RecoveryUtils(fileSystemClientManager, recoveryConfig, new InMemoryNodeManager()); recoveryUtils.rootPath = SNAPSHOT_FILE_SYSTEM_DIR; recoveryUtils.initialize(); + + Path basePath = Paths.get(baseDir); + File base = new File(accessDir); + if (base.exists()) { + deleteDirectory(base); + } + Files.createDirectories(basePath); } @AfterMethod @@ -82,6 +124,63 @@ public class TestTaskSnapshotManager { // Cleanup files recoveryUtils.removeQuerySnapshotManager(queryId); + + File base = new File(accessDir); + if (base.exists()) { + deleteDirectory(base); + } + } + + private void setConfig(FileSystemExchangeConfig.DirectSerialisationType type) + { + FileSystemExchangeConfig config = new FileSystemExchangeConfig() + .setExchangeEncryptionEnabled(false) + .setDirectSerializationType(type) + .setBaseDirectories(baseDir); + + FileSystemExchangeStorage exchangeStorage = new HetuFileSystemExchangeStorage(); + exchangeStorage.setFileSystemClient(hetuFileSystemClient); + exchangeManager = new FileSystemExchangeManager(exchangeStorage, new FileSystemExchangeStats(), config); + exchangeSinkInstanceHandle = new FileSystemExchangeSinkInstanceHandle( + new FileSystemExchangeSinkHandle(0, Optional.empty(), false), + config.getBaseDirectories().get(0), + 10); + exchangeSink = exchangeManager.createSink(exchangeSinkInstanceHandle, false); + } + + private HybridSpoolingBuffer createHybridSpoolingBuffer() + { + OutputBuffers outputBuffers = OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.PARTITIONED); + outputBuffers.setExchangeSinkInstanceHandle(exchangeSinkInstanceHandle); + + return new HybridSpoolingBuffer(new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), directExecutor()), + outputBuffers, + exchangeSink, + TestSpoolingExchangeOutputBuffer.TestingLocalMemoryContext::new, + exchangeManager); + } + + @Test + public void TestTaskSnapshotConsolidatedStoreAndLoad() + throws Exception + { + setConfig(FileSystemExchangeConfig.DirectSerialisationType.OFF); + HybridSpoolingBuffer hybridSpoolingBuffer = createHybridSpoolingBuffer(); + + queryId = new QueryId("saveandload"); + TaskId taskId1 = new TaskId(queryId.getId(), 1, 2, 0); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_TASK_SNAPSHOT_SESSION); + recoveryUtils.getOrCreateQuerySnapshotManager(queryId, TEST_SNAPSHOT_SESSION); + + snapshotManager.setTotalComponents(1); + snapshotManager.setHybridSpoolingBuffer(hybridSpoolingBuffer); + MockState operatorState = new MockState("operator-state"); + SnapshotStateId operatorStateId = SnapshotStateId.forOperator(1L, taskId1, 3, 4, 5); + snapshotManager.storeConsolidatedState(operatorStateId, operatorState, 0); + snapshotManager.succeededToCapture(operatorStateId); + + MockState newOperatorState = (MockState) snapshotManager.loadConsolidatedState(TaskSnapshotManager.createConsolidatedId(operatorStateId.getSnapshotId(), operatorStateId.getTaskId())).get(); + Assert.assertEquals(operatorState.getState(), newOperatorState.getState()); } @Test @@ -90,7 +189,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("saveandload"); TaskId taskId1 = new TaskId(queryId.getId(), 1, 2, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); recoveryUtils.getOrCreateQuerySnapshotManager(queryId, TEST_SNAPSHOT_SESSION); @@ -120,7 +219,7 @@ public class TestTaskSnapshotManager queryId = new QueryId("loadbacktrack"); StageId stageId = new StageId(queryId, 0); TaskId taskId = new TaskId(stageId, 0, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); recoveryUtils.getOrCreateQuerySnapshotManager(queryId, TEST_SNAPSHOT_SESSION); @@ -151,7 +250,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("file"); TaskId taskId = new TaskId(queryId.getId(), 1, 5, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager.setTotalComponents(1); // Create a file @@ -190,7 +289,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("filebacktrack"); TaskId taskId = new TaskId(queryId.getId(), 2, 3, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager.setTotalComponents(1); // Create a file @@ -241,8 +340,8 @@ public class TestTaskSnapshotManager queryId = new QueryId("query"); TaskId taskId1 = new TaskId(queryId.getId(), 1, 1, 0); TaskId taskId2 = new TaskId(queryId.getId(), 1, 2, 0); - TaskSnapshotManager snapshotManager1 = new TaskSnapshotManager(taskId1, 0, recoveryUtils); - TaskSnapshotManager snapshotManager2 = new TaskSnapshotManager(taskId2, 0, recoveryUtils); + TaskSnapshotManager snapshotManager1 = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); + TaskSnapshotManager snapshotManager2 = new TaskSnapshotManager(taskId2, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager1.setTotalComponents(2); snapshotManager2.setTotalComponents(2); @@ -267,8 +366,8 @@ public class TestTaskSnapshotManager { TaskId taskId1 = new TaskId("query", 1, 1, 0); TaskId taskId2 = new TaskId("query", 1, 2, 0); - TaskSnapshotManager snapshotManager1 = new TaskSnapshotManager(taskId1, 0, recoveryUtils); - TaskSnapshotManager snapshotManager2 = new TaskSnapshotManager(taskId2, 0, recoveryUtils); + TaskSnapshotManager snapshotManager1 = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); + TaskSnapshotManager snapshotManager2 = new TaskSnapshotManager(taskId2, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager1.setTotalComponents(2); snapshotManager2.setTotalComponents(3); @@ -304,7 +403,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("query"); TaskId taskId = new TaskId(queryId.getId(), 1, 1, 0); - TaskSnapshotManager sm = new TaskSnapshotManager(taskId, 0, recoveryUtils); + TaskSnapshotManager sm = new TaskSnapshotManager(taskId, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); sm.setTotalComponents(2); recoveryUtils.getOrCreateQuerySnapshotManager(queryId, TEST_SNAPSHOT_SESSION); @@ -320,7 +419,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("consolidatesimplequery"); TaskId taskId1 = new TaskId(queryId.getId(), 1, 0, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager.setTotalComponents(1); recoveryUtils.getOrCreateQuerySnapshotManager(queryId, TEST_SNAPSHOT_SESSION); @@ -339,7 +438,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("consolidatebacktrackquery"); TaskId taskId1 = new TaskId(queryId.getId(), 1, 0, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager.setTotalComponents(1); recoveryUtils.getOrCreateQuerySnapshotManager(queryId, TEST_SNAPSHOT_SESSION); @@ -374,8 +473,8 @@ public class TestTaskSnapshotManager queryId = new QueryId("consolidateoverlapquery"); TaskId taskId1 = new TaskId(queryId.getId(), 1, 0, 0); TaskId taskId2 = new TaskId(queryId.getId(), 1, 1, 0); - TaskSnapshotManager snapshotManager1 = new TaskSnapshotManager(taskId1, 0, recoveryUtils); - TaskSnapshotManager snapshotManager2 = new TaskSnapshotManager(taskId2, 0, recoveryUtils); + TaskSnapshotManager snapshotManager1 = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); + TaskSnapshotManager snapshotManager2 = new TaskSnapshotManager(taskId2, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager1.setTotalComponents(1); snapshotManager2.setTotalComponents(1); @@ -421,7 +520,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("deletedfileloadquery"); TaskId taskId1 = new TaskId(queryId.getId(), 1, 0, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager.setTotalComponents(1); QuerySnapshotManager querySnapshotManager = new QuerySnapshotManager(queryId, recoveryUtils, TEST_SNAPSHOT_SESSION); @@ -467,7 +566,7 @@ public class TestTaskSnapshotManager doThrow(new NullPointerException()).when(faultyRecoveryUtils).storeState(any(), any(), any()); TaskId taskId1 = new TaskId(queryId.getId(), 1, 0, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, faultyRecoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, faultyRecoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager.setTotalComponents(1); MockState state = new MockState("mockstate"); @@ -486,7 +585,7 @@ public class TestTaskSnapshotManager { queryId = new QueryId("spilleddeletedquery"); TaskId taskId1 = new TaskId(queryId.getId(), 1, 0, 0); - TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils); + TaskSnapshotManager snapshotManager = new TaskSnapshotManager(taskId1, 0, recoveryUtils, TEST_SNAPSHOT_SESSION); snapshotManager.setTotalComponents(1); QuerySnapshotManager querySnapshotManager = new QuerySnapshotManager(queryId, recoveryUtils, TEST_SNAPSHOT_SESSION); @@ -528,4 +627,15 @@ public class TestTaskSnapshotManager assertFalse(snapshotManager.loadFile(secondId, secondFile.toPath())); } } + + private boolean deleteDirectory(File dir) + { + File[] allContents = dir.listFiles(); + if (allContents != null) { + for (File file : allContents) { + deleteDirectory(file); + } + } + return dir.delete(); + } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java b/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java index b7eb581f16359aa133d61f0b16c3bb6136c26e32..9d0bb8497c96b48febc5785ccdf6654de9d0e9f0 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java +++ b/presto-spi/src/main/java/io/prestosql/spi/filesystem/HetuFileSystemClient.java @@ -210,4 +210,7 @@ public interface HetuFileSystemClient Stream getDirectoryStream(Path path, String prefix, String suffix) throws IOException; + + void flush(OutputStream outputStream) + throws IOException; }