diff --git a/omnioperator/omniop-native-reader/cpp/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..21cac76349bea878b57c222e4fd37713c727f6cf --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/CMakeLists.txt @@ -0,0 +1,48 @@ +# project name +project(native_reader) + +# required cmake version +cmake_minimum_required(VERSION 3.10) + +# configure cmake +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_COMPILER "g++") + +set(root_directory ${PROJECT_BINARY_DIR}) + +set(CMAKE_CXX_FLAGS_DEBUG "-pipe -g -Wall -fPIC -fno-common -fno-stack-protector") +set(CMAKE_CXX_FLAGS_RELEASE "-O2 -pipe -Wall -Wtrampolines -D_FORTIFY_SOURCE=2 -O2 -fPIC -finline-functions -fstack-protector-strong -s -Wl,-z,noexecstack -Wl,-z,relro,-z,now") + +if (DEFINED COVERAGE) + if(${COVERAGE} STREQUAL "ON") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -ftest-coverage -fprofile-arcs") + endif() +endif() +# configure file +configure_file( + "${PROJECT_SOURCE_DIR}/config.h.in" + "${PROJECT_SOURCE_DIR}/config.h" +) + +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} ROOT_SRCS) +# for header searching +include_directories(SYSTEM src) + +# compile library +add_subdirectory(src) + +message(STATUS "Build by ${CMAKE_BUILD_TYPE}") + +option(BUILD_CPP_TESTS "test" OFF) +message(STATUS "Option BUILD_CPP_TESTS: ${BUILD_CPP_TESTS}") +if(${BUILD_CPP_TESTS}) + enable_testing() + add_subdirectory(test) +endif () + +# options +option(DEBUG_RUNTIME "Debug" OFF) +message(STATUS "Option DEBUG: ${DEBUG_RUNTIME}") + +option(TRACE_RUNTIME "Trace" OFF) +message(STATUS "Option TRACE: ${TRACE_RUNTIME}") diff --git a/omnioperator/omniop-native-reader/cpp/build.sh b/omnioperator/omniop-native-reader/cpp/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..c21dba905a8ba6ba17c7d0448405ab82ba81d981 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/build.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +if [ -z "$OMNI_HOME" ]; then + echo "OMNI_HOME is empty" + OMNI_HOME=/opt +fi + +export OMNI_INCLUDE_PATH=$OMNI_HOME/lib/include +export OMNI_INCLUDE_PATH=$OMNI_INCLUDE_PATH:$OMNI_HOME/lib +export CPLUS_INCLUDE_PATH=$OMNI_INCLUDE_PATH:$CPLUS_INCLUDE_PATH +echo "OMNI_INCLUDE_PATH=$OMNI_INCLUDE_PATH" + +CURRENT_DIR=$(cd "$(dirname "$BASH_SOURCE")"; pwd) +echo $CURRENT_DIR +cd ${CURRENT_DIR} +if [ -d build ]; then + rm -r build +fi +mkdir build +cd build + +# options +if [ $# != 0 ] ; then + options="" + if [ $1 = 'debug' ]; then + echo "-- Enable Debug" + options="$options -DCMAKE_BUILD_TYPE=Debug -DDEBUG_RUNTIME=ON" + elif [ $1 = 'trace' ]; then + echo "-- Enable Trace" + options="$options -DCMAKE_BUILD_TYPE=Debug -DTRACE_RUNTIME=ON" + elif [ $1 = 'release' ];then + echo "-- Enable Release" + options="$options -DCMAKE_BUILD_TYPE=Release" + elif [ $1 = 'test' ];then + echo "-- Enable Test" + options="$options -DCMAKE_BUILD_TYPE=Test -DBUILD_CPP_TESTS=TRUE" + elif [ $1 = 'coverage' ]; then + echo "-- Enable Coverage" + options="$options -DCMAKE_BUILD_TYPE=Debug -DDEBUG_RUNTIME=ON -DCOVERAGE=ON" + else + echo "-- Enable Release" + options="$options -DCMAKE_BUILD_TYPE=Release" + fi + cmake .. $options +else + echo "-- Enable Release" + cmake .. -DCMAKE_BUILD_TYPE=Release +fi + +make -j5 + +if [ $# != 0 ] ; then + if [ $1 = 'coverage' ]; then + ./test/tptest --gtest_output=xml:test_detail.xml + lcov --d ../ --c --output-file test.info --rc lcov_branch_coverage=1 + lcov --remove test.info '*/opt/lib/include/*' '*test/*' '*build/src/*' '*/usr/include/*' '*/usr/lib/*' '*/usr/lib64/*' '*/usr/local/include/*' '*/usr/local/lib/*' '*/usr/local/lib64/*' -o final.info --rc lcov_branch_coverage=1 + genhtml final.info -o test_coverage --branch-coverage --rc lcov_branch_coverage=1 + fi +fi + +set +eu diff --git a/omnioperator/omniop-native-reader/cpp/config.h b/omnioperator/omniop-native-reader/cpp/config.h new file mode 100644 index 0000000000000000000000000000000000000000..71d819b34c775afde138dfb1db023fb20408c46b --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/config.h @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//#cmakedefine DEBUG_RUNTIME +//#cmakedefine TRACE_RUNTIME \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/config.h.in b/omnioperator/omniop-native-reader/cpp/config.h.in new file mode 100644 index 0000000000000000000000000000000000000000..43c74967c62ab21066187cc4eaf6a692706f4a97 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/config.h.in @@ -0,0 +1,2 @@ +#cmakedefine DEBUG_RUNTIME +#cmakedefine TRACE_RUNTIME \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ba2967f87f8ebcaeb6959b51c23cef857462a07 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt @@ -0,0 +1,59 @@ +include_directories(SYSTEM "/user/local/include") + +set (PROJ_TARGET native_reader) + + +set (SOURCE_FILES + jni/OrcColumnarBatchJniReader.cpp + jni/jni_common.cpp + jni/ParquetColumnarBatchJniReader.cpp + parquet/ParquetReader.cpp + parquet/ParquetColumnReader.cpp + parquet/ParquetTypedRecordReader.cpp + parquet/ParquetDecoder.cpp + common/UriInfo.cc + orcfile/OrcFileOverride.cc + orcfile/OrcHdfsFileOverride.cc + filesystem/hdfs_file.cpp + filesystem/hdfs_filesystem.cpp + filesystem/io_exception.cpp + filesystem/status.cpp + arrowadapter/FileSystemAdapter.cc + arrowadapter/UtilInternal.cc + arrowadapter/HdfsAdapter.cc + arrowadapter/LocalfsAdapter.cc + ) + +#Find required protobuf package +find_package(Protobuf REQUIRED) +if(PROTOBUF_FOUND) + message(STATUS "protobuf library found") +else() + message(FATAL_ERROR "protobuf library is needed but cant be found") +endif() + +include_directories(${Protobuf_INCLUDE_DIRS}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +add_library (${PROJ_TARGET} SHARED ${SOURCE_FILES} ${PROTO_SRCS} ${PROTO_HDRS} ${PROTO_SRCS_VB} ${PROTO_HDRS_VB}) + +find_package(Arrow REQUIRED) +find_package(Parquet REQUIRED) + +#JNI +target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) +target_include_directories(${PROJ_TARGET} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +target_link_libraries (${PROJ_TARGET} PUBLIC + Arrow::arrow_shared + Parquet::parquet_shared + orc + boostkit-omniop-vector-1.4.0-aarch64 + hdfs + ) + +set_target_properties(${PROJ_TARGET} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY ${root_directory}/releases +) + +install(TARGETS ${PROJ_TARGET} DESTINATION lib) diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/FileSystemAdapter.cc b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/FileSystemAdapter.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e0684e8098dc6662e0b506e7cde1e203fcdb306 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/FileSystemAdapter.cc @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "FileSystemAdapter.h" +#include "arrow/filesystem/hdfs.h" +#include "arrow/filesystem/localfs.h" +#include "arrow/filesystem/mockfs.h" +#include "arrow/filesystem/path_util.h" +#include "arrow/io/slow.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/macros.h" +#include "arrow/util/parallel.h" +#include "HdfsAdapter.h" +#include "LocalfsAdapter.h" +#include "UtilInternal.h" + +namespace arrow_adapter { + +using arrow::internal::Uri; +using arrow::fs::internal::RemoveLeadingSlash; +using arrow::fs::internal::ToSlashes; +using arrow::fs::FileSystem; +using arrow::fs::HadoopFileSystem; +using arrow::fs::LocalFileSystem; +using arrow::fs::internal::MockFileSystem; +using arrow::Result; + +namespace { + +Result> +FileSystemFromUriReal(const UriInfo &uri, const arrow::io::IOContext &io_context, std::string *out_path) +{ + const auto scheme = uri.Scheme(); + + if (scheme == "file") { + std::string path; + ARROW_ASSIGN_OR_RAISE(auto options, buildLocalfsOptionsFromUri(uri, &path)); + if (out_path != nullptr) { + *out_path = path; + } + return std::make_shared(options, io_context); + } + + if (scheme == "hdfs" || scheme == "viewfs") { + ARROW_ASSIGN_OR_RAISE(auto options, buildHdfsOptionsFromUri(uri)); + if (out_path != nullptr) { + *out_path = uri.Path(); + } + ARROW_ASSIGN_OR_RAISE(auto hdfs, HadoopFileSystem::Make(options, io_context)); + return hdfs; + } + + if (scheme == "mock") { + // MockFileSystem does not have an absolute / relative path distinction, + // normalize path by removing leading slash. + if (out_path != nullptr) { + *out_path = std::string(RemoveLeadingSlash(uri.Path())); + } + return std::make_shared(CurrentTimePoint(), + io_context); + } + + return arrow::fs::FileSystemFromUri(uri.ToString(), io_context, out_path); +} + +} // namespace + + +Result> FileSystemFromUriOrPath(const UriInfo &uri, + std::string *out_path) +{ + return FileSystemFromUriOrPath(uri, arrow::io::IOContext(), out_path); +} + +Result> FileSystemFromUriOrPath( + const UriInfo &uri, const arrow::io::IOContext &io_context, + std::string *out_path) +{ + const auto &uri_string = uri.ToString(); + if (arrow::fs::internal::DetectAbsolutePath(uri_string)) { + // Normalize path separators + if (out_path != nullptr) { + *out_path = ToSlashes(uri_string); + } + return std::make_shared(); + } + return FileSystemFromUriReal(uri, io_context, out_path); +} + +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/FileSystemAdapter.h b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/FileSystemAdapter.h new file mode 100644 index 0000000000000000000000000000000000000000..246ac313dcf991014a29d1bde21cdbe099de5ff8 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/FileSystemAdapter.h @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/filesystem/type_fwd.h" +#include "arrow/io/interfaces.h" +#include "arrow/type_fwd.h" +#include "arrow/util/compare.h" +#include "arrow/util/macros.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" +#include "arrow/util/windows_fixup.h" +#include "common/UriInfo.h" + +namespace arrow_adapter { + +using arrow::Result; + +using arrow::fs::FileSystem; + +/// \defgroup filesystem-factories Functions for creating FileSystem instances + +/// @{ + +/// \brief Create a new FileSystem by URI +/// +/// Same as FileSystemFromUriOrPath, but it use uri that constructed by client +ARROW_EXPORT +Result> FileSystemFromUriOrPath(const UriInfo &uri, + std::string *out_path = NULLPTR); + + +/// \brief Create a new FileSystem by URI with a custom IO context +/// +/// Recognized schemes are "file", "mock", "hdfs", "viewfs", "s3", +/// "gs" and "gcs". +/// +/// \param[in] uri a URI-based path, ex: file:///some/local/path +/// \param[in] io_context an IOContext which will be associated with the filesystem +/// \param[out] out_path (optional) Path inside the filesystem. +/// \return out_fs FileSystem instance. + + +/// \brief Create a new FileSystem by URI with a custom IO context +/// +/// Same as FileSystemFromUri, but in addition also recognize non-URIs +/// and treat them as local filesystem paths. Only absolute local filesystem +/// paths are allowed. +ARROW_EXPORT +Result> FileSystemFromUriOrPath( + const UriInfo &uri, const arrow::io::IOContext &io_context, + std::string *out_path = NULLPTR); + +/// @} + +// namespace fs +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/HdfsAdapter.cc b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/HdfsAdapter.cc new file mode 100644 index 0000000000000000000000000000000000000000..debadaa35e10f3ba8de4534373a54cdb1d58afa4 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/HdfsAdapter.cc @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "arrow/filesystem/hdfs.h" +#include "arrow/util/value_parsing.h" +#include "HdfsAdapter.h" + +namespace arrow_adapter { + +using arrow::internal::ParseValue; + +using arrow::Result; +using arrow::fs::HdfsOptions; + +Result buildHdfsOptionsFromUri(const UriInfo &uri) +{ + HdfsOptions options; + + std::string host; + host = uri.Scheme() + "://" + uri.Host(); + + // configure endpoint + int32_t port; + if (uri.Port().empty() || (port = atoi(uri.Port().c_str())) == -1) { + // default port will be determined by hdfs FileSystem impl + options.ConfigureEndPoint(host, 0); + } else { + options.ConfigureEndPoint(host, port); + } + + return options; +} + +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/HdfsAdapter.h b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/HdfsAdapter.h new file mode 100644 index 0000000000000000000000000000000000000000..10aa9bc8e462e7f62a82bea6db1fbb501781592d --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/HdfsAdapter.h @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "arrow/filesystem/filesystem.h" +#include "arrow/filesystem/hdfs.h" +#include "common/UriInfo.h" + +namespace arrow_adapter { + +using arrow::Result; +using arrow::fs::HdfsOptions; + +ARROW_EXPORT +Result buildHdfsOptionsFromUri(const UriInfo &uri); + +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/LocalfsAdapter.cc b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/LocalfsAdapter.cc new file mode 100644 index 0000000000000000000000000000000000000000..08e1f204cbd21699079a03f059e0bf8d095803d3 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/LocalfsAdapter.cc @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "arrow/filesystem/localfs.h" +#include "arrow/util/io_util.h" +#include "LocalfsAdapter.h" +#include "arrow/result.h" + +namespace arrow_adapter { + +using ::arrow::internal::IOErrorFromErrno; +using ::arrow::internal::NativePathString; +using ::arrow::internal::PlatformFilename; +using arrow::Result; +using arrow::fs::LocalFileSystemOptions; +using arrow::Status; + +Result buildLocalfsOptionsFromUri(const UriInfo &uri, std::string *out_path) +{ + std::string path; + const auto host = uri.Host(); + if (!host.empty()) { + return Status::Invalid("Unsupported hostname in non-Windows local URI: '", + uri.ToString(), "'"); + } else { + *out_path = uri.Path(); + } + + return LocalFileSystemOptions(); +} + +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/LocalfsAdapter.h b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/LocalfsAdapter.h new file mode 100644 index 0000000000000000000000000000000000000000..26d3b60cf782bd7ed4d761327f8e15d784a451d6 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/LocalfsAdapter.h @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "arrow/filesystem/filesystem.h" +#include "arrow/filesystem/localfs.h" +#include "common/UriInfo.h" + +namespace arrow_adapter { + +using arrow::Result; +using arrow::fs::LocalFileSystemOptions; +using arrow::Status; + +ARROW_EXPORT +Result buildLocalfsOptionsFromUri(const UriInfo &uri, std::string *out_path); + +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/UtilInternal.cc b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/UtilInternal.cc new file mode 100644 index 0000000000000000000000000000000000000000..058aeb38b45442dfd16a1392ce9107afbf933640 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/UtilInternal.cc @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "UtilInternal.h" + +namespace arrow_adapter { + +using arrow::fs::TimePoint; + +TimePoint CurrentTimePoint() +{ + auto now = std::chrono::system_clock::now(); + return TimePoint( + std::chrono::duration_cast(now.time_since_epoch())); +} + +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/arrowadapter/UtilInternal.h b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/UtilInternal.h new file mode 100644 index 0000000000000000000000000000000000000000..67d51eb4646ff3b641c8164cc84d331797016fc0 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/arrowadapter/UtilInternal.h @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "arrow/filesystem/filesystem.h" +#include "arrow/io/interfaces.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow_adapter { + +using arrow::fs::TimePoint; + +ARROW_EXPORT +TimePoint CurrentTimePoint(); + +} +// namespace arrow \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/common/UriInfo.cc b/omnioperator/omniop-native-reader/cpp/src/common/UriInfo.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4598ce3a75b3701a0b8ac91878f146a78e8caf1 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/common/UriInfo.cc @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "UriInfo.h" + +static const std::string LOCAL_FILE = "file"; + +UriInfo::UriInfo(std::string _uri, std::string _scheme, std::string _path, std::string _host, + std::string _port) : hostString(std::move(_host)), + schemeString(std::move(_scheme)), + portString(std::move(_port)), + pathString(std::move(_path)), + uriString(std::move(_uri)) +{ + // when local file, transfer to absolute path + if (schemeString == LOCAL_FILE) { + uriString = pathString; + } +} + +UriInfo::UriInfo(std::string _scheme, std::string _path, std::string _host, + std::string _port) : hostString(std::move(_host)), + schemeString(std::move(_scheme)), + portString(std::move(_port)), + pathString(std::move(_path)), + uriString("Not initialize origin uri!") +{ +} + +UriInfo::~UriInfo() {} + +const std::string UriInfo::Scheme() const +{ + return schemeString; +} + +const std::string UriInfo::Host() const +{ + return hostString; +} + +const std::string UriInfo::Port() const +{ + return portString; +} + +const std::string UriInfo::Path() const +{ + return pathString; +} + +const std::string UriInfo::ToString() const +{ + return uriString; +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/common/UriInfo.h b/omnioperator/omniop-native-reader/cpp/src/common/UriInfo.h new file mode 100644 index 0000000000000000000000000000000000000000..c9885a5bc657f993fd69f9473795c8803f3d3f88 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/common/UriInfo.h @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef URI_INFO_H +#define URI_INFO_H + +#include +#include + +/// \brief A parsed URI +class UriInfo { +public: + UriInfo(std::string _uri, std::string _scheme, std::string _path, std::string _host, std::string _port); + + UriInfo(std::string _scheme, std::string _path, std::string _host, std::string _port); + + ~UriInfo(); + + const std::string Scheme() const; + + /// The URI Host name, such as "localhost", "127.0.0.1" or "::1", or the empty + /// string is the URI does not have a Host component. + const std::string Host() const; + + /// The URI Path component. + const std::string Path() const; + + /// The URI Port number, as a string such as "80", or the empty string is the URI + /// does not have a Port number component. + const std::string Port() const; + + /// Get the string representation of this URI. + const std::string ToString() const; + +private: + std::string hostString; + std::string schemeString; + std::string portString; + std::string pathString; + std::string uriString; + +}; + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/common/debug.h b/omnioperator/omniop-native-reader/cpp/src/common/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..43a98d172faddc3c8886a622437d585c56d0ad49 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/common/debug.h @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "../../config.h" + +#ifdef TRACE_RUNTIME +#define LogsTrace(format, ...) \ + do { \ + printf("[TRACE][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ + } while (0) +#else +#define LogsTrace(format, ...) +#endif + +#if defined(TRACE_RUNTIME) || defined(DEBUG_RUNTIME) +#define LogsDebug(format, ...) \ + do { \ + if (static_cast(LogType::LOG_DEBUG) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_DEBUG); \ + } \ + } while (0) +#else +#define LogsDebug(format, ...) +#endif + +#define LogsInfo(format, ...) \ + do { \ + if (static_cast(LogType::LOG_INFO) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_INFO); \ + } \ + } while (0) + +#define LogsWarn(format, ...) \ + do { \ + if (static_cast(LogType::LOG_WARN) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_WARN); \ + } \ + } while (0) + +#define LogsError(format, ...) \ + do { \ + if (static_cast(LogType::LOG_ERROR) >= GetLogLevel()) { \ + char logBuf[GLOBAL_LOG_BUF_SIZE]; \ + LogsInfoVargMacro(logBuf, format, ##__VA_ARGS__); \ + std::string logString(logBuf); \ + Log(logString, LogType::LOG_ERROR); \ + } \ + } while (0) diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/file_interface.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/file_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..ba5e0af9dc2d501873a0e75b63997d5c9309dc08 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/file_interface.h @@ -0,0 +1,54 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_FILE_INTERFACE_H +#define SPARK_THESTRAL_PLUGIN_FILE_INTERFACE_H + +#include "status.h" + +namespace fs { + +class ReadableFile { +public: + // Virtual destructor + virtual ~ReadableFile() = default; + + // Close the file + virtual Status Close() = 0; + + // Open the file + virtual Status OpenFile() = 0; + + // Read data from the specified offset into the buffer with the given length + virtual int64_t ReadAt(void *buffer, int32_t length, int64_t offset) = 0; + + // Get the size of the file + virtual int64_t GetFileSize() = 0; + + // Set the read position within the file + virtual Status Seek(int64_t position) = 0; + + // Read data from the current position into the buffer with the given length + virtual int64_t Read(void *buffer, int32_t length) = 0; +}; + +} + + +#endif //SPARK_THESTRAL_PLUGIN_FILE_INTERFACE_H diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/filesystem.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/filesystem.h new file mode 100644 index 0000000000000000000000000000000000000000..5cbc4568dd573c26fdd5dd80d2c853599ab0889e --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/filesystem.h @@ -0,0 +1,131 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_FILESYSTEM_H +#define SPARK_THESTRAL_PLUGIN_FILESYSTEM_H + +#include +#include +#include +#include "status.h" + +namespace fs { + +using TimePoint = + std::chrono::time_point; + +static const int64_t kNoSize = -1; +static const TimePoint kNoTime = TimePoint(TimePoint::duration(-1)); + +enum class FileType : int8_t { + /// Entry is not found + NotFound, + /// Entry exists but its type is unknown + /// + /// This can designate a special file such as a Unix socket or character + /// device, or Windows NUL / CON / ... + Unknown, + /// Entry is a regular file + File, + /// Entry is a directory + Directory +}; + +std::string ToString(FileType); + +struct FileInfo { + /// The full file path in the filesystem + const std::string &path() const { return path_; } + + void setPath(std::string path) { path_ = std::move(path); } + + /// The file type + FileType type() const { return type_; } + + void setType(FileType type) { type_ = type; } + + /// The size in bytes, if available + int64_t size() const { return size_; } + + void setSize(int64_t size) { size_ = size; } + + /// The time of last modification, if available + TimePoint mtime() const { return mtime_; } + + void setMtime(TimePoint mtime) { mtime_ = mtime; } + + bool IsFile() const { return type_ == FileType::File; } + + bool IsDirectory() const { return type_ == FileType::Directory; } + + bool Equals(const FileInfo &other) const { + return type() == other.type() && path() == other.path() && size() == other.size() && + mtime() == other.mtime(); + } + +protected: + std::string path_; + FileType type_ = FileType::Unknown; + int64_t size_ = kNoSize; + TimePoint mtime_ = kNoTime; + +}; + +} + +namespace fs { + +class FileSystem { +public: + // Virtual destructor + virtual ~FileSystem() = default; + + // Get the type name of the file system + virtual std::string type_name() const = 0; + + /** + * Get information about the file at the specified path + * @param path the file path + */ + virtual FileInfo GetFileInfo(const std::string &path) = 0; + + /** + * Check if this file system is equal to another file system + * @param other the other filesystem + */ + virtual bool Equals(const FileSystem &other) const = 0; + + /** + * Check if this file system is equal to a shared pointer to another file system + * @param other the other filesystem pointer + */ + virtual bool Equals(const std::shared_ptr &other) const { + return Equals(*other); + } + + // Close the file system + virtual Status Close() = 0; +}; + +} // fs + + + + +#endif //SPARK_THESTRAL_PLUGIN_FILESYSTEM_H diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.cpp b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b08d1b2152ae53967f368430a7fe38825ae6584 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.cpp @@ -0,0 +1,101 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hdfs_file.h" +#include "iostream" + +namespace fs { + +HdfsReadableFile::HdfsReadableFile(std::shared_ptr fileSystemPtr, + const std::string &path, int64_t bufferSize) + : fileSystem_(fileSystemPtr), path_(path), bufferSize_(bufferSize) { +} + +HdfsReadableFile::~HdfsReadableFile() { + this->TryClose(); +} + +Status HdfsReadableFile::Close() { + return TryClose(); +} + +Status HdfsReadableFile::TryClose() { + if (!isOpen_) { + return Status::OK(); + } + int st = hdfsCloseFile(fileSystem_->getFileSystem(), file_); + if (st == -1) { + return Status::IOError("Fail to close hdfs file, path is " + path_); + } + this->isOpen_ = false; + return Status::OK(); +} + +Status HdfsReadableFile::OpenFile() { + if (isOpen_) { + return Status::OK(); + } + hdfsFile handle = hdfsOpenFile(fileSystem_->getFileSystem(), path_.c_str(), O_RDONLY, bufferSize_, 0, 0); + if (handle == nullptr) { + return Status::IOError("Fail to open hdfs file, path is " + path_); + } + + this->file_ = handle; + this->isOpen_ = true; + return Status::OK(); +} + +int64_t HdfsReadableFile::ReadAt(void *buffer, int32_t length, int64_t offset) { + if (!OpenFile().IsOk()) { + return -1; + } + + return hdfsPread(fileSystem_->getFileSystem(), file_, offset, buffer, length); +} + +int64_t HdfsReadableFile::GetFileSize() { + if (!OpenFile().IsOk()) { + return -1; + } + + FileInfo fileInfo = fileSystem_->GetFileInfo(path_); + return fileInfo.size(); +} + +Status HdfsReadableFile::Seek(int64_t position) { + if (!OpenFile().IsOk()) { + return Status::IOError("Fail to open and seek hdfs file, path is " + path_); + } + int st = hdfsSeek(fileSystem_->getFileSystem(), file_, position); + if (st == -1) { + return Status::IOError("Fail to seek hdfs file, path is " + path_); + } + return Status::OK(); +} + +int64_t HdfsReadableFile::Read(void *buffer, int32_t length) { + if (!OpenFile().IsOk()) { + return -1; + } + + return hdfsRead(fileSystem_->getFileSystem(), file_, buffer, length); +} + + +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.h new file mode 100644 index 0000000000000000000000000000000000000000..ebfe0334fb2fc307612ed3fba8ff50f3710fcb4c --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.h @@ -0,0 +1,65 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_HDFS_FILE_H +#define SPARK_THESTRAL_PLUGIN_HDFS_FILE_H + +#include "file_interface.h" +#include "hdfs_filesystem.h" + +namespace fs { + +class HdfsReadableFile : public ReadableFile { + +public: + HdfsReadableFile(std::shared_ptr fileSystemPtr, const std::string &path, + int64_t bufferSize = 0); + + ~HdfsReadableFile(); + + Status Close() override; + + Status OpenFile() override; + + int64_t ReadAt(void *buffer, int32_t length, int64_t offset) override; + + int64_t GetFileSize() override; + + Status Seek(int64_t position) override; + + int64_t Read(void *buffer, int32_t length) override; + +private: + Status TryClose(); + + std::shared_ptr fileSystem_; + + const std::string &path_; + + int64_t bufferSize_; + + bool isOpen_ = false; + + hdfsFile file_; +}; + +} + + +#endif //SPARK_THESTRAL_PLUGIN_HDFS_FILE_H diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_filesystem.cpp b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_filesystem.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1dc074060d79e5612414f268b3d50bb95c19ea4 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_filesystem.cpp @@ -0,0 +1,147 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "iostream" +#include "chrono" +#include "map" +#include "mutex" +#include "hdfs_filesystem.h" +#include "io_exception.h" + +namespace fs { + +void HdfsOptions::ConfigureHost(const std::string &host) { + this->host_ = host; +} + +void HdfsOptions::ConfigurePort(int port) { + this->port_ = port; +} + +bool HdfsOptions::Equals(const HdfsOptions &other) const { + return (this->host_ == other.host_ && this->port_ == other.port_); +} + +HadoopFileSystem::HadoopFileSystem(HdfsOptions &options) { + this->options_ = options; + Status st = this->Init(); + if (!st.IsOk()) { + throw IOException(st.ToString()); + } +} + +HadoopFileSystem::~HadoopFileSystem() = default; + +hdfsFS HadoopFileSystem::getFileSystem() { + return this->fs_; +} + +HdfsOptions HadoopFileSystem::getOptions() const { + return this->options_; +} + +bool HadoopFileSystem::Equals(const FileSystem &other) const { + if (this == &other) { + return true; + } + if (other.type_name() != type_name()) { + return false; + } + // todo reinterpret_cast 能不能转换类型,多态场景 + const auto &hdfs = reinterpret_cast(other); + return getOptions().Equals(hdfs.getOptions()); +} + +FileInfo HadoopFileSystem::GetFileInfo(const std::string &path) { + hdfsFileInfo *fileInfo = hdfsGetPathInfo(fs_, path.c_str()); + if (fileInfo == nullptr) { + throw IOException(Status::FSError("Fail to get file info").ToString()); + } + FileInfo info; + if (fileInfo->mKind == kObjectKindFile) { + info.setType(FileType::File); + } else if (fileInfo->mKind == kObjectKindDirectory) { + info.setType(FileType::Directory); + } else { + info.setType(FileType::Unknown); + } + info.setPath(path); + info.setSize(fileInfo->mSize); + info.setMtime(std::chrono::system_clock::from_time_t(fileInfo->mLastMod)); + return info; +} + +Status HadoopFileSystem::Close() { + if (hdfsDisconnect(fs_) == 0) { + return Status::OK(); + } + return Status::FSError("Fail to close hdfs filesystem"); +} + +Status HadoopFileSystem::Init() { + struct hdfsBuilder *bld = hdfsNewBuilder(); + if (bld == nullptr) { + return Status::FSError("Fail to create hdfs builder"); + } + hdfsBuilderSetNameNode(bld, options_.host_.c_str()); + hdfsBuilderSetNameNodePort(bld, options_.port_); + hdfsBuilderSetForceNewInstance(bld); + hdfsFS fileSystem = hdfsBuilderConnect(bld); + if (fileSystem == nullptr) { + return Status::FSError("Fail to connect hdfs filesystem"); + } + this->fs_ = fileSystem; + return Status::OK(); +} + +// the cache of hdfs filesystem +static std::map> fsMap_; +static std::mutex mutex_; + +std::shared_ptr getHdfsFileSystem(const std::string &host, const std::string &port) { + std::shared_ptr fileSystemPtr; + + mutex_.lock(); + std::string key = host + ":" + port; + auto iter = fsMap_.find(key); + if (iter != fsMap_.end()) { + fileSystemPtr = fsMap_[key]; + mutex_.unlock(); + return fileSystemPtr; + } + + HdfsOptions options; + options.ConfigureHost(host); + int portInt = 0; + if (!port.empty()) { + portInt = std::stoi(port); + } + if (portInt > 0) { + options.ConfigurePort(portInt); + } + + std::shared_ptr fs(new HadoopFileSystem(options)); + fileSystemPtr = fs; + fsMap_[key] = fs; + mutex_.unlock(); + + return fileSystemPtr; +} + +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_filesystem.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_filesystem.h new file mode 100644 index 0000000000000000000000000000000000000000..bd122f6f7ea5436841e00d58b568049574529b3a --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_filesystem.h @@ -0,0 +1,96 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_HDFS_FILESYSTEM_H +#define SPARK_THESTRAL_PLUGIN_HDFS_FILESYSTEM_H + +#include "filesystem.h" +#include "hdfs.h" +#include "status.h" + +namespace fs { + +struct HdfsOptions { + HdfsOptions() = default; + + ~HdfsOptions() = default; + + std::string host_; + int port_ = 0; + + void ConfigureHost(const std::string &host); + + void ConfigurePort(int port); + + bool Equals(const HdfsOptions &other) const; +}; + +class HadoopFileSystem : public FileSystem { +private: + // Hadoop file system handle + hdfsFS fs_; + // Options for Hadoop file system + HdfsOptions options_; + +public: + // Constructor with Hadoop options + HadoopFileSystem(HdfsOptions &options); + + // Destructor + ~HadoopFileSystem(); + + // Get the type name of the file system + std::string type_name() const override { return "HdfsFileSystem"; } + + /** + * Check if this file system is equal to another file system + * @param other the other filesystem + */ + bool Equals(const FileSystem &other) const override; + + /** + * Get file info from file system + * @param path the file path + */ + FileInfo GetFileInfo(const std::string &path) override; + + // Close the file system + Status Close(); + + // Get the Hadoop file system handle + hdfsFS getFileSystem(); + + // Get the Hadoop file system options + HdfsOptions getOptions() const; + +private: + // Initialize the Hadoop file system + Status Init(); +}; + +/** +* Get a shared pointer to a Hadoop file system +* @param host the host of hdfs filesystem +* @param port the port of hdfs filesystem +*/ +std::shared_ptr getHdfsFileSystem(const std::string &host, const std::string &port); + +} + +#endif //SPARK_THESTRAL_PLUGIN_HDFS_FILESYSTEM_H diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/io_exception.cpp b/omnioperator/omniop-native-reader/cpp/src/filesystem/io_exception.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70577ba69555dfa551ac49e192cf2db31e22bccd --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/io_exception.cpp @@ -0,0 +1,34 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "io_exception.h" + +namespace fs { + +IOException::IOException(const std::string &arg +) : runtime_error(arg) {} + +IOException::IOException(const char *arg +) : runtime_error(arg) {} + +IOException::IOException(const IOException &error) : runtime_error(error) {} + +IOException::~IOException() noexcept {} + +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/io_exception.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/io_exception.h new file mode 100644 index 0000000000000000000000000000000000000000..50ab4200c30960230f44d045d7ed8408119e8b2c --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/io_exception.h @@ -0,0 +1,44 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_IO_EXCEPTION_H +#define SPARK_THESTRAL_PLUGIN_IO_EXCEPTION_H + +#include "stdexcept" + +namespace fs { + +class IOException : public std::runtime_error { +public: + explicit IOException(const std::string &arg); + + explicit IOException(const char *arg); + + virtual ~IOException() noexcept; + + IOException(const IOException &); + +private: + IOException &operator=(const IOException &); +}; + +} + + +#endif //SPARK_THESTRAL_PLUGIN_IO_EXCEPTION_H diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/status.cpp b/omnioperator/omniop-native-reader/cpp/src/filesystem/status.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57726418589c6b039e6af5792755a4552ac6c8fe --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/status.cpp @@ -0,0 +1,51 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "status.h" + +namespace fs { + +std::string Status::ToString() const { + std::string result(CodeAsString(state_->code)); + result += ": "; + result += state_->msg; + return result; +} + +std::string Status::CodeAsString(StatusCode code) { + const char *type; + switch (code) { + case StatusCode::OK: + type = "OK"; + break; + case StatusCode::FSError: + type = "FileSystem error"; + break; + case StatusCode::IOError: + type = "IO error"; + break; + default: + type = "Unknown"; + break; + } + return std::string(type); +} + + +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/status.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/status.h new file mode 100644 index 0000000000000000000000000000000000000000..fcae2ab4c340ee43406e85cb76fc7ff2add40f2d --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/status.h @@ -0,0 +1,107 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SPARK_THESTRAL_PLUGIN_STATUS_H +#define SPARK_THESTRAL_PLUGIN_STATUS_H + +#include + +namespace fs { + +// Enum to represent different status codes +enum class StatusCode : char { + OK = 0, + FSError = 1, + IOError = 2, + UnknownError = 3 +}; + +// Struct to hold status code and message +struct State { + StatusCode code; // Status code + std::string msg; // Status message +}; + +// Class to represent status +class Status { + +public: + // Default constructor + Status() noexcept: state_(nullptr) {} + + // Constructor with status code and message + Status(StatusCode code, const std::string &msg) { + State *state = new State(); + state->code = code; + state->msg = msg; + this->state_ = state; + } + + // Destructor + ~Status() noexcept { + delete state_; + state_ = nullptr; + } + + // Create a status from status code and message + static Status FromMsg(StatusCode code, const std::string &msg) { + return Status(code, msg); + } + + // Create a file system error status with message + static Status FSError(const std::string &msg) { + return Status::FromMsg(StatusCode::FSError, msg); + } + + // Create an I/O error status with message + static Status IOError(const std::string &msg) { + return Status::FromMsg(StatusCode::IOError, msg); + } + + // Create an unknown error status with message + static Status UnknownError(const std::string &msg) { + return Status::FromMsg(StatusCode::UnknownError, msg); + } + + // Create an OK status + static Status OK() { + return Status(); + } + + // Check if the status is OK + constexpr bool IsOk() const { + if (state_ == nullptr || state_->code == StatusCode::OK) { + return true; + } + return false; + } + + // Get the status as a string + std::string ToString() const; + + // Get the status code as a string + static std::string CodeAsString(StatusCode); + +private: + // Pointer to the status state + State *state_; +}; +} + +#endif //SPARK_THESTRAL_PLUGIN_STATUS_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp similarity index 58% rename from omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp rename to omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp index 2efdc3ea0e626be860951755c4980833e8376f51..e1300c4e062857780c1c30d56fe9d366cce6f868 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp @@ -1,5 +1,5 @@ /** - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -19,184 +19,20 @@ #include "OrcColumnarBatchJniReader.h" #include +#include #include "jni_common.h" -#include "../io/OrcObsFile.hh" using namespace omniruntime::vec; using namespace omniruntime::type; using namespace std; using namespace orc; -using namespace hdfs; static constexpr int32_t MAX_DECIMAL64_DIGITS = 18; +bool isDecimal64Transfor128 = false; -bool isLegalHex(const char c) { - if ((c >= '0') && (c <= '9')) { - return true; - } - - if ((c >= 'a') && (c <= 'f')) { - return true; - } - - if ((c >= 'A') && (c <= 'F')) { - return true; - } - - return false; -} - -uint8_t hexStrToValue(const char c) { - if ((c >= '0') && (c <= '9')) { - return c - '0'; - } - - if ((c >= 'A') && (c <= 'F')) { - return c - 'A' + 10; - } - - return c - 'a' + 10; -} - -void transHexToByte(const std::string &origin, std::string &result) { - const uint32_t strLenPerByte = 2; - const char* srcStr = origin.c_str(); - char first; - char second; - - if (origin.size() % strLenPerByte) { - LogsError("Input string(%s) length(%u) must be multiple of 2.", srcStr, origin.size()); - return; - } - - result.resize(origin.size() / strLenPerByte); - for (uint32_t i = 0; i < origin.size(); i += strLenPerByte) { - first = srcStr[i]; - second = srcStr[i + 1]; - if (!isLegalHex(first) || !isLegalHex(second)) { - LogsError("Input string(%s) is not legal at about index=%d.", srcStr, i); - result.resize(0); - return; - } - - result[i / strLenPerByte] = ((hexStrToValue(first) & 0x0F) << 4) + (hexStrToValue(second) & 0x0F); - } - - return; -} - -void parseTokens(JNIEnv* env, jobject jsonObj, std::vector& tokenVector) { - const char* strTokens = "tokens"; - const char* strToken = "token"; - const char* strIdentifier = "identifier"; - const char* strPassword = "password"; - const char* strService = "service"; - const char* strTokenKind = "kind"; - - jboolean hasTokens = env->CallBooleanMethod(jsonObj, jsonMethodHas, env->NewStringUTF(strTokens)); - if (!hasTokens) { - return; - } - - jobject tokensObj = env->CallObjectMethod(jsonObj, jsonMethodObj, env->NewStringUTF(strTokens)); - if (tokensObj == NULL) { - return; - } - - jobjectArray tokenJsonArray = (jobjectArray)env->CallObjectMethod(tokensObj, jsonMethodObj, env->NewStringUTF(strToken)); - if (tokenJsonArray == NULL) { - return; - } - - uint32_t count = env->GetArrayLength(tokenJsonArray); - for (uint32_t i = 0; i < count; i++) { - jobject child = env->GetObjectArrayElement(tokenJsonArray, i); - - jstring jIdentifier = (jstring)env->CallObjectMethod(child, jsonMethodString, env->NewStringUTF(strIdentifier)); - jstring jPassword = (jstring)env->CallObjectMethod(child, jsonMethodString, env->NewStringUTF(strPassword)); - jstring jService = (jstring)env->CallObjectMethod(child, jsonMethodString, env->NewStringUTF(strService)); - jstring jKind = (jstring)env->CallObjectMethod(child, jsonMethodString, env->NewStringUTF(strTokenKind)); - - auto identifierStr = env->GetStringUTFChars(jIdentifier, nullptr); - std::string inIdentifier(identifierStr); - env->ReleaseStringUTFChars(jIdentifier, identifierStr); - transform(inIdentifier.begin(), inIdentifier.end(), inIdentifier.begin(), ::tolower); - std::string identifier; - transHexToByte(inIdentifier, identifier); - - auto passwordStr = env->GetStringUTFChars(jPassword, nullptr); - std::string inPassword(passwordStr); - env->ReleaseStringUTFChars(jPassword, passwordStr); - transform(inPassword.begin(), inPassword.end(), inPassword.begin(), ::tolower); - std::string password; - transHexToByte(inPassword, password); - - auto kindStr = env->GetStringUTFChars(jKind, nullptr); - std::string kind(kindStr); - env->ReleaseStringUTFChars(jKind, kindStr); - - auto serviceStr = env->GetStringUTFChars(jService, nullptr); - std::string service(serviceStr); - env->ReleaseStringUTFChars(jService, serviceStr); - - transform(kind.begin(), kind.end(), kind.begin(), ::tolower); - if (kind != "hdfs_delegation_token") { - continue; // only hdfs delegation token is useful for liborc - } - - Token* token = new Token(); - token->setIdentifier(identifier); - token->setPassword(password); - token->setService(service); - token->setKind(kind); - - tokenVector.push_back(token); - } -} - -void deleteTokens(std::vector& tokenVector) { - for (auto token : tokenVector) { - delete token; - } - - tokenVector.clear(); -} - -void parseObs(JNIEnv* env, jobject jsonObj, ObsConfig &obsInfo) { - jobject obsObject = env->CallObjectMethod(jsonObj, jsonMethodObj, env->NewStringUTF("obsInfo")); - if (obsObject == NULL) { - LogsWarn("get obs info failed, obs info is null."); - return; - } - - jstring jEndpoint = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("endpoint")); - auto endpointCharPtr = env->GetStringUTFChars(jEndpoint, JNI_FALSE); - std::string endpoint = endpointCharPtr; - obsInfo.hostLen = endpoint.length() + 1; - strcpy_s(obsInfo.hostName, obsInfo.hostLen, endpoint.c_str()); - env->ReleaseStringUTFChars(jEndpoint, endpointCharPtr); - - jstring jAk = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("ak")); - auto akCharPtr = env->GetStringUTFChars(jAk, JNI_FALSE); - std::string ak = akCharPtr; - strcpy_s(obsInfo.accessKey, ak.length() + 1, ak.c_str()); - env->ReleaseStringUTFChars(jAk, akCharPtr); - - jstring jSk = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("sk")); - auto skCharPtr = env->GetStringUTFChars(jSk, JNI_FALSE); - std::string sk = skCharPtr; - strcpy_s(obsInfo.secretKey, sk.length() + 1, sk.c_str()); - env->ReleaseStringUTFChars(jSk, skCharPtr); - - jstring jToken = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("token")); - auto tokenCharPtr = env->GetStringUTFChars(jToken, JNI_FALSE); - std::string token = tokenCharPtr; - strcpy_s(obsInfo.token, token.length() + 1, token.c_str()); - env->ReleaseStringUTFChars(jToken, tokenCharPtr); -} - -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeReader(JNIEnv *env, - jobject jObj, jstring path, jobject jsonObj) +// vecFildsNames存储文件每列的列名,从orc reader c++侧获取,回传到java侧使用 +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeReader(JNIEnv *env, + jobject jObj, jobject jsonObj, jobject vecFildsNames) { JNI_FUNC_START @@ -206,8 +42,6 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe jlong tailLocation = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("tailLocation")); jstring serTailJstr = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("serializedTail")); - const char *pathPtr = env->GetStringUTFChars(path, nullptr); - std::string filePath(pathPtr); orc::MemoryPool *pool = orc::getDefaultPool(); orc::ReaderOptions readerOptions; readerOptions.setMemoryPool(*pool); @@ -219,21 +53,35 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe env->ReleaseStringUTFChars(serTailJstr, ptr); } - std::vector tokens; - parseTokens(env, jsonObj, tokens); + jstring schemaJstr = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("scheme")); + const char *schemaPtr = env->GetStringUTFChars(schemaJstr, nullptr); + std::string schemaStr(schemaPtr); + env->ReleaseStringUTFChars(schemaJstr, schemaPtr); + + jstring fileJstr = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("path")); + const char *filePtr = env->GetStringUTFChars(fileJstr, nullptr); + std::string fileStr(filePtr); + env->ReleaseStringUTFChars(fileJstr, filePtr); + + jstring hostJstr = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("host")); + const char *hostPtr = env->GetStringUTFChars(hostJstr, nullptr); + std::string hostStr(hostPtr); + env->ReleaseStringUTFChars(hostJstr, hostPtr); + + jint port = (jint)env->CallIntMethod(jsonObj, jsonMethodInt, env->NewStringUTF("port")); std::unique_ptr reader; - if (0 == strncmp(filePath.c_str(), "obs://", OBS_PROTOCOL_SIZE)) { - ObsConfig obsInfo; - parseObs(env, jsonObj, obsInfo); - reader = createReader(orc::readObsFile(filePath, &obsInfo), readerOptions); - } else { - reader = createReader(orc::readFileRewrite(filePath, tokens), readerOptions); + UriInfo uri{schemaStr, fileStr, hostStr, std::to_string(port)}; + reader = createReader(orc::readFileOverride(uri), readerOptions); + std::vector orcColumnNames = reader->getAllFiedsName(); + for (int i = 0; i < orcColumnNames.size(); i++) { + jstring fildname = env->NewStringUTF(orcColumnNames[i].c_str()); + // use ArrayList and function + env->CallBooleanMethod(vecFildsNames, arrayListAdd, fildname); + env->DeleteLocalRef(fildname); } - env->ReleaseStringUTFChars(path, pathPtr); orc::Reader *readerNew = reader.release(); - deleteTokens(tokens); return (jlong)(readerNew); JNI_FUNC_END(runtimeExceptionClass) } @@ -317,7 +165,14 @@ int BuildLeaves(PredicateOperatorType leafOp, vector &litList, Literal break; } case PredicateOperatorType::IN: { - builder.in(leafNameString, leafType, litList); + if (litList.empty()) { + // build.in方法第一个参数给定空值,即会认为该predictLeaf的TruthValue为YES_NO_NULL(不过滤数据) + // 即与java orc in中存在null的行为保持一致 + std::string emptyString; + builder.in(emptyString, leafType, litList); + } else { + builder.in(leafNameString, leafType, litList); + } break; } case PredicateOperatorType::BETWEEN: { @@ -341,8 +196,10 @@ int initLeaves(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jo Literal lit(0L); jstring leafValue = (jstring)env->CallObjectMethod(leafJsonObj, jsonMethodString, env->NewStringUTF("literal")); if (leafValue != nullptr) { - std::string leafValueString(env->GetStringUTFChars(leafValue, nullptr)); - if (leafValueString.size() != 0) { + const char *leafChars = env->GetStringUTFChars(leafValue, nullptr); + std::string leafValueString(leafChars); + env->ReleaseStringUTFChars(leafValue, leafChars); + if (leafValueString.size() != 0 || (leafValueString.size() == 0 && (orc::PredicateDataType)leafType == orc::PredicateDataType::STRING)) { GetLiteral(lit, leafType, leafValueString); } } @@ -351,10 +208,20 @@ int initLeaves(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jsonExp, jo if (litListValue != nullptr) { int childs = (int)env->CallIntMethod(litListValue, arrayListSize); for (int i = 0; i < childs; i++) { - jstring child = (jstring)env->CallObjectMethod(litListValue, arrayListGet, i); - std::string childString(env->GetStringUTFChars(child, nullptr)); - GetLiteral(lit, leafType, childString); - litList.push_back(lit); + jstring child = (jstring) env->CallObjectMethod(litListValue, arrayListGet, i); + if (child == nullptr) { + // 原生spark-sql PredicateLiteralList如果含有null元素,会捕获NPE,然后产生TruthValue.YES_NO或者TruthValue.YES_NO_NULL + // 这两者TruthValue在谓词下推都不会过滤该行组的数据 + // 此处将litList清空,作为BuildLeaves的标志,Build时传入相应参数产生上述TruthValue,使表现出的特性与原生保持一致 + litList.clear(); + break; + } else { + auto chars = env->GetStringUTFChars(child, nullptr); + std::string childString(chars); + env->ReleaseStringUTFChars(child, chars); + GetLiteral(lit, leafType, childString); + litList.push_back(lit); + } } } BuildLeaves((PredicateOperatorType)leafOp, litList, lit, leafNameString, (PredicateDataType)leafType, builder); @@ -396,11 +263,17 @@ int initExpressionTree(JNIEnv *env, SearchArgumentBuilder &builder, jobject &jso } -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeRecordReader(JNIEnv *env, +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeRecordReader(JNIEnv *env, jobject jObj, jlong reader, jobject jsonObj) { JNI_FUNC_START orc::Reader *readerPtr = (orc::Reader *)reader; + // Get if the decimal for spark or hive + jboolean jni_isDecimal64Transfor128 = env->CallBooleanMethod(jsonObj, jsonMethodHas, + env->NewStringUTF("isDecimal64Transfor128")); + if (jni_isDecimal64Transfor128) { + isDecimal64Transfor128 = true; + } // get offset from json obj jlong offset = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("offset")); jlong length = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("length")); @@ -443,7 +316,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe } -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeBatch(JNIEnv *env, +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeBatch(JNIEnv *env, jobject jObj, jlong rowReader, jlong batchSize) { JNI_FUNC_START @@ -455,75 +328,80 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe JNI_FUNC_END(runtimeExceptionClass) } -template uint64_t CopyFixedWidth(orc::ColumnVectorBatch *field) +template +std::unique_ptr CopyFixedWidth(orc::ColumnVectorBatch *field) { using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); - auto originalVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); // Check ColumnVectorBatch has null or not firstly if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { - originalVector->SetValue(i, (T)(values[i])); + newVectorPtr->SetValue(i, (T)(values[i])); } else { - originalVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { - originalVector->SetValue(i, (T)(values[i])); + newVectorPtr->SetValue(i, (T)(values[i])); } } - return (uint64_t)originalVector; + return newVector; } -template uint64_t CopyOptimizedForInt64(orc::ColumnVectorBatch *field) +template +std::unique_ptr CopyOptimizedForInt64(orc::ColumnVectorBatch *field) { using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); - auto originalVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); // Check ColumnVectorBatch has null or not firstly if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - originalVector->SetNull(i); + newVectorPtr->SetNull(i); } } } - originalVector->SetValues(0, values, numElements); - return (uint64_t)originalVector; + newVectorPtr->SetValues(0, values, numElements); + return newVector; } -uint64_t CopyVarWidth(orc::ColumnVectorBatch *field) +std::unique_ptr CopyVarWidth(orc::ColumnVectorBatch *field) { orc::StringVectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); auto lens = lvb->length.data(); - auto originalVector = new Vector>(numElements); + auto newVector = std::make_unique>>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); - originalVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } else { - originalVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); - originalVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } } - return (uint64_t)originalVector; + return newVector; } inline void FindLastNotEmpty(const char *chars, long &len) @@ -533,14 +411,15 @@ inline void FindLastNotEmpty(const char *chars, long &len) } } -uint64_t CopyCharType(orc::ColumnVectorBatch *field) +std::unique_ptr CopyCharType(orc::ColumnVectorBatch *field) { orc::StringVectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); auto lens = lvb->length.data(); - auto originalVector = new Vector>(numElements); + auto newVector = std::make_unique>>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { @@ -548,9 +427,9 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field) auto len = lens[i]; FindLastNotEmpty(chars, len); auto data = std::string_view(chars, len); - originalVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } else { - originalVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { @@ -559,158 +438,167 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field) auto len = lens[i]; FindLastNotEmpty(chars, len); auto data = std::string_view(chars, len); - originalVector->SetValue(i, data); - } - } - return (uint64_t)originalVector; -} - -inline void TransferDecimal128(int64_t &highbits, uint64_t &lowbits) -{ - if (highbits < 0) { // int128's 2s' complement code - lowbits = ~lowbits + 1; // 2s' complement code - highbits = ~highbits; //1s' complement code - if (lowbits == 0) { - highbits += 1; // carry a number as in adding + newVectorPtr->SetValue(i, data); } - highbits ^= ((uint64_t)1 << 63); } + return newVector; } -uint64_t CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) { orc::Decimal128VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto originalVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { - auto highbits = values[i].getHighBits(); - auto lowbits = values[i].getLowBits(); - TransferDecimal128(highbits, lowbits); - Decimal128 d128(highbits, lowbits); - originalVector->SetValue(i, d128); + __int128_t dst = values[i].getHighBits(); + dst <<= 64; + dst |= values[i].getLowBits(); + newVectorPtr->SetValue(i, Decimal128(dst)); } else { - originalVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { - auto highbits = values[i].getHighBits(); - auto lowbits = values[i].getLowBits(); - TransferDecimal128(highbits, lowbits); - Decimal128 d128(highbits, lowbits); - originalVector->SetValue(i, d128); + __int128_t dst = values[i].getHighBits(); + dst <<= 64; + dst |= values[i].getLowBits(); + newVectorPtr->SetValue(i, Decimal128(dst)); } } - return (uint64_t)originalVector; + return newVector; } -uint64_t CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) { orc::Decimal64VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto originalVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - originalVector->SetNull(i); + newVectorPtr->SetNull(i); } } } - originalVector->SetValues(0, values, numElements); - return (uint64_t)originalVector; + newVectorPtr->SetValues(0, values, numElements); + return newVector; } -int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch *field) +{ + orc::Decimal64VectorBatch *lvb = dynamic_cast(field); + auto numElements = lvb->numElements; + auto values = lvb->values.data(); + auto notNulls = lvb->notNull.data(); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); + if (lvb->hasNulls) { + for (uint i = 0; i < numElements; i++) { + if (!notNulls[i]) { + newVectorPtr->SetNull(i); + } else { + Decimal128 d128(values[i]); + newVectorPtr->SetValue(i, d128); + } + } + } else { + for (uint i = 0; i < numElements; i++) { + Decimal128 d128(values[i]); + newVectorPtr->SetValue(i, d128); + } + } + + return newVector; +} + +std::unique_ptr CopyToOmniVec(const orc::Type *type, int &omniTypeId, orc::ColumnVectorBatch *field, + bool isDecimal64Transfor128) { switch (type->getKind()) { case orc::TypeKind::BOOLEAN: omniTypeId = static_cast(OMNI_BOOLEAN); - omniVecId = CopyFixedWidth(field); - break; + return CopyFixedWidth(field); case orc::TypeKind::SHORT: omniTypeId = static_cast(OMNI_SHORT); - omniVecId = CopyFixedWidth(field); - break; + return CopyFixedWidth(field); case orc::TypeKind::DATE: omniTypeId = static_cast(OMNI_DATE32); - omniVecId = CopyFixedWidth(field); - break; + return CopyFixedWidth(field); case orc::TypeKind::INT: omniTypeId = static_cast(OMNI_INT); - omniVecId = CopyFixedWidth(field); - break; + return CopyFixedWidth(field); case orc::TypeKind::LONG: omniTypeId = static_cast(OMNI_LONG); - omniVecId = CopyOptimizedForInt64(field); - break; + return CopyOptimizedForInt64(field); case orc::TypeKind::DOUBLE: omniTypeId = static_cast(OMNI_DOUBLE); - omniVecId = CopyOptimizedForInt64(field); - break; + return CopyOptimizedForInt64(field); case orc::TypeKind::CHAR: omniTypeId = static_cast(OMNI_VARCHAR); - omniVecId = CopyCharType(field); - break; + return CopyCharType(field); case orc::TypeKind::STRING: case orc::TypeKind::VARCHAR: omniTypeId = static_cast(OMNI_VARCHAR); - omniVecId = CopyVarWidth(field); - break; + return CopyVarWidth(field); case orc::TypeKind::DECIMAL: if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128Vec(field); + return CopyToOmniDecimal128Vec(field); + } else if (isDecimal64Transfor128) { + omniTypeId = static_cast(OMNI_DECIMAL128); + return CopyToOmniDecimal128VecFrom64(field); } else { omniTypeId = static_cast(OMNI_DECIMAL64); - omniVecId = CopyToOmniDecimal64Vec(field); + return CopyToOmniDecimal64Vec(field); } - break; default: { throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind()); } } - return 1; } -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env, +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch, jintArray typeId, jlongArray vecNativeId) { - JNI_FUNC_START orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch; + std::vector> omniVecs; + const orc::Type &baseTp = rowReaderPtr->getSelectedType(); - int vecCnt = 0; - long batchRowSize = 0; + uint64_t batchRowSize = 0; if (rowReaderPtr->next(*columnVectorBatch)) { orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch); - vecCnt = root->fields.size(); batchRowSize = root->fields[0]->numElements; - for (int id = 0; id < vecCnt; id++) { + int32_t vecCnt = root->fields.size(); + std::vector omniTypeIds(vecCnt, 0); + for (int32_t id = 0; id < vecCnt; id++) { auto type = baseTp.getSubtype(id); - int omniTypeId = 0; - uint64_t omniVecId = 0; - CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id]); - env->SetIntArrayRegion(typeId, id, 1, &omniTypeId); - jlong omniVec = static_cast(omniVecId); + omniVecs.emplace_back(CopyToOmniVec(type, omniTypeIds[id], root->fields[id], isDecimal64Transfor128)); + } + for (int32_t id = 0; id < vecCnt; id++) { + env->SetIntArrayRegion(typeId, id, 1, omniTypeIds.data() + id); + jlong omniVec = reinterpret_cast(omniVecs[id].release()); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); } } - return (jlong)batchRowSize; - JNI_FUNC_END(runtimeExceptionClass) + return (jlong) batchRowSize; } /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderGetRowNumber * Signature: (J)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetRowNumber( +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderGetRowNumber( JNIEnv *env, jobject jObj, jlong rowReader) { JNI_FUNC_START @@ -721,11 +609,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRe } /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderGetProgress * Signature: (J)F */ -JNIEXPORT jfloat JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetProgress( +JNIEXPORT jfloat JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderGetProgress( JNIEnv *env, jobject jObj, jlong rowReader) { JNI_FUNC_START @@ -736,11 +624,11 @@ JNIEXPORT jfloat JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniR } /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderClose * Signature: (J)F */ -JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderClose(JNIEnv *env, +JNIEXPORT void JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderClose(JNIEnv *env, jobject jObj, jlong rowReader, jlong reader, jlong batchReader) { JNI_FUNC_START @@ -763,11 +651,11 @@ JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRea } /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderSeekToRow * Signature: (JJ)F */ -JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderSeekToRow(JNIEnv *env, +JNIEXPORT void JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderSeekToRow(JNIEnv *env, jobject jObj, jlong rowReader, jlong rowNumber) { JNI_FUNC_START @@ -778,7 +666,7 @@ JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniRea JNIEXPORT jobjectArray JNICALL -Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getAllColumnNames(JNIEnv *env, jobject jObj, jlong reader) +Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_getAllColumnNames(JNIEnv *env, jobject jObj, jlong reader) { JNI_FUNC_START orc::Reader *readerPtr = (orc::Reader *)reader; @@ -792,7 +680,7 @@ Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getAllColumnNames(J JNI_FUNC_END(runtimeExceptionClass) } -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch) { JNI_FUNC_START diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h similarity index 58% rename from omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h rename to omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h index 860effb7adfe1be969a37aa54590366011b12d4e..829f5c0744d3d563601ec6506ebfc82b5a020e93 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/OrcColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -1,5 +1,5 @@ /** - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -17,10 +17,10 @@ * limitations under the License. */ -/* Header for class THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H */ +/* Header for class OMNI_RUNTIME_ORCCOLUMNARBATCHJNIREADER_H */ -#ifndef THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H -#define THESTRAL_PLUGIN_ORCCOLUMNARBATCHJNIREADER_H +#ifndef OMNI_RUNTIME_ORCCOLUMNARBATCHJNIREADER_H +#define OMNI_RUNTIME_ORCCOLUMNARBATCHJNIREADER_H #include #include @@ -36,8 +36,7 @@ #include #include #include -#include "io/orcfile/OrcFileRewrite.hh" -#include "hdfspp/options.h" +#include "orcfile/OrcFileOverride.hh" #include #include #include @@ -65,74 +64,74 @@ enum class PredicateOperatorType { }; /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: initializeReader * Signature: (Ljava/lang/String;Lorg/json/simple/JSONObject;)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeReader - (JNIEnv* env, jobject jObj, jstring path, jobject job); +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeReader + (JNIEnv* env, jobject jObj, jobject job, jobject vecFildsNames); /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: initializeRecordReader * Signature: (JLorg/json/simple/JSONObject;)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeRecordReader +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeRecordReader (JNIEnv* env, jobject jObj, jlong reader, jobject job); /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader - * Method: initializeRecordReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader + * Method: initializeBatch * Signature: (JLorg/json/simple/JSONObject;)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_initializeBatch +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeBatch (JNIEnv* env, jobject jObj, jlong rowReader, jlong batchSize); /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderNext * Signature: (J[I[J)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderNext +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext (JNIEnv *, jobject, jlong, jlong, jintArray, jlongArray); /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderGetRowNumber * Signature: (J)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetRowNumber +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderGetRowNumber (JNIEnv *, jobject, jlong); /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderGetProgress * Signature: (J)F */ -JNIEXPORT jfloat JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderGetProgress +JNIEXPORT jfloat JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderGetProgress (JNIEnv *, jobject, jlong); /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderClose - * Signature: (J)F + * Signature: (JJJ)F */ -JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderClose +JNIEXPORT void JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderClose (JNIEnv *, jobject, jlong, jlong, jlong); /* - * Class: com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader * Method: recordReaderSeekToRow * Signature: (JJ)F */ -JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_recordReaderSeekToRow +JNIEXPORT void JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderSeekToRow (JNIEnv *, jobject, jlong, jlong); -JNIEXPORT jobjectArray JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getAllColumnNames +JNIEXPORT jobjectArray JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_getAllColumnNames (JNIEnv *, jobject, jlong); -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_getNumberOfRows(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch); int GetLiteral(orc::Literal &lit, int leafType, const std::string &value); @@ -142,7 +141,8 @@ int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList bool StringToBool(const std::string &boolStr); -int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field); +int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, + bool isDecimal64Transfor128); #ifdef __cplusplus } diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp similarity index 44% rename from omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.cpp rename to omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp index e24bff186a5ac7ad36076fd0b4346c79bbc18eb5..991699a7be573db1f191fe7b20de0a45baf4964d 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp @@ -1,5 +1,5 @@ /** - * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -19,14 +19,10 @@ #include "ParquetColumnarBatchJniReader.h" #include "jni_common.h" -#include "tablescan/ParquetReader.h" +#include "parquet/ParquetReader.h" +#include "common/UriInfo.h" -using namespace omniruntime::vec; -using namespace omniruntime::type; -using namespace std; -using namespace arrow; -using namespace parquet::arrow; -using namespace spark::reader; +using namespace omniruntime::reader; std::vector GetIndices(JNIEnv *env, jobject jsonObj, const char* name) { @@ -41,54 +37,40 @@ std::vector GetIndices(JNIEnv *env, jobject jsonObj, const char* name) return indices; } -void parseObs(JNIEnv* env, jobject jsonObj, ObsConfig &obsInfo) { - jobject obsObject = env->CallObjectMethod(jsonObj, jsonMethodObj, env->NewStringUTF("obsInfo")); - if (obsObject == NULL) { - LogsWarn("get obs info failed, obs info is null."); - return; - } - - jstring jEndpoint = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("endpoint")); - auto endpointCharPtr = env->GetStringUTFChars(jEndpoint, JNI_FALSE); - std::string endpoint = endpointCharPtr; - obsInfo.hostLen = endpoint.length() + 1; - strcpy_s(obsInfo.hostName, obsInfo.hostLen, endpoint.c_str()); - env->ReleaseStringUTFChars(jEndpoint, endpointCharPtr); - - jstring jAk = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("ak")); - auto akCharPtr = env->GetStringUTFChars(jAk, JNI_FALSE); - std::string ak = akCharPtr; - strcpy_s(obsInfo.accessKey, ak.length() + 1, ak.c_str()); - env->ReleaseStringUTFChars(jAk, akCharPtr); - - jstring jSk = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("sk")); - auto skCharPtr = env->GetStringUTFChars(jSk, JNI_FALSE); - std::string sk = skCharPtr; - strcpy_s(obsInfo.secretKey, sk.length() + 1, sk.c_str()); - env->ReleaseStringUTFChars(jSk, skCharPtr); - - jstring jToken = (jstring)env->CallObjectMethod(obsObject, jsonMethodString, env->NewStringUTF("token")); - auto tokenCharPtr = env->GetStringUTFChars(jToken, JNI_FALSE); - std::string token = tokenCharPtr; - strcpy_s(obsInfo.token, token.length() + 1, token.c_str()); - env->ReleaseStringUTFChars(jToken, tokenCharPtr); -} - -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_initializeReader(JNIEnv *env, +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeReader(JNIEnv *env, jobject jObj, jobject jsonObj) { JNI_FUNC_START - // Get filePath - jstring path = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("filePath")); - const char *filePath = env->GetStringUTFChars(path, JNI_FALSE); - std::string file(filePath); - env->ReleaseStringUTFChars(path, filePath); + // Get uriStr + jstring uri = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("uri")); + const char *uriStr = env->GetStringUTFChars(uri, JNI_FALSE); + std::string uriString(uriStr); + env->ReleaseStringUTFChars(uri, uriStr); jstring ugiTemp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("ugi")); const char *ugi = env->GetStringUTFChars(ugiTemp, JNI_FALSE); std::string ugiString(ugi); env->ReleaseStringUTFChars(ugiTemp, ugi); + jstring schemeTmp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("scheme")); + const char *scheme = env->GetStringUTFChars(schemeTmp, JNI_FALSE); + std::string schemeString(scheme); + env->ReleaseStringUTFChars(schemeTmp, scheme); + + jstring hostTmp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("host")); + const char *host = env->GetStringUTFChars(hostTmp, JNI_FALSE); + std::string hostString(host); + env->ReleaseStringUTFChars(hostTmp, host); + + jstring pathTmp = (jstring)env->CallObjectMethod(jsonObj, jsonMethodString, env->NewStringUTF("path")); + const char *path = env->GetStringUTFChars(pathTmp, JNI_FALSE); + std::string pathString(path); + env->ReleaseStringUTFChars(pathTmp, path); + + jint port = (jint)env->CallIntMethod(jsonObj, jsonMethodInt, env->NewStringUTF("port")); + + UriInfo uriInfo(uriString, schemeString, pathString, hostString, std::to_string(port)); + // Get capacity for each record batch int64_t capacity = (int64_t)env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("capacity")); @@ -96,11 +78,8 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJ auto row_group_indices = GetIndices(env, jsonObj, "rowGroupIndices"); auto column_indices = GetIndices(env, jsonObj, "columnIndices"); - ObsConfig obsInfo; - parseObs(env, jsonObj, obsInfo); - ParquetReader *pReader = new ParquetReader(); - auto state = pReader->InitRecordReader(file, capacity, row_group_indices, column_indices, ugiString, obsInfo); + auto state = pReader->InitRecordReader(uriInfo, capacity, row_group_indices, column_indices, ugiString); if (state != Status::OK()) { env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); return 0; @@ -109,42 +88,38 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJ JNI_FUNC_END(runtimeExceptionClass) } -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderNext(JNIEnv *env, - jobject jObj, jlong reader, jintArray typeId, jlongArray vecNativeId) +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_recordReaderNext(JNIEnv *env, + jobject jObj, jlong reader, jlongArray vecNativeId) { JNI_FUNC_START ParquetReader *pReader = (ParquetReader *)reader; - std::shared_ptr recordBatchPtr; - auto state = pReader->ReadNextBatch(&recordBatchPtr); + std::vector recordBatch(pReader->columnReaders.size(), 0); + long batchRowSize = 0; + auto state = pReader->ReadNextBatch(recordBatch, &batchRowSize); if (state != Status::OK()) { + for (auto vec : recordBatch) { + delete vec; + } + recordBatch.clear(); env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); return 0; } - int vecCnt = 0; - long batchRowSize = 0; - if (recordBatchPtr != NULL) { - batchRowSize = recordBatchPtr->num_rows(); - vecCnt = recordBatchPtr->num_columns(); - std::vector> fields = recordBatchPtr->schema()->fields(); - - for (int colIdx = 0; colIdx < vecCnt; colIdx++) { - std::shared_ptr array = recordBatchPtr->column(colIdx); - // One array in current batch - std::shared_ptr data = array->data(); - int omniTypeId = 0; - uint64_t omniVecId = 0; - spark::reader::CopyToOmniVec(data->type, omniTypeId, omniVecId, array); - - env->SetIntArrayRegion(typeId, colIdx, 1, &omniTypeId); - jlong omniVec = static_cast(omniVecId); - env->SetLongArrayRegion(vecNativeId, colIdx, 1, &omniVec); + + for (uint64_t colIdx = 0; colIdx < recordBatch.size(); colIdx++) { + auto vector = recordBatch[colIdx]; + // If vector is not initialized, meaning that all data had been read. + if (vector == NULL) { + return 0; } + jlong omniVec = (jlong)(vector); + env->SetLongArrayRegion(vecNativeId, colIdx, 1, &omniVec); } + return (jlong)batchRowSize; JNI_FUNC_END(runtimeExceptionClass) } -JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderClose(JNIEnv *env, +JNIEXPORT void JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_recordReaderClose(JNIEnv *env, jobject jObj, jlong reader) { JNI_FUNC_START diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h similarity index 58% rename from omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.h rename to omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h index 9f47c6fb7a4731a53191e0301ed64dc7da1a282b..a374567476487d13848fab31a82ffef7e038a106 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/ParquetColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.h @@ -1,5 +1,5 @@ /** - * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -17,8 +17,8 @@ * limitations under the License. */ -#ifndef SPARK_THESTRAL_PLUGIN_PARQUETCOLUMNARBATCHJNIREADER_H -#define SPARK_THESTRAL_PLUGIN_PARQUETCOLUMNARBATCHJNIREADER_H +#ifndef OMNI_RUNTIME_PARQUETCOLUMNARBATCHJNIREADER_H +#define OMNI_RUNTIME_PARQUETCOLUMNARBATCHJNIREADER_H #include #include @@ -28,12 +28,8 @@ #include #include #include -#include #include #include -#include -#include -#include #include "common/debug.h" #ifdef __cplusplus @@ -41,28 +37,28 @@ extern "C" { #endif /* - * Class: com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader * Method: initializeReader * Signature: (Ljava/lang/String;Lorg/json/simple/JSONObject;)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_initializeReader - (JNIEnv* env, jobject jObj, jobject job); +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_initializeReader + (JNIEnv* env, jobject jObj, jobject job); /* - * Class: com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader * Method: recordReaderNext * Signature: (J[I[J)J */ -JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderNext - (JNIEnv *, jobject, jlong, jintArray, jlongArray); +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_recordReaderNext + (JNIEnv *, jobject, jlong, jlongArray); /* - * Class: com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader + * Class: com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader * Method: recordReaderClose * Signature: (J)F */ -JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_ParquetColumnarBatchJniReader_recordReaderClose - (JNIEnv *, jobject, jlong); +JNIEXPORT void JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJniReader_recordReaderClose + (JNIEnv *, jobject, jlong); #ifdef __cplusplus } diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/jni_common.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/jni_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13f57e45db23845db3bd061e6db81d326b74c816 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/jni/jni_common.cpp @@ -0,0 +1,96 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMNI_RUNTIME_JNI_COMMON_CPP +#define OMNI_RUNTIME_JNI_COMMON_CPP + +#include "jni_common.h" + +jclass runtimeExceptionClass; +jclass jsonClass; +jclass arrayListClass; +jclass threadClass; + +jmethodID jsonMethodInt; +jmethodID jsonMethodLong; +jmethodID jsonMethodHas; +jmethodID jsonMethodString; +jmethodID jsonMethodJsonObj; +jmethodID arrayListGet; +jmethodID arrayListAdd; +jmethodID arrayListSize; +jmethodID jsonMethodObj; +jmethodID currentThread; + +static jint JNI_VERSION = JNI_VERSION_1_8; + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) +{ + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + return global_class; +} + +jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) +{ + jmethodID ret = env->GetMethodID(this_class, name, sig); + return ret; +} + +jint JNI_OnLoad(JavaVM* vm, void* reserved) +{ + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + + runtimeExceptionClass = CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); + + jsonClass = CreateGlobalClassReference(env, "org/json/JSONObject"); + jsonMethodInt = env->GetMethodID(jsonClass, "getInt", "(Ljava/lang/String;)I"); + jsonMethodLong = env->GetMethodID(jsonClass, "getLong", "(Ljava/lang/String;)J"); + jsonMethodHas = env->GetMethodID(jsonClass, "has", "(Ljava/lang/String;)Z"); + jsonMethodString = env->GetMethodID(jsonClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;"); + jsonMethodJsonObj = env->GetMethodID(jsonClass, "getJSONObject", "(Ljava/lang/String;)Lorg/json/JSONObject;"); + jsonMethodObj = env->GetMethodID(jsonClass, "get", "(Ljava/lang/String;)Ljava/lang/Object;"); + + arrayListClass = CreateGlobalClassReference(env, "java/util/ArrayList"); + arrayListGet = env->GetMethodID(arrayListClass, "get", "(I)Ljava/lang/Object;"); + arrayListSize = env->GetMethodID(arrayListClass, "size", "()I"); + arrayListAdd = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z"); + + threadClass = CreateGlobalClassReference(env, "java/lang/Thread"); + currentThread = env->GetStaticMethodID(threadClass, "currentThread", "()Ljava/lang/Thread;"); + + return JNI_VERSION; +} + +void JNI_OnUnload(JavaVM* vm, void* reserved) +{ + JNIEnv* env; + vm->GetEnv(reinterpret_cast(&env), JNI_VERSION); + + env->DeleteGlobalRef(runtimeExceptionClass); + env->DeleteGlobalRef(jsonClass); + env->DeleteGlobalRef(arrayListClass); + env->DeleteGlobalRef(threadClass); +} + +#endif //OMNI_RUNTIME_JNI_COMMON_CPP diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/jni_common.h b/omnioperator/omniop-native-reader/cpp/src/jni/jni_common.h new file mode 100644 index 0000000000000000000000000000000000000000..6e8326bc37f2a64161bef6071e6c34966c2d988d --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/jni/jni_common.h @@ -0,0 +1,64 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMNI_RUNTIME_JNI_COMMON_H +#define OMNI_RUNTIME_JNI_COMMON_H + +#include + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name); + +jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig); + +#define JNI_FUNC_START try { + +#define JNI_FUNC_END(exceptionClass) \ + } \ + catch (const std::exception &e) \ + { \ + env->ThrowNew(exceptionClass, e.what()); \ + return 0; \ + } \ + + +#define JNI_FUNC_END_VOID(exceptionClass) \ + } \ + catch (const std::exception &e) \ + { \ + env->ThrowNew(exceptionClass, e.what()); \ + return; \ + } \ + +extern jclass runtimeExceptionClass; +extern jclass jsonClass; +extern jclass arrayListClass; +extern jclass threadClass; + +extern jmethodID jsonMethodInt; +extern jmethodID jsonMethodLong; +extern jmethodID jsonMethodHas; +extern jmethodID jsonMethodString; +extern jmethodID jsonMethodJsonObj; +extern jmethodID arrayListGet; +extern jmethodID arrayListAdd; +extern jmethodID arrayListSize; +extern jmethodID jsonMethodObj; +extern jmethodID currentThread; + +#endif //OMNI_RUNTIME_JNI_COMMON_H diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/Adaptor.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/Adaptor.hh new file mode 100644 index 0000000000000000000000000000000000000000..a57858416cd85083b2c1e42c365e9819c48a7543 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/Adaptor.hh @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMNI_RUNTIME_ADAPTER_HH +#define OMNI_RUNTIME_ADAPTER_HH + +#define PRAGMA(TXT) _Pragma(#TXT) + +#ifdef __clang__ + #define DIAGNOSTIC_IGNORE(XXX) PRAGMA(clang diagnostic ignored XXX) +#elif defined(__GNUC__) + #define DIAGNOSTIC_IGNORE(XXX) PRAGMA(GCC diagnostic ignored XXX) +#elif defined(_MSC_VER) + #define DIAGNOSTIC_IGNORE(XXX) __pragma(warning(disable : XXX)) +#else + #define DIAGNOSTIC_IGNORE(XXX) +#endif + +#endif \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.cc similarity index 52% rename from omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.cc rename to omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.cc index 8ec77da2ce30c96cbab5ab4f6dfd768f4648a502..b52401b1a3e3eea8cf5841cefdc871d96a1ec99d 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.cc @@ -16,35 +16,16 @@ * limitations under the License. */ -#include "OrcFileRewrite.hh" -#include "orc/Exceptions.hh" -#include "io/Adaptor.hh" +#include "OrcFileOverride.hh" -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -#include -#define S_IRUSR _S_IREAD -#define S_IWUSR _S_IWRITE -#define stat _stat64 -#define fstat _fstat64 -#else -#include #define O_BINARY 0 -#endif namespace orc { - std::unique_ptr readFileRewrite(const std::string& path, std::vector& tokens) { - if (strncmp(path.c_str(), "hdfs://", 7) == 0) { - return orc::readHdfsFileRewrite(std::string(path), tokens); - } else if (strncmp(path.c_str(), "file:", 5) == 0) { - return orc::readLocalFile(std::string(path.substr(5))); - } else { - return orc::readLocalFile(std::string(path)); + std::unique_ptr readFileOverride(const UriInfo &uri) { + if (uri.Scheme() == "hdfs") { + return orc::createHdfsFileInputStream(uri); + } else { + return orc::readLocalFile(std::string(uri.Path())); + } } - } } diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.hh similarity index 76% rename from omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.hh rename to omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.hh index e7bcee95cecd9dd8b0ac7a120be74a507e47d8a5..8d038627d788b3371e24cd6d4651430457489589 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcFileRewrite.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.hh @@ -21,8 +21,8 @@ #include -#include "hdfspp/options.h" #include "orc/OrcFile.hh" +#include "common/UriInfo.h" /** /file orc/OrcFile.hh @brief The top level interface to ORC. @@ -32,15 +32,15 @@ namespace orc { /** * Create a stream to a local file or HDFS file if path begins with "hdfs://" - * @param path the name of the file in the local file system or HDFS + * @param uri the UriInfo of HDFS */ - ORC_UNIQUE_PTR readFileRewrite(const std::string& path, std::vector& tokens); + ORC_UNIQUE_PTR readFileOverride(const UriInfo &uri); /** * Create a stream to an HDFS file. - * @param path the uri of the file in HDFS + * @param uri the UriInfo of HDFS */ - ORC_UNIQUE_PTR readHdfsFileRewrite(const std::string& path, std::vector& tokens); + ORC_UNIQUE_PTR createHdfsFileInputStream(const UriInfo &uri); } #endif diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcHdfsFileOverride.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcHdfsFileOverride.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a877087b3ef0fc02052601a4fbb60273a68f790 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcHdfsFileOverride.cc @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OrcFileOverride.hh" +#include +#include + +#include "filesystem/hdfs_file.h" +#include "filesystem/io_exception.h" + +namespace orc { + + using namespace fs; + + class HdfsFileInputStreamOverride : public InputStream { + private: + std::string filename_; + std::unique_ptr hdfs_file_; + uint64_t total_length_; + const uint64_t READ_SIZE_ = 1024 * 1024; //1 MB + + public: + HdfsFileInputStreamOverride(const UriInfo& uri) { + this->filename_ = uri.Path(); + std::shared_ptr fileSystemPtr = getHdfsFileSystem(uri.Host(), uri.Port()); + this->hdfs_file_ = std::make_unique(fileSystemPtr, this->filename_, 0); + + Status openFileSt = hdfs_file_->OpenFile(); + if (!openFileSt.IsOk()) { + throw IOException(openFileSt.ToString()); + } + + this->total_length_= hdfs_file_->GetFileSize(); + } + + ~HdfsFileInputStreamOverride() override { + } + + /** + * get the total length of the file in bytes + */ + uint64_t getLength() const override { + return total_length_; + } + + + /** + * get the natural size of reads + */ + uint64_t getNaturalReadSize() const override { + return READ_SIZE_; + } + + /** + * read length bytes from the file starting at offset into the buffer starting at buf + * @param buf buffer save data + * @param length the number of bytes to read + * @param offset read from + */ + void read(void *buf, + uint64_t length, + uint64_t offset) override { + + if (!buf) { + throw IOException(Status::IOError("Fail to read hdfs file, because read buffer is null").ToString()); + } + + char *buf_ptr = reinterpret_cast(buf); + int64_t total_bytes_read = 0; + int64_t last_bytes_read = 0; + + do { + last_bytes_read = hdfs_file_->ReadAt(buf_ptr, length - total_bytes_read,offset + total_bytes_read); + if (last_bytes_read < 0) { + throw IOException(Status::IOError("Error reading bytes the file").ToString()); + } + if (last_bytes_read == 0) { + break; + } + total_bytes_read += last_bytes_read; + buf_ptr += last_bytes_read; + } while (total_bytes_read < length); + } + + const std::string &getName() const override { + return filename_; + } + }; + + std::unique_ptr createHdfsFileInputStream(const UriInfo &uri) { + return std::unique_ptr(new HdfsFileInputStreamOverride(uri)); + } +} diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetColumnReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetColumnReader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c0446411afcb8e2b9e790465ff43ccf74c31937b --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetColumnReader.cpp @@ -0,0 +1,62 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ParquetColumnReader.h" + +using namespace omniruntime::vec; + +namespace omniruntime::reader { + +Status ParquetColumnReader::NextBatch(int64_t batch_size, BaseVector** out) +{ + RETURN_NOT_OK(LoadBatch(batch_size, out)); + return Status::OK(); +} + +Status ParquetColumnReader::LoadBatch(int64_t records_to_read, BaseVector** out) +{ + BEGIN_PARQUET_CATCH_EXCEPTIONS + record_reader_->Reset(); + record_reader_->Reserve(records_to_read); + while (records_to_read > 0) { + if (!record_reader_->HasMoreData()) { + break; + } + int64_t records_read = record_reader_->ReadRecords(records_to_read); + records_to_read -= records_read; + if (records_read == 0) { + NextRowGroup(); + } + } + + *out = record_reader_->GetBaseVec(); + if (*out == nullptr) { + return Status::Invalid("Parquet Read OmniVector is nullptr!"); + } + return Status::OK(); + END_PARQUET_CATCH_EXCEPTIONS +} + +void ParquetColumnReader::NextRowGroup() +{ + std::unique_ptr page_reader = input_->NextChunk(); + record_reader_->SetPageReader(std::move(page_reader)); +} + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/OrcObsFile.hh b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetColumnReader.h similarity index 35% rename from omnioperator/omniop-spark-extension/cpp/src/io/OrcObsFile.hh rename to omnioperator/omniop-native-reader/cpp/src/parquet/ParquetColumnReader.h index 1c7af3669d513a3cb551c285c1271b7749558410..3061c6259a83b95cd28ebaf4c6a782b904469679 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/io/OrcObsFile.hh +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetColumnReader.h @@ -1,5 +1,5 @@ /** - * Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -17,63 +17,43 @@ * limitations under the License. */ -#include "orc/OrcFile.hh" +#ifndef OMNI_RUNTIME_COLUMN_READER_H +#define OMNI_RUNTIME_COLUMN_READER_H -#include "eSDKOBS.h" - -#define OBS_READ_SIZE 1024 -#define OBS_KEY_SIZE 2048 -#define OBS_TOKEN_SIZE 8192 -#define OBS_PROTOCOL_SIZE 6 - -namespace orc { - typedef struct ObsConfig { - char hostName[OBS_KEY_SIZE]; - char accessKey[OBS_KEY_SIZE]; - char secretKey[OBS_KEY_SIZE]; - char token[OBS_TOKEN_SIZE]; - char bucket[OBS_KEY_SIZE]; - char objectKey[OBS_KEY_SIZE]; - uint32_t hostLen; - } ObsConfig; - - std::unique_ptr readObsFile(const std::string& path, ObsConfig *obsInfo); - - class ObsFileInputStream : public InputStream { - private: - obs_options option; - obs_object_info objectInfo; - obs_get_conditions conditions; - ObsConfig obsInfo; - - std::string filename; - uint64_t totalLength; - const uint64_t READ_SIZE = OBS_READ_SIZE * OBS_READ_SIZE; - - static obs_status obsInitStatus; - - static obs_status obsInit(); - - void getObsInfo(ObsConfig *obsInfo); +#include "ParquetTypedRecordReader.h" +#include +#include +namespace omniruntime::reader { + class ParquetColumnReader { public: - ObsFileInputStream(std::string _filename, ObsConfig *obsInfo); - - uint64_t getLength() const override { - return totalLength; + ParquetColumnReader(std::shared_ptr<::parquet::arrow::ReaderContext> ctx, std::shared_ptr<::arrow::Field> field, + std::unique_ptr<::parquet::arrow::FileColumnIterator> input, ::parquet::internal::LevelInfo leaf_info) + : ctx_(std::move(ctx)), + field_(std::move(field)), + input_(std::move(input)), + descr_(input_->descr()) { + record_reader_ = MakeRecordReader(descr_, leaf_info, ctx_->pool, + field_->type()->id() == ::arrow::Type::DICTIONARY, field_->type()); + NextRowGroup(); } - uint64_t getNaturalReadSize() const override { - return READ_SIZE; - } + ::arrow::Status NextBatch(int64_t batch_size, omniruntime::vec::BaseVector** out); - void read(void* buf, uint64_t length, uint64_t offset) override; + ::arrow::Status LoadBatch(int64_t records_to_read, omniruntime::vec::BaseVector** out); - const std::string& getName() const override { - return filename; + const std::shared_ptr<::arrow::Field> field() { + return field_; } - ~ObsFileInputStream() override { - } + private: + void NextRowGroup(); + + std::shared_ptr<::parquet::arrow::ReaderContext> ctx_; + std::shared_ptr<::arrow::Field> field_; + std::unique_ptr<::parquet::arrow::FileColumnIterator> input_; + const ::parquet::ColumnDescriptor* descr_; + std::shared_ptr record_reader_; }; } +#endif // OMNI_RUNTIME_COLUMN_READER_H \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b5c1d712dd179e362837347e76e11b3684a9af2f --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.cpp @@ -0,0 +1,114 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ParquetDecoder.h" + +using namespace parquet::arrow; +using namespace parquet; +using namespace omniruntime::vec; + +namespace omniruntime::reader { + + ParquetPlainBooleanDecoder::ParquetPlainBooleanDecoder(const ::parquet::ColumnDescriptor* descr) + : ParquetDecoderImpl(descr, ::parquet::Encoding::PLAIN) {} + + void ParquetPlainBooleanDecoder::SetData(int num_values, const uint8_t* data, int len) { + num_values_ = num_values; + bit_reader_ = std::make_unique<::arrow::bit_util::BitReader>(data, len); + } + + int ParquetPlainBooleanDecoder::Decode(uint8_t* buffer, int max_values) { + max_values = std::min(max_values, num_values_); + bool val; + ::arrow::internal::BitmapWriter bit_writer(buffer, 0, max_values); + for (int i = 0; i < max_values; ++i) { + if (!bit_reader_->GetValue(1, &val)) { + ParquetException::EofException(); + } + if (val) { + bit_writer.Set(); + } + bit_writer.Next(); + } + bit_writer.Finish(); + num_values_ -= max_values; + return max_values; + } + + int ParquetPlainBooleanDecoder::Decode(bool* buffer, int max_values) { + max_values = std::min(max_values, num_values_); + if (bit_reader_->GetBatch(1, buffer, max_values) != max_values) { + ::parquet::ParquetException::EofException(); + } + num_values_ -= max_values; + return max_values; + } + + template <> + void ParquetDictDecoderImpl<::parquet::BooleanType>::SetDict(ParquetTypedDecoder<::parquet::BooleanType>* dictionary) { + ParquetException::NYI("Dictionary encoding is not implemented for boolean values"); + } + + template <> + void ParquetDictDecoderImpl::SetDict(ParquetTypedDecoder* dictionary) { + DecodeDict(dictionary); + + auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + + int total_size = 0; + for (int i = 0; i < dictionary_length_; ++i) { + total_size += dict_values[i].len; + } + PARQUET_THROW_NOT_OK(byte_array_data_->Resize(total_size, + /*shrink_to_fit=*/false)); + PARQUET_THROW_NOT_OK( + byte_array_offsets_->Resize((dictionary_length_ + 1) * sizeof(int32_t), + /*shrink_to_fit=*/false)); + + int32_t offset = 0; + uint8_t* bytes_data = byte_array_data_->mutable_data(); + int32_t* bytes_offsets = + reinterpret_cast(byte_array_offsets_->mutable_data()); + for (int i = 0; i < dictionary_length_; ++i) { + memcpy(bytes_data + offset, dict_values[i].ptr, dict_values[i].len); + bytes_offsets[i] = offset; + dict_values[i].ptr = bytes_data + offset; + offset += dict_values[i].len; + } + bytes_offsets[dictionary_length_] = offset; + } + + template <> + inline void ParquetDictDecoderImpl::SetDict(ParquetTypedDecoder* dictionary) { + DecodeDict(dictionary); + + auto dict_values = reinterpret_cast(dictionary_->mutable_data()); + + int fixed_len = descr_->type_length(); + int total_size = dictionary_length_ * fixed_len; + + PARQUET_THROW_NOT_OK(byte_array_data_->Resize(total_size, + /*shrink_to_fit=*/false)); + uint8_t* bytes_data = byte_array_data_->mutable_data(); + for (int32_t i = 0, offset = 0; i < dictionary_length_; ++i, offset += fixed_len) { + memcpy(bytes_data + offset, dict_values[i].ptr, fixed_len); + dict_values[i].ptr = bytes_data + offset; + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h new file mode 100644 index 0000000000000000000000000000000000000000..a36c2e2acb430d15e32f1a1da1be6c83700ecd7d --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h @@ -0,0 +1,651 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMNI_RUNTIME_ENCODING_H +#define OMNI_RUNTIME_ENCODING_H + +#include +#include +#include +#include +#include +#include + +using namespace omniruntime::vec; +using namespace arrow; + +namespace omniruntime::reader { + + class ParquetDecoderImpl : virtual public ::parquet::Decoder { + public: + void SetData(int num_values, const uint8_t* data, int len) override { + num_values_ = num_values; + data_ = data; + len_ = len; + } + + int values_left() const override { return num_values_; } + ::parquet::Encoding::type encoding() const override { return encoding_; } + + protected: + explicit ParquetDecoderImpl(const ::parquet::ColumnDescriptor* descr, ::parquet::Encoding::type encoding) + : descr_(descr), encoding_(encoding), num_values_(0), data_(NULLPTR), len_(0) {} + + // For accessing type-specific metadata, like FIXED_LEN_BYTE_ARRAY + const ::parquet::ColumnDescriptor* descr_; + + const ::parquet::Encoding::type encoding_; + int num_values_; + const uint8_t* data_; + int len_; + int type_length_; + }; + + // TODO: optimize batch move + template + inline int SpacedExpand(T* buffer, int num_values, int null_count, + bool* nulls) { + int idx_decode = num_values - null_count; + std::memset(static_cast(buffer + idx_decode), 0, null_count * sizeof(T)); + if (idx_decode == 0) { + // All nulls, nothing more to do + return num_values; + } + for (int i = num_values - 1; i >= 0; --i) { + if (!nulls[i]) { + idx_decode--; + std::memmove(buffer + i, buffer + idx_decode, sizeof(T)); + } + } + assert(idx_decode == 0); + return num_values; + } + + template + class ParquetTypedDecoder : virtual public ::parquet::TypedDecoder { + public: + using T = typename DType::c_type; + + virtual int DecodeSpaced(T* buffer, int num_values, int null_count, + bool* nulls) { + if (null_count > 0) { + int values_to_read = num_values - null_count; + int values_read = Decode(buffer, values_to_read); + if (values_read != values_to_read) { + throw ::parquet::ParquetException("Number of values / definition_levels read did not match"); + } + + return SpacedExpand(buffer, num_values, null_count, nulls); + } else { + return Decode(buffer, num_values); + } + } + + int Decode(T* buffer, int num_values) override { + ::parquet::ParquetException::NYI("ParquetTypedDecoder for Decode"); + } + + virtual int DecodeArrowNonNull(int num_values, omniruntime::vec::BaseVector** outBaseVec, int64_t offset) { + ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrowNonNull"); + } + + virtual int DecodeArrow(int num_values, int null_count, bool* nulls, + int64_t offset, omniruntime::vec::BaseVector** outBaseVec) { + ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrow"); + } + }; + + template + class ParquetDictDecoder : virtual public ParquetTypedDecoder { + public: + using T = typename DType::c_type; + + virtual void SetDict(ParquetTypedDecoder* dictionary) = 0; + + virtual void InsertDictionary(::arrow::ArrayBuilder* builder) = 0; + + virtual int DecodeIndicesSpaced(int num_values, int null_count, + const uint8_t* valid_bits, int64_t valid_bits_offset, + ::arrow::ArrayBuilder* builder) = 0; + + virtual int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) = 0; + + virtual int DecodeIndices(int num_values, int32_t* indices) = 0; + + virtual void GetDictionary(const T** dictionary, int32_t* dictionary_length) = 0; + }; + + template + class ParquetDictDecoderImpl : public ParquetDecoderImpl, virtual public ParquetDictDecoder { + public: + typedef typename Type::c_type T; + + explicit ParquetDictDecoderImpl(const ::parquet::ColumnDescriptor* descr, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) + : ParquetDecoderImpl(descr, ::parquet::Encoding::RLE_DICTIONARY), + dictionary_(::parquet::AllocateBuffer(pool, 0)), + dictionary_length_(0), + byte_array_data_(::parquet::AllocateBuffer(pool, 0)), + byte_array_offsets_(::parquet::AllocateBuffer(pool, 0)) {} + + void SetDict(ParquetTypedDecoder* dictionary) override; + + void SetData(int num_values, const uint8_t* data, int len) override { + num_values_ = num_values; + if (len == 0) { + idx_decoder_ = ::arrow::util::RleDecoder(data, len, 1); + return; + } + uint8_t bit_width = *data; + if (ARROW_PREDICT_FALSE(bit_width > 32)) { + throw ::parquet::ParquetException("Invalid or corrupted bit_width " + + std::to_string(bit_width) + ". Maximum allowed is 32."); + } + idx_decoder_ = ::arrow::util::RleDecoder(++data, --len, bit_width); + } + + int Decode(T* buffer, int num_values) override { + num_values = std::min(num_values, num_values_); + int decoded_values = + idx_decoder_.GetBatchWithDict(reinterpret_cast(dictionary_->data()), + dictionary_length_, buffer, num_values); + if (decoded_values != num_values) { + ::parquet::ParquetException::EofException(); + } + num_values_ -= num_values; + return num_values; + } + + int DecodeSpaced(T* buffer, int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset) override { + num_values = std::min(num_values, num_values_); + if (num_values != idx_decoder_.GetBatchWithDictSpaced( + reinterpret_cast(dictionary_->data()), + dictionary_length_, buffer, num_values, null_count, valid_bits, + valid_bits_offset)) { + ::parquet::ParquetException::EofException(); + } + num_values_ -= num_values; + return num_values; + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits::Accumulator* out) override { + ::parquet::ParquetException::NYI("DecodeArrow(Accumulator) for OmniDictDecoderImpl"); + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits::DictAccumulator* out) override { + ::parquet::ParquetException::NYI("DecodeArrow(DictAccumulator) for OmniDictDecoderImpl"); + } + + void InsertDictionary(::arrow::ArrayBuilder* builder) override { + ::parquet::ParquetException::NYI("InsertDictionary ArrayBuilder"); + } + + int DecodeIndicesSpaced(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + ::arrow::ArrayBuilder* builder) override { + ::parquet::ParquetException::NYI("DecodeIndicesSpaced ArrayBuilder"); + } + + int DecodeIndices(int num_values, ::arrow::ArrayBuilder* builder) override { + ::parquet::ParquetException::NYI("DecodeIndices ArrayBuilder"); + } + + int DecodeIndices(int num_values, int32_t* indices) override { + if (num_values != idx_decoder_.GetBatch(indices, num_values)) { + ::parquet::ParquetException::EofException(); + } + num_values_ -= num_values; + return num_values; + } + + void GetDictionary(const T** dictionary, int32_t* dictionary_length) override { + *dictionary_length = dictionary_length_; + *dictionary = reinterpret_cast(dictionary_->mutable_data()); + } + + virtual int DecodeArrowNonNull(int num_values, omniruntime::vec::BaseVector** outBaseVec, int64_t offset) { + ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrowNonNull"); + } + + virtual int DecodeArrow(int num_values, int null_count, bool* nulls, + int64_t offset, omniruntime::vec::BaseVector** outBaseVec) { + ::parquet::ParquetException::NYI("ParquetTypedDecoder for DecodeArrow"); + } + + protected: + Status IndexInBounds(int32_t index) { + if (ARROW_PREDICT_TRUE(0 <= index && index < dictionary_length_)) { + return Status::OK(); + } + return Status::Invalid("Index not in dictionary bounds"); + } + + inline void DecodeDict(::parquet::TypedDecoder* dictionary) { + dictionary_length_ = static_cast(dictionary->values_left()); + PARQUET_THROW_NOT_OK(dictionary_->Resize(dictionary_length_ * sizeof(T), + /*shrink_to_fit=*/false)); + dictionary->Decode(reinterpret_cast(dictionary_->mutable_data()), dictionary_length_); + } + + std::shared_ptr<::parquet::ResizableBuffer> dictionary_; + + int32_t dictionary_length_; + + std::shared_ptr<::parquet::ResizableBuffer> byte_array_data_; + + std::shared_ptr<::parquet::ResizableBuffer> byte_array_offsets_; + + ::arrow::util::RleDecoder idx_decoder_; + }; + + template + void ParquetDictDecoderImpl::SetDict(ParquetTypedDecoder* dictionary) { + DecodeDict(dictionary); + } + + class OmniDictByteArrayDecoderImpl : public ParquetDictDecoderImpl<::parquet::ByteArrayType> { + public: + using BASE = ParquetDictDecoderImpl<::parquet::ByteArrayType>; + using BASE::ParquetDictDecoderImpl; + + int DecodeArrowNonNull(int num_values, omniruntime::vec::BaseVector** outBaseVec, int64_t offset) override { + int result = 0; + PARQUET_THROW_NOT_OK(DecodeArrowNonNull(num_values, &result, outBaseVec, offset)); + return result; + } + + int DecodeArrow(int num_values, int null_count, bool* nulls, + int64_t offset, omniruntime::vec::BaseVector** vec) override { + int result = 0; + PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, nulls, + offset, &result, vec)); + return result; + } + + private: + Status DecodeArrowDense(int num_values, int null_count, bool* nulls, + int64_t offset, + int* out_num_values, omniruntime::vec::BaseVector** out) { + constexpr int32_t kBufferSize = 1024; + int32_t indices[kBufferSize]; + + auto vec = dynamic_cast>*>(*out); + + auto dict_values = reinterpret_cast(dictionary_->data()); + int values_decoded = 0; + int num_indices = 0; + int pos_indices = 0; + + for (int i = 0; i < num_values; i++) { + if (!nulls[offset + i]) { + if (num_indices == pos_indices) { + const auto batch_size = + std::min(kBufferSize, num_values - null_count - values_decoded); + num_indices = idx_decoder_.GetBatch(indices, batch_size); + if (ARROW_PREDICT_FALSE(num_indices < 1)) { + return Status::Invalid("Invalid number of indices: ", num_indices); + } + pos_indices = 0; + } + const auto index = indices[pos_indices++]; + RETURN_NOT_OK(IndexInBounds(index)); + const auto& val = dict_values[index]; + std::string_view value(reinterpret_cast(val.ptr), val.len); + vec->SetValue(offset + i, value); + ++values_decoded; + } else { + vec->SetNull(offset + i); + } + } + + *out_num_values = values_decoded; + return Status::OK(); + } + + Status DecodeArrowNonNull(int num_values, int* out_num_values, omniruntime::vec::BaseVector** out, int offset) { + constexpr int32_t kBufferSize = 2048; + int32_t indices[kBufferSize]; + + auto vec = dynamic_cast>*>(*out); + + auto dict_values = reinterpret_cast(dictionary_->data()); + + int values_decoded = 0; + while (values_decoded < num_values) { + int32_t batch_size = std::min(kBufferSize, num_values - values_decoded); + int num_indices = idx_decoder_.GetBatch(indices, batch_size); + if (num_indices == 0) ::parquet::ParquetException::EofException(); + for (int i = 0; i < num_indices; ++i) { + auto idx = indices[i]; + RETURN_NOT_OK(IndexInBounds(idx)); + const auto& val = dict_values[idx]; + std::string_view value(reinterpret_cast(val.ptr), val.len); + vec->SetValue(i + offset, value); + } + values_decoded += num_indices; + } + *out_num_values = values_decoded; + return Status::OK(); + } + }; + + template + class ParquetPlainDecoder : public ParquetDecoderImpl, virtual public ParquetTypedDecoder { + public: + using T = typename DType::c_type; + explicit ParquetPlainDecoder(const ::parquet::ColumnDescriptor* descr); + + int Decode(T* buffer, int max_values) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits::Accumulator* builder) override { + ::parquet::ParquetException::NYI("DecodeArrow(Accumulator) for ParquetPlainDecoder"); + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits::DictAccumulator* builder) override { + ::parquet::ParquetException::NYI("DecodeArrow(DictAccumulator) for ParquetPlainDecoder"); + } + }; + + template + inline int DecodePlain(const uint8_t* data, int64_t data_size, int num_values, + int type_length, T* out) { + int64_t bytes_to_decode = num_values * static_cast(sizeof(T)); + if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { + ::parquet::ParquetException::EofException(); + } + if (bytes_to_decode > 0) { + memcpy(out, data, bytes_to_decode); + } + return static_cast(bytes_to_decode); + } + + static inline int64_t ReadByteArray(const uint8_t* data, int64_t data_size, + ::parquet::ByteArray* out) { + if (ARROW_PREDICT_FALSE(data_size < 4)) { + parquet::ParquetException::EofException(); + } + const int32_t len = ::arrow::util::SafeLoadAs(data); + if (len < 0) { + throw parquet::ParquetException("Invalid BYTE_ARRAY value"); + } + const int64_t consumed_length = static_cast(len) + 4; + if (ARROW_PREDICT_FALSE(data_size < consumed_length)) { + parquet::ParquetException::EofException(); + } + *out = parquet::ByteArray{static_cast(len), data + 4}; + return consumed_length; + } + + template <> + inline int DecodePlain<::parquet::ByteArray>(const uint8_t* data, int64_t data_size, int num_values, + int type_length, ::parquet::ByteArray* out) { + int bytes_decoded = 0; + for (int i = 0; i < num_values; ++i) { + const auto increment = ReadByteArray(data, data_size, out + i); + if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytes_decoded)) { + throw ::parquet::ParquetException("BYTE_ARRAY chunk too large"); + } + data += increment; + data_size -= increment; + bytes_decoded += static_cast(increment); + } + return bytes_decoded; + } + + template <> + inline int DecodePlain<::parquet::FixedLenByteArray>(const uint8_t* data, int64_t data_size, + int num_values, int type_length, + ::parquet::FixedLenByteArray* out) { + int64_t bytes_to_decode = static_cast(type_length) * num_values; + if (bytes_to_decode > data_size || bytes_to_decode > INT_MAX) { + ::parquet::ParquetException::EofException(); + } + + memcpy_s(reinterpret_cast(out), bytes_to_decode, data, bytes_to_decode); + + return static_cast(bytes_to_decode); + } + + template + ParquetPlainDecoder::ParquetPlainDecoder(const ::parquet::ColumnDescriptor* descr) + : ParquetDecoderImpl(descr, ::parquet::Encoding::PLAIN) { + if (descr_ && descr_->physical_type() == ::parquet::Type::FIXED_LEN_BYTE_ARRAY) { + type_length_ = descr_->type_length(); + } else { + type_length_ = -1; + } + } + + template + int ParquetPlainDecoder::Decode(T* buffer, int max_values) { + max_values = std::min(max_values, num_values_); + int bytes_consumed = DecodePlain(data_, len_, max_values, type_length_, buffer); + data_ += bytes_consumed; + len_ -= bytes_consumed; + num_values_ -= max_values; + return max_values; + } + + class ParquetPlainByteArrayDecoder : public ParquetPlainDecoder<::parquet::ByteArrayType> { + public: + using Base = ParquetPlainDecoder<::parquet::ByteArrayType>; + using Base::ParquetPlainDecoder; + + int DecodeArrowNonNull(int num_values, omniruntime::vec::BaseVector** outBaseVec, int64_t offset) override { + int result = 0; + PARQUET_THROW_NOT_OK(DecodeArrowDenseNonNull(num_values, &result, outBaseVec, offset)); + return result; + } + + int DecodeArrow(int num_values, int null_count, bool* nulls, + int64_t offset, omniruntime::vec::BaseVector** outBaseVec) { + int result = 0; + PARQUET_THROW_NOT_OK(DecodeArrowDense(num_values, null_count, nulls, + offset, &result, outBaseVec)); + return result; + } + + private: + Status DecodeArrowDense(int num_values, int null_count, bool* nulls, + int64_t offset, + int* out_values_decoded, omniruntime::vec::BaseVector** out) { + int values_decoded = 0; + auto vec = dynamic_cast>*>(*out); + + for (int i = 0; i < num_values; i++) { + if (!nulls[offset + i]) { + if (ARROW_PREDICT_FALSE(len_ < 4)) { + ::parquet::ParquetException::EofException(); + } + auto value_len = ::arrow::util::SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + return Status::Invalid("Invalid or corrupted value_len '", value_len, "'"); + } + auto increment = value_len + 4; + if (ARROW_PREDICT_FALSE(len_ < increment)) { + ::parquet::ParquetException::EofException(); + } + std::string_view value(reinterpret_cast(data_ + 4), value_len); + vec->SetValue(offset + i, value); + data_ += increment; + len_ -= increment; + ++values_decoded; + } else { + vec->SetNull(offset + i); + } + } + + num_values_ -= values_decoded; + *out_values_decoded = values_decoded; + return Status::OK(); + } + + Status DecodeArrowDenseNonNull(int num_values, + int* out_values_decoded, omniruntime::vec::BaseVector** out, int64_t offset) { + int values_decoded = 0; + auto vec = dynamic_cast>*>(*out); + + for (int i = 0; i < num_values; i++) { + if (ARROW_PREDICT_FALSE(len_ < 4)) { + ::parquet::ParquetException::EofException(); + } + auto value_len = ::arrow::util::SafeLoadAs(data_); + if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - 4)) { + return Status::Invalid("Invalid or corrupted value_len '", value_len, "'"); + } + auto increment = value_len + 4; + if (ARROW_PREDICT_FALSE(len_ < increment)) { + ::parquet::ParquetException::EofException(); + } + std::string_view value(reinterpret_cast(data_ + 4), value_len); + (vec)->SetValue(offset + i, value); + data_ += increment; + len_ -= increment; + ++values_decoded; + } + num_values_ -= values_decoded; + *out_values_decoded = values_decoded; + return Status::OK(); + } + }; + + class ParquetBooleanDecoder : virtual public ParquetTypedDecoder<::parquet::BooleanType> { + public: + using ParquetTypedDecoder<::parquet::BooleanType>::Decode; + virtual int Decode(uint8_t* buffer, int max_values) = 0; + }; + + class ParquetPlainBooleanDecoder : public ParquetDecoderImpl, virtual public ParquetBooleanDecoder { + public: + explicit ParquetPlainBooleanDecoder(const ::parquet::ColumnDescriptor* descr); + void SetData(int num_values, const uint8_t* data, int len) override; + + int Decode(uint8_t* buffer, int max_values) override; + int Decode(bool* buffer, int max_values) override; + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits<::parquet::BooleanType>::Accumulator* out) override { + ::parquet::ParquetException::NYI("DecodeArrow for ParquetPlainBooleanDecoder"); + } + + int DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits<::parquet::BooleanType>::DictAccumulator* builder) override { + ::parquet::ParquetException::NYI("DecodeArrow for ParquetPlainBooleanDecoder"); + } + + private: + std::unique_ptr<::arrow::bit_util::BitReader> bit_reader_; + }; + + class ParquetRleBooleanDecoder : public ParquetDecoderImpl, virtual public ParquetBooleanDecoder { + public: + explicit ParquetRleBooleanDecoder(const ::parquet::ColumnDescriptor* descr) + : ParquetDecoderImpl(descr, ::parquet::Encoding::RLE) {} + + void SetData(int num_values, const uint8_t* data, int len) override { + num_values_ = num_values; + uint32_t num_bytes = 0; + + if (len < 4) { + throw ::parquet::ParquetException("Received invalid length : " + std::to_string(len) + + " (corrupt data page?)"); + } + + num_bytes = + ::arrow::bit_util::ToLittleEndian(::arrow::util::SafeLoadAs(data)); + if (num_bytes < 0 || num_bytes > static_cast(len - 4)) { + throw ::parquet::ParquetException("Received invalid number of bytes : " + + std::to_string(num_bytes) + " (corrupt data page?)"); + } + + auto decoder_data = data + 4; + decoder_ = std::make_shared<::arrow::util::RleDecoder>(decoder_data, num_bytes, + /*bit_width=*/1); + } + + int Decode(bool* buffer, int max_values) override { + max_values = std::min(max_values, num_values_); + + if (decoder_->GetBatch(buffer, max_values) != max_values) { + ::parquet::ParquetException::EofException(); + } + num_values_ -= max_values; + return max_values; + } + + int Decode(uint8_t* buffer, int max_values) override { + ::parquet::ParquetException::NYI("Decode(uint8_t*, int) for RleBooleanDecoder"); + } + + int DecodeArrow(int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits<::parquet::BooleanType>::Accumulator* out) override { + ::parquet::ParquetException::NYI("DecodeArrow for RleBooleanDecoder"); + } + + int DecodeArrow( + int num_values, int null_count, const uint8_t* valid_bits, + int64_t valid_bits_offset, + typename ::parquet::EncodingTraits<::parquet::BooleanType>::DictAccumulator* builder) override { + ::parquet::ParquetException::NYI("DecodeArrow for RleBooleanDecoder"); + } + + private: + std::shared_ptr<::arrow::util::RleDecoder> decoder_; + }; + + class ParquetPlainFLBADecoder : public ParquetPlainDecoder<::parquet::FLBAType>, virtual public ::parquet::FLBADecoder { + public: + using Base = ParquetPlainDecoder<::parquet::FLBAType>; + using Base::ParquetPlainDecoder; + + int DecodeSpaced(T* buffer, int num_values, int null_count, + bool* nulls) override { + int values_to_read = num_values - null_count; + Decode(buffer, values_to_read); + return num_values; + } + + int Decode(T* buffer, int max_values) override { + max_values = std::min(max_values, num_values_); + int bytes_consumed = DecodePlain(data_, len_, max_values, type_length_, buffer); + data_ += bytes_consumed; + len_ -= bytes_consumed; + num_values_ -= max_values; + return max_values; + } + }; +} +#endif // OMNI_RUNTIME_ENCODING_H diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d4d6a8a489ced0f8328f334815e0e5078fec164 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp @@ -0,0 +1,245 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "jni/jni_common.h" +#include "ParquetReader.h" +#include "common/UriInfo.h" +#include "arrowadapter/FileSystemAdapter.h" + +using namespace arrow; +using namespace arrow::internal; +using namespace parquet::arrow; +using namespace omniruntime::reader; + +static std::mutex mutex_; +static std::unordered_map restore_filesysptr; +static constexpr int32_t LOCAL_FILE_PREFIX = 5; +static constexpr int32_t LOCAL_FILE_PREFIX_EXT = 7; +static const std::string LOCAL_FILE = "file:"; +static const std::string HDFS_FILE = "hdfs:"; + +// the ugi is UserGroupInformation +std::string omniruntime::reader::GetFileSystemKey(std::string& path, std::string& ugi) +{ + // if the local file, all the files are the same key "file:" + std::string result = ugi; + + // if the hdfs file, only get the ip and port just like the ugi + ip + port as key + if (path.substr(0, LOCAL_FILE_PREFIX) == HDFS_FILE) { + auto end = path.find("/", LOCAL_FILE_PREFIX_EXT); + std::string ip_and_port = path.substr(LOCAL_FILE_PREFIX_EXT, end - LOCAL_FILE_PREFIX_EXT); + result += ip_and_port; + return result; + } + + // if the local file, get the ugi + "file" as the key + if (path.substr(0, LOCAL_FILE_PREFIX) == LOCAL_FILE) { + // process the path "file://" head, the arrow could not read the head + path = path.substr(LOCAL_FILE_PREFIX); + result += "file:"; + return result; + } + + // if not the local, not the hdfs, get the ugi + path as the key + result += path; + return result; +} + +Filesystem* omniruntime::reader::GetFileSystemPtr(UriInfo &uri, std::string& ugi, arrow::Status &status) +{ + std::string fullPath = uri.ToString(); + auto key = GetFileSystemKey(fullPath, ugi); + + // if not find key, create the filesystem ptr + auto iter = restore_filesysptr.find(key); + if (iter == restore_filesysptr.end()) { + Filesystem* fs = new Filesystem(); + auto result = arrow_adapter::FileSystemFromUriOrPath(uri); + status = result.status(); + if (!status.ok()) { + return nullptr; + } + fs->filesys_ptr = std::move(result).ValueUnsafe(); + restore_filesysptr[key] = fs; + } + + return restore_filesysptr[key]; +} + +Status ParquetReader::InitRecordReader(UriInfo &uri, int64_t capacity, + const std::vector& row_group_indices, const std::vector& column_indices, + std::string& ugi) +{ + // Configure reader settings + auto reader_properties = parquet::ReaderProperties(pool); + + // Configure Arrow-specific reader settings + auto arrow_reader_properties = parquet::ArrowReaderProperties(); + arrow_reader_properties.set_batch_size(capacity); + + std::shared_ptr file; + + // Get the file from filesystem + Status result; + mutex_.lock(); + Filesystem* fs = GetFileSystemPtr(uri, ugi, result); + mutex_.unlock(); + if (fs == nullptr || fs->filesys_ptr == nullptr) { + return Status::IOError(result); + } + std::string path = uri.ToString(); + ARROW_ASSIGN_OR_RAISE(file, fs->filesys_ptr->OpenInputFile(path)); + + FileReaderBuilder reader_builder; + ARROW_RETURN_NOT_OK(reader_builder.Open(file, reader_properties)); + reader_builder.memory_pool(pool); + reader_builder.properties(arrow_reader_properties); + + ARROW_ASSIGN_OR_RAISE(arrow_reader, reader_builder.Build()); + ARROW_RETURN_NOT_OK(GetRecordBatchReader(row_group_indices, column_indices)); + return arrow::Status::OK(); +} + +Status ParquetReader::ReadNextBatch(std::vector &batch, long *batchRowSize) +{ + ARROW_RETURN_NOT_OK(rb_reader->ReadNext(batch, batchRowSize)); + return arrow::Status::OK(); +} + +Status ParquetReader::GetRecordBatchReader(const std::vector &row_group_indices, + const std::vector &column_indices) +{ + std::shared_ptr<::arrow::Schema> batch_schema; + RETURN_NOT_OK(GetFieldReaders(row_group_indices, column_indices, &columnReaders, &batch_schema)); + + int64_t num_rows = 0; + for(int row_group : row_group_indices) { + num_rows += arrow_reader->parquet_reader()->metadata()->RowGroup(row_group)->num_rows(); + } + // Use lambda function to generate BaseVectors + auto batches = [num_rows, this](std::vector &batch, + long *batchRowSize) mutable -> Status { + int64_t read_size = std::min(arrow_reader->properties().batch_size(), num_rows); + num_rows -= read_size; + *batchRowSize = read_size; + + if (columnReaders.empty() || read_size <= 0) { + return Status::OK(); + } + + try { + for (uint64_t i = 0; i < columnReaders.size(); i++) { + RETURN_NOT_OK(columnReaders[i]->NextBatch(read_size, &batch[i])); + } + } catch (const std::exception &e) { + return Status::Invalid(e.what()); + } + + // Check BaseVector + for (const auto& column : batch) { + if (column == nullptr) { + return Status::Invalid("BaseVector should not be nullptr after reading"); + } + } + + return Status::OK(); + }; + + rb_reader = std::make_unique(std::move(batches)); + return Status::OK(); +} + +std::shared_ptr> VectorToSharedSet(const std::vector &values) { + std::shared_ptr> result(new std::unordered_set()); + result->insert(values.begin(), values.end()); + return result; +} + +Status ParquetReader::GetFieldReaders(const std::vector &row_group_indices, const std::vector &column_indices, + std::vector>* out, std::shared_ptr<::arrow::Schema>* out_schema) +{ + // We only read schema fields which have columns indicated in the indices vector + ARROW_ASSIGN_OR_RAISE(std::vector field_indices, arrow_reader->manifest().GetFieldIndices(column_indices)); + auto included_leaves = VectorToSharedSet(column_indices); + out->resize(field_indices.size()); + ::arrow::FieldVector out_fields(field_indices.size()); + + for (size_t i = 0; i < out->size(); i++) { + std::unique_ptr reader; + RETURN_NOT_OK(GetFieldReader(field_indices[i], included_leaves, row_group_indices, &reader)); + out_fields[i] = reader->field(); + out->at(i) = std::move(reader); + } + + *out_schema = ::arrow::schema(std::move(out_fields), arrow_reader->manifest().schema_metadata); + return Status::OK(); +} + +FileColumnIteratorFactory SomeRowGroupsFactory(std::vector row_group_indices) { + return [row_group_indices] (int i, parquet::ParquetFileReader* reader) { + return new FileColumnIterator(i, reader, row_group_indices); + }; +} + +Status ParquetReader::GetFieldReader(int i, const std::shared_ptr>& included_leaves, + const std::vector &row_group_indices, std::unique_ptr* out) +{ + if (ARROW_PREDICT_FALSE(i < 0 || static_cast(i) >= arrow_reader->manifest().schema_fields.size())) { + return Status::Invalid("Column index out of bounds (got ", i, + ", should be between 0 and ", arrow_reader->manifest().schema_fields.size(), ")"); + } + auto ctx = std::make_shared(); + ctx->reader = arrow_reader->parquet_reader(); + ctx->pool = pool; + ctx->iterator_factory = SomeRowGroupsFactory(row_group_indices); + ctx->filter_leaves = true; + ctx->included_leaves = included_leaves; + auto field = arrow_reader->manifest().schema_fields[i]; + return GetReader(field, field.field, ctx, out); +} + +Status ParquetReader::GetReader(const SchemaField &field, const std::shared_ptr &arrow_field, + const std::shared_ptr &ctx, std::unique_ptr *out) +{ + BEGIN_PARQUET_CATCH_EXCEPTIONS + + auto type_id = arrow_field->type()->id(); + + if (type_id == ::arrow::Type::EXTENSION) { + return Status::Invalid("Unsupported type: ", arrow_field->ToString()); + } + + if (field.children.size() == 0) { + if (!field.is_leaf()) { + return Status::Invalid("Parquet non-leaf node has no children"); + } + if (!ctx->IncludesLeaf(field.column_index)) { + *out = nullptr; + return Status::OK(); + } + std::unique_ptr input(ctx->iterator_factory(field.column_index, ctx->reader)); + *out = std::make_unique(ctx, arrow_field, std::move(input), field.level_info); + } else { + return Status::Invalid("Unsupported type: ", arrow_field->ToString()); + } + return Status::OK(); + + END_PARQUET_CATCH_EXCEPTIONS +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h similarity index 35% rename from omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.h rename to omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h index 9a55d785ca9a3e926cea28ebea392e7e73680da7..a0f475e5aca8a06e60cca4d18c99090eb9b24a77 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.h @@ -1,5 +1,5 @@ /** - * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -17,40 +17,63 @@ * limitations under the License. */ -#ifndef SPARK_THESTRAL_PLUGIN_PARQUETREADER_H -#define SPARK_THESTRAL_PLUGIN_PARQUETREADER_H +#ifndef OMNI_RUNTIME_PARQUETREADER_H +#define OMNI_RUNTIME_PARQUETREADER_H -#include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace spark::reader { +#include +#include "ParquetColumnReader.h" +#include "common/UriInfo.h" + +using namespace arrow::internal; + +namespace omniruntime::reader { + + class OmniRecordBatchReader { + public: + OmniRecordBatchReader(std::function &batch, long *batchRowSize)> batches) + : batches_(std::move(batches)) {} + + ~OmniRecordBatchReader() {} + + Status ReadNext(std::vector &out, long *batchRowSize) { + return batches_(out, batchRowSize); + } + + private: + std::function &batch, long *batchRowSize)> batches_; + }; + + class ParquetReader { public: ParquetReader() {} - arrow::Status InitRecordReader(std::string& path, int64_t capacity, - const std::vector& row_group_indices, const std::vector& column_indices, std::string& ugi, - ObsConfig& obsInfo); + arrow::Status InitRecordReader(UriInfo &uri, int64_t capacity, + const std::vector& row_group_indices, const std::vector& column_indices, std::string& ugi); - arrow::Status ReadNextBatch(std::shared_ptr *batch); + arrow::Status ReadNextBatch(std::vector &batch, long *batchRowSize); std::unique_ptr arrow_reader; - std::shared_ptr rb_reader; + std::unique_ptr rb_reader; + + std::vector> columnReaders; + + arrow::MemoryPool* pool = arrow::default_memory_pool(); + + private: + arrow::Status GetRecordBatchReader(const std::vector &row_group_indices, const std::vector &column_indices); + + arrow::Status GetFieldReaders(const std::vector &row_group_indices, const std::vector &column_indices, + std::vector>* out, std::shared_ptr<::arrow::Schema>* out_schema); + + arrow::Status GetFieldReader(int i, const std::shared_ptr>& included_leaves, + const std::vector &row_group_indices, std::unique_ptr* out); + + arrow::Status GetReader(const parquet::arrow::SchemaField &field, const std::shared_ptr &arrow_field, + const std::shared_ptr &ctx, std::unique_ptr* out); + }; class Filesystem { @@ -65,11 +88,6 @@ namespace spark::reader { std::string GetFileSystemKey(std::string& path, std::string& ugi); - Filesystem* GetFileSystemPtr(std::string& path, std::string& ugi); - - int CopyToOmniVec(std::shared_ptr vcType, int &omniTypeId, uint64_t &omniVecId, - std::shared_ptr array); - - std::pair TransferToOmniVecs(std::shared_ptr batch); + Filesystem* GetFileSystemPtr(UriInfo &uri, std::string& ugi, arrow::Status &status); } -#endif // SPARK_THESTRAL_PLUGIN_PARQUETREADER_H \ No newline at end of file +#endif // OMNI_RUNTIME_PARQUETREADER_H \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6251044a85da44926b049f313bef92813e69552a --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.cpp @@ -0,0 +1,503 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ParquetTypedRecordReader.h" +#include "ParquetDecoder.h" + +using namespace parquet::internal; +using namespace arrow; +using namespace parquet; + +namespace omniruntime::reader { + +constexpr int32_t DECIMAL64_LEN = 8; + +::parquet::Decoder* MakeOmniParquetDecoder(::parquet::Type::type type_num, ::parquet::Encoding::type encoding, + const ColumnDescriptor* descr) { + if (encoding == ::parquet::Encoding::PLAIN) { + switch (type_num) { + case ::parquet::Type::BOOLEAN: + return new ParquetPlainBooleanDecoder(descr); + case ::parquet::Type::INT32: + return new ParquetPlainDecoder<::parquet::Int32Type>(descr); + case ::parquet::Type::INT64: + return new ParquetPlainDecoder<::parquet::Int64Type>(descr); + case ::parquet::Type::DOUBLE: + return new ParquetPlainDecoder<::parquet::DoubleType>(descr); + case ::parquet::Type::BYTE_ARRAY: + return new ParquetPlainByteArrayDecoder(descr); + case ::parquet::Type::FIXED_LEN_BYTE_ARRAY: + return new ParquetPlainFLBADecoder(descr); + default: + ::parquet::ParquetException::NYI("Not supported decoder type: " + type_num); + } + } else if (encoding == ::parquet::Encoding::RLE) { + if (type_num == ::parquet::Type::BOOLEAN) { + return new ParquetRleBooleanDecoder(descr); + } + ::parquet::ParquetException::NYI("RLE encoding only supports BOOLEAN"); + } else { + ::parquet::ParquetException::NYI("Selected encoding is not supported"); + } + DCHECK(false) << "Should not be able to reach this code"; + return nullptr; +} + + +::parquet::Decoder* MakeOmniDictDecoder(::parquet::Type::type type_num, + const ColumnDescriptor* descr, ::arrow::MemoryPool* pool) { + switch (type_num) { + case ::parquet::Type::BOOLEAN: + ::parquet::ParquetException::NYI("Dictionary BOOLEAN encoding not implemented for boolean type"); + case ::parquet::Type::INT32: + return new ParquetDictDecoderImpl<::parquet::Int32Type>(descr, pool); + case ::parquet::Type::INT64: + return new ParquetDictDecoderImpl<::parquet::Int64Type>(descr, pool); + case ::parquet::Type::DOUBLE: + return new ParquetDictDecoderImpl<::parquet::DoubleType>(descr, pool); + case ::parquet::Type::BYTE_ARRAY: + return new OmniDictByteArrayDecoderImpl(descr, pool); + case ::parquet::Type::FIXED_LEN_BYTE_ARRAY: + return new ParquetDictDecoderImpl<::parquet::FLBAType>(descr, pool); + default: + ::parquet::ParquetException::NYI("Not supported dictionary decoder type: " + type_num); + } + DCHECK(false) << "Should not be able to reach this code"; + return nullptr; +} + +template +std::unique_ptr> MakeParquetDictDecoder( + const ColumnDescriptor* descr = NULLPTR, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()) { + using OutType = ParquetDictDecoder; + auto decoder = MakeOmniDictDecoder(DType::type_num, descr, pool); + return std::unique_ptr(dynamic_cast(decoder)); +} + +template +std::unique_ptr> MakeParquetTypedDecoder( + ::parquet::Encoding::type encoding, const ColumnDescriptor* descr = NULLPTR) { + using OutType = ParquetTypedDecoder; + auto base = MakeOmniParquetDecoder(DType::type_num, encoding, descr); + return std::unique_ptr(dynamic_cast(base)); +} + +// Advance to the next data page +template +bool ParquetColumnReaderBase::ReadNewPage() { + // Loop until we find the next data page. + while (true) { + current_page_ = pager_->NextPage(); + if (!current_page_) { + // EOS + return false; + } + + if (current_page_->type() == PageType::DICTIONARY_PAGE) { + ConfigureDictionary(static_cast(current_page_.get())); + continue; + } else if (current_page_->type() == PageType::DATA_PAGE) { + const auto page = std::static_pointer_cast(current_page_); + const int64_t levels_byte_size = InitializeLevelDecoders( + *page, page->repetition_level_encoding(), page->definition_level_encoding()); + InitializeDataDecoder(*page, levels_byte_size); + return true; + } else if (current_page_->type() == PageType::DATA_PAGE_V2) { + const auto page = std::static_pointer_cast(current_page_); + int64_t levels_byte_size = InitializeLevelDecodersV2(*page); + InitializeDataDecoder(*page, levels_byte_size); + return true; + } else { + // We don't know what this page type is. We're allowed to skip non-data + // pages. + continue; + } + } + return true; +} + +template +void ParquetColumnReaderBase::ConfigureDictionary(const DictionaryPage* page) { + int encoding = static_cast(page->encoding()); + if (page->encoding() == ::parquet::Encoding::PLAIN_DICTIONARY || + page->encoding() == ::parquet::Encoding::PLAIN) { + encoding = static_cast(::parquet::Encoding::RLE_DICTIONARY); + } + + auto it = decoders_.find(encoding); + if (it != decoders_.end()) { + throw ParquetException("Column cannot have more than one dictionary."); + } + + if (page->encoding() == ::parquet::Encoding::PLAIN_DICTIONARY || + page->encoding() == ::parquet::Encoding::PLAIN) { + auto dictionary = MakeParquetTypedDecoder(::parquet::Encoding::PLAIN, descr_); + dictionary->SetData(page->num_values(), page->data(), page->size()); + + // The dictionary is fully decoded during DictionaryDecoder::Init, so the + // DictionaryPage buffer is no longer required after this step + std::unique_ptr> decoder = MakeParquetDictDecoder(descr_, pool_); + decoder->SetDict(dynamic_cast(dictionary.get())); + decoders_[encoding] = + std::unique_ptr(dynamic_cast(decoder.release())); + } else { + ParquetException::NYI("only plain dictionary encoding has been implemented"); + } + + new_dictionary_ = true; + current_decoder_ = decoders_[encoding].get(); + DCHECK(current_decoder_); +} + +// Initialize repetition and definition level decoders on the next data page. + +// If the data page includes repetition and definition levels, we +// initialize the level decoders and return the number of encoded level bytes. +// The return value helps determine the number of bytes in the encoded data. +template +int64_t ParquetColumnReaderBase::InitializeLevelDecoders(const DataPage& page, + ::parquet::Encoding::type repetition_level_encoding, + ::parquet::Encoding::type definition_level_encoding) { + // Read a data page. + num_buffered_values_ = page.num_values(); + + // Have not decoded any values from the data page yet + num_decoded_values_ = 0; + + const uint8_t* buffer = page.data(); + int32_t levels_byte_size = 0; + int32_t max_size = page.size(); + + // Data page Layout: Repetition Levels - Definition Levels - encoded values. + // Levels are encoded as rle or bit-packed. + // Init repetition levels + if (max_rep_level_ > 0) { + int32_t rep_levels_bytes = repetition_level_decoder_.SetData( + repetition_level_encoding, max_rep_level_, + static_cast(num_buffered_values_), buffer, max_size); + buffer += rep_levels_bytes; + levels_byte_size += rep_levels_bytes; + max_size -= rep_levels_bytes; + } + + // Init definition levels + if (max_def_level_ > 0) { + int32_t def_levels_bytes = definition_level_decoder_.SetData( + definition_level_encoding, max_def_level_, + static_cast(num_buffered_values_), buffer, max_size); + levels_byte_size += def_levels_bytes; + max_size -= def_levels_bytes; + } + + return levels_byte_size; +} + +template +int64_t ParquetColumnReaderBase::InitializeLevelDecodersV2(const ::parquet::DataPageV2& page) { + // Read a data page. + num_buffered_values_ = page.num_values(); + + // Have not decoded any values from the data page yet + num_decoded_values_ = 0; + const uint8_t* buffer = page.data(); + + const int64_t total_levels_length = + static_cast(page.repetition_levels_byte_length()) + + page.definition_levels_byte_length(); + + if (total_levels_length > page.size()) { + throw ParquetException("Data page too small for levels (corrupt header?)"); + } + + if (max_rep_level_ > 0) { + repetition_level_decoder_.SetDataV2(page.repetition_levels_byte_length(), + max_rep_level_, static_cast(num_buffered_values_), buffer); + } + // ARROW-17453: Even if max_rep_level_ is 0, there may still be + // repetition level bytes written and/or reported in the header by + // some writers (e.g. Athena) + buffer += page.repetition_levels_byte_length(); + + if (max_def_level_ > 0) { + definition_level_decoder_.SetDataV2(page.definition_levels_byte_length(), + max_def_level_, static_cast(num_buffered_values_), buffer); + } + + return total_levels_length; +} + +static bool IsDictionaryIndexEncoding(const ::parquet::Encoding::type& e) { + return e == ::parquet::Encoding::RLE_DICTIONARY || e == ::parquet::Encoding::PLAIN_DICTIONARY; +} + +// Get a decoder object for this page or create a new decoder if this is the +// first page with this encoding. +template +void ParquetColumnReaderBase::InitializeDataDecoder(const DataPage& page, int64_t levels_byte_size) { + const uint8_t* buffer = page.data() + levels_byte_size; + const int64_t data_size = page.size() - levels_byte_size; + + if (data_size < 0) { + throw ParquetException("Page smaller than size of encoded levels"); + } + + ::parquet::Encoding::type encoding = page.encoding(); + + if (IsDictionaryIndexEncoding(encoding)) { + encoding = ::parquet::Encoding::RLE_DICTIONARY; + } + + auto it = decoders_.find(static_cast(encoding)); + if (it != decoders_.end()) { + DCHECK(it->second.get() != nullptr); + current_decoder_ = it->second.get(); + } else { + switch (encoding) { + case ::parquet::Encoding::PLAIN: { + auto decoder = MakeParquetTypedDecoder(::parquet::Encoding::PLAIN, descr_); + current_decoder_ = decoder.get(); + decoders_[static_cast(encoding)] = std::move(decoder); + break; + } + case ::parquet::Encoding::RLE: { + auto decoder = MakeParquetTypedDecoder(::parquet::Encoding::PLAIN, descr_); + current_decoder_ = decoder.get(); + decoders_[static_cast(encoding)] = std::move(decoder); + break; + } + case ::parquet::Encoding::RLE_DICTIONARY: + case ::parquet::Encoding::BYTE_STREAM_SPLIT: + case ::parquet::Encoding::DELTA_BINARY_PACKED: + case ::parquet::Encoding::DELTA_BYTE_ARRAY: + case ::parquet::Encoding::DELTA_LENGTH_BYTE_ARRAY: + default: + throw ParquetException("Unknown encoding type."); + } + } + current_encoding_ = encoding; + current_decoding_type = DType::type_num; + current_decoder_->SetData(static_cast(num_buffered_values_), buffer,static_cast(data_size)); +} + +std::shared_ptr MakeByteArrayRecordReader(const ColumnDescriptor* descr, + LevelInfo leaf_info, + ::arrow::MemoryPool* pool, + bool read_dictionary) { + if (read_dictionary) { + std::stringstream ss; + ss << "Invalid ParquetByteArrayDictionary is not implement yet " << static_cast(descr->physical_type()); + throw ParquetException(ss.str()); + } else { + return std::make_shared(descr, leaf_info, pool); + } +} + +std::shared_ptr MakeRecordReader(const ColumnDescriptor* descr, + LevelInfo leaf_info, ::arrow::MemoryPool* pool, + bool read_dictionary, + const std::shared_ptr<::arrow::DataType>& type) { + switch (type->id()) { + case ::arrow::Type::BOOL: { + return std::make_shared>(descr, + leaf_info, pool); + } + case ::arrow::Type::INT16: { + return std::make_shared(descr, leaf_info, pool); + } + case ::arrow::Type::INT32: { + return std::make_shared>(descr, leaf_info, pool); + } + case ::arrow::Type::DATE32: { + return std::make_shared>(descr, + leaf_info, pool); + } + case ::arrow::Type::INT64: { + return std::make_shared>(descr, leaf_info, pool); + } + case ::arrow::Type::DATE64: { + return std::make_shared>(descr, + leaf_info, pool); + } + case ::arrow::Type::DOUBLE: { + return std::make_shared>(descr, + leaf_info, pool); + } + case ::arrow::Type::STRING: { + return MakeByteArrayRecordReader(descr, leaf_info, pool, read_dictionary); + } + case ::arrow::Type::DECIMAL: { + switch (descr->physical_type()) { + case ::parquet::Type::INT32: + return std::make_shared(descr, leaf_info, pool); + case ::parquet::Type::INT64: + return std::make_shared>(descr, leaf_info, pool); + case ::parquet::Type::FIXED_LEN_BYTE_ARRAY: { + int32_t precision = ::arrow::internal::checked_cast(*type).precision(); + if (precision > PARQUET_MAX_DECIMAL64_DIGITS) { + return std::make_shared(descr, leaf_info, pool); + } else { + return std::make_shared(descr, leaf_info, pool); + } + } + default: + std::stringstream ss; + ss << "RecordReader not support decimal type " << static_cast(descr->physical_type()); + throw ParquetException(ss.str()); + } + } + default: { + // PARQUET-1481: This can occur if the file is corrupt + std::stringstream ss; + ss << "Invalid physical column type: " << static_cast(descr->physical_type()); + throw ParquetException(ss.str()); + } + } + // Unreachable code, but suppress compiler warning + return nullptr; +} + +// Helper function used by Decimal128::FromBigEndian +static inline uint64_t UInt64FromBigEndian(const uint8_t* bytes, int32_t length) { + // We don't bounds check the length here because this is called by + // FromBigEndian that has a Decimal128 as its out parameters and + // that function is already checking the length of the bytes and only + // passes lengths between zero and eight. + uint64_t result = 0; + // Using memcpy instead of special casing for length + // and doing the conversion in 16, 32 parts, which could + // possibly create unaligned memory access on certain platforms + memcpy_s(reinterpret_cast(&result) + 8 - length, length, bytes, length); + return ::arrow::bit_util::FromBigEndian(result); +} + +static inline Result FromBigEndianToOmniDecimal128(const uint8_t* bytes, int32_t length) { + static constexpr int32_t kMinDecimalBytes = 1; + static constexpr int32_t kMaxDecimalBytes = 16; + + int64_t high, low; + + if (ARROW_PREDICT_FALSE(length < kMinDecimalBytes || length > kMaxDecimalBytes)) { + return Status::Invalid("Length of byte array passed to Decimal128::FromBigEndian ", + "was ", length, ", but must be between ", kMinDecimalBytes, + " and ", kMaxDecimalBytes); + } + + // Bytes are coming in big-endian, so the first byte is the MSB and therefore holds the + // sign bit. + const bool is_negative = static_cast(bytes[0]) < 0; + + // 1. Extract the high bytes + // Stop byte of the high bytes + const int32_t high_bits_offset = std::max(0, length - DECIMAL64_LEN); + const auto high_bits = UInt64FromBigEndian(bytes, high_bits_offset); + + if (high_bits_offset == DECIMAL64_LEN) { + // Avoid undefined shift by 64 below + high = high_bits; + } else { + high = -1 * (is_negative && length < kMaxDecimalBytes); + // Shift left enough bits to make room for the incoming int64_t + high = SafeLeftShift(high, high_bits_offset * CHAR_BIT); + // Preserve the upper bits by inplace OR-ing the int64_t + high |= high_bits; + } + + // 2. Extract the low bytes + // Stop byte of the low bytes + const int32_t low_bits_offset = std::min(length, DECIMAL64_LEN); + const auto low_bits = + UInt64FromBigEndian(bytes + high_bits_offset, length - high_bits_offset); + + if (low_bits_offset == DECIMAL64_LEN) { + // Avoid undefined shift by 64 below + low = low_bits; + } else { + // Sign extend the low bits if necessary + low = -1 * (is_negative && length < DECIMAL64_LEN); + // Shift left enough bits to make room for the incoming int64_t + low = SafeLeftShift(low, low_bits_offset * CHAR_BIT); + // Preserve the upper bits by inplace OR-ing the int64_t + low |= low_bits; + } + + __int128_t temp_high = high; + temp_high = temp_high << (8 * CHAR_BIT); + __int128_t val = temp_high | static_cast(low); + + return omniruntime::type::Decimal128(val); +} + +Status RawBytesToDecimal128Bytes(const uint8_t* bytes, int32_t length, + omniruntime::vec::BaseVector** out_buf, int64_t index) { + auto out = static_cast*>(*out_buf); + ARROW_ASSIGN_OR_RAISE(auto t, FromBigEndianToOmniDecimal128(bytes, length)); + out->SetValue(index, t); + return Status::OK(); +} + +Status RawBytesToDecimal64Bytes(const uint8_t* bytes, int32_t length, + omniruntime::vec::BaseVector** out_buf, int64_t index) { + auto out = static_cast*>(*out_buf); + + // Directly Extract the low bytes + // Stop byte of the low bytes + int64_t low = 0; + const bool is_negative = static_cast(bytes[0]) < 0; + const int32_t low_bits_offset = std::min(length, DECIMAL64_LEN); + auto low_bits = UInt64FromBigEndian(bytes, low_bits_offset); + + if (low_bits_offset == DECIMAL64_LEN) { + // Avoid undefined shift by 64 below + low = low_bits; + } else { + // Sign extend the low bits if necessary + low = -1 * (is_negative && length < DECIMAL64_LEN); + // Shift left enough bits to make room for the incoming int64_t + low = SafeLeftShift(low, low_bits_offset * CHAR_BIT); + // Preserve the upper bits by inplace OR-ing the int64_t + low |= low_bits; + } + + out->SetValue(index, low); + return Status::OK(); +} + +void DefLevelsToNullsSIMD(const int16_t* def_levels, int64_t num_def_levels, const int16_t max_def_level, + int64_t* values_read, int64_t* null_count, bool* nulls) { + for (int i = 0; i < num_def_levels; ++i) { + if (def_levels[i] < max_def_level) { + nulls[i] = true; + (*null_count)++; + } + } + *values_read = num_def_levels; +} + +void DefLevelsToNulls(const int16_t* def_levels, int64_t num_def_levels, LevelInfo level_info, + int64_t* values_read, int64_t* null_count, bool* nulls) { + if (level_info.rep_level == 0) { + DefLevelsToNullsSIMD(def_levels, num_def_levels, level_info.def_level, values_read, null_count, nulls); + } else { + ::ParquetException::NYI("rep_level > 0 NYI"); + } +} + +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h new file mode 100644 index 0000000000000000000000000000000000000000..3f602c979d71e76d2995e99c15f73355c956a98c --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h @@ -0,0 +1,847 @@ +/** + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef OMNI_RUNTIME_COLUMN_TYPE_READER_H +#define OMNI_RUNTIME_COLUMN_TYPE_READER_H + +#include "ParquetDecoder.h" +#include +#include +#include + +using ResizableBuffer = ::arrow::ResizableBuffer; +using namespace omniruntime::vec; + +namespace omniruntime::reader { + constexpr int64_t kMinLevelBatchSize = 1024; + static constexpr int32_t PARQUET_MAX_DECIMAL64_DIGITS = 18; + + inline void CheckNumberDecoded(int64_t number_decoded, int64_t expected) { + if (ARROW_PREDICT_FALSE(number_decoded != expected)) { + ::parquet::ParquetException::EofException("Decoded values " + std::to_string(number_decoded) + + " does not match expected" + std::to_string(expected)); + } + } + + template + SignedInt SafeLeftShift(SignedInt u, Shift shift) { + using UnsignedInt = typename std::make_unsigned::type; + return static_cast(static_cast(u) << shift); + } + + ::arrow::Status RawBytesToDecimal128Bytes(const uint8_t* bytes, int32_t length, BaseVector** out_buf, int64_t index); + + ::arrow::Status RawBytesToDecimal64Bytes(const uint8_t* bytes, int32_t length, BaseVector** out_buf, int64_t index); + + void DefLevelsToNulls(const int16_t* def_levels, int64_t num_def_levels, ::parquet::internal::LevelInfo level_info, + int64_t* values_read, int64_t* null_count, bool* nulls); + + template + class ParquetColumnReaderBase { + public: + using T = typename DType::c_type; + + ParquetColumnReaderBase(const ::parquet::ColumnDescriptor* descr, ::arrow::MemoryPool* pool) + : descr_(descr), + max_def_level_(descr->max_definition_level()), + max_rep_level_(descr->max_repetition_level()), + num_buffered_values_(0), + num_decoded_values_(0), + pool_(pool), + current_decoder_(nullptr), + current_encoding_(::parquet::Encoding::UNKNOWN) {} + + virtual ~ParquetColumnReaderBase() = default; + + protected: + int64_t ReadDefinitionLevels(int64_t batch_size, int16_t* levels) { + if (max_def_level_ == 0) { + return 0; + } + return definition_level_decoder_.Decode(static_cast(batch_size), levels); + } + + bool HasNextInternal() { + if (num_buffered_values_ == 0 || num_decoded_values_ == num_buffered_values_) { + if (!ReadNewPage() || num_buffered_values_ == 0) { + return false; + } + } + return true; + } + + int64_t ReadRepetitionLevels(int64_t batch_size, int16_t* levels) { + if (max_rep_level_ == 0) { + return 0; + } + return repetition_level_decoder_.Decode(static_cast(batch_size), levels); + } + + bool ReadNewPage(); + + void ConfigureDictionary(const ::parquet::DictionaryPage* page); + + int64_t InitializeLevelDecoders(const ::parquet::DataPage& page, + ::parquet::Encoding::type repetition_level_encoding, + ::parquet::Encoding::type definition_level_encoding); + + int64_t InitializeLevelDecodersV2(const ::parquet::DataPageV2& page); + + void InitializeDataDecoder(const ::parquet::DataPage& page, int64_t levels_byte_size); + + int64_t available_values_current_page() const { + return num_buffered_values_ - num_decoded_values_; + } + + const ::parquet::ColumnDescriptor* descr_; + const int16_t max_def_level_; + const int16_t max_rep_level_; + + std::unique_ptr<::parquet::PageReader> pager_; + std::shared_ptr<::parquet::Page> current_page_; + + ::parquet::LevelDecoder definition_level_decoder_; + ::parquet::LevelDecoder repetition_level_decoder_; + + int64_t num_buffered_values_; + int64_t num_decoded_values_; + + ::arrow::MemoryPool* pool_; + + using DecoderType = ParquetTypedDecoder; + DecoderType* current_decoder_; + ::parquet::Encoding::type current_encoding_; + ::parquet::Type::type current_decoding_type; + + bool new_dictionary_ = false; + + std::unordered_map> decoders_; + + void ConsumeBufferedValues(int64_t num_values) { + num_decoded_values_ += num_values; + } + }; + + class OmniRecordReader { + public: + virtual ~OmniRecordReader() = default; + + /// \brief Attempt to read indicated number of records from column chunk + /// Note that for repeated fields, a record may have more than one value + /// and all of them are read. + virtual int64_t ReadRecords(int64_t num_records) = 0; + + /// \brief Attempt to skip indicated number of records from column chunk. + /// Note that for repeated fields, a record may have more than one value + /// and all of them are skipped. + /// \return number of records skipped + virtual int64_t SkipRecords(int64_t num_records) = 0; + + /// \brief Pre-allocate space for data. Results in better flat read performance + virtual void Reserve(int64_t num_values) = 0; + + /// \brief Clear consumed values and repetition/definition levels as the + /// result of calling ReadRecords + virtual void Reset() = 0; + + /// \brief Return true if the record reader has more internal data yet to + /// process + virtual bool HasMoreData() const = 0; + + /// \brief Advance record reader to the next row group. Must be set before + /// any records could be read/skipped. + /// \param[in] reader obtained from RowGroupReader::GetColumnPageReader + virtual void SetPageReader(std::unique_ptr reader) = 0; + + virtual BaseVector* GetBaseVec() = 0; + + /// \brief Decoded definition levels + int16_t* def_levels() const { + return reinterpret_cast(def_levels_->mutable_data()); + } + + /// \brief Decoded repetition levels + int16_t* rep_levels() const { + return reinterpret_cast(rep_levels_->mutable_data()); + } + + /// \brief Decoded values, including nulls, if any + /// FLBA and ByteArray types do not use this array and read into their own + /// builders. + uint8_t* values() const { return values_->mutable_data(); } + + /// \brief Number of values written, including space left for nulls if any. + /// If this Reader was constructed with read_dense_for_nullable(), there is no space for + /// nulls and null_count() will be 0. There is no read-ahead/buffering for values. For + /// FLBA and ByteArray types this value reflects the values written with the last + /// ReadRecords call since those readers will reset the values after each call. + int64_t values_written() const { return values_written_; } + + /// \brief Number of definition / repetition levels (from those that have + /// been decoded) that have been consumed inside the reader. + int64_t levels_position() const { return levels_position_; } + + /// \brief Number of definition / repetition levels that have been written + /// internally in the reader. This may be larger than values_written() because + /// for repeated fields we need to look at the levels in advance to figure out + /// the record boundaries. + int64_t levels_written() const { return levels_written_; } + + /// \brief Number of nulls in the leaf that we have read so far into the + /// values vector. This is only valid when !read_dense_for_nullable(). When + /// read_dense_for_nullable() it will always be 0. + int64_t null_count() const { return null_count_; } + + /// \brief True if the leaf values are nullable + bool nullable_values() const { return nullable_values_; } + + /// \brief True if reading directly as Arrow dictionary-encoded + bool read_dictionary() const { return read_dictionary_; } + + + /// \brief Indicates if we can have nullable values. Note that repeated fields + /// may or may not be nullable. + bool nullable_values_; + + bool at_record_start_; + int64_t records_read_; + int64_t values_decode_; + + /// \brief Stores values. These values are populated based on each ReadRecords + /// call. No extra values are buffered for the next call. SkipRecords will not + /// add any value to this buffer. + std::shared_ptr values_; + /// \brief False for BYTE_ARRAY, in which case we don't allocate the values + /// buffer and we directly read into builder classes. + bool uses_values_; + + /// \brief Values that we have read into 'values_' + 'null_count_'. + int64_t values_written_; + int64_t values_capacity_; + int64_t null_count_; + + /// \brief Buffer for definition levels. May contain more levels than + /// is actually read. This is because we read levels ahead to + /// figure out record boundaries for repeated fields. + /// For flat required fields, 'def_levels_' and 'rep_levels_' are not + /// populated. For non-repeated fields 'rep_levels_' is not populated. + /// 'def_levels_' and 'rep_levels_' must be of the same size if present. + std::shared_ptr def_levels_; + /// \brief Buffer for repetition levels. Only populated for repeated + /// fields. + std::shared_ptr rep_levels_; + + /// \brief Number of definition / repetition levels that have been written + /// internally in the reader. This may be larger than values_written() since + /// for repeated fields we need to look at the levels in advance to figure out + /// the record boundaries. + int64_t levels_written_; + /// \brief Position of the next level that should be consumed. + int64_t levels_position_; + int64_t levels_capacity_; + + bool read_dictionary_ = false; + }; + + /** + * ParquetTypedRecordReader is used to generate omnivector directly from the def_level/rep_level/values. + * And we directly use omnivector's nulls to store each null value flag instead of bitmap to reduce extra cost. + * When setting omnivector's values, it can choose whether transferring values according to the TYPE_ID and DType. + * @tparam TYPE_ID omni type + * @tparam DType parquet store type + */ + template + class ParquetTypedRecordReader : public ParquetColumnReaderBase, virtual public OmniRecordReader { + public: + using T = typename DType::c_type; + using V = typename NativeType::type; + using BASE = ParquetColumnReaderBase; + + explicit ParquetTypedRecordReader(const ::parquet::ColumnDescriptor* descr, + ::parquet::internal::LevelInfo leaf_info, ::arrow::MemoryPool* pool) + // Pager must be set using SetPageReader. + : BASE(descr, pool) { + leaf_info_ = leaf_info; + nullable_values_ = leaf_info.HasNullableValues(); + at_record_start_ = true; + values_written_ = 0; + null_count_ = 0; + values_capacity_ = 0; + levels_written_ = 0; + levels_position_ = 0; + levels_capacity_ = 0; + uses_values_ = !(descr->physical_type() == ::parquet::Type::BYTE_ARRAY); + byte_width_ = descr->type_length(); + values_decode_ = 0; + + if (uses_values_) { + values_ = ::parquet::AllocateBuffer(pool); + } + def_levels_ = ::parquet::AllocateBuffer(pool); + rep_levels_ = ::parquet::AllocateBuffer(pool); + Reset(); + } + + ~ParquetTypedRecordReader() { + if (parquet_vec_ != nullptr) { + delete[] parquet_vec_; + } + } + + // Compute the values capacity in bytes for the given number of elements + int64_t bytes_for_values(int64_t nitems) const { + int64_t type_size = GetTypeByteSize(this->descr_->physical_type()); + int64_t bytes_for_values = -1; + if (::arrow::internal::MultiplyWithOverflow(nitems, type_size, &bytes_for_values)) { + throw ::parquet::ParquetException("Total size of items too large"); + } + return bytes_for_values; + } + + int64_t ReadRecords(int64_t num_records) override { + if (num_records == 0) return 0; + // Delimit records, then read values at the end + int64_t records_read = 0; + + if (has_values_to_process()) { + records_read += ReadRecordData(num_records); + } + + int64_t level_batch_size = std::max(kMinLevelBatchSize, num_records); + + // If we are in the middle of a record, we continue until reaching the + // desired number of records or the end of the current record if we've found + // enough records + while (!at_record_start_ || records_read < num_records) { + // Is there more data to read in this row group? + if (!this->HasNextInternal()) { + if (!at_record_start_) { + // We ended the row group while inside a record that we haven't seen + // the end of yet. So increment the record count for the last record in + // the row group + ++records_read; + at_record_start_ = true; + } + break; + } + + /// We perform multiple batch reads until we either exhaust the row group + /// or observe the desired number of records + int64_t batch_size = + std::min(level_batch_size, this->available_values_current_page()); + + // No more data in column + if (batch_size == 0) { + break; + } + + if (this->max_def_level_ > 0) { + ReserveLevels(batch_size); + + int16_t* def_levels = this->def_levels() + levels_written_; + int16_t* rep_levels = this->rep_levels() + levels_written_; + + // Not present for non-repeated fields + int64_t levels_read = 0; + if (this->max_rep_level_ > 0) { + levels_read = this->ReadDefinitionLevels(batch_size, def_levels); + if (this->ReadRepetitionLevels(batch_size, rep_levels) != levels_read) { + throw ::parquet::ParquetException("Number of decoded rep / def levels did not match"); + } + } else if (this->max_def_level_ > 0) { + levels_read = this->ReadDefinitionLevels(batch_size, def_levels); + } + + // Exhausted column chunk + if (levels_read == 0) { + break; + } + + levels_written_ += levels_read; + records_read += ReadRecordData(num_records - records_read); + } else { + // No repetition or definition levels + batch_size = std::min(num_records - records_read, batch_size); + records_read += ReadRecordData(batch_size); + } + } + + return records_read; + } + + // Throw away levels from start_levels_position to levels_position_. + // Will update levels_position_, levels_written_, and levels_capacity_ + // accordingly and move the levels to left to fill in the gap. + // It will resize the buffer without releasing the memory allocation. + void ThrowAwayLevels(int64_t start_levels_position) { + ARROW_DCHECK_LE(levels_position_, levels_written_); + ARROW_DCHECK_LE(start_levels_position, levels_position_); + ARROW_DCHECK_GT(this->max_def_level_, 0); + ARROW_DCHECK_NE(def_levels_, nullptr); + + int64_t gap = levels_position_ - start_levels_position; + if (gap == 0) return; + + int64_t levels_remaining = levels_written_ - gap; + + auto left_shift = [&](ResizableBuffer* buffer) { + int16_t* data = reinterpret_cast(buffer->mutable_data()); + std::copy(data + levels_position_, data + levels_written_, + data + start_levels_position); + PARQUET_THROW_NOT_OK(buffer->Resize(levels_remaining * sizeof(int16_t), + /*shrink_to_fit=*/false)); + }; + + left_shift(def_levels_.get()); + + if (this->max_rep_level_ > 0) { + ARROW_DCHECK_NE(rep_levels_, nullptr); + left_shift(rep_levels_.get()); + } + + levels_written_ -= gap; + levels_position_ -= gap; + levels_capacity_ -= gap; + } + + + int64_t SkipRecords(int64_t num_records) override { + throw ::parquet::ParquetException("SkipRecords not implemented yet"); + } + + // We may outwardly have the appearance of having exhausted a column chunk + // when in fact we are in the middle of processing the last batch + bool has_values_to_process() const { return levels_position_ < levels_written_; } + + // Process written repetition/definition levels to reach the end of + // records. Only used for repeated fields. + // Process no more levels than necessary to delimit the indicated + // number of logical records. Updates internal state of RecordReader + // + // \return Number of records delimited + int64_t DelimitRecords(int64_t num_records, int64_t* values_seen) { + int64_t values_to_read = 0; + int64_t records_read = 0; + + const int16_t* def_levels = this->def_levels() + levels_position_; + const int16_t* rep_levels = this->rep_levels() + levels_position_; + + DCHECK_GT(this->max_rep_level_, 0); + + // Count logical records and number of values to read + while (levels_position_ < levels_written_) { + const int16_t rep_level = *rep_levels++; + if (rep_level == 0) { + // If at_record_start_ is true, we are seeing the start of a record + // for the second time, such as after repeated calls to + // DelimitRecords. In this case we must continue until we find + // another record start or exhausting the ColumnChunk + if (!at_record_start_) { + // We've reached the end of a record; increment the record count. + ++records_read; + if (records_read == num_records) { + // We've found the number of records we were looking for. Set + // at_record_start_ to true and break + at_record_start_ = true; + break; + } + } + } + // We have decided to consume the level at this position; therefore we + // must advance until we find another record boundary + at_record_start_ = false; + + const int16_t def_level = *def_levels++; + if (def_level == this->max_def_level_) { + ++values_to_read; + } + ++levels_position_; + } + *values_seen = values_to_read; + return records_read; + } + + void Reserve(int64_t capacity) override { + ReserveLevels(capacity); + ReserveValues(capacity); + InitVec(capacity); + } + + virtual void InitVec(int64_t capacity) { + vec_ = new Vector(capacity); + auto capacity_bytes = capacity * byte_width_; + if (parquet_vec_ != nullptr) { + memset(parquet_vec_, 0, capacity_bytes); + } else { + parquet_vec_ = new uint8_t[capacity_bytes]; + } + // Init nulls + if (nullable_values_) { + nulls_ = unsafe::UnsafeBaseVector::GetNulls(vec_); + } + } + + + int64_t UpdateCapacity(int64_t capacity, int64_t size, int64_t extra_size) { + if (extra_size < 0) { + throw ::parquet::ParquetException("Negative size (corrupt file?)"); + } + int64_t target_size = -1; + if (::arrow::internal::AddWithOverflow(size, extra_size, &target_size)) { + throw ::parquet::ParquetException("Allocation size too large (corrupt file?)"); + } + if (target_size >= (1LL << 62)) { + throw ::parquet::ParquetException("Allocation size too large (corrupt file?)"); + } + if (capacity >= target_size) { + return capacity; + } + return ::arrow::bit_util::NextPower2(target_size); + } + + void ReserveLevels(int64_t extra_levels) { + if (this->max_def_level_ > 0) { + const int64_t new_levels_capacity = + UpdateCapacity(levels_capacity_, levels_written_, extra_levels); + if (new_levels_capacity > levels_capacity_) { + constexpr auto kItemSize = static_cast(sizeof(int16_t)); + int64_t capacity_in_bytes = -1; + if (::arrow::internal::MultiplyWithOverflow(new_levels_capacity, kItemSize, &capacity_in_bytes)) { + throw ::parquet::ParquetException("Allocation size too large (corrupt file?)"); + } + PARQUET_THROW_NOT_OK( + def_levels_->Resize(capacity_in_bytes, /*shrink_to_fit=*/false)); + if (this->max_rep_level_ > 0) { + PARQUET_THROW_NOT_OK( + rep_levels_->Resize(capacity_in_bytes, /*shrink_to_fit=*/false)); + } + levels_capacity_ = new_levels_capacity; + } + } + } + + void ReserveValues(int64_t extra_values) { + const int64_t new_values_capacity = + UpdateCapacity(values_capacity_, values_written_, extra_values); + if (new_values_capacity > values_capacity_) { + // XXX(wesm): A hack to avoid memory allocation when reading directly + // into builder classes + if (uses_values_) { + PARQUET_THROW_NOT_OK(values_->Resize(bytes_for_values(new_values_capacity), + /*shrink_to_fit=*/false)); + } + values_capacity_ = new_values_capacity; + } + } + + void Reset() override { + ResetValues(); + if (levels_written_ > 0) { + // Throw away levels from 0 to levels_position_. + ThrowAwayLevels(0); + } + + vec_ = nullptr; + } + + void SetPageReader(std::unique_ptr<::parquet::PageReader> reader) override { + at_record_start_ = true; + this->pager_ = std::move(reader); + ResetDecoders(); + } + + bool HasMoreData() const override { return this->pager_ != nullptr; } + + const ::parquet::ColumnDescriptor* descr() const { return this->descr_; } + + // Dictionary decoders must be reset when advancing row groups + void ResetDecoders() { this->decoders_.clear(); } + + virtual void ReadValuesSpaced(int64_t values_with_nulls, int64_t null_count) { + int64_t num_decoded = this->current_decoder_->DecodeSpaced( + ValuesHead(), static_cast(values_with_nulls), + static_cast(null_count), nulls_ + values_written_); + CheckNumberDecoded(num_decoded, values_with_nulls); + } + + virtual void ReadValuesDense(int64_t values_to_read) { + int64_t num_decoded = + this->current_decoder_->Decode(ValuesHead(), static_cast(values_to_read)); + CheckNumberDecoded(num_decoded, values_to_read); + } + + // Return number of logical records read. + int64_t ReadRecordData(int64_t num_records) { + // Conservative upper bound + const int64_t possible_num_values = + std::max(num_records, levels_written_ - levels_position_); + ReserveValues(possible_num_values); + + const int64_t start_levels_position = levels_position_; + + int64_t records_read = 0; + int64_t values_to_read = 0; + if (this->max_rep_level_ > 0) { + records_read = DelimitRecords(num_records, &values_to_read); + } else if (this->max_def_level_ > 0) { + records_read = std::min(levels_written_ - levels_position_, num_records); + levels_position_ += records_read; + } else { + records_read = values_to_read = num_records; + } + + int64_t null_count = 0; + if (leaf_info_.HasNullableValues()) { + int64_t values_read = 0; + DefLevelsToNulls(def_levels() + start_levels_position, levels_position_ - start_levels_position, leaf_info_, + &values_read, &null_count, nulls_ + start_levels_position); + values_to_read = values_read - null_count; + DCHECK_GE(values_to_read, 0); + ReadValuesSpaced(values_read, null_count); + } else { + DCHECK_GE(values_to_read, 0); + ReadValuesDense(values_to_read); + } + + if (this->leaf_info_.def_level > 0) { + // Optional, repeated, or some mix thereof + this->ConsumeBufferedValues(levels_position_ - start_levels_position); + } else { + // Flat, non-repeated + this->ConsumeBufferedValues(values_to_read); + } + // Total values, including null spaces, if any + values_written_ += values_to_read + null_count; + null_count_ += null_count; + + return records_read; + } + + void ResetValues() { + if (values_written_ <= 0) { + return; + } + // Resize to 0, but do not shrink to fit + if (uses_values_) { + PARQUET_THROW_NOT_OK(values_->Resize(0, /*shrink_to_fit=*/false)); + } + values_written_ = 0; + values_capacity_ = 0; + null_count_ = 0; + values_decode_ = 0; + } + + virtual BaseVector* GetBaseVec() { + if (vec_ == nullptr) { + throw ::parquet::ParquetException("BaseVector is nullptr!"); + } + auto res = dynamic_cast*>(vec_); + res->SetValues(0, Values(), values_written_); + return vec_; + } + + protected: + template + T* ValuesHead() { + return reinterpret_cast(values_->mutable_data()) + values_written_; + } + + template + T* Values() const { + return reinterpret_cast(values_->mutable_data()); + } + ::parquet::internal::LevelInfo leaf_info_; + omniruntime::vec::BaseVector* vec_ = nullptr; + uint8_t* parquet_vec_ = nullptr; + bool* nulls_ = nullptr; + int32_t byte_width_; + }; + + class ParquetShortRecordReader : public ParquetTypedRecordReader { + public: + using BASE = ParquetTypedRecordReader; + ParquetShortRecordReader(const ::parquet::ColumnDescriptor* descr, ::parquet::internal::LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : BASE(descr, leaf_info, pool) {} + + BaseVector* GetBaseVec() override { + if (vec_ == nullptr) { + throw ::parquet::ParquetException("GetBaseVec() is nullptr!"); + } + auto res = dynamic_cast *>(vec_); + auto values = Values(); + for (int i = 0; i < values_written_; i++) { + res->SetValue(i, static_cast(values[i])); + } + return vec_; + } + }; + + class ParquetIntDecimal64RecordReader : public ParquetTypedRecordReader { + public: + using BASE = ParquetTypedRecordReader; + ParquetIntDecimal64RecordReader(const ::parquet::ColumnDescriptor* descr, ::parquet::internal::LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : BASE(descr, leaf_info, pool) {} + + BaseVector* GetBaseVec() override { + if (vec_ == nullptr) { + throw ::parquet::ParquetException("GetBaseVec() is nullptr!"); + } + auto res = dynamic_cast *>(vec_); + auto values = Values(); + for (int i = 0; i < values_written_; i++) { + res->SetValue(i, static_cast(values[i])); + } + return vec_; + } + }; + + class ParquetFLBADecimal64RecordReader : public ParquetTypedRecordReader { + public: + ParquetFLBADecimal64RecordReader(const ::parquet::ColumnDescriptor* descr, ::parquet::internal::LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : ParquetTypedRecordReader(descr, leaf_info, pool) {} + + void ReadValuesDense(int64_t values_to_read) override { + uint8_t* values = GetParquetVecOffsetPtr(0); + int64_t num_decoded = this->current_decoder_->Decode( + reinterpret_cast<::parquet::FixedLenByteArray*>(values), static_cast(values_to_read)); + values_decode_ += num_decoded; + DCHECK_EQ(num_decoded, values_to_read); + } + + void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { + uint8_t* values = GetParquetVecOffsetPtr(0); + int64_t no_null_values_to_read = values_to_read - null_count; + int64_t num_decoded = this->current_decoder_->Decode( + reinterpret_cast<::parquet::FixedLenByteArray*>(values), static_cast(no_null_values_to_read)); + values_decode_ += num_decoded; + DCHECK_EQ(num_decoded, no_null_values_to_read); + } + + uint8_t* GetParquetVecOffsetPtr(int index) { + return parquet_vec_ + (index + values_decode_) * byte_width_; + } + + uint8_t* GetParquetVecHeadPtr(int index) { + return parquet_vec_ + index * byte_width_; + } + + BaseVector* GetBaseVec() override { + if (vec_ == nullptr) { + throw ::parquet::ParquetException("GetBaseVector() is nullptr"); + } + int index = 0; + for (int64_t i = 0; i < values_written_; i++) { + if (nulls_ == nullptr || !nulls_[i]) { + PARQUET_THROW_NOT_OK(RawBytesToDecimal64Bytes(GetParquetVecHeadPtr(index++), byte_width_, &vec_, i)); + } + } + return vec_; + } + }; + + class ParquetFLBADecimal128RecordReader : public ParquetTypedRecordReader { + public: + ParquetFLBADecimal128RecordReader(const ::parquet::ColumnDescriptor* descr, ::parquet::internal::LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : ParquetTypedRecordReader(descr, leaf_info, pool) {} + + void ReadValuesDense(int64_t values_to_read) override { + uint8_t* values = GetParquetVecOffsetPtr(0); + int64_t num_decoded = this->current_decoder_->Decode( + reinterpret_cast<::parquet::FixedLenByteArray*>(values), static_cast(values_to_read)); + values_decode_ += num_decoded; + DCHECK_EQ(num_decoded, values_to_read); + } + + void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { + uint8_t* values = GetParquetVecOffsetPtr(0); + int64_t no_null_values_to_read = values_to_read - null_count; + int64_t num_decoded = this->current_decoder_->Decode( + reinterpret_cast<::parquet::FixedLenByteArray*>(values), static_cast(no_null_values_to_read)); + values_decode_ += num_decoded; + DCHECK_EQ(num_decoded, no_null_values_to_read); + } + + uint8_t* GetParquetVecOffsetPtr(int index) { + return parquet_vec_ + (index + values_decode_) * byte_width_; + } + + uint8_t* GetParquetVecHeadPtr(int index) { + return parquet_vec_ + index * byte_width_; + } + + BaseVector* GetBaseVec() override { + if (vec_ == nullptr) { + throw ::parquet::ParquetException("GetBaseVector() is nullptr"); + } + int index = 0; + for (int64_t i = 0; i < values_written_; i++) { + if (nulls_ == nullptr || !nulls_[i]) { + PARQUET_THROW_NOT_OK(RawBytesToDecimal128Bytes(GetParquetVecHeadPtr(index++), byte_width_, &vec_, i)); + } + } + return vec_; + } + }; + + class ParquetByteArrayChunkedRecordReader : public ParquetTypedRecordReader { + public: + ParquetByteArrayChunkedRecordReader(const ::parquet::ColumnDescriptor* descr, ::parquet::internal::LevelInfo leaf_info, + ::arrow::MemoryPool* pool) + : ParquetTypedRecordReader(descr, leaf_info, pool) { + DCHECK_EQ(descr_->physical_type(), ::parquet::Type::BYTE_ARRAY); + } + + void InitVec(int64_t capacity) override { + vec_ = new Vector>(capacity); + if (nullable_values_) { + nulls_ = unsafe::UnsafeBaseVector::GetNulls(vec_); + } + } + + void ReadValuesDense(int64_t values_to_read) override { + int64_t num_decoded = this->current_decoder_->DecodeArrowNonNull(static_cast(values_to_read), + &vec_, values_written_); + CheckNumberDecoded(num_decoded, values_to_read); + } + + void ReadValuesSpaced(int64_t values_to_read, int64_t null_count) override { + int64_t num_decoded = this->current_decoder_->DecodeArrow( + static_cast(values_to_read), static_cast(null_count), + nulls_, values_written_, &vec_); + CheckNumberDecoded(num_decoded, values_to_read - null_count); + } + + BaseVector* GetBaseVec() { + if (vec_ == nullptr) { + throw ::parquet::ParquetException("GetBaseVec() is nullptr"); + } + return vec_; + } + }; + + std::shared_ptr MakeRecordReader(const ::parquet::ColumnDescriptor* descr, + ::parquet::internal::LevelInfo leaf_info, ::arrow::MemoryPool* pool, + const bool read_dictionary, const std::shared_ptr<::arrow::DataType>& type); +} +#endif //OMNI_RUNTIME_COLUMN_TYPE_READER_H \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3d1d559df94b1137db424b3318b86b956830e09c --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt @@ -0,0 +1,44 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} TEST_ROOT_SRCS) + +add_subdirectory(tablescan) +add_subdirectory(filesystem) +add_subdirectory(io/arrowadapter) +add_subdirectory(io/orcfile) + +# configure +set(TP_TEST_TARGET tptest) +set(MY_LINK + tablescantest + filesystemtest + arrowadaptertest + orcfiletest + ) + +# find gtest package +find_package(GTest REQUIRED) + +# compile a executable file +add_executable(${TP_TEST_TARGET} ${ROOT_SRCS} ${TEST_ROOT_SRCS}) + +# dependent libraries +target_link_libraries(${TP_TEST_TARGET} + ${GTEST_BOTH_LIBRARIES} + ${SOURCE_LINK} + -Wl,--whole-archive + ${MY_LINK} + -Wl,--no-whole-archive + gtest + pthread + stdc++ + dl + boostkit-omniop-vector-1.4.0-aarch64 + securec + spark_columnar_plugin) + +target_compile_options(${TP_TEST_TARGET} PUBLIC -g -O2 -fPIC) + +# dependent include +target_include_directories(${TP_TEST_TARGET} PRIVATE ${GTEST_INCLUDE_DIRS}) + +# discover tests +gtest_discover_tests(${TP_TEST_TARGET}) diff --git a/omnioperator/omniop-native-reader/cpp/test/filesystem/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/filesystem/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..00155fbf5112510ae5cb30f1ccd53581fc9814ff --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/filesystem/CMakeLists.txt @@ -0,0 +1,7 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} FILESYSTEM_TESTS_LIST) +set(FILESYSTEM_TEST_TARGET filesystemtest) +add_library(${FILESYSTEM_TEST_TARGET} STATIC ${FILESYSTEM_TESTS_LIST}) +target_compile_options(${FILESYSTEM_TEST_TARGET} PUBLIC ) +target_include_directories(${FILESYSTEM_TEST_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src) +target_include_directories(${FILESYSTEM_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${FILESYSTEM_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) diff --git a/omnioperator/omniop-native-reader/cpp/test/filesystem/filesystem_test.cpp b/omnioperator/omniop-native-reader/cpp/test/filesystem/filesystem_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15e11864969e73656a7629f051e797f784f37aeb --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/filesystem/filesystem_test.cpp @@ -0,0 +1,56 @@ +/** + * Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "filesystem/hdfs_filesystem.h" +#include "filesystem/hdfs_file.h" + +namespace fs { + +// Test HdfsOptions +class HdfsOptionsTest : public ::testing::Test { +protected: + HdfsOptions options; +}; + +// Test HdfsOptions::ConfigureHost +TEST_F(HdfsOptionsTest, ConfigureHost) { + options.ConfigureHost("server1"); + ASSERT_EQ(options.host_, "server1"); +} + +// Test HdfsOptions::ConfigurePort +TEST_F(HdfsOptionsTest, ConfigurePort) { + options.ConfigurePort(9000); + ASSERT_EQ(options.port_, 9000); +} + +// Test HdfsOptions::Equals +TEST_F(HdfsOptionsTest, Equals) { + HdfsOptions options; + options.ConfigureHost("server1"); + options.ConfigurePort(9000); + + HdfsOptions otherOptions; + otherOptions.ConfigureHost("server1"); + otherOptions.ConfigurePort(9000); + ASSERT_TRUE(options.Equals(otherOptions)); +} + +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/io/arrowadapter/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/io/arrowadapter/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..aec5bbc4032b32ea7840a61e10f043e70e3b007d --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/io/arrowadapter/CMakeLists.txt @@ -0,0 +1,7 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} ARROW_ADAPTER_TESTS_LIST) +set(ARROW_ADAPTER_TARGET arrowadaptertest) +add_library(${ARROW_ADAPTER_TARGET} STATIC ${ARROW_ADAPTER_TESTS_LIST}) +target_compile_options(${ARROW_ADAPTER_TARGET} PUBLIC ) +target_include_directories(${ARROW_ADAPTER_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src) +target_include_directories(${ARROW_ADAPTER_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${ARROW_ADAPTER_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) diff --git a/omnioperator/omniop-native-reader/cpp/test/io/arrowadapter/FileSystemAdapterTest.cc b/omnioperator/omniop-native-reader/cpp/test/io/arrowadapter/FileSystemAdapterTest.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad2b2be5f4bba64477082a2f21143351471e8e1f --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/io/arrowadapter/FileSystemAdapterTest.cc @@ -0,0 +1,200 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include "gtest/gtest.h" +#include "arrowadapter/FileSystemAdapter.h" +#include "arrow/filesystem/filesystem.h" +#include "arrow/filesystem/mockfs.h" +#include "arrow/util/checked_cast.h" +#include "arrow/result.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/io_util.h" +#include "arrow/filesystem/path_util.h" +#include "arrow/filesystem/localfs.h" +#include "../../utils/test_utils.h" +#include "arrow/util/uri.h" + +using namespace arrow::fs::internal; +using arrow::fs::TimePoint; +using arrow::fs::FileSystem; +using arrow_adapter::FileSystemFromUriOrPath; +using arrow::internal::TemporaryDir; +using arrow::fs::LocalFileSystem; +using arrow::fs::LocalFileSystemOptions; +using arrow::internal::PlatformFilename; +using arrow::internal::FileDescriptor; +using arrow::Result; +using arrow::fs::HadoopFileSystem; +using arrow::fs::HdfsOptions; + +class TestMockFS : public ::testing::Test { +public: + void SetUp() override { + time_ = TimePoint(TimePoint::duration(42)); + fs_ = std::make_shared(time_); + } + + std::vector AllDirs() { + return arrow::internal::checked_pointer_cast(fs_)->AllDirs(); + } + + void CheckDirs(const std::vector& expected) { + ASSERT_EQ(AllDirs(), expected); + } + +protected: + TimePoint time_; + std::shared_ptr fs_; +}; + +TEST_F(TestMockFS, FileSystemFromUriOrPath) { + std::string path; + UriInfo uri1("mock", "", "", -1); + ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUriOrPath(uri1, &path)); + ASSERT_EQ(path, ""); + CheckDirs({}); // Ensures it's a MockFileSystem + + UriInfo uri2("mock", "foo/bar", "", -1); + ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUriOrPath(uri2, &path)); + ASSERT_EQ(path, "foo/bar"); + CheckDirs({}); + + UriInfo ur3("mock", "/foo/bar", "", -1); + ASSERT_OK_AND_ASSIGN(fs_, FileSystemFromUriOrPath(ur3, &path)); + ASSERT_EQ(path, "foo/bar"); + CheckDirs({}); +} + +struct CommonPathFormatter { + std::string operator()(std::string fn) { return fn; } + bool supports_uri() { return true; } +}; + +using PathFormatters = ::testing::Types; + +// Non-overloaded version of FileSystemFromUri, for template resolution +Result> FSFromUriOrPath(const UriInfo& uri, + std::string* out_path = NULLPTR) { + return arrow_adapter::FileSystemFromUriOrPath(uri, out_path); +} + + +template +class TestLocalFs : public ::testing::Test { +public: + void SetUp() override { + ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("test-localfs-")); + local_path_ = EnsureTrailingSlash(path_formatter_(temp_dir_->path().ToString())); + MakeFileSystem(); + } + + void MakeFileSystem() { + local_fs_ = std::make_shared(options_); + } + + template + void CheckFileSystemFromUriFunc(const UriInfo& uri, + FileSystemFromUriFunc&& fs_from_uri) { + if (!path_formatter_.supports_uri()) { + return; // skip + } + std::string path; + ASSERT_OK_AND_ASSIGN(fs_, fs_from_uri(uri, &path)); + ASSERT_EQ(path, local_path_); + + // Test that the right location on disk is accessed + CreateFile(fs_.get(), local_path_ + "abc", "some data"); + CheckConcreteFile(this->temp_dir_->path().ToString() + "abc", 9); + } + + void TestFileSystemFromUri(const UriInfo& uri) { + CheckFileSystemFromUriFunc(uri, FSFromUriOrPath); + } + + void CheckConcreteFile(const std::string& path, int64_t expected_size) { + ASSERT_OK_AND_ASSIGN(auto fn, PlatformFilename::FromString(path)); + ASSERT_OK_AND_ASSIGN(FileDescriptor fd, ::arrow::internal::FileOpenReadable(fn)); + auto result = ::arrow::internal::FileGetSize(fd.fd()); + ASSERT_OK_AND_ASSIGN(int64_t size, result); + ASSERT_EQ(size, expected_size); + } + + void TestLocalUri(const UriInfo& uri, const std::string& expected_path) { + CheckLocalUri(uri, expected_path, FSFromUriOrPath); + } + + template + void CheckLocalUri(const UriInfo& uri, const std::string& expected_path, + FileSystemFromUriFunc&& fs_from_uri) { + if (!path_formatter_.supports_uri()) { + return; // skip + } + std::string path; + ASSERT_OK_AND_ASSIGN(fs_, fs_from_uri(uri, &path)); + ASSERT_EQ(fs_->type_name(), "local"); + ASSERT_EQ(path, expected_path); + } + + void TestInvalidUri(const UriInfo& uri) { + if (!path_formatter_.supports_uri()) { + return; // skip + } + ASSERT_RAISES(Invalid, FSFromUriOrPath(uri)); + } + +protected: + std::unique_ptr temp_dir_; + std::shared_ptr fs_; + std::string local_path_; + PathFormatter path_formatter_; + std::shared_ptr local_fs_; + LocalFileSystemOptions options_ = LocalFileSystemOptions::Defaults(); +}; + +TYPED_TEST_SUITE(TestLocalFs, PathFormatters); + +TYPED_TEST(TestLocalFs, FileSystemFromUriFile){ + std::string path; + ASSERT_OK_AND_ASSIGN(auto uri_string, arrow::internal::UriFromAbsolutePath(this->local_path_)); + UriInfo uri1(uri_string, "", uri_string, "", -1); + this->TestFileSystemFromUri(uri1); + + path = "/foo/bar"; + UriInfo uri2("file", path, "", -1); + this->TestLocalUri(uri2, path); + + path = "/some path/%percent"; + UriInfo uri3("file", path, "", -1); + this->TestLocalUri(uri3, path); + + path = "/some path/%中文魑魅魍魉"; + UriInfo uri4("file", path, "", -1); + this->TestLocalUri(uri4, path); +} + +TYPED_TEST(TestLocalFs, FileSystemFromUriNoScheme){ + + UriInfo uri1(this->local_path_, "", "", "", -1); + this->TestFileSystemFromUri(uri1); + + UriInfo uri2("foo/bar", "", "", "", -1); + this->TestInvalidUri(uri2); +} diff --git a/omnioperator/omniop-native-reader/cpp/test/io/orcfile/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/io/orcfile/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..cdb765aa3581f7885e5df8716fffc343c58f0c20 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/io/orcfile/CMakeLists.txt @@ -0,0 +1,11 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} ORC_FILE_TESTS_LIST) +set(MAIN_PATH ${CMAKE_CURRENT_SOURCE_DIR}) + +configure_file(orcfile_test.h.in ${CMAKE_CURRENT_SOURCE_DIR}/orcfile_test.h) +set(ORC_FILE_TARGET orcfiletest) + +add_library(${ORC_FILE_TARGET} STATIC ${ORC_FILE_TESTS_LIST}) +target_compile_options(${ORC_FILE_TARGET} PUBLIC ) +target_include_directories(${ORC_FILE_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src) +target_include_directories(${ORC_FILE_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${ORC_FILE_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) diff --git a/omnioperator/omniop-native-reader/cpp/test/io/orcfile/OrcHdfsFileOverrideTest.cc b/omnioperator/omniop-native-reader/cpp/test/io/orcfile/OrcHdfsFileOverrideTest.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed6fc9875d95fc5b0f51bb627a93393d2875a8d8 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/io/orcfile/OrcHdfsFileOverrideTest.cc @@ -0,0 +1,40 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "orcfile/OrcFileOverride.hh" +#include "orcfile_test.h" + +TEST(OrcReader, createLocalFileReader) { + std::string filename = "/resources/orc_data_all_type"; + filename = PROJECT_PATH + filename; + + std::unique_ptr reader; + std::unique_ptr rowReader; + std::unique_ptr batch; + orc::ReaderOptions readerOpts; + orc::RowReaderOptions rowReaderOpts; + std::list cols; + + cols.push_back(1); + rowReaderOpts.include(cols); + UriInfo uriInfo("file", filename, "", ""); + reader = orc::createReader(orc::readFileOverride(uriInfo), readerOpts); + EXPECT_NE(nullptr, reader); +} diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.h.in b/omnioperator/omniop-native-reader/cpp/test/io/orcfile/orcfile_test.h.in similarity index 100% rename from omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.h.in rename to omnioperator/omniop-native-reader/cpp/test/io/orcfile/orcfile_test.h.in diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/orc_data_all_type b/omnioperator/omniop-native-reader/cpp/test/io/orcfile/resource/orc_data_all_type similarity index 100% rename from omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/orc_data_all_type rename to omnioperator/omniop-native-reader/cpp/test/io/orcfile/resource/orc_data_all_type diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/tablescan/CMakeLists.txt similarity index 90% rename from omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt rename to omnioperator/omniop-native-reader/cpp/test/tablescan/CMakeLists.txt index 2d8dcdbeb34e7e98075b1b513ae1bfd24920fe52..c18f9da39ac1c3732c15f3b29389f24e107945cb 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/tablescan/CMakeLists.txt +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/CMakeLists.txt @@ -6,7 +6,7 @@ set(SCAN_TEST_TARGET tablescantest) add_library(${SCAN_TEST_TARGET} STATIC ${SCAN_TESTS_LIST} parquet_scan_test.cpp) target_compile_options(${SCAN_TEST_TARGET} PUBLIC ) -target_link_libraries(${SCAN_TEST_TARGET} eSDKOBS) +target_link_libraries(${SCAN_TEST_TARGET}) target_include_directories(${SCAN_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) target_include_directories(${SCAN_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/parquet_scan_test.cpp b/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp similarity index 63% rename from omnioperator/omniop-spark-extension/cpp/test/tablescan/parquet_scan_test.cpp rename to omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp index 39c30151e3d4d81ffc8fa6fff7dc5e7766b153a0..287cb299601f86b5c05f22a85702d538bf581e69 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/tablescan/parquet_scan_test.cpp +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/parquet_scan_test.cpp @@ -1,5 +1,5 @@ /** - * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -19,11 +19,10 @@ #include #include -#include #include "scan_test.h" -#include "tablescan/ParquetReader.h" +#include "parquet/ParquetReader.h" -using namespace spark::reader; +using namespace omniruntime::reader; using namespace arrow; using namespace omniruntime::vec; @@ -44,50 +43,48 @@ TEST(read, test_parquet_reader) ParquetReader *reader = new ParquetReader(); std::string ugi = "root@sample"; - ObsConfig obsInfo; - auto state1 = reader->InitRecordReader(filename, 1024, row_group_indices, column_indices, ugi, obsInfo); + auto state1 = reader->InitRecordReader(filename, 1024, row_group_indices, column_indices, ugi); ASSERT_EQ(state1, Status::OK()); - std::shared_ptr batch; - auto state2 = reader->ReadNextBatch(&batch); + std::vector recordBatch(column_indices.size()); + long batchRowSize = 0; + auto state2 = reader->ReadNextBatch(recordBatch, &batchRowSize); ASSERT_EQ(state2, Status::OK()); - std::cout << "num_rows: " << batch->num_rows() << std::endl; - std::cout << "num_columns: " << batch->num_columns() << std::endl; - std::cout << "Print: " << batch->ToString() << std::endl; - auto pair = TransferToOmniVecs(batch); + std::cout << "num_rows: " << batchRowSize << std::endl; + std::cout << "num_columns: " << recordBatch.size() << std::endl; - BaseVector *intVector = reinterpret_cast(pair.second[0]); + BaseVector *intVector = reinterpret_cast(recordBatch[0]); auto int_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(intVector)); ASSERT_EQ(*int_result, 10); - auto varCharVector = reinterpret_cast> *>(pair.second[1]); + auto varCharVector = reinterpret_cast> *>(recordBatch[1]); std::string str_expected = "varchar_1"; ASSERT_TRUE(str_expected == varCharVector->GetValue(0)); - BaseVector *longVector = reinterpret_cast(pair.second[2]); + BaseVector *longVector = reinterpret_cast(recordBatch[2]); auto long_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(longVector)); ASSERT_EQ(*long_result, 10000); - BaseVector *doubleVector = reinterpret_cast(pair.second[3]); + BaseVector *doubleVector = reinterpret_cast(recordBatch[3]); auto double_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(doubleVector)); ASSERT_EQ(*double_result, 1111.1111); - BaseVector *nullVector = reinterpret_cast(pair.second[4]); + BaseVector *nullVector = reinterpret_cast(recordBatch[4]); ASSERT_TRUE(nullVector->IsNull(0)); - BaseVector *decimal64Vector = reinterpret_cast(pair.second[5]); + BaseVector *decimal64Vector = reinterpret_cast(recordBatch[5]); auto decimal64_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(decimal64Vector)); ASSERT_EQ(*decimal64_result, 13111110); - BaseVector *booleanVector = reinterpret_cast(pair.second[6]); + BaseVector *booleanVector = reinterpret_cast(recordBatch[6]); auto boolean_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(booleanVector)); ASSERT_EQ(*boolean_result, true); - BaseVector *smallintVector = reinterpret_cast(pair.second[7]); + BaseVector *smallintVector = reinterpret_cast(recordBatch[7]); auto smallint_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(smallintVector)); ASSERT_EQ(*smallint_result, 11); - BaseVector *dateVector = reinterpret_cast(pair.second[8]); + BaseVector *dateVector = reinterpret_cast(recordBatch[8]); auto date_result = static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(dateVector)); omniruntime::type::Date32 date32(*date_result); char chars[11]; @@ -107,23 +104,31 @@ TEST(read, test_parquet_reader) delete dateVector; } -TEST(read, test_decimal128_copy) +TEST(read, test_varchar) { - auto decimal_type = arrow::decimal(20, 1); - arrow::Decimal128Builder builder(decimal_type); - arrow::Decimal128 value(20230420); - auto s1 = builder.Append(value); - std::shared_ptr array; - auto s2 = builder.Finish(&array); - - int omniTypeId = 0; - uint64_t omniVecId = 0; - spark::reader::CopyToOmniVec(decimal_type, omniTypeId, omniVecId, array); - - BaseVector *decimal128Vector = reinterpret_cast(omniVecId); - auto decimal128_result = - static_cast(omniruntime::vec::VectorHelper::UnsafeGetValues(decimal128Vector)); - ASSERT_TRUE((*decimal128_result).ToString() == "20230420"); - - delete decimal128Vector; + std::string filename = "/../../../java/src/test/java/com/huawei/boostkit/spark/jni/parquetsrc/date_dim.parquet"; + filename = PROJECT_PATH + filename; + const std::vector row_group_indices = {0}; + const std::vector column_indices = {23, 24, 25, 26, 27}; + ParquetReader *reader = new ParquetReader(); + std::string ugi = "root@sample"; + auto state1 = reader->InitRecordReader(filename, 4096, row_group_indices, column_indices, ugi); + ASSERT_EQ(state1, Status::OK()); + int total_nums = 0; + int iter = 0; + while (true) { + std::vector recordBatch(column_indices.size()); + long batchRowSize = 0; + auto state2 = reader->ReadNextBatch(recordBatch, &batchRowSize); + if (batchRowSize == 0) { + break; + } + total_nums += batchRowSize; + std::cout << iter++ << " num rows: " << batchRowSize << std::endl; + for (auto vec : recordBatch) { + delete vec; + } + recordBatch.clear(); + } + std::cout << "total nums: " << total_nums << std::endl; } \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/tablescan/resources/orc_data_all_type b/omnioperator/omniop-native-reader/cpp/test/tablescan/resources/orc_data_all_type new file mode 100644 index 0000000000000000000000000000000000000000..9cc57fa78ccdae728d2d902f587c30c337b0e4a5 Binary files /dev/null and b/omnioperator/omniop-native-reader/cpp/test/tablescan/resources/orc_data_all_type differ diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/parquet_data_all_type b/omnioperator/omniop-native-reader/cpp/test/tablescan/resources/parquet_data_all_type similarity index 100% rename from omnioperator/omniop-spark-extension/cpp/test/tablescan/resources/parquet_data_all_type rename to omnioperator/omniop-native-reader/cpp/test/tablescan/resources/parquet_data_all_type diff --git a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp b/omnioperator/omniop-native-reader/cpp/test/tablescan/scan_test.cpp similarity index 97% rename from omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp rename to omnioperator/omniop-native-reader/cpp/test/tablescan/scan_test.cpp index 2ed604e50420c402e9184c0a4011f66d69c00158..e47ec373acb1c4ba259d7fac6b07e3201b0e92fd 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/tablescan/scan_test.cpp +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/scan_test.cpp @@ -1,5 +1,5 @@ /** - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -161,7 +161,7 @@ TEST_F(ScanTest, test_copy_intVec) int omniType = 0; uint64_t omniVecId = 0; // int type - CopyToOmniVec(types->getSubtype(0), omniType, omniVecId, root->fields[0]); + CopyToOmniVec(types->getSubtype(0), omniType, omniVecId, root->fields[0], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_INT); auto *olbInt = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbInt->GetValue(0), 10); @@ -173,7 +173,7 @@ TEST_F(ScanTest, test_copy_varCharVec) int omniType = 0; uint64_t omniVecId = 0; // varchar type - CopyToOmniVec(types->getSubtype(1), omniType, omniVecId, root->fields[1]); + CopyToOmniVec(types->getSubtype(1), omniType, omniVecId, root->fields[1], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); auto *olbVc = (omniruntime::vec::Vector> *)( omniVecId); @@ -187,7 +187,7 @@ TEST_F(ScanTest, test_copy_stringVec) int omniType = 0; uint64_t omniVecId = 0; // string type - CopyToOmniVec(types->getSubtype(2), omniType, omniVecId, root->fields[2]); + CopyToOmniVec(types->getSubtype(2), omniType, omniVecId, root->fields[2], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); auto *olbStr = (omniruntime::vec::Vector> *)( omniVecId); @@ -201,7 +201,7 @@ TEST_F(ScanTest, test_copy_longVec) int omniType = 0; uint64_t omniVecId = 0; // bigint type - CopyToOmniVec(types->getSubtype(3), omniType, omniVecId, root->fields[3]); + CopyToOmniVec(types->getSubtype(3), omniType, omniVecId, root->fields[3], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_LONG); auto *olbLong = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbLong->GetValue(0), 10000); @@ -213,7 +213,7 @@ TEST_F(ScanTest, test_copy_charVec) int omniType = 0; uint64_t omniVecId = 0; // char type - CopyToOmniVec(types->getSubtype(4), omniType, omniVecId, root->fields[4]); + CopyToOmniVec(types->getSubtype(4), omniType, omniVecId, root->fields[4], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_VARCHAR); auto *olbChar = (omniruntime::vec::Vector> *)( omniVecId); @@ -227,7 +227,7 @@ TEST_F(ScanTest, test_copy_doubleVec) int omniType = 0; uint64_t omniVecId = 0; // double type - CopyToOmniVec(types->getSubtype(6), omniType, omniVecId, root->fields[6]); + CopyToOmniVec(types->getSubtype(6), omniType, omniVecId, root->fields[6], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_DOUBLE); auto *olbDouble = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbDouble->GetValue(0), 1111.1111); @@ -239,7 +239,7 @@ TEST_F(ScanTest, test_copy_booleanVec) int omniType = 0; uint64_t omniVecId = 0; // boolean type - CopyToOmniVec(types->getSubtype(9), omniType, omniVecId, root->fields[9]); + CopyToOmniVec(types->getSubtype(9), omniType, omniVecId, root->fields[9], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_BOOLEAN); auto *olbBoolean = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbBoolean->GetValue(0), true); @@ -251,7 +251,7 @@ TEST_F(ScanTest, test_copy_shortVec) int omniType = 0; uint64_t omniVecId = 0; // short type - CopyToOmniVec(types->getSubtype(10), omniType, omniVecId, root->fields[10]); + CopyToOmniVec(types->getSubtype(10), omniType, omniVecId, root->fields[10], false); ASSERT_EQ(omniType, omniruntime::type::OMNI_SHORT); auto *olbShort = (omniruntime::vec::Vector *)(omniVecId); ASSERT_EQ(olbShort->GetValue(0), 11); diff --git a/omnioperator/omniop-native-reader/cpp/test/tablescan/scan_test.h.in b/omnioperator/omniop-native-reader/cpp/test/tablescan/scan_test.h.in new file mode 100644 index 0000000000000000000000000000000000000000..5ca616ec499c349478cb839213a4eb7bb289439c --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/scan_test.h.in @@ -0,0 +1 @@ +#define PROJECT_PATH "@MAIN_PATH@" \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/tptest.cpp b/omnioperator/omniop-native-reader/cpp/test/tptest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2db15a1c0812a0a305cdce502c04b969bc6db8e2 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/tptest.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/utils/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/utils/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d5ef3a300b492d20afe78550cab6fdc7635b5cc5 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/utils/CMakeLists.txt @@ -0,0 +1,6 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} UTILS_TESTS_LIST) +set(UTILS_TEST_TARGET utilstest) +add_library(${UTILS_TEST_TARGET} ${UTILS_TESTS_LIST}) +target_include_directories(${UTILS_TEST_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src) +target_include_directories(${UTILS_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${UTILS_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) diff --git a/omnioperator/omniop-native-reader/cpp/test/utils/test_utils.h b/omnioperator/omniop-native-reader/cpp/test/utils/test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..40321316b3d544bf6a28fe91a43b871d3fe5458e --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/utils/test_utils.h @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NATIVE_READER_TEST_UTILS_H +#define NATIVE_READER_TEST_UTILS_H + +#include +#include +#include +#include +#include "arrow/filesystem/filesystem.h" +#include "arrow/result.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/filesystem/type_fwd.h" + +using arrow::fs::FileSystem; +using arrow::fs::FileInfo; +using arrow::fs::FileType; + +void CreateFile(FileSystem *fs, const std::string &path, const std::string &data) { + ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path)); + ASSERT_OK(stream->Write(data)); + ASSERT_OK(stream->Close()); +} + +void AssertFileInfo(const FileInfo &info, const std::string &path, FileType type) { + ASSERT_EQ(info.path(), path); + ASSERT_EQ(info.type(), type) << "For path '" << info.path() << "'"; +} + +void AssertFileInfo(FileSystem *fs, const std::string &path, FileType type) { + ASSERT_OK_AND_ASSIGN(FileInfo info, fs->GetFileInfo(path)); + AssertFileInfo(info, path, type); +} + +#endif //NATIVE_READER_TEST_UTILS_H diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..99c66a43076a3c9b7f8b529749436e9c4ea10ed9 --- /dev/null +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -0,0 +1,135 @@ + + + + 4.0.0 + + com.huawei.boostkit + boostkit-omniop-native-reader + jar + 3.3.1-1.4.0 + + BoostKit Spark Native Sql Engine Extension With OmniOperator + + + 2.12 + 3.3.1 + FALSE + ../cpp/ + ../cpp/build/releases/ + ${cpp.test} + incremental + 0.6.1 + 3.0.0 + 1.6.2 + ${project.build.directory}/scala-${scala.binary.version}/jars + + + + + com.huawei.boostkit + boostkit-omniop-bindings + aarch64 + 1.4.0 + + + org.slf4j + slf4j-api + 1.7.32 + + + junit + junit + 4.12 + test + + + io.trino.tpcds + tpcds + 1.4 + test + + + com.tdunning + json + 1.8 + + + + ${artifactId}-${version}${dep.os.arch} + + + ../cpp/build/releases + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + kr.motd.maven + os-maven-plugin + ${os.plugin.version} + + + + + exec-maven-plugin + org.codehaus.mojo + 3.0.0 + + + Build CPP + generate-resources + + exec + + + bash + + ${cpp.dir}/build.sh + ${plugin.cpp.test} + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.0 + + 1.8 + 1.8 + + + + compile + + compile + + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.1.0 + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + + \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/NativeReaderLoader.java b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/NativeReaderLoader.java new file mode 100644 index 0000000000000000000000000000000000000000..3d061452417508ff800c74aaa751428e1b118e90 --- /dev/null +++ b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/NativeReaderLoader.java @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.scan.jni; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + +import nova.hetu.omniruntime.utils.NativeLog; + +/** + * @since 2021.08 + */ + +public class NativeReaderLoader { + + private static volatile NativeReaderLoader INSTANCE; + private static final String LIBRARY_NAME = "native_reader"; + private static final Logger LOG = LoggerFactory.getLogger(NativeReaderLoader.class); + private static final int BUFFER_SIZE = 1024; + + public static NativeReaderLoader getInstance() { + if (INSTANCE == null) { + synchronized (NativeReaderLoader.class) { + if (INSTANCE == null) { + INSTANCE = new NativeReaderLoader(); + } + } + } + return INSTANCE; + } + + private NativeReaderLoader() { + File tempFile = null; + try { + String nativeLibraryPath = File.separator + System.mapLibraryName(LIBRARY_NAME); + tempFile = File.createTempFile(LIBRARY_NAME, ".so"); + try (InputStream in = NativeReaderLoader.class.getResourceAsStream(nativeLibraryPath); + FileOutputStream fos = new FileOutputStream(tempFile)) { + int i; + byte[] buf = new byte[BUFFER_SIZE]; + while ((i = in.read(buf)) != -1) { + fos.write(buf, 0, i); + } + System.load(tempFile.getCanonicalPath()); + NativeLog.getInstance(); + } + } catch (IOException e) { + LOG.warn("fail to load library from Jar!errmsg:{}", e.getMessage()); + System.loadLibrary(LIBRARY_NAME); + } finally { + if (tempFile != null) { + tempFile.deleteOnExit(); + } + } + } +} diff --git a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/OrcColumnarBatchJniReader.java b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/OrcColumnarBatchJniReader.java new file mode 100644 index 0000000000000000000000000000000000000000..ca4e479f3c4280d08679a157ccce1a8f4ecb1948 --- /dev/null +++ b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/OrcColumnarBatchJniReader.java @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.scan.jni; +import org.json.JSONObject; +import java.util.ArrayList; + + +public class OrcColumnarBatchJniReader { + + public OrcColumnarBatchJniReader() { + NativeReaderLoader.getInstance(); + } + + public native long initializeReader(JSONObject job, ArrayList vecFildsNames); + + public native long initializeRecordReader(long reader, JSONObject job); + + public native long initializeBatch(long rowReader, long batchSize); + + public native long recordReaderNext(long rowReader, long batchReader, int[] typeId, long[] vecNativeId); + + public native long recordReaderGetRowNumber(long rowReader); + + public native float recordReaderGetProgress(long rowReader); + + public native void recordReaderClose(long rowReader, long reader, long batchReader); + + public native void recordReaderSeekToRow(long rowReader, long rowNumber); + + public native String[] getAllColumnNames(long reader); + + public native long getNumberOfRows(long rowReader, long batch); +} diff --git a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java new file mode 100644 index 0000000000000000000000000000000000000000..b740b726ce9b3ea08d30ee516cbdb4d8c9ee7cdb --- /dev/null +++ b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/scan/jni/ParquetColumnarBatchJniReader.java @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.scan.jni; +import org.json.JSONObject; + +public class ParquetColumnarBatchJniReader { + + public ParquetColumnarBatchJniReader() { + NativeReaderLoader.getInstance(); + } + + public native long initializeReader(JSONObject job); + + public native long recordReaderNext(long parquetReader, long[] vecNativeId); + + public native void recordReaderClose(long parquetReader); + +} \ No newline at end of file diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java index fd1240a20721eecc420c9eaf3e9b64d34c0382b1..42e7d6176646ed5e8cf53d8cbb14be0168575021 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/OmniLocalExecutionPlanner.java @@ -1354,12 +1354,10 @@ public class OmniLocalExecutionPlanner } public JoinBridgeManager createLookupSourceFactory(JoinNode node, - PlanNode buildNode, List buildSymbols, Optional buildHashSymbol, - PhysicalOperation probeSource, LocalExecutionPlanContext context, boolean spillEnabled) + LocalExecutionPlanContext buildContext, PhysicalOperation buildSource, PlanNode buildNode, + List buildSymbols, Optional buildHashSymbol, PhysicalOperation probeSource, + LocalExecutionPlanContext context, boolean spillEnabled) { - LocalExecutionPlanContext buildContext = context.createSubContext(); - PhysicalOperation buildSource = buildNode.accept(this, buildContext); - if (buildSource.getPipelineExecutionStrategy() == GROUPED_EXECUTION) { checkState(probeSource.getPipelineExecutionStrategy() == GROUPED_EXECUTION, "Build execution is GROUPED_EXECUTION. Probe execution is expected be GROUPED_EXECUTION, but is UNGROUPED_EXECUTION."); @@ -1500,11 +1498,15 @@ public class OmniLocalExecutionPlanner // Plan build boolean spillEnabled = isSpillEnabled(session) && node.isSpillable().orElseThrow(() -> new IllegalArgumentException("spillable not yet set")); - JoinBridgeManager lookupSourceFactory = createLookupSourceFactory(node, + LocalExecutionPlanContext buildContext = context.createSubContext(); + PhysicalOperation buildSource = buildNode.accept(this, buildContext); + JoinBridgeManager lookupSourceFactory = createLookupSourceFactory(node, buildContext, buildSource, buildNode, buildSymbols, buildHashSymbol, probeSource, context, spillEnabled); - + Optional filterFunction = node.getFilter() + .map(filterExpression -> getTranslatedExpression(context, buildSource, probeSource, + filterExpression)); OperatorFactory operator = createLookupJoin(node, probeSource, probeSymbols, probeHashSymbol, - lookupSourceFactory, context, spillEnabled); + lookupSourceFactory, context, spillEnabled, filterFunction); ImmutableMap.Builder outputMappings = ImmutableMap.builder(); List outputSymbols = node.getOutputSymbols(); @@ -1518,7 +1520,7 @@ public class OmniLocalExecutionPlanner public OperatorFactory createLookupJoin(JoinNode node, PhysicalOperation probeSource, List probeSymbols, Optional probeHashSymbol, JoinBridgeManager lookupSourceFactoryManager, - LocalExecutionPlanContext context, boolean spillEnabled) + LocalExecutionPlanContext context, boolean spillEnabled, Optional filter) { List probeTypes = probeSource.getTypes(); List probeOutputSymbols = node.getOutputSymbols().stream() @@ -1537,7 +1539,7 @@ public class OmniLocalExecutionPlanner boolean buildOuter = node.getType() == RIGHT || node.getType() == FULL; if (!buildOuter) { return createOmniLookupJoin(node, lookupSourceFactoryManager, context, probeTypes, probeOutputChannels, - probeJoinChannels, probeHashChannel, totalOperatorsCount); + probeJoinChannels, probeHashChannel, totalOperatorsCount, filter); } return getLookUpJoinOperatorFactory(node, lookupSourceFactoryManager, context, probeTypes, probeOutputChannels, probeJoinChannels, probeHashChannel, totalOperatorsCount); @@ -1559,19 +1561,19 @@ public class OmniLocalExecutionPlanner public OperatorFactory createOmniLookupJoin(JoinNode node, JoinBridgeManager lookupSourceFactoryManager, LocalExecutionPlanContext context, List probeTypes, List probeOutputChannels, - List probeJoinChannels, OptionalInt probeHashChannel, OptionalInt totalOperatorsCount) + List probeJoinChannels, OptionalInt probeHashChannel, OptionalInt totalOperatorsCount, Optional filter) { List driverFactories = context.getDriverFactories(); DriverFactory driverFactory = driverFactories.get(driverFactories.size() - 1); List operatorFactories = driverFactory.getOperatorFactories(); OperatorFactory buildOperatorFactory = operatorFactories.get(operatorFactories.size() - 1); - + System.out.println(node.getType()); switch (node.getType()) { case INNER: return LookupJoinOmniOperators.innerJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, probeTypes, probeJoinChannels, probeHashChannel, Optional.of(probeOutputChannels), totalOperatorsCount, - (HashBuilderOmniOperatorFactory) buildOperatorFactory); + (HashBuilderOmniOperatorFactory) buildOperatorFactory, filter); case LEFT: return LookupJoinOmniOperators.probeOuterJoin(context.getNextOperatorId(), node.getId(), lookupSourceFactoryManager, probeTypes, probeJoinChannels, probeHashChannel, diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java index b50ab1d0333d024e4507cbcd27cb1e8d2569b65e..7f57dce84d41728ca5fd6610cc672d67f1f27047 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/block/Int128ArrayOmniBlock.java @@ -119,6 +119,18 @@ public class Int128ArrayOmniBlock long[] values) { this.vecAllocator = vecAllocator; + for (int i = positionOffset; i < positionCount; i++) { + int first = 2 * i; + int second = first + 1; + if (values[second] < 0) { + values[first] = ~values[first] + 1; + values[second] = values[second] ^ 0x7FFFFFFFFFFFFFFFL; + if (values[first] == 0) { + values[second] = values[second] + 1; + } + } + } + if (positionOffset < 0) { throw new IllegalArgumentException("positionOffset is negative"); } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java index 3b9efaf18a244f13ed32e94aeea0f9096ce0e4cc..c193da450b0beda935a01c339058f55eebfe333f 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashBuilderOmniOperator.java @@ -124,7 +124,7 @@ public class HashBuilderOmniOperator DataType[] omniBuildTypes = OperatorUtils.toDataTypes(buildTypes); String[] omniSearchFunctions = searchFunctions.stream().toArray(String[]::new); this.omniHashBuilderOperatorFactory = new OmniHashBuilderOperatorFactory(omniBuildTypes, - Ints.toArray(hashChannels), filterFunction, sortChannel, omniSearchFunctions, operatorCount); + Ints.toArray(hashChannels), operatorCount); } @Override diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java index 0bed1c1b582989adf87921abfdcfd67692d52fc6..ad932eea879fb6e31da3df6203fe95863184081a 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperator.java @@ -332,7 +332,7 @@ public class LookupJoinOmniOperator JoinBridgeManager lookupSourceFactoryManager, List probeTypes, List probeOutputChannels, List probeOutputChannelTypes, JoinType joinType, OptionalInt totalOperatorsCount, List probeJoinChannel, OptionalInt probeHashChannel, - HashBuilderOmniOperator.HashBuilderOmniOperatorFactory hashBuilderOmniOperatorFactory) + HashBuilderOmniOperator.HashBuilderOmniOperatorFactory hashBuilderOmniOperatorFactory, Optional filter) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); @@ -364,7 +364,7 @@ public class LookupJoinOmniOperator this.omniLookupJoinOperatorFactory = new OmniLookupJoinOperatorFactory(types, Ints.toArray(probeOutputChannels), Ints.toArray(probeJoinChannel), Ints.toArray(buildOutputChannels), buildOutputDataTypes, getOmniJoinType(joinType), - hashBuilderOmniOperatorFactory.getOmniHashBuilderOperatorFactory()); + hashBuilderOmniOperatorFactory.getOmniHashBuilderOperatorFactory(), filter); } private nova.hetu.omniruntime.constants.JoinType getOmniJoinType(JoinType joinType) diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperators.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperators.java index ce39ada9f449d4c6f1928db04b3c5d98ebb14f5a..358d7c4ad4aa4234b32fb0b9ceabc3cde73a97fa 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperators.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/LookupJoinOmniOperators.java @@ -73,11 +73,11 @@ public class LookupJoinOmniOperators JoinBridgeManager lookupSourceFactory, List probeTypes, List probeJoinChannel, OptionalInt probeHashChannel, Optional> probeOutputChannels, OptionalInt totalOperatorsCount, - HashBuilderOmniOperator.HashBuilderOmniOperatorFactory hashBuilderOmniOperatorFactory) + HashBuilderOmniOperator.HashBuilderOmniOperatorFactory hashBuilderOmniOperatorFactory, Optional filter) { return createOmniJoinOperatorFactory(operatorId, planNodeId, lookupSourceFactory, probeTypes, probeJoinChannel, probeHashChannel, probeOutputChannels.orElse(rangeList(probeTypes.size())), - LookupJoinOperators.JoinType.INNER, totalOperatorsCount, hashBuilderOmniOperatorFactory); + LookupJoinOperators.JoinType.INNER, totalOperatorsCount, hashBuilderOmniOperatorFactory, filter); } /** @@ -102,7 +102,7 @@ public class LookupJoinOmniOperators { return createOmniJoinOperatorFactory(operatorId, planNodeId, lookupSourceFactory, probeTypes, probeJoinChannel, probeHashChannel, probeOutputChannels.orElse(rangeList(probeTypes.size())), - LookupJoinOperators.JoinType.PROBE_OUTER, totalOperatorsCount, hashBuilderOmniOperatorFactory); + LookupJoinOperators.JoinType.PROBE_OUTER, totalOperatorsCount, hashBuilderOmniOperatorFactory, Optional.empty()); } /** @@ -127,7 +127,7 @@ public class LookupJoinOmniOperators { return createOmniJoinOperatorFactory(operatorId, planNodeId, lookupSourceFactory, probeTypes, probeJoinChannel, probeHashChannel, probeOutputChannels.orElse(rangeList(probeTypes.size())), - LookupJoinOperators.JoinType.LOOKUP_OUTER, totalOperatorsCount, hashBuilderOmniOperatorFactory); + LookupJoinOperators.JoinType.LOOKUP_OUTER, totalOperatorsCount, hashBuilderOmniOperatorFactory, Optional.empty()); } /** @@ -152,7 +152,7 @@ public class LookupJoinOmniOperators { return createOmniJoinOperatorFactory(operatorId, planNodeId, lookupSourceFactory, probeTypes, probeJoinChannel, probeHashChannel, probeOutputChannels.orElse(rangeList(probeTypes.size())), - LookupJoinOperators.JoinType.FULL_OUTER, totalOperatorsCount, hashBuilderOmniOperatorFactory); + LookupJoinOperators.JoinType.FULL_OUTER, totalOperatorsCount, hashBuilderOmniOperatorFactory, Optional.empty()); } private static List rangeList(int endExclusive) @@ -164,13 +164,13 @@ public class LookupJoinOmniOperators JoinBridgeManager lookupSourceFactoryManager, List probeTypes, List probeJoinChannel, OptionalInt probeHashChannel, List probeOutputChannels, LookupJoinOperators.JoinType joinType, OptionalInt totalOperatorsCount, - HashBuilderOmniOperator.HashBuilderOmniOperatorFactory hashBuilderOmniOperatorFactory) + HashBuilderOmniOperator.HashBuilderOmniOperatorFactory hashBuilderOmniOperatorFactory, Optional filter) { List probeOutputChannelTypes = probeOutputChannels.stream().map(probeTypes::get) .collect(toImmutableList()); return new LookupJoinOmniOperator.LookupJoinOmniOperatorFactory(operatorId, planNodeId, lookupSourceFactoryManager, probeTypes, probeOutputChannels, probeOutputChannelTypes, joinType, - totalOperatorsCount, probeJoinChannel, probeHashChannel, hashBuilderOmniOperatorFactory); + totalOperatorsCount, probeJoinChannel, probeHashChannel, hashBuilderOmniOperatorFactory, filter); } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java index 59b556c781f79d4ad07e69eb4c2bbba07162e791..a5d54431f124bd71f34e43c839013a7e47cdc1bf 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/tool/OperatorUtils.java @@ -930,8 +930,19 @@ public final class OperatorUtils private static Block buildInt128ArrayBlock(Block block, int positionCount) { Decimal128Vec decimal128Vec = (Decimal128Vec) block.getValues(); - return new Int128ArrayBlock(positionCount, Optional.of(decimal128Vec.getValuesNulls(0, positionCount)), - decimal128Vec.get(0, positionCount)); + long[] values = decimal128Vec.get(0, positionCount); + for (int i = 0; i < positionCount; i++) { + int first = 2 * i; + int second = first + 1; + if (values[second] < 0) { + values[first] = ~values[first] + 1; + values[second] = values[second] ^ 0x7FFFFFFFFFFFFFFFL; + if (values[first] == 0) { + values[second] = values[second] + 1; + } + } + } + return new Int128ArrayBlock(positionCount, Optional.of(decimal128Vec.getValuesNulls(0, positionCount)), values); } private static Block buildDoubleArrayBLock(Block block, int positionCount) diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LookupJoinOmniOperatorTest.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LookupJoinOmniOperatorTest.java index 747dab1c0b4bfc863df20de7e889e456fef0d377..9c2c607ab3463e98dad52ed5f319317be5e9d57b 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LookupJoinOmniOperatorTest.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/LookupJoinOmniOperatorTest.java @@ -128,7 +128,7 @@ public class LookupJoinOmniOperatorTest case INNER: operatorFactory = innerJoin(operatorId, planNodeId, lookupSourceFactoryManager, probeTypes, probeJoinChannels, empty, Optional.of(probeOutputChannels), totalOperatorsCount, - hashBuilderOmniOperatorFactory); + hashBuilderOmniOperatorFactory, Optional.empty()); break; case PROBE_OUTER: operatorFactory = probeOuterJoin(operatorId, planNodeId, lookupSourceFactoryManager, probeTypes, diff --git a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java index f29c85116e570f4506ce09594fc4956a6f728f1b..bb0aba2b6363b41df864b4912532740077d743e6 100644 --- a/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java +++ b/omnioperator/omniop-openlookeng-extension/src/test/java/nova/hetu/olk/operator/benchmark/BenchmarkHashJoinOmniOperators.java @@ -205,7 +205,7 @@ public class BenchmarkHashJoinOmniOperators HashBuilderOmniOperatorFactory hashBuilderOperatorFactory = createBuildOperatorFactory(); LookupJoinOmniOperators.innerJoin(HASH_JOIN_OPERATOR_ID, TEST_PLAN_NODE_ID, lookupSourceFactoryManager, getBuildTypes(), buildJoinChannels, buildHashChannel, - Optional.of(buildOutputChannels), OptionalInt.of(1), hashBuilderOperatorFactory); + Optional.of(buildOutputChannels), OptionalInt.of(1), hashBuilderOperatorFactory, Optional.empty()); return hashBuilderOperatorFactory; } @@ -560,7 +560,7 @@ public class BenchmarkHashJoinOmniOperators OperatorFactory operatorFactory = LookupJoinOmniOperators.innerJoin(HASH_JOIN_OPERATOR_ID, TEST_PLAN_NODE_ID, lookupSourceFactoryManager, getProbeTypes(), probeJoinChannels, probeHashChannel, - Optional.of(probeOutputChannels), OptionalInt.of(1), hashBuilderOperatorFactory); + Optional.of(probeOutputChannels), OptionalInt.of(1), hashBuilderOperatorFactory, Optional.empty()); buildDriverContext = super.createTaskContext().addPipelineContext(0, true, true, false) .addDriverContext(); buildOperator = hashBuilderOperatorFactory.createOperator(buildDriverContext); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/CMakeLists.txt similarity index 97% rename from omnioperator/omniop-spark-extension-ock/cpp/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/CMakeLists.txt index 92d57e99819f7b21e42a01d242b044e8f667fe12..86d401d8384bb36b65aa75b6c10dde7abb74f8ba 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/CMakeLists.txt @@ -7,7 +7,7 @@ set(CMAKE_VERBOSE_MAKEFILE ON) cmake_minimum_required(VERSION 3.10) # configure cmake -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(root_directory ${PROJECT_BINARY_DIR}) diff --git a/omnioperator/omniop-spark-extension-ock/cpp/build.sh b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/build.sh similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/build.sh rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/build.sh diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt similarity index 96% rename from omnioperator/omniop-spark-extension-ock/cpp/src/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt index 4e3c3e2160cc8415575ed5c6745f29ac60fc298b..27a927fdb7c0fceae786683ccced1f396af59d1a 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/CMakeLists.txt @@ -38,8 +38,7 @@ target_include_directories(${PROJ_TARGET} PUBLIC /opt/lib/include) target_link_libraries (${PROJ_TARGET} PUBLIC protobuf.a z - boostkit-omniop-runtime-1.1.0-aarch64 - boostkit-omniop-vector-1.1.0-aarch64 + boostkit-omniop-vector-1.3.0-aarch64 ock_shuffle gcov ) diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/common/common.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/common.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/common/common.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/common.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/common/debug.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/debug.h similarity index 97% rename from omnioperator/omniop-spark-extension-ock/cpp/src/common/debug.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/debug.h index 65b69d4647ef8cd4424996e37b76578cc93ca5c0..ad3498061306d819aee05d53034a9f0f98d3987b 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/common/debug.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/common/debug.h @@ -1,44 +1,44 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -#ifndef DEBUG_H -#define DEBUG_H - -#include -#include - -#ifdef TRACE_RUNTIME -#define LOG_TRACE(format, ...) \ - do { \ - printf("[TRACE][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ - } while (0) -#else -#define LOG_TRACE(format, ...) -#endif - -#if defined(DEBUG_RUNTIME) || defined(TRACE_RUNTIME) -#define LOG_DEBUG(format, ...) \ - do { \ - printf("[DEBUG][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ - } while (0) -#else -#define LOG_DEBUG(format, ...) -#endif - -#define LOG_INFO(format, ...) \ - do { \ - printf("[INFO][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ - } while (0) - -#define LOG_WARN(format, ...) \ - do { \ - printf("[WARN][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ - } while (0) - -#define LOG_ERROR(format, ...) \ - do { \ - printf("[ERROR][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ - } while (0) - +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +#ifndef DEBUG_H +#define DEBUG_H + +#include +#include + +#ifdef TRACE_RUNTIME +#define LOG_TRACE(format, ...) \ + do { \ + printf("[TRACE][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ + } while (0) +#else +#define LOG_TRACE(format, ...) +#endif + +#if defined(DEBUG_RUNTIME) || defined(TRACE_RUNTIME) +#define LOG_DEBUG(format, ...) \ + do { \ + printf("[DEBUG][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ + } while (0) +#else +#define LOG_DEBUG(format, ...) +#endif + +#define LOG_INFO(format, ...) \ + do { \ + printf("[INFO][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ + } while (0) + +#define LOG_WARN(format, ...) \ + do { \ + printf("[WARN][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ + } while (0) + +#define LOG_ERROR(format, ...) \ + do { \ + printf("[ERROR][%s][%s][%d]:" format "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__); \ + } while (0) + #endif // DEBUG_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.cpp similarity index 75% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.cpp index 456519e9a8ee7edac294289f84273244f50c9d62..21e482c8d2f2b3457d1167c83c3aaa7e7fc09da1 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.cpp @@ -10,6 +10,7 @@ #include "OckShuffleJniReader.h" using namespace omniruntime::vec; +using namespace omniruntime::type; using namespace ock::dopspark; static std::mutex gInitLock; @@ -20,11 +21,16 @@ static const char *exceptionClass = "java/lang/Exception"; static void JniInitialize(JNIEnv *env) { + if (UNLIKELY(env ==nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } std::lock_guard lk(gInitLock); if (UNLIKELY(gLongClass == nullptr)) { gLongClass = env->FindClass("java/lang/Long"); if (UNLIKELY(gLongClass == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), "Failed to find class java/lang/Long"); + return; } gLongValueFieldId = env->GetFieldID(gLongClass, "value", "J"); @@ -38,24 +44,53 @@ static void JniInitialize(JNIEnv *env) JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_make(JNIEnv *env, jobject, jintArray jTypeIds) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } + if (UNLIKELY(jTypeIds == nullptr)) { + env->ThrowNew(env->FindClass(exceptionClass), "jTypeIds is null."); + return 0; + } std::shared_ptr instance = std::make_shared(); if (UNLIKELY(instance == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), "Failed to create instance for ock merge reader"); return 0; } - bool result = instance->Initialize(env->GetIntArrayElements(jTypeIds, nullptr), env->GetArrayLength(jTypeIds)); + auto typeIds = env->GetIntArrayElements(jTypeIds, nullptr); + if (UNLIKELY(typeIds == nullptr)) { + env->ThrowNew(env->FindClass(exceptionClass), "Failed to get int array elements."); + return 0; + } + bool result = instance->Initialize(typeIds, env->GetArrayLength(jTypeIds)); if (UNLIKELY(!result)) { + env->ReleaseIntArrayElements(jTypeIds, typeIds, JNI_ABORT); env->ThrowNew(env->FindClass(exceptionClass), "Failed to initialize ock merge reader"); return 0; } - + env->ReleaseIntArrayElements(jTypeIds, typeIds, JNI_ABORT); return gBlobReader.Insert(instance); } +JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_close(JNIEnv *env, jobject, jlong jReaderId) +{ + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIENV is null."); + return; + } + + gBlobReader.Erase(jReaderId); +} + JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeGetVectorBatch(JNIEnv *env, jobject, jlong jReaderId, jlong jAddress, jint jRemain, jint jMaxRow, jint jMaxSize, jobject jRowCnt) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return -1; + } + auto mergeReader = gBlobReader.Lookup(jReaderId); if (UNLIKELY(!mergeReader)) { std::string errMsg = "Invalid reader id " + std::to_string(jReaderId); @@ -80,6 +115,10 @@ JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeG JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeGetVecValueLength(JNIEnv *env, jobject, jlong jReaderId, jint jColIndex) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } auto mergeReader = gBlobReader.Lookup(jReaderId); if (UNLIKELY(!mergeReader)) { std::string errMsg = "Invalid reader id " + std::to_string(jReaderId); @@ -100,7 +139,12 @@ JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeG JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeCopyVecDataInVB(JNIEnv *env, jobject, jlong jReaderId, jlong dstNativeVec, jint jColIndex) { - auto dstVector = reinterpret_cast(dstNativeVec); // get from scala which is real vector + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } + + auto dstVector = reinterpret_cast(dstNativeVec); // get from scala which is real vector if (UNLIKELY(dstVector == nullptr)) { std::string errMsg = "Invalid dst vector address for reader id " + std::to_string(jReaderId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.h similarity index 86% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.h index 80a63c403ef8ce43ee5be522ab6bfd5fea6c9b37..eb8a692a7dcde68fafed60820da44870c3fc3a3e 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniReader.h @@ -18,6 +18,12 @@ extern "C" { */ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_make(JNIEnv *, jobject, jintArray); +/* + * Class: com_huawei_ock_spark_jni_OckShuffleJniReader + * Method: close + * Signature: (JI)I + */ +JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_close(JNIEnv *, jobject, jlong); /* * Class: com_huawei_ock_spark_jni_OckShuffleJniReader * Method: nativeGetVectorBatch diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.cpp similarity index 83% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.cpp index 61633605eb8afbf26abeeea595fcfc48742f3498..346f1c5e4d8e7cbd70adca89042e903d25fe591f 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.cpp @@ -20,11 +20,15 @@ static const char *exceptionClass = "java/lang/Exception"; JNIEXPORT jboolean JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_initialize(JNIEnv *env, jobject) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return JNI_FALSE; + } gSplitResultClass = CreateGlobalClassReference(env, "Lcom/huawei/boostkit/spark/vectorized/SplitResult;"); gSplitResultConstructor = GetMethodID(env, gSplitResultClass, "", "(JJJJJ[J)V"); if (UNLIKELY(!OckShuffleSdk::Initialize())) { - std::cout << "Failed to load ock shuffle library." << std::endl; + env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to load ock shuffle library.").c_str()); return JNI_FALSE; } @@ -36,9 +40,14 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native jstring jPartitioningMethod, jint jPartitionNum, jstring jColTypes, jint jColNum, jint jRegionSize, jint jMinCapacity, jint jMaxCapacity, jboolean jIsCompress) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } auto appIdStr = env->GetStringUTFChars(jAppId, JNI_FALSE); if (UNLIKELY(appIdStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("ApplicationId can't be empty").c_str()); + return 0; } auto appId = std::string(appIdStr); env->ReleaseStringUTFChars(jAppId, appIdStr); @@ -46,6 +55,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto partitioningMethodStr = env->GetStringUTFChars(jPartitioningMethod, JNI_FALSE); if (UNLIKELY(partitioningMethodStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Partitioning method can't be empty").c_str()); + return 0; } auto partitionMethod = std::string(partitioningMethodStr); env->ReleaseStringUTFChars(jPartitioningMethod, partitioningMethodStr); @@ -53,6 +63,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto colTypesStr = env->GetStringUTFChars(jColTypes, JNI_FALSE); if (UNLIKELY(colTypesStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Columns types can't be empty").c_str()); + return 0; } DataTypes colTypes = Deserialize(colTypesStr); @@ -63,7 +74,8 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native jmethodID jMethodId = env->GetStaticMethodID(jThreadCls, "currentThread", "()Ljava/lang/Thread;"); jobject jThread = env->CallStaticObjectMethod(jThreadCls, jMethodId); if (UNLIKELY(jThread == nullptr)) { - std::cout << "Failed to get current thread instance." << std::endl; + env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to get current thread instance.").c_str()); + return 0; } else { jThreadId = env->CallLongMethod(jThread, env->GetMethodID(jThreadCls, "getId", "()J")); } @@ -71,16 +83,19 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto splitter = OckSplitter::Make(partitionMethod, jPartitionNum, colTypes.GetIds(), jColNum, (uint64_t)jThreadId); if (UNLIKELY(splitter == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to make ock splitter").c_str()); + return 0; } bool ret = splitter->SetShuffleInfo(appId, jShuffleId, jStageId, jStageAttemptNum, jMapId, jTaskAttemptId); if (UNLIKELY(!ret)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to set shuffle information").c_str()); + return 0; } ret = splitter->InitLocalBuffer(jRegionSize, jMinCapacity, jMaxCapacity, (jIsCompress == JNI_TRUE)); if (UNLIKELY(!ret)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to initialize local buffer").c_str()); + return 0; } return gOckSplitterMap.Insert(std::shared_ptr(splitter)); @@ -89,21 +104,28 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_split(JNIEnv *env, jobject, jlong splitterId, jlong nativeVectorBatch) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string errMsg = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } auto vecBatch = (VectorBatch *)nativeVectorBatch; if (UNLIKELY(vecBatch == nullptr)) { std::string errMsg = "Invalid address for native vector batch."; env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } if (UNLIKELY(!splitter->Split(*vecBatch))) { std::string errMsg = "Failed to split vector batch by splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } delete vecBatch; @@ -112,13 +134,22 @@ JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_split(J JNIEXPORT jobject JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_stop(JNIEnv *env, jobject, jlong splitterId) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return nullptr; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string error_message = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), error_message.c_str()); + return nullptr; } - splitter->Stop(); // free resource + if (!splitter->Stop()) { + std::string errMsg = "Failed to Stop by splitter id " + std::to_string(splitterId); + env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return nullptr; + } const auto &partitionLengths = splitter->PartitionLengths(); auto jPartitionLengths = env->NewLongArray(partitionLengths.size()); @@ -132,10 +163,15 @@ JNIEXPORT jobject JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_stop JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_close(JNIEnv *env, jobject, jlong splitterId) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string errMsg = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } gOckSplitterMap.Erase(splitterId); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/OckShuffleJniWriter.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/concurrent_map.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/concurrent_map.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/concurrent_map.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/concurrent_map.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/jni_common.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/jni/jni_common.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/jni/jni_common.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/proto/vec_data.proto similarity index 95% rename from omnioperator/omniop-spark-extension-ock/cpp/src/proto/vec_data.proto rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/proto/vec_data.proto index 785ac441ab3dfaa6c99abb0d310cd85850e5615d..c40472020171692ea7b0acde2dd873efeda691f4 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/proto/vec_data.proto +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/proto/vec_data.proto @@ -1,60 +1,60 @@ -syntax = "proto3"; - -package spark; -option java_package = "com.huawei.boostkit.spark.serialize"; -option java_outer_classname = "VecData"; - -message VecBatch { - int32 rowCnt = 1; - int32 vecCnt = 2; - repeated Vec vecs = 3; -} - -message Vec { - VecType vecType = 1; - bytes offset = 2; - bytes values = 3; - bytes nulls = 4; -} - -message VecType { - enum VecTypeId { - VEC_TYPE_NONE = 0; - VEC_TYPE_INT = 1; - VEC_TYPE_LONG = 2; - VEC_TYPE_DOUBLE = 3; - VEC_TYPE_BOOLEAN = 4; - VEC_TYPE_SHORT = 5; - VEC_TYPE_DECIMAL64 = 6; - VEC_TYPE_DECIMAL128 = 7; - VEC_TYPE_DATE32 = 8; - VEC_TYPE_DATE64 = 9; - VEC_TYPE_TIME32 = 10; - VEC_TYPE_TIME64 = 11; - VEC_TYPE_TIMESTAMP = 12; - VEC_TYPE_INTERVAL_MONTHS = 13; - VEC_TYPE_INTERVAL_DAY_TIME =14; - VEC_TYPE_VARCHAR = 15; - VEC_TYPE_CHAR = 16; - VEC_TYPE_DICTIONARY = 17; - VEC_TYPE_CONTAINER = 18; - VEC_TYPE_INVALID = 19; - } - - VecTypeId typeId = 1; - int32 width = 2; - uint32 precision = 3; - uint32 scale = 4; - enum DateUnit { - DAY = 0; - MILLI = 1; - } - DateUnit dateUnit = 5; - enum TimeUnit { - SEC = 0; - MILLISEC = 1; - MICROSEC = 2; - NANOSEC = 3; - } - TimeUnit timeUnit = 6; +syntax = "proto3"; + +package spark; +option java_package = "com.huawei.boostkit.spark.serialize"; +option java_outer_classname = "VecData"; + +message VecBatch { + int32 rowCnt = 1; + int32 vecCnt = 2; + repeated Vec vecs = 3; +} + +message Vec { + VecType vecType = 1; + bytes offset = 2; + bytes values = 3; + bytes nulls = 4; +} + +message VecType { + enum VecTypeId { + VEC_TYPE_NONE = 0; + VEC_TYPE_INT = 1; + VEC_TYPE_LONG = 2; + VEC_TYPE_DOUBLE = 3; + VEC_TYPE_BOOLEAN = 4; + VEC_TYPE_SHORT = 5; + VEC_TYPE_DECIMAL64 = 6; + VEC_TYPE_DECIMAL128 = 7; + VEC_TYPE_DATE32 = 8; + VEC_TYPE_DATE64 = 9; + VEC_TYPE_TIME32 = 10; + VEC_TYPE_TIME64 = 11; + VEC_TYPE_TIMESTAMP = 12; + VEC_TYPE_INTERVAL_MONTHS = 13; + VEC_TYPE_INTERVAL_DAY_TIME =14; + VEC_TYPE_VARCHAR = 15; + VEC_TYPE_CHAR = 16; + VEC_TYPE_DICTIONARY = 17; + VEC_TYPE_CONTAINER = 18; + VEC_TYPE_INVALID = 19; + } + + VecTypeId typeId = 1; + int32 width = 2; + uint32 precision = 3; + uint32 scale = 4; + enum DateUnit { + DAY = 0; + MILLI = 1; + } + DateUnit dateUnit = 5; + enum TimeUnit { + SEC = 0; + MILLISEC = 1; + MICROSEC = 2; + NANOSEC = 3; + } + TimeUnit timeUnit = 6; } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/sdk/ock_shuffle_sdk.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/sdk/ock_shuffle_sdk.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/sdk/ock_shuffle_sdk.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/sdk/ock_shuffle_sdk.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.cpp similarity index 81% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.cpp index b9c6ced10a6742812d257ee6cf95c84b9e5b3ad0..d0fe8198b4eb15f8796e2e70ce4480761180cf59 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.cpp @@ -23,9 +23,21 @@ bool OckHashWriteBuffer::Initialize(uint32_t regionSize, uint32_t minCapacity, u mIsCompress = isCompress; uint32_t bufferNeed = regionSize * mPartitionNum; mDataCapacity = std::min(std::max(bufferNeed, minCapacity), maxCapacity); + if (UNLIKELY(mDataCapacity < mSinglePartitionAndRegionUsedSize * mPartitionNum)) { + LogError("mDataCapacity should be bigger than mSinglePartitionAndRegionUsedSize * mPartitionNum"); + return false; + } mRegionPtRecordOffset = mDataCapacity - mSinglePartitionAndRegionUsedSize * mPartitionNum; + if (UNLIKELY(mDataCapacity < mSingleRegionUsedSize * mPartitionNum)) { + LogError("mDataCapacity should be bigger than mSingleRegionUsedSize * mPartitionNum"); + return false; + } mRegionUsedRecordOffset = mDataCapacity - mSingleRegionUsedSize * mPartitionNum; + if (UNLIKELY(mDataCapacity / mPartitionNum < mSinglePartitionAndRegionUsedSize)) { + LogError("mDataCapacity / mPartitionNum should be bigger than mSinglePartitionAndRegionUsedSize"); + return false; + } mEachPartitionSize = mDataCapacity / mPartitionNum - mSinglePartitionAndRegionUsedSize; mDoublePartitionSize = reserveSize * mEachPartitionSize; @@ -76,6 +88,10 @@ OckHashWriteBuffer::ResultFlag OckHashWriteBuffer::PreoccupiedDataSpace(uint32_t return ResultFlag::UNEXPECTED; } + if (UNLIKELY(mTotalSize > UINT32_MAX - length)) { + LogError("mTotalSize + length exceed UINT32_MAX"); + return ResultFlag::UNEXPECTED; + } // 1. get the new region id for partitionId uint32_t regionId = UINT32_MAX; if (newRegion && !GetNewRegion(partitionId, regionId)) { @@ -98,7 +114,7 @@ OckHashWriteBuffer::ResultFlag OckHashWriteBuffer::PreoccupiedDataSpace(uint32_t (mDoublePartitionSize - mRegionUsedSize[regionId] - mRegionUsedSize[nearRegionId]); if (remainBufLength >= length) { mRegionUsedSize[regionId] += length; - mTotalSize += length; // todo check + mTotalSize += length; return ResultFlag::ENOUGH; } @@ -111,8 +127,16 @@ uint8_t *OckHashWriteBuffer::GetEndAddressOfRegion(uint32_t partitionId, uint32_ regionId = mPtCurrentRegionId[partitionId]; if ((regionId % groupSize) == 0) { + if (UNLIKELY(regionId * mEachPartitionSize + mRegionUsedSize[regionId] < length)) { + LogError("regionId * mEachPartitionSize + mRegionUsedSize[regionId] shoulld be bigger than length"); + return nullptr; + } offset = regionId * mEachPartitionSize + mRegionUsedSize[regionId] - length; } else { + if (UNLIKELY((regionId + 1) * mEachPartitionSize < mRegionUsedSize[regionId])) { + LogError("(regionId + 1) * mEachPartitionSize shoulld be bigger than mRegionUsedSize[regionId]"); + return nullptr; + } offset = (regionId + 1) * mEachPartitionSize - mRegionUsedSize[regionId]; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.h similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_hash_write_buffer.h diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.cpp similarity index 52% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.cpp index 80ff1737977846dee4dad93049c35ffb44509f13..d1ef824c4a3032e3305ac5d7b16cc7838f5f8684 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.cpp @@ -8,19 +8,23 @@ #include "common/common.h" -using namespace omniruntime::type; using namespace omniruntime::vec; using namespace ock::dopspark; bool OckMergeReader::Initialize(const int32_t *typeIds, uint32_t colNum) { mColNum = colNum; - mVectorBatch = new (std::nothrow) VBDataDesc(colNum); + mVectorBatch = std::make_shared(); if (UNLIKELY(mVectorBatch == nullptr)) { LOG_ERROR("Failed to new instance for vector batch description"); return false; } + if (UNLIKELY(!mVectorBatch->Initialize(colNum))) { + LOG_ERROR("Failed to initialize vector batch."); + return false; + } + mColTypeIds.reserve(colNum); for (uint32_t index = 0; index < colNum; ++index) { mColTypeIds.emplace_back(typeIds[index]); @@ -29,44 +33,48 @@ bool OckMergeReader::Initialize(const int32_t *typeIds, uint32_t colNum) return true; } -bool OckMergeReader::GenerateVector(OckVector &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress) +bool OckMergeReader::GenerateVector(OckVectorPtr &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress) { uint8_t *address = startAddress; - vector.SetValueNulls(static_cast(address)); - vector.SetSize(rowNum); + vector->SetValueNulls(static_cast(address)); + vector->SetSize(rowNum); address += rowNum; switch (typeId) { case OMNI_BOOLEAN: { - vector.SetCapacityInBytes(sizeof(uint8_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint8_t) * rowNum); break; } case OMNI_SHORT: { - vector.SetCapacityInBytes(sizeof(uint16_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint16_t) * rowNum); break; } case OMNI_INT: case OMNI_DATE32: { - vector.SetCapacityInBytes(sizeof(uint32_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint32_t) * rowNum); break; } case OMNI_LONG: case OMNI_DOUBLE: case OMNI_DECIMAL64: case OMNI_DATE64: { - vector.SetCapacityInBytes(sizeof(uint64_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint64_t) * rowNum); break; } case OMNI_DECIMAL128: { - vector.SetCapacityInBytes(decimal128Size * rowNum); // 16 means value cost 16Byte + vector->SetCapacityInBytes(decimal128Size * rowNum); // 16 means value cost 16Byte break; } case OMNI_CHAR: case OMNI_VARCHAR: { // unknown length for value vector, calculate later // will add offset_vector_len when the length of values_vector is variable - vector.SetValueOffsets(static_cast(address)); + vector->SetValueOffsets(static_cast(address)); address += capacityOffset * (rowNum + 1); // 4 means value cost 4Byte - vector.SetCapacityInBytes(*reinterpret_cast(address - capacityOffset)); + vector->SetCapacityInBytes(*reinterpret_cast(address - capacityOffset)); + if (UNLIKELY(vector->GetCapacityInBytes() > maxCapacityInBytes)) { + LOG_ERROR("vector capacityInBytes exceed maxCapacityInBytes"); + return false; + } break; } default: { @@ -75,26 +83,26 @@ bool OckMergeReader::GenerateVector(OckVector &vector, uint32_t rowNum, int32_t } } - vector.SetValues(static_cast(address)); - address += vector.GetCapacityInBytes(); + vector->SetValues(static_cast(address)); + address += vector->GetCapacityInBytes(); startAddress = address; return true; } bool OckMergeReader::CalVectorValueLength(uint32_t colIndex, uint32_t &length) { - OckVector *vector = mVectorBatch->mColumnsHead[colIndex]; + auto vector = mVectorBatch->GetColumnHead(colIndex); + length = 0; for (uint32_t cnt = 0; cnt < mMergeCnt; ++cnt) { if (UNLIKELY(vector == nullptr)) { LOG_ERROR("Failed to calculate value length for column index %d", colIndex); return false; } - - mVectorBatch->mVectorValueLength[colIndex] += vector->GetCapacityInBytes(); + length += vector->GetCapacityInBytes(); vector = vector->GetNextVector(); } - length = mVectorBatch->mVectorValueLength[colIndex]; + mVectorBatch->SetColumnCapacity(colIndex, length); return true; } @@ -102,37 +110,27 @@ bool OckMergeReader::ScanOneVectorBatch(uint8_t *&startAddress) { uint8_t *address = startAddress; // get vector batch msg as vb_data_batch memory layout (upper) - mCurVBHeader = reinterpret_cast(address); - mVectorBatch->mHeader.rowNum += mCurVBHeader->rowNum; - mVectorBatch->mHeader.length += mCurVBHeader->length; + auto curVBHeader = reinterpret_cast(address); + mVectorBatch->AddTotalCapacity(curVBHeader->length); + mVectorBatch->AddTotalRowNum(curVBHeader->rowNum); address += sizeof(struct VBDataHeaderDesc); OckVector *curVector = nullptr; for (uint32_t colIndex = 0; colIndex < mColNum; colIndex++) { - curVector = mVectorBatch->mColumnsCur[colIndex]; - if (UNLIKELY(!GenerateVector(*curVector, mCurVBHeader->rowNum, mColTypeIds[colIndex], address))) { - LOG_ERROR("Failed to generate vector"); + auto curVector = mVectorBatch->GetCurColumn(colIndex); + if (UNLIKELY(curVector == nullptr)) { + LOG_ERROR("curVector is null, index %d", colIndex); return false; } - - if (curVector->GetNextVector() == nullptr) { - curVector = new (std::nothrow) OckVector(); - if (UNLIKELY(curVector == nullptr)) { - LOG_ERROR("Failed to new instance for ock vector"); - return false; - } - - // set next vector in the column merge list, and current column vector point to it - mVectorBatch->mColumnsCur[colIndex]->SetNextVector(curVector); - mVectorBatch->mColumnsCur[colIndex] = curVector; - } else { - mVectorBatch->mColumnsCur[colIndex] = curVector->GetNextVector(); + if (UNLIKELY(!GenerateVector(curVector, curVBHeader->rowNum, mColTypeIds[colIndex], address))) { + LOG_ERROR("Failed to generate vector"); + return false; } } - if (UNLIKELY((uint32_t)(address - startAddress) != mCurVBHeader->length)) { + if (UNLIKELY((uint32_t)(address - startAddress) != curVBHeader->length)) { LOG_ERROR("Failed to scan one vector batch as invalid date setting %d vs %d", - (uint32_t)(address - startAddress), mCurVBHeader->length); + (uint32_t)(address - startAddress), curVBHeader->length); return false; } @@ -159,49 +157,72 @@ bool OckMergeReader::GetMergeVectorBatch(uint8_t *&startAddress, uint32_t remain } mMergeCnt++; - if (mVectorBatch->mHeader.rowNum >= maxRowNum || mVectorBatch->mHeader.length >= maxSize) { + if (mVectorBatch->GetTotalRowNum() >= maxRowNum || mVectorBatch->GetTotalCapacity() >= maxSize) { break; } } startAddress = address; - return true; } -bool OckMergeReader::CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, - OckVector &srcVector, uint32_t colIndex) +bool OckMergeReader::CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, uint32_t &remainingSize, + uint32_t &remainingCapacity, OckVectorPtr &srcVector) { - errno_t ret = memcpy_s(nulls, srcVector.GetSize(), srcVector.GetValueNulls(), srcVector.GetSize()); + uint32_t srcSize = srcVector->GetSize(); + if (UNLIKELY(remainingSize < srcSize)) { + LOG_ERROR("Not eneough resource. remainingSize %d, srcSize %d.", remainingSize, srcSize); + return false; + } + errno_t ret = memcpy_s(nulls, remainingSize, srcVector->GetValueNulls(), srcSize); if (UNLIKELY(ret != EOK)) { LOG_ERROR("Failed to copy null vector"); return false; } - nulls += srcVector.GetSize(); + nulls += srcSize; + remainingSize -= srcSize; - if (srcVector.GetCapacityInBytes() > 0) { - ret = memcpy_s(values, srcVector.GetCapacityInBytes(), srcVector.GetValues(), - srcVector.GetCapacityInBytes()); + uint32_t srcCapacity = srcVector->GetCapacityInBytes(); + if (UNLIKELY(remainingCapacity < srcCapacity)) { + LOG_ERROR("Not enough resource. remainingCapacity %d, srcCapacity %d", remainingCapacity, srcCapacity); + return false; + } + if (srcCapacity > 0) { + ret = memcpy_s(values, remainingCapacity, srcVector->GetValues(), srcCapacity); if (UNLIKELY(ret != EOK)) { LOG_ERROR("Failed to copy values vector"); return false; } - values += srcVector.GetCapacityInBytes(); + values += srcCapacity; + remainingCapacity -=srcCapacity; } return true; } -bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) +bool OckMergeReader::CopyDataToVector(BaseVector *dstVector, uint32_t colIndex) { // point to first src vector in list - OckVector *srcVector = mVectorBatch->mColumnsHead[colIndex]; + auto srcVector = mVectorBatch->GetColumnHead(colIndex); - auto *nullsAddress = (uint8_t *)dstVector->GetValueNulls(); - auto *valuesAddress = (uint8_t *)dstVector->GetValues(); - uint32_t *offsetsAddress = (uint32_t *)dstVector->GetValueOffsets(); + auto *nullsAddress = (uint8_t *)omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(dstVector); + auto *valuesAddress = (uint8_t *)VectorHelper::UnsafeGetValues(dstVector); + uint32_t *offsetsAddress = (uint32_t *)VectorHelper::UnsafeGetOffsetsAddr(dstVector); + dstVector->SetNullFlag(true); uint32_t totalSize = 0; uint32_t currentSize = 0; + if (dstVector->GetSize() < 0) { + LOG_ERROR("Invalid vector size %d", dstVector->GetSize()); + return false; + } + uint32_t remainingSize = (uint32_t)dstVector->GetSize(); + uint32_t remainingCapacity = 0; + if (mColTypeIds[colIndex] == OMNI_CHAR || mColTypeIds[colIndex] == OMNI_VARCHAR) { + auto *varCharVector = reinterpret_cast> *>(dstVector); + remainingCapacity = omniruntime::vec::unsafe::UnsafeStringVector::GetContainer(varCharVector)->GetCapacityInBytes(); + } else { + remainingCapacity = GetDataSize(colIndex) * remainingSize; + } for (uint32_t cnt = 0; cnt < mMergeCnt; ++cnt) { if (UNLIKELY(srcVector == nullptr)) { @@ -209,7 +230,7 @@ bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) return false; } - if (UNLIKELY(!CopyPartDataToVector(nullsAddress, valuesAddress, *srcVector, colIndex))) { + if (UNLIKELY(!CopyPartDataToVector(nullsAddress, valuesAddress, remainingSize, remainingCapacity, srcVector))) { return false; } @@ -226,9 +247,9 @@ bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) if (mColTypeIds[colIndex] == OMNI_CHAR || mColTypeIds[colIndex] == OMNI_VARCHAR) { *offsetsAddress = totalSize; - if (UNLIKELY(totalSize != mVectorBatch->mVectorValueLength[colIndex])) { + if (UNLIKELY(totalSize != mVectorBatch->GetColumnCapacity(colIndex))) { LOG_ERROR("Failed to calculate variable vector value length, %d to %d", totalSize, - mVectorBatch->mVectorValueLength[colIndex]); + mVectorBatch->GetColumnCapacity(colIndex)); return false; } } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.h similarity index 47% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.h index b5d5fba4d7ddd910146126201cc27776f6ad813b..838dd6a8d6e78b3557764869f1240c47b48aa398 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_merge_reader.h @@ -10,38 +10,69 @@ namespace ock { namespace dopspark { +using namespace omniruntime::type; class OckMergeReader { public: bool Initialize(const int32_t *typeIds, uint32_t colNum); bool GetMergeVectorBatch(uint8_t *&address, uint32_t remain, uint32_t maxRowNum, uint32_t maxSize); - bool CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, OckVector &srcVector, uint32_t colIndex); - bool CopyDataToVector(omniruntime::vec::Vector *dstVector, uint32_t colIndex); + bool CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, uint32_t &remainingSize, uint32_t &remainingCapacity, + OckVectorPtr &srcVector); + bool CopyDataToVector(omniruntime::vec::BaseVector *dstVector, uint32_t colIndex); [[nodiscard]] inline uint32_t GetVectorBatchLength() const { - return mVectorBatch->mHeader.length; + return mVectorBatch->GetTotalCapacity(); } [[nodiscard]] inline uint32_t GetRowNumAfterMerge() const { - return mVectorBatch->mHeader.rowNum; + return mVectorBatch->GetTotalRowNum(); } bool CalVectorValueLength(uint32_t colIndex, uint32_t &length); + inline uint32_t GetDataSize(int32_t colIndex) + { + switch (mColTypeIds[colIndex]) { + case OMNI_BOOLEAN: { + return sizeof(uint8_t); + } + case OMNI_SHORT: { + return sizeof(uint16_t); + } + case OMNI_INT: + case OMNI_DATE32: { + return sizeof(uint32_t); + } + case OMNI_LONG: + case OMNI_DOUBLE: + case OMNI_DECIMAL64: + case OMNI_DATE64: { + return sizeof(uint64_t); + } + case OMNI_DECIMAL128: { + return decimal128Size; + } + default: { + LOG_ERROR("Unsupported data type id %d", mColTypeIds[colIndex]); + return false; + } + } + } + private: - static bool GenerateVector(OckVector &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress); + static bool GenerateVector(OckVectorPtr &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress); bool ScanOneVectorBatch(uint8_t *&startAddress); static constexpr int capacityOffset = 4; static constexpr int decimal128Size = 16; + static constexpr int maxCapacityInBytes = 1073741824; private: // point to shuffle blob current vector batch data header uint32_t mColNum = 0; uint32_t mMergeCnt = 0; std::vector mColTypeIds {}; - VBHeaderPtr mCurVBHeader = nullptr; VBDataDescPtr mVectorBatch = nullptr; }; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.cpp similarity index 65% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.cpp index 5c046686755c88ccf3e0bdb39e70633c49015aca..ba1296be400e6def31bc02810f71eec3c67ae9f7 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.cpp @@ -23,39 +23,49 @@ bool OckSplitter::ToSplitterTypeId(const int32_t *vBColTypes) for (uint32_t colIndex = 0; colIndex < mColNum; ++colIndex) { switch (vBColTypes[colIndex]) { case OMNI_BOOLEAN: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_1BYTE); - mMinDataLenInVBByRow += uint8Size; + CastOmniToShuffleType(OMNI_BOOLEAN, ShuffleTypeId::SHUFFLE_1BYTE, uint8Size); break; } case OMNI_SHORT: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_2BYTE); - mMinDataLenInVBByRow += uint16Size; + CastOmniToShuffleType(OMNI_SHORT, ShuffleTypeId::SHUFFLE_2BYTE, uint16Size); + break; + } + case OMNI_DATE32: { + CastOmniToShuffleType(OMNI_DATE32, ShuffleTypeId::SHUFFLE_4BYTE, uint32Size); break; } - case OMNI_DATE32: case OMNI_INT: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_4BYTE); - mMinDataLenInVBByRow += uint32Size; // 4 means value cost 4Byte + CastOmniToShuffleType(OMNI_INT, ShuffleTypeId::SHUFFLE_4BYTE, uint32Size); + break; + } + case OMNI_DATE64: { + CastOmniToShuffleType(OMNI_DATE64, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); + break; + } + case OMNI_DOUBLE: { + CastOmniToShuffleType(OMNI_DOUBLE, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); + break; + } + case OMNI_DECIMAL64: { + CastOmniToShuffleType(OMNI_DECIMAL64, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); break; } - case OMNI_DATE64: - case OMNI_DOUBLE: - case OMNI_DECIMAL64: case OMNI_LONG: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_8BYTE); - mMinDataLenInVBByRow += uint64Size; // 8 means value cost 8Byte + CastOmniToShuffleType(OMNI_LONG, ShuffleTypeId::SHUFFLE_8BYTE, uint64Size); + break; + } + case OMNI_CHAR: { + CastOmniToShuffleType(OMNI_CHAR, ShuffleTypeId::SHUFFLE_BINARY, uint32Size); + mColIndexOfVarVec.emplace_back(colIndex); break; } - case OMNI_CHAR: - case OMNI_VARCHAR: { // unknown length for value vector, calculate later - mMinDataLenInVBByRow += uint32Size; // 4 means offset - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_BINARY); + case OMNI_VARCHAR: { // unknown length for value vector, calculate later + CastOmniToShuffleType(OMNI_VARCHAR, ShuffleTypeId::SHUFFLE_BINARY, uint32Size); mColIndexOfVarVec.emplace_back(colIndex); break; } case OMNI_DECIMAL128: { - mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_DECIMAL128); - mMinDataLenInVBByRow += decimal128Size; // 16 means value cost 8Byte + CastOmniToShuffleType(OMNI_DECIMAL128, ShuffleTypeId::SHUFFLE_DECIMAL128, decimal128Size); break; } default: { @@ -70,11 +80,15 @@ bool OckSplitter::ToSplitterTypeId(const int32_t *vBColTypes) return true; } -void OckSplitter::InitCacheRegion() +bool OckSplitter::InitCacheRegion() { mCacheRegion.reserve(mPartitionNum); mCacheRegion.resize(mPartitionNum); + if (UNLIKELY(mOckBuffer->GetRegionSize() * 2 < mMinDataLenInVB || mMinDataLenInVBByRow == 0)) { + LOG_DEBUG("regionSize * doubleNum should be bigger than mMinDataLenInVB %d", mMinDataLenInVBByRow); + return false; + } uint32_t rowNum = (mOckBuffer->GetRegionSize() * 2 - mMinDataLenInVB) / mMinDataLenInVBByRow; LOG_INFO("Each region can cache row number is %d", rowNum); @@ -84,6 +98,7 @@ void OckSplitter::InitCacheRegion() region.mLength = 0; region.mRowNum = 0; } + return true; } bool OckSplitter::Initialize(const int32_t *colTypeIds) @@ -122,6 +137,10 @@ std::shared_ptr OckSplitter::Create(const int32_t *colTypeIds, int3 std::shared_ptr OckSplitter::Make(const std::string &partitionMethod, int partitionNum, const int32_t *colTypeIds, int32_t colNum, uint64_t threadId) { + if (UNLIKELY(colTypeIds == nullptr || colNum == 0)) { + LOG_ERROR("colTypeIds is null or colNum is 0, colNum %d", colNum); + return nullptr; + } if (partitionMethod == "hash" || partitionMethod == "rr" || partitionMethod == "range") { return Create(colTypeIds, colNum, partitionNum, false, threadId); } else if (UNLIKELY(partitionMethod == "single")) { @@ -132,35 +151,38 @@ std::shared_ptr OckSplitter::Make(const std::string &partitionMetho } } -uint32_t OckSplitter::GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex, uint8_t **address) const +uint32_t OckSplitter::GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex) const { - auto vector = mIsSinglePt ? vb.GetVector(colIndex) : vb.GetVector(static_cast(colIndex + 1)); - if (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY) { - return reinterpret_cast(vector)->GetVarchar(rowIndex, address); + auto vector = mIsSinglePt ? vb.Get(colIndex) : vb.Get(static_cast(colIndex + 1)); + if (vector->GetEncoding() == OMNI_DICTIONARY) { + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + return static_cast(value.length()); } else { - return reinterpret_cast(vector)->GetValue(rowIndex, address); + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + return static_cast(value.length()); } } uint32_t OckSplitter::GetRowLengthInBytes(VectorBatch &vb, uint32_t rowIndex) const { - uint8_t *address = nullptr; uint32_t length = mMinDataLenInVBByRow; // calculate variable width value for (auto &colIndex : mColIndexOfVarVec) { - length += GetVarVecValue(vb, rowIndex, colIndex, &address); + length += GetVarVecValue(vb, rowIndex, colIndex); } return length; } -bool OckSplitter::WriteNullValues(Vector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteNullValues(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address) { uint8_t *nullAddress = address; for (uint32_t index = 0; index < rowNum; ++index) { - *nullAddress = const_cast((uint8_t *)(VectorHelper::GetNullsAddr(vector)))[rowIndexes[index]]; + *nullAddress = const_cast((uint8_t *)(unsafe::UnsafeBaseVector::GetNulls(vector)))[rowIndexes[index]]; nullAddress++; } @@ -169,34 +191,45 @@ bool OckSplitter::WriteNullValues(Vector *vector, std::vector &rowInde } template -bool OckSplitter::WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::vector &rowIndexes, +bool OckSplitter::WriteFixedWidthValueTemple(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, T *&address) { T *dstValues = address; T *srcValues = nullptr; if (isDict) { - auto ids = static_cast(mAllocator->alloc(mCurrentVB->GetRowCount() * sizeof(int32_t))); - if (UNLIKELY(ids == nullptr)) { - LOG_ERROR("Failed to allocate space for fixed width value ids."); + int32_t idsNum = mCurrentVB->GetRowCount(); + int64_t idsSizeInBytes = idsNum * sizeof(int32_t); + auto ids = VectorHelper::UnsafeGetValues(vector); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetDictionary(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); return false; } - auto dictionary = - (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, mCurrentVB->GetRowCount(), ids); - if (UNLIKELY(dictionary == nullptr)) { - LOG_ERROR("Failed to get dictionary"); - return false; - } - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(dictionary)); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[reinterpret_cast(ids)[rowIndexes[index]]]; // write value to local blob + uint32_t idIndex = rowIndexes[index]; + if (UNLIKELY(idIndex >= idsNum)) { + LOG_ERROR("Invalid idIndex %d, idsNum.", idIndex, idsNum); + return false; + } + uint32_t rowIndex = reinterpret_cast(ids)[idIndex]; + *dstValues++ = srcValues[rowIndex]; // write value to local blob } - mAllocator->free((uint8_t *)(ids), mCurrentVB->GetRowCount() * sizeof(int32_t)); } else { - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(vector)); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetValues(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + return false; + } + int32_t srcRowCount = vector->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[rowIndexes[index]]; // write value to local blob + uint32_t rowIndex = rowIndexes[index]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + return false; + } + *dstValues++ = srcValues[rowIndex]; // write value to local blob } } @@ -205,37 +238,45 @@ bool OckSplitter::WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::v return true; } -bool OckSplitter::WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, - uint32_t rowNum, uint64_t *&address) +bool OckSplitter::WriteDecimal128(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, + uint64_t *&address) { uint64_t *dstValues = address; uint64_t *srcValues = nullptr; if (isDict) { - auto ids = static_cast(mAllocator->alloc(mCurrentVB->GetRowCount() * sizeof(int32_t))); - if (UNLIKELY(ids == nullptr)) { - LOG_ERROR("Failed to allocate space for fixed width value ids."); + uint32_t idsNum = mCurrentVB->GetRowCount(); + auto ids = VectorHelper::UnsafeGetValues(vector); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetDictionary(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); return false; } - - auto dictionary = - (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, mCurrentVB->GetRowCount(), ids); - if (UNLIKELY(dictionary == nullptr)) { - LOG_ERROR("Failed to get dictionary"); - return false; - } - - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(dictionary)); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[reinterpret_cast(ids)[rowIndexes[index]] << 1]; - *dstValues++ = srcValues[(reinterpret_cast(ids)[rowIndexes[index]] << 1) | 1]; + uint32_t idIndex = rowIndexes[index]; + if (UNLIKELY(idIndex >= idsNum)) { + LOG_ERROR("Invalid idIndex %d, idsNum.", idIndex, idsNum); + return false; + } + uint32_t rowIndex = reinterpret_cast(ids)[idIndex]; + *dstValues++ = srcValues[rowIndex << 1]; + *dstValues++ = srcValues[rowIndex << 1 | 1]; } - mAllocator->free((uint8_t *)(ids), mCurrentVB->GetRowCount() * sizeof(int32_t)); } else { - srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(vector)); + srcValues = reinterpret_cast(VectorHelper::UnsafeGetValues(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + return false; + } + int32_t srcRowCount = vector->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { + uint32_t rowIndex = rowIndexes[index]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + return false; + } *dstValues++ = srcValues[rowIndexes[index] << 1]; // write value to local blob - *dstValues++ = srcValues[(rowIndexes[index] << 1) | 1]; // write value to local blob + *dstValues++ = srcValues[rowIndexes[index] << 1 | 1]; // write value to local blob } } @@ -243,10 +284,10 @@ bool OckSplitter::WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteFixedWidthValue(BaseVector *vector, ShuffleTypeId typeId, std::vector &rowIndexes, + uint32_t rowNum, uint8_t *&address) { - bool isDict = (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY); + bool isDict = (vector->GetEncoding() == OMNI_DICTIONARY); switch (typeId) { case ShuffleTypeId::SHUFFLE_1BYTE: { WriteFixedWidthValueTemple(vector, isDict, rowIndexes, rowNum, address); @@ -285,21 +326,33 @@ bool OckSplitter::WriteFixedWidthValue(Vector *vector, ShuffleTypeId typeId, return true; } -bool OckSplitter::WriteVariableWidthValue(Vector *vector, std::vector &rowIndexes, - uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteVariableWidthValue(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, + uint8_t *&address) { - bool isDict = (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY); + bool isDict = (vector->GetEncoding() == OMNI_DICTIONARY); auto *offsetAddress = reinterpret_cast(address); // point the offset space base address uint8_t *valueStartAddress = address + (rowNum + 1) * sizeof(int32_t); // skip the offsets space uint8_t *valueAddress = valueStartAddress; - int32_t length = 0; + uint32_t length = 0; uint8_t *srcValues = nullptr; + int32_t vectorSize = vector->GetSize(); for (uint32_t rowCnt = 0; rowCnt < rowNum; rowCnt++) { + uint32_t rowIndex = rowIndexes[rowCnt]; + if (UNLIKELY(rowIndex >= vectorSize)) { + LOG_ERROR("Invalid rowIndex %d, vectorSize %d.", rowIndex, vectorSize); + return false; + } if (isDict) { - length = reinterpret_cast(vector)->GetVarchar(rowIndexes[rowCnt], &srcValues); + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + srcValues = reinterpret_cast(reinterpret_cast(value.data())); + length = static_cast(value.length()); } else { - length = reinterpret_cast(vector)->GetValue(rowIndexes[rowCnt], &srcValues); + auto vc = reinterpret_cast> *>(vector); + std::string_view value = vc->GetValue(rowIndex); + srcValues = reinterpret_cast(reinterpret_cast(value.data())); + length = static_cast(value.length()); } // write the null value in the vector with row index to local blob if (UNLIKELY(length > 0 && memcpy_s(valueAddress, length, srcValues, length) != EOK)) { @@ -320,7 +373,7 @@ bool OckSplitter::WriteVariableWidthValue(Vector *vector, std::vector bool OckSplitter::WriteOneVector(VectorBatch &vb, uint32_t colIndex, std::vector &rowIndexes, uint32_t rowNum, uint8_t **address) { - Vector *vector = vb.GetVector(colIndex); + BaseVector *vector = vb.Get(colIndex); if (UNLIKELY(vector == nullptr)) { LOG_ERROR("Failed to get vector with index %d in current vector batch", colIndex); return false; @@ -353,6 +406,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) uint32_t regionId = 0; // backspace from local blob the region end address to remove preoccupied bytes for the vector batch region auto address = mOckBuffer->GetEndAddressOfRegion(partitionId, regionId, vbRegion->mLength); + if (UNLIKELY(address == nullptr)) { + LOG_ERROR("Failed to get address with partitionId %d", partitionId); + return false; + } // write the header information of the vector batch in local blob auto header = reinterpret_cast(address); header->length = vbRegion->mLength; @@ -361,6 +418,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) if (!mOckBuffer->IsCompress()) { // record write bytes when don't need compress mTotalWriteBytes += header->length; } + if (UNLIKELY(partitionId > mPartitionLengths.size())) { + LOG_ERROR("Illegal partitionId %d", partitionId); + return false; + } mPartitionLengths[partitionId] += header->length; // we can't get real length when compress address += vbHeaderSize; // 8 means header length so skip @@ -382,6 +443,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) bool OckSplitter::FlushAllRegionAndGetNewBlob(VectorBatch &vb) { + if (UNLIKELY(mPartitionNum > mCacheRegion.size())) { + LOG_ERROR("Illegal mPartitionNum %d", mPartitionNum); + return false; + } for (uint32_t partitionId = 0; partitionId < mPartitionNum; ++partitionId) { if (mCacheRegion[partitionId].mRowNum == 0) { continue; @@ -421,6 +486,10 @@ bool OckSplitter::FlushAllRegionAndGetNewBlob(VectorBatch &vb) bool OckSplitter::PreoccupiedBufferSpace(VectorBatch &vb, uint32_t partitionId, uint32_t rowIndex, uint32_t rowLength, bool newRegion) { + if (UNLIKELY(partitionId > mCacheRegion.size())) { + LOG_ERROR("Illegal partitionId %d", partitionId); + return false; + } uint32_t preoccupiedSize = rowLength; if (mCacheRegion[partitionId].mRowNum == 0) { preoccupiedSize += mMinDataLenInVB; // means create a new vector batch, so will cost header @@ -472,7 +541,7 @@ bool OckSplitter::Split(VectorBatch &vb) ResetCacheRegion(); // clear the record about those partition regions in old vector batch mCurrentVB = &vb; // point to current native vector batch address // the first vector in vector batch that record partitionId about same index row when exist multiple partition - mPtViewInCurVB = mIsSinglePt ? nullptr : reinterpret_cast(vb.GetVector(0)); + mPtViewInCurVB = mIsSinglePt ? nullptr : reinterpret_cast *>(vb.Get(0)); // PROFILE_START_L1(PREOCCUPIED_STAGE) for (int rowIndex = 0; rowIndex < vb.GetRowCount(); ++rowIndex) { @@ -499,19 +568,19 @@ bool OckSplitter::Split(VectorBatch &vb) } // release data belong to the vector batch in memory after write it to local blob - vb.ReleaseAllVectors(); + vb.FreeAllVectors(); // PROFILE_END_L1(RELEASE_VECTOR) mCurrentVB = nullptr; return true; } -void OckSplitter::Stop() +bool OckSplitter::Stop() { uint32_t dataSize = 0; if (UNLIKELY(!mOckBuffer->Flush(true, dataSize))) { LogError("Failed to flush local blob when stop."); - return; + return false; } if (mOckBuffer->IsCompress()) { @@ -520,4 +589,5 @@ void OckSplitter::Stop() LOG_INFO("Time cost preoccupied: %lu write_data: %lu release_resource: %lu", mPreoccupiedTime, mWriteVBTime, mReleaseResource); + return true; } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.h similarity index 86% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.h index fc81195099f49a2cae202a10324f0725ee5a08bb..9e239f7aac87cf5ba43c5942425197d1dc10ddf4 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_splitter.h @@ -20,8 +20,6 @@ #include "vec_data.pb.h" #include "ock_hash_write_buffer.h" -#include "memory/base_allocator.h" - using namespace spark; using namespace omniruntime::vec; using namespace omniruntime::type; @@ -47,7 +45,7 @@ public: const int32_t *colTypeIds, int32_t colNum, uint64_t threadId); bool Initialize(const int32_t *colTypeIds); bool Split(VectorBatch &vb); - void Stop(); + bool Stop(); inline bool SetShuffleInfo(const std::string &appId, uint32_t shuffleId, uint32_t stageId, uint32_t stageAttemptNum, uint32_t mapId, uint32_t taskAttemptId) @@ -70,7 +68,10 @@ public: return false; } - InitCacheRegion(); + if (UNLIKELY(!InitCacheRegion())) { + LOG_ERROR("Failed to initialize CacheRegion"); + return false; + } return true; } @@ -89,7 +90,7 @@ private: bool isSinglePt, uint64_t threadId); bool ToSplitterTypeId(const int32_t *vBColTypes); - uint32_t GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex, uint8_t **address) const; + uint32_t GetVarVecValue(VectorBatch &vb, uint32_t rowIndex, uint32_t colIndex) const; uint32_t GetRowLengthInBytes(VectorBatch &vb, uint32_t rowIndex) const; inline uint32_t GetPartitionIdOfRow(uint32_t rowIndex) @@ -98,7 +99,12 @@ private: return mIsSinglePt ? 0 : mPtViewInCurVB->GetValue(rowIndex); } - void InitCacheRegion(); + void CastOmniToShuffleType(DataTypeId omniType, ShuffleTypeId shuffleType, uint32_t size) + { + mVBColShuffleTypes.emplace_back(shuffleType); + mMinDataLenInVBByRow += size; + } + bool InitCacheRegion(); inline void ResetCacheRegion() { @@ -137,21 +143,19 @@ private: bool newRegion); bool WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId); - static bool WriteNullValues(Vector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); + static bool WriteNullValues(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); template - bool WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, + bool WriteFixedWidthValueTemple(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, T *&address); - bool WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, uint64_t *&address); - bool WriteFixedWidthValue(Vector *vector, ShuffleTypeId typeId, std::vector &rowIndexes, + bool WriteDecimal128(BaseVector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, uint64_t *&address); + bool WriteFixedWidthValue(BaseVector *vector, ShuffleTypeId typeId, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); - static bool WriteVariableWidthValue(Vector *vector, std::vector &rowIndexes, uint32_t rowNum, + static bool WriteVariableWidthValue(BaseVector *vector, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address); bool WriteOneVector(VectorBatch &vb, uint32_t colIndex, std::vector &rowIndexes, uint32_t rowNum, uint8_t **address); private: - BaseAllocator *mAllocator = omniruntime::mem::GetProcessRootAllocator(); - static constexpr uint32_t vbDataHeadLen = 8; // Byte static constexpr uint32_t uint8Size = 1; static constexpr uint32_t uint16Size = 2; @@ -159,6 +163,7 @@ private: static constexpr uint32_t uint64Size = 8; static constexpr uint32_t decimal128Size = 16; static constexpr uint32_t vbHeaderSize = 8; + static constexpr uint32_t doubleNum = 2; /* the region use for all vector batch ---------------------------------------------------------------- */ // this splitter which corresponding to one map task in one shuffle, so some params is same uint32_t mPartitionNum = 0; @@ -187,7 +192,7 @@ private: std::vector mCacheRegion {}; // the vector point to vector0 in current vb which record rowIndex -> ptId - IntVector *mPtViewInCurVB = nullptr; + Vector *mPtViewInCurVB = nullptr; /* ock shuffle resource -------------------------------------------------------------------------------- */ OckHashWriteBuffer *mOckBuffer = nullptr; diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_type.h similarity index 33% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_type.h index e07e67f17d7281f5df0e1d4ee17a4949bc1da697..03e444b6ce4e7284a36e859c327cc51546fb26ab 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_type.h @@ -6,7 +6,7 @@ #define SPARK_THESTRAL_PLUGIN_OCK_TYPE_H #include "ock_vector.h" -#include "common/debug.h" +#include "common/common.h" namespace ock { namespace dopspark { @@ -33,58 +33,118 @@ enum class ShuffleTypeId : int { using VBHeaderPtr = struct VBDataHeaderDesc { uint32_t length = 0; // 4Byte uint32_t rowNum = 0; // 4Byte -} __attribute__((packed)) * ; +} __attribute__((packed)) *; -using VBDataDescPtr = struct VBDataDesc { - explicit VBDataDesc(uint32_t colNum) +class VBDataDesc { +public: + VBDataDesc() = default; + ~VBDataDesc() { + for (auto &vector : mColumnsHead) { + if (vector == nullptr) { + continue; + } + auto currVector = vector; + while (currVector->GetNextVector() != nullptr) { + auto nextVector = currVector->GetNextVector(); + currVector->SetNextVector(nullptr); + currVector = nextVector; + } + } + } + + bool Initialize(uint32_t colNum) + { + this->colNum = colNum; mHeader.rowNum = 0; mHeader.length = 0; - mColumnsHead.reserve(colNum); mColumnsHead.resize(colNum); - mColumnsCur.reserve(colNum); mColumnsCur.resize(colNum); - mVectorValueLength.reserve(colNum); - mVectorValueLength.resize(colNum); + mColumnsCapacity.resize(colNum); - for (auto &index : mColumnsHead) { - index = new (std::nothrow) OckVector(); + for (auto &vector : mColumnsHead) { + vector = std::make_shared(); + if (vector == nullptr) { + mColumnsHead.clear(); + return false; + } } + return true; } inline void Reset() { mHeader.rowNum = 0; mHeader.length = 0; - std::fill(mVectorValueLength.begin(), mVectorValueLength.end(), 0); + std::fill(mColumnsCapacity.begin(), mColumnsCapacity.end(), 0); for (uint32_t index = 0; index < mColumnsCur.size(); ++index) { mColumnsCur[index] = mColumnsHead[index]; } } + std::shared_ptr GetColumnHead(uint32_t colIndex) { + if (colIndex >= colNum) { + return nullptr; + } + return mColumnsHead[colIndex]; + } + + void SetColumnCapacity(uint32_t colIndex, uint32_t length) { + mColumnsCapacity[colIndex] = length; + } + + uint32_t GetColumnCapacity(uint32_t colIndex) { + return mColumnsCapacity[colIndex]; + } + + std::shared_ptr GetCurColumn(uint32_t colIndex) + { + if (colIndex >= colNum) { + return nullptr; + } + auto currVector = mColumnsCur[colIndex]; + if (currVector->GetNextVector() == nullptr) { + auto newCurVector = std::make_shared(); + if (UNLIKELY(newCurVector == nullptr)) { + LOG_ERROR("Failed to new instance for ock vector"); + return nullptr; + } + currVector->SetNextVector(newCurVector); + mColumnsCur[colIndex] = newCurVector; + } else { + mColumnsCur[colIndex] = currVector->GetNextVector(); + } + return currVector; + } + + uint32_t GetTotalCapacity() + { + return mHeader.length; + } + + uint32_t GetTotalRowNum() + { + return mHeader.rowNum; + } + + void AddTotalCapacity(uint32_t length) { + mHeader.length += length; + } + + void AddTotalRowNum(uint32_t rowNum) + { + mHeader.rowNum +=rowNum; + } + +private: + uint32_t colNum = 0; VBDataHeaderDesc mHeader; - std::vector mVectorValueLength; - std::vector mColumnsCur; - std::vector mColumnsHead; // Array[List[OckVector *]] -} * ; + std::vector mColumnsCapacity; + std::vector mColumnsCur; + std::vector mColumnsHead; // Array[List[OckVector *]] +}; +using VBDataDescPtr = std::shared_ptr; } } -#define PROFILE_START_L1(name) \ - long tcDiff##name = 0; \ - struct timespec tcStart##name = { 0, 0 }; \ - clock_gettime(CLOCK_MONOTONIC, &tcStart##name); - -#define PROFILE_END_L1(name) \ - struct timespec tcEnd##name = { 0, 0 }; \ - clock_gettime(CLOCK_MONOTONIC, &tcEnd##name); \ - \ - long diffSec##name = tcEnd##name.tv_sec - tcStart##name.tv_sec; \ - if (diffSec##name == 0) { \ - tcDiff##name = tcEnd##name.tv_nsec - tcStart##name.tv_nsec; \ - } else { \ - tcDiff##name = diffSec##name * 1000000000 + tcEnd##name.tv_nsec - tcStart##name.tv_nsec; \ - } - -#define PROFILE_VALUE(name) tcDiff##name #endif // SPARK_THESTRAL_PLUGIN_OCK_TYPE_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_vector.h similarity index 88% rename from omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_vector.h index 0cfca5d63173c04c37771900e1ac17c2c04e8bba..515f88db8355a58321a7290179e48b48802cb8cc 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/src/shuffle/ock_vector.h @@ -69,12 +69,12 @@ public: valueOffsetsAddress = address; } - inline void SetNextVector(OckVector *next) + inline void SetNextVector(std::shared_ptr next) { mNext = next; } - inline OckVector *GetNextVector() + inline std::shared_ptr GetNextVector() { return mNext; } @@ -87,8 +87,9 @@ private: void *valueNullsAddress = nullptr; void *valueOffsetsAddress = nullptr; - OckVector *mNext = nullptr; + std::shared_ptr mNext = nullptr; }; +using OckVectorPtr = std::shared_ptr; } } #endif // SPARK_THESTRAL_PLUGIN_OCK_VECTOR_H diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt similarity index 95% rename from omnioperator/omniop-spark-extension-ock/cpp/test/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt index 53605f08556f538682e83427a130c1684318702f..dedb097bb17e65b3a6e42a602be15423c99e9652 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/CMakeLists.txt @@ -28,7 +28,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.1.0-aarch64 + boostkit-omniop-vector-1.2.0-aarch64 securec ock_columnar_shuffle) diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/CMakeLists.txt similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/CMakeLists.txt diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/ock_shuffle_test.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/ock_shuffle_test.cpp similarity index 91% rename from omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/ock_shuffle_test.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/ock_shuffle_test.cpp index 7980cbf198d192488c313fd719f340fc71c0521a..cc02862fd1b91b1117bdfb07346af13d27db5259 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/shuffle/ock_shuffle_test.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/shuffle/ock_shuffle_test.cpp @@ -54,7 +54,7 @@ bool PrintVectorBatch(uint8_t **startAddress, uint32_t &length) info << "vector_batch: { "; for (uint32_t colIndex = 0; colIndex < gColNum; colIndex++) { auto typeId = static_cast(gVecTypeIds[colIndex]); - Vector *vector = OckNewbuildVector(typeId, rowNum); + BaseVector *vector = OckNewbuildVector(typeId, rowNum); if (typeId == OMNI_VARCHAR) { uint32_t varlength = 0; instance->CalVectorValueLength(colIndex, varlength); @@ -75,29 +75,29 @@ bool PrintVectorBatch(uint8_t **startAddress, uint32_t &length) for (uint32_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { LOG_DEBUG("%d", const_cast((uint8_t*)(VectorHelper::GetNullsAddr(vector)))[rowIndex]); info << "{ rowIndex: " << rowIndex << ", nulls: " << - std::to_string(const_cast((uint8_t*)(VectorHelper::GetNullsAddr(vector)))[rowIndex]); + std::to_string(const_cast((uint8_t*)(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vector)))[rowIndex]); switch (typeId) { case OMNI_SHORT: - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; case OMNI_INT: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_LONG: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_DOUBLE: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_DECIMAL64: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_DECIMAL128: { - info << ", value: " << static_cast(vector)->GetValue(rowIndex) << " }, "; + info << ", value: " << static_cast *>(vector)->GetValue(rowIndex) << " }, "; break; } case OMNI_VARCHAR: { // unknown length for value vector, calculate later @@ -118,9 +118,16 @@ bool PrintVectorBatch(uint8_t **startAddress, uint32_t &length) valueAddress += vector->GetValueOffset(rowIndex); }*/ uint8_t *valueAddress = nullptr; - int32_t length = static_cast(vector)->GetValue(rowIndex, &valueAddress); + int32_t length = reinterpret_cast> *>(vector); std::string valueString(valueAddress, valueAddress + length); - info << ", value: " << valueString << " }, "; + uint32_t length = 0; + std::string_view value; + if (!vc->IsNull(rowIndex)) { + value = vc->GetValue(); + valueAddress = reinterpret_cast(reinterpret_cast(value.data())); + length = static_cast(value.length()); + } + info << ", value: " << value << " }, "; break; } default: @@ -314,7 +321,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Long_Cols) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important // for (uint64_t j = 0; j < 999; j++) { - VectorBatch *vb = OckCreateVectorBatch_1fixedCols_withPid(partitionNum, 10000); + VectorBatch *vb = OckCreateVectorBatch_1fixedCols_withPid(partitionNum, 10000, LongType()); OckTest_splitter_split(splitterId, vb); // } OckTest_splitter_stop(splitterId); @@ -323,7 +330,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Long_Cols) TEST_F(OckShuffleTest, Split_Fixed_Cols) { - int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; // 4Byte + 8Byte + 8Byte + 3Byte + int32_t inputVecTypeIds[] = {OMNI_BOOLEAN, OMNI_SHORT, OMNI_INT, OMNI_LONG, OMNI_DOUBLE}; // 4Byte + 8Byte + 8Byte + 3Byte gVecTypeIds = &inputVecTypeIds[0]; gColNum = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); int partitionNum = 4; @@ -331,7 +338,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Cols) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important // for (uint64_t j = 0; j < 999; j++) { - VectorBatch *vb = OckCreateVectorBatch_3fixedCols_withPid(partitionNum, 999); + VectorBatch *vb = OckCreateVectorBatch_5fixedCols_withPid(partitionNum, 999); OckTest_splitter_split(splitterId, vb); // } OckTest_splitter_stop(splitterId); @@ -340,7 +347,7 @@ TEST_F(OckShuffleTest, Split_Fixed_Cols) TEST_F(OckShuffleTest, Split_Fixed_SinglePartition_SomeNullRow) { - int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; // 4 + 8 + 8 + 4 + 4 + int32_t inputVecTypeIds[] = {OMNI_BOOLEAN, OMNI_SHORT, OMNI_INT, OMNI_LONG, OMNI_DOUBLE, OMNI_VARCHAR}; // 4 + 8 + 8 + 4 + 4 gVecTypeIds = &inputVecTypeIds[0]; gColNum = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); int partitionNum = 1; @@ -399,7 +406,7 @@ TEST_F(OckShuffleTest, Split_Long_10WRows) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important for (uint64_t j = 0; j < 100; j++) { - VectorBatch *vb = OckCreateVectorBatch_1longCol_withPid(partitionNum, 10000); + VectorBatch *vb = OckCreateVectorBatch_1fixedCols_withPid(partitionNum, 10000, LongType()); OckTest_splitter_split(splitterId, vb); } OckTest_splitter_stop(splitterId); @@ -458,7 +465,7 @@ TEST_F(OckShuffleTest, Split_VarChar_First) TEST_F(OckShuffleTest, Split_Dictionary) { - int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG, OMNI_DECIMAL64, OMNI_DECIMAL128}; + int32_t inputVecTypeIds[] = {OMNI_INT, OMNI_LONG}; int partitionNum = 4; gVecTypeIds = &inputVecTypeIds[0]; gColNum = sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]); @@ -483,7 +490,7 @@ TEST_F(OckShuffleTest, Split_OMNI_DECIMAL128) sizeof(inputVecTypeIds) / sizeof(inputVecTypeIds[0]), false, 40960, 41943040, 134217728); gTempSplitId = splitterId; // very important for (uint64_t j = 0; j < 2; j++) { - VectorBatch *vb = OckCreateVectorBatch_1decimal128Col_withPid(partitionNum); + VectorBatch *vb = OckCreateVectorBatch_1decimal128Col_withPid(partitionNum, 999); OckTest_splitter_split(splitterId, vb); } OckTest_splitter_stop(splitterId); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/tptest.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/tptest.cpp similarity index 95% rename from omnioperator/omniop-spark-extension-ock/cpp/test/tptest.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/tptest.cpp index a65c54095332ad5cb420b0691850f855a65d064d..e05871c767a52ea0c88e536cc789ce94445d11ec 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/tptest.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/tptest.cpp @@ -1,11 +1,11 @@ -/* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. - */ - -#include "gtest/gtest.h" - -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +/* + * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + */ + +#include "gtest/gtest.h" + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/CMakeLists.txt b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/CMakeLists.txt similarity index 100% rename from omnioperator/omniop-spark-extension-ock/cpp/test/utils/CMakeLists.txt rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/CMakeLists.txt diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.cpp b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.cpp similarity index 39% rename from omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.cpp rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.cpp index 2b49ba28ffaaf79621de049fb59e39120cad5490..251aea490f144d386820074bf9101c92594022fc 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.cpp +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.cpp @@ -10,7 +10,7 @@ using namespace omniruntime::vec; using namespace omniruntime::type; -void OckToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::vector &dataTypes) +/*void OckToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::vector &dataTypes) { for (int i = 0; i < dataTypeCount; ++i) { if (dataTypeIds[i] == OMNI_VARCHAR) { @@ -22,125 +22,39 @@ void OckToVectorTypes(const int32_t *dataTypeIds, int32_t dataTypeCount, std::ve } dataTypes.emplace_back(DataType(dataTypeIds[i])); } -} +}*/ -VectorBatch *OckCreateInputData(const int32_t numRows, const int32_t numCols, int32_t *inputTypeIds, int64_t *allData) +VectorBatch *OckCreateInputData(const DataType &types, int32_t rowCount, ...) { - auto *vecBatch = new VectorBatch(numCols, numRows); - std::vector inputTypes; - OckToVectorTypes(inputTypeIds, numCols, inputTypes); - vecBatch->NewVectors(VectorAllocator::GetGlobalAllocator(), inputTypes); - for (int i = 0; i < numCols; ++i) { - switch (inputTypeIds[i]) { - case OMNI_INT: - ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_LONG: - ((LongVector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *)allData[i], numRows); - break; - case OMNI_DOUBLE: - ((DoubleVector *)vecBatch->GetVector(i))->SetValues(0, (double *)allData[i], numRows); - break; - case OMNI_SHORT: - ((IntVector *)vecBatch->GetVector(i))->SetValues(0, (int32_t *)allData[i], numRows); - break; - case OMNI_VARCHAR: - case OMNI_CHAR: { - for (int j = 0; j < numRows; ++j) { - int64_t addr = (reinterpret_cast(allData[i]))[j]; - std::string s(reinterpret_cast(addr)); - ((VarcharVector *)vecBatch->GetVector(i))->SetValue(j, (uint8_t *)(s.c_str()), s.length()); - } - break; - } - case OMNI_DECIMAL128: - ((Decimal128Vector *)vecBatch->GetVector(i))->SetValues(0, (int64_t *)allData[i], numRows); - break; - default: { - LogError("No such data type %d", inputTypeIds[i]); - } - } + int32_t typesCount = types.GetSize(); + auto *vecBatch = new VectorBatch(rowCount); + va_list args; + va_start(args, rowCount); + for (int32_t i = 0; i< typesCount; i++) { + dataTypePtr = type = types.GetType(i); + VectorBatch->Append(CreateVector(*type, rowCount, args)); } + va_end(args); return vecBatch; } -VarcharVector *OckCreateVarcharVector(VarcharDataType type, std::string *values, int32_t length) +BaseVector *CreateVector(DataType &dataType, int32_t rowCount, va_list &args) { - VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); - uint32_t width = type.GetWidth(); - VarcharVector *vector = std::make_unique(vecAllocator, length * width, length).release(); - uint32_t offset = 0; - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, reinterpret_cast(values[i].c_str()), values[i].length()); - bool isNull = values[i].empty() ? true : false; - vector->SetValueNull(i, isNull); - vector->SetValueOffset(i, offset); - offset += values[i].length(); - } - - if (length > 0) { - vector->SetValueOffset(values->size(), offset); - } - - std::stringstream offsetValue; - offsetValue << "{ "; - for (uint32_t index = 0; index < length; index++) { - offsetValue << vector->GetValueOffset(index) << ", "; - } - - offsetValue << vector->GetValueOffset(values->size()) << " }"; - - LOG_INFO("%s", offsetValue.str().c_str()); - - return vector; -} - -Decimal128Vector *OckCreateDecimal128Vector(Decimal128 *values, int32_t length) -{ - VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); - Decimal128Vector *vector = std::make_unique(vecAllocator, length).release(); - for (int32_t i = 0; i < length; i++) { - vector->SetValue(i, values[i]); - } - return vector; + return DYNAMIC_TYPE_DISPATCH(CreateFlatVector, dataType.GetId(), rowCount, args); } -Vector *OckCreateVector(DataType &vecType, int32_t rowCount, va_list &args) -{ - switch (vecType.GetId()) { - case OMNI_INT: - case OMNI_DATE32: - return OckCreateVector(va_arg(args, int32_t *), rowCount); - case OMNI_LONG: - case OMNI_DECIMAL64: - return OckCreateVector(va_arg(args, int64_t *), rowCount); - case OMNI_DOUBLE: - return OckCreateVector(va_arg(args, double *), rowCount); - case OMNI_BOOLEAN: - return OckCreateVector(va_arg(args, bool *), rowCount); - case OMNI_VARCHAR: - case OMNI_CHAR: - return OckCreateVarcharVector(static_cast(vecType), va_arg(args, std::string *), - rowCount); - case OMNI_DECIMAL128: - return OckCreateDecimal128Vector(va_arg(args, Decimal128 *), rowCount); - default: - std::cerr << "Unsupported type : " << vecType.GetId() << std::endl; - return nullptr; - } -} -DictionaryVector *OckCreateDictionaryVector(DataType &vecType, int32_t rowCount, int32_t *ids, int32_t idsCount, ...) +BaseVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t *ids, int32_t idsCount, + ..) { va_list args; va_start(args, idsCount); - Vector *dictionary = OckCreateVector(vecType, rowCount, args); + BaseVector *dictionary = CreateVector(dataType, rowCount, args); va_end(args); - auto vec = std::make_unique(dictionary, ids, idsCount).release(); - delete dictionary; - return vec; + return DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary, ids, idsCount); } +/* Vector *OckbuildVector(const DataType &aggType, int32_t rowNumber) { VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); @@ -212,47 +126,37 @@ Vector *OckbuildVector(const DataType &aggType, int32_t rowNumber) return nullptr; } } -} +}*/ -Vector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber) +BaseVector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber) { - VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); - switch (typeId) { + switch (typeId) { case OMNI_SHORT: { - auto *col = new ShortVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_NONE: { - auto *col = new LongVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_INT: case OMNI_DATE32: { - auto *col = new IntVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_LONG: case OMNI_DECIMAL64: { - auto *col = new LongVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_DOUBLE: { - auto *col = new DoubleVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_BOOLEAN: { - auto *col = new BooleanVector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_DECIMAL128: { - auto *col = new Decimal128Vector(vecAllocator, rowNumber); - return col; + return new Vector(rowNumber); } case OMNI_VARCHAR: case OMNI_CHAR: { - VarcharDataType charType = (VarcharDataType &)typeId; - auto *col = new VarcharVector(vecAllocator, charType.GetWidth() * rowNumber, rowNumber); - return col; + return new Vector>(rowNumber); } default: { LogError("No such %d type support", typeId); @@ -261,15 +165,15 @@ Vector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber) } } -VectorBatch *OckCreateVectorBatch(DataTypes &types, int32_t rowCount, ...) +VectorBatch *OckCreateVectorBatch(const DataTypes &types, int32_t rowCount, ...) { int32_t typesCount = types.GetSize(); - VectorBatch *vectorBatch = std::make_unique(typesCount).release(); + auto *vectorBatch = new vecBatch(rowCount); va_list args; va_start(args, rowCount); for (int32_t i = 0; i < typesCount; i++) { - DataType type = types.Get()[i]; - vectorBatch->SetVector(i, OckCreateVector(type, rowCount, args)); + dataTypePtr type = types.GetType(i); + vectorBatch->Append(OckCreateVector(*type, rowCount, args)); } va_end(args); return vectorBatch; @@ -286,23 +190,46 @@ VectorBatch *OckCreateVectorBatch_1row_varchar_withPid(int pid, const std::strin { // gen vectorBatch const int32_t numCols = 2; - auto inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - + DataTypes inputTypes(std::vector)({ IntType(), VarcharType()}); const int32_t numRows = 1; auto *col1 = new int32_t[numRows]; col1[0] = pid; - auto *col2 = new int64_t[numRows]; - auto *strTmp = new std::string(std::move(inputString)); - col2[0] = (int64_t)(strTmp->c_str()); + auto *col2 = new std::string[numRows]; + col2[0] = std::move(inputString); + VectorBatch *in = OckCreateInputData(inputTypes, numCols, col1, col2); + delete[] col1; + delete[] col2; + return in; +} - int64_t allData[numCols] = {reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); +VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum) +{ + int partitionNum = parNum; + const int32_t numCols = 5; + DataTypes inputTypes(std::vector)({ IntType(), VarcharType(), VarcharType(), VarcharType(), VarcharType() }); + const int32_t numRows = rowNum; + auto *col0 = new int32_t[numRows]; + auto *col1 = new std::string[numRows]; + auto *col2 = new std::string[numRows]; + auto *col3 = new std::string[numRows]; + auto *col4 = new std::string[numRows]; + col0[i] = (i + 1) % partitionNum; + std::string strTmp1 = std::string("Col1_START_" + to_string(i + 1) + "_END_"); + col1[i] = std::move(strTmp1); + std::string strTmp2 = std::string("Col2_START_" + to_string(i + 1) + "_END_"); + col2[i] = std::move(strTmp2); + std::string strTmp3 = std::string("Col3_START_" + to_string(i + 1) + "_END_"); + col3[i] = std::move(strTmp3); + std::string strTmp4 = std::string("Col4_START_" + to_string(i + 1) + "_END_"); + col4[i] = std::move(strTmp4); + } + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); + delete[] col0; delete[] col1; delete[] col2; - delete strTmp; + delete[] col3; + delete[] col4; return in; } @@ -316,229 +243,104 @@ VectorBatch *OckCreateVectorBatch_1row_varchar_withPid(int pid, const std::strin VectorBatch *OckCreateVectorBatch_4col_withPid(int parNum, int rowNum) { int partitionNum = parNum; - const int32_t numCols = 6; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_INT; - inputTypes[2] = OMNI_LONG; - inputTypes[3] = OMNI_DOUBLE; - inputTypes[4] = OMNI_VARCHAR; - inputTypes[5] = OMNI_SHORT; - + DataTypes inputTypes(std::vector)({ IntType(), VarcharType(), VarcharType(), VarcharType(), VarcharType() }); + const int32_t numRows = rowNum; auto *col0 = new int32_t[numRows]; auto *col1 = new int32_t[numRows]; auto *col2 = new int64_t[numRows]; auto *col3 = new double[numRows]; - auto *col4 = new int64_t[numRows]; - auto *col5 = new int16_t[numRows]; + auto *col4 = new std::string[numRows]; std::string startStr = "_START_"; std::string endStr = "_END_"; - - std::vector string_cache_test_; + std::vector string_cache_test_; for (int i = 0; i < numRows; i++) { col0[i] = (i + 1) % partitionNum; col1[i] = i + 1; col2[i] = i + 1; col3[i] = i + 1; - auto *strTmp = new std::string(startStr + std::to_string(i + 1) + endStr); - string_cache_test_.push_back(strTmp); - col4[i] = (int64_t)((*strTmp).c_str()); - col5[i] = i + 1; + std::string strTmp = std::string(startStr + to_string(i + 1) + endStr); + col4[i] = std::move(strTmp); } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4), - reinterpret_cast(col5)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4); delete[] col0; delete[] col1; delete[] col2; delete[] col3; delete[] col4; - - for (int p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // 释放内存 - } return in; } -VectorBatch *OckCreateVectorBatch_1longCol_withPid(int parNum, int rowNum) -{ - int partitionNum = parNum; - const int32_t numCols = 2; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_LONG; - - const int32_t numRows = rowNum; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - for (int i = 0; i < numRows; i++) { - col0[i] = (i + 1) % partitionNum; - col1[i] = i + 1; - } - - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - for (int i = 0; i < 2; i++) { - delete (int64_t *)allData[i]; // 释放内存 - } - return in; -} - -VectorBatch *OckCreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar) -{ - const int32_t numCols = 3; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_INT; +VectorBatch* CreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar) { + DataTypes inputTypes(std::vector({ IntType(), VarcharType(), IntType() })); const int32_t numRows = 1; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - auto *col2 = new int32_t[numRows]; + auto* col0 = new int32_t[numRows]; + auto* col1 = new std::string[numRows]; + auto* col2 = new int32_t[numRows]; col0[0] = pid; - auto *strTmp = new std::string(strVar); - col1[0] = (int64_t)(strTmp->c_str()); + col1[0] = std::move(strVar); col2[0] = intVar; - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); delete[] col0; delete[] col1; delete[] col2; - delete strTmp; return in; } -VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum) +VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int rowNum, dataTypePtr fixColType) { int partitionNum = parNum; - const int32_t numCols = 5; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_VARCHAR; - inputTypes[2] = OMNI_VARCHAR; - inputTypes[3] = OMNI_VARCHAR; - inputTypes[4] = OMNI_VARCHAR; + DataTypes inputTypes(std::vector({ IntType(), std::move(fixColType) })); const int32_t numRows = rowNum; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - auto *col2 = new int64_t[numRows]; - auto *col3 = new int64_t[numRows]; - auto *col4 = new int64_t[numRows]; - - std::vector string_cache_test_; + auto* col0 = new int32_t[numRows]; + auto* col1 = new int64_t[numRows]; for (int i = 0; i < numRows; i++) { col0[i] = (i + 1) % partitionNum; - auto *strTmp1 = new std::string("Col1_START_" + std::to_string(i + 1) + "_END_"); - col1[i] = (int64_t)((*strTmp1).c_str()); - auto *strTmp2 = new std::string("Col2_START_" + std::to_string(i + 1) + "_END_"); - col2[i] = (int64_t)((*strTmp2).c_str()); - auto *strTmp3 = new std::string("Col3_START_" + std::to_string(i + 1) + "_END_"); - col3[i] = (int64_t)((*strTmp3).c_str()); - auto *strTmp4 = new std::string("Col4_START_" + std::to_string(i + 1) + "_END_"); - col4[i] = (int64_t)((*strTmp4).c_str()); - string_cache_test_.push_back(strTmp1); - string_cache_test_.push_back(strTmp2); - string_cache_test_.push_back(strTmp3); - string_cache_test_.push_back(strTmp4); - } - - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3), - reinterpret_cast(col4)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; - delete[] col0; - delete[] col1; - delete[] col2; - delete[] col3; - delete[] col4; - - for (int p = 0; p < string_cache_test_.size(); p++) { - delete string_cache_test_[p]; // 释放内存 - } - return in; -} - -VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int32_t rowNum) -{ - int partitionNum = parNum; - - // gen vectorBatch - const int32_t numCols = 1; - auto *inputTypes = new int32_t[numCols]; - // inputTypes[0] = OMNI_INT; - inputTypes[0] = OMNI_LONG; - - const uint32_t numRows = rowNum; - - std::cout << "gen row " << numRows << std::endl; - // auto *col0 = new int32_t[numRows]; - auto *col1 = new int64_t[numRows]; - for (int i = 0; i < numRows; i++) { - // col0[i] = 0; // i % partitionNum; col1[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col1)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; - // delete[] col0; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; delete[] col1; - return in; + return in; } -VectorBatch *OckCreateVectorBatch_3fixedCols_withPid(int parNum, int rowNum) +VectorBatch *OckCreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum) { int partitionNum = parNum; - // gen vectorBatch - const int32_t numCols = 4; - auto *inputTypes = new int32_t[numCols]; - inputTypes[0] = OMNI_INT; - inputTypes[1] = OMNI_INT; - inputTypes[2] = OMNI_LONG; - inputTypes[3] = OMNI_DOUBLE; + DataTypes inputTypes( + std::vector({ IntType(), BooleanType(), ShortType(), IntType(), LongType(), DoubleType() })); const int32_t numRows = rowNum; - auto *col0 = new int32_t[numRows]; - auto *col1 = new int32_t[numRows]; - auto *col2 = new int64_t[numRows]; - auto *col3 = new double[numRows]; + auto* col0 = new int32_t[numRows]; + auto* col1 = new bool[numRows]; + auto* col2 = new int16_t[numRows]; + auto* col3 = new int32_t[numRows]; + auto* col4 = new int64_t[numRows]; + auto* col5 = new double[numRows]; for (int i = 0; i < numRows; i++) { col0[i] = i % partitionNum; - col1[i] = i + 1; + col1[i] = (i % 2) == 0 ? true : false; col2[i] = i + 1; col3[i] = i + 1; + col4[i] = i + 1; + col5[i] = i + 1; } - int64_t allData[numCols] = {reinterpret_cast(col0), - reinterpret_cast(col1), - reinterpret_cast(col2), - reinterpret_cast(col3)}; - VectorBatch *in = OckCreateInputData(numRows, numCols, inputTypes, allData); - delete[] inputTypes; + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2, col3, col4, col5); delete[] col0; delete[] col1; delete[] col2; delete[] col3; - return in; + delete[] col4; + delete[] col5; + return in; } VectorBatch *OckCreateVectorBatch_2dictionaryCols_withPid(int partitionNum) @@ -547,121 +349,121 @@ VectorBatch *OckCreateVectorBatch_2dictionaryCols_withPid(int partitionNum) // construct input data const int32_t dataSize = 6; // prepare data - int32_t data0[dataSize] = {111, 112, 113, 114, 115, 116}; - int64_t data1[dataSize] = {221, 222, 223, 224, 225, 226}; - int64_t data2[dataSize] = {111, 222, 333, 444, 555, 666}; - Decimal128 data3[dataSize] = {Decimal128(0, 1), Decimal128(0, 2), Decimal128(0, 3), Decimal128(0, 4), Decimal128(0, 5), Decimal128(0, 6)}; - void *datas[4] = {data0, data1, data2, data3}; + auto *col0 = new int32_t[dataSize]; + for (int32_t i = 0; i< dataSize; i++) { + col0[i] = (i + 1) % partitionNum; + } + int32_t col1[dataSize] = {111, 112, 113, 114, 115, 116}; + int64_t col2[dataSize] = {221, 222, 223, 224, 225, 226}; + void *datas[2] = {col1, col2}; + DataTypes sourceTypes(std::vector({ IntType(), LongType() })); + int32_t ids[] = {0, 1, 2, 3, 4, 5}; - DataTypes sourceTypes(std::vector({ IntDataType(), LongDataType(), Decimal64DataType(7, 2), Decimal128DataType(38, 2)})); + VectorBatch *vectorBatch = new VectorBatch(dataSize); + auto Vec0 = CreateVector(dataSize, col0); + vectorBatch->Append(Vec0); + auto dicVec0 = CreateDictionaryVector(*sourceTypes.GetType(0), dataSize, ids, dataSize, datas[0]); + auto dicVec1 = CreateDictionaryVector(*sourceTypes.GetType(1), dataSize, ids, dataSize, datas[1]); + vectorBatch->Append(dicVec0); + vectorBatch->Append(dicVec1); - int32_t ids[] = {0, 1, 2, 3, 4, 5}; - auto vectorBatch = new VectorBatch(5, dataSize); - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - auto intVectorTmp = new IntVector(allocator, 6); - for (int i = 0; i < intVectorTmp->GetSize(); i++) { - intVectorTmp->SetValue(i, (i + 1) % partitionNum); - } - for (int32_t i = 0; i < 5; i++) { - if (i == 0) { - vectorBatch->SetVector(i, intVectorTmp); - } else { - omniruntime::vec::DataType dataType = sourceTypes.Get()[i - 1]; - vectorBatch->SetVector(i, OckCreateDictionaryVector(dataType, dataSize, ids, dataSize, datas[i - 1])); - } - } + delete[] col0; return vectorBatch; } VectorBatch *OckCreateVectorBatch_1decimal128Col_withPid(int partitionNum) { - int32_t ROW_PER_VEC_BATCH = 999; - auto decimal128InputVec = OckbuildVector(Decimal128DataType(38, 2), ROW_PER_VEC_BATCH); - VectorAllocator *allocator = omniruntime::vec::GetProcessGlobalVecAllocator(); - auto *intVectorPid = new IntVector(allocator, ROW_PER_VEC_BATCH); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i + 1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = Decimal128(0, 1); } - auto *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch *OckCreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = OckbuildVector(Decimal64DataType(7, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; } - VectorBatch *vecBatch = new VectorBatch(2); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1); + delete[] col0; + delete[] col1; + return in; } VectorBatch *OckCreateVectorBatch_2decimalCol_withPid(int partitionNum, int rowNum) { - auto decimal64InputVec = OckbuildVector(Decimal64DataType(7, 2), rowNum); - auto decimal128InputVec = OckbuildVector(Decimal128DataType(38, 2), rowNum); - VectorAllocator *allocator = VectorAllocator::GetGlobalAllocator(); - IntVector *intVectorPid = new IntVector(allocator, rowNum); - for (int i = 0; i < intVectorPid->GetSize(); i++) { - intVectorPid->SetValue(i, (i+1) % partitionNum); + const int32_t numRows = rowNum; + DataTypes inputTypes(std::vector({ IntType(), Decimal64Type(7, 2), Decimal128Type(38, 2) })); + + auto *col0 = new int32_t[numRows]; + auto *col1 = new int64_t[numRows]; + auto *col2 = new Decimal128[numRows]; + for (int32_t i = 0; i < numRows; i++) { + col0[i] = (i + 1) % partitionNum; + col1[i] = 1; + col2[i] = Decimal128(0, 1); } - VectorBatch *vecBatch = new VectorBatch(3); - vecBatch->SetVector(0, intVectorPid); - vecBatch->SetVector(1, decimal64InputVec); - vecBatch->SetVector(2, decimal128InputVec); - return vecBatch; + + VectorBatch* in = CreateVectorBatch(inputTypes, numRows, col0, col1, col2); + delete[] col0; + delete[] col1; + delete[] col2; + return in; } VectorBatch *OckCreateVectorBatch_someNullRow_vectorBatch() { const int32_t numRows = 6; - int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; - int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; - double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; - std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - - auto vec0 = OckCreateVector(data1, numRows); - auto vec1 = OckCreateVector(data2, numRows); - auto vec2 = OckCreateVector(data3, numRows); - auto vec3 = OckCreateVarcharVector(VarcharDataType(varcharType), data4, numRows); - for (int i = 0; i < numRows; i = i + 2) { - vec0->SetValueNull(i, false); - vec1->SetValueNull(i, false); - vec2->SetValueNull(i, false); + const int32_t numCols = 6; + bool data0[numRows] = {true, false, true, false, true, false}; + int16_t data1[numRows] = {0, 1, 2, 3, 4, 6}; + int32_t data2[numRows] = {0, 1, 2, 0, 1, 2}; + int64_t data3[numRows] = {0, 1, 2, 3, 4, 5}; + double data4[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; + std::string data5[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; + + DataTypes inputTypes( + std::vector({ BooleanType(), ShortType(), IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data0, data1, data2, data3, data4, data5); + for (int32_t i = 0; i < numCols; i++) { + for (int32_t j = 0; j < numRows; j = j + 2) { + vecBatch->Get(i)->SetNull(j); + } } - auto *vecBatch = new VectorBatch(4); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); return vecBatch; } VectorBatch *OckCreateVectorBatch_someNullCol_vectorBatch() { const int32_t numRows = 6; + const int32_t numCols = 4; int32_t data1[numRows] = {0, 1, 2, 0, 1, 2}; int64_t data2[numRows] = {0, 1, 2, 3, 4, 5}; double data3[numRows] = {0.0, 1.1, 2.2, 3.3, 4.4, 5.5}; std::string data4[numRows] = {"abcde", "fghij", "klmno", "pqrst", "", ""}; - auto vec0 = OckCreateVector(data1, numRows); - auto vec1 = OckCreateVector(data2, numRows); - auto vec2 = OckCreateVector(data3, numRows); - auto vec3 = OckCreateVarcharVector(VarcharDataType(varcharType), data4, numRows); - for (int i = 0; i < numRows; i = i + 1) { - vec1->SetValueNull(i); - vec3->SetValueNull(i); + DataTypes inputTypes(std::vector({ IntType(), LongType(), DoubleType(), VarcharType(5) })); + VectorBatch* vecBatch = CreateVectorBatch(inputTypes, numRows, data1, data2, data3, data4); + for (int32_t i = 0; i < numCols; i = i + 2) { + for (int32_t j = 0; j < numRows; j++) { + vecBatch->Get(i)->SetNull(j); + } } - auto *vecBatch = new VectorBatch(4); - vecBatch->SetVector(0, vec0); - vecBatch->SetVector(1, vec1); - vecBatch->SetVector(2, vec2); - vecBatch->SetVector(3, vec3); return vecBatch; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.h b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.h similarity index 52% rename from omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.h rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.h index 9695a5ad6f1015ed230e426c20b1981b98891a84..6ffb74492d39dd81c17c6b7c21fb1a9b557c3085 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/test/utils/ock_test_utils.h +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/cpp/test/utils/ock_test_utils.h @@ -11,7 +11,7 @@ #include #include #include - +#include #include "../../src/jni/concurrent_map.h" #define private public static const int varcharType = 5; @@ -22,29 +22,29 @@ static ock::dopspark::ConcurrentMap> static std::string Ocks_shuffle_tests_dir = "/tmp/OckshuffleTests"; -VectorBatch *OckCreateInputData(const int32_t numRows, const int32_t numCols, int32_t *inputTypeIds, int64_t *allData); +std::unique_ptr CreateVector(DataType &dataType, int32_t rowCount, va_list &args); + +VectorBatch *OckCreateInputData(const DataTypes &types, int32_t rowCount, ...); -Vector *OckbuildVector(const DataType &aggType, int32_t rowNumber); +VectorBatch *OckCreateVectorBatch(const DataTypes &types, int32_t rowCount, ...); -Vector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber); +BaseVector *OckNewbuildVector(const DataTypeId &typeId, int32_t rowNumber); + +VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum); VectorBatch *OckCreateVectorBatch_1row_varchar_withPid(int pid, const std::string &inputChar); VectorBatch *OckCreateVectorBatch_4col_withPid(int parNum, int rowNum); -VectorBatch *OckCreateVectorBatch_1longCol_withPid(int parNum, int rowNum); - VectorBatch *OckCreateVectorBatch_2column_1row_withPid(int pid, std::string strVar, int intVar); -VectorBatch *OckCreateVectorBatch_4varcharCols_withPid(int parNum, int rowNum); - -VectorBatch *OckCreateVectorBatch_3fixedCols_withPid(int parNum, int rowNum); +VectorBatch *OckCreateVectorBatch_5fixedCols_withPid(int parNum, int rowNum); -VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int32_t rowNum); +VectorBatch *OckCreateVectorBatch_1fixedCols_withPid(int parNum, int32_t rowNum, DataTypePtr fixColType); VectorBatch *OckCreateVectorBatch_2dictionaryCols_withPid(int partitionNum); -VectorBatch *OckCreateVectorBatch_1decimal128Col_withPid(int partitionNum); +VectorBatch *OckCreateVectorBatch_1decimal128Col_withPid(int partitionNum, int rowNum); VectorBatch *OckCreateVectorBatch_1decimal64Col_withPid(int partitionNum, int rowNum); @@ -67,6 +67,53 @@ void OckTest_splitter_stop(long splitter_id); void OckTest_splitter_close(long splitter_id); +template BaseVector *CreateVector(int32_t length, T *values) +{ + std::unique_ptr> vector = std::make_unique>(length); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, values[i]); + } + return vector; +} + +template +BaseVector *CreateFlatVector(int32_t length, va_list &args) +{ + using namespace omniruntime::type; + using T = typename NativeType::type; + using VarcharVector = Vector>; + if constexpr (std::is_same_v) { + VarcharVector *vector = new VarcharVector(length); + std::string *str = va_arg(args, std::string *); + for (int32_t i = 0; i < length; i++) { + std::string_view value(str[i].data(), str[i].length()); + vector->SetValue(i, value); + } + return vector; + } else { + Vector *vector = new Vector(length); + T *value = va_arg(args, T *); + for (int32_t i = 0; i < length; i++) { + vector->SetValue(i, value[i]); + } + return vector; + } +} + +template +BaseVector *CreateDictionary(BaseVector *vector, int32_t *ids, int32_t size) +{ + using T = typename NativeType::type; + if constexpr (std::is_same_v) { + return VectorHelper::CreateStringDictionary(ids, size, + reinterpret_cast> *>(vector)); + } else { + return VectorHelper::CreateDictionary(ids, size, reinterpret_cast *>(vector)); + } +} + + + template T *OckCreateVector(V *values, int32_t length) { VectorAllocator *vecAllocator = VectorAllocator::GetGlobalAllocator(); diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..b2fdb093d1a890acfe16eb154c522d1af04baf0e --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/pom.xml @@ -0,0 +1,122 @@ + + + 4.0.0 + + com.huawei.ock + omniop-spark-extension-ock + 23.0.0 + + + cpp/ + cpp/build/releases/ + FALSE + 0.6.1 + + + ock-omniop-shuffle-manager + jar + Huawei Open Computing Kit for Spark, shuffle manager + 23.0.0 + + + + ${project.artifactId}-${project.version}-for-${input.version} + + + ${cpp.build.dir} + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + net.alchim31.maven + scala-maven-plugin + ${scala.plugin.version} + + all + + + + + compile + testCompile + + + + -dependencyfile + ${project.build.directory}/.scala_dependencies + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + 8 + 8 + true + + -Xlint:all + + + + + exec-maven-plugin + org.codehaus.mojo + 3.0.0 + + + Build CPP + generate-resources + + exec + + + bash + + ${cpp.dir}/build.sh + ${plugin.cpp.test} + + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + ${protobuf.maven.version} + + ${project.basedir}/../cpp/src/proto + + + + + compile + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.plugin.version} + + + + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java similarity index 100% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/NativeLoader.java diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java similarity index 96% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java index ec294bdbf2208361846b4576ba0559abb9cfabc2..462ad9d105a54374bc867a9d83e45133fc238332 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java @@ -150,8 +150,18 @@ public class OckShuffleJniReader { nativeCopyVecDataInVB(nativeReader, dstVec.getNativeVector(), colIndex); } + /** + * close reader. + * + */ + public void doClose() { + close(nativeReader); + } + private native long make(int[] typeIds); + private native long close(long readerId); + private native int nativeGetVectorBatch(long readerId, long vbDataAddr, int capacity, int maxRow, int maxDataSize, Long rowCnt); diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java similarity index 96% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java index 5e6094019d25df8ba094cb626d762fc158352d7a..08813362a20b4aaecb8cd78e48a5e47b389a0b7b 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniWriter.java @@ -1,122 +1,122 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package com.huawei.ock.spark.jni; - -import com.huawei.boostkit.spark.vectorized.PartitionInfo; -import com.huawei.boostkit.spark.vectorized.SplitResult; - -import java.rmi.UnexpectedException; - -/** - * OckShuffleJniWriter. - * - * @since 2022-6-10 - */ -public class OckShuffleJniWriter { - /** - * OckShuffleJniWriter constructor. - * - * @throws UnexpectedException UnexpectedException - */ - public OckShuffleJniWriter() throws UnexpectedException { - NativeLoader.getInstance(); - boolean isInitSuc = doInitialize(); - if (!isInitSuc) { - throw new UnexpectedException("OckShuffleJniWriter initialization failed"); - } - } - - /** - * make - * - * @param appId appId - * @param shuffleId shuffleId - * @param stageId stageId - * @param stageAttemptNumber stageAttemptNumber - * @param mapId mapId - * @param taskAttemptId taskAttemptId - * @param part part - * @param capacity capacity - * @param maxCapacity maxCapacity - * @param minCapacity minCapacity - * @param isCompress isCompress - * @return splitterId - */ - public long make(String appId, int shuffleId, int stageId, int stageAttemptNumber, - int mapId, long taskAttemptId, PartitionInfo part, int capacity, int maxCapacity, - int minCapacity, boolean isCompress) { - return nativeMake( - appId, - shuffleId, - stageId, - stageAttemptNumber, - mapId, - taskAttemptId, - part.getPartitionName(), - part.getPartitionNum(), - part.getInputTypes(), - part.getNumCols(), - capacity, - maxCapacity, - minCapacity, - isCompress); - } - - /** - * Create ock shuffle native writer - * - * @param appId appId - * @param shuffleId shuffleId - * @param stageId stageId - * @param stageAttemptNumber stageAttemptNumber - * @param mapId mapId - * @param taskAttemptId taskAttemptId - * @param partitioningMethod partitioningMethod - * @param numPartitions numPartitions - * @param inputTpyes inputTpyes - * @param numCols numCols - * @param capacity capacity - * @param maxCapacity maxCapacity - * @param minCapacity minCapacity - * @param isCompress isCompress - * @return splitterId - */ - public native long nativeMake(String appId, int shuffleId, int stageId, int stageAttemptNumber, - int mapId, long taskAttemptId, String partitioningMethod, int numPartitions, - String inputTpyes, int numCols, int capacity, int maxCapacity, int minCapacity, - boolean isCompress); - - private boolean doInitialize() { - return initialize(); - } - - private native boolean initialize(); - - /** - * Split one record batch represented by bufAddrs and bufSizes into several batches. The batch is - * split according to the first column as partition id. During splitting, the data in native - * buffers will be write to disk when the buffers are full. - * - * @param splitterId splitter instance id - * @param nativeVectorBatch Addresses of nativeVectorBatch - */ - public native void split(long splitterId, long nativeVectorBatch); - - /** - * Write the data remained in the buffers hold by native splitter to each partition's temporary - * file. And stop processing splitting - * - * @param splitterId splitter instance id - * @return SplitResult - */ - public native SplitResult stop(long splitterId); - - /** - * Release resources associated with designated splitter instance. - * - * @param splitterId splitter instance id - */ - public native void close(long splitterId); +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package com.huawei.ock.spark.jni; + +import com.huawei.boostkit.spark.vectorized.PartitionInfo; +import com.huawei.boostkit.spark.vectorized.SplitResult; + +import java.rmi.UnexpectedException; + +/** + * OckShuffleJniWriter. + * + * @since 2022-6-10 + */ +public class OckShuffleJniWriter { + /** + * OckShuffleJniWriter constructor. + * + * @throws UnexpectedException UnexpectedException + */ + public OckShuffleJniWriter() throws UnexpectedException { + NativeLoader.getInstance(); + boolean isInitSuc = doInitialize(); + if (!isInitSuc) { + throw new UnexpectedException("OckShuffleJniWriter initialization failed"); + } + } + + /** + * make + * + * @param appId appId + * @param shuffleId shuffleId + * @param stageId stageId + * @param stageAttemptNumber stageAttemptNumber + * @param mapId mapId + * @param taskAttemptId taskAttemptId + * @param part part + * @param capacity capacity + * @param maxCapacity maxCapacity + * @param minCapacity minCapacity + * @param isCompress isCompress + * @return splitterId + */ + public long make(String appId, int shuffleId, int stageId, int stageAttemptNumber, + int mapId, long taskAttemptId, PartitionInfo part, int capacity, int maxCapacity, + int minCapacity, boolean isCompress) { + return nativeMake( + appId, + shuffleId, + stageId, + stageAttemptNumber, + mapId, + taskAttemptId, + part.getPartitionName(), + part.getPartitionNum(), + part.getInputTypes(), + part.getNumCols(), + capacity, + maxCapacity, + minCapacity, + isCompress); + } + + /** + * Create ock shuffle native writer + * + * @param appId appId + * @param shuffleId shuffleId + * @param stageId stageId + * @param stageAttemptNumber stageAttemptNumber + * @param mapId mapId + * @param taskAttemptId taskAttemptId + * @param partitioningMethod partitioningMethod + * @param numPartitions numPartitions + * @param inputTpyes inputTpyes + * @param numCols numCols + * @param capacity capacity + * @param maxCapacity maxCapacity + * @param minCapacity minCapacity + * @param isCompress isCompress + * @return splitterId + */ + public native long nativeMake(String appId, int shuffleId, int stageId, int stageAttemptNumber, + int mapId, long taskAttemptId, String partitioningMethod, int numPartitions, + String inputTpyes, int numCols, int capacity, int maxCapacity, int minCapacity, + boolean isCompress); + + private boolean doInitialize() { + return initialize(); + } + + private native boolean initialize(); + + /** + * Split one record batch represented by bufAddrs and bufSizes into several batches. The batch is + * split according to the first column as partition id. During splitting, the data in native + * buffers will be write to disk when the buffers are full. + * + * @param splitterId splitter instance id + * @param nativeVectorBatch Addresses of nativeVectorBatch + */ + public native void split(long splitterId, long nativeVectorBatch); + + /** + * Write the data remained in the buffers hold by native splitter to each partition's temporary + * file. And stop processing splitting + * + * @param splitterId splitter instance id + * @return SplitResult + */ + public native SplitResult stop(long splitterId); + + /** + * Release resources associated with designated splitter instance. + * + * @param splitterId splitter instance id + */ + public native void close(long splitterId); } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java similarity index 97% rename from omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java index 9cfce65da353bf4548e89314210892b4d599eb4a..efc2b764a3e1e8356faa023ad5b2b6fc1cd2d501 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/java/com/huawei/ock/spark/serialize/OckShuffleDataSerializer.java @@ -1,159 +1,159 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package com.huawei.ock.spark.serialize; - -import com.huawei.ock.spark.jni.OckShuffleJniReader; - -import nova.hetu.omniruntime.type.Decimal128DataType; -import nova.hetu.omniruntime.type.Decimal64DataType; -import nova.hetu.omniruntime.vector.BooleanVec; -import nova.hetu.omniruntime.vector.Decimal128Vec; -import nova.hetu.omniruntime.vector.DoubleVec; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.ShortVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; - -import org.apache.spark.sql.execution.vectorized.OmniColumnVector; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -import java.rmi.UnexpectedException; - -/** - * Ock Shuffle DataSerializer - * - * @since 2022-6-10 - */ -public class OckShuffleDataSerializer { - private boolean isFinish = false; - private final OckShuffleJniReader jniReader; - private final nova.hetu.omniruntime.type.DataType[] vectorTypes; - private final int maxLength; - private final int maxRowNum; - - OckShuffleDataSerializer(OckShuffleJniReader reader, - nova.hetu.omniruntime.type.DataType[] vectorTypes, - int maxLength, - int maxRowNum) { - this.jniReader = reader; - this.vectorTypes = vectorTypes; - this.maxLength = maxLength; - this.maxRowNum = maxRowNum; - } - - // must call this function before deserialize - public boolean isFinish() { - return isFinish; - } - - /** - * deserialize - * - * @return ColumnarBatch - * @throws UnexpectedException UnexpectedException - */ - public ColumnarBatch deserialize() throws UnexpectedException { - jniReader.getNewVectorBatch(maxLength, maxRowNum); - int rowCount = jniReader.rowCntInVB(); - int vecCount = jniReader.colCntInVB(); - ColumnVector[] vectors = new ColumnVector[vecCount]; - for (int index = 0; index < vecCount; index++) { // mutli value - vectors[index] = buildVec(vectorTypes[index], rowCount, index); - } - - isFinish = jniReader.readFinish(); - return new ColumnarBatch(vectors, rowCount); - } - - private ColumnVector buildVec(nova.hetu.omniruntime.type.DataType srcType, int rowNum, int colIndex) { - Vec dstVec; - switch (srcType.getId()) { - case OMNI_INT: - case OMNI_DATE32: - dstVec = new IntVec(rowNum); - break; - case OMNI_LONG: - case OMNI_DATE64: - case OMNI_DECIMAL64: - dstVec = new LongVec(rowNum); - break; - case OMNI_SHORT: - dstVec = new ShortVec(rowNum); - break; - case OMNI_BOOLEAN: - dstVec = new BooleanVec(rowNum); - break; - case OMNI_DOUBLE: - dstVec = new DoubleVec(rowNum); - break; - case OMNI_CHAR: - case OMNI_VARCHAR: - // values buffer length - dstVec = new VarcharVec(jniReader.getVectorValueLength(colIndex), rowNum); - break; - case OMNI_DECIMAL128: - dstVec = new Decimal128Vec(rowNum); - break; - case OMNI_TIME32: - case OMNI_TIME64: - case OMNI_INTERVAL_DAY_TIME: - case OMNI_INTERVAL_MONTHS: - default: - throw new IllegalStateException("Unexpected value: " + srcType.getId()); - } - - jniReader.copyVectorDataInVB(dstVec, colIndex); - OmniColumnVector vecTmp = new OmniColumnVector(rowNum, getRealType(srcType), false); - vecTmp.setVec(dstVec); - return vecTmp; - } - - private DataType getRealType(nova.hetu.omniruntime.type.DataType srcType) { - switch (srcType.getId()) { - case OMNI_INT: - return DataTypes.IntegerType; - case OMNI_DATE32: - return DataTypes.DateType; - case OMNI_LONG: - return DataTypes.LongType; - case OMNI_DATE64: - return DataTypes.DateType; - case OMNI_DECIMAL64: - // for example 123.45=> precision(data length) = 5 ,scale(decimal length) = 2 - if (srcType instanceof Decimal64DataType) { - return DataTypes.createDecimalType(((Decimal64DataType) srcType).getPrecision(), - ((Decimal64DataType) srcType).getScale()); - } else { - throw new IllegalStateException("Unexpected value: " + srcType.getId()); - } - case OMNI_SHORT: - return DataTypes.ShortType; - case OMNI_BOOLEAN: - return DataTypes.BooleanType; - case OMNI_DOUBLE: - return DataTypes.DoubleType; - case OMNI_CHAR: - case OMNI_VARCHAR: - return DataTypes.StringType; - case OMNI_DECIMAL128: - if (srcType instanceof Decimal128DataType) { - return DataTypes.createDecimalType(((Decimal128DataType) srcType).getPrecision(), - ((Decimal128DataType) srcType).getScale()); - } else { - throw new IllegalStateException("Unexpected value: " + srcType.getId()); - } - case OMNI_TIME32: - case OMNI_TIME64: - case OMNI_INTERVAL_DAY_TIME: - case OMNI_INTERVAL_MONTHS: - default: - throw new IllegalStateException("Unexpected value: " + srcType.getId()); - } - } +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package com.huawei.ock.spark.serialize; + +import com.huawei.ock.spark.jni.OckShuffleJniReader; + +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; +import nova.hetu.omniruntime.vector.BooleanVec; +import nova.hetu.omniruntime.vector.Decimal128Vec; +import nova.hetu.omniruntime.vector.DoubleVec; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.ShortVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.Vec; + +import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +import java.rmi.UnexpectedException; + +/** + * Ock Shuffle DataSerializer + * + * @since 2022-6-10 + */ +public class OckShuffleDataSerializer { + private boolean isFinish = false; + private final OckShuffleJniReader jniReader; + private final nova.hetu.omniruntime.type.DataType[] vectorTypes; + private final int maxLength; + private final int maxRowNum; + + OckShuffleDataSerializer(OckShuffleJniReader reader, + nova.hetu.omniruntime.type.DataType[] vectorTypes, + int maxLength, + int maxRowNum) { + this.jniReader = reader; + this.vectorTypes = vectorTypes; + this.maxLength = maxLength; + this.maxRowNum = maxRowNum; + } + + // must call this function before deserialize + public boolean isFinish() { + return isFinish; + } + + /** + * deserialize + * + * @return ColumnarBatch + * @throws UnexpectedException UnexpectedException + */ + public ColumnarBatch deserialize() throws UnexpectedException { + jniReader.getNewVectorBatch(maxLength, maxRowNum); + int rowCount = jniReader.rowCntInVB(); + int vecCount = jniReader.colCntInVB(); + ColumnVector[] vectors = new ColumnVector[vecCount]; + for (int index = 0; index < vecCount; index++) { // mutli value + vectors[index] = buildVec(vectorTypes[index], rowCount, index); + } + + isFinish = jniReader.readFinish(); + return new ColumnarBatch(vectors, rowCount); + } + + private ColumnVector buildVec(nova.hetu.omniruntime.type.DataType srcType, int rowNum, int colIndex) { + Vec dstVec; + switch (srcType.getId()) { + case OMNI_INT: + case OMNI_DATE32: + dstVec = new IntVec(rowNum); + break; + case OMNI_LONG: + case OMNI_DATE64: + case OMNI_DECIMAL64: + dstVec = new LongVec(rowNum); + break; + case OMNI_SHORT: + dstVec = new ShortVec(rowNum); + break; + case OMNI_BOOLEAN: + dstVec = new BooleanVec(rowNum); + break; + case OMNI_DOUBLE: + dstVec = new DoubleVec(rowNum); + break; + case OMNI_CHAR: + case OMNI_VARCHAR: + // values buffer length + dstVec = new VarcharVec(jniReader.getVectorValueLength(colIndex), rowNum); + break; + case OMNI_DECIMAL128: + dstVec = new Decimal128Vec(rowNum); + break; + case OMNI_TIME32: + case OMNI_TIME64: + case OMNI_INTERVAL_DAY_TIME: + case OMNI_INTERVAL_MONTHS: + default: + throw new IllegalStateException("Unexpected value: " + srcType.getId()); + } + + jniReader.copyVectorDataInVB(dstVec, colIndex); + OmniColumnVector vecTmp = new OmniColumnVector(rowNum, getRealType(srcType), false); + vecTmp.setVec(dstVec); + return vecTmp; + } + + private DataType getRealType(nova.hetu.omniruntime.type.DataType srcType) { + switch (srcType.getId()) { + case OMNI_INT: + return DataTypes.IntegerType; + case OMNI_DATE32: + return DataTypes.DateType; + case OMNI_LONG: + return DataTypes.LongType; + case OMNI_DATE64: + return DataTypes.DateType; + case OMNI_DECIMAL64: + // for example 123.45=> precision(data length) = 5 ,scale(decimal length) = 2 + if (srcType instanceof Decimal64DataType) { + return DataTypes.createDecimalType(((Decimal64DataType) srcType).getPrecision(), + ((Decimal64DataType) srcType).getScale()); + } else { + throw new IllegalStateException("Unexpected value: " + srcType.getId()); + } + case OMNI_SHORT: + return DataTypes.ShortType; + case OMNI_BOOLEAN: + return DataTypes.BooleanType; + case OMNI_DOUBLE: + return DataTypes.DoubleType; + case OMNI_CHAR: + case OMNI_VARCHAR: + return DataTypes.StringType; + case OMNI_DECIMAL128: + if (srcType instanceof Decimal128DataType) { + return DataTypes.createDecimalType(((Decimal128DataType) srcType).getPrecision(), + ((Decimal128DataType) srcType).getScale()); + } else { + throw new IllegalStateException("Unexpected value: " + srcType.getId()); + } + case OMNI_TIME32: + case OMNI_TIME64: + case OMNI_INTERVAL_DAY_TIME: + case OMNI_INTERVAL_MONTHS: + default: + throw new IllegalStateException("Unexpected value: " + srcType.getId()); + } + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala similarity index 97% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala index 9acbf51ac62ac45e88eb3e8d4edb157579f0437f..309afd0b53e21e8376b25f2f0e728687f61010fa 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/com/huawei/ock/spark/serialize/OckColumnarBatchSerialize.scala @@ -1,103 +1,103 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package com.huawei.ock.spark.serialize - -import com.huawei.ock.spark.jni.OckShuffleJniReader -import nova.hetu.omniruntime.`type`.DataType -import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.vectorized.ColumnarBatch - -import java.io.{InputStream, OutputStream} -import java.nio.ByteBuffer -import scala.reflect.ClassTag - -class OckColumnarBatchSerializer(readBatchNumRows: SQLMetric, numOutputRows: SQLMetric) - extends Serializer with Serializable { - - /** Creates a new [[SerializerInstance]]. */ - override def newInstance(): SerializerInstance = - new OckColumnarBatchSerializerInstance(readBatchNumRows, numOutputRows) -} - -class OckColumnarBatchSerializerInstance( - readBatchNumRows: SQLMetric, - numOutputRows: SQLMetric) - extends SerializerInstance with Logging { - - override def deserializeStream(in: InputStream): DeserializationStream = { - // This method is never called by shuffle code. - throw new UnsupportedOperationException - } - - def deserializeReader(reader: OckShuffleJniReader, - vectorTypes: Array[DataType], - maxLength: Int, - maxRowNum: Int): DeserializationStream = { - new DeserializationStream { - val serializer = new OckShuffleDataSerializer(reader, vectorTypes, maxLength, maxRowNum) - - private var numBatchesTotal: Long = _ - private var numRowsTotal: Long = _ - - override def asKeyValueIterator: Iterator[(Int, ColumnarBatch)] = { - new Iterator[(Int, ColumnarBatch)] { - override def hasNext: Boolean = !serializer.isFinish() - - override def next(): (Int, ColumnarBatch) = { - val columnarBatch: ColumnarBatch = serializer.deserialize() - // todo check need count? - numBatchesTotal += 1 - numRowsTotal += columnarBatch.numRows() - (0, columnarBatch) - } - } - } - - override def asIterator: Iterator[Any] = { - // This method is never called by shuffle code. - throw new UnsupportedOperationException - } - - override def readKey[T: ClassTag](): T = { - // We skipped serialization of the key in writeKey(), so just return a dummy value since - // this is going to be discarded anyways. - null.asInstanceOf[T] - } - - override def readValue[T: ClassTag](): T = { - val columnarBatch: ColumnarBatch = serializer.deserialize() - numBatchesTotal += 1 - numRowsTotal += columnarBatch.numRows() - columnarBatch.asInstanceOf[T] - } - - override def readObject[T: ClassTag](): T = { - // This method is never called by shuffle code. - throw new UnsupportedOperationException - } - - override def close(): Unit = { - if (numBatchesTotal > 0) { - readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) - } - numOutputRows += numRowsTotal - } - } - } - - override def serialize[T: ClassTag](t: T): ByteBuffer = - throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException - - override def serializeStream(s: OutputStream): SerializationStream = - throw new UnsupportedOperationException +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package com.huawei.ock.spark.serialize + +import com.huawei.ock.spark.jni.OckShuffleJniReader +import nova.hetu.omniruntime.`type`.DataType +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +import java.io.{InputStream, OutputStream} +import java.nio.ByteBuffer +import scala.reflect.ClassTag + +class OckColumnarBatchSerializer(readBatchNumRows: SQLMetric, numOutputRows: SQLMetric) + extends Serializer with Serializable { + + /** Creates a new [[SerializerInstance]]. */ + override def newInstance(): SerializerInstance = + new OckColumnarBatchSerializerInstance(readBatchNumRows, numOutputRows) +} + +class OckColumnarBatchSerializerInstance( + readBatchNumRows: SQLMetric, + numOutputRows: SQLMetric) + extends SerializerInstance with Logging { + + override def deserializeStream(in: InputStream): DeserializationStream = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + def deserializeReader(reader: OckShuffleJniReader, + vectorTypes: Array[DataType], + maxLength: Int, + maxRowNum: Int): DeserializationStream = { + new DeserializationStream { + val serializer = new OckShuffleDataSerializer(reader, vectorTypes, maxLength, maxRowNum) + + private var numBatchesTotal: Long = _ + private var numRowsTotal: Long = _ + + override def asKeyValueIterator: Iterator[(Int, ColumnarBatch)] = { + new Iterator[(Int, ColumnarBatch)] { + override def hasNext: Boolean = !serializer.isFinish() + + override def next(): (Int, ColumnarBatch) = { + val columnarBatch: ColumnarBatch = serializer.deserialize() + // todo check need count? + numBatchesTotal += 1 + numRowsTotal += columnarBatch.numRows() + (0, columnarBatch) + } + } + } + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def readKey[T: ClassTag](): T = { + // We skipped serialization of the key in writeKey(), so just return a dummy value since + // this is going to be discarded anyways. + null.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val columnarBatch: ColumnarBatch = serializer.deserialize() + numBatchesTotal += 1 + numRowsTotal += columnarBatch.numRows() + columnarBatch.asInstanceOf[T] + } + + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def close(): Unit = { + if (numBatchesTotal > 0) { + readBatchNumRows.set(numRowsTotal.toDouble / numBatchesTotal) + } + numOutputRows += numRowsTotal + } + } + } + + override def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException + + override def serializeStream(s: OutputStream): SerializationStream = + throw new UnsupportedOperationException } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala similarity index 87% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala index b08652bdc7085d1705c9058d58adbd6ad406868f..153ba5607b92e94d29d9401002aa009f5e5a33be 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBlockResolver.scala @@ -1,72 +1,81 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package org.apache.spark.shuffle.ock - -import com.huawei.ock.spark.jni.OckShuffleJniReader -import org.apache.spark._ -import org.apache.spark.executor.TempShuffleReadMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockResolver} -import org.apache.spark.storage.{BlockId, BlockManagerId} -import org.apache.spark.util.{OCKConf, OCKFunctions} - -class OckColumnarShuffleBlockResolver(conf: SparkConf, ockConf: OCKConf) - extends ShuffleBlockResolver with Logging { - - override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]): ManagedBuffer = { - null - } - - /** - * Remove shuffle temp memory data that contain the output data from one map. - */ - def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { - } - - override def stop(): Unit = {} -} - -object OckColumnarShuffleBlockResolver extends Logging { - def getShuffleData[T](ockConf: OCKConf, - appId: String, - shuffleId: Int, - readMetrics: TempShuffleReadMetrics, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - numBuffers: Int, - bufferSize: Long, - typeIds: Array[Int], - context: TaskContext): Iterator[OckShuffleJniReader] = { - val blocksByAddresses = getMapSizes(shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) - - new OckColumnarShuffleBufferIterator(ockConf, appId, shuffleId, readMetrics, startMapIndex, endMapIndex, startPartition, endPartition, numBuffers, bufferSize, - OCKFunctions.parseBlocksByHost(blocksByAddresses), typeIds, context) - } - - def CreateFetchFailedException( - address: BlockManagerId, - shuffleId: Int, - mapId: Long, - mapIndex: Int, - reduceId: Int, - message: String - ): FetchFailedException = { - new FetchFailedException(address, shuffleId, mapId, mapIndex, reduceId, message) - } - - def getMapSizes( - shuffleId: Int, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int - ): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - val mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker - mapOutputTracker.getMapSizesByExecutorId(shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) - } +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.shuffle.ock + +import com.huawei.ock.spark.jni.OckShuffleJniReader +import org.apache.spark._ +import org.apache.spark.executor.TempShuffleReadMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shuffle.MergedBlockMeta +import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockResolver} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleMergedBlockId} +import org.apache.spark.util.{OCKConf, OCKFunctions} + +class OckColumnarShuffleBlockResolver(conf: SparkConf, ockConf: OCKConf) + extends ShuffleBlockResolver with Logging { + + override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]): ManagedBuffer = { + null + } + + /** + * Remove shuffle temp memory data that contain the output data from one map. + */ + def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { + } + + override def stop(): Unit = {} + + override def getMergedBlockData(blockId: ShuffleMergedBlockId, dirs: Option[Array[String]]): Seq[ManagedBuffer] = { + null + } + + override def getMergedBlockMeta(blockId: ShuffleMergedBlockId, dirs: Option[Array[String]]): MergedBlockMeta = { + null + } +} + +object OckColumnarShuffleBlockResolver extends Logging { + def getShuffleData[T](ockConf: OCKConf, + appId: String, + shuffleId: Int, + readMetrics: TempShuffleReadMetrics, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + numBuffers: Int, + bufferSize: Long, + typeIds: Array[Int], + context: TaskContext): Iterator[OckShuffleJniReader] = { + val blocksByAddresses = getMapSizes(shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + + new OckColumnarShuffleBufferIterator(ockConf, appId, shuffleId, readMetrics, startMapIndex, endMapIndex, startPartition, endPartition, numBuffers, bufferSize, + OCKFunctions.parseBlocksByHost(blocksByAddresses), typeIds, context) + } + + def CreateFetchFailedException( + address: BlockManagerId, + shuffleId: Int, + mapId: Long, + mapIndex: Int, + reduceId: Int, + message: String + ): FetchFailedException = { + new FetchFailedException(address, shuffleId, mapId, mapIndex, reduceId, message) + } + + def getMapSizes( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int + ): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + val mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker + mapOutputTracker.getMapSizesByExecutorId(shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala similarity index 96% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala index dc7e081555dfed6646beed6b85fc1f8356b8aa86..827971e9cec92a0c89bee1871efd8e181006cf96 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala @@ -1,153 +1,156 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package org.apache.spark.shuffle.ock - -import com.huawei.ock.spark.jni.OckShuffleJniReader -import com.huawei.ock.ucache.shuffle.NativeShuffle -import com.huawei.ock.ucache.shuffle.datatype.{FetchError, FetchResult, MapTasksInfo} -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ShuffleReadMetricsReporter -import org.apache.spark.shuffle.ock.OckColumnarShuffleBufferIterator.getAndIncReaderSequence -import org.apache.spark.util.{OCKConf, OCKException} - -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicInteger - -class OckColumnarShuffleBufferIterator[T]( - ockConf: OCKConf, - appId: String, - shuffleId: Int, - readMetrics: ShuffleReadMetricsReporter, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - numBuffers: Int, - bufferSize: Long, - mapTaskToHostInfo: MapTasksInfo, - typeIds: Array[Int], - context: TaskContext) - extends Iterator[OckShuffleJniReader] with Logging { - - private var totalFetchNum = 0L - private var blobMap: Map[Long, OckShuffleJniReader] = Map() - - private var usedBlobId = -1L - final private val FETCH_ERROR = -1L; - final private val FETCH_FINISH = 0L; - - private val taskContext = context - private val sequenceId: String = "Spark_%s_%d_%d_%d_%d_%d_%d".format(appId, shuffleId, startMapIndex, - endMapIndex, startPartition, endPartition, getAndIncReaderSequence()) - private var hasBlob: Boolean = false; - - initialize() - - private[this] def destroyMapTaskInfo(): Unit = { - if (mapTaskToHostInfo.getNativeObjHandle != 0) { - NativeShuffle.destroyMapTaskInfo(mapTaskToHostInfo.getNativeObjHandle) - mapTaskToHostInfo.setNativeObjHandle(0) - } - } - - private[this] def throwFetchException(fetchError: FetchError): Unit = { - NativeShuffle.shuffleStreamReadStop(sequenceId) - destroyMapTaskInfo() - if (fetchError.getExecutorId() > 0) { - logError("Fetch failed error occurred, mostly because ockd is killed in some stage, node id is: " - + fetchError.getNodeId + " executor id is: " + fetchError.getExecutorId() + " sequenceId is " + sequenceId) - NativeShuffle.markShuffleWorkerRemoved(appId, fetchError.getNodeId.toInt) - val blocksByAddress = OckColumnarShuffleBlockResolver.getMapSizes(shuffleId, startMapIndex, endMapIndex, - startPartition, endPartition) - OCKException.ThrowFetchFailed(appId, shuffleId, fetchError, blocksByAddress, taskContext) - } - - val errorMessage = "Other error occurred, mostly because mf copy is failed in some stage, copy from node: " - + fetchError.getNodeId + " sequenceId is " + sequenceId - OCKException.ThrowOckException(errorMessage) - } - - private[this] def initialize(): Unit = { - // if reduce task fetch data is empty, will construct empty iterator - if (mapTaskToHostInfo.recordNum() > 0) { - val ret = NativeShuffle.shuffleStreamReadSizesGet(sequenceId, shuffleId, context.stageId(), - context.stageAttemptNumber(), startMapIndex, endMapIndex, startPartition, endPartition, mapTaskToHostInfo) - if (ret == FETCH_ERROR) { - throwFetchException(NativeShuffle.shuffleStreamReaderGetError(sequenceId)) - } - totalFetchNum = ret - } - - // create buffers, or blobIds - // use bagName, numBuffers and bufferSize to create buffers in low level - if (totalFetchNum != 0) { - NativeShuffle.shuffleStreamReadStart(sequenceId) - hasBlob = true - } - - logDebug("Initialize OCKColumnarShuffleBufferIterator sequenceId " + sequenceId + " blobNum " + totalFetchNum) - } - - override def hasNext: Boolean = { - if (!hasBlob && totalFetchNum != 0) { - val dataSize: Int = NativeShuffle.shuffleStreamReadStop(sequenceId) - if (OckColumnarShuffleManager.isCompress(ockConf.sparkConf) && dataSize > 0) { - readMetrics.incRemoteBytesRead(dataSize) - } - destroyMapTaskInfo() - } - - hasBlob - } - - override def next(): OckShuffleJniReader = { - logDebug(s"new next called, need to release last buffer and call next buffer") - if (usedBlobId != -1L) { - NativeShuffle.shuffleStreamReadGatherFlush(sequenceId, usedBlobId) - } - val startFetchWait = System.nanoTime() - val result: FetchResult = NativeShuffle.shuffleStreamReadGatherOneBlob(sequenceId) - val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) - readMetrics.incFetchWaitTime(fetchWaitTime) - - if (result.getRet == FETCH_ERROR) { - throwFetchException(result.getError) - } else if (result.getRet == FETCH_FINISH) { - hasBlob = false - } - - usedBlobId = result.getBlobId - logDebug("Get info blobId " + result.getBlobId + " blobSize " + result.getDataSize + ", sequenceId " - + sequenceId + " getRet " + result.getRet) - if (result.getDataSize > 0) { - if (!OckColumnarShuffleManager.isCompress(ockConf.sparkConf)) { - readMetrics.incRemoteBytesRead(result.getDataSize) - } - if (blobMap.contains(result.getBlobId)) { - val record = blobMap(result.getBlobId) - record.upgradeValueLen(result.getDataSize) - record - } else { - val record = new OckShuffleJniReader(result.getBlobId, result.getCapacity.toInt, - result.getAddress, result.getDataSize, typeIds) - blobMap += (result.getBlobId -> record) - record - } - } else { - val errorMessage = "Get buffer capacity to read is zero, sequenceId is " + sequenceId - OCKException.ThrowOckException(errorMessage) - new OckShuffleJniReader(result.getBlobId, 0, result.getAddress, result.getDataSize, typeIds) - } - } -} - -private object OckColumnarShuffleBufferIterator { - var gReaderSequence : AtomicInteger = new AtomicInteger(0) - - def getAndIncReaderSequence(): Int = { - gReaderSequence.getAndIncrement() - } +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.shuffle.ock + +import com.huawei.ock.spark.jni.OckShuffleJniReader +import com.huawei.ock.ucache.shuffle.NativeShuffle +import com.huawei.ock.ucache.shuffle.datatype.{FetchError, FetchResult, MapTasksInfo} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.apache.spark.shuffle.ock.OckColumnarShuffleBufferIterator.getAndIncReaderSequence +import org.apache.spark.util.{OCKConf, OCKException} + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger + +class OckColumnarShuffleBufferIterator[T]( + ockConf: OCKConf, + appId: String, + shuffleId: Int, + readMetrics: ShuffleReadMetricsReporter, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + numBuffers: Int, + bufferSize: Long, + mapTaskToHostInfo: MapTasksInfo, + typeIds: Array[Int], + context: TaskContext) + extends Iterator[OckShuffleJniReader] with Logging { + + private var totalFetchNum = 0L + private var blobMap: Map[Long, OckShuffleJniReader] = Map() + + private var usedBlobId = -1L + final private val FETCH_ERROR = -1L; + final private val FETCH_FINISH = 0L; + + private val taskContext = context + private val sequenceId: String = "Spark_%s_%d_%d_%d_%d_%d_%d".format(appId, shuffleId, startMapIndex, + endMapIndex, startPartition, endPartition, getAndIncReaderSequence()) + private var hasBlob: Boolean = false; + + initialize() + + private[this] def destroyMapTaskInfo(): Unit = { + if (mapTaskToHostInfo.getNativeObjHandle != 0) { + NativeShuffle.destroyMapTaskInfo(mapTaskToHostInfo.getNativeObjHandle) + mapTaskToHostInfo.setNativeObjHandle(0) + } + blobMap.values.foreach(reader => { + reader.doClose() + }) + } + + private[this] def throwFetchException(fetchError: FetchError): Unit = { + NativeShuffle.shuffleStreamReadStop(sequenceId) + destroyMapTaskInfo() + if (fetchError.getExecutorId() > 0) { + logError("Fetch failed error occurred, mostly because ockd is killed in some stage, node id is: " + + fetchError.getNodeId + " executor id is: " + fetchError.getExecutorId() + " sequenceId is " + sequenceId) + NativeShuffle.markShuffleWorkerRemoved(appId, fetchError.getNodeId.toInt) + val blocksByAddress = OckColumnarShuffleBlockResolver.getMapSizes(shuffleId, startMapIndex, endMapIndex, + startPartition, endPartition) + OCKException.ThrowFetchFailed(appId, shuffleId, fetchError, blocksByAddress, taskContext) + } + + val errorMessage = "Other error occurred, mostly because mf copy is failed in some stage, copy from node: " + + fetchError.getNodeId + " sequenceId is " + sequenceId + OCKException.ThrowOckException(errorMessage) + } + + private[this] def initialize(): Unit = { + // if reduce task fetch data is empty, will construct empty iterator + if (mapTaskToHostInfo.recordNum() > 0) { + val ret = NativeShuffle.shuffleStreamReadSizesGet(sequenceId, shuffleId, context.stageId(), + context.stageAttemptNumber(), startMapIndex, endMapIndex, startPartition, endPartition, mapTaskToHostInfo) + if (ret == FETCH_ERROR) { + throwFetchException(NativeShuffle.shuffleStreamReaderGetError(sequenceId)) + } + totalFetchNum = ret + } + + // create buffers, or blobIds + // use bagName, numBuffers and bufferSize to create buffers in low level + if (totalFetchNum != 0) { + NativeShuffle.shuffleStreamReadStart(sequenceId, endPartition) + hasBlob = true + } + + logDebug("Initialize OCKColumnarShuffleBufferIterator sequenceId " + sequenceId + " blobNum " + totalFetchNum) + } + + override def hasNext: Boolean = { + if (!hasBlob && totalFetchNum != 0) { + val dataSize: Int = NativeShuffle.shuffleStreamReadStop(sequenceId) + if (OckColumnarShuffleManager.isCompress(ockConf.sparkConf) && dataSize > 0) { + readMetrics.incRemoteBytesRead(dataSize) + } + destroyMapTaskInfo() + } + + hasBlob + } + + override def next(): OckShuffleJniReader = { + logDebug(s"new next called, need to release last buffer and call next buffer") + if (usedBlobId != -1L) { + NativeShuffle.shuffleStreamReadGatherFlush(sequenceId, usedBlobId) + } + val startFetchWait = System.nanoTime() + val result: FetchResult = NativeShuffle.shuffleStreamReadGatherOneBlob(sequenceId) + val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait) + readMetrics.incFetchWaitTime(fetchWaitTime) + + if (result.getRet == FETCH_ERROR) { + throwFetchException(result.getError) + } else if (result.getRet == FETCH_FINISH) { + hasBlob = false + } + + usedBlobId = result.getBlobId + logDebug("Get info blobId " + result.getBlobId + " blobSize " + result.getDataSize + ", sequenceId " + + sequenceId + " getRet " + result.getRet) + if (result.getDataSize > 0) { + if (!OckColumnarShuffleManager.isCompress(ockConf.sparkConf)) { + readMetrics.incRemoteBytesRead(result.getDataSize) + } + if (blobMap.contains(result.getBlobId)) { + val record = blobMap(result.getBlobId) + record.upgradeValueLen(result.getDataSize) + record + } else { + val record = new OckShuffleJniReader(result.getBlobId, result.getCapacity.toInt, + result.getAddress, result.getDataSize, typeIds) + blobMap += (result.getBlobId -> record) + record + } + } else { + val errorMessage = "Get buffer capacity to read is zero, sequenceId is " + sequenceId + OCKException.ThrowOckException(errorMessage) + new OckShuffleJniReader(result.getBlobId, 0, result.getAddress, result.getDataSize, typeIds) + } + } +} + +private object OckColumnarShuffleBufferIterator { + var gReaderSequence : AtomicInteger = new AtomicInteger(0) + + def getAndIncReaderSequence(): Int = { + gReaderSequence.getAndIncrement() + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala similarity index 96% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala index 8dba25ea5ad06ef810d8416021322dbdf90ffe18..70530996a18f718c525436fda65bc66516433e81 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleHandle.scala @@ -1,19 +1,19 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package org.apache.spark.shuffle.ock - -import org.apache.spark.ShuffleDependency -import org.apache.spark.shuffle.BaseShuffleHandle - -class OckColumnarShuffleHandle[K, V]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, V], - secureId: String, - _appAttemptId: String) - extends BaseShuffleHandle(shuffleId, dependency) { - var secCode: String = secureId - - def appAttemptId : String = _appAttemptId +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.shuffle.ock + +import org.apache.spark.ShuffleDependency +import org.apache.spark.shuffle.BaseShuffleHandle + +class OckColumnarShuffleHandle[K, V]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, V], + secureId: String, + _appAttemptId: String) + extends BaseShuffleHandle(shuffleId, dependency) { + var secCode: String = secureId + + def appAttemptId : String = _appAttemptId } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala similarity index 94% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala index 3457f0da62f4db1ae14e614d8925ff2089e1d256..8111dc9046243962b74fe8fac5d1c3c6d8930ab7 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleManager.scala @@ -1,218 +1,216 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package org.apache.spark.shuffle.ock - -import com.huawei.ock.ucache.common.exception.ApplicationException -import com.huawei.ock.ucache.shuffle.NativeShuffle -import org.apache.spark._ -import org.apache.spark.executor.TempShuffleReadMetrics -import org.apache.spark.internal.config.IO_COMPRESSION_CODEC -import org.apache.spark.internal.{Logging, config} -import org.apache.spark.scheduler.OCKScheduler -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.ColumnarShuffleManager -import org.apache.spark.util.{OCKConf, OCKFunctions, Utils} - -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicBoolean - -class OckColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager with Logging { - /** - * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. - */ - private[this] val numMapsForOCKShuffle = new ConcurrentHashMap[Int, Long]() - private[this] val ockConf = new OCKConf(conf) - - - val shuffleBlockResolver = new OckColumnarShuffleBlockResolver(conf, ockConf) - - var appId = "" - var listenFlg: Boolean = false - var isOckBroadcast: Boolean = ockConf.isOckBroadcast - var heartBeatFlag = false - val applicationDefaultAttemptId = "1"; - - if (ockConf.excludeUnavailableNodes && ockConf.appId == "driver") { - OCKScheduler.waitAndBlacklistUnavailableNode(conf) - } - - OCKFunctions.shuffleInitialize(ockConf, isOckBroadcast) - val isShuffleCompress: Boolean = conf.get(config.SHUFFLE_COMPRESS) - val compressCodec: String = conf.get(IO_COMPRESSION_CODEC); - OCKFunctions.setShuffleCompress(OckColumnarShuffleManager.isCompress(conf), compressCodec) - - /** - * Obtains a [[ShuffleHandle]] to pass to tasks. - */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - appId = OCKFunctions.genAppId(conf.getAppId, SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) - if (!listenFlg) { - dependency.rdd.sparkContext.addSparkListener(new OCKShuffleStageListener(conf, appId, ockConf.removeShuffleDataAfterJobFinished)) - listenFlg = true - } - var tokenCode: String = "" - if (isOckBroadcast) { - tokenCode = OCKFunctions.getToken(ockConf.isIsolated) - OckColumnarShuffleManager.registerShuffle(shuffleId, dependency.partitioner.numPartitions, conf, ockConf) - } else { - tokenCode = OckColumnarShuffleManager.registerShuffle(shuffleId, dependency.partitioner.numPartitions, - conf, ockConf) - } - if (!heartBeatFlag && ockConf.appId == "driver") { - heartBeatFlag = true - OCKFunctions.tryStartHeartBeat(this, appId) - } - - if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { - new OckColumnarShuffleHandle[K, V]( - shuffleId, - dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]], - tokenCode, - SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) - } else { - new OCKShuffleHandle(shuffleId, dependency, tokenCode, - SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) - } - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Long, - context: TaskContext, - metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - logInfo(s"Map task get writer. Task info: shuffleId ${handle.shuffleId} mapId $mapId") - - handle match { - case ockColumnarShuffleHandle: OckColumnarShuffleHandle[K@unchecked, V@unchecked] => - appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].appAttemptId) - //when ock shuffle work with memory cache will remove numMapsForOCKShuffle - OckColumnarShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].secCode) - new OckColumnarShuffleWriter(appId, ockConf, ockColumnarShuffleHandle, mapId, context, metrics) - case ockShuffleHandle: OCKShuffleHandle[K@unchecked, V@unchecked, _] => - appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].appAttemptId) - //when ock shuffle work with memory cache will remove numMapsForOCKShuffle - OckColumnarShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].secCode) - val serializerClass: String = ockConf.serializerClass - val serializer: Serializer = Utils.classForName(serializerClass).newInstance().asInstanceOf[Serializer] - new OCKShuffleWriter(appId, ockConf, ockShuffleHandle.asInstanceOf[BaseShuffleHandle[K, V, _]], - serializer, mapId, context, metrics) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - logInfo(s"Reduce task get reader. Task info: shuffleId ${handle.shuffleId} reduceId $startPartition - $endPartition ") - - if (handle.isInstanceOf[OckColumnarShuffleHandle[_, _]]) { - appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].appAttemptId) - ShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].secCode) - new OckColumnarShuffleReader(appId, handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startMapIndex, endMapIndex, startPartition, endPartition, context, conf, ockConf, metrics.asInstanceOf[TempShuffleReadMetrics]) - } else { - appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].appAttemptId) - ShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].secCode) - new OCKShuffleReader(appId, handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startMapIndex, endMapIndex, startPartition, endPartition, context, conf, ockConf, metrics.asInstanceOf[TempShuffleReadMetrics]) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - logInfo(s"Unregister shuffle. Task info: shuffleId $shuffleId") - Option(numMapsForOCKShuffle.remove(shuffleId)).foreach { numMaps => - (0 until numMaps.toInt).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - logInfo("stop ShuffleManager") - if (ockConf.appId == "driver") { - if (SparkContext.getActive.isDefined) { - appId = OCKFunctions.genAppId(conf.getAppId, SparkContext.getActive.get.applicationAttemptId.getOrElse(applicationDefaultAttemptId)) - } - if (appId.nonEmpty) { - OCKFunctions.tryStopHeartBeat(this, appId) - OckColumnarShuffleManager.markComplete(ockConf, appId) - } - } - shuffleBlockResolver.stop() - } -} - -private[spark] object OckColumnarShuffleManager extends Logging { - - var externalShuffleServiceFlag :AtomicBoolean = new AtomicBoolean(false) - var isWR: AtomicBoolean = new AtomicBoolean(false) - - def registerShuffle( - shuffleId: Int, - numPartitions: Int, - conf: SparkConf, - ockConf: OCKConf): String = { - val appId = OCKFunctions.genAppId(conf.getAppId, SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) - val bagPartName = OCKFunctions.concatBagPartName(appId, shuffleId) - NativeShuffle.shuffleBagBatchCreate(appId, bagPartName, numPartitions, ockConf.priority, 0) - - if (!externalShuffleServiceFlag.get()) { - try { - val blockManagerClass = Class.forName("org.apache.spark.storage.BlockManager") - val externalShuffleServiceEnabledField = blockManagerClass.getDeclaredField("externalShuffleServiceEnabled") - externalShuffleServiceEnabledField.setAccessible(true) - externalShuffleServiceEnabledField.set(SparkEnv.get.blockManager, true) - logInfo("success to change externalShuffleServiceEnabled in block manager to " + - SparkEnv.get.blockManager.externalShuffleServiceEnabled) - externalShuffleServiceFlag.set(true) - } catch { - case _: Exception => - logWarning("failed to change externalShuffleServiceEnabled in block manager," + - " maybe ockd could not be able to recover in shuffle process") - } - conf.set(config.SHUFFLE_SERVICE_ENABLED, true) - } - // generate token code. Need 32bytes. - OCKFunctions.getToken(ockConf.isIsolated) - } - - def registerApp(appId: String, ockConf: OCKConf, secCode: String): Unit = { - if (!isWR.get()) { - synchronized(if (!isWR.get()) { - val nodeId = NativeShuffle.registerShuffleApp(appId, ockConf.removeShuffleDataAfterJobFinished, secCode) - isWR.set(true) - OCKFunctions.setNodeId(nodeId) - }) - } - } - - def markComplete(ockConf: OCKConf, appId: String): Unit = { - try { - NativeShuffle.markApplicationCompleted(appId) - } catch { - case ex: ApplicationException => - logError("Failed to mark application completed") - } - } - - def isCompress(conf: SparkConf): Boolean = { - conf.get(config.SHUFFLE_COMPRESS) - } +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.shuffle.ock + +import com.huawei.ock.common.exception.ApplicationException +import com.huawei.ock.ucache.shuffle.NativeShuffle +import org.apache.spark._ +import org.apache.spark.executor.TempShuffleReadMetrics +import org.apache.spark.internal.config.IO_COMPRESSION_CODEC +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.scheduler.OCKScheduler +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.ColumnarShuffleManager +import org.apache.spark.util.{OCKConf, OCKFunctions, Utils} + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean + +class OckColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager with Logging { + /** + * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. + */ + private[this] val numMapsForOCKShuffle = new ConcurrentHashMap[Int, Long]() + private[this] val ockConf = new OCKConf(conf) + + + val shuffleBlockResolver = new OckColumnarShuffleBlockResolver(conf, ockConf) + + var appId = "" + var listenFlg: Boolean = false + var isOckBroadcast: Boolean = ockConf.isOckBroadcast + @volatile var heartBeatFlag: AtomicBoolean = new AtomicBoolean(false) + val applicationDefaultAttemptId = "1"; + + if (ockConf.excludeUnavailableNodes && ockConf.appId == "driver") { + OCKScheduler.waitAndBlacklistUnavailableNode(conf) + } + + OCKFunctions.shuffleInitialize(ockConf) + val isShuffleCompress: Boolean = conf.get(config.SHUFFLE_COMPRESS) + val compressCodec: String = conf.get(IO_COMPRESSION_CODEC); + OCKFunctions.setShuffleCompress(OckColumnarShuffleManager.isCompress(conf), compressCodec) + + /** + * Obtains a [[ShuffleHandle]] to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + appId = OCKFunctions.genAppId(conf.getAppId, SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) + if (!listenFlg) { + dependency.rdd.sparkContext.addSparkListener(new OCKShuffleStageListener(conf, appId, ockConf.removeShuffleDataAfterJobFinished)) + listenFlg = true + } + var tokenCode: String = "" + if (isOckBroadcast) { + tokenCode = OCKFunctions.getToken(ockConf.isIsolated) + OckColumnarShuffleManager.registerShuffle(shuffleId, dependency.partitioner.numPartitions, conf, ockConf) + } else { + tokenCode = OckColumnarShuffleManager.registerShuffle(shuffleId, dependency.partitioner.numPartitions, + conf, ockConf) + } + if (ockConf.appId == "driver" && !heartBeatFlag.getAndSet(true)) { + OCKFunctions.tryStartHeartBeat(this, appId) + } + + if (dependency.isInstanceOf[ColumnarShuffleDependency[_, _, _]]) { + new OckColumnarShuffleHandle[K, V]( + shuffleId, + dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]], + tokenCode, + SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) + } else { + new OCKShuffleHandle(shuffleId, dependency, tokenCode, + SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) + } + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + logInfo(s"Map task get writer. Task info: shuffleId ${handle.shuffleId} mapId $mapId") + + handle match { + case ockColumnarShuffleHandle: OckColumnarShuffleHandle[K@unchecked, V@unchecked] => + appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].appAttemptId) + //when ock shuffle work with memory cache will remove numMapsForOCKShuffle + OckColumnarShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].secCode) + new OckColumnarShuffleWriter(appId, ockConf, ockColumnarShuffleHandle, mapId, context, metrics) + case ockShuffleHandle: OCKShuffleHandle[K@unchecked, V@unchecked, _] => + appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].appAttemptId) + //when ock shuffle work with memory cache will remove numMapsForOCKShuffle + OckColumnarShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].secCode) + val serializerClass: String = ockConf.serializerClass + val serializer: Serializer = Utils.classForName(serializerClass).newInstance().asInstanceOf[Serializer] + new OCKShuffleWriter(appId, ockConf, ockShuffleHandle.asInstanceOf[BaseShuffleHandle[K, V, _]], + serializer, mapId, context, metrics) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + logInfo(s"Reduce task get reader. Task info: shuffleId ${handle.shuffleId} reduceId $startPartition - $endPartition ") + + if (handle.isInstanceOf[OckColumnarShuffleHandle[_, _]]) { + appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].appAttemptId) + ShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OckColumnarShuffleHandle[_, _]].secCode) + new OckColumnarShuffleReader(appId, handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startMapIndex, endMapIndex, startPartition, endPartition, context, conf, ockConf, metrics.asInstanceOf[TempShuffleReadMetrics]) + } else { + appId = OCKFunctions.genAppId(ockConf.appId, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].appAttemptId) + ShuffleManager.registerApp(appId, ockConf, handle.asInstanceOf[OCKShuffleHandle[_, _, _]].secCode) + new OCKShuffleReader(appId, handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startMapIndex, endMapIndex, startPartition, endPartition, context, conf, ockConf, metrics.asInstanceOf[TempShuffleReadMetrics]) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + logInfo(s"Unregister shuffle. Task info: shuffleId $shuffleId") + Option(numMapsForOCKShuffle.remove(shuffleId)).foreach { numMaps => + (0 until numMaps.toInt).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + logInfo("stop ShuffleManager") + if (ockConf.appId == "driver") { + if (SparkContext.getActive.isDefined) { + appId = OCKFunctions.genAppId(conf.getAppId, SparkContext.getActive.get.applicationAttemptId.getOrElse(applicationDefaultAttemptId)) + } + if (appId.nonEmpty) { + OCKFunctions.tryStopHeartBeat(this, appId) + OckColumnarShuffleManager.markComplete(ockConf, appId) + } + } + shuffleBlockResolver.stop() + } +} + +private[spark] object OckColumnarShuffleManager extends Logging { + + var externalShuffleServiceFlag :AtomicBoolean = new AtomicBoolean(false) + var isWR: AtomicBoolean = new AtomicBoolean(false) + + def registerShuffle( + shuffleId: Int, + numPartitions: Int, + conf: SparkConf, + ockConf: OCKConf): String = { + val appId = OCKFunctions.genAppId(conf.getAppId, SparkContext.getActive.get.applicationAttemptId.getOrElse("1")) + val bagPartName = OCKFunctions.concatBagPartName(appId, shuffleId) + NativeShuffle.shuffleBagBatchCreate(appId, bagPartName, numPartitions, ockConf.priority, 0) + + if (!externalShuffleServiceFlag.get()) { + try { + val blockManagerClass = Class.forName("org.apache.spark.storage.BlockManager") + val externalShuffleServiceEnabledField = blockManagerClass.getDeclaredField("externalShuffleServiceEnabled") + externalShuffleServiceEnabledField.setAccessible(true) + externalShuffleServiceEnabledField.set(SparkEnv.get.blockManager, true) + logInfo("success to change externalShuffleServiceEnabled in block manager to " + + SparkEnv.get.blockManager.externalShuffleServiceEnabled) + externalShuffleServiceFlag.set(true) + } catch { + case _: Exception => + logWarning("failed to change externalShuffleServiceEnabled in block manager," + + " maybe ockd could not be able to recover in shuffle process") + } + } + // generate token code. Need 32bytes. + OCKFunctions.getToken(ockConf.isIsolated) + } + + def registerApp(appId: String, ockConf: OCKConf, secCode: String): Unit = { + if (!isWR.get()) { + synchronized(if (!isWR.get()) { + val nodeId = NativeShuffle.registerShuffleApp(appId, ockConf.removeShuffleDataAfterJobFinished, secCode) + isWR.set(true) + OCKFunctions.setNodeId(nodeId) + }) + } + } + + def markComplete(ockConf: OCKConf, appId: String): Unit = { + try { + NativeShuffle.markApplicationCompleted(appId) + } catch { + case ex: ApplicationException => + logError("Failed to mark application completed") + } + } + + def isCompress(conf: SparkConf): Boolean = { + conf.get(config.SHUFFLE_COMPRESS) + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala similarity index 97% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala index a1cf5ebe08c30281c062a45631edff65fea1aee8..723884dcb1c84d76248ea256773006aa113a67fb 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleReader.scala @@ -1,139 +1,139 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package org.apache.spark.shuffle.ock - -import com.huawei.boostkit.spark.ColumnarPluginConfig -import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer -import com.huawei.ock.spark.jni.OckShuffleJniReader -import com.huawei.ock.spark.serialize.{OckColumnarBatchSerializer, OckColumnarBatchSerializerInstance} -import nova.hetu.omniruntime.`type`.{DataType, DataTypeSerializer} -import org.apache.spark._ -import org.apache.spark.executor.TempShuffleReadMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.serializer.JavaSerializerInstance -import org.apache.spark.shuffle.{BaseShuffleHandle, ColumnarShuffleDependency, ShuffleReader} -import org.apache.spark.sorter.OCKShuffleSorter -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.util.{CompletionIterator, OCKConf, Utils} - -/** - * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by - * requesting them from other nodes' block stores. - */ -class OckColumnarShuffleReader[K, C]( - appId: String, - handle: BaseShuffleHandle[K, _, C], - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - context: TaskContext, - conf: SparkConf, - ockConf: OCKConf, - readMetrics: TempShuffleReadMetrics) - extends ShuffleReader[K, C] with Logging { - logInfo(s"get OCKShuffleReader mapIndex $startMapIndex - $endMapIndex partition: $startPartition - $endPartition.") - - private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, C, C]] - - val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - - private var recordsSize: Long = 0L - // some input stream may exist header, must handle for it - private var isInputStreamExistHeader: Boolean = false - - val shuffleSorterClass: String = ockConf.shuffleSorterClass - - val ockShuffleSorter: OCKShuffleSorter = - Utils.classForName(shuffleSorterClass).newInstance.asInstanceOf[OCKShuffleSorter] - - val readBatchNumRows = classOf[ColumnarBatchSerializer].getDeclaredField("readBatchNumRows") - val numOutputRows = classOf[ColumnarBatchSerializer].getDeclaredField("numOutputRows") - readBatchNumRows.setAccessible(true) - numOutputRows.setAccessible(true) - - private val serializerInstance = new OckColumnarBatchSerializer( - readBatchNumRows.get(dep.serializer).asInstanceOf[SQLMetric], - numOutputRows.get(dep.serializer).asInstanceOf[SQLMetric]) - .newInstance() - .asInstanceOf[OckColumnarBatchSerializerInstance] - - /** - * Read the combined key-values for this reduce task - */ - override def read(): Iterator[Product2[K, C]] = { - // Update the context task metrics for each record read. - val vectorTypes: Array[DataType] = DataTypeSerializer.deserialize(dep.partitionInfo.getInputTypes) - val typeIds: Array[Int] = vectorTypes.map { - vecType => vecType.getId.ordinal - } - - val gatherDataStart = System.currentTimeMillis() - val records: Iterator[OckShuffleJniReader] = OckColumnarShuffleBlockResolver.getShuffleData(ockConf, appId, - handle.shuffleId, readMetrics, startMapIndex, endMapIndex, - startPartition, endPartition, 3, 0L, typeIds, context) - val gatherDataEnd = System.currentTimeMillis() - - var aggregatedIter: Iterator[Product2[K, C]] = null - var deserializeStart: Long = 0L - var deserializeEnd: Long = 0L - var combineBranchEnd: Long = 0L - var branch: Int = 0 - - if (ockConf.useSparkSerializer) { - deserializeStart = System.currentTimeMillis() - val readIter = records.flatMap { shuffleJniReader => - recordsSize += shuffleJniReader.getValueLen - serializerInstance.deserializeReader(shuffleJniReader, vectorTypes, - columnarConf.maxBatchSizeInBytes, - columnarConf.maxRowCount).asKeyValueIterator - } - - val recordIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - readIter.map { record => - readMetrics.incRecordsRead(1) - record - }, - context.taskMetrics().mergeShuffleReadMetrics()) - - // An interruptible iterator must be used here in order to support task cancellation - val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, recordIter) - - deserializeEnd = System.currentTimeMillis() - - aggregatedIter = if (dep.aggregator.isDefined) { - if (dep.mapSideCombine && ockConf.isMapSideCombineExt) { - branch = 1 - // We are reading values that are already combined - val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] - dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) - } else { - branch = 2 - val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] - dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) - } - } else { - branch = 3 - interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] - } - combineBranchEnd = System.currentTimeMillis() - } - context.taskMetrics().mergeShuffleReadMetrics() - - val result = dep.keyOrdering match { - case Some(keyOrd: Ordering[K]) => - ockShuffleSorter.sort(context, keyOrd, dep.serializer, records, aggregatedIter) - case None => - aggregatedIter - } - val sortEnd = System.currentTimeMillis() - - logInfo("Time cost for shuffle read partitionId: " + startPartition + "; gather data cost " + (gatherDataEnd - gatherDataStart) - + "ms. data size: " + recordsSize + "Bytes. deserialize cost " + (deserializeEnd - deserializeStart) + "ms. combine branch: " - + branch + ", cost: " + (combineBranchEnd - deserializeEnd) + "ms. " + "sort: " + (sortEnd - combineBranchEnd) + "ms.") - - new InterruptibleIterator[Product2[K, C]](context, result) - } +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.shuffle.ock + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import com.huawei.ock.spark.jni.OckShuffleJniReader +import com.huawei.ock.spark.serialize.{OckColumnarBatchSerializer, OckColumnarBatchSerializerInstance} +import nova.hetu.omniruntime.`type`.{DataType, DataTypeSerializer} +import org.apache.spark._ +import org.apache.spark.executor.TempShuffleReadMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.JavaSerializerInstance +import org.apache.spark.shuffle.{BaseShuffleHandle, ColumnarShuffleDependency, ShuffleReader} +import org.apache.spark.sorter.OCKShuffleSorter +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.util.{CompletionIterator, OCKConf, Utils} + +/** + * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by + * requesting them from other nodes' block stores. + */ +class OckColumnarShuffleReader[K, C]( + appId: String, + handle: BaseShuffleHandle[K, _, C], + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + conf: SparkConf, + ockConf: OCKConf, + readMetrics: TempShuffleReadMetrics) + extends ShuffleReader[K, C] with Logging { + logInfo(s"get OCKShuffleReader mapIndex $startMapIndex - $endMapIndex partition: $startPartition - $endPartition.") + + private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, C, C]] + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + + private var recordsSize: Long = 0L + // some input stream may exist header, must handle for it + private var isInputStreamExistHeader: Boolean = false + + val shuffleSorterClass: String = ockConf.shuffleSorterClass + + val ockShuffleSorter: OCKShuffleSorter = + Utils.classForName(shuffleSorterClass).newInstance.asInstanceOf[OCKShuffleSorter] + + val readBatchNumRows = classOf[ColumnarBatchSerializer].getDeclaredField("readBatchNumRows") + val numOutputRows = classOf[ColumnarBatchSerializer].getDeclaredField("numOutputRows") + readBatchNumRows.setAccessible(true) + numOutputRows.setAccessible(true) + + private val serializerInstance = new OckColumnarBatchSerializer( + readBatchNumRows.get(dep.serializer).asInstanceOf[SQLMetric], + numOutputRows.get(dep.serializer).asInstanceOf[SQLMetric]) + .newInstance() + .asInstanceOf[OckColumnarBatchSerializerInstance] + + /** + * Read the combined key-values for this reduce task + */ + override def read(): Iterator[Product2[K, C]] = { + // Update the context task metrics for each record read. + val vectorTypes: Array[DataType] = DataTypeSerializer.deserialize(dep.partitionInfo.getInputTypes) + val typeIds: Array[Int] = vectorTypes.map { + vecType => vecType.getId.ordinal + } + + val gatherDataStart = System.currentTimeMillis() + val records: Iterator[OckShuffleJniReader] = OckColumnarShuffleBlockResolver.getShuffleData(ockConf, appId, + handle.shuffleId, readMetrics, startMapIndex, endMapIndex, + startPartition, endPartition, 3, 0L, typeIds, context) + val gatherDataEnd = System.currentTimeMillis() + + var aggregatedIter: Iterator[Product2[K, C]] = null + var deserializeStart: Long = 0L + var deserializeEnd: Long = 0L + var combineBranchEnd: Long = 0L + var branch: Int = 0 + + if (ockConf.useSparkSerializer) { + deserializeStart = System.currentTimeMillis() + val readIter = records.flatMap { shuffleJniReader => + recordsSize += shuffleJniReader.getValueLen + serializerInstance.deserializeReader(shuffleJniReader, vectorTypes, + columnarConf.maxBatchSizeInBytes, + columnarConf.maxRowCount).asKeyValueIterator + } + + val recordIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + readIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, recordIter) + + deserializeEnd = System.currentTimeMillis() + + aggregatedIter = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine && ockConf.isMapSideCombineExt) { + branch = 1 + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + branch = 2 + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + branch = 3 + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + combineBranchEnd = System.currentTimeMillis() + } + context.taskMetrics().mergeShuffleReadMetrics() + + val result = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + ockShuffleSorter.sort(context, keyOrd, dep.serializer, records, aggregatedIter) + case None => + aggregatedIter + } + val sortEnd = System.currentTimeMillis() + + logInfo("Time cost for shuffle read partitionId: " + startPartition + "; gather data cost " + (gatherDataEnd - gatherDataStart) + + "ms. data size: " + recordsSize + "Bytes. deserialize cost " + (deserializeEnd - deserializeStart) + "ms. combine branch: " + + branch + ", cost: " + (combineBranchEnd - deserializeEnd) + "ms. " + "sort: " + (sortEnd - combineBranchEnd) + "ms.") + + new InterruptibleIterator[Product2[K, C]](context, result) + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala similarity index 95% rename from omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala rename to omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala index e7aaf0fdf7a21abf1737d5280a5d83733cf9d416..41daa661cd306fed6a0afeb3cdfcb151855a0600 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-shuffle/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleWriter.scala @@ -1,155 +1,161 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - */ - -package org.apache.spark.shuffle.ock - -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs -import com.huawei.boostkit.spark.vectorized.SplitResult -import com.huawei.ock.spark.jni.OckShuffleJniWriter -import nova.hetu.omniruntime.vector.VecBatch -import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle._ -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{OCKConf, OCKFunctions} -import org.apache.spark.{SparkEnv, TaskContext} - -class OckColumnarShuffleWriter[K, V]( - applicationId: String, - ockConf: OCKConf, - handle: BaseShuffleHandle[K, V, V], - mapId: Long, - context: TaskContext, - writeMetrics: ShuffleWriteMetricsReporter) - extends ShuffleWriter[K, V] with Logging { - - private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]] - - private val blockManager = SparkEnv.get.blockManager - - private var stopping = false - - private var mapStatus: MapStatus = _ - - val enableShuffleCompress: Boolean = OckColumnarShuffleManager.isCompress(ockConf.sparkConf) - - val cap: Int = ockConf.capacity - val maxCapacityTotal: Int = ockConf.maxCapacityTotal - val minCapacityTotal: Int = ockConf.minCapacityTotal - - private val jniWritter = new OckShuffleJniWriter() - - private var nativeSplitter: Long = 0 - - private var splitResult: SplitResult = _ - - private var partitionLengths: Array[Long] = _ - - private var first: Boolean = true - private var readTime: Long = 0L - private var markTime: Long = 0L - private var splitTime: Long = 0L - private var changeTime: Long = 0L - private var rowNum: Int = 0 - private var vbCnt: Int = 0 - - override def write(records: Iterator[Product2[K, V]]): Unit = { - if (!records.hasNext) { - partitionLengths = new Array[Long](dep.partitioner.numPartitions) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) - return - } - - val startMake = System.currentTimeMillis() - if (nativeSplitter == 0) { - nativeSplitter = jniWritter.make( - applicationId, - dep.shuffleId, - context.stageId(), - context.stageAttemptNumber(), - mapId.toInt, - context.taskAttemptId(), - dep.partitionInfo, - cap, - maxCapacityTotal, - minCapacityTotal, - enableShuffleCompress) - } - val makeTime = System.currentTimeMillis() - startMake - - while (records.hasNext) { - vbCnt += 1 - if (first) { - readTime = System.currentTimeMillis() - makeTime - first = false - } else { - readTime += (System.currentTimeMillis() - markTime) - } - val cb = records.next()._2.asInstanceOf[ColumnarBatch] - if (cb.numRows == 0 || cb.numCols == 0) { - logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") - System.out.println("Skip column") - markTime = System.currentTimeMillis() - } else { - val startTime = System.currentTimeMillis() - val input = transColBatchToOmniVecs(cb) - val endTime = System.currentTimeMillis() - changeTime += endTime - startTime - for( col <- 0 until cb.numCols()) { - dep.dataSize += input(col).getRealValueBufCapacityInBytes - dep.dataSize += input(col).getRealNullBufCapacityInBytes - dep.dataSize += input(col).getRealOffsetBufCapacityInBytes - } - val vb = new VecBatch(input, cb.numRows()) - if (rowNum == 0) { - rowNum = cb.numRows() - } - jniWritter.split(nativeSplitter, vb.getNativeVectorBatch) - dep.numInputRows.add(cb.numRows) - writeMetrics.incRecordsWritten(1) - markTime = System.currentTimeMillis() - splitTime += markTime - endTime - } - } - val flushStartTime = System.currentTimeMillis() - splitResult = jniWritter.stop(nativeSplitter) - - val stopTime = (System.currentTimeMillis() - flushStartTime) - dep.splitTime.add(splitTime) - writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten) - writeMetrics.incWriteTime(splitResult.getTotalWriteTime) - - partitionLengths = splitResult.getPartitionLengths - - val blockManagerId = BlockManagerId.apply(blockManager.blockManagerId.executorId, - blockManager.blockManagerId.host, - blockManager.blockManagerId.port, - Option.apply(OCKFunctions.getNodeId + "#" + context.taskAttemptId())) - mapStatus = MapStatus(blockManagerId, partitionLengths, mapId) - - System.out.println("shuffle_write_tick makeTime " + makeTime + " readTime " + readTime + " splitTime " - + splitTime + " changeTime " + changeTime + " stopTime " + stopTime + " rowNum " + dep.numInputRows.value + " vbCnt " + vbCnt) - } - - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - None - } else { - stopping = true - if (success) { - Option(mapStatus) - } else { - None - } - } - } finally { - if (nativeSplitter != 0) { - jniWritter.close(nativeSplitter) - nativeSplitter = 0 - } - } - } +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.shuffle.ock + +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import com.huawei.boostkit.spark.vectorized.SplitResult +import com.huawei.ock.spark.jni.OckShuffleJniWriter +import com.huawei.ock.ucache.shuffle.NativeShuffle +import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.shuffle._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{OCKConf, OCKFunctions} +import org.apache.spark.{SparkEnv, TaskContext} + +class OckColumnarShuffleWriter[K, V]( + applicationId: String, + ockConf: OCKConf, + handle: BaseShuffleHandle[K, V, V], + mapId: Long, + context: TaskContext, + writeMetrics: ShuffleWriteMetricsReporter) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]] + + private val blockManager = SparkEnv.get.blockManager + + private var stopping = false + + private var mapStatus: MapStatus = _ + + val enableShuffleCompress: Boolean = OckColumnarShuffleManager.isCompress(ockConf.sparkConf) + + val cap: Int = ockConf.capacity + val maxCapacityTotal: Int = ockConf.maxCapacityTotal + val minCapacityTotal: Int = ockConf.minCapacityTotal + + private val jniWritter = new OckShuffleJniWriter() + + private var nativeSplitter: Long = 0 + + private var splitResult: SplitResult = _ + + private var partitionLengths: Array[Long] = _ + + private var first: Boolean = true + private var readTime: Long = 0L + private var markTime: Long = 0L + private var splitTime: Long = 0L + private var changeTime: Long = 0L + private var rowNum: Int = 0 + private var vbCnt: Int = 0 + + override def write(records: Iterator[Product2[K, V]]): Unit = { + if (!records.hasNext) { + partitionLengths = new Array[Long](dep.partitioner.numPartitions) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) + return + } + + val startMake = System.currentTimeMillis() + if (nativeSplitter == 0) { + nativeSplitter = jniWritter.make( + applicationId, + dep.shuffleId, + context.stageId(), + context.stageAttemptNumber(), + mapId.toInt, + context.taskAttemptId(), + dep.partitionInfo, + cap, + maxCapacityTotal, + minCapacityTotal, + enableShuffleCompress) + } + val makeTime = System.currentTimeMillis() - startMake + + while (records.hasNext) { + vbCnt += 1 + if (first) { + readTime = System.currentTimeMillis() - makeTime + first = false + } else { + readTime += (System.currentTimeMillis() - markTime) + } + val cb = records.next()._2.asInstanceOf[ColumnarBatch] + if (cb.numRows == 0 || cb.numCols == 0) { + logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} cols") + System.out.println("Skip column") + markTime = System.currentTimeMillis() + } else { + val startTime = System.currentTimeMillis() + val input = transColBatchToOmniVecs(cb) + val endTime = System.currentTimeMillis() + changeTime += endTime - startTime + for( col <- 0 until cb.numCols()) { + dep.dataSize += input(col).getRealValueBufCapacityInBytes + dep.dataSize += input(col).getRealNullBufCapacityInBytes + dep.dataSize += input(col).getRealOffsetBufCapacityInBytes + } + val vb = new VecBatch(input, cb.numRows()) + if (rowNum == 0) { + rowNum = cb.numRows() + } + jniWritter.split(nativeSplitter, vb.getNativeVectorBatch) + dep.numInputRows.add(cb.numRows) + writeMetrics.incRecordsWritten(1) + markTime = System.currentTimeMillis() + splitTime += markTime - endTime + } + } + val flushStartTime = System.currentTimeMillis() + splitResult = jniWritter.stop(nativeSplitter) + + val stopTime = (System.currentTimeMillis() - flushStartTime) + dep.splitTime.add(splitTime) + writeMetrics.incBytesWritten(splitResult.getTotalBytesWritten) + writeMetrics.incWriteTime(splitResult.getTotalWriteTime) + + partitionLengths = splitResult.getPartitionLengths + + val blockManagerId = BlockManagerId.apply(blockManager.blockManagerId.executorId, + blockManager.blockManagerId.host, + blockManager.blockManagerId.port, + Option.apply(OCKFunctions.getNodeId + "#" + context.taskAttemptId())) + mapStatus = MapStatus(blockManagerId, partitionLengths, mapId) + + System.out.println("shuffle_write_tick makeTime " + makeTime + " readTime " + readTime + " splitTime " + + splitTime + " changeTime " + changeTime + " stopTime " + stopTime + " rowNum " + dep.numInputRows.value + " vbCnt " + vbCnt) + } + + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + None + } else { + stopping = true + if (success) { + NativeShuffle.shuffleStageSetShuffleId("Spark_"+applicationId, context.stageId(), handle.shuffleId) + Option(mapStatus) + } else { + None + } + } + } finally { + if (nativeSplitter != 0) { + jniWritter.close(nativeSplitter) + nativeSplitter = 0 + } + } + } + + override def getPartitionLengths(): Array[Long] = { + partitionLengths + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..608a3ca714fe707b61758f870d0b32877e22f9b2 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml @@ -0,0 +1,138 @@ + + + 4.0.0 + + + 3.3.1 + 2.12.15 + 2.12 + 3.2.3 + org.apache.spark + spark-3.3 + 3.2.0 + 3.1.1 + 23.0.0 + + + com.huawei.ock + ock-omniop-tuning + jar + Huawei Open Computing Kit for Spark, BoostTuning for OmniOperator + 23.0.0 + + + + org.scala-lang + scala-library + ${scala.version} + provided + + + ${spark.groupId} + spark-core_${scala.compat.version} + ${spark.version} + provided + + + ${spark.groupId} + spark-catalyst_${scala.compat.version} + ${spark.version} + provided + + + ${spark.groupId} + spark-sql_${scala.compat.version} + ${spark.version} + provided + + + com.huawei.ock + ock-adaptive-tuning + ${global.version} + + + com.huawei.ock + ock-tuning-sdk + ${global.version} + + + com.huawei.ock + ock-shuffle-sdk + ${global.version} + + + com.huawei.boostkit + boostkit-omniop-bindings + 1.3.0 + + + com.huawei.kunpeng + boostkit-omniop-spark + 3.3.1-1.3.0 + + + org.scalatest + scalatest_${scala.compat.version} + ${scalaTest.version} + test + + + + + ${project.artifactId}-${project.version}-for-${input.version} + src/main/scala + + + + net.alchim31.maven + scala-maven-plugin + ${scala.plugin.version} + + all + + + + + compile + testCompile + + + + -dependencyfile + ${project.build.directory}/.scala_dependencies + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + 8 + 8 + true + + -Xlint:all + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + ${maven.plugin.version} + + + + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/OmniOpBoostTuningExtension.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/OmniOpBoostTuningExtension.scala new file mode 100644 index 0000000000000000000000000000000000000000..13c4cf45e357d34f60370a8c4056f5c3505eabc7 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/OmniOpBoostTuningExtension.scala @@ -0,0 +1,18 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock + +import org.apache.spark.SparkContext +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.execution.adaptive.ock.rule._ + +class OmniOpBoostTuningExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectQueryStagePrepRule(_ => BoostTuningQueryStagePrepRule()) + extensions.injectColumnar(_ => OmniOpBoostTuningColumnarRule( + OmniOpBoostTuningPreColumnarRule(), OmniOpBoostTuningPostColumnarRule())) + SparkContext.getActive.get.addSparkListener(new BoostTuningListener()) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/common/OmniOpBoostTuningDefine.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/common/OmniOpBoostTuningDefine.scala new file mode 100644 index 0000000000000000000000000000000000000000..6213dd5878b1349e3c0c5608c9ab60159e14435e --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/common/OmniOpBoostTuningDefine.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.common + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.SparkEnv + +object OmniOpDefine { + final val COLUMNAR_SHUFFLE_MANAGER_DEFINE = "org.apache.spark.shuffle.sort.ColumnarShuffleManager" + + final val COLUMNAR_SORT_SPILL_ROW_THRESHOLD = "spark.omni.sql.columnar.sortSpill.rowThreshold" + final val COLUMNAR_SORT_SPILL_ROW_BASED_ENABLED = "spark.omni.sql.columnar.sortSpill.enabled" +} + +object OmniOCKShuffleDefine { + final val OCK_COLUMNAR_SHUFFLE_MANAGER_DEFINE = "org.apache.spark.shuffle.ock.OckColumnarShuffleManager" +} + +object OmniRuntimeConfiguration { + val enableColumnarShuffle: Boolean = ColumnarPluginConfig.getSessionConf.enableColumnarShuffle + val OMNI_SPILL_ROWS: Long = SparkEnv.get.conf.getLong(OmniOpDefine.COLUMNAR_SORT_SPILL_ROW_THRESHOLD, Integer.MAX_VALUE) + val OMNI_SPILL_ROW_ENABLED: Boolean = SparkEnv.get.conf.getBoolean(OmniOpDefine.COLUMNAR_SORT_SPILL_ROW_BASED_ENABLED, defaultValue = true) +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed6ef1a1a1d168d1e7ce6096290025a2779f2a1b --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeExec.scala @@ -0,0 +1,206 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer + +import nova.hetu.omniruntime.`type`.DataType + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningLogger._ +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningUtil._ +import org.apache.spark.sql.execution.adaptive.ock.exchange.estimator._ +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.{MapOutputStatistics, ShuffleDependency} +import org.apache.spark.util.MutablePair + +import scala.concurrent.Future + +case class BoostTuningColumnarShuffleExchangeExec( + override val outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS, + @transient context: PartitionContext) extends BoostTuningShuffleExchangeLike{ + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "bytesSpilled" -> SQLMetrics.createSizeMetric(sparkContext, "shuffle bytes spilled"), + "splitTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_split"), + "spillTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle spill time"), + "compressTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "totaltime_compress"), + "avgReadBatchNumRows" -> SQLMetrics + .createAverageMetric(sparkContext, "avg read batch num rows"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs"), + "numOutputRows" -> SQLMetrics + .createMetric(sparkContext, "number of output rows")) ++ readMetrics ++ writeMetrics + + override def nodeName: String = "BoostTuningOmniColumnarShuffleExchange" + + override def getContext: PartitionContext = context + + override def getDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = boostTuningColumnarShuffleDependency + + override def getUpStreamDataSize: Long = collectUpStreamInputDataSize(this.child) + + override def getPartitionEstimators: Seq[PartitionEstimator] = estimators + + @transient val helper: BoostTuningShuffleExchangeHelper = + new BoostTuningColumnarShuffleExchangeHelper(this, sparkContext) + + @transient lazy val estimators: Seq[PartitionEstimator] = Seq( + UpStreamPartitionEstimator(), + ColumnarSamplePartitionEstimator(helper.executionMem)) ++ Seq( + SinglePartitionEstimator(), + ColumnarElementsForceSpillPartitionEstimator() + ) + + override def supportsColumnar: Boolean = true + + val serializer: Serializer = new ColumnarBatchSerializer( + longMetric("avgReadBatchNumRows"), + longMetric("numOutputRows")) + + @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar() + + // 'mapOutputStatisticsFuture' is only needed when enable AQE. + @transient override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (inputColumnarRDD.getNumPartitions == 0) { + context.setSelfAndDepPartitionNum(outputPartitioning.numPartitions) + Future.successful(null) + } else { + omniAdaptivePartitionWithMapOutputStatistics() + } + } + + private def omniAdaptivePartitionWithMapOutputStatistics(): Future[MapOutputStatistics] = { + helper.cachedSubmitMapStage() match { + case Some(f) => return f + case _ => + } + + helper.onlineSubmitMapStage() match { + case f: Future[MapOutputStatistics] => f + case _ => Future.failed(null) + } + } + + override def numMappers: Int = boostTuningColumnarShuffleDependency.rdd.getNumPartitions + + override def numPartitions: Int = boostTuningColumnarShuffleDependency.partitioner.numPartitions + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = { + throw new IllegalArgumentException("Failed to getShuffleRDD, exec should use ColumnarBatch but not InternalRow") + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + + @transient + lazy val boostTuningColumnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val partitionInitTime = System.currentTimeMillis() + val newOutputPartitioning = helper.replacePartitionWithNewNum() + val partitionReadyTime = System.currentTimeMillis() + val dep = ColumnarShuffleExchangeExec.prepareShuffleDependency( + inputColumnarRDD, + child.output, + newOutputPartitioning, + serializer, + writeMetrics, + longMetric("dataSize"), + longMetric("bytesSpilled"), + longMetric("numInputRows"), + longMetric("splitTime"), + longMetric("spillTime")) + val dependencyReadyTime = System.currentTimeMillis() + TLogInfo(s"BoostTuningShuffleExchange $id input partition ${inputColumnarRDD.getNumPartitions}" + + s" modify ${if (helper.isAdaptive) "adaptive" else "global"}" + + s" partitionNum ${outputPartitioning.numPartitions} -> ${newOutputPartitioning.numPartitions}" + + s" partition modify cost ${partitionReadyTime - partitionInitTime} ms" + + s" dependency prepare cost ${dependencyReadyTime - partitionReadyTime} ms") + dep + } + + var cachedShuffleRDD: ShuffledColumnarRDD = _ + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } + + def buildCheck(): Unit = { + val inputTypes = new Array[DataType](child.output.size) + child.output.zipWithIndex.foreach { + case (attr, i) => + inputTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) + } + + outputPartitioning match { + case HashPartitioning(expressions, numPartitions) => + val genHashExpressionFunc = ColumnarShuffleExchangeExec.genHashExpr() + val hashJSonExpressions = genHashExpressionFunc(expressions, numPartitions, ColumnarShuffleExchangeExec.defaultMm3HashSeed, child.output) + if (!isSimpleColumn(hashJSonExpressions)) { + checkOmniJsonWhiteList("", Array(hashJSonExpressions)) + } + case _ => + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + if (cachedShuffleRDD == null) { + cachedShuffleRDD = new ShuffledColumnarRDD(boostTuningColumnarShuffleDependency, readMetrics) + } + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + if (enableShuffleBatchMerge) { + cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + new MergeIterator(iter, + StructType.fromAttributes(child.output), + longMetric("numMergedVecBatchs")) + } + } else { + cachedShuffleRDD + } + } + + protected def withNewChildInternal(newChild: SparkPlan): BoostTuningColumnarShuffleExchangeExec = { + copy(child = newChild) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeHelper.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeHelper.scala new file mode 100644 index 0000000000000000000000000000000000000000..4743b7e67da8bd16abd32e05f678bcf3e83053c8 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/BoostTuningColumnarShuffleExchangeHelper.scala @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.adaptive.ock.common.OmniRuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.common._ +import org.apache.spark.sql.execution.adaptive.ock.memory._ + +import java.util + +class BoostTuningColumnarShuffleExchangeHelper(exchange: BoostTuningShuffleExchangeLike, sparkContext: SparkContext) + extends BoostTuningShuffleExchangeHelper(exchange, sparkContext) { + + override val executionMem: Long = shuffleManager match { + case OCKBoostShuffleDefine.OCK_SHUFFLE_MANAGER_DEFINE => + BoostShuffleExecutionModel().apply() + case OmniOpDefine.COLUMNAR_SHUFFLE_MANAGER_DEFINE => + ColumnarExecutionModel().apply() + case OmniOCKShuffleDefine.OCK_COLUMNAR_SHUFFLE_MANAGER_DEFINE => + ColumnarExecutionModel().apply() + case _ => + OriginExecutionModel().apply() + } + + override protected def fillInput(input: util.LinkedHashMap[String, String]): Unit = { + input.put("executionSize", executionMem.toString) + input.put("upstreamDataSize", exchange.getUpStreamDataSize.toString) + input.put("partitionRatio", initPartitionRatio.toString) + var spillThreshold = if (OMNI_SPILL_ROW_ENABLED) { + Math.min(OMNI_SPILL_ROWS, numElementsForceSpillThreshold) + } else { + numElementsForceSpillThreshold + } + if (spillThreshold == Integer.MAX_VALUE) { + spillThreshold = -1 + } + + input.put("elementSpillThreshold", spillThreshold.toString) + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarElementsForceSpillPartitionEstimator.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarElementsForceSpillPartitionEstimator.scala new file mode 100644 index 0000000000000000000000000000000000000000..3c2507b1abce20d0b514df6dacd8b055c2dd1fbf --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarElementsForceSpillPartitionEstimator.scala @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange.estimator + +import org.apache.spark.sql.execution.adaptive.ock.common.OmniRuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.exchange._ +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + +case class ColumnarElementsForceSpillPartitionEstimator() extends PartitionEstimator { + + override def estimatorType: EstimatorType = ElementNumBased + + override def apply(exchange: ShuffleExchangeLike): Option[Int] = { + if (!sampleEnabled) { + return None + } + + if (!OMNI_SPILL_ROW_ENABLED && numElementsForceSpillThreshold == Integer.MAX_VALUE) { + return None + } + + val spillMinThreshold = if (OMNI_SPILL_ROW_ENABLED) { + Math.min(OMNI_SPILL_ROWS, numElementsForceSpillThreshold) + } else { + numElementsForceSpillThreshold + } + + exchange match { + case ex: BoostTuningColumnarShuffleExchangeExec => + val rowCount = ex.inputColumnarRDD + .sample(withReplacement = false, sampleRDDFraction) + .map(cb => cb.numRows()).first() + Some((initPartitionRatio * rowCount / spillMinThreshold).toInt) + case _ => + None + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarSamplePartitionEstimator.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarSamplePartitionEstimator.scala new file mode 100644 index 0000000000000000000000000000000000000000..e8decd6a5407290bcb7de39b1ae24e46b70e94ae --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/exchange/estimator/ColumnarSamplePartitionEstimator.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.exchange.estimator + +import com.huawei.boostkit.spark.util.OmniAdaptorUtil + +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ +import org.apache.spark.sql.execution.adaptive.ock.exchange.BoostTuningColumnarShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + +case class ColumnarSamplePartitionEstimator(executionMem: Long) extends PartitionEstimator { + + override def estimatorType: EstimatorType = DataSizeBased + + override def apply(exchange: ShuffleExchangeLike): Option[Int] = { + if (!sampleEnabled) { + return None + } + + exchange match { + case ex: BoostTuningColumnarShuffleExchangeExec => + val inputPartitionNum = ex.inputColumnarRDD.getNumPartitions + val sampleRDD = ex.inputColumnarRDD + .sample(withReplacement = false, sampleRDDFraction) + .map(cb => OmniAdaptorUtil.transColBatchToOmniVecs(cb).map(_.getCapacityInBytes).sum) + Some(SamplePartitionEstimator(executionMem).sampleAndGenPartitionNum(ex, inputPartitionNum, sampleRDD)) + case _ => + None + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/memory/ColumnarExecutionModel.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/memory/ColumnarExecutionModel.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5edfc7ab3859a32b52ff6e9fae1b26ecb80de21 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/memory/ColumnarExecutionModel.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.memory + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningLogger._ +import org.apache.spark.sql.execution.adaptive.ock.common.RuntimeConfiguration._ + +case class ColumnarExecutionModel() extends ExecutionModel { + override def apply(): Long = { + val systemMem = executorMemory + val executorCores = SparkEnv.get.conf.get(config.EXECUTOR_CORES).toLong + val reservedMem = SparkEnv.get.conf.getLong("spark.testing.reservedMemory", 300 * 1024 * 1024) + val usableMem = systemMem - reservedMem + val shuffleMemFraction = SparkEnv.get.conf.get(config.MEMORY_FRACTION) * + (1 - SparkEnv.get.conf.get(config.MEMORY_STORAGE_FRACTION)) + val offHeapMem = if (offHeapEnabled) { + offHeapSize + } else { + 0 + } + val finalMem = ((usableMem * shuffleMemFraction + offHeapMem) / executorCores).toLong + TLogDebug(s"ExecutorMemory is $systemMem reserved $reservedMem offHeapMem is $offHeapMem" + + s" shuffleMemFraction is $shuffleMemFraction, execution memory of executor is $finalMem") + finalMem + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/reader/BoostTuningColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/reader/BoostTuningColumnarCustomShuffleReaderExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..0cea10ba74ca12acc2a5b7370428948a8e4f9d33 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/reader/BoostTuningColumnarCustomShuffleReaderExec.scala @@ -0,0 +1,236 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive.ock.reader + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive. ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.ock.exchange.BoostTuningColumnarShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import scala.collection.mutable.ArrayBuffer + + +/** + * A wrapper of shuffle query stage, which follows the given partition arrangement. + * + * @param child It is usually `ShuffleQueryStageExec`, but can be the shuffle exchange + * node during canonicalization. + * @param partitionSpecs The partition specs that defines the arrangement. + */ +case class BoostTuningOmniAQEShuffleReadExec ( + child: SparkPlan, + partitionSpecs: Seq[ShufflePartitionSpec]) + extends UnaryExecNode { + // If this reader is to read shuffle files locally, then all partition specs should be + // `PartialMapperPartitionSpec`. + if (partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec])) { + assert(partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec])) + } + + override def nodeName: String = "BoostTuningOmniAQEShuffleReadeExec" + + override def supportsColumnar: Boolean = true + + override def output: Seq[Attribute] = child.output + override lazy val outputPartitioning: Partitioning = { + // If it is a local shuffle reader with one mapper per task, then the output partitioning is + // the same as the plan before shuffle. + if (partitionSpecs.nonEmpty && + partitionSpecs.forall(_.isInstanceOf[PartialMapperPartitionSpec]) && + partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size == + partitionSpecs.length) { + child match { + case ShuffleQueryStageExec(_, s: ShuffleExchangeLike, _) => + s.child.outputPartitioning + case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeLike), _) => + s.child.outputPartitioning match { + case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning] + case other => other + } + case _ => + throw new IllegalStateException("operating on canonicalization plan") + } + } else { + UnknownPartitioning(partitionSpecs.length) + } + } + + override def stringArgs: Iterator[Any] = { + val desc = if (isLocalReader) { + "local" + } else if (hasCoalescedPartition && hasSkewedPartition) { + "coalesced and skewed" + } else if (hasCoalescedPartition) { + "coalesced" + } else if (hasSkewedPartition) { + "skewed" + } else { + "" + } + Iterator(desc) + } + + def hasCoalescedPartition: Boolean = + partitionSpecs.exists(_.isInstanceOf[CoalescedPartitionSpec]) + + def hasSkewedPartition: Boolean = + partitionSpecs.exists(_.isInstanceOf[PartialReducerPartitionSpec]) + + def isLocalReader: Boolean = + partitionSpecs.exists(_.isInstanceOf[PartialMapperPartitionSpec]) + + private def shuffleStage = child match { + case stage: ShuffleQueryStageExec => Some(stage) + case _ => None + } + + @transient private lazy val partitionDataSizes: Option[Seq[Long]] = { + if (partitionSpecs.nonEmpty && !isLocalReader && shuffleStage.get.mapStats.isDefined) { + val bytesByPartitionId = shuffleStage.get.mapStats.get.bytesByPartitionId + Some(partitionSpecs.map { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + startReducerIndex.until(endReducerIndex).map(bytesByPartitionId).sum + case p: PartialReducerPartitionSpec => p.dataSize + case p => throw new IllegalStateException("unexpected " + p) + }) + } else { + None + } + } + + private def sendDriverMetrics(): Unit = { + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + val driverAccumUpdates = ArrayBuffer.empty[(Long, Long)] + + val numPartitionsMetric = metrics("numPartitions") + numPartitionsMetric.set(partitionSpecs.length) + driverAccumUpdates += (numPartitionsMetric.id -> partitionSpecs.length.toLong) + + if (hasSkewedPartition) { + val skewedSpecs = partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p + } + + val skewedPartitions = metrics("numSkewedPartitions") + val skewedSplits = metrics("numSkewedSplits") + + val numSkewedPartitions = skewedSpecs.map(_.reducerIndex).distinct.length + val numSplits = skewedSpecs.length + + skewedPartitions.set(numSkewedPartitions) + driverAccumUpdates += (skewedPartitions.id -> numSkewedPartitions) + + skewedSplits.set(numSplits) + driverAccumUpdates += (skewedSplits.id -> numSplits) + } + + partitionDataSizes.foreach { dataSizes => + val partitionDataSizeMetrics = metrics("partitionDataSize") + driverAccumUpdates ++= dataSizes.map(partitionDataSizeMetrics.id -> _) + // Set sum value to "partitionDataSize" metric. + partitionDataSizeMetrics.set(dataSizes.sum) + } + + SQLMetrics.postDriverMetricsUpdatedByValue(sparkContext, executionId, driverAccumUpdates.toSeq) + } + + override lazy val metrics: Map[String, SQLMetric] = { + if (shuffleStage.isDefined) { + Map( + "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + "bypassVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of bypass vecBatchs"), + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")) ++ { + if (isLocalReader) { + // We split the mapper partition evenly when creating local shuffle reader, so no + // data size info is available. + Map.empty + } else { + Map("partitionDataSize" -> + SQLMetrics.createSizeMetric(sparkContext, "partition data size")) + } + } ++ { + if (hasSkewedPartition) { + Map("numSkewedPartitions" -> + SQLMetrics.createMetric(sparkContext, "number of skewed partitions"), + "numSkewedSplits" -> + SQLMetrics.createMetric(sparkContext, "number of skewed partition splits")) + } else { + Map.empty + } + } + } else { + // It's a canonicalized plan, no need to report metrics. + Map.empty + } + } + + private var cachedShuffleRDD: RDD[ColumnarBatch] = null + + private lazy val shuffleRDD: RDD[_] = { + sendDriverMetrics() + if (cachedShuffleRDD == null) { + cachedShuffleRDD = child match { + case stage: ShuffleQueryStageExec => + new ShuffledColumnarRDD( + stage.shuffle + .asInstanceOf[BoostTuningColumnarShuffleExchangeExec] + .boostTuningColumnarShuffleDependency, + stage.shuffle.asInstanceOf[BoostTuningColumnarShuffleExchangeExec].readMetrics, + partitionSpecs.toArray) + case _ => + throw new IllegalStateException("operating on canonicalized plan") + } + } + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + if (enableShuffleBatchMerge) { + cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => + new MergeIterator(iter, + StructType.fromAttributes(child.output), + longMetric("numMergedVecBatchs")) + } + } else { + cachedShuffleRDD + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + shuffleRDD.asInstanceOf[RDD[ColumnarBatch]] + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + new BoostTuningOmniAQEShuffleReadExec(newChild, this.partitionSpecs) + } +} diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/OmniOpBoostTuningColumnarRule.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/OmniOpBoostTuningColumnarRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..be6632fa7d4df47854bb5ddb187d9e1cd77bd9f9 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/OmniOpBoostTuningColumnarRule.scala @@ -0,0 +1,155 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.rule + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ock.BoostTuningQueryManager +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningLogger.TLogWarning +import org.apache.spark.sql.execution.adaptive.ock.common.BoostTuningUtil.{getQueryExecutionId, normalizedSparkPlan} +import org.apache.spark.sql.execution.adaptive.ock.common.OmniRuntimeConfiguration.enableColumnarShuffle +import org.apache.spark.sql.execution.adaptive.ock.common.StringPrefix.SHUFFLE_PREFIX +import org.apache.spark.sql.execution.adaptive.ock.exchange._ +import org.apache.spark.sql.execution.adaptive.ock.reader._ +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + +import scala.collection.mutable + +case class OmniOpBoostTuningColumnarRule(pre: Rule[SparkPlan], post: Rule[SparkPlan]) extends ColumnarRule { + override def preColumnarTransitions: Rule[SparkPlan] = pre + + override def postColumnarTransitions: Rule[SparkPlan] = post +} + +object OmniOpBoostTuningColumnarRule { + val rollBackExchangeIdents: mutable.Set[String] = mutable.Set.empty +} + +case class OmniOpBoostTuningPreColumnarRule() extends Rule[SparkPlan] { + + override val ruleName: String = "OmniOpBoostTuningPreColumnarRule" + + val delegate: BoostTuningPreNewQueryStageRule = BoostTuningPreNewQueryStageRule() + + override def apply(plan: SparkPlan): SparkPlan = { + val executionId = getQueryExecutionId(plan) + if (executionId < 0) { + TLogWarning(s"Skipped to apply BoostTuning new query stage rule for unneeded plan: $plan") + return plan + } + + val query = BoostTuningQueryManager.getOrCreateQueryManager(executionId) + + delegate.prepareQueryExecution(query, plan) + + delegate.reportQueryShuffleMetrics(query, plan) + + tryMarkRollBack(plan) + + replaceOmniQueryExchange(plan) + } + + private def tryMarkRollBack(plan: SparkPlan): Unit = { + plan.foreach { + case plan: BoostTuningShuffleExchangeLike => + if (!enableColumnarShuffle) { + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + } + try { + BoostTuningColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrigin, null).buildCheck() + } catch { + case e: UnsupportedOperationException => + logDebug(s"[OPERATOR FALLBACK] ${e} ${plan.getClass} falls back to Spark operator") + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + case l: UnsatisfiedLinkError => + throw l + case f: NoClassDefFoundError => + throw f + case r: RuntimeException => + logDebug(s"[OPERATOR FALLBACK] ${r} ${plan.getClass} falls back to Spark operator") + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + case t: Throwable => + logDebug(s"[OPERATOR FALLBACK] ${t} ${plan.getClass} falls back to Spark operator") + OmniOpBoostTuningColumnarRule.rollBackExchangeIdents += plan.getContext.ident + } + case _ => + } + } + + def replaceOmniQueryExchange(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case ex: ColumnarShuffleExchangeExec => + BoostTuningColumnarShuffleExchangeExec( + ex.outputPartitioning, ex.child, ex.shuffleOrigin, + PartitionContext(normalizedSparkPlan(ex, SHUFFLE_PREFIX))) + } + } +} + +case class OmniOpBoostTuningPostColumnarRule() extends Rule[SparkPlan] { + + override val ruleName: String = "OmniOpBoostTuningPostColumnarRule" + + override def apply(plan: SparkPlan): SparkPlan = { + + var newPlan = plan match { + case b: BoostTuningShuffleExchangeLike if !OmniOpBoostTuningColumnarRule.rollBackExchangeIdents.contains(b.getContext.ident) => + b.child match { + case ColumnarToRowExec(child) => + BoostTuningColumnarShuffleExchangeExec(b.outputPartitioning, child, b.shuffleOrigin, b.getContext) + case plan if !plan.supportsColumnar => + BoostTuningColumnarShuffleExchangeExec(b.outputPartitioning, RowToOmniColumnarExec(plan), b.shuffleOrigin, b.getContext) + case _ => b + } + case _ => plan + } + + newPlan = additionalReplaceWithColumnarPlan(newPlan) + + newPlan.transformUp { + case c: AQEShuffleReadExec if ColumnarPluginConfig.getConf.enableColumnarShuffle => + c.child match { + case shuffle: BoostTuningColumnarShuffleExchangeExec => + logDebug(s"Columnar Processing for ${c.getClass} is currently supported.") + BoostTuningOmniAQEShuffleReadExec(c.child, c.partitionSpecs) + case ShuffleQueryStageExec(_, shuffle: BoostTuningColumnarShuffleExchangeExec, _) => + logDebug(s"Columnar Processing for ${c.getClass} is currently supported.") + BoostTuningOmniAQEShuffleReadExec(c.child, c.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => + reused match { + case ReusedExchangeExec(_, shuffle: BoostTuningColumnarShuffleExchangeExec) => + logDebug(s"Columnar Processing for ${c.getClass} is currently supported.") + BoostTuningOmniAQEShuffleReadExec(c.child, c.partitionSpecs) + case _ => + c + } + case _ => + c + } + } + } + + def additionalReplaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { + case ColumnarToRowExec(child: BoostTuningShuffleExchangeLike) => + additionalReplaceWithColumnarPlan(child) + case r: SparkPlan + if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c => + c.isInstanceOf[ColumnarToRowExec]) => + val children = r.children.map { + case c: ColumnarToRowExec => + val child = additionalReplaceWithColumnarPlan(c.child) + OmniColumnarToRowExec(child) + case other => + additionalReplaceWithColumnarPlan(other) + } + r.withNewChildren(children) + case p => + val children = p.children.map(additionalReplaceWithColumnarPlan) + p.withNewChildren(children) + } +} + diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/relation/ColumnarSMJRelationMarker.scala b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/relation/ColumnarSMJRelationMarker.scala new file mode 100644 index 0000000000000000000000000000000000000000..380b6d55323e9e6f5ab9589da3c669fff966bba7 --- /dev/null +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/src/main/scala/org/apache/spark/execution/adaptive/ock/rule/relation/ColumnarSMJRelationMarker.scala @@ -0,0 +1,20 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + */ + +package org.apache.spark.sql.execution.adaptive.ock.rule.relation + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.joins.{ColumnarSortMergeJoinExec, SortMergeJoinExec} + +object ColumnarSMJRelationMarker extends RelationMarker { + + override def solve(plan: SparkPlan): SparkPlan = plan.transformUp { + case csmj @ ColumnarSortMergeJoinExec(_, _, _, _, left, right, _, _) => + SMJRelationMarker.solveDepAndWorkGroupOfSMJExec(left, right) + csmj + case smj @ SortMergeJoinExec(_, _, _, _, left, right, _) => + SMJRelationMarker.solveDepAndWorkGroupOfSMJExec(left, right) + smj + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/pom.xml b/omnioperator/omniop-spark-extension-ock/pom.xml index 2d3f670bbc6407bc23bfef889f4aded7e1db108a..84c9208cc3caebf6a4a3dff7a9b4cd9b0b4d63ee 100644 --- a/omnioperator/omniop-spark-extension-ock/pom.xml +++ b/omnioperator/omniop-spark-extension-ock/pom.xml @@ -4,29 +4,25 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 + com.huawei.ock + omniop-spark-extension-ock + pom + Huawei Open Computing Kit for Spark + 23.0.0 + - cpp/ - cpp/build/releases/ - FALSE - 0.6.1 - 3.1.2 - 2.12.10 + 3.3.1 + 2.12.15 2.12 3.2.3 3.4.6 org.apache.spark - spark-3.1 + spark-3.3 3.2.0 3.1.1 - 22.0.0 + 23.0.0 - com.huawei.ock - ock-omniop-shuffle-manager - jar - Huawei Open Computing Kit for Spark, shuffle manager - 22.0.0 - org.scala-lang @@ -66,12 +62,12 @@ com.huawei.boostkit boostkit-omniop-bindings - 1.1.0 + 1.3.0 com.huawei.kunpeng boostkit-omniop-spark - 3.1.1-1.1.0 + 3.3.1-1.3.0 com.huawei.ock @@ -103,103 +99,8 @@ - - - ${project.artifactId}-${project.version}-for-${input.version} - - - ${cpp.build.dir} - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - - - net.alchim31.maven - scala-maven-plugin - ${scala.plugin.version} - - all - - - - - compile - testCompile - - - - -dependencyfile - ${project.build.directory}/.scala_dependencies - - - - - - - org.apache.maven.plugins - maven-jar-plugin - - - org.apache.maven.plugins - maven-compiler-plugin - 3.1 - - 8 - 8 - true - - -Xlint:all - - - - - exec-maven-plugin - org.codehaus.mojo - 3.0.0 - - - Build CPP - generate-resources - - exec - - - bash - - ${cpp.dir}/build.sh - ${plugin.cpp.test} - - - - - - - org.xolstice.maven.plugins - protobuf-maven-plugin - ${protobuf.maven.version} - - ${project.basedir}/../cpp/src/proto - - - - - compile - - - - - - - - - org.apache.maven.plugins - maven-jar-plugin - ${maven.plugin.version} - - - - - \ No newline at end of file + + ock-omniop-shuffle + ock-omniop-tuning + + \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/.keep b/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt index 491cfb7086037229608f2963cf6c278ca132b198..10f630ad13925922872540fb13b379a0b52e15b3 100644 --- a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt @@ -1,9 +1,9 @@ -# project name -project(spark-thestral-plugin) - # required cmake version cmake_minimum_required(VERSION 3.10) +# project name +project(spark-thestral-plugin) + # configure cmake set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_COMPILER "g++") diff --git a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt index 45780185a3c4dad66193e71bc6b13e506be34591..26df3cb85b9255ee0969d8631deb6ab76488101d 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt @@ -7,20 +7,14 @@ set (SOURCE_FILES io/ColumnWriter.cc io/Compression.cc io/MemoryPool.cc - io/OrcObsFile.cc io/OutputStream.cc io/SparkFile.cc io/WriterOptions.cc - io/orcfile/OrcFileRewrite.cc - io/orcfile/OrcHdfsFileRewrite.cc shuffle/splitter.cpp common/common.cpp jni/SparkJniWrapper.cpp - jni/OrcColumnarBatchJniReader.cpp jni/jni_common.cpp - jni/ParquetColumnarBatchJniReader.cpp - tablescan/ParquetReader.cpp - io/ParquetObsFile.cc) + ) #Find required protobuf package find_package(Protobuf REQUIRED) @@ -35,20 +29,12 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) protobuf_generate_cpp(PROTO_SRCS_VB PROTO_HDRS_VB proto/vec_data.proto) add_library (${PROJ_TARGET} SHARED ${SOURCE_FILES} ${PROTO_SRCS} ${PROTO_HDRS} ${PROTO_SRCS_VB} ${PROTO_HDRS_VB}) -find_package(Arrow REQUIRED) -find_package(ArrowDataset REQUIRED) -find_package(Parquet REQUIRED) - #JNI target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include) target_include_directories(${PROJ_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) target_include_directories(${PROJ_TARGET} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries (${PROJ_TARGET} PUBLIC - Arrow::arrow_shared - ArrowDataset::arrow_dataset_shared - Parquet::parquet_shared - orc crypto sasl2 protobuf @@ -56,8 +42,7 @@ target_link_libraries (${PROJ_TARGET} PUBLIC snappy lz4 zstd - eSDKOBS - boostkit-omniop-vector-1.3.0-aarch64 + boostkit-omniop-vector-1.4.0-aarch64 ) set_target_properties(${PROJ_TARGET} PROPERTIES diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/BinaryLocation.h b/omnioperator/omniop-spark-extension/cpp/src/common/BinaryLocation.h index 683b0fa9d7fd5cb2f849b3637816758c29bf53d4..a9c8b4e974397347347dc677ad3fd2d7844a8acd 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/BinaryLocation.h +++ b/omnioperator/omniop-spark-extension/cpp/src/common/BinaryLocation.h @@ -67,11 +67,19 @@ public: return vc_list; } + bool hasNull() const { + return hasNullFlag; + } + + void SetNullFlag(bool hasNull) { + hasNullFlag = hasNull; + } + public: uint32_t vcb_capacity; uint32_t vcb_total_len; std::vector vc_list; - + bool hasNullFlag = false; }; #endif //SPARK_THESTRAL_PLUGIN_BINARYLOCATION_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h index 73fe13732d27dca87e12ac72900635f8f26cd5f4..ab8a52c229017b4277c7c3b5552477133aa27e4b 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h +++ b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h @@ -16,29 +16,42 @@ * limitations under the License. */ - #ifndef CPP_BUFFER_H - #define CPP_BUFFER_H +#ifndef CPP_BUFFER_H +#define CPP_BUFFER_H - #include - #include - #include - #include - #include - - class Buffer { - public: - Buffer(uint8_t* data, int64_t size, int64_t capacity) - : data_(data), - size_(size), - capacity_(capacity) { +#include +#include +#include +#include +#include +#include + +class Buffer { +public: + Buffer(uint8_t* data, int64_t size, int64_t capacity, bool isOmniAllocated = true) + : data_(data), + size_(size), + capacity_(capacity), + allocatedByOmni(isOmniAllocated) { + } + + ~Buffer() { + if (allocatedByOmni && not releaseFlag) { + auto *allocator = omniruntime::mem::Allocator::GetAllocator(); + allocator->Free(data_, capacity_); } + } - ~Buffer() {} + void SetReleaseFlag() { + releaseFlag = true; + } - public: - uint8_t * data_; - int64_t size_; - int64_t capacity_; - }; +public: + uint8_t * data_; + int64_t size_; + int64_t capacity_; + bool allocatedByOmni = true; + bool releaseFlag = false; +}; - #endif //CPP_BUFFER_H \ No newline at end of file +#endif //CPP_BUFFER_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp index 0f78c68cba5ea92d83d4eed2acaa21c96b633f33..6a6e5f9121cd44eec36f6f7dd0d7cd86bb637c2d 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.cpp @@ -20,35 +20,6 @@ using namespace omniruntime::vec; -int32_t BytesGen(uint64_t offsetsAddr, uint64_t nullsAddr, uint64_t valuesAddr, VCBatchInfo& vcb) -{ - int32_t* offsets = reinterpret_cast(offsetsAddr); - char *nulls = reinterpret_cast(nullsAddr); - char* values = reinterpret_cast(valuesAddr); - std::vector &lst = vcb.getVcList(); - int itemsTotalLen = lst.size(); - int valueTotalLen = 0; - for (int i = 0; i < itemsTotalLen; i++) { - char* addr = reinterpret_cast(lst[i].get_vc_addr()); - int len = lst[i].get_vc_len(); - if (i == 0) { - offsets[0] = 0; - } else { - offsets[i] = offsets[i -1] + lst[i - 1].get_vc_len(); - } - if (lst[i].get_is_null()) { - nulls[i] = 1; - } else { - nulls[i] = 0; - } - if (len != 0) { - memcpy((char *) (values + offsets[i]), addr, len); - valueTotalLen += len; - } - } - offsets[itemsTotalLen] = offsets[itemsTotalLen -1] + lst[itemsTotalLen - 1].get_vc_len(); - return valueTotalLen; -} uint32_t reversebytes_uint32t(uint32_t const value) { diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/common.h b/omnioperator/omniop-spark-extension/cpp/src/common/common.h index 733dac920727489b205727d32300252bd32626c5..1578b85141ac6e9869da97ab2cb66b5359d52624 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/common/common.h @@ -37,7 +37,43 @@ #include "Buffer.h" #include "BinaryLocation.h" -int32_t BytesGen(uint64_t offsets, uint64_t nulls, uint64_t values, VCBatchInfo& vcb); +template +int32_t BytesGen(uint64_t offsetsAddr, std::string &nullStr, uint64_t valuesAddr, VCBatchInfo& vcb) +{ + int32_t* offsets = reinterpret_cast(offsetsAddr); + char *nulls = nullptr; + char* values = reinterpret_cast(valuesAddr); + std::vector &lst = vcb.getVcList(); + int itemsTotalLen = lst.size(); + + int valueTotalLen = 0; + if constexpr (hasNull) { + nullStr.resize(itemsTotalLen, 0); + nulls = nullStr.data(); + } + + for (int i = 0; i < itemsTotalLen; i++) { + char* addr = reinterpret_cast(lst[i].get_vc_addr()); + int len = lst[i].get_vc_len(); + if (i == 0) { + offsets[0] = 0; + } else { + offsets[i] = offsets[i -1] + lst[i - 1].get_vc_len(); + } + if constexpr(hasNull) { + if (lst[i].get_is_null()) { + nulls[i] = 1; + } + } + + if (len != 0) { + memcpy((char *) (values + offsets[i]), addr, len); + valueTotalLen += len; + } + } + offsets[itemsTotalLen] = offsets[itemsTotalLen -1] + lst[itemsTotalLen - 1].get_vc_len(); + return valueTotalLen; +} uint32_t reversebytes_uint32t(uint32_t value); diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/OrcObsFile.cc b/omnioperator/omniop-spark-extension/cpp/src/io/OrcObsFile.cc deleted file mode 100644 index b3abc9eb35cd67eedf9cdf58cab7afab1cafec34..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/cpp/src/io/OrcObsFile.cc +++ /dev/null @@ -1,194 +0,0 @@ -/** - * Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "OrcObsFile.hh" - -#include - -#include "../common/debug.h" -#include "securec.h" - -namespace orc { - std::unique_ptr readObsFile(const std::string& path, ObsConfig *obsInfo) { - return std::unique_ptr(new ObsFileInputStream(path, obsInfo)); - } - - typedef struct CallbackData { - char *buf; - uint64_t length; - uint64_t readLength; - obs_status retStatus; - } CallbackData; - - obs_status responsePropertiesCallback(const obs_response_properties *properties, void *data) { - if (NULL == properties) { - LogsError("OBS error, obs_response_properties is null!"); - return OBS_STATUS_ErrorUnknown; - } - CallbackData *ret = (CallbackData *)data; - ret->length = properties->content_length; - return OBS_STATUS_OK; - } - - void commonErrorHandle(const obs_error_details *error) { - if (!error) { - return; - } - if (error->message) { - LogsError("OBS error message: %s", error->message); - } - if (error->resource) { - LogsError("OBS error resource: %s", error->resource); - } - if (error->further_details) { - LogsError("OBS error further details: %s", error->further_details); - } - if (error->extra_details_count) { - LogsError("OBS error extra details:"); - for (int i = 0; i < error->extra_details_count; i++) { - LogsError("[name] %s: [value] %s", error->extra_details[i].name, error->extra_details[i].value); - } - } - } - - void responseCompleteCallback(obs_status status, const obs_error_details *error, void *data) { - if (data) { - CallbackData *ret = (CallbackData *)data; - ret->retStatus = status; - } - commonErrorHandle(error); - } - - obs_status getObjectDataCallback(int buffer_size, const char *buffer, void *data) { - CallbackData *callbackData = (CallbackData *)data; - int read = buffer_size; - if (callbackData->readLength + buffer_size > callbackData->length) { - LogsError("OBS get object failed, read buffer size(%d) is bigger than the remaining buffer\ - (totalLength[%ld] - readLength[%ld] = %ld).\n", - buffer_size, callbackData->length, callbackData->readLength, - callbackData->length - callbackData->readLength); - return OBS_STATUS_InvalidParameter; - } - memcpy_s(callbackData->buf + callbackData->readLength, read, buffer, read); - callbackData->readLength += read; - return OBS_STATUS_OK; - } - - obs_status ObsFileInputStream::obsInit() { - obs_status status = OBS_STATUS_BUTT; - status = obs_initialize(OBS_INIT_ALL); - if (OBS_STATUS_OK != status) { - LogsError("OBS initialize failed(%s).", obs_get_status_name(status)); - throw ParseError("OBS initialize failed."); - } - return status; - } - - obs_status ObsFileInputStream::obsInitStatus = obsInit(); - - void ObsFileInputStream::getObsInfo(ObsConfig *obsConf) { - memcpy_s(&obsInfo, sizeof(ObsConfig), obsConf, sizeof(ObsConfig)); - - std::string obsFilename = filename.substr(OBS_PROTOCOL_SIZE); - uint64_t splitNum = obsFilename.find_first_of("/"); - std::string bucket = obsFilename.substr(0, splitNum); - uint32_t bucketLen = bucket.length(); - strcpy_s(obsInfo.bucket, bucketLen + 1, bucket.c_str()); - option.bucket_options.bucket_name = obsInfo.bucket; - - memset_s(&objectInfo, sizeof(obs_object_info), 0, sizeof(obs_object_info)); - std::string key = obsFilename.substr(splitNum + 1); - strcpy_s(obsInfo.objectKey, key.length() + 1, key.c_str()); - objectInfo.key = obsInfo.objectKey; - - if (obsInfo.hostLen > bucketLen && strncmp(obsInfo.hostName, obsInfo.bucket, bucketLen) == 0) { - obsInfo.hostLen = obsInfo.hostLen - bucketLen - 1; - memcpy_s(obsInfo.hostName, obsInfo.hostLen, obsInfo.hostName + bucketLen + 1, obsInfo.hostLen); - obsInfo.hostName[obsInfo.hostLen - 1] = '\0'; - } - - option.bucket_options.host_name = obsInfo.hostName; - option.bucket_options.access_key = obsInfo.accessKey; - option.bucket_options.secret_access_key = obsInfo.secretKey; - option.bucket_options.token = obsInfo.token; - } - - ObsFileInputStream::ObsFileInputStream(std::string _filename, ObsConfig *obsConf) { - filename = _filename; - init_obs_options(&option); - - getObsInfo(obsConf); - - CallbackData data; - data.retStatus = OBS_STATUS_BUTT; - data.length = 0; - obs_response_handler responseHandler = { - &responsePropertiesCallback, - &responseCompleteCallback - }; - - get_object_metadata(&option, &objectInfo, 0, &responseHandler, &data); - if (OBS_STATUS_OK != data.retStatus) { - throw ParseError("get obs object(" + filename + ") metadata failed, error_code: " + - obs_get_status_name(data.retStatus)); - } - totalLength = data.length; - - memset_s(&conditions, sizeof(obs_get_conditions), 0, sizeof(obs_get_conditions)); - init_get_properties(&conditions); - } - - void ObsFileInputStream::read(void *buf, uint64_t length, uint64_t offset) { - if (!buf) { - throw ParseError("Buffer is null."); - } - conditions.start_byte = offset; - conditions.byte_count = length; - - obs_get_object_handler handler = { - { &responsePropertiesCallback, - &responseCompleteCallback}, - &getObjectDataCallback - }; - - CallbackData data; - data.retStatus = OBS_STATUS_BUTT; - data.length = length; - data.readLength = 0; - data.buf = reinterpret_cast(buf); - do { - // the data.buf offset is processed in the callback function getObjectDataCallback - uint64_t tmpRead = data.readLength; - get_object(&option, &objectInfo, &conditions, 0, &handler, &data); - if (OBS_STATUS_OK != data.retStatus) { - LogsError("get obs object failed, length=%ld, readLength=%ld, offset=%ld", - data.length, data.readLength, offset); - throw ParseError("get obs object(" + filename + ") failed, error_code: " + - obs_get_status_name(data.retStatus)); - } - - // read data buffer size = 0, no more remaining data need to read - if (tmpRead == data.readLength) { - break; - } - conditions.start_byte = offset + data.readLength; - conditions.byte_count = length - data.readLength; - } while (data.readLength < length); - } -} diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/ParquetObsFile.cc b/omnioperator/omniop-spark-extension/cpp/src/io/ParquetObsFile.cc deleted file mode 100644 index 32b294853e6a7ccf0bf47da47856ee9db9763fbc..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/cpp/src/io/ParquetObsFile.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "ParquetObsFile.hh" -#include "securec.h" -#include "common/debug.h" - -using namespace arrow::io; -using namespace arrow; - -namespace spark::reader { - std::shared_ptr readObsFile(const std::string& path, ObsConfig *obsInfo) { - return std::shared_ptr(new ObsReadableFile(path, obsInfo)); - } - - typedef struct CallbackData { - char *buf; - uint64_t length; - uint64_t readLength; - obs_status retStatus; - } CallbackData; - - obs_status responsePropertiesCallback(const obs_response_properties *properties, void *data) { - if (NULL == properties) { - LogsError("OBS error, obs_response_properties is null!"); - return OBS_STATUS_ErrorUnknown; - } - CallbackData *ret = (CallbackData *)data; - ret->length = properties->content_length; - return OBS_STATUS_OK; - } - - void commonErrorHandle(const obs_error_details *error) { - if (!error) { - return; - } - if (error->message) { - LogsError("OBS error message: %s", error->message); - } - if (error->resource) { - LogsError("OBS error resource: %s", error->resource); - } - if (error->further_details) { - LogsError("OBS error further details: %s", error->further_details); - } - if (error->extra_details_count) { - LogsError("OBS error extra details:"); - for (int i = 0; i < error->extra_details_count; i++) { - LogsError("[name] %s: [value] %s", error->extra_details[i].name, error->extra_details[i].value); - } - } - } - - void responseCompleteCallback(obs_status status, const obs_error_details *error, void *data) { - if (data) { - CallbackData *ret = (CallbackData *)data; - ret->retStatus = status; - } - commonErrorHandle(error); - } - - obs_status getObjectDataCallback(int buffer_size, const char *buffer, void *data) { - CallbackData *callbackData = (CallbackData *)data; - int read = buffer_size; - if (callbackData->readLength + buffer_size > callbackData->length) { - LogsError("OBS get object failed, read buffer size(%d) is bigger than the remaining buffer\ - (totalLength[%ld] - readLength[%ld] = %ld).\n", - buffer_size, callbackData->length, callbackData->readLength, - callbackData->length - callbackData->readLength); - return OBS_STATUS_InvalidParameter; - } - memcpy_s(callbackData->buf + callbackData->readLength, read, buffer, read); - callbackData->readLength += read; - return OBS_STATUS_OK; - } - - obs_status ObsReadableFile::obsInit() { - obs_status status = OBS_STATUS_BUTT; - status = obs_initialize(OBS_INIT_ALL); - if (OBS_STATUS_OK != status) { - LogsError("OBS initialize failed(%s).", obs_get_status_name(status)); - throw std::runtime_error("OBS initialize failed."); - } - return status; - } - - obs_status ObsReadableFile::obsInitStatus = obsInit(); - - void ObsReadableFile::getObsInfo(ObsConfig *obsConf) { - memcpy_s(&obsInfo, sizeof(ObsConfig), obsConf, sizeof(ObsConfig)); - - std::string obsFilename = filename.substr(OBS_PROTOCOL_SIZE); - uint64_t splitNum = obsFilename.find_first_of("/"); - std::string bucket = obsFilename.substr(0, splitNum); - uint32_t bucketLen = bucket.length(); - strcpy_s(obsInfo.bucket, bucketLen + 1, bucket.c_str()); - option.bucket_options.bucket_name = obsInfo.bucket; - - memset_s(&objectInfo, sizeof(obs_object_info), 0, sizeof(obs_object_info)); - std::string key = obsFilename.substr(splitNum + 1); - strcpy_s(obsInfo.objectKey, key.length() + 1, key.c_str()); - objectInfo.key = obsInfo.objectKey; - - if (obsInfo.hostLen > bucketLen && strncmp(obsInfo.hostName, obsInfo.bucket, bucketLen) == 0) { - obsInfo.hostLen = obsInfo.hostLen - bucketLen - 1; - memcpy_s(obsInfo.hostName, obsInfo.hostLen, obsInfo.hostName + bucketLen + 1, obsInfo.hostLen); - obsInfo.hostName[obsInfo.hostLen - 1] = '\0'; - } - - option.bucket_options.host_name = obsInfo.hostName; - option.bucket_options.access_key = obsInfo.accessKey; - option.bucket_options.secret_access_key = obsInfo.secretKey; - option.bucket_options.token = obsInfo.token; - } - - ObsReadableFile::ObsReadableFile(std::string _filename, ObsConfig *obsConf) { - filename = _filename; - init_obs_options(&option); - - getObsInfo(obsConf); - - CallbackData data; - data.retStatus = OBS_STATUS_BUTT; - data.length = 0; - obs_response_handler responseHandler = { - &responsePropertiesCallback, - &responseCompleteCallback - }; - - get_object_metadata(&option, &objectInfo, 0, &responseHandler, &data); - if (OBS_STATUS_OK != data.retStatus) { - throw std::runtime_error("get obs object(" + filename + ") metadata failed, error_code: " + - obs_get_status_name(data.retStatus)); - } - totalLength = data.length; - - memset_s(&conditions, sizeof(obs_get_conditions), 0, sizeof(obs_get_conditions)); - init_get_properties(&conditions); - } - - Result> ObsReadableFile::ReadAt(int64_t position, int64_t nbytes) { - RETURN_NOT_OK(CheckClosed()); - ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateResizableBuffer(nbytes, io::default_io_context().pool())); - ARROW_ASSIGN_OR_RAISE(int64_t bytes_read, ReadAt(position, nbytes, buffer->mutable_data())); - if (bytes_read < nbytes) { - RETURN_NOT_OK(buffer->Resize(bytes_read)); - buffer->ZeroPadding(); - } - return std::move(buffer); - } - - Result ObsReadableFile::ReadAt(int64_t offset, int64_t length, void* buf) { - if (!buf) { - throw std::runtime_error("Buffer is null."); - } - conditions.start_byte = offset; - conditions.byte_count = length; - - obs_get_object_handler handler = { - { &responsePropertiesCallback, - &responseCompleteCallback}, - &getObjectDataCallback - }; - - CallbackData data; - data.retStatus = OBS_STATUS_BUTT; - data.length = length; - data.readLength = 0; - data.buf = reinterpret_cast(buf); - do { - // the data.buf offset is processed in the callback function getObjectDataCallback - uint64_t tmpRead = data.readLength; - get_object(&option, &objectInfo, &conditions, 0, &handler, &data); - if (OBS_STATUS_OK != data.retStatus) { - LogsError("get obs object failed, length=%ld, readLength=%ld, offset=%ld", - data.length, data.readLength, offset); - throw std::runtime_error("get obs object(" + filename + ") failed, error_code: " + - obs_get_status_name(data.retStatus)); - } - - // read data buffer size = 0, no more remaining data need to read - if (tmpRead == data.readLength) { - break; - } - conditions.start_byte = offset + data.readLength; - conditions.byte_count = length - data.readLength; - } while (data.readLength < length); - - return data.readLength; - } -} diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/ParquetObsFile.hh b/omnioperator/omniop-spark-extension/cpp/src/io/ParquetObsFile.hh deleted file mode 100644 index 143f0441ad59d3c36c8f2da7efb752bca97a9cf5..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/cpp/src/io/ParquetObsFile.hh +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARQURTOBSFILE_H -#define PARQURTOBSFILE_H - -#include "eSDKOBS.h" -#include -#include -#include -#include - -#define OBS_READ_SIZE 1024 -#define OBS_KEY_SIZE 2048 -#define OBS_TOKEN_SIZE 8192 -#define OBS_PROTOCOL_SIZE 6 - -using namespace arrow::io; -using namespace arrow; - -namespace spark::reader { - typedef struct ObsConfig { - char hostName[OBS_KEY_SIZE]; - char accessKey[OBS_KEY_SIZE]; - char secretKey[OBS_KEY_SIZE]; - char token[OBS_TOKEN_SIZE]; - char bucket[OBS_KEY_SIZE]; - char objectKey[OBS_KEY_SIZE]; - uint32_t hostLen; - } ObsConfig; - - std::shared_ptr readObsFile(const std::string& path, ObsConfig *obsInfo); - - class ObsReadableFile : public RandomAccessFile { - private: - obs_options option; - obs_object_info objectInfo; - obs_get_conditions conditions; - ObsConfig obsInfo; - - std::string filename; - uint64_t totalLength; - const uint64_t READ_SIZE = OBS_READ_SIZE * OBS_READ_SIZE; - - static obs_status obsInitStatus; - - static obs_status obsInit(); - - bool is_open_ = true; - - void getObsInfo(ObsConfig *obsInfo); - - public: - ObsReadableFile(std::string _filename, ObsConfig *obsInfo); - - Result> ReadAt(int64_t position, int64_t nbytes) override; - - Result ReadAt(int64_t offset, int64_t length, void* buf) override; - - Status Close() override { - if (is_open_) { - is_open_ = false; - return Status::OK(); - } - return Status::OK(); - } - - bool closed() const override { - return !is_open_; - } - - Status CheckClosed() { - if (!is_open_) { - return Status::Invalid("Operation on closed OBS file"); - } - return Status::OK(); - } - - Result GetSize() override { - return totalLength; - } - - Result Read(int64_t nbytes, void* out) override { - return Result(Status::NotImplemented("Not implemented")); - } - - Result> Read(int64_t nbytes) override { - return Result>(Status::NotImplemented("Not implemented")); - } - - Status Seek(int64_t position) override { - return Status::NotImplemented("Not implemented"); - } - - Result Tell() const override { - return Result(Status::NotImplemented("Not implemented")); - } - - ~ObsReadableFile() {} - }; -} - -#endif diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcHdfsFileRewrite.cc b/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcHdfsFileRewrite.cc deleted file mode 100644 index c0204162ad07f3c61a29f2ace0cf62964c37176b..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/cpp/src/io/orcfile/OrcHdfsFileRewrite.cc +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "OrcFileRewrite.hh" - -#include "orc/Exceptions.hh" -#include "io/Adaptor.hh" - -#include -#include -#include -#include -#include -#include -#include - -#include "hdfspp/hdfspp.h" - -namespace orc { - - class HdfsFileInputStreamRewrite : public InputStream { - private: - std::string filename; - std::unique_ptr file; - std::unique_ptr file_system; - uint64_t totalLength; - const uint64_t READ_SIZE = 1024 * 1024; //1 MB - - public: - HdfsFileInputStreamRewrite(std::string _filename) { - std::vector tokens; - HdfsFileInputStreamRewrite(_filename, tokens); - } - - HdfsFileInputStreamRewrite(std::string _filename, std::vector& tokens) { - filename = _filename ; - - //Building a URI object from the given uri_path - hdfs::URI uri; - try { - uri = hdfs::URI::parse_from_string(filename); - } catch (const hdfs::uri_parse_error&) { - throw ParseError("Malformed URI: " + filename); - } - - //This sets conf path to default "$HADOOP_CONF_DIR" or "/etc/hadoop/conf" - //and loads configs core-site.xml and hdfs-site.xml from the conf path - hdfs::ConfigParser parser; - if(!parser.LoadDefaultResources()){ - throw ParseError("Could not load default resources. "); - } - auto stats = parser.ValidateResources(); - //validating core-site.xml - if(!stats[0].second.ok()){ - throw ParseError(stats[0].first + " is invalid: " + stats[0].second.ToString()); - } - //validating hdfs-site.xml - if(!stats[1].second.ok()){ - throw ParseError(stats[1].first + " is invalid: " + stats[1].second.ToString()); - } - hdfs::Options options; - if(!parser.get_options(options)){ - throw ParseError("Could not load Options object. "); - } - - if (!tokens.empty()) { - for (auto input : tokens) { - hdfs::Token token; - token.setIdentifier(input->getIdentifier()); - token.setPassword(input->getPassword()); - token.setKind(input->getKind()); - token.setService(input->getService()); - options.addToken(token); - } - } - hdfs::IoService * io_service = hdfs::IoService::New(); - //Wrapping file_system into a unique pointer to guarantee deletion - file_system = std::unique_ptr( - hdfs::FileSystem::New(io_service, "", options)); - if (file_system.get() == nullptr) { - throw ParseError("Can't create FileSystem object. "); - } - hdfs::Status status; - //Checking if the user supplied the host - if(!uri.get_host().empty()){ - //Using port if supplied, otherwise using "" to look up port in configs - std::string port = uri.has_port() ? - std::to_string(uri.get_port()) : ""; - status = file_system->Connect(uri.get_host(), port); - if (!status.ok()) { - throw ParseError("Can't connect to " + uri.get_host() - + ":" + port + ". " + status.ToString()); - } - } else { - status = file_system->ConnectToDefaultFs(); - if (!status.ok()) { - if(!options.defaultFS.get_host().empty()){ - throw ParseError("Error connecting to " + - options.defaultFS.str() + ". " + status.ToString()); - } else { - throw ParseError( - "Error connecting to the cluster: defaultFS is empty. " - + status.ToString()); - } - } - } - - if (file_system.get() == nullptr) { - throw ParseError("Can't connect the file system. "); - } - - hdfs::FileHandle *file_raw = nullptr; - status = file_system->Open(uri.get_path(), &file_raw); - if (!status.ok()) { - throw ParseError("Can't open " - + uri.get_path() + ". " + status.ToString()); - } - //Wrapping file_raw into a unique pointer to guarantee deletion - file.reset(file_raw); - - hdfs::StatInfo stat_info; - status = file_system->GetFileInfo(uri.get_path(), stat_info); - if (!status.ok()) { - throw ParseError("Can't stat " - + uri.get_path() + ". " + status.ToString()); - } - totalLength = stat_info.length; - } - - uint64_t getLength() const override { - return totalLength; - } - - uint64_t getNaturalReadSize() const override { - return READ_SIZE; - } - - void read(void* buf, - uint64_t length, - uint64_t offset) override { - - if (!buf) { - throw ParseError("Buffer is null"); - } - - char* buf_ptr = reinterpret_cast(buf); - hdfs::Status status; - size_t total_bytes_read = 0; - size_t last_bytes_read = 0; - - do { - status = file->PositionRead(buf_ptr, - static_cast(length) - total_bytes_read, - static_cast(offset + total_bytes_read), &last_bytes_read); - if(!status.ok()) { - throw ParseError("Error reading the file: " + status.ToString()); - } - total_bytes_read += last_bytes_read; - buf_ptr += last_bytes_read; - } while (total_bytes_read < length); - } - - const std::string& getName() const override { - return filename; - } - - ~HdfsFileInputStreamRewrite() override; - }; - - HdfsFileInputStreamRewrite::~HdfsFileInputStreamRewrite() { - } - - std::unique_ptr readHdfsFileRewrite(const std::string& path, std::vector& tokens) { - return std::unique_ptr(new HdfsFileInputStreamRewrite(path, tokens)); - } -} diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp index ca982c0a4ca56100cb6c11599d6d0c334009da92..14785a9cf453f5925974f8b85dc80538a5b85a17 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp @@ -131,7 +131,6 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress) { - JNI_FUNC_START auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); @@ -140,10 +139,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split } auto vecBatch = (VectorBatch *) jVecBatchAddress; - + splitter->SetInputVecBatch(vecBatch); + JNI_FUNC_START splitter->Split(*vecBatch); return 0L; - JNI_FUNC_END(runtimeExceptionClass) + JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch()) } JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h index 4b59296e152876062a06db3d69c81a7ed22b670b..964fab6dfc06ac692294fa212f40afc21a4d1041 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h @@ -48,6 +48,15 @@ jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const ch return; \ } \ +#define JNI_FUNC_END_WITH_VECBATCH(exceptionClass, toDeleteVecBatch) \ + } \ + catch (const std::exception &e) \ + { \ + VectorHelper::FreeVecBatch(toDeleteVecBatch); \ + env->ThrowNew(exceptionClass, e.what()); \ + return 0; \ + } + extern jclass runtimeExceptionClass; extern jclass splitResultClass; extern jclass jsonClass; diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index 8cb7f2bc95eca005000b9a937786e956822d3e0d..c503c38f085cd51ed29061fbfb4d91d2e573cd8c 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -225,6 +225,89 @@ int Splitter::SplitFixedWidthValueBuffer(VectorBatch& vb) { return 0; } +void HandleNull(VCBatchInfo &vcbInfo, bool isNull) { + if(isNull) { + vcbInfo.SetNullFlag(isNull); + } +} + +template +void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) { + int32_t num_rows = varcharVector->GetSize(); + bool is_null = false; + if (varcharVector->GetEncoding() == OMNI_DICTIONARY) { + auto vc = reinterpret_cast> *>( + varcharVector); + for (auto row = 0; row < num_rows; ++row) { + auto pid = partition_id_[row]; + uint8_t *dst = nullptr; + uint32_t str_len = 0; + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + if constexpr (hasNull) { + is_null = vc->IsNull(row); + } + cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 + VCLocation cl((uint64_t) dst, str_len, is_null); + if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && + (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < + options_.spill_batch_row_num)) { + if constexpr(hasNull) { + HandleNull(vc_partition_array_buffers_[pid][col_schema].back(), is_null); + } + vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; + } else { + VCBatchInfo svc(options_.spill_batch_row_num); + svc.getVcList().push_back(cl); + svc.vcb_total_len += str_len; + if constexpr (hasNull) { + HandleNull(svc, is_null); + } + vc_partition_array_buffers_[pid][col_schema].push_back(svc); + } + } + } else { + auto vc = reinterpret_cast> *>(varcharVector); + for (auto row = 0; row < num_rows; ++row) { + auto pid = partition_id_[row]; + uint8_t *dst = nullptr; + uint32_t str_len = 0; + if (!vc->IsNull(row)) { + std::string_view value = vc->GetValue(row); + dst = reinterpret_cast(reinterpret_cast(value.data())); + str_len = static_cast(value.length()); + } + + if constexpr (hasNull) { + is_null = vc->IsNull(row); + } + cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 + VCLocation cl((uint64_t) dst, str_len, is_null); + if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && + (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < + options_.spill_batch_row_num)) { + if constexpr(hasNull) { + HandleNull(vc_partition_array_buffers_[pid][col_schema].back(), is_null); + } + vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); + vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; + } else { + VCBatchInfo svc(options_.spill_batch_row_num); + svc.getVcList().push_back(cl); + if constexpr(hasNull) { + HandleNull(svc, is_null); + } + svc.vcb_total_len += str_len; + vc_partition_array_buffers_[pid][col_schema].push_back(svc); + } + } + } +} + int Splitter::SplitBinaryArray(VectorBatch& vb) { const auto num_rows = vb.GetRowCount(); @@ -234,60 +317,12 @@ int Splitter::SplitBinaryArray(VectorBatch& vb) switch (column_type_id_[col_schema]) { case SHUFFLE_BINARY: { auto col_vb = singlePartitionFlag ? col_schema : col_schema + 1; - varcharVectorCache.insert(vb.Get(col_vb)); - if (vb.Get(col_vb)->GetEncoding() == OMNI_DICTIONARY) { - auto vc = reinterpret_cast> *>( - vb.Get(col_vb)); - for (auto row = 0; row < num_rows; ++row) { - auto pid = partition_id_[row]; - uint8_t *dst = nullptr; - uint32_t str_len = 0; - if (!vc->IsNull(row)) { - std::string_view value = vc->GetValue(row); - dst = reinterpret_cast(reinterpret_cast(value.data())); - str_len = static_cast(value.length()); - } - bool is_null = vc->IsNull(row); - cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 - VCLocation cl((uint64_t) dst, str_len, is_null); - if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && - (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < - options_.spill_batch_row_num)) { - vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); - vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; - } else { - VCBatchInfo svc(options_.spill_batch_row_num); - svc.getVcList().push_back(cl); - svc.vcb_total_len += str_len; - vc_partition_array_buffers_[pid][col_schema].push_back(svc); - } - } + auto *varcharVector = vb.Get(col_vb); + varcharVectorCache.insert(varcharVector); + if (varcharVector->HasNull()) { + this->template SplitBinaryVector(varcharVector, col_schema); } else { - auto vc = reinterpret_cast> *>(vb.Get(col_vb)); - for (auto row = 0; row < num_rows; ++row) { - auto pid = partition_id_[row]; - uint8_t *dst = nullptr; - uint32_t str_len = 0; - if (!vc->IsNull(row)) { - std::string_view value = vc->GetValue(row); - dst = reinterpret_cast(reinterpret_cast(value.data())); - str_len = static_cast(value.length()); - } - bool is_null = vc->IsNull(row); - cached_vectorbatch_size_ += str_len; // 累计变长部分cache数据 - VCLocation cl((uint64_t) dst, str_len, is_null); - if ((vc_partition_array_buffers_[pid][col_schema].size() != 0) && - (vc_partition_array_buffers_[pid][col_schema].back().getVcList().size() < - options_.spill_batch_row_num)) { - vc_partition_array_buffers_[pid][col_schema].back().getVcList().push_back(cl); - vc_partition_array_buffers_[pid][col_schema].back().vcb_total_len += str_len; - } else { - VCBatchInfo svc(options_.spill_batch_row_num); - svc.getVcList().push_back(cl); - svc.vcb_total_len += str_len; - vc_partition_array_buffers_[pid][col_schema].push_back(svc); - } - } + this->template SplitBinaryVector(varcharVector, col_schema); } break; } @@ -306,33 +341,35 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ auto col_idx = fixed_width_array_idx_[col]; auto& dst_addrs = partition_fixed_width_validity_addrs_[col]; // 分配内存并初始化 - for (auto pid = 0; pid < num_partitions_; ++pid) { - if (partition_id_cnt_cur_[pid] > 0 && dst_addrs[pid] == nullptr) { - // init bitmap if it's null - auto new_size = partition_id_cnt_cur_[pid] > options_.buffer_size ? partition_id_cnt_cur_[pid] : options_.buffer_size; - auto ptr_tmp = static_cast(options_.allocator->Alloc(new_size)); - if (nullptr == ptr_tmp) { - throw std::runtime_error("Allocator for ValidityBuffer Failed! "); + if (vb.Get(col_idx)->HasNull()) { + for (auto pid = 0; pid < num_partitions_; ++pid) { + if (partition_id_cnt_cur_[pid] > 0 && dst_addrs[pid] == nullptr) { + // init bitmap if it's null + auto new_size = partition_id_cnt_cur_[pid] > options_.buffer_size ? partition_id_cnt_cur_[pid] : options_.buffer_size; + auto ptr_tmp = static_cast(options_.allocator->Alloc(new_size)); + if (nullptr == ptr_tmp) { + throw std::runtime_error("Allocator for ValidityBuffer Failed! "); + } + std::shared_ptr validity_buffer ( + new Buffer((uint8_t *)ptr_tmp, partition_id_cnt_cur_[pid], new_size)); + dst_addrs[pid] = const_cast(validity_buffer->data_); + std::memset(validity_buffer->data_, 0, new_size); + partition_fixed_width_buffers_[col][pid][0] = std::move(validity_buffer); + fixed_nullBuffer_size_[pid] = new_size; } - std::shared_ptr validity_buffer (new Buffer((uint8_t *)ptr_tmp, 0, new_size)); - dst_addrs[pid] = const_cast(validity_buffer->data_); - std::memset(validity_buffer->data_, 0, new_size); - partition_fixed_width_buffers_[col][pid][0] = std::move(validity_buffer); - fixed_nullBuffer_size_[pid] = new_size; - } - } + } - // 计算并填充数据 - auto src_addr = const_cast((uint8_t *)( - reinterpret_cast(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vb.Get(col_idx))))); - std::memset(partition_buffer_idx_offset_, 0, num_partitions_ * sizeof(int32_t)); - const auto num_rows = vb.GetRowCount(); - for (auto row = 0; row < num_rows; ++row) { - auto pid = partition_id_[row]; - auto dst_offset = partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; - dst_addrs[pid][dst_offset] = src_addr[row]; - partition_buffer_idx_offset_[pid]++; - partition_fixed_width_buffers_[col][pid][0]->size_ += 1; + // 计算并填充数据 + auto src_addr = const_cast((uint8_t *)( + reinterpret_cast(omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vb.Get(col_idx))))); + std::memset(partition_buffer_idx_offset_, 0, num_partitions_ * sizeof(int32_t)); + const auto num_rows = vb.GetRowCount(); + for (auto row = 0; row < num_rows; ++row) { + auto pid = partition_id_[row]; + auto dst_offset = partition_buffer_idx_base_[pid] + partition_buffer_idx_offset_[pid]; + dst_addrs[pid][dst_offset] = src_addr[row]; + partition_buffer_idx_offset_[pid]++; + } } } return 0; @@ -358,7 +395,9 @@ int Splitter::CacheVectorBatch(int32_t partition_id, bool reset_buffers) { } default: { auto& buffers = partition_fixed_width_buffers_[fixed_width_idx][partition_id]; - batch_partition_size += buffers[0]->capacity_; // 累计null数组所占内存大小 + if (buffers[0] != nullptr) { + batch_partition_size += buffers[0]->capacity_; // 累计null数组所占内存大小 + } batch_partition_size += buffers[1]->capacity_; // 累计value数组所占内存大小 if (reset_buffers) { bufferArrayTotal[fixed_width_idx] = std::move(buffers); @@ -411,6 +450,7 @@ int Splitter::DoSplit(VectorBatch& vb) { num_row_splited_ += vb.GetRowCount(); // release the fixed width vector and release vectorBatch at the same time ReleaseVectorBatch(&vb); + this->ResetInputVecBatch(); // 阈值检查,是否溢写 if (num_row_splited_ >= SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) { @@ -650,17 +690,18 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, colIndexTmpSchema = singlePartitionFlag ? fixed_width_array_idx_[fixColIndexTmp] : fixed_width_array_idx_[fixColIndexTmp] - 1; auto onceCopyLen = splitRowInfoTmp->onceCopyRow * (1 << column_type_id_[colIndexTmpSchema]); // 临时内存,拷贝拼接onceCopyRow批,用完释放 - void *ptr_value_tmp = static_cast(options_.allocator->Alloc(onceCopyLen)); - std::shared_ptr ptr_value (new Buffer((uint8_t*)ptr_value_tmp, 0, onceCopyLen)); - void *ptr_validity_tmp = static_cast(options_.allocator->Alloc(splitRowInfoTmp->onceCopyRow)); - std::shared_ptr ptr_validity (new Buffer((uint8_t*)ptr_validity_tmp, 0, splitRowInfoTmp->onceCopyRow)); - if (nullptr == ptr_value->data_ || nullptr == ptr_validity->data_) { - throw std::runtime_error("Allocator for tmp buffer Failed! "); - } + std::string valueStr; + valueStr.resize(onceCopyLen); + std::string nullStr; + + std::shared_ptr ptr_value (new Buffer((uint8_t*)valueStr.data(), 0, onceCopyLen, false)); + std::shared_ptr ptr_validity; + // options_.spill_batch_row_num长度切割与拼接 uint destCopyedLength = 0; uint memCopyLen = 0; uint cacheBatchSize = 0; + bool nullAllocated = false; while (destCopyedLength < onceCopyLen) { if (splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] >= partition_cached_vectorbatch_[partitionId].size()) { // 数组越界保护 throw std::runtime_error("Columnar shuffle CacheBatchIndex out of bound."); @@ -674,20 +715,29 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, onceCopyLen, destCopyedLength, splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]); + if (not nullAllocated && partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0] != nullptr) { + nullStr.resize(splitRowInfoTmp->onceCopyRow); + ptr_validity.reset(new Buffer((uint8_t*)nullStr.data(), 0, splitRowInfoTmp->onceCopyRow, false)); + nullAllocated = true; + } if ((onceCopyLen - destCopyedLength) >= (cacheBatchSize - splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp])) { memCopyLen = cacheBatchSize - splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]; memcpy((uint8_t*)(ptr_value->data_) + destCopyedLength, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_ + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp], memCopyLen); // (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])) 等比例计算null数组偏移 - memcpy((uint8_t*)(ptr_validity->data_) + (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])), - partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), - memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); - // 释放内存 - options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, - partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_); + if (partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0] != nullptr) { + memcpy((uint8_t*)(ptr_validity->data_) + (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])), + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), + memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); + // 释放内存 + options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_); + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->SetReleaseFlag(); + } options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->capacity_); + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->SetReleaseFlag(); destCopyedLength += memCopyLen; splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] += 1; // cacheBatchIndex下标后移 splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] = 0; // 初始化下一个cacheBatch的起始偏移 @@ -697,9 +747,12 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_ + splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp], memCopyLen); // (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])) 等比例计算null数组偏移 - memcpy((uint8_t*)(ptr_validity->data_) + (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])), - partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), - memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); + + if(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0] != nullptr) { + memcpy((uint8_t*)(ptr_validity->data_) + (destCopyedLength / (1 << column_type_id_[colIndexTmpSchema])), + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_ + (splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] / (1 << column_type_id_[colIndexTmpSchema])), + memCopyLen / (1 << column_type_id_[colIndexTmpSchema])); + } destCopyedLength = onceCopyLen; // copy目标完成,结束while循环 splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] += memCopyLen; } @@ -710,11 +763,11 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, fixColIndexTmp, splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]); } - vec.set_values(ptr_value->data_, onceCopyLen); - vec.set_nulls(ptr_validity->data_, splitRowInfoTmp->onceCopyRow); + auto *protoValue = vec.mutable_values(); + *protoValue = std::move(valueStr); + auto *protoNulls = vec.mutable_nulls(); + *protoNulls = std::move(nullStr); // 临时内存,拷贝拼接onceCopyRow批,用完释放 - options_.allocator->Free(ptr_value->data_, ptr_value->capacity_); - options_.allocator->Free(ptr_validity->data_, ptr_validity->capacity_); } // partition_cached_vectorbatch_[partition_id][cache_index][col][0]代表ByteMap, // partition_cached_vectorbatch_[partition_id][cache_index][col][1]代表value @@ -727,16 +780,27 @@ void Splitter::SerializingBinaryColumns(int32_t partitionId, spark::Vec& vec, in int valuesTotalLen = vcb.getVcbTotalLen(); std::vector lst = vcb.getVcList(); int itemsTotalLen = lst.size(); - auto OffsetsByte(std::make_unique(itemsTotalLen + 1)); - auto nullsByte(std::make_unique(itemsTotalLen)); - auto valuesByte(std::make_unique(valuesTotalLen)); - BytesGen(reinterpret_cast(OffsetsByte.get()), - reinterpret_cast(nullsByte.get()), - reinterpret_cast(valuesByte.get()), vcb); - vec.set_values(valuesByte.get(), valuesTotalLen); - // nulls add boolean array; serizelized tobytearray - vec.set_nulls((char *)nullsByte.get(), itemsTotalLen); - vec.set_offset(OffsetsByte.get(), (itemsTotalLen + 1) * sizeof(int32_t)); + + std::string offsetsStr; + offsetsStr.resize(sizeof(int32_t) * (itemsTotalLen + 1)); + std::string nullsStr; + std::string valuesStr; + valuesStr.resize(valuesTotalLen); + if(vcb.hasNull()) { + BytesGen(reinterpret_cast(offsetsStr.data()), + nullsStr, + reinterpret_cast(valuesStr.data()), vcb); + } else { + BytesGen(reinterpret_cast(offsetsStr.data()), + nullsStr, + reinterpret_cast(valuesStr.data()), vcb); + } + auto *protoValue = vec.mutable_values(); + *protoValue = std::move(valuesStr); + auto *protoNulls = vec.mutable_nulls(); + *protoNulls = std::move(nullsStr); + auto *protoOffset = vec.mutable_offset(); + *protoOffset = std::move(offsetsStr); } int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut) { @@ -976,7 +1040,7 @@ void Splitter::MergeSpilled() { } } - std::memset(partition_id_cnt_cache_, 0, num_partitions_ * sizeof(uint64_t)); + std::memset(partition_id_cnt_cache_, 0, num_partitions_ * sizeof(uint64_t)); ReleaseVarcharVector(); num_row_splited_ = 0; cached_vectorbatch_size_ = 0; @@ -1015,6 +1079,7 @@ int Splitter::DeleteSpilledTmpFile() { auto tmpDataFilePath = pair.first + ".data"; // 释放存储有各个临时文件的偏移数据内存 options_.allocator->Free(pair.second->data_, pair.second->capacity_); + pair.second->SetReleaseFlag(); if (IsFileExist(tmpDataFilePath)) { remove(tmpDataFilePath.c_str()); } diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h index cba14253b6132dedf9a7d646ed44af13a7c471d4..ec0cc661f0a49d531b47dab87299cb6a8dfbde2a 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -82,6 +82,9 @@ class Splitter { int SplitBinaryArray(VectorBatch& vb); + template + void SplitBinaryVector(BaseVector *varcharVector, int col_schema); + int CacheVectorBatch(int32_t partition_id, bool reset_buffers); void ToSplitterTypeId(int num_cols); @@ -160,6 +163,7 @@ private: } } vectorAddress.clear(); + vb->ClearVectors(); delete vb; } @@ -167,7 +171,7 @@ private: std::vector vector_batch_col_types_; InputDataTypes input_col_types; std::vector binary_array_empirical_size_; - + omniruntime::vec::VectorBatch *inputVecBatch = nullptr; public: bool singlePartitionFlag = false; int32_t num_partitions_; @@ -215,6 +219,22 @@ public: int64_t TotalComputePidTime() const { return total_compute_pid_time_; } const std::vector& PartitionLengths() const { return partition_lengths_; } + + omniruntime::vec::VectorBatch *GetInputVecBatch() + { + return inputVecBatch; + } + + void SetInputVecBatch(omniruntime::vec::VectorBatch *inVecBatch) + { + inputVecBatch = inVecBatch; + } + + // no need to clear memory when exception, so we have to reset + void ResetInputVecBatch() + { + inputVecBatch = nullptr; + } }; diff --git a/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.cpp b/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.cpp deleted file mode 100644 index a6049df84c2c1b5d13488d8844fbc315e276f2eb..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/cpp/src/tablescan/ParquetReader.cpp +++ /dev/null @@ -1,294 +0,0 @@ -/** - * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "jni/jni_common.h" -#include "ParquetReader.h" - -using namespace omniruntime::vec; -using namespace omniruntime::type; -using namespace arrow; -using namespace parquet::arrow; -using namespace arrow::compute; -using namespace spark::reader; - -static std::mutex mutex_; -static std::map restore_filesysptr; -static constexpr int32_t PARQUET_MAX_DECIMAL64_DIGITS = 18; -static constexpr int32_t INT128_BYTES = 16; -static constexpr int32_t INT64_BYTES = 8; -static constexpr int32_t BYTE_BITS = 8; -static constexpr int32_t LOCAL_FILE_PREFIX = 5; -static constexpr int32_t LOCAL_FILE_PREFIX_EXT = 7; -static const std::string LOCAL_FILE = "file:"; -static const std::string HDFS_FILE = "hdfs:"; - -std::string spark::reader::GetFileSystemKey(std::string& path, std::string& ugi) -{ - // if the local file, all the files are the same key "file:" - std::string result = ugi; - - // if the hdfs file, only get the ip and port just like the ugi + ip + port as key - if (path.substr(0, LOCAL_FILE_PREFIX) == HDFS_FILE) { - auto end = path.find("/", LOCAL_FILE_PREFIX_EXT); - std::string ip_and_port = path.substr(LOCAL_FILE_PREFIX_EXT, end - LOCAL_FILE_PREFIX_EXT); - result += ip_and_port; - return result; - } - - // if the local file, get the ugi + "file" as the key - if (path.substr(0, LOCAL_FILE_PREFIX) == LOCAL_FILE) { - // process the path "file://" head, the arrow could not read the head - path = path.substr(LOCAL_FILE_PREFIX); - result += "file:"; - return result; - } - - // if not the local, not the hdfs, get the ugi + path as the key - result += path; - return result; -} - -Filesystem* spark::reader::GetFileSystemPtr(std::string& path, std::string& ugi) -{ - auto key = GetFileSystemKey(path, ugi); - - // if not find key, creadte the filesystem ptr - auto iter = restore_filesysptr.find(key); - if (iter == restore_filesysptr.end()) { - Filesystem* fs = new Filesystem(); - fs->filesys_ptr = std::move(fs::FileSystemFromUriOrPath(path)).ValueUnsafe(); - restore_filesysptr[key] = fs; - } - - return restore_filesysptr[key]; -} - -Status ParquetReader::InitRecordReader(std::string& filePath, int64_t capacity, - const std::vector& row_group_indices, const std::vector& column_indices, - std::string& ugi, ObsConfig& obsInfo) -{ - arrow::MemoryPool* pool = default_memory_pool(); - - // Configure reader settings - auto reader_properties = parquet::ReaderProperties(pool); - - // Configure Arrow-specific reader settings - auto arrow_reader_properties = parquet::ArrowReaderProperties(); - arrow_reader_properties.set_batch_size(capacity); - - std::shared_ptr file; - if (0 == strncmp(filePath.c_str(), "obs://", OBS_PROTOCOL_SIZE)) { - file = readObsFile(filePath, &obsInfo); - } else { - // Get the file from filesystem - mutex_.lock(); - Filesystem* fs = GetFileSystemPtr(filePath, ugi); - mutex_.unlock(); - ARROW_ASSIGN_OR_RAISE(file, fs->filesys_ptr->OpenInputFile(filePath)); - } - - FileReaderBuilder reader_builder; - ARROW_RETURN_NOT_OK(reader_builder.Open(file, reader_properties)); - reader_builder.memory_pool(pool); - reader_builder.properties(arrow_reader_properties); - - ARROW_ASSIGN_OR_RAISE(arrow_reader, reader_builder.Build()); - ARROW_RETURN_NOT_OK(arrow_reader->GetRecordBatchReader(row_group_indices, column_indices, &rb_reader)); - return arrow::Status::OK(); -} - -Status ParquetReader::ReadNextBatch(std::shared_ptr *batch) -{ - ARROW_RETURN_NOT_OK(rb_reader->ReadNext(batch)); - return arrow::Status::OK(); -} - -/** - * For BooleanType, copy values one by one. - */ -uint64_t CopyBooleanType(std::shared_ptr array) -{ - arrow::BooleanArray *lvb = dynamic_cast(array.get()); - auto numElements = lvb->length(); - auto originalVector = new Vector(numElements); - for (int64_t i = 0; i < numElements; i++) { - if (lvb->IsNull(i)) { - originalVector->SetNull(i); - } else { - if (lvb->Value(i)) { - originalVector->SetValue(i, true); - } else { - originalVector->SetValue(i, false); - } - } - } - return (uint64_t)originalVector; -} - -/** - * For int16/int32/int64/double type, copy values in batches and skip setNull if there is no nulls. - */ -template uint64_t CopyFixedWidth(std::shared_ptr array) -{ - using T = typename NativeType::type; - PARQUET_TYPE *lvb = dynamic_cast(array.get()); - auto numElements = lvb->length(); - auto values = lvb->raw_values(); - auto originalVector = new Vector(numElements); - // Check ColumnVectorBatch has null or not firstly - if (lvb->null_count() != 0) { - for (int64_t i = 0; i < numElements; i++) { - if (lvb->IsNull(i)) { - originalVector->SetNull(i); - } - } - } - originalVector->SetValues(0, values, numElements); - return (uint64_t)originalVector; -} - -uint64_t CopyVarWidth(std::shared_ptr array) -{ - auto lvb = dynamic_cast(array.get()); - auto numElements = lvb->length(); - auto originalVector = new Vector>(numElements); - for (int64_t i = 0; i < numElements; i++) { - if (lvb->IsValid(i)) { - auto data = lvb->GetView(i); - originalVector->SetValue(i, data); - } else { - originalVector->SetNull(i); - } - } - return (uint64_t)originalVector; -} - -uint64_t CopyToOmniDecimal128Vec(std::shared_ptr array) -{ - auto lvb = dynamic_cast(array.get()); - auto numElements = lvb->length(); - auto originalVector = new Vector(numElements); - for (int64_t i = 0; i < numElements; i++) { - if (lvb->IsValid(i)) { - auto data = lvb->GetValue(i); - __int128_t val; - memcpy_s(&val, sizeof(val), data, INT128_BYTES); - omniruntime::type::Decimal128 d128(val); - originalVector->SetValue(i, d128); - } else { - originalVector->SetNull(i); - } - } - return (uint64_t)originalVector; -} - -uint64_t CopyToOmniDecimal64Vec(std::shared_ptr array) -{ - auto lvb = dynamic_cast(array.get()); - auto numElements = lvb->length(); - auto originalVector = new Vector(numElements); - for (int64_t i = 0; i < numElements; i++) { - if (lvb->IsValid(i)) { - auto data = lvb->GetValue(i); - int64_t val; - memcpy_s(&val, sizeof(val), data, INT64_BYTES); - originalVector->SetValue(i, val); - } else { - originalVector->SetNull(i); - } - } - return (uint64_t)originalVector; -} - -int spark::reader::CopyToOmniVec(std::shared_ptr vcType, int &omniTypeId, uint64_t &omniVecId, - std::shared_ptr array) -{ - switch (vcType->id()) { - case arrow::Type::BOOL: - omniTypeId = static_cast(OMNI_BOOLEAN); - omniVecId = CopyBooleanType(array); - break; - case arrow::Type::INT16: - omniTypeId = static_cast(OMNI_SHORT); - omniVecId = CopyFixedWidth(array); - break; - case arrow::Type::INT32: - omniTypeId = static_cast(OMNI_INT); - omniVecId = CopyFixedWidth(array); - break; - case arrow::Type::DATE32: - omniTypeId = static_cast(OMNI_DATE32); - omniVecId = CopyFixedWidth(array); - break; - case arrow::Type::INT64: - omniTypeId = static_cast(OMNI_LONG); - omniVecId = CopyFixedWidth(array); - break; - case arrow::Type::DATE64: - omniTypeId = static_cast(OMNI_DATE64); - omniVecId = CopyFixedWidth(array); - break; - case arrow::Type::DOUBLE: - omniTypeId = static_cast(OMNI_DOUBLE); - omniVecId = CopyFixedWidth(array); - break; - case arrow::Type::STRING: - omniTypeId = static_cast(OMNI_VARCHAR); - omniVecId = CopyVarWidth(array); - break; - case arrow::Type::DECIMAL128: { - auto decimalType = static_cast(vcType.get()); - if (decimalType->precision() > PARQUET_MAX_DECIMAL64_DIGITS) { - omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128Vec(array); - } else { - omniTypeId = static_cast(OMNI_DECIMAL64); - omniVecId = CopyToOmniDecimal64Vec(array); - } - break; - } - default: { - throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + vcType->id()); - } - } - return 1; -} - -std::pair spark::reader::TransferToOmniVecs(std::shared_ptr batch) -{ - int64_t num_columns = batch->num_columns(); - std::vector> fields = batch->schema()->fields(); - auto vecTypes = new int64_t[num_columns]; - auto vecs = new int64_t[num_columns]; - for (int64_t colIdx = 0; colIdx < num_columns; colIdx++) { - std::shared_ptr array = batch->column(colIdx); - // One array in current batch - std::shared_ptr data = array->data(); - int omniTypeId = 0; - uint64_t omniVecId = 0; - spark::reader::CopyToOmniVec(data->type, omniTypeId, omniVecId, array); - vecTypes[colIdx] = omniTypeId; - vecs[colIdx] = omniVecId; - } - return std::make_pair(vecTypes, vecs); -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt index ba1ad3a773c35a101cf728f00a19ba30b0dae607..f53ac2ad45e95bddfe3a15a4da78b245e373981e 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt @@ -2,14 +2,12 @@ aux_source_directory(${CMAKE_CURRENT_LIST_DIR} TEST_ROOT_SRCS) add_subdirectory(shuffle) add_subdirectory(utils) -add_subdirectory(tablescan) # configure set(TP_TEST_TARGET tptest) set(MY_LINK shuffletest utilstest - tablescantest ) # find gtest package @@ -29,7 +27,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.3.0-aarch64 + boostkit-omniop-vector-1.4.0-aarch64 securec spark_columnar_plugin) diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index 32e13688864526905cd2602c4a40103f50beb563..62c407dc3dedf2df0c4ca7ab891083467a8279f7 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.3.1-1.3.0 + 3.3.1-1.4.0 ../pom.xml @@ -46,28 +46,13 @@ com.huawei.boostkit boostkit-omniop-bindings + 1.4.0 aarch64 - com.huaweicloud - esdk-obs-java-optimised - - 3.21.8.2 - provided - - - jackson-databind - com.fasterxml.jackson.core - - - jackson-annotations - com.fasterxml.jackson.core - - - jackson-core - com.fasterxml.jackson.core - - + com.huawei.boostkit + boostkit-omniop-native-reader + 3.3.1-1.4.0 junit diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/ObsConf.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/ObsConf.java deleted file mode 100644 index 0c9228c88b02573d742706a46cecfeda8d40f258..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/ObsConf.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.huawei.boostkit.spark; - -import com.huawei.boostkit.spark.ColumnarPluginConfig; - -import com.obs.services.IObsCredentialsProvider; -import com.obs.services.model.ISecurityKey; - -import org.apache.hadoop.conf.Configuration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.json.JSONObject; - -public class ObsConf { - private static final Logger LOG = LoggerFactory.getLogger(ObsConf.class); - private static String endpoint; - private static String accessKey = ""; - private static String secretKey = ""; - private static String token = ""; - private static IObsCredentialsProvider securityProvider; - private static boolean syncToGetToken = false; - private static int retryTimes = 10; - private static Object lock = new Object(); - - private ObsConf() { - syncToGetToken = ColumnarPluginConfig.getConf().enableSyncGetObsToken(); - retryTimes = ColumnarPluginConfig.getConf().retryTimesGetObsToken(); - } - - private static void init() { - Configuration conf = new Configuration(); - String endpointConf = "fs.obs.endpoint"; - String accessKeyConf = "fs.obs.access.key"; - String secretKeyConf = "fs.obs.secret.key"; - String providerConf = "fs.obs.security.provider"; - endpoint = conf.get(endpointConf, ""); - if ("".equals(endpoint)) { - LOG.warn("Key parameter {} is missing in the configuration file.", endpointConf); - return; - } - accessKey = conf.get(accessKeyConf, ""); - secretKey = conf.get(secretKeyConf, ""); - if ("".equals(accessKey) && "".equals(secretKey)) { - if ("".equals(conf.get(providerConf, ""))) { - LOG.error("Key parameters such as {}, {}, or {} are missing or the parameter value is incorrect.", - accessKeyConf, secretKeyConf, providerConf); - } else { - getSecurityKey(conf, providerConf); - } - } - } - - private static void getSecurityKey(Configuration conf, String providerConf) { - try { - Class securityProviderClass = conf.getClass(providerConf, null); - if (securityProviderClass == null) { - LOG.error("Failed to get securityProviderClass {}.", conf.get(providerConf, "")); - return; - } - securityProvider = (IObsCredentialsProvider) securityProviderClass.getDeclaredConstructor().newInstance(); - updateSecurityKey(); - if (!syncToGetToken) { - timerGetSecurityKey(); - } - } catch (Exception e) { - LOG.error("get obs ak/sk/token failed."); - } - } - - private static boolean checkSecurityKeyValid(ISecurityKey iSecurityKey) { - if (null == iSecurityKey) { - LOG.error("iSecurityKey is null"); - return false; - } - if (null == iSecurityKey.getAccessKey() - || null == iSecurityKey.getSecretKey() - || null == iSecurityKey.getSecurityToken()) { - return false; - } - return true; - } - - private static void updateSecurityKey() { - ISecurityKey iSecurityKey = securityProvider.getSecurityKey(); - int count = 0; - while(!checkSecurityKeyValid(iSecurityKey) && count < retryTimes) { - LOG.error("Get securityKey failed,try again"); - iSecurityKey = securityProvider.getSecurityKey(); - count++; - } - synchronized (lock) { - accessKey = iSecurityKey.getAccessKey(); - secretKey = iSecurityKey.getSecretKey(); - token = iSecurityKey.getSecurityToken(); - } - } - - private static void timerGetSecurityKey() { - Thread updateKeyThread = new Thread(new MyRunnable()); - updateKeyThread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - LOG.error("Failed to get securityKey: {}, {}", t.getName(), e.getMessage()); - } - }); - updateKeyThread.start(); - } - - public static String getEndpoint() { - if (endpoint == null) { - synchronized (lock) { - init(); - } - } - return endpoint; - } - - public static String getAk() { - return accessKey; - } - - public static String getSk() { - return secretKey; - } - - public static String getToken() { - return token; - } - - public static Object getLock() { - if (syncToGetToken) { - updateSecurityKey(); - } - return lock; - } - - private static class MyRunnable implements Runnable { - @Override - public void run() { - long sleepTime = ColumnarPluginConfig.getConf().timeGetObsToken(); - while (true) { - try { - updateSecurityKey(); - Thread.sleep(sleepTime); - } catch (InterruptedException e) { - break; - } - } - } - } - - public static JSONObject constructObsJSONObject() { - JSONObject obsJsonItem = new JSONObject(); - obsJsonItem.put("endpoint", ObsConf.getEndpoint()); - synchronized (ObsConf.getLock()) { - obsJsonItem.put("ak", ObsConf.getAk()); - obsJsonItem.put("sk", ObsConf.getSk()); - obsJsonItem.put("token", ObsConf.getToken()); - } - return obsJsonItem; - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/DecompressionStream.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/DecompressionStream.java index 4bbe922ca85907f0eda0e7820277d757a00d2ebe..66146e4980dd4c731821771ea2920eb02f3bad8f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/DecompressionStream.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/compress/DecompressionStream.java @@ -85,8 +85,8 @@ public class DecompressionStream extends InputStream { uncompressedLimit = chunkLength; return; } - if (uncompressed == null || UNCOMPRESSED_LENGTH > uncompressed.length) { - uncompressed = new byte[UNCOMPRESSED_LENGTH]; + if (uncompressed == null || compressBlockSize > uncompressed.length) { + uncompressed = new byte[compressBlockSize]; } int actualUncompressedLength = codec.decompress(compressed, chunkLength, uncompressed); diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/NativeLoader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/NativeLoader.java index 7cd435f7ce052e1ece0c9b5140fc151a9e7152a0..49194e5a37d7b594d0b3dd6ac314af811630f4e9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/NativeLoader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/NativeLoader.java @@ -44,6 +44,7 @@ public class NativeLoader { synchronized (NativeLoader.class) { if (INSTANCE == null) { INSTANCE = new NativeLoader(); + NativeLog.getInstance(); } } } @@ -63,7 +64,6 @@ public class NativeLoader { fos.write(buf, 0, i); } System.load(tempFile.getCanonicalPath()); - NativeLog.getInstance(); } } catch (IOException e) { LOG.warn("fail to load library from Jar!errmsg:{}", e.getMessage()); diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java similarity index 75% rename from omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java rename to omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 128ff6ca190012de4ece35e25e13bf3266fdfff2..1d858a5e3f22353ddfd18588ec69f50d96de5852 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2020-2023. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -17,8 +17,7 @@ */ package com.huawei.boostkit.spark.jni; - -import com.huawei.boostkit.spark.ObsConf; +import com.huawei.boostkit.scan.jni.OrcColumnarBatchJniReader; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.type.Decimal128DataType; @@ -37,14 +36,14 @@ import org.slf4j.LoggerFactory; import org.apache.orc.TypeDescription; import java.io.IOException; +import java.net.URI; import java.sql.Date; import java.util.ArrayList; import java.util.Arrays; import java.util.List; - -public class OrcColumnarBatchJniReader { - private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchJniReader.class); +public class OrcColumnarBatchScanReader { + private static final Logger LOGGER = LoggerFactory.getLogger(OrcColumnarBatchScanReader.class); public long reader; public long recordReader; @@ -52,19 +51,31 @@ public class OrcColumnarBatchJniReader { public int[] colsToGet; public int realColsCnt; - public OrcColumnarBatchJniReader() { - NativeLoader.getInstance(); + public ArrayList fildsNames; + + public ArrayList colToInclu; + + public String[] requiredfieldNames; + + public int[] precisionArray; + + public int[] scaleArray; + + public OrcColumnarBatchJniReader jniReader; + public OrcColumnarBatchScanReader() { + jniReader = new OrcColumnarBatchJniReader(); + fildsNames = new ArrayList(); } - public JSONObject getSubJson(ExpressionTree etNode) { + public JSONObject getSubJson(ExpressionTree node) { JSONObject jsonObject = new JSONObject(); - jsonObject.put("op", etNode.getOperator().ordinal()); - if (etNode.getOperator().toString().equals("LEAF")) { - jsonObject.put("leaf", etNode.toString()); + jsonObject.put("op", node.getOperator().ordinal()); + if (node.getOperator().toString().equals("LEAF")) { + jsonObject.put("leaf", node.toString()); return jsonObject; } ArrayList child = new ArrayList(); - for (ExpressionTree childNode : etNode.getChildren()) { + for (ExpressionTree childNode : node.getChildren()) { JSONObject rtnJson = getSubJson(childNode); child.add(rtnJson); } @@ -73,15 +84,35 @@ public class OrcColumnarBatchJniReader { } public String padZeroForDecimals(String [] decimalStrArray, int decimalScale) { - String decimalVal = ""; // Integer without decimals, eg: 12345 - if (decimalStrArray.length == 2) { // Integer with decimals, eg: 12345.6 + String decimalVal = ""; + if (decimalStrArray.length == 2) { decimalVal = decimalStrArray[1]; } // If the length of the formatted number string is insufficient, pad '0's. return String.format("%1$-" + decimalScale + "s", decimalVal).replace(' ', '0'); } - public JSONObject getLeavesJson(List leaves, TypeDescription schema) { + public int getPrecision(String colname) { + for (int i = 0; i < requiredfieldNames.length; i++) { + if (colname.equals(requiredfieldNames[i])) { + return precisionArray[i]; + } + } + + return -1; + } + + public int getScale(String colname) { + for (int i = 0; i < requiredfieldNames.length; i++) { + if (colname.equals(requiredfieldNames[i])) { + return scaleArray[i]; + } + } + + return -1; + } + + public JSONObject getLeavesJson(List leaves) { JSONObject jsonObjectList = new JSONObject(); for (int i = 0; i < leaves.size(); i++) { PredicateLeaf pl = leaves.get(i); @@ -93,8 +124,8 @@ public class OrcColumnarBatchJniReader { if (pl.getType() == PredicateLeaf.Type.DATE) { jsonObject.put("literal", ((int)Math.ceil(((Date)pl.getLiteral()).getTime()* 1.0/3600/24/1000)) + ""); } else if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = schema.findSubtype(pl.getColumnName()).getPrecision(); - int decimalS = schema.findSubtype(pl.getColumnName()).getScale(); + int decimalP = getPrecision(pl.getColumnName()); + int decimalS = getScale(pl.getColumnName()); String[] spiltValues = pl.getLiteral().toString().split("\\."); if (decimalS == 0) { jsonObject.put("literal", spiltValues[0] + " " + decimalP + " " + decimalS); @@ -109,11 +140,15 @@ public class OrcColumnarBatchJniReader { jsonObject.put("literal", ""); } if ((pl.getLiteralList() != null) && (pl.getLiteralList().size() != 0)){ - List lst = new ArrayList(); + List lst = new ArrayList<>(); for (Object ob : pl.getLiteralList()) { + if (ob == null) { + lst.add(null); + continue; + } if (pl.getType() == PredicateLeaf.Type.DECIMAL) { - int decimalP = schema.findSubtype(pl.getColumnName()).getPrecision(); - int decimalS = schema.findSubtype(pl.getColumnName()).getScale(); + int decimalP = getPrecision(pl.getColumnName()); + int decimalS = getScale(pl.getColumnName()); String[] spiltValues = ob.toString().split("\\."); if (decimalS == 0) { lst.add(spiltValues[0] + " " + decimalP + " " + decimalS); @@ -139,10 +174,10 @@ public class OrcColumnarBatchJniReader { /** * Init Orc reader. * - * @param path split file path + * @param uri split file path * @param options split file options */ - public long initializeReaderJava(String path, ReaderOptions options) { + public long initializeReaderJava(URI uri, ReaderOptions options) { JSONObject job = new JSONObject(); if (options.getOrcTail() == null) { job.put("serializedTail", ""); @@ -152,16 +187,18 @@ public class OrcColumnarBatchJniReader { job.put("tailLocation", 9223372036854775807L); // handle delegate token for native orc reader - OrcColumnarBatchJniReader.tokenDebug("initializeReader"); - JSONObject tokensJsonObj = constructTokensJSONObject(); - if (null != tokensJsonObj) { - job.put("tokens", tokensJsonObj); + OrcColumnarBatchScanReader.tokenDebug("initializeReader"); + JSONObject tokenJsonObj = constructTokensJSONObject(); + if (null != tokenJsonObj) { + job.put("tokens", tokenJsonObj); } - // just used for obs - job.put("obsInfo", ObsConf.constructObsJSONObject()); + job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); + job.put("host", uri.getHost() == null ? "" : uri.getHost()); + job.put("port", uri.getPort()); + job.put("path", uri.getPath() == null ? "" : uri.getPath()); - reader = initializeReader(path, job); + reader = jniReader.initializeReader(job, fildsNames); return reader; } @@ -179,67 +216,51 @@ public class OrcColumnarBatchJniReader { } job.put("offset", options.getOffset()); job.put("length", options.getLength()); - if (options.getSearchArgument() != null) { + // When the number of pushedFilters > hive.CNF_COMBINATIONS_THRESHOLD, the expression is rewritten to + // 'YES_NO_NULL'. Under the circumstances, filter push down will be skipped. + if (options.getSearchArgument() != null + && !options.getSearchArgument().toString().contains("YES_NO_NULL")) { LOGGER.debug("SearchArgument: {}", options.getSearchArgument().toString()); JSONObject jsonexpressionTree = getSubJson(options.getSearchArgument().getExpression()); job.put("expressionTree", jsonexpressionTree); - JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves(), options.getSchema()); + JSONObject jsonleaves = getLeavesJson(options.getSearchArgument().getLeaves()); job.put("leaves", jsonleaves); } - List allCols; - if (options.getColumnNames() == null) { - allCols = Arrays.asList(getAllColumnNames(reader)); - } else { - allCols = Arrays.asList(options.getColumnNames()); - } - ArrayList colToInclu = new ArrayList(); - List optionField = options.getSchema().getFieldNames(); - colsToGet = new int[optionField.size()]; - realColsCnt = 0; - for (int i = 0; i < optionField.size(); i++) { - if (allCols.contains(optionField.get(i))) { - colToInclu.add(optionField.get(i)); - colsToGet[i] = 0; - realColsCnt++; - } else { - colsToGet[i] = -1; - } - } job.put("includedColumns", colToInclu.toArray()); // handle delegate token for native orc reader - OrcColumnarBatchJniReader.tokenDebug("initializeRecordReader"); + OrcColumnarBatchScanReader.tokenDebug("initializeRecordReader"); JSONObject tokensJsonObj = constructTokensJSONObject(); if (null != tokensJsonObj) { job.put("tokens", tokensJsonObj); } - recordReader = initializeRecordReader(reader, job); + recordReader = jniReader.initializeRecordReader(reader, job); return recordReader; } public long initBatchJava(long batchSize) { - batchReader = initializeBatch(recordReader, batchSize); + batchReader = jniReader.initializeBatch(recordReader, batchSize); return 0; } public long getNumberOfRowsJava() { - return getNumberOfRows(recordReader, batchReader); + return jniReader.getNumberOfRows(recordReader, batchReader); } public long getRowNumber() { - return recordReaderGetRowNumber(recordReader); + return jniReader.recordReaderGetRowNumber(recordReader); } public float getProgress() { - return recordReaderGetProgress(recordReader); + return jniReader.recordReaderGetProgress(recordReader); } public void close() { - recordReaderClose(recordReader, reader, batchReader); + jniReader.recordReaderClose(recordReader, reader, batchReader); } public void seekToRow(long rowNumber) { - recordReaderSeekToRow(recordReader, rowNumber); + jniReader.recordReaderSeekToRow(recordReader, rowNumber); } public void convertJulianToGreGorian(IntVec intVec, long rowNumber) { @@ -251,15 +272,14 @@ public class OrcColumnarBatchJniReader { } public int next(Vec[] vecList) { - int vectorCnt = vecList.length; int[] typeIds = new int[realColsCnt]; long[] vecNativeIds = new long[realColsCnt]; - long rtn = recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); + long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { return 0; } int nativeGetId = 0; - for (int i = 0; i < vectorCnt; i++) { + for (int i = 0; i < realColsCnt; i++) { if (colsToGet[i] != 0) { continue; } @@ -308,26 +328,6 @@ public class OrcColumnarBatchJniReader { return (int)rtn; } - public native long initializeReader(String path, JSONObject job); - - public native long initializeRecordReader(long reader, JSONObject job); - - public native long initializeBatch(long rowReader, long batchSize); - - public native long recordReaderNext(long rowReader, long batchReader, int[] typeId, long[] vecNativeId); - - public native long recordReaderGetRowNumber(long rowReader); - - public native float recordReaderGetProgress(long rowReader); - - public native void recordReaderClose(long rowReader, long reader, long batchReader); - - public native void recordReaderSeekToRow(long rowReader, long rowNumber); - - public native String[] getAllColumnNames(long reader); - - public native long getNumberOfRows(long rowReader, long batch); - private static String bytesToHexString(byte[] bytes) { if (bytes == null || bytes.length < 1) { throw new IllegalArgumentException("this bytes must not be null or empty"); diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java similarity index 42% rename from omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReader.java rename to omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java index c45f33bb50a565d0937304817d9a8f6daf86e3e3..5a209a66d763902744d3d141b14f5e9130c15293 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchScanReader.java @@ -18,103 +18,87 @@ package com.huawei.boostkit.spark.jni; -import com.huawei.boostkit.spark.ObsConf; -import nova.hetu.omniruntime.type.DataType; +import com.huawei.boostkit.scan.jni.ParquetColumnarBatchJniReader; import nova.hetu.omniruntime.vector.*; - -import org.apache.spark.sql.catalyst.util.RebaseDateTime; - +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.types.*; import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.UnsupportedEncodingException; +import java.net.URI; import java.util.List; -public class ParquetColumnarBatchJniReader { - private static final Logger LOGGER = LoggerFactory.getLogger(ParquetColumnarBatchJniReader.class); +public class ParquetColumnarBatchScanReader { + private static final Logger LOGGER = LoggerFactory.getLogger(ParquetColumnarBatchScanReader.class); public long parquetReader; - public ParquetColumnarBatchJniReader() { - NativeLoader.getInstance(); + public ParquetColumnarBatchJniReader jniReader; + public ParquetColumnarBatchScanReader() { + jniReader = new ParquetColumnarBatchJniReader(); } - public long initializeReaderJava(String path, int capacity, - List rowgroupIndices, List columnIndices, String ugi) { + public long initializeReaderJava(Path path, int capacity, + List rowgroupIndices, List columnIndices, String ugi) throws UnsupportedEncodingException { JSONObject job = new JSONObject(); - job.put("filePath", path); + URI uri = path.toUri(); + + job.put("uri", path.toString()); job.put("capacity", capacity); job.put("rowGroupIndices", rowgroupIndices.stream().mapToInt(Integer::intValue).toArray()); job.put("columnIndices", columnIndices.stream().mapToInt(Integer::intValue).toArray()); job.put("ugi", ugi); - // just used for obs - job.put("obsInfo", ObsConf.constructObsJSONObject()); - parquetReader = initializeReader(job); + + job.put("host", uri.getHost() == null ? "" : uri.getHost()); + job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); + job.put("port", uri.getPort()); + job.put("path", uri.getPath() == null ? "" : uri.getPath()); + + parquetReader = jniReader.initializeReader(job); return parquetReader; } - public int next(Vec[] vecList) { + public int next(Vec[] vecList, List types) { int vectorCnt = vecList.length; - int[] typeIds = new int[vectorCnt]; long[] vecNativeIds = new long[vectorCnt]; - long rtn = recordReaderNext(parquetReader, typeIds, vecNativeIds); + long rtn = jniReader.recordReaderNext(parquetReader, vecNativeIds); if (rtn == 0) { return 0; } - int nativeGetId = 0; for (int i = 0; i < vectorCnt; i++) { - switch (DataType.DataTypeId.values()[typeIds[nativeGetId]]) { - case OMNI_BOOLEAN: { - vecList[i] = new BooleanVec(vecNativeIds[nativeGetId]); - break; - } - case OMNI_SHORT: { - vecList[i] = new ShortVec(vecNativeIds[nativeGetId]); - break; - } - case OMNI_DATE32: { - vecList[i] = new IntVec(vecNativeIds[nativeGetId]); - break; - } - case OMNI_INT: { - vecList[i] = new IntVec(vecNativeIds[nativeGetId]); - break; - } - case OMNI_LONG: - case OMNI_DECIMAL64: { - vecList[i] = new LongVec(vecNativeIds[nativeGetId]); - break; - } - case OMNI_DOUBLE: { - vecList[i] = new DoubleVec(vecNativeIds[nativeGetId]); - break; - } - case OMNI_VARCHAR: { - vecList[i] = new VarcharVec(vecNativeIds[nativeGetId]); - break; - } - case OMNI_DECIMAL128: { - vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId]); - break; - } - default: { - throw new RuntimeException("UnSupport type for ColumnarFileScan:" + - DataType.DataTypeId.values()[typeIds[i]]); + DataType type = types.get(i); + if (type instanceof LongType) { + vecList[i] = new LongVec(vecNativeIds[i]); + } else if (type instanceof BooleanType) { + vecList[i] = new BooleanVec(vecNativeIds[i]); + } else if (type instanceof ShortType) { + vecList[i] = new ShortVec(vecNativeIds[i]); + } else if (type instanceof IntegerType) { + vecList[i] = new IntVec(vecNativeIds[i]); + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + vecList[i] = new LongVec(vecNativeIds[i]); + } else { + vecList[i] = new Decimal128Vec(vecNativeIds[i]); } + } else if (type instanceof DoubleType) { + vecList[i] = new DoubleVec(vecNativeIds[i]); + } else if (type instanceof StringType) { + vecList[i] = new VarcharVec(vecNativeIds[i]); + } else if (type instanceof DateType) { + vecList[i] = new IntVec(vecNativeIds[i]); + } else if (type instanceof ByteType) { + vecList[i] = new VarcharVec(vecNativeIds[i]); + } else { + throw new RuntimeException("Unsupport type for ColumnarFileScan: " + type.typeName()); } - nativeGetId++; } return (int)rtn; } public void close() { - recordReaderClose(parquetReader); + jniReader.recordReaderClose(parquetReader); } - - public native long initializeReader(JSONObject job); - - public native long recordReaderNext(long parquetReader, int[] typeId, long[] vecNativeId); - - public native void recordReaderClose(long parquetReader); - } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 1b94c47b03f77f628e1a3a0e6f5d0a6ee2147597..6a0c1b27c4282016aecada2ba4ef0c48c320f20f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -20,6 +20,7 @@ package com.huawei.boostkit.spark.serialize; import com.google.protobuf.InvalidProtocolBufferException; +import nova.hetu.omniruntime.utils.OmniRuntimeException; import nova.hetu.omniruntime.vector.BooleanVec; import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.DoubleVec; @@ -35,21 +36,31 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; - public class ShuffleDataSerializer { public static ColumnarBatch deserialize(byte[] bytes) { + ColumnVector[] vecs = null; try { VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes); int vecCount = vecBatch.getVecCnt(); int rowCount = vecBatch.getRowCnt(); - ColumnVector[] vecs = new ColumnVector[vecCount]; + vecs = new ColumnVector[vecCount]; for (int i = 0; i < vecCount; i++) { vecs[i] = buildVec(vecBatch.getVecs(i), rowCount); } return new ColumnarBatch(vecs, rowCount); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); + } catch (OmniRuntimeException e) { + if (vecs != null) { + for (int i = 0; i < vecs.length; i++) { + ColumnVector vec = vecs[i]; + if (vec != null) { + vec.close(); + } + } + } + throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); } } @@ -110,7 +121,9 @@ public class ShuffleDataSerializer { throw new IllegalStateException("Unexpected value: " + protoTypeId.getTypeId()); } vec.setValuesBuf(protoVec.getValues().toByteArray()); - vec.setNullsBuf(protoVec.getNulls().toByteArray()); + if(protoVec.getNulls().size() != 0) { + vec.setNullsBuf(protoVec.getNulls().toByteArray()); + } OmniColumnVector vecTmp = new OmniColumnVector(vecSize, type, false); vecTmp.setVec(vec); return vecTmp; diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index c170b04e4a4b678d962200772cf0c542bed591c4..aeaa10faab50bb0490529fadebb331db2c60efa5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.orc; import com.google.common.annotations.VisibleForTesting; -import com.huawei.boostkit.spark.jni.OrcColumnarBatchJniReader; +import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.InputSplit; @@ -29,18 +29,17 @@ import org.apache.hadoop.mapreduce.lib.input.FileSplit; import org.apache.orc.OrcConf; import org.apache.orc.OrcFile; import org.apache.orc.Reader; -import org.apache.orc.TypeDescription; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OmniColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.OmniColumnVector; -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.IOException; +import java.util.ArrayList; /** * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. @@ -64,9 +63,10 @@ public class OmniOrcColumnarBatchReader extends RecordReader orcfieldNames = recordReader.fildsNames; + // save valid cols and numbers of valid cols + recordReader.colsToGet = new int[requiredfieldNames.length]; + recordReader.realColsCnt = 0; + // save valid cols fieldsNames + recordReader.colToInclu = new ArrayList(); + for (int i = 0; i < requiredfieldNames.length; i++) { + String target = requiredfieldNames[i]; + boolean is_find = false; + for (int j = 0; j < orcfieldNames.size(); j++) { + String temp = orcfieldNames.get(j); + if (target.equals(temp)) { + requestedDataColIds[i] = i; + recordReader.colsToGet[i] = 0; + recordReader.colToInclu.add(requiredfieldNames[i]); + recordReader.realColsCnt++; + is_find = true; + } + } + + // if invalid, set colsToGet value -1, else set colsToGet 0 + if (!is_find) { + recordReader.colsToGet[i] = -1; + } + } + + for (int i = 0; i < resultFields.length; i++) { + if (requestedPartitionColIds[i] != -1) { + requestedDataColIds[i] = -1; + } + } + + // set data members resultFields and requestedDataColIdS + this.resultFields = resultFields; + this.requestedDataColIds = requestedDataColIds; + + recordReader.requiredfieldNames = requiredfieldNames; + recordReader.precisionArray = precisionArray; + recordReader.scaleArray = scaleArray; recordReader.initializeRecordReaderJava(options); } @@ -142,43 +193,36 @@ public class OmniOrcColumnarBatchReader extends RecordReader { @@ -79,13 +80,15 @@ public class OmniParquetColumnarBatchReader extends RecordReader types = new ArrayList<>(); private boolean isFilterPredicate = false; public OmniParquetColumnarBatchReader(int capacity, ParquetMetadata fileFooter) { @@ -93,7 +96,7 @@ public class OmniParquetColumnarBatchReader extends RecordReader rowgroupIndices = getFilteredBlocks(split.getStart(), split.getEnd()); List columnIndices = getColumnIndices(requestedSchema.getColumns(), fileSchema.getColumns()); String ugi = UserGroupInformation.getCurrentUser().toString(); - reader.initializeReaderJava(split.getPath().toString(), capacity, rowgroupIndices, columnIndices, ugi); + reader.initializeReaderJava(split.getPath(), capacity, rowgroupIndices, columnIndices, ugi); // Add missing Cols flags. initializeInternal(); } @@ -242,6 +247,7 @@ public class OmniParquetColumnarBatchReader extends RecordReader if (!enableColumnarProject) return false @@ -172,7 +172,8 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.buildSide, plan.condition, plan.left, - plan.right).buildCheck() + plan.right, + plan.isNullAwareAntiJoin).buildCheck() case plan: SortMergeJoinExec => if (!enableColumnarSortMergeJoin) return false new ColumnarSortMergeJoinExec( @@ -205,6 +206,9 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { if (!enableGlobalColumnarLimit) return false ColumnarGlobalLimitExec(plan.limit, plan.child).buildCheck() case plan: BroadcastNestedLoopJoinExec => return false + case plan: CoalesceExec => + if (!enableColumnarCoalesce) return false + ColumnarCoalesceExec(plan.numPartitions, plan.child).buildCheck() case p => p } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 5200431f43821f83c442ebe3170eebaa870a9277..108562dc66926aac1d5adadccc3d7568d79071da 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -1,4 +1,5 @@ /* + * Copyright (C) 2020-2024. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -19,20 +20,23 @@ package com.huawei.boostkit.spark import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.util.PhysicalPlanSelector - import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.Partial +import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} +import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.Aggregate case class ColumnarPreOverrides() extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf @@ -59,6 +63,10 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val enableColumnarProjectFusion: Boolean = columnarConf.enableColumnarProjectFusion val enableLocalColumnarLimit: Boolean = columnarConf.enableLocalColumnarLimit val enableGlobalColumnarLimit: Boolean = columnarConf.enableGlobalColumnarLimit + val enableDedupLeftSemiJoin: Boolean = columnarConf.enableDedupLeftSemiJoin + val dedupLeftSemiJoinThreshold: Int = columnarConf.dedupLeftSemiJoinThreshold + val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce + val enableRollupOptimization: Boolean = columnarConf.enableRollupOptimization def apply(plan: SparkPlan): SparkPlan = { replaceWithColumnarPlan(plan) @@ -115,9 +123,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.optionalNumCoalescedBuckets, plan.dataFilters, plan.tableIdentifier, - plan.needPriv, - plan.disableBucketedScan, - plan.outputAllAttributes + plan.disableBucketedScan ) case range: RangeExec => new ColumnarRangeExec(range.range) @@ -197,7 +203,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { proj4 @ ColumnarProjectExec(_, join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _, _, _) + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) @@ -229,7 +235,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { proj3 @ ColumnarProjectExec(_, join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) .asInstanceOf[ColumnarBroadcastExchangeExec].child) => @@ -258,7 +264,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { proj3 @ ColumnarProjectExec(_, join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) .asInstanceOf[ColumnarBroadcastExchangeExec].child) => @@ -305,16 +311,92 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { child) } } else { - new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - child) + if (child.isInstanceOf[ColumnarExpandExec]) { + var columnarExpandExec = child.asInstanceOf[ColumnarExpandExec] + val matchRollupOptimization: Boolean = columnarExpandExec.matchRollupOptimization() + if (matchRollupOptimization && enableRollupOptimization) { + // The sparkPlan: ColumnarExpandExec -> ColumnarHashAggExec => ColumnarExpandExec -> ColumnarHashAggExec -> ColumnarOptRollupExec. + // ColumnarHashAggExec handles the first combination by Partial mode, i.e. projections[0]. + // ColumnarOptRollupExec handles the residual combinations by PartialMerge mode, i.e. projections[1]~projections[n]. + val projections = columnarExpandExec.projections + val headProjections = projections.slice(0, 1) + var residualProjections = projections.slice(1, projections.length) + // replace parameters + columnarExpandExec = columnarExpandExec.replace(headProjections) + + // partial + val partialHashAggExec = new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + columnarExpandExec) + + + // If the aggregator has an expression, more than one column in the projection is used + // for expression calculation. Meanwhile, If the single distinct syntax exists, the + // sequence of group columns is disordered. Therefore, we need to calculate the sequence + // of expandSeq first to ensure the project operator correctly processes the columns. + val expectSeq = plan.resultExpressions + val expandSeq = columnarExpandExec.output + // the processing sequences of expandSeq + residualProjections = residualProjections.map(projection => { + val indexSeq: Seq[Expression] = expectSeq.map(expectExpr => { + val index = expandSeq.indexWhere(expandExpr => expectExpr.exprId.equals(expandExpr.exprId)) + if (index != -1) { + projection.apply(index) match { + case literal: Literal => literal + case _ => expectExpr + } + } else { + expectExpr + } + }) + indexSeq + }) + + // partial merge + val groupingExpressions = plan.resultExpressions.slice(0, plan.groupingExpressions.length) + val aggregateExpressions = plan.aggregateExpressions.map(expr => { + expr.copy(expr.aggregateFunction, PartialMerge, expr.isDistinct, expr.filter, expr.resultId) + }) + + // need ExpandExec parameters and HashAggExec parameters + new ColumnarOptRollupExec( + residualProjections, + plan.output, + groupingExpressions, + aggregateExpressions, + plan.aggregateAttributes, + partialHashAggExec) + } else { + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } + } else { + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } } case plan: TakeOrderedAndProjectExec if enableTakeOrderedAndProject => @@ -342,6 +424,72 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.condition, left, right) + case plan: ShuffledHashJoinExec if enableShuffledHashJoin && enableDedupLeftSemiJoin => { + plan.joinType match { + case LeftSemi => { + if (plan.condition.isEmpty && plan.right.output.size >= dedupLeftSemiJoinThreshold) { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + val partialAgg = PhysicalAggregation.unapply(Aggregate(plan.right.output, plan.right.output, new DummyLogicalPlan)) match { + case Some((groupingExpressions, aggExpressions, resultExpressions, _)) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + ExtendedAggUtils.planPartialAggregateWithoutDistinct( + ExtendedAggUtils.normalizeGroupingExpressions(groupingExpressions), + aggExpressions.map(_.asInstanceOf[AggregateExpression]), + resultExpressions, + right).asInstanceOf[HashAggregateExec] + } + val newHashAgg = new ColumnarHashAggregateExec( + partialAgg.requiredChildDistributionExpressions, + partialAgg.isStreaming, + partialAgg.numShufflePartitions, + partialAgg.groupingExpressions, + partialAgg.aggregateExpressions, + partialAgg.aggregateAttributes, + partialAgg.initialInputBufferOffset, + partialAgg.resultExpressions, + right) + + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + newHashAgg, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + plan.isSkewJoin) + } + } + case _ => { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + plan.isSkewJoin) + } + } + } case plan: ShuffledHashJoinExec if enableShuffledHashJoin => val left = replaceWithColumnarPlan(plan.left) val right = replaceWithColumnarPlan(plan.right) @@ -378,6 +526,9 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency) case plan: WindowExec if enableColumnarWindow => val child = replaceWithColumnarPlan(plan.child) + if (child.output.isEmpty) { + return plan + } logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") child match { case ColumnarSortExec(sortOrder, _, sortChild, _) => @@ -395,8 +546,12 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { ColumnarUnionExec(children) case plan: ShuffleExchangeExec if enableColumnarShuffle => val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin) + if (child.output.nonEmpty) { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin) + } else { + plan + } case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => plan.child match { case shuffle: ColumnarShuffleExchangeExec => @@ -426,6 +581,10 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarGlobalLimitExec(plan.limit, child) + case plan: CoalesceExec if enableColumnarCoalesce => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarCoalesceExec(plan.numPartitions, child) case p => val children = plan.children.map(replaceWithColumnarPlan) logInfo(s"Columnar Processing for ${p.getClass} is currently not supported.") @@ -460,12 +619,12 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { var isSupportAdaptive: Boolean = true def apply(plan: SparkPlan): SparkPlan = { - handleColumnarToRowParitalFetch(replaceWithColumnarPlan(plan)) + handleColumnarToRowPartialFetch(replaceWithColumnarPlan(plan)) } - private def handleColumnarToRowParitalFetch(plan: SparkPlan): SparkPlan = { + private def handleColumnarToRowPartialFetch(plan: SparkPlan): SparkPlan = { // simple check plan tree have OmniColumnarToRow and no LimitExec and TakeOrderedAndProjectExec plan - val noParitalFetch = if (plan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined) { + val noPartialFetch = if (plan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined) { (!plan.find(node => node.isInstanceOf[LimitExec] || node.isInstanceOf[TakeOrderedAndProjectExec] || node.isInstanceOf[SortMergeJoinExec]).isDefined) @@ -473,7 +632,7 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { false } val newPlan = plan.transformUp { - case c: OmniColumnarToRowExec if noParitalFetch => + case c: OmniColumnarToRowExec if noPartialFetch => c.copy(c.child, false) } newPlan @@ -582,11 +741,15 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit rule(plan) } } + class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { override def apply(extensions: SparkSessionExtensions): Unit = { logInfo("Using BoostKit Spark Native Sql Engine Extension to Speed Up Your Queries.") extensions.injectColumnar(session => ColumnarOverrideRules(session)) - extensions.injectPlannerStrategy(session => ShuffleJoinStrategy(session)) + extensions.injectPlannerStrategy(_ => ShuffleJoinStrategy) extensions.injectOptimizerRule(_ => RewriteSelfJoinInInPredicate) + extensions.injectOptimizerRule(_ => DelayCartesianProduct) + extensions.injectOptimizerRule(_ => HeuristicJoinReorder) + extensions.injectOptimizerRule(_ => MergeSubqueryFilters) } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index 9ab6c52dac915805d0534d82866f09e2172333c2..e87122e87c6b675dfaeab352af390143111510d9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2020-2024. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -107,21 +107,6 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { .getConfString("spark.omni.sql.columnar.parquetNativefilescan", "true") .toBoolean - // enable sync to get obs token - val enableSyncGetObsToken: Boolean = conf - .getConfString("spark.omni.sql.columnar.syncGetObsToken", "false") - .toBoolean - - // scheduled time to get obs token, the time unit is millisecond - val timeGetObsToken: Long = conf - .getConfString("spark.omni.sql.columnar.timeGetObsToken", "60000") - .toLong - - // retry times to get obs ak/sk/token - val retryTimesGetObsToken: Integer = conf - .getConfString("spark.omni.sql.columnar.retryTimesGetObsToken", "10") - .toInt - val enableColumnarSortMergeJoin: Boolean = conf .getConfString("spark.omni.sql.columnar.sortMergeJoin", "true") .toBoolean @@ -171,22 +156,37 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val columnarShuffleNativeBufferSize = conf.getConfString("spark.sql.execution.columnar.maxRecordsPerBatch", "4096").toInt + // columnar spill threshold - Percentage of memory usage, associate with the "spark.memory.offHeap" together + val columnarSpillMemPctThreshold: Integer = + conf.getConfString("spark.omni.sql.columnar.spill.memFraction", "90").toInt + + // columnar spill dir disk reserve Size, default 10GB + val columnarSpillDirDiskReserveSize:Long = + conf.getConfString("spark.omni.sql.columnar.spill.dirDiskReserveSize", "10737418240").toLong + + // enable or disable columnar sort spill + val enableSortSpill: Boolean = conf + .getConfString("spark.omni.sql.columnar.sortSpill.enabled", "true").toBoolean + // columnar sort spill threshold val columnarSortSpillRowThreshold: Integer = - conf.getConfString("spark.omni.sql.columnar.sortSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt + conf.getConfString("spark.omni.sql.columnar.sortSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt - // columnar sort spill threshold - Percentage of memory usage, associate with the "spark.memory.offHeap" together - val columnarSortSpillMemPctThreshold: Integer = - conf.getConfString("spark.omni.sql.columnar.sortSpill.memFraction", "90").toInt + // enable or disable columnar window spill + val enableWindowSpill: Boolean = conf + .getConfString("spark.omni.sql.columnar.windowSpill.enabled", "true").toBoolean - // columnar sort spill dir disk reserve Size, default 10GB - val columnarSortSpillDirDiskReserveSize:Long = - conf.getConfString("spark.omni.sql.columnar.sortSpill.dirDiskReserveSize", "10737418240").toLong + // columnar window spill threshold + val columnarWindowSpillRowThreshold: Integer = + conf.getConfString("spark.omni.sql.columnar.windowSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt - // enable or disable columnar sortSpill - val enableSortSpill: Boolean = conf - .getConfString("spark.omni.sql.columnar.sortSpill.enabled", "false") - .toBoolean + // enable or disable columnar hash aggregate spill + val enableHashAggSpill: Boolean = conf + .getConfString("spark.omni.sql.columnar.hashAggSpill.enabled", "true").toBoolean + + // columnar hash aggregate spill threshold + val columnarHashAggSpillRowThreshold: Integer = + conf.getConfString("spark.omni.sql.columnar.hashAggSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt // enable or disable columnar shuffledHashJoin val enableShuffledHashJoin: Boolean = conf @@ -195,7 +195,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { // enable or disable force shuffle hash join val forceShuffledHashJoin: Boolean = conf - .getConfString("spark.omni.sql.columnar.forceShuffledHashJoin", "false") + .getConfString("spark.omni.sql.columnar.forceShuffledHashJoin", "true") .toBoolean // enable or disable rewrite self join in Predicate to aggregate @@ -231,6 +231,34 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val enableLocalColumnarLimit : Boolean = conf.getConfString("spark.omni.sql.columnar.localLimit", "true").toBoolean val enableGlobalColumnarLimit : Boolean = conf.getConfString("spark.omni.sql.columnar.globalLimit", "true").toBoolean + + val topNPushDownForWindowThreshold = conf.getConfString("spark.sql.execution.topNPushDownForWindow.threshold", "100").toInt + + val topNPushDownForWindowEnable: Boolean = conf.getConfString("spark.sql.execution.topNPushDownForWindow.enabled", "true").toBoolean + + // enable or disable deduplicate the right side of left semi join + val enableDedupLeftSemiJoin: Boolean = + conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoin", "false").toBoolean + + val dedupLeftSemiJoinThreshold: Int = + conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoinThreshold", "3").toInt + + val filterMergeEnable: Boolean = conf.getConfString("spark.sql.execution.filterMerge.enabled", "false").toBoolean + + val filterMergeThreshold: Double = conf.getConfString("spark.sql.execution.filterMerge.maxCost", "100.0").toDouble + + // enable or disable columnar CoalesceExec + val enableColumnarCoalesce: Boolean = conf + .getConfString("spark.omni.sql.columnar.coalesce", "true") + .toBoolean + val enableRollupOptimization: Boolean = conf.getConfString("spark.omni.sql.columnar.rollupOptimization.enabled", "true").toBoolean + + // enable or disable radix sort + val enableRadixSort: Boolean = + conf.getConfString("spark.omni.sql.columnar.radixSort.enabled", "false").toBoolean + + val radixSortThreshold: Int = + conf.getConfString("spark.omni.sql.columnar.radixSortThreshold", "1000000").toInt } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala index b54b652ed8c88774d249308d49035ef52947bcfd..a36c5bcfe643a512807bd8b7419f9e77ca424d83 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala @@ -17,18 +17,16 @@ package com.huawei.boostkit.spark -import com.huawei.boostkit.spark.util.LogicalPlanSelector - -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper} import org.apache.spark.sql.catalyst.planning._ +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{joins, SparkPlan} -case class ShuffleJoinStrategy(session: SparkSession) extends Strategy +object ShuffleJoinStrategy extends Strategy with PredicateHelper with JoinSelectionHelper with SQLConfHelper { @@ -39,8 +37,7 @@ case class ShuffleJoinStrategy(session: SparkSession) extends Strategy private val columnarForceShuffledHashJoin = ColumnarPluginConfig.getConf.forceShuffledHashJoin - def apply(plan: LogicalPlan): Seq[SparkPlan] = LogicalPlanSelector.maybeNil(session, plan) { - plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, _, left, right, hint) if columnarPreferShuffledHashJoin => val enable = getBroadcastBuildSide(left, right, joinType, hint, true, conf).isEmpty && @@ -96,8 +93,8 @@ case class ShuffleJoinStrategy(session: SparkSession) extends Strategy leftBuildable = canBuildShuffledHashJoinLeft(joinType) rightBuildable = canBuildShuffledHashJoinRight(joinType) } else { - leftBuildable = canBuildShuffledHashJoinLeft(joinType) && buildLeft - rightBuildable = canBuildShuffledHashJoinRight(joinType) && buildRight + leftBuildable = canBuildShuffledHashJoinLeft(joinType) + rightBuildable = canBuildShuffledHashJoinRight(joinType) } getBuildSide( leftBuildable, @@ -120,7 +117,6 @@ case class ShuffleJoinStrategy(session: SparkSession) extends Strategy Nil } case _ => Nil - } } private def getBuildSide( @@ -140,4 +136,14 @@ case class ShuffleJoinStrategy(session: SparkSession) extends Strategy None } } + + def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = { + case _: InnerLike | RightOuter | FullOuter => true + case _ => false + } + + def supportHashBuildJoinTypeOnRight: JoinType => Boolean = { + case _: InnerLike | LeftOuter | FullOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 5c1ad0ef98564f8958d13e118b0c1775df7fe4d8..11ff8e12b54519484e6aa04885e3a50933d859a4 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -19,28 +19,26 @@ package com.huawei.boostkit.spark.expression import scala.collection.mutable.ArrayBuffer - import com.huawei.boostkit.spark.Constant.{DEFAULT_STRING_TYPE_LENGTH, IS_CHECK_OMNI_EXP, OMNI_BOOLEAN_TYPE, OMNI_DATE_TYPE, OMNI_DECIMAL128_TYPE, OMNI_DECIMAL64_TYPE, OMNI_DOUBLE_TYPE, OMNI_INTEGER_TYPE, OMNI_LONG_TYPE, OMNI_SHOR_TYPE, OMNI_VARCHAR_TYPE} import nova.hetu.omniruntime.`type`.{BooleanDataType, DataTypeSerializer, Date32DataType, Decimal128DataType, Decimal64DataType, DoubleDataType, IntDataType, LongDataType, ShortDataType, VarcharDataType} import nova.hetu.omniruntime.constants.FunctionType -import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL, OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL, OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.operator.OmniExprVerify - import com.huawei.boostkit.spark.ColumnarPluginConfig import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.Subquery +import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString import org.apache.spark.sql.execution.ColumnarBloomFilterSubquery import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.hive.HiveUdfAdaptorUtil -import org.apache.spark.sql.types.{BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, ShortType, StringType} +import org.apache.spark.sql.types.{BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} +import org.json.{JSONArray, JSONObject} import java.util.Locale -import scala.collection.mutable object OmniExpressionAdaptor extends Logging { @@ -53,6 +51,7 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported expression: $expr") } } + def getExprIdMap(inputAttrs: Seq[Attribute]): Map[ExprId, Int] = { var attrMap: Map[ExprId, Int] = Map() inputAttrs.zipWithIndex.foreach { case (inputAttr, i) => @@ -75,237 +74,32 @@ object OmniExpressionAdaptor extends Logging { } } - def rewriteToOmniExpressionLiteral(expr: Expression, exprsIndexMap: Map[ExprId, Int]): String = { - expr match { - case unscaledValue: UnscaledValue => - "UnscaledValue:%s(%s, %d, %d)".format( - sparkTypeToOmniExpType(unscaledValue.dataType), - rewriteToOmniExpressionLiteral(unscaledValue.child, exprsIndexMap), - unscaledValue.child.dataType.asInstanceOf[DecimalType].precision, - unscaledValue.child.dataType.asInstanceOf[DecimalType].scale) - - // omni not support return null, now rewrite to if(IsOverflowDecimal())? NULL:MakeDecimal() - case checkOverflow: CheckOverflow => - ("IF:%s(IsOverflowDecimal:%s(%s,%d,%d,%d,%d), %s, MakeDecimal:%s(%s,%d,%d,%d,%d))") - .format(sparkTypeToOmniExpType(checkOverflow.dataType), - // IsOverflowDecimal returnType - sparkTypeToOmniExpType(BooleanType), - // IsOverflowDecimal arguments - rewriteToOmniExpressionLiteral(checkOverflow.child, exprsIndexMap), - checkOverflow.dataType.precision, checkOverflow.dataType.scale, - checkOverflow.dataType.precision, checkOverflow.dataType.scale, - // if_true - rewriteToOmniExpressionLiteral(Literal(null, checkOverflow.dataType), exprsIndexMap), - // if_false - sparkTypeToOmniExpJsonType(checkOverflow.dataType), - rewriteToOmniExpressionLiteral(checkOverflow.child, exprsIndexMap), - checkOverflow.dataType.precision, checkOverflow.dataType.scale, - checkOverflow.dataType.precision, checkOverflow.dataType.scale) - - case makeDecimal: MakeDecimal => - makeDecimal.child.dataType match { - case decimalChild: DecimalType => - ("MakeDecimal:%s(%s,%s,%s,%s,%s)") - .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), - rewriteToOmniExpressionLiteral(makeDecimal.child, exprsIndexMap), - decimalChild.precision, decimalChild.scale, - makeDecimal.precision, makeDecimal.scale) - case longChild: LongType => - ("MakeDecimal:%s(%s,%s,%s)") - .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), - rewriteToOmniExpressionLiteral(makeDecimal.child, exprsIndexMap), - makeDecimal.precision, makeDecimal.scale) - case _ => - throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") - } - - case promotePrecision: PromotePrecision => - rewriteToOmniExpressionLiteral(promotePrecision.child, exprsIndexMap) - - case sub: Subtract => - "$operator$SUBTRACT:%s(%s,%s)".format( - sparkTypeToOmniExpType(sub.dataType), - rewriteToOmniExpressionLiteral(sub.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(sub.right, exprsIndexMap)) - - case add: Add => - "$operator$ADD:%s(%s,%s)".format( - sparkTypeToOmniExpType(add.dataType), - rewriteToOmniExpressionLiteral(add.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(add.right, exprsIndexMap)) - - case mult: Multiply => - "$operator$MULTIPLY:%s(%s,%s)".format( - sparkTypeToOmniExpType(mult.dataType), - rewriteToOmniExpressionLiteral(mult.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(mult.right, exprsIndexMap)) - - case divide: Divide => - "$operator$DIVIDE:%s(%s,%s)".format( - sparkTypeToOmniExpType(divide.dataType), - rewriteToOmniExpressionLiteral(divide.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(divide.right, exprsIndexMap)) - - case mod: Remainder => - "$operator$MODULUS:%s(%s,%s)".format( - sparkTypeToOmniExpType(mod.dataType), - rewriteToOmniExpressionLiteral(mod.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(mod.right, exprsIndexMap)) - - case greaterThan: GreaterThan => - "$operator$GREATER_THAN:%s(%s,%s)".format( - sparkTypeToOmniExpType(greaterThan.dataType), - rewriteToOmniExpressionLiteral(greaterThan.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(greaterThan.right, exprsIndexMap)) - - case greaterThanOrEq: GreaterThanOrEqual => - "$operator$GREATER_THAN_OR_EQUAL:%s(%s,%s)".format( - sparkTypeToOmniExpType(greaterThanOrEq.dataType), - rewriteToOmniExpressionLiteral(greaterThanOrEq.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(greaterThanOrEq.right, exprsIndexMap)) - - case lessThan: LessThan => - "$operator$LESS_THAN:%s(%s,%s)".format( - sparkTypeToOmniExpType(lessThan.dataType), - rewriteToOmniExpressionLiteral(lessThan.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(lessThan.right, exprsIndexMap)) - - case lessThanOrEq: LessThanOrEqual => - "$operator$LESS_THAN_OR_EQUAL:%s(%s,%s)".format( - sparkTypeToOmniExpType(lessThanOrEq.dataType), - rewriteToOmniExpressionLiteral(lessThanOrEq.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(lessThanOrEq.right, exprsIndexMap)) - - case equal: EqualTo => - "$operator$EQUAL:%s(%s,%s)".format( - sparkTypeToOmniExpType(equal.dataType), - rewriteToOmniExpressionLiteral(equal.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(equal.right, exprsIndexMap)) - - case or: Or => - "OR:%s(%s,%s)".format( - sparkTypeToOmniExpType(or.dataType), - rewriteToOmniExpressionLiteral(or.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(or.right, exprsIndexMap)) - - case and: And => - "AND:%s(%s,%s)".format( - sparkTypeToOmniExpType(and.dataType), - rewriteToOmniExpressionLiteral(and.left, exprsIndexMap), - rewriteToOmniExpressionLiteral(and.right, exprsIndexMap)) - - case alias: Alias => rewriteToOmniExpressionLiteral(alias.child, exprsIndexMap) - case literal: Literal => toOmniLiteral(literal) - case not: Not => - "not:%s(%s)".format( - sparkTypeToOmniExpType(BooleanType), - rewriteToOmniExpressionLiteral(not.child, exprsIndexMap)) - case isnotnull: IsNotNull => - "IS_NOT_NULL:%s(%s)".format( - sparkTypeToOmniExpType(BooleanType), - rewriteToOmniExpressionLiteral(isnotnull.child, exprsIndexMap)) - // Substring - case subString: Substring => - "substr:%s(%s,%s,%s)".format( - sparkTypeToOmniExpType(subString.dataType), - rewriteToOmniExpressionLiteral(subString.str, exprsIndexMap), - rewriteToOmniExpressionLiteral(subString.pos, exprsIndexMap), - rewriteToOmniExpressionLiteral(subString.len, exprsIndexMap)) - // Cast - case cast: Cast => - unsupportedCastCheck(expr, cast) - "CAST:%s(%s)".format( - sparkTypeToOmniExpType(cast.dataType), - rewriteToOmniExpressionLiteral(cast.child, exprsIndexMap)) - // Abs - case abs: Abs => - "abs:%s(%s)".format( - sparkTypeToOmniExpType(abs.dataType), - rewriteToOmniExpressionLiteral(abs.child, exprsIndexMap)) - // In - case in: In => - "IN:%s(%s)".format( - sparkTypeToOmniExpType(in.dataType), - in.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) - .mkString(",")) - // coming from In expression with optimizerInSetConversionThreshold - case inSet: InSet => - "IN:%s(%s,%s)".format( - sparkTypeToOmniExpType(inSet.dataType), - rewriteToOmniExpressionLiteral(inSet.child, exprsIndexMap), - inSet.set.map(child => toOmniLiteral( - Literal(child, inSet.child.dataType))).mkString(",")) - // only support with one case condition, for omni rewrite to if(A, B, C) - case caseWhen: CaseWhen => - "IF:%s(%s, %s, %s)".format( - sparkTypeToOmniExpType(caseWhen.dataType), - rewriteToOmniExpressionLiteral(caseWhen.branches(0)._1, exprsIndexMap), - rewriteToOmniExpressionLiteral(caseWhen.branches(0)._2, exprsIndexMap), - rewriteToOmniExpressionLiteral(caseWhen.elseValue.get, exprsIndexMap)) - // Sum - case sum: Sum => - "SUM:%s(%s)".format( - sparkTypeToOmniExpType(sum.dataType), - sum.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) - .mkString(",")) - // Max - case max: Max => - "MAX:%s(%s)".format( - sparkTypeToOmniExpType(max.dataType), - max.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) - .mkString(",")) - // Average - case avg: Average => - "AVG:%s(%s)".format( - sparkTypeToOmniExpType(avg.dataType), - avg.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) - .mkString(",")) - // Min - case min: Min => - "MIN:%s(%s)".format( - sparkTypeToOmniExpType(min.dataType), - min.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) - .mkString(",")) - - case coalesce: Coalesce => - "COALESCE:%s(%s)".format( - sparkTypeToOmniExpType(coalesce.dataType), - coalesce.children.map(child => rewriteToOmniExpressionLiteral(child, exprsIndexMap)) - .mkString(",")) - - case concat: Concat => - getConcatStr(concat, exprsIndexMap) - - case attr: Attribute => s"#${exprsIndexMap(attr.exprId).toString}" - case _ => - throw new UnsupportedOperationException(s"Unsupported expression: $expr") + private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { + def doSupportCastToString(dataType: DataType): Boolean = { + if (dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] + || dataType.isInstanceOf[LongType]) { + true + } else { + false + } } - } - private def getConcatStr(concat: Concat, exprsIndexMap: Map[ExprId, Int]): String = { - val child: Seq[Expression] = concat.children - checkInputDataTypes(child) - val template = "concat:%s(%s,%s)" - val omniType = sparkTypeToOmniExpType(concat.dataType) - if (child.length == 1) { - return rewriteToOmniExpressionLiteral(child.head, exprsIndexMap) + def doSupportCastFromString(dataType: DataType): Boolean = { + if (dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[DateType] + || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType]) { + true + } else { + false + } } - // (a, b, c) => concat(concat(a,b),c) - var res = template.format(omniType, - rewriteToOmniExpressionLiteral(child.head, exprsIndexMap), - rewriteToOmniExpressionLiteral(child(1), exprsIndexMap)) - for (i <- 2 until child.length) { - res = template.format(omniType, res, - rewriteToOmniExpressionLiteral(child(i), exprsIndexMap)) + + // support cast(decimal/string/int/long as string) + if (cast.dataType.isInstanceOf[StringType] && !doSupportCastToString(cast.child.dataType)) { + throw new UnsupportedOperationException(s"Unsupported expression: $expr") } - res - } - private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { - def isDecimalOrStringType(dataType: DataType): Boolean = (dataType.isInstanceOf[DecimalType]) || (dataType.isInstanceOf[StringType] || (dataType.isInstanceOf[DateType])) - // not support Cast(string as !(decimal/string)) and Cast(!(decimal/string) as string) - if ((cast.dataType.isInstanceOf[StringType] && !isDecimalOrStringType(cast.child.dataType)) || - (!isDecimalOrStringType(cast.dataType) && cast.child.dataType.isInstanceOf[StringType])) { + // support cast(string as decimal/string/date/int/long/double) + if (!doSupportCastFromString(cast.dataType) && cast.child.dataType.isInstanceOf[StringType]) { throw new UnsupportedOperationException(s"Unsupported expression: $expr") } @@ -315,281 +109,297 @@ object OmniExpressionAdaptor extends Logging { } } - def toOmniLiteral(literal: Literal): String = { - val omniType = sparkTypeToOmniExpType(literal.dataType) - literal.dataType match { - case null => s"null:${omniType}" - case StringType => s"\'${literal.toString}\':${omniType}" - case _ => literal.toString + s":${omniType}" - } - } - def rewriteToOmniJsonExpressionLiteral(expr: Expression, - exprsIndexMap: Map[ExprId, Int]): String = { + exprsIndexMap: Map[ExprId, Int]): String = { rewriteToOmniJsonExpressionLiteral(expr, exprsIndexMap, expr.dataType) } def rewriteToOmniJsonExpressionLiteral(expr: Expression, exprsIndexMap: Map[ExprId, Int], returnDatatype: DataType): String = { + rewriteToOmniJsonExpressionLiteralJsonObject(expr, exprsIndexMap, returnDatatype).toString + } + + private def rewriteToOmniJsonExpressionLiteralJsonObject(expr: Expression, + exprsIndexMap: Map[ExprId, Int]): JSONObject = { + rewriteToOmniJsonExpressionLiteralJsonObject(expr, exprsIndexMap, expr.dataType) + } + + private def rewriteToOmniJsonExpressionLiteralJsonObject(expr: Expression, + exprsIndexMap: Map[ExprId, Int], + returnDatatype: DataType): JSONObject = { expr match { case unscaledValue: UnscaledValue => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"UnscaledValue\", \"arguments\":[%s]}") - .format(sparkTypeToOmniExpJsonType(unscaledValue.dataType), - rewriteToOmniJsonExpressionLiteral(unscaledValue.child, exprsIndexMap)) - + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", unscaledValue.dataType) + .put("function_name", "UnscaledValue") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(unscaledValue.child, exprsIndexMap))) case checkOverflow: CheckOverflow => - rewriteToOmniJsonExpressionLiteral(checkOverflow.child, exprsIndexMap, returnDatatype) + rewriteToOmniJsonExpressionLiteralJsonObject(checkOverflow.child, exprsIndexMap, returnDatatype) case makeDecimal: MakeDecimal => makeDecimal.child.dataType match { case decimalChild: DecimalType => - ("{\"exprType\": \"FUNCTION\", \"returnType\":%s," + - "\"function_name\": \"MakeDecimal\", \"arguments\": [%s]}") - .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), - rewriteToOmniJsonExpressionLiteral(makeDecimal.child, exprsIndexMap)) - + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", makeDecimal.dataType) + .put("function_name", "MakeDecimal") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(makeDecimal.child, exprsIndexMap))) case longChild: LongType => - ("{\"exprType\": \"FUNCTION\", \"returnType\":%s," + - "\"function_name\": \"MakeDecimal\", \"arguments\": [%s]}") - .format(sparkTypeToOmniExpJsonType(makeDecimal.dataType), - rewriteToOmniJsonExpressionLiteral(makeDecimal.child, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .put("function_name", "MakeDecimal") + .addOmniExpJsonType("returnType", makeDecimal.dataType) + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(makeDecimal.child, exprsIndexMap))) case _ => throw new UnsupportedOperationException(s"Unsupported datatype for MakeDecimal: ${makeDecimal.child.dataType}") } case promotePrecision: PromotePrecision => - rewriteToOmniJsonExpressionLiteral(promotePrecision.child, exprsIndexMap) + rewriteToOmniJsonExpressionLiteralJsonObject(promotePrecision.child, exprsIndexMap) case sub: Subtract => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"SUBTRACT\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(returnDatatype), - rewriteToOmniJsonExpressionLiteral(sub.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(sub.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", returnDatatype) + .put("operator", "SUBTRACT") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(sub.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(sub.right, exprsIndexMap)) case add: Add => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"ADD\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(returnDatatype), - rewriteToOmniJsonExpressionLiteral(add.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(add.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", returnDatatype) + .put("operator", "ADD") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(add.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(add.right, exprsIndexMap)) case mult: Multiply => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"MULTIPLY\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(returnDatatype), - rewriteToOmniJsonExpressionLiteral(mult.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(mult.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", returnDatatype) + .put("operator", "MULTIPLY") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(mult.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(mult.right, exprsIndexMap)) case divide: Divide => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"DIVIDE\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(returnDatatype), - rewriteToOmniJsonExpressionLiteral(divide.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(divide.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", returnDatatype) + .put("operator", "DIVIDE") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(divide.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(divide.right, exprsIndexMap)) case mod: Remainder => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"MODULUS\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(returnDatatype), - rewriteToOmniJsonExpressionLiteral(mod.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(mod.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", returnDatatype) + .put("operator", "MODULUS") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(mod.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(mod.right, exprsIndexMap)) case greaterThan: GreaterThan => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"GREATER_THAN\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(greaterThan.dataType), - rewriteToOmniJsonExpressionLiteral(greaterThan.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(greaterThan.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", greaterThan.dataType) + .put("operator", "GREATER_THAN") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(greaterThan.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(greaterThan.right, exprsIndexMap)) case greaterThanOrEq: GreaterThanOrEqual => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"GREATER_THAN_OR_EQUAL\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(greaterThanOrEq.dataType), - rewriteToOmniJsonExpressionLiteral(greaterThanOrEq.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(greaterThanOrEq.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", greaterThanOrEq.dataType) + .put("operator", "GREATER_THAN_OR_EQUAL") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(greaterThanOrEq.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(greaterThanOrEq.right, exprsIndexMap)) case lessThan: LessThan => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"LESS_THAN\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(lessThan.dataType), - rewriteToOmniJsonExpressionLiteral(lessThan.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(lessThan.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", lessThan.dataType) + .put("operator", "LESS_THAN") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(lessThan.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(lessThan.right, exprsIndexMap)) case lessThanOrEq: LessThanOrEqual => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"LESS_THAN_OR_EQUAL\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(lessThanOrEq.dataType), - rewriteToOmniJsonExpressionLiteral(lessThanOrEq.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(lessThanOrEq.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", lessThanOrEq.dataType) + .put("operator", "LESS_THAN_OR_EQUAL") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(lessThanOrEq.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(lessThanOrEq.right, exprsIndexMap)) case equal: EqualTo => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"EQUAL\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(equal.dataType), - rewriteToOmniJsonExpressionLiteral(equal.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(equal.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", equal.dataType) + .put("operator", "EQUAL") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(equal.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(equal.right, exprsIndexMap)) case or: Or => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"OR\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(or.dataType), - rewriteToOmniJsonExpressionLiteral(or.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(or.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", or.dataType) + .put("operator", "OR") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(or.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(or.right, exprsIndexMap)) case and: And => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"AND\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(and.dataType), - rewriteToOmniJsonExpressionLiteral(and.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(and.right, exprsIndexMap)) + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", and.dataType) + .put("operator", "AND") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(and.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(and.right, exprsIndexMap)) - case alias: Alias => rewriteToOmniJsonExpressionLiteral(alias.child, exprsIndexMap) + case alias: Alias => rewriteToOmniJsonExpressionLiteralJsonObject(alias.child, exprsIndexMap) case literal: Literal => toOmniJsonLiteral(literal) case not: Not => not.child match { case isnull: IsNull => - "{\"exprType\":\"UNARY\",\"returnType\":%s,\"operator\":\"not\",\"expr\":%s}".format( - sparkTypeToOmniExpJsonType(BooleanType), - rewriteToOmniJsonExpressionLiteral(isnull, exprsIndexMap)) + new JSONObject().put("exprType", "UNARY") + .addOmniExpJsonType("returnType", BooleanType) + .put("operator", "not") + .put("expr", rewriteToOmniJsonExpressionLiteralJsonObject(isnull, exprsIndexMap)) + case equal: EqualTo => - ("{\"exprType\":\"BINARY\",\"returnType\":%s," + - "\"operator\":\"NOT_EQUAL\",\"left\":%s,\"right\":%s}").format( - sparkTypeToOmniExpJsonType(equal.dataType), - rewriteToOmniJsonExpressionLiteral(equal.left, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(equal.right, exprsIndexMap)) - case _ => throw new UnsupportedOperationException(s"Unsupported expression: $expr") + new JSONObject().put("exprType", "BINARY") + .addOmniExpJsonType("returnType", equal.dataType) + .put("operator", "NOT_EQUAL") + .put("left", rewriteToOmniJsonExpressionLiteralJsonObject(equal.left, exprsIndexMap)) + .put("right", rewriteToOmniJsonExpressionLiteralJsonObject(equal.right, exprsIndexMap)) + + case _ => + new JSONObject().put("exprType", "UNARY") + .addOmniExpJsonType("returnType", BooleanType) + .put("operator", "not") + .put("expr", rewriteToOmniJsonExpressionLiteralJsonObject(not.child, exprsIndexMap)) } case isnotnull: IsNotNull => - ("{\"exprType\":\"UNARY\",\"returnType\":%s, \"operator\":\"not\"," - + "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":%s," - + "\"arguments\":[%s]}}").format(sparkTypeToOmniExpJsonType(BooleanType), - sparkTypeToOmniExpJsonType(BooleanType), - rewriteToOmniJsonExpressionLiteral(isnotnull.child, exprsIndexMap)) + new JSONObject().put("exprType", "UNARY") + .addOmniExpJsonType("returnType", BooleanType) + .put("operator", "not") + .put("expr", new JSONObject() + .put("exprType", "IS_NULL") + .addOmniExpJsonType("returnType", BooleanType) + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(isnotnull.child, exprsIndexMap)))) case isNull: IsNull => - "{\"exprType\":\"IS_NULL\",\"returnType\":%s,\"arguments\":[%s]}".format( - sparkTypeToOmniExpJsonType(BooleanType), - rewriteToOmniJsonExpressionLiteral(isNull.child, exprsIndexMap)) + new JSONObject().put("exprType", "IS_NULL") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(isNull.child, exprsIndexMap))) + .addOmniExpJsonType("returnType", BooleanType) // Substring case subString: Substring => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"substr\", \"arguments\":[%s,%s,%s]}") - .format(sparkTypeToOmniExpJsonType(subString.dataType), - rewriteToOmniJsonExpressionLiteral(subString.str, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(subString.pos, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(subString.len, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", subString.dataType) + .put("function_name", "substr") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.str, exprsIndexMap)). + put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.pos, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap))) // Cast case cast: Cast => unsupportedCastCheck(expr, cast) - val returnType = sparkTypeToOmniExpJsonType(cast.dataType) cast.dataType match { case StringType => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"width\":50,\"function_name\":\"CAST\", \"arguments\":[%s]}") - .format(returnType, rewriteToOmniJsonExpressionLiteral(cast.child, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", cast.dataType) + .put("width", 50) + .put("function_name", "CAST") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(cast.child, exprsIndexMap))) + case _ => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"CAST\",\"arguments\":[%s]}") - .format(returnType, rewriteToOmniJsonExpressionLiteral(cast.child, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", cast.dataType) + .put("function_name", "CAST") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(cast.child, exprsIndexMap))) + } // Abs case abs: Abs => - "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"abs\", \"arguments\":[%s]}" - .format(sparkTypeToOmniExpJsonType(abs.dataType), - rewriteToOmniJsonExpressionLiteral(abs.child, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", abs.dataType) + .put("function_name", "abs") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(abs.child, exprsIndexMap))) case lower: Lower => - "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"lower\", \"arguments\":[%s]}" - .format(sparkTypeToOmniExpJsonType(lower.dataType), - rewriteToOmniJsonExpressionLiteral(lower.child, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", lower.dataType) + .put("function_name", "lower") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(lower.child, exprsIndexMap))) case upper: Upper => - "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"upper\", \"arguments\":[%s]}" - .format(sparkTypeToOmniExpJsonType(upper.dataType), - rewriteToOmniJsonExpressionLiteral(upper.child, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", upper.dataType) + .put("function_name", "upper") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(upper.child, exprsIndexMap))) case length: Length => - "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"length\", \"arguments\":[%s]}" - .format(sparkTypeToOmniExpJsonType(length.dataType), - rewriteToOmniJsonExpressionLiteral(length.child, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", length.dataType) + .put("function_name", "length") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(length.child, exprsIndexMap))) case replace: StringReplace => - "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"replace\", \"arguments\":[%s,%s,%s]}" - .format(sparkTypeToOmniExpJsonType(replace.dataType), - rewriteToOmniJsonExpressionLiteral(replace.srcExpr, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(replace.searchExpr, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(replace.replaceExpr, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", replace.dataType) + .put("function_name", "replace") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(replace.srcExpr, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(replace.searchExpr, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(replace.replaceExpr, exprsIndexMap))) // In case in: In => - "{\"exprType\":\"IN\",\"returnType\":%s, \"arguments\":%s}".format( - sparkTypeToOmniExpJsonType(in.dataType), - in.children.map(child => rewriteToOmniJsonExpressionLiteral(child, exprsIndexMap)) - .mkString("[", ",", "]")) + new JSONObject().put("exprType", "IN") + .addOmniExpJsonType("returnType", in.dataType) + .put("arguments", new JSONArray(in.children.map(child => rewriteToOmniJsonExpressionLiteralJsonObject(child, exprsIndexMap)).toArray)) // coming from In expression with optimizerInSetConversionThreshold case inSet: InSet => - "{\"exprType\":\"IN\",\"returnType\":%s, \"arguments\":[%s, %s]}" - .format(sparkTypeToOmniExpJsonType(inSet.dataType), - rewriteToOmniJsonExpressionLiteral(inSet.child, exprsIndexMap), - inSet.set.map(child => - toOmniJsonLiteral(Literal(child, inSet.child.dataType))).mkString(",")) + val jSONObject = new JSONObject().put("exprType", "IN") + .addOmniExpJsonType("returnType", inSet.dataType) + val jsonArray = new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(inSet.child, exprsIndexMap)) + inSet.set.foreach(child => jsonArray.put(toOmniJsonLiteral(Literal(child, inSet.child.dataType)))) + jSONObject.put("arguments", jsonArray) + jSONObject case ifExp: If => - "{\"exprType\":\"IF\",\"returnType\":%s,\"condition\":%s,\"if_true\":%s,\"if_false\":%s}" - .format(sparkTypeToOmniExpJsonType(ifExp.dataType), - rewriteToOmniJsonExpressionLiteral(ifExp.predicate, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(ifExp.trueValue, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(ifExp.falseValue, exprsIndexMap)) + new JSONObject().put("exprType", "IF") + .addOmniExpJsonType("returnType", ifExp.dataType) + .put("condition", rewriteToOmniJsonExpressionLiteralJsonObject(ifExp.predicate, exprsIndexMap)) + .put("if_true", rewriteToOmniJsonExpressionLiteralJsonObject(ifExp.trueValue, exprsIndexMap)) + .put("if_false", rewriteToOmniJsonExpressionLiteralJsonObject(ifExp.falseValue, exprsIndexMap)) case caseWhen: CaseWhen => procCaseWhenExpression(caseWhen, exprsIndexMap) case coalesce: Coalesce => if (coalesce.children.length > 2) { - throw new UnsupportedOperationException(s"Number of parameters is ${coalesce.children.length}. Exceeds the maximum number of parameters, coalesce only supports up to 2 parameters") + throw new UnsupportedOperationException(s"Number of parameters is ${coalesce.children.length}. Exceeds the maximum number of parameters, coalesce only supports up to 2 parameters") } - "{\"exprType\":\"COALESCE\",\"returnType\":%s, \"value1\":%s,\"value2\":%s}".format( - sparkTypeToOmniExpJsonType(coalesce.dataType), - rewriteToOmniJsonExpressionLiteral(coalesce.children(0), exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(coalesce.children(1), exprsIndexMap)) + new JSONObject().put("exprType", "COALESCE") + .addOmniExpJsonType("returnType", coalesce.dataType) + .put("value1", rewriteToOmniJsonExpressionLiteralJsonObject(coalesce.children.head, exprsIndexMap)) + .put("value2", rewriteToOmniJsonExpressionLiteralJsonObject(coalesce.children(1), exprsIndexMap)) case concat: Concat => getConcatJsonStr(concat, exprsIndexMap) case round: Round => - "{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"round\", \"arguments\":[%s,%s]}" - .format(sparkTypeToOmniExpJsonType(round.dataType), - rewriteToOmniJsonExpressionLiteral(round.child, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(round.scale, exprsIndexMap)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", round.dataType) + .put("function_name", "round") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(round.child, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(round.scale, exprsIndexMap))) case attr: Attribute => toOmniJsonAttribute(attr, exprsIndexMap(attr.exprId)) // might_contain case bloomFilterMightContain: BloomFilterMightContain => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"might_contain\", \"arguments\":[%s,%s]}") - .format(sparkTypeToOmniExpJsonType(bloomFilterMightContain.dataType), - rewriteToOmniJsonExpressionLiteral( - ColumnarExpressionConverter.replaceWithColumnarExpression(bloomFilterMightContain.bloomFilterExpression), - exprsIndexMap - ), - rewriteToOmniJsonExpressionLiteral(bloomFilterMightContain.valueExpression, exprsIndexMap, returnDatatype)) + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", bloomFilterMightContain.dataType) + .put("function_name", "might_contain") + .put("arguments", new JSONArray() + .put(rewriteToOmniJsonExpressionLiteralJsonObject( + ColumnarExpressionConverter.replaceWithColumnarExpression(bloomFilterMightContain.bloomFilterExpression), + exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(bloomFilterMightContain.valueExpression, exprsIndexMap, returnDatatype))) case columnarBloomFilterSubquery: ColumnarBloomFilterSubquery => val bfAddress: Long = columnarBloomFilterSubquery.eval().asInstanceOf[Long] - if (bfAddress == 0L) { - ("{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":true,\"value\":%d}") - .format(bfAddress) - } else { - ("{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":%d}") - .format(bfAddress) - } + new JSONObject().put("exprType", "LITERAL") + .put("isNull", bfAddress == 0L) + .put("dataType", 2) + .put("value", bfAddress) case hash: Murmur3Hash => genMurMur3HashExpr(hash.children, hash.seed, exprsIndexMap) @@ -597,32 +407,74 @@ object OmniExpressionAdaptor extends Logging { case xxHash: XxHash64 => genXxHash64Expr(xxHash.children, xxHash.seed, exprsIndexMap) + case inStr: StringInstr => + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", inStr.dataType) + .put("function_name", "instr") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.str, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.substr, exprsIndexMap))) + + // for floating numbers normalize + case normalizeNaNAndZero: NormalizeNaNAndZero => + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", normalizeNaNAndZero.dataType) + .put("function_name", "NormalizeNaNAndZero") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(normalizeNaNAndZero.child, exprsIndexMap))) + + case knownFloatingPointNormalized: KnownFloatingPointNormalized => + rewriteToOmniJsonExpressionLiteralJsonObject(knownFloatingPointNormalized.child, exprsIndexMap) + + // for like + case startsWith: StartsWith => + startsWith.right match { + case literal: Literal => + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", startsWith.dataType) + .put("function_name", "StartsWith") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(startsWith.left, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(startsWith.right, exprsIndexMap))) + + case _ => + throw new UnsupportedOperationException(s"Unsupported right expression in like expression: $startsWith") + } + case endsWith: EndsWith => + endsWith.right match { + case literal: Literal => + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", endsWith.dataType) + .put("function_name", "EndsWith") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(endsWith.left, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(endsWith.right, exprsIndexMap))) + + case _ => + throw new UnsupportedOperationException(s"Unsupported right expression in like expression: $endsWith") + } + case _ => if (HiveUdfAdaptorUtil.isHiveUdf(expr) && ColumnarPluginConfig.getSessionConf.enableColumnarUdf) { val hiveUdf = HiveUdfAdaptorUtil.asHiveSimpleUDF(expr) val nameSplit = hiveUdf.name.split("\\.") val udfName = if (nameSplit.size == 1) nameSplit(0).toLowerCase(Locale.ROOT) else nameSplit(1).toLowerCase(Locale.ROOT) - return ("{\"exprType\":\"FUNCTION\",\"returnType\":%s,\"function_name\":\"%s\"," + - "\"arguments\":[%s]}").format(sparkTypeToOmniExpJsonType(hiveUdf.dataType), udfName, - getJsonExprArgumentsByChildren(hiveUdf.children, exprsIndexMap)) + return new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", hiveUdf.dataType) + .put("function_name", udfName) + .put("arguments", getJsonExprArgumentsByChildren(hiveUdf.children, exprsIndexMap)) } throw new UnsupportedOperationException(s"Unsupported expression: $expr") } } private def getJsonExprArgumentsByChildren(children: Seq[Expression], - exprsIndexMap: Map[ExprId, Int]): String = { + exprsIndexMap: Map[ExprId, Int]): JSONArray = { val size = children.size - val stringBuild = new mutable.StringBuilder + val jsonArray = new JSONArray() if (size == 0) { - return stringBuild.toString() + return jsonArray } - for (i <- 0 until size - 1) { - stringBuild.append(rewriteToOmniJsonExpressionLiteral(children(i), exprsIndexMap)) - stringBuild.append(",") + for (i <- 0 until size) { + jsonArray.put(rewriteToOmniJsonExpressionLiteralJsonObject(children(i), exprsIndexMap)) } - stringBuild.append(rewriteToOmniJsonExpressionLiteral(children(size - 1), exprsIndexMap)) - stringBuild.toString() + jsonArray } private def checkInputDataTypes(children: Seq[Expression]): Unit = { @@ -634,101 +486,156 @@ object OmniExpressionAdaptor extends Logging { } } - private def getConcatJsonStr(concat: Concat, exprsIndexMap: Map[ExprId, Int]): String = { + private def getConcatJsonStr(concat: Concat, exprsIndexMap: Map[ExprId, Int]): JSONObject = { val children: Seq[Expression] = concat.children checkInputDataTypes(children) - val template = "{\"exprType\": \"FUNCTION\",\"returnType\":%s," + - "\"function_name\": \"concat\", \"arguments\": [%s, %s]}" - val returnType = sparkTypeToOmniExpJsonType(concat.dataType) + if (children.length == 1) { - return rewriteToOmniJsonExpressionLiteral(children.head, exprsIndexMap) + return rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap) } - var res = template.format(returnType, - rewriteToOmniJsonExpressionLiteral(children.head, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(children(1), exprsIndexMap)) + val res = new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", concat.dataType) + .put("function_name", "concat") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(children.head, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(children(1), exprsIndexMap))) for (i <- 2 until children.length) { - res = template.format(returnType, res, - rewriteToOmniJsonExpressionLiteral(children(i), exprsIndexMap)) + val preResJson = new JSONObject(res, JSONObject.getNames(res)) + res.put("arguments", new JSONArray().put(preResJson) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(children(i), exprsIndexMap))) } res } // gen murmur3hash partition expression - private def genMurMur3HashExpr(expressions: Seq[Expression], seed: Int, exprsIndexMap: Map[ExprId, Int]): String = { - var omniExpr: String = "" + private def genMurMur3HashExpr(expressions: Seq[Expression], seed: Int, exprsIndexMap: Map[ExprId, Int]): JSONObject = { + var jsonObject: JSONObject = new JSONObject() expressions.foreach { expr => - val colExpr = rewriteToOmniJsonExpressionLiteral(expr, exprsIndexMap) - if (omniExpr.isEmpty) { - omniExpr = ("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"%s\",\"arguments\":[" + - "%s,{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":%d}]}").format("mm3hash", colExpr, seed) + val colExprJsonObject = rewriteToOmniJsonExpressionLiteralJsonObject(expr, exprsIndexMap) + if (jsonObject.length() == 0) { + jsonObject = new JSONObject().put("exprType", "FUNCTION") + .put("returnType", 1) + .put("function_name", "mm3hash") + .put("arguments", new JSONArray() + .put(colExprJsonObject) + .put(new JSONObject() + .put("exprType", "LITERAL") + .put("dataType", 1) + .put("isNull", false) + .put("value", seed))) } else { - omniExpr = ("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"%s\",\"arguments\":[%s,%s]}") - .format("mm3hash", colExpr, omniExpr) + jsonObject = new JSONObject().put("exprType", "FUNCTION") + .put("returnType", 1) + .put("function_name", "mm3hash") + .put("arguments", new JSONArray().put(colExprJsonObject).put(jsonObject)) } } - omniExpr + jsonObject } // gen XxHash64 partition expression - private def genXxHash64Expr(expressions: Seq[Expression], seed: Long, exprsIndexMap: Map[ExprId, Int]): String = { - var omniExpr: String = "" + private def genXxHash64Expr(expressions: Seq[Expression], seed: Long, exprsIndexMap: Map[ExprId, Int]): JSONObject = { + var jsonObject: JSONObject = new JSONObject() expressions.foreach { expr => - val colExpr = rewriteToOmniJsonExpressionLiteral(expr, exprsIndexMap) - if (omniExpr.isEmpty) { - omniExpr = ("{\"exprType\":\"FUNCTION\",\"returnType\":2,\"function_name\":\"%s\",\"arguments\":[" + - "%s,{\"exprType\":\"LITERAL\",\"dataType\":2,\"isNull\":false,\"value\":%d}]}").format("xxhash64", colExpr, seed) + val colExprJsonObject = rewriteToOmniJsonExpressionLiteralJsonObject(expr, exprsIndexMap) + if (jsonObject.length() == 0) { + jsonObject = new JSONObject().put("exprType", "FUNCTION") + .put("returnType", 2) + .put("function_name", "xxhash64") + .put("arguments", new JSONArray() + .put(colExprJsonObject) + .put(new JSONObject() + .put("exprType", "LITERAL") + .put("dataType", 2) + .put("isNull", false) + .put("value", seed))) } else { - omniExpr = ("{\"exprType\":\"FUNCTION\",\"returnType\":2,\"function_name\":\"%s\",\"arguments\":[%s,%s]}") - .format("xxhash64", colExpr, omniExpr) + jsonObject = new JSONObject().put("exprType", "FUNCTION") + .put("returnType", 2) + .put("function_name", "xxhash64") + .put("arguments", new JSONArray().put(colExprJsonObject).put(jsonObject)) } } - omniExpr + jsonObject } - def toOmniJsonAttribute(attr: Attribute, colVal: Int): String = { - - val omniDataType = sparkTypeToOmniExpType(attr.dataType) + def toOmniJsonAttribute(attr: Attribute, colVal: Int): JSONObject = { + val omniDataType = sparkTypeToOmniExpType(attr.dataType) attr.dataType match { case StringType => - ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s," + - "\"colVal\":%d,\"width\":%d}").format(omniDataType, colVal, - getStringLength(attr.metadata)) + new JSONObject().put("exprType", "FIELD_REFERENCE") + .put("dataType", omniDataType.toInt) + .put("colVal", colVal) + .put("width", getStringLength(attr.metadata)) case dt: DecimalType => - ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s," + - "\"colVal\":%d,\"precision\":%s, \"scale\":%s}").format(omniDataType, - colVal, dt.precision, dt.scale) - case _ => ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s," + - "\"colVal\":%d}").format(omniDataType, colVal) + new JSONObject().put("exprType", "FIELD_REFERENCE") + .put("colVal", colVal) + .put("dataType", omniDataType.toInt) + .put("precision", dt.precision) + .put("scale", dt.scale) + case _ => new JSONObject().put("exprType", "FIELD_REFERENCE") + .put("dataType", omniDataType.toInt) + .put("colVal", colVal) } } - def toOmniJsonLiteral(literal: Literal): String = { + def toOmniJsonLiteral(literal: Literal): JSONObject = { val omniType = sparkTypeToOmniExpType(literal.dataType) val value = literal.value if (value == null) { - return "{\"exprType\":\"LITERAL\",\"dataType\":%s,\"isNull\":%b}".format(sparkTypeToOmniExpJsonType(literal.dataType), true) + return new JSONObject().put("exprType", "LITERAL") + .addOmniExpJsonType("dataType", literal.dataType) + .put("isNull", true) } literal.dataType match { case StringType => - ("{\"exprType\":\"LITERAL\",\"dataType\":%s," + - "\"isNull\":%b, \"value\":\"%s\",\"width\":%d}") - .format(omniType, false, value.toString, value.toString.length) + new JSONObject().put("exprType", "LITERAL") + .put("dataType", omniType.toInt) + .put("isNull", false) + .put("value", value.toString) + .put("width", value.toString.length) case dt: DecimalType => if (DecimalType.is64BitDecimalType(dt)) { - ("{\"exprType\":\"LITERAL\",\"dataType\":%s," + - "\"isNull\":%b,\"value\":%s,\"precision\":%s, \"scale\":%s}").format(omniType, - false, value.asInstanceOf[Decimal].toUnscaledLong, dt.precision, dt.scale) + new JSONObject().put("exprType", "LITERAL") + .put("dataType", omniType.toInt) + .put("isNull", false) + .put("value", value.asInstanceOf[Decimal].toUnscaledLong) + .put("precision", dt.precision) + .put("scale", dt.scale) } else { // NOTES: decimal128 literal value need use string format - ("{\"exprType\":\"LITERAL\",\"dataType\":%s," + - "\"isNull\":%b, \"value\":\"%s\", \"precision\":%s, \"scale\":%s}").format(omniType, - false, value.asInstanceOf[Decimal].toJavaBigDecimal.unscaledValue().toString(), - dt.precision, dt.scale) + new JSONObject().put("exprType", "LITERAL") + .put("dataType", omniType.toInt) + .put("isNull", false) + .put("value", value.asInstanceOf[Decimal].toJavaBigDecimal.unscaledValue().toString()) + .put("precision", dt.precision) + .put("scale", dt.scale) } case _ => - "{\"exprType\":\"LITERAL\",\"dataType\":%s, \"isNull\":%b, \"value\":%s}" - .format(omniType, false, value) - } + new JSONObject().put("exprType", "LITERAL") + .put("dataType", omniType.toInt) + .put("isNull", false) + .put("value", value) + } + } + + def checkFirstParamType(agg: AggregateExpression): Unit = { + agg.aggregateFunction.children.map( + exp => { + val exprDataType = exp.dataType + exprDataType match { + case ShortType => + case IntegerType => + case LongType => + case DoubleType => + case BooleanType => + case DateType => + case dt: DecimalType => + case StringType => + case _ => + throw new UnsupportedOperationException(s"First_value does not support datatype: $exprDataType") + } + } + ) } def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isMergeCount: Boolean = false): FunctionType = { @@ -744,8 +651,12 @@ object OmniExpressionAdaptor extends Logging { OMNI_AGGREGATION_TYPE_COUNT_ALL } case Count(_) if agg.aggregateFunction.children.size == 1 => OMNI_AGGREGATION_TYPE_COUNT_COLUMN - case First(_, true) => OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL - case First(_, false) => OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL + case First(_, true) => + checkFirstParamType(agg) + OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL + case First(_, false) => + checkFirstParamType(agg) + OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL case _ => throw new UnsupportedOperationException(s"Unsupported aggregate function: $agg") } } @@ -759,19 +670,19 @@ object OmniExpressionAdaptor extends Logging { } def toOmniAggInOutJSonExp(attribute: Seq[Expression], exprsIndexMap: Map[ExprId, Int]): - Array[String] = { - attribute.map(attr => rewriteToOmniJsonExpressionLiteral(attr, exprsIndexMap)).toArray + Array[String] = { + attribute.map(attr => rewriteToOmniJsonExpressionLiteral(attr, exprsIndexMap)).toArray } def toOmniAggInOutType(attribute: Seq[AttributeReference]): - Array[nova.hetu.omniruntime.`type`.DataType] = { - attribute.map(attr => - sparkTypeToOmniType(attr.dataType, attr.metadata)).toArray + Array[nova.hetu.omniruntime.`type`.DataType] = { + attribute.map(attr => + sparkTypeToOmniType(attr.dataType, attr.metadata)).toArray } def toOmniAggInOutType(dataType: DataType, metadata: Metadata = Metadata.empty): - Array[nova.hetu.omniruntime.`type`.DataType] = { - Array[nova.hetu.omniruntime.`type`.DataType](sparkTypeToOmniType(dataType, metadata)) + Array[nova.hetu.omniruntime.`type`.DataType] = { + Array[nova.hetu.omniruntime.`type`.DataType](sparkTypeToOmniType(dataType, metadata)) } def sparkTypeToOmniExpType(datatype: DataType): String = { @@ -789,20 +700,26 @@ object OmniExpressionAdaptor extends Logging { } else { OMNI_DECIMAL128_TYPE } + case NullType => OMNI_BOOLEAN_TYPE case _ => throw new UnsupportedOperationException(s"Unsupported datatype: $datatype") } } - def sparkTypeToOmniExpJsonType(datatype: DataType): String = { - val omniTypeIdStr = sparkTypeToOmniExpType(datatype) - datatype match { - case StringType => - "%s,\"width\":%s".format(omniTypeIdStr, DEFAULT_STRING_TYPE_LENGTH) - case dt: DecimalType => - "%s,\"precision\":%s,\"scale\":%s".format(omniTypeIdStr, dt.precision, dt.scale) - case _ => - omniTypeIdStr + implicit private class JSONObjectExtension(val jsonObject: JSONObject) { + def addOmniExpJsonType(jsonAttributeKey: String, datatype: DataType): JSONObject = { + val omniTypeIdStr = sparkTypeToOmniExpType(datatype) + datatype match { + case StringType => + jsonObject.put(jsonAttributeKey, omniTypeIdStr.toInt) + .put("width", DEFAULT_STRING_TYPE_LENGTH) + case dt: DecimalType => + jsonObject.put(jsonAttributeKey, omniTypeIdStr.toInt) + .put("precision", dt.precision) + .put("scale", dt.scale) + case _ => + jsonObject.put(jsonAttributeKey, omniTypeIdStr.toInt) + } } } @@ -840,19 +757,24 @@ object OmniExpressionAdaptor extends Logging { val omniDataType: String = sparkTypeToOmniExpType(dataType) dataType match { case ShortType | IntegerType | LongType | DoubleType | BooleanType | DateType => - "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s,\"colVal\":%d}" - .format(omniDataType, colVal) + new JSONObject().put("exprType", "FIELD_REFERENCE") + .put("dataType", omniDataType.toInt) + .put("colVal", colVal).toString case StringType => - "{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s,\"colVal\":%d,\"width\":%d}" - .format(omniDataType, colVal, getStringLength(metadata)) + new JSONObject().put("exprType", "FIELD_REFERENCE") + .put("dataType", omniDataType.toInt) + .put("colVal", colVal) + .put("width", getStringLength(metadata)).toString case dt: DecimalType => var omniDataType = OMNI_DECIMAL128_TYPE if (DecimalType.is64BitDecimalType(dt)) { omniDataType = OMNI_DECIMAL64_TYPE } - ("{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":%s,\"colVal\":%d," + - "\"precision\":%s,\"scale\":%s}") - .format(omniDataType, colVal, dt.precision, dt.scale) + new JSONObject().put("exprType", "FIELD_REFERENCE") + .put("dataType", omniDataType.toInt) + .put("colVal", colVal) + .put("precision", dt.precision) + .put("scale", dt.scale).toString case _ => throw new UnsupportedOperationException(s"Unsupported datatype: $dataType") } @@ -874,160 +796,27 @@ object OmniExpressionAdaptor extends Logging { } def procCaseWhenExpression(caseWhen: CaseWhen, - exprsIndexMap: Map[ExprId, Int]): String = { - val exprStr = "{\"exprType\":\"IF\",\"returnType\":%s,\"condition\":%s,\"if_true\":%s,\"if_false\":%s}" - var exprStrRes = exprStr - for (i <- caseWhen.branches.indices) { - var ifFalseStr = "" + exprsIndexMap: Map[ExprId, Int]): JSONObject = { + var jsonObject = new JSONObject() + for (i <- caseWhen.branches.indices.reverse) { + val outerJson = new JSONObject().put("exprType", "IF") + .addOmniExpJsonType("returnType", caseWhen.dataType) + .put("condition", rewriteToOmniJsonExpressionLiteralJsonObject(caseWhen.branches(i)._1, exprsIndexMap)) + .put("if_true", rewriteToOmniJsonExpressionLiteralJsonObject(caseWhen.branches(i)._2, exprsIndexMap)) + if (i != caseWhen.branches.length - 1) { - ifFalseStr = exprStr + val innerJson = new JSONObject(jsonObject, JSONObject.getNames(jsonObject)) + outerJson.put("if_false", innerJson) } else { var elseValue = caseWhen.elseValue if (elseValue.isEmpty) { - elseValue = Some(Literal(null, caseWhen.dataType)) + elseValue = Some(Literal(null, caseWhen.dataType)) } - ifFalseStr = rewriteToOmniJsonExpressionLiteral(elseValue.get, exprsIndexMap) + outerJson.put("if_false", rewriteToOmniJsonExpressionLiteralJsonObject(elseValue.get, exprsIndexMap)) } - exprStrRes = exprStrRes.format(sparkTypeToOmniExpJsonType(caseWhen.dataType), - rewriteToOmniJsonExpressionLiteral(caseWhen.branches(i)._1, exprsIndexMap), - rewriteToOmniJsonExpressionLiteral(caseWhen.branches(i)._2, exprsIndexMap), - ifFalseStr) + jsonObject = outerJson } - exprStrRes - } - - def procLikeExpression(likeExpr: Expression, - exprsIndexMap: Map[ExprId, Int]): String = { - likeExpr match { - case like: Like => - val dataType = like.right.dataType - like.right match { - case literal: Literal => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"LIKE\", \"arguments\":[%s, %s]}") - .format(sparkTypeToOmniExpJsonType(like.dataType), - rewriteToOmniJsonExpressionLiteral(like.left, exprsIndexMap), - generateLikeArg(literal,"")) - case _ => - throw new UnsupportedOperationException(s"Unsupported datatype in like expression: $dataType") - } - case startsWith: StartsWith => - val dataType = startsWith.right.dataType - startsWith.right match { - case literal: Literal => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"LIKE\", \"arguments\":[%s, %s]}") - .format(sparkTypeToOmniExpJsonType(startsWith.dataType), - rewriteToOmniJsonExpressionLiteral(startsWith.left, exprsIndexMap), - generateLikeArg(literal, "startsWith")) - case _ => - throw new UnsupportedOperationException(s"Unsupported datatype in like expression: $dataType") - } - case endsWith: EndsWith => - val dataType = endsWith.right.dataType - endsWith.right match { - case literal: Literal => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"LIKE\", \"arguments\":[%s, %s]}") - .format(sparkTypeToOmniExpJsonType(endsWith.dataType), - rewriteToOmniJsonExpressionLiteral(endsWith.left, exprsIndexMap), - generateLikeArg(literal, "endsWith")) - case _ => - throw new UnsupportedOperationException(s"Unsupported datatype in like expression: $dataType") - } - case contains: Contains => - val dataType = contains.right.dataType - contains.right match { - case literal: Literal => - ("{\"exprType\":\"FUNCTION\",\"returnType\":%s," + - "\"function_name\":\"LIKE\", \"arguments\":[%s, %s]}") - .format(sparkTypeToOmniExpJsonType(contains.dataType), - rewriteToOmniJsonExpressionLiteral(contains.left, exprsIndexMap), - generateLikeArg(literal, "contains")) - case _ => - throw new UnsupportedOperationException(s"Unsupported datatype in like expression: $dataType") - } - } - } - - def generateLikeArg(literal: Literal, exprFormat: String) : String = { - val value = literal.value - if (value == null) { - return "{\"exprType\":\"LITERAL\",\"dataType\":%s,\"isNull\":%b}".format(sparkTypeToOmniExpJsonType(literal.dataType), true) - } - var inputValue = value.toString - exprFormat match { - case "startsWith" => - inputValue = inputValue + "%" - case "endsWith" => - inputValue = "%" + inputValue - case "contains" => - inputValue = "%" + inputValue + "%" - case _ => - inputValue = value.toString - } - - val omniType = sparkTypeToOmniExpType(literal.dataType) - literal.dataType match { - case StringType => - val likeRegExpr = generateLikeRegExpr(inputValue) - ("{\"exprType\":\"LITERAL\",\"dataType\":%s," + - "\"isNull\":%b, \"value\":\"%s\",\"width\":%d}") - .format(omniType, false, likeRegExpr, likeRegExpr.length) - case dt: DecimalType => - toOmniJsonLiteral(literal) - case _ => - toOmniJsonLiteral(literal) - } - } - - def generateLikeRegExpr(value : String) : String = { - val regexString = new mutable.StringBuilder - regexString.append('^') - val valueArr = value.toCharArray - for (i <- 0 until valueArr.length) { - valueArr(i) match { - case '%' => - if (i - 1 < 0 || valueArr(i - 1) != '\\') { - regexString.append(".*") - } else { - regexString.append(valueArr(i)) - } - - case '_' => - if (i - 1 < 0 || valueArr(i - 1) != '\\') { - regexString.append(".") - } else { - regexString.append(valueArr(i)) - } - - case '\\' => - regexString.append("\\") - regexString.append(valueArr(i)) - - case '^' => - regexString.append("\\") - regexString.append(valueArr(i)) - - case '$' => - regexString.append("\\") - regexString.append(valueArr(i)) - - case '.' => - regexString.append("\\") - regexString.append(valueArr(i)) - - case '*' => - regexString.append("\\") - regexString.append(valueArr(i)) - - case _ => - regexString.append(valueArr(i)) - - } - } - regexString.append('$') - regexString.toString() + jsonObject } def toOmniJoinType(joinType: JoinType): nova.hetu.omniruntime.constants.JoinType = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala index de5638f0a7d927380a129538ba6f664830765f90..07ac07e8f81a47811ce8b979c9efcc3b52f882e3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala @@ -40,13 +40,13 @@ private class ColumnarBatchSerializerInstance( readBatchNumRows: SQLMetric, numOutputRows: SQLMetric) extends SerializerInstance with Logging { + private val columnarConf = ColumnarPluginConfig.getSessionConf + private val shuffleCompressBlockSize = columnarConf.columnarShuffleCompressBlockSize + private val enableShuffleCompress = columnarConf.enableShuffleCompress + private var shuffleCompressionCodec = columnarConf.columnarShuffleCompressionCodec + override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { - val columnarConf = ColumnarPluginConfig.getSessionConf - val shuffleCompressBlockSize = columnarConf.columnarShuffleCompressBlockSize - val enableShuffleCompress = columnarConf.enableShuffleCompress - var shuffleCompressionCodec = columnarConf.columnarShuffleCompressionCodec - if (!enableShuffleCompress) { shuffleCompressionCodec = "uncompressed" } @@ -146,4 +146,4 @@ private class ColumnarBatchSerializerInstance( override def serializeStream(s: OutputStream): SerializationStream = throw new UnsupportedOperationException -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index ed99f6b4311a48492438095a87d450f7d9d89a5a..875fe939dcbd932877165079a6f16400fc968865 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -19,7 +19,10 @@ package com.huawei.boostkit.spark.util import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP +import java.io.{File, IOException} +import java.util import java.util.concurrent.TimeUnit.NANOSECONDS + import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import nova.hetu.omniruntime.constants.FunctionType import nova.hetu.omniruntime.operator.OmniOperator @@ -33,9 +36,10 @@ import org.apache.spark.sql.execution.vectorized.{OmniColumnVector, OnHeapColumn import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import org.apache.spark.util.Utils import scala.collection.mutable.ListBuffer -import java.util +import scala.util.control.Breaks.{break, breakable} object OmniAdaptorUtil { def transColBatchToOmniVecs(cb: ColumnarBatch): Array[Vec] = { @@ -44,20 +48,34 @@ object OmniAdaptorUtil { def transColBatchToOmniVecs(cb: ColumnarBatch, isSlice: Boolean): Array[Vec] = { val input = new Array[Vec](cb.numCols()) - for (i <- 0 until cb.numCols()) { - val omniVec: Vec = cb.column(i) match { - case vector: OmniColumnVector => - if (!isSlice) { - vector.getVec - } else { - vector.getVec.slice(0, cb.numRows()) + try { + for (i <- 0 until cb.numCols()) { + val omniVec: Vec = cb.column(i) match { + case vector: OmniColumnVector => + if (!isSlice) { + vector.getVec + } else { + vector.getVec.slice(0, cb.numRows()) + } + case vector: ColumnVector => + transColumnVector(vector, cb.numRows()) + case _ => + throw new UnsupportedOperationException("unsupport column vector!") + } + input(i) = omniVec + } + } catch { + case e: Exception => { + for (j <- 0 until cb.numCols()) { + val vec = input(j) + if (vec != null) vec.close + cb.column(j) match { + case vector: OmniColumnVector => + vector.close() } - case vector: ColumnVector => - transColumnVector(vector, cb.numRows()) - case _ => - throw new UnsupportedOperationException("unsupport column vector!") + } + throw new RuntimeException("allocate memory failed!") } - input(i) = omniVec } input } @@ -182,7 +200,7 @@ object OmniAdaptorUtil { (Array[nova.hetu.omniruntime.`type`.DataType], Array[Int], Array[Int], Array[String]) = { val inputColSize: Int = output.size val sourceTypes = new Array[nova.hetu.omniruntime.`type`.DataType](inputColSize) - val ascendings = new Array[Int](sortOrder.size) + val ascending = new Array[Int](sortOrder.size) val nullFirsts = new Array[Int](sortOrder.size) val sortColsExp = new Array[String](sortOrder.size) val omniAttrExpsIdMap: Map[ExprId, Int] = getExprIdMap(output) @@ -192,7 +210,7 @@ object OmniAdaptorUtil { } sortOrder.zipWithIndex.foreach { case (sortAttr, i) => sortColsExp(i) = rewriteToOmniJsonExpressionLiteral(sortAttr.child, omniAttrExpsIdMap) - ascendings(i) = if (sortAttr.isAscending) { + ascending(i) = if (sortAttr.isAscending) { 1 } else { 0 @@ -205,18 +223,18 @@ object OmniAdaptorUtil { if (!isSimpleColumnForAll(sortColsExp)) { checkOmniJsonWhiteList("", sortColsExp.asInstanceOf[Array[AnyRef]]) } - (sourceTypes, ascendings, nullFirsts, sortColsExp) + (sourceTypes, ascending, nullFirsts, sortColsExp) } def addAllAndGetIterator(operator: OmniOperator, inputIter: Iterator[ColumnarBatch], schema: StructType, - addInputTime: SQLMetric, numInputVecBatchs: SQLMetric, + addInputTime: SQLMetric, numInputVecBatches: SQLMetric, numInputRows: SQLMetric, getOutputTime: SQLMetric, - numOutputVecBatchs: SQLMetric, numOutputRows: SQLMetric, + numOutputVecBatches: SQLMetric, numOutputRows: SQLMetric, outputDataSize: SQLMetric): Iterator[ColumnarBatch] = { while (inputIter.hasNext) { val batch: ColumnarBatch = inputIter.next() - numInputVecBatchs += 1 + numInputVecBatches+= 1 val input: Array[Vec] = transColBatchToOmniVecs(batch) val vecBatch = new VecBatch(input, batch.numRows()) val startInput: Long = System.nanoTime() @@ -259,7 +277,7 @@ object OmniAdaptorUtil { // metrics val rowCnt: Int = vecBatch.getRowCount numOutputRows += rowCnt - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 // close omni vecbetch vecBatch.close() new ColumnarBatch(vectors.toArray, rowCnt) @@ -282,7 +300,8 @@ object OmniAdaptorUtil { omniAggFunctionTypes: Array[FunctionType], omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], omniInputRaws: Array[Boolean], - omniOutputPartials: Array[Boolean]): OmniOperator = { + omniOutputPartials: Array[Boolean], + sparkSpillConf: SpillConfig = SpillConfig.NONE): OmniOperator = { var operator: OmniOperator = null if (groupingExpressions.nonEmpty) { operator = new OmniHashAggregationWithExprOperatorFactory( @@ -294,7 +313,8 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + new OperatorConfig(sparkSpillConf, new OverflowConfig(OmniAdaptorUtil.overflowConf()), + IS_SKIP_VERIFY_EXP)).createOperator } else { operator = new OmniAggregationWithExprOperatorFactory( omniGroupByChanel, @@ -344,14 +364,19 @@ object OmniAdaptorUtil { } def reorderVecs(prunedOutput: Seq[Attribute], projectList: Seq[NamedExpression], resultVecs: Array[nova.hetu.omniruntime.vector.Vec], vecs: Array[OmniColumnVector]) = { + val used = new Array[Boolean](resultVecs.length) for (index <- projectList.indices) { val project = projectList(index) - for (i <- prunedOutput.indices) { - val col = prunedOutput(i) - if (col.exprId.equals(getProjectAliasExprId(project))) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) + breakable { + for (i <- prunedOutput.indices) { + val col = prunedOutput(i) + if (!used(i) && col.exprId.equals(getProjectAliasExprId(project))) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + used(i) = true; + break + } } } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/OmniMapOutputTracker.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/OmniMapOutputTracker.scala new file mode 100644 index 0000000000000000000000000000000000000000..468b3036adf2206cf73216eab13145be2e91a41b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/OmniMapOutputTracker.scala @@ -0,0 +1,1730 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{ByteArrayInputStream, InputStream, IOException, ObjectInputStream, ObjectOutputStream} +import java.nio.ByteBuffer +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, ListBuffer, Map} +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.broadcast.{Broadcast, BroadcastManager} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.io.CompressionCodec +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus} +import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} +import org.apache.spark.util._ +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +/** + * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single + * ShuffleMapStage. + * + * This class maintains a mapping from map index to `MapStatus`. It also maintains a cache of + * serialized map statuses in order to speed up tasks' requests for map output statuses. + * + * All public methods of this class are thread-safe. + */ +private class ShuffleStatus( + numPartitions: Int, + numReducers: Int = -1) extends Logging { + + private val (readLock, writeLock) = { + val lock = new ReentrantReadWriteLock() + (lock.readLock(), lock.writeLock()) + } + + // All accesses to the following state must be guarded with `withReadLock` or `withWriteLock`. + private def withReadLock[B](fn: => B): B = { + readLock.lock() + try { + fn + } finally { + readLock.unlock() + } + } + + private def withWriteLock[B](fn: => B): B = { + writeLock.lock() + try { + fn + } finally { + writeLock.unlock() + } + } + + /** + * MapStatus for each partition. The index of the array is the map partition id. + * Each value in the array is the MapStatus for a partition, or null if the partition + * is not available. Even though in theory a task may run multiple times (due to speculation, + * stage retries, etc.), in practice the likelihood of a map output being available at multiple + * locations is so small that we choose to ignore that case and store only a single location + * for each output. + */ + // Exposed for testing + val mapStatuses = new Array[MapStatus](numPartitions) + + /** + * Keep the previous deleted MapStatus for recovery. + */ + val mapStatusesDeleted = new Array[MapStatus](numPartitions) + + /** + * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the + * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for + * a shuffle partition, or null if not available. When push-based shuffle is enabled, this array + * provides a reducer oriented view of the shuffle status specifically for the results of + * merging shuffle partition blocks into per-partition merged shuffle files. + */ + val mergeStatuses = if (numReducers > 0) { + new Array[MergeStatus](numReducers) + } else { + Array.empty[MergeStatus] + } + + /** + * The cached result of serializing the map statuses array. This cache is lazily populated when + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + */ + private[this] var cachedSerializedMapStatus: Array[Byte] = _ + + /** + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + */ + private[spark] var cachedSerializedBroadcast: Broadcast[Array[Array[Byte]]] = _ + + /** + * Similar to cachedSerializedMapStatus and cachedSerializedBroadcast, but for MergeStatus. + */ + private[this] var cachedSerializedMergeStatus: Array[Byte] = _ + + private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Array[Byte]]] = _ + + /** + * Counter tracking the number of partitions that have output. This is a performance optimization + * to avoid having to count the number of non-null entries in the `mapStatuses` array and should + * be equivalent to`mapStatuses.count(_ ne null)`. + */ + private[this] var _numAvailableMapOutputs: Int = 0 + + /** + * Counter tracking the number of MergeStatus results received so far from the shuffle services. + */ + private[this] var _numAvailableMergeResults: Int = 0 + + private[this] var shufflePushMergerLocations: Seq[BlockManagerId] = Seq.empty + + /** + * Register a map output. If there is already a registered location for the map output then it + * will be replaced by the new location. + */ + def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { + if (mapStatuses(mapIndex) == null) { + _numAvailableMapOutputs += 1 + invalidateSerializedMapOutputStatusCache() + } + mapStatuses(mapIndex) = status + } + + /** + * Update the map output location (e.g. during migration). + */ + def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock { + try { + val mapStatusOpt = mapStatuses.find(x => x != null && x.mapId == mapId) + mapStatusOpt match { + case Some(mapStatus) => + logInfo(s"Updating map output for ${mapId} to ${bmAddress}") + mapStatus.updateLocation(bmAddress) + invalidateSerializedMapOutputStatusCache() + case None => + val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId == mapId) + if (index >= 0 && mapStatuses(index) == null) { + val mapStatus = mapStatusesDeleted(index) + mapStatus.updateLocation(bmAddress) + mapStatuses(index) = mapStatus + _numAvailableMapOutputs += 1 + invalidateSerializedMapOutputStatusCache() + mapStatusesDeleted(index) = null + logInfo(s"Recover ${mapStatus.mapId} ${mapStatus.location}") + } else { + logWarning(s"Asked to update map output ${mapId} for untracked map status.") + } + } + } catch { + case e: java.lang.NullPointerException => + logWarning(s"Unable to update map output for ${mapId}, status removed in-flight") + } + } + + /** + * Remove the map output which was served by the specified block manager. + * This is a no-op if there is no registered map output or if the registered output is from a + * different block manager. + */ + def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock { + logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}") + if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { + _numAvailableMapOutputs -= 1 + mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex) + mapStatuses(mapIndex) = null + invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Register a merge result. + */ + def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock { + if (mergeStatuses(reduceId) != status) { + _numAvailableMergeResults += 1 + invalidateSerializedMergeOutputStatusCache() + } + mergeStatuses(reduceId) = status + } + + def registerShuffleMergerLocations(shuffleMergers: Seq[BlockManagerId]): Unit = withWriteLock { + if (shufflePushMergerLocations.isEmpty) { + shufflePushMergerLocations = shuffleMergers + } + } + + def removeShuffleMergerLocations(): Unit = withWriteLock { + shufflePushMergerLocations = Nil + } + + // TODO support updateMergeResult for similar use cases as updateMapOutput + + /** + * Remove the merge result which was served by the specified block manager. + */ + def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = withWriteLock { + if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = withWriteLock { + logDebug(s"Removing outputs for host ${host}") + removeOutputsByFilter(x => x.host == host) + removeMergeResultsByFilter(x => x.host == host) + } + + /** + * Removes all map outputs associated with the specified executor. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists), as they are + * still registered with that execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = withWriteLock { + logDebug(s"Removing outputs for execId ${execId}") + removeOutputsByFilter(x => x.executorId == execId) + } + + /** + * Removes all shuffle outputs which satisfies the filter. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { + for (mapIndex <- mapStatuses.indices) { + if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { + _numAvailableMapOutputs -= 1 + mapStatusesDeleted(mapIndex) = mapStatuses(mapIndex) + mapStatuses(mapIndex) = null + invalidateSerializedMapOutputStatusCache() + } + } + } + + /** + * Removes all shuffle merge result which satisfies the filter. + */ + def removeMergeResultsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { + for (reduceId <- mergeStatuses.indices) { + if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle map outputs. + */ + def numAvailableMapOutputs: Int = withReadLock { + _numAvailableMapOutputs + } + + /** + * Number of shuffle partitions that have already been merge finalized when push-based + * is enabled. + */ + def numAvailableMergeResults: Int = withReadLock { + _numAvailableMergeResults + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + */ + def findMissingPartitions(): Seq[Int] = withReadLock { + val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) + assert(missing.size == numPartitions - _numAvailableMapOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}") + missing + } + + /** + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * `MapOutputTracker.serializeOutputStatuses()` for more details on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the map statuses then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. + */ + def serializedMapStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): Array[Byte] = { + var result: Array[Byte] = null + withReadLock { + if (cachedSerializedMapStatus != null) { + result = cachedSerializedMapStatus + } + } + + if (result == null) withWriteLock { + if (cachedSerializedMapStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus]( + mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + // The following line has to be outside if statement since it's possible that another thread + // initializes cachedSerializedMapStatus in-between `withReadLock` and `withWriteLock`. + result = cachedSerializedMapStatus + } + result + } + + /** + * Serializes the mapStatuses and mergeStatuses array into an efficient compressed format. + * See the comments on `MapOutputTracker.serializeOutputStatuses()` for more details + * on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the statuses array then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. + */ + def serializedMapAndMergeStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): (Array[Byte], Array[Byte]) = { + val mapStatusesBytes: Array[Byte] = + serializedMapStatus(broadcastManager, isLocal, minBroadcastSize, conf) + var mergeStatusesBytes: Array[Byte] = null + + withReadLock { + if (cachedSerializedMergeStatus != null) { + mergeStatusesBytes = cachedSerializedMergeStatus + } + } + + if (mergeStatusesBytes == null) withWriteLock { + if (cachedSerializedMergeStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus]( + mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMergeStatus = serResult._1 + cachedSerializedBroadcastMergeStatus = serResult._2 + } + + // The following line has to be outside if statement since it's possible that another + // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and + // `withWriteLock`. + mergeStatusesBytes = cachedSerializedMergeStatus + } + (mapStatusesBytes, mergeStatusesBytes) + } + + // Used in testing. + def hasCachedSerializedBroadcast: Boolean = withReadLock { + cachedSerializedBroadcast != null + } + + /** + * Helper function which provides thread-safe access to the mapStatuses array. + * The function should NOT mutate the array. + */ + def withMapStatuses[T](f: Array[MapStatus] => T): T = withReadLock { + f(mapStatuses) + } + + def withMergeStatuses[T](f: Array[MergeStatus] => T): T = withReadLock { + f(mergeStatuses) + } + + def getShufflePushMergerLocations: Seq[BlockManagerId] = withReadLock { + shufflePushMergerLocations + } + + /** + * Clears the cached serialized map output statuses. + */ + def invalidateSerializedMapOutputStatusCache(): Unit = withWriteLock { + if (cachedSerializedBroadcast != null) { + // Prevent errors during broadcast cleanup from crashing the DAGScheduler (see SPARK-21444) + Utils.tryLogNonFatalError { + // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup + // RPCs to dead executors. + cachedSerializedBroadcast.destroy() + } + cachedSerializedBroadcast = null + } + cachedSerializedMapStatus = null + } + + /** + * Clears the cached serialized merge result statuses. + */ + def invalidateSerializedMergeOutputStatusCache(): Unit = withWriteLock { + if (cachedSerializedBroadcastMergeStatus != null) { + Utils.tryLogNonFatalError { + // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup + // RPCs to dead executors. + cachedSerializedBroadcastMergeStatus.destroy() + } + cachedSerializedBroadcastMergeStatus = null + } + cachedSerializedMergeStatus = null + } +} + +private[spark] sealed trait MapOutputTrackerMessage +private[spark] case class GetMapOutputStatuses(shuffleId: Int) + extends MapOutputTrackerMessage +private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int) + extends MapOutputTrackerMessage +private[spark] case class GetShufflePushMergerLocations(shuffleId: Int) + extends MapOutputTrackerMessage +private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage + +private[spark] sealed trait MapOutputTrackerMasterMessage +private[spark] case class GetMapOutputMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage +private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage +private[spark] case class GetShufflePushMergersMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage +private[spark] case class MapSizesByExecutorId( + iter: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], enableBatchFetch: Boolean) + +/** RpcEndpoint class for MapOutputTrackerMaster */ +private[spark] class MapOutputTrackerMasterEndpoint( + override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) + extends RpcEndpoint with Logging { + + logDebug("init") // force eager creation of logger + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetMapOutputStatuses(shuffleId: Int) => + val hostPort = context.senderAddress.hostPort + logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort") + tracker.post(GetMapOutputMessage(shuffleId, context)) + + case GetMapAndMergeResultStatuses(shuffleId: Int) => + val hostPort = context.senderAddress.hostPort + logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort") + tracker.post(GetMapAndMergeOutputMessage(shuffleId, context)) + + case GetShufflePushMergerLocations(shuffleId: Int) => + logInfo(s"Asked to send shuffle push merger locations for shuffle" + + s" $shuffleId to ${context.senderAddress.hostPort}") + tracker.post(GetShufflePushMergersMessage(shuffleId, context)) + + case StopMapOutputTracker => + logInfo("MapOutputTrackerMasterEndpoint stopped!") + context.reply(true) + stop() + } +} + +/** + * Class that keeps track of the location of the map output of a stage. This is abstract because the + * driver and executor have different versions of the MapOutputTracker. In principle the driver- + * and executor-side classes don't need to share a common base class; the current shared base class + * is maintained primarily for backwards-compatibility in order to avoid having to update existing + * test code. + */ +private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { + /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ + var trackerEndpoint: RpcEndpointRef = _ + + /** + * The driver-side counter is incremented every time that a map output is lost. This value is sent + * to executors as part of tasks, where executors compare the new epoch number to the highest + * epoch number that they received in the past. If the new epoch number is higher then executors + * will clear their local caches of map output statuses and will re-fetch (possibly updated) + * statuses from the driver. + */ + protected var epoch: Long = 0 + protected val epochLock = new AnyRef + + /** + * Send a message to the trackerEndpoint and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + protected def askTracker[T: ClassTag](message: Any): T = { + try { + trackerEndpoint.askSync[T](message) + } catch { + case e: Exception => + logError("Error communicating with MapOutputTracker", e) + throw new SparkException("Error communicating with MapOutputTracker", e) + } + } + + /** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */ + protected def sendTracker(message: Any): Unit = { + val response = askTracker[Boolean](message) + if (response != true) { + throw new SparkException( + "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) + } + } + + // For testing + def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + getMapSizesByExecutorId(shuffleId, 0, Int.MaxValue, reduceId, reduceId + 1) + } + + // For testing + def getPushBasedShuffleMapSizesByExecutorId(shuffleId: Int, reduceId: Int) + : MapSizesByExecutorId = { + getPushBasedShuffleMapSizesByExecutorId(shuffleId, 0, Int.MaxValue, reduceId, reduceId + 1) + } + + /** + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given range of map output partitions (startPartition is included but + * endPartition is excluded from the range) within a range of mappers (startMapIndex is included + * but endMapIndex is excluded) when push based shuffle is not enabled for the specific shuffle + * dependency. If endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length + * of total map outputs. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + * Note that zero-sized blocks are excluded in the result. + */ + def getMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + + /** + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given range of map output partitions (startPartition is included but + * endPartition is excluded from the range) within a range of mappers (startMapIndex is included + * but endMapIndex is excluded) when push based shuffle is enabled for the specific shuffle + * dependency. If endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length + * of total map outputs. + * + * @return A case class object which includes two attributes. The first attribute is a sequence + * of 2-item tuples, where the first item in the tuple is a BlockManagerId, and the + * second item is a sequence of (shuffle block id, shuffle block size, map index) tuples + * tuples describing the shuffle blocks that are stored at that block manager. Note that + * zero-sized blocks are excluded in the result. The second attribute is a boolean flag, + * indicating whether batch fetch can be enabled. + */ + def getPushBasedShuffleMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): MapSizesByExecutorId + + /** + * Called from executors upon fetch failure on an entire merged shuffle reduce partition. + * Such failures can happen if the shuffle client fails to fetch the metadata for the given + * merged shuffle partition. This method is to get the server URIs and output sizes for each + * shuffle block that is merged in the specified merged shuffle block so fetch failure on a + * merged shuffle block can fall back to fetching the unmerged blocks. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + + /** + * Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is + * to get the server URIs and output sizes for each shuffle block that is merged in the specified + * merged shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back + * to fetching the unmerged blocks. + * + * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is + * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original + * blocks part of this merged chunk. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + + /** + * Called from executors whenever a task with push based shuffle is enabled doesn't have shuffle + * mergers available. This typically happens when the initial stages doesn't have enough shuffle + * mergers available since very few executors got registered. This is on a best effort basis, + * if there is not enough shuffle mergers available for this stage then an empty sequence would + * be returned indicating the task to avoid shuffle push. + * @param shuffleId + */ + def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] + + /** + * Deletes map output status information for the specified shuffle stage. + */ + def unregisterShuffle(shuffleId: Int): Unit + + def stop(): Unit = {} +} + +/** + * Driver-side class that keeps track of the location of the map output of a stage. + * + * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics + * for performing locality-aware reduce task scheduling. + * + * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine + * which tasks need to be run. + */ +private[spark] class OmniMapOutputTrackerMaster( + conf: SparkConf, + private[spark] val broadcastManager: BroadcastManager, + private[spark] val isLocal: Boolean) + extends MapOutputTracker(conf) { + + // The size at which we use Broadcast to send the map output statuses to the executors + private val minSizeForBroadcast = conf.get(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST).toInt + + /** Whether to compute locality preferences for reduce tasks */ + private val shuffleLocalityEnabled = conf.get(SHUFFLE_REDUCE_LOCALITY_ENABLE) + + // Number of map and reduce tasks above which we do not assign preferred locations based on map + // output sizes. We limit the size of jobs for which assign preferred locations as computing the + // top locations by size becomes expensive. + private val SHUFFLE_PREF_MAP_THRESHOLD = 3000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private val SHUFFLE_PREF_REDUCE_THRESHOLD = 3000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. Making this larger will focus on fewer locations where most data + // can be read locally, but may lead to more delay in scheduling if those locations are busy. + private val REDUCER_PREF_LOCS_FRACTION = 0.2 + + // HashMap for storing shuffleStatuses in the driver. + // Statuses are dropped only by explicit de-registering. + // Exposed for testing + val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala + + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + + // requests for MapOutputTrackerMasterMessages + private val mapOutputTrackerMasterMessages = + new LinkedBlockingQueue[MapOutputTrackerMasterMessage] + + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf, isDriver = true) + + // Thread pool used for handling map output status requests. This is a separate thread pool + // to ensure we don't block the normal dispatcher threads. + private val threadpool: ThreadPoolExecutor = { + val numThreads = conf.get(SHUFFLE_MAPOUTPUT_DISPATCHER_NUM_THREADS) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + // Make sure that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { + val msg = s"${SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST.key} ($minSizeForBroadcast bytes) " + + s"must be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an " + + "rpc message that is too large." + logError(msg) + throw new IllegalArgumentException(msg) + } + + def post(message: MapOutputTrackerMasterMessage): Unit = { + mapOutputTrackerMasterMessages.offer(message) + } + + /** Message loop used for dispatching messages. */ + private class MessageLoop extends Runnable { + private def handleStatusMessage( + shuffleId: Int, + context: RpcCallContext, + needMergeOutput: Boolean): Unit = { + val hostPort = context.senderAddress.hostPort + val shuffleStatus = shuffleStatuses.get(shuffleId).head + logDebug(s"Handling request to send ${if (needMergeOutput) "map" else "map/merge"}" + + s" output locations for shuffle $shuffleId to $hostPort") + if (needMergeOutput) { + context.reply( + shuffleStatus. + serializedMapAndMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } else { + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } + } + + override def run(): Unit = { + try { + while (true) { + try { + val data = mapOutputTrackerMasterMessages.take() + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + mapOutputTrackerMasterMessages.offer(PoisonPill) + return + } + + data match { + case GetMapOutputMessage(shuffleId, context) => + handleStatusMessage(shuffleId, context, false) + case GetMapAndMergeOutputMessage(shuffleId, context) => + handleStatusMessage(shuffleId, context, true) + case GetShufflePushMergersMessage(shuffleId, context) => + logDebug(s"Handling request to send shuffle push merger locations for shuffle" + + s" $shuffleId to ${context.senderAddress.hostPort}") + context.reply(shuffleStatuses.get(shuffleId).map(_.getShufflePushMergerLocations) + .getOrElse(Seq.empty[BlockManagerId])) + } + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = GetMapOutputMessage(-99, null) + + // Used only in unit tests. + private[spark] def getNumCachedSerializedBroadcast: Int = { + shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) + } + + def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = { + if (pushBasedShuffleEnabled) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } else { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } + } + + def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.updateMapOutput(mapId, bmAddress) + case None => + logError(s"Asked to update map output for unknown shuffle ${shuffleId}") + } + } + + def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = { + shuffleStatuses(shuffleId).addMapOutput(mapIndex, status) + } + + /** Unregister map output information of the given shuffle, mapper and block manager */ + def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: BlockManagerId): Unit = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMapOutput(mapIndex, bmAddress) + incrementEpoch() + case None => + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + } + } + + /** Unregister all map and merge output information of the given shuffle. */ + def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeOutputsByFilter(x => true) + shuffleStatus.removeMergeResultsByFilter(x => true) + shuffleStatus.removeShuffleMergerLocations() + incrementEpoch() + case None => + throw new SparkException( + s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID $shuffleId.") + } + } + + def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus): Unit = { + shuffleStatuses(shuffleId).addMergeResult(reduceId, status) + } + + def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = { + statuses.foreach { + case (reduceId, status) => registerMergeResult(shuffleId, reduceId, status) + } + } + + def registerShufflePushMergerLocations( + shuffleId: Int, + shuffleMergers: Seq[BlockManagerId]): Unit = { + shuffleStatuses(shuffleId).registerShuffleMergerLocations(shuffleMergers) + } + + /** + * Unregisters a merge result corresponding to the reduceId if present. If the optional mapIndex + * is specified, it will only unregister the merge result if the mapIndex is part of that merge + * result. + * + * @param shuffleId the shuffleId. + * @param reduceId the reduceId. + * @param bmAddress block manager address. + * @param mapIndex the optional mapIndex which should be checked to see it was part of the + * merge result. + */ + def unregisterMergeResult( + shuffleId: Int, + reduceId: Int, + bmAddress: BlockManagerId, + mapIndex: Option[Int] = None): Unit = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + val mergeStatus = shuffleStatus.mergeStatuses(reduceId) + if (mergeStatus != null && + (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) { + shuffleStatus.removeMergeResult(reduceId, bmAddress) + incrementEpoch() + } + case None => + throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID") + } + } + + def unregisterAllMergeResult(shuffleId: Int): Unit = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMergeResultsByFilter(x => true) + incrementEpoch() + case None => + throw new SparkException( + s"unregisterAllMergeResult called for nonexistent shuffle ID $shuffleId.") + } + } + + /** Unregister shuffle data */ + def unregisterShuffle(shuffleId: Int): Unit = { + shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => + // SPARK-39553: Add protection for Scala 2.13 due to https://github.com/scala/bug/issues/12613 + // We should revert this if Scala 2.13 solves this issue. + if (shuffleStatus != null) { + shuffleStatus.invalidateSerializedMapOutputStatusCache() + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } + } + + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) } + incrementEpoch() + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + incrementEpoch() + } + + /** Check if the given shuffle is being tracked */ + def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + + def getNumAvailableOutputs(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0) + } + + /** VisibleForTest. Invoked in test only. */ + private[spark] def getNumAvailableMergeResults(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0) + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None + * if the MapOutputTrackerMaster doesn't know about this shuffle. + */ + def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = { + shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) + } + + /** + * Grouped function of Range, this is to avoid traverse of all elements of Range using + * IterableLike's grouped function. + */ + def rangeGrouped(range: Range, size: Int): Seq[Range] = { + val start = range.start + val step = range.step + val end = range.end + for (i <- start.until(end, size * step)) yield { + i.until(i + size * step, step) + } + } + + /** + * To equally divide n elements into m buckets, basically each bucket should have n/m elements, + * for the remaining n%m elements, add one more element to the first n%m buckets each. + */ + def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = { + val elementsPerBucket = numElements / numBuckets + val remaining = numElements % numBuckets + val splitPoint = (elementsPerBucket + 1) * remaining + if (elementsPerBucket == 0) { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) + } else { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++ + rangeGrouped(splitPoint.until(numElements), elementsPerBucket) + } + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + val parallelAggThreshold = conf.get( + SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) + val parallelism = math.min( + Runtime.getRuntime.availableProcessors(), + statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt + if (parallelism <= 1) { + statuses.filter(_ != null).foreach { s => + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + } else { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") + try { + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map { + reduceIds => Future { + statuses.filter(_ != null).foreach { s => + reduceIds.foreach(i => totalSizes(i) += s.getSizeForBlock(i)) + } + } + } + ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf) + } finally { + threadPool.shutdown() + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } + } + + /** + * Return the preferred hosts on which to run the given map output partition in a given shuffle, + * i.e. the nodes that the most outputs for that partition are on. If the map output is + * pre-merged, then return the node where the merged block is located if the merge ratio is + * above the threshold. + * + * @param dep shuffle dependency object + * @param partitionId map output partition that we want to read + * @return a sequence of host names + */ + def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int) + : Seq[String] = { + val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull + if (shuffleStatus != null) { + // Check if the map output is pre-merged and if the merge ratio is above the threshold. + // If so, the location of the merged block is the preferred location. + val preferredLoc = if (pushBasedShuffleEnabled) { + shuffleStatus.withMergeStatuses { statuses => + val status = statuses(partitionId) + val numMaps = dep.rdd.partitions.length + if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps + <= (1 - REDUCER_PREF_LOCS_FRACTION)) { + Seq(status.location.host) + } else { + Nil + } + } + } else { + Nil + } + if (preferredLoc.nonEmpty) { + preferredLoc + } else { + if (shuffleLocalityEnabled ) { + && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && + dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD + val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, + dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) + if (blockManagerIds.nonEmpty) { + blockManagerIds.get.map(_.host) + } else { + Nil + } + } else { + Nil + } + } + } else { + Nil + } + } + + /** + * Return a list of locations that each have fraction of map output greater than the specified + * threshold. + * + * @param shuffleId id of the shuffle + * @param reducerId id of the reduce task + * @param numReducers total number of reducers in the shuffle + * @param fractionThreshold fraction of total map output size that a location must have + * for it to be considered large. + */ + def getLocationsWithLargestOutputs( + shuffleId: Int, + reducerId: Int, + numReducers: Int, + fractionThreshold: Double) + : Option[Array[BlockManagerId]] = { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + // status may be null here if we are called between registerShuffle, which creates an + // array with null entries for each output, and registerMapOutputs, which populates it + // with valid status entries. This is possible if one thread schedules a job which + // depends on an RDD which is currently being computed by another thread. + if (status != null) { + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + } + mapIdx = mapIdx + 1 + } + val bandNS: Double = 10.0 + val latencyNS: Double = 10.0 + val bandSS: Double = 40.0 + val latencySS: Double = 40.0 + val netTransCost = new HashMap[BlockManagerId, Double]().withDefaultValue(0.0) + locs.foreach{ case (outerLoc, esize) => + locs.foreach{ case (innerLoc, isize) => + if ((outerLoc.topologyInfo != innerLoc.topologyInfo) + && !((outerLoc.topologyInfo.isEmpty && innerLoc.topologyInfo.exists(_ == "")) + || (innerLoc.topologyInfo.isEmpty && outerLoc.topologyInfo.exists(_ == "")))) { + // 不同机架,除去一个为""一个为None的情况 + val cost = 2 * (isize / bandNS + latencyNS + isize / bandSS + latencySS) + netTransCost(outerLoc) += cost + } + else if (outerLoc.host != innerLoc.host) { + // 同一机架不同节点 + val cost = 2 * (isize / bandNS + latencyNS) + netTransCost(outerLoc) += cost + } + } + } + val sortedPairs = netTransCost.toSeq.sortBy(_._2) + val numToTake = (sortedPairs.length / 2) + val selectedPairs = sortedPairs.take(numToTake) + val topLocs = HashMap(selectedPairs: _*) + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.keys.toArray) + } + } + } + } + None + } + + /** + * Return the locations where the Mappers ran. The locations each includes both a host and an + * executor id on that host. + * + * @param dep shuffle dependency object + * @param startMapIndex the start map index + * @param endMapIndex the end map index (exclusive) + * @return a sequence of locations where task runs. + */ + def getMapLocation( + dep: ShuffleDependency[_, _, _], + startMapIndex: Int, + endMapIndex: Int): Seq[String] = + { + val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => + if (startMapIndex < endMapIndex && + (startMapIndex >= 0 && endMapIndex <= statuses.length)) { + val statusesPicked = statuses.slice(startMapIndex, endMapIndex).filter(_ != null) + statusesPicked.map(_.location.host).toSeq + } else { + Nil + } + } + } else { + Nil + } + } + + def incrementEpoch(): Unit = { + epochLock.synchronized { + epoch += 1 + logDebug("Increasing epoch to " + epoch) + } + } + + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch + } + } + + // This method is only called in local-mode. + override def getMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + val mapSizesByExecutorId = getPushBasedShuffleMapSizesByExecutorId( + shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + assert(mapSizesByExecutorId.enableBatchFetch == true) + mapSizesByExecutorId.iter + } + + // This method is only called in local-mode. + override def getPushBasedShuffleMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): MapSizesByExecutorId = { + logDebug(s"Fetching outputs for shuffle $shuffleId") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex + logDebug(s"Convert map statuses for shuffle $shuffleId, " + + s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition") + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex) + } + case None => + MapSizesByExecutorId(Iterator.empty, true) + } + } + + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.iterator + } + + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.iterator + } + + // This method is only called in local-mode. + override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { + shuffleStatuses(shuffleId).getShufflePushMergerLocations + } + + override def stop(): Unit = { + mapOutputTrackerMasterMessages.offer(PoisonPill) + threadpool.shutdown() + try { + sendTracker(StopMapOutputTracker) + } catch { + case e: SparkException => + logError("Could not tell tracker we are stopping.", e) + } + trackerEndpoint = null + shuffleStatuses.clear() + } +} + +/** + * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. + * Note that this is not used in local-mode; instead, local-mode Executors access the + * MapOutputTrackerMaster directly (which is possible because the master and worker share a common + * superclass). + */ +private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { + + val mapStatuses: Map[Int, Array[MapStatus]] = + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + + val mergeStatuses: Map[Int, Array[MergeStatus]] = + new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala + + // This must be lazy to ensure that it is initialized when the first task is run and not at + // executor startup time. At startup time, user-added libraries may not have been + // downloaded to the executor, causing `isPushBasedShuffleEnabled` to fail when it tries to + // instantiate a serializer. See the followup to SPARK-36705 for more details. + private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf, isDriver = false) + + /** + * [[shufflePushMergerLocations]] tracks shuffle push merger locations for the latest + * shuffle execution + * + * Exposed for testing + */ + val shufflePushMergerLocations = new ConcurrentHashMap[Int, Seq[BlockManagerId]]().asScala + + /** + * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching + * the same shuffle block. + */ + private val fetchingLock = new KeyLock[Int] + + override def getMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + val mapSizesByExecutorId = getMapSizesByExecutorIdImpl( + shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = false) + assert(mapSizesByExecutorId.enableBatchFetch == true) + mapSizesByExecutorId.iter + } + + override def getPushBasedShuffleMapSizesByExecutorId( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int): MapSizesByExecutorId = { + getMapSizesByExecutorIdImpl( + shuffleId, startMapIndex, endMapIndex, startPartition, endPartition, useMergeResult = true) + } + + private def getMapSizesByExecutorIdImpl( + shuffleId: Int, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + useMergeResult: Boolean): MapSizesByExecutorId = { + logDebug(s"Fetching outputs for shuffle $shuffleId") + val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf, + // enableBatchFetch can be set to false during stage retry when the + // shuffleDependency.isShuffleMergeFinalizedMarked is set to false, and Driver + // has already collected the mergedStatus for its shuffle dependency. + // In this case, boolean check helps to ensure that the unnecessary + // mergeStatus won't be fetched, thus mergedOutputStatuses won't be + // passed to convertMapStatuses. See details in [SPARK-37023]. + if (useMergeResult) fetchMergeResult else false) + try { + val actualEndMapIndex = + if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex + logDebug(s"Convert map statuses for shuffle $shuffleId, " + + s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition") + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex, + actualEndMapIndex, Option(mergedOutputStatuses)) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf, fetchMergeResult) + try { + val mergeStatus = mergeResultStatuses(partitionId) + // If the original MergeStatus is no longer available, we cannot identify the list of + // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case. + MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId) + // Use the MergeStatus's partition level bitmap since we are doing partition level fallback + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, + mapOutputStatuses, mergeStatus.tracker) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, _) = getStatuses(shuffleId, conf, fetchMergeResult) + try { + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses, + chunkTracker) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { + shufflePushMergerLocations.getOrElse(shuffleId, getMergerLocations(shuffleId)) + } + + private def getMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { + fetchingLock.withLock(shuffleId) { + var fetchedMergers = shufflePushMergerLocations.get(shuffleId).orNull + if (null == fetchedMergers) { + fetchedMergers = + askTracker[Seq[BlockManagerId]](GetShufflePushMergerLocations(shuffleId)) + if (fetchedMergers.nonEmpty) { + shufflePushMergerLocations(shuffleId) = fetchedMergers + } else { + fetchedMergers = Seq.empty[BlockManagerId] + } + } + fetchedMergers + } + } + + /** + * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled + * for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses( + shuffleId: Int, + conf: SparkConf, + canFetchMergeResult: Boolean): (Array[MapStatus], Array[MergeStatus]) = { + if (canFetchMergeResult) { + val mapOutputStatuses = mapStatuses.get(shuffleId).orNull + val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull + + if (mapOutputStatuses == null || mergeOutputStatuses == null) { + logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull + var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull + if (fetchedMapStatuses == null || fetchedMergeStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = + askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId)) + try { + fetchedMapStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) + fetchedMergeStatuses = + MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map/merge statuses" + + s" for shuffle $shuffleId: " + e.getCause) + } + logInfo("Got the map/merge output locations") + mapStatuses.put(shuffleId, fetchedMapStatuses) + mergeStatuses.put(shuffleId, fetchedMergeStatuses) + } + logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedMapStatuses, fetchedMergeStatuses) + } + } else { + (mapOutputStatuses, mergeOutputStatuses) + } + } else { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + try { + fetchedStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + + e.getCause) + } + logInfo("Got the map output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedStatuses, null) + } + } else { + (statuses, null) + } + } + } + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int): Unit = { + mapStatuses.remove(shuffleId) + mergeStatuses.remove(shuffleId) + shufflePushMergerLocations.remove(shuffleId) + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each executor task calls this with the latest epoch + * number on the driver at the time it was created. + */ + def updateEpoch(newEpoch: Long): Unit = { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + mergeStatuses.clear() + shufflePushMergerLocations.clear() + } + } + } +} + +private[spark] object MapOutputTracker extends Logging { + + val ENDPOINT_NAME = "MapOutputTracker" + private val DIRECT = 0 + private val BROADCAST = 1 + + val SHUFFLE_PUSH_MAP_ID = -1 + + // Serialize an array of map/merge output locations into an efficient byte format so that we can + // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will + // generally be pretty compressible because many outputs will be on the same hostname. + def serializeOutputStatuses[T <: ShuffleOutputStatus]( + statuses: Array[T], + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): (Array[Byte], Broadcast[Array[Array[Byte]]]) = { + // ByteArrayOutputStream has the 2GB limit so use ChunkedByteBufferOutputStream instead + val out = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) + out.write(DIRECT) + val codec = CompressionCodec.createCodec(conf, conf.get(MAP_STATUS_COMPRESSION_CODEC)) + val objOut = new ObjectOutputStream(codec.compressedOutputStream(out)) + Utils.tryWithSafeFinally { + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } + } { + objOut.close() + } + val chunkedByteBuf = out.toChunkedByteBuffer + val arrSize = out.size + if (arrSize >= minBroadcastSize) { + // Use broadcast instead. + // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! + // arr is a nested Array so that it can handle over 2GB serialized data + val arr = chunkedByteBuf.getChunks().map(_.array()) + val bcast = broadcastManager.newBroadcast(arr, isLocal) + // Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard one + // This implementation doesn't reallocate the whole memory block but allocates + // additional buffers. This way no buffers need to be garbage collected and + // the contents don't have to be copied to the new buffer. + val out = new ApacheByteArrayOutputStream() + out.write(BROADCAST) + val oos = new ObjectOutputStream(codec.compressedOutputStream(out)) + Utils.tryWithSafeFinally { + oos.writeObject(bcast) + } { + oos.close() + } + val outArr = out.toByteArray + logInfo("Broadcast outputstatuses size = " + outArr.length + ", actual size = " + arrSize) + (outArr, bcast) + } else { + (chunkedByteBuf.toArray, null) + } + } + + // Opposite of serializeOutputStatuses. + def deserializeOutputStatuses[T <: ShuffleOutputStatus]( + bytes: Array[Byte], conf: SparkConf): Array[T] = { + assert (bytes.length > 0) + + def deserializeObject(in: InputStream): AnyRef = { + val codec = CompressionCodec.createCodec(conf, conf.get(MAP_STATUS_COMPRESSION_CODEC)) + // The ZStd codec is wrapped in a `BufferedInputStream` which avoids overhead excessive + // of JNI call while trying to decompress small amount of data for each element + // of `MapStatuses` + val objIn = new ObjectInputStream(codec.compressedInputStream(in)) + Utils.tryWithSafeFinally { + objIn.readObject() + } { + objIn.close() + } + } + + val in = new ByteArrayInputStream(bytes, 1, bytes.length - 1) + bytes(0) match { + case DIRECT => + deserializeObject(in).asInstanceOf[Array[T]] + case BROADCAST => + try { + // deserialize the Broadcast, pull .value array out of it, and then deserialize that + val bcast = deserializeObject(in).asInstanceOf[Broadcast[Array[Array[Byte]]]] + logInfo("Broadcast outputstatuses size = " + bytes.length + + ", actual size = " + bcast.value.foldLeft(0L)(_ + _.length)) + val bcastIn = new ChunkedByteBuffer(bcast.value.map(ByteBuffer.wrap)).toInputStream() + // Important - ignore the DIRECT tag ! Start from offset 1 + bcastIn.skip(1) + deserializeObject(bcastIn).asInstanceOf[Array[T]] + } catch { + case e: IOException => + logWarning("Exception encountered during deserializing broadcasted" + + " output statuses: ", e) + throw new SparkException("Unable to deserialize broadcasted" + + " output statuses", e) + } + case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) + } + } + + /** + * Given an array of map statuses and a range of map output partitions, returns a sequence that, + * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes + * stored at that block manager. + * Note that empty blocks are filtered in the result. + * + * If push-based shuffle is enabled and an array of merge statuses is available, prioritize + * the locations of the merged shuffle partitions over unmerged shuffle blocks. + * + * If any of the statuses is null (indicating a missing location due to a failed mapper), + * throws a FetchFailedException. + * + * @param shuffleId Identifier for the shuffle + * @param startPartition Start of map output partition ID range (included in range) + * @param endPartition End of map output partition ID range (excluded from range) + * @param mapStatuses List of map statuses, indexed by map partition index. + * @param startMapIndex Start Map index. + * @param endMapIndex End Map index. + * @param mergeStatuses List of merge statuses, index by reduce ID. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def convertMapStatuses( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + mapStatuses: Array[MapStatus], + startMapIndex : Int, + endMapIndex: Int, + mergeStatuses: Option[Array[MergeStatus]] = None): MapSizesByExecutorId = { + assert (mapStatuses != null) + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + var enableBatchFetch = true + // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle + // partition consists of blocks merged in random order, we are unable to serve map index + // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of + // map outputs, it usually indicates skewed partitions which push-based shuffle delegates + // to AQE to handle. + // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle, + // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end + // TODO: map indexes + if (mergeStatuses.exists(_.exists(_ != null)) && startMapIndex == 0 + && endMapIndex == mapStatuses.length) { + enableBatchFetch = false + logDebug(s"Disable shuffle batch fetch as Push based shuffle is enabled for $shuffleId.") + // We have MergeStatus and full range of mapIds are requested so return a merged block. + val numMaps = mapStatuses.length + mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach { + case (mergeStatus, partId) => + val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) { + // If MergeStatus is available for the given partition, add location of the + // pre-merged shuffle partition for this partition ID. Here we create a + // ShuffleMergedBlockId to indicate this is a merged shuffle block. + splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += + ((ShuffleMergedBlockId(shuffleId, mergeStatus.shuffleMergeId, partId), + mergeStatus.totalSize, SHUFFLE_PUSH_MAP_ID)) + // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper + // shuffle partition blocks, fetch the original map produced shuffle partition blocks + val mapStatusesWithIndex = mapStatuses.zipWithIndex + mergeStatus.getMissingMaps(numMaps).map(mapStatusesWithIndex) + } else { + // If MergeStatus is not available for the given partition, fall back to + // fetching all the original mapper shuffle partition blocks + mapStatuses.zipWithIndex.toSeq + } + // Add location for the mapper shuffle partition blocks + for ((mapStatus, mapIndex) <- remainingMapStatuses) { + validateStatus(mapStatus, shuffleId, partId) + val size = mapStatus.getSizeForBlock(partId) + if (size != 0) { + splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), size, mapIndex)) + } + } + } + } else { + val iter = mapStatuses.iterator.zipWithIndex + for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { + validateStatus(status, shuffleId, startPartition) + for (part <- startPartition until endPartition) { + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, status.mapId, part), size, mapIndex)) + } + } + } + } + + MapSizesByExecutorId(splitsByAddress.mapValues(_.toSeq).iterator, enableBatchFetch) + } + + /** + * Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding + * to either a merged shuffle partition or a merged shuffle partition chunk, identify + * the metadata about the shuffle partition blocks that are merged into the merged shuffle + * partition or partition chunk represented by the bitmap. + * + * @param shuffleId Identifier for the shuffle + * @param partitionId The partition ID of the MergeStatus for which we look for the metadata + * of the merged shuffle partition blocks + * @param mapStatuses List of map statuses, indexed by map ID + * @param tracker bitmap containing mapIndexes that belong to the merged block or merged + * block chunk. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapStatusesForMergeStatus( + shuffleId: Int, + partitionId: Int, + mapStatuses: Array[MapStatus], + tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + assert (mapStatuses != null && tracker != null) + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + for ((status, mapIndex) <- mapStatuses.zipWithIndex) { + // Only add blocks that are merged + if (tracker.contains(mapIndex)) { + MapOutputTracker.validateStatus(status, shuffleId, partitionId) + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, status.mapId, partitionId), + status.getSizeForBlock(partitionId), mapIndex)) + } + } + splitsByAddress.mapValues(_.toSeq).iterator + } + + def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, partition, errorMessage) + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala new file mode 100644 index 0000000000000000000000000000000000000000..f0dd04487fff7420a86e501c64a5345d0794738b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorder.scala @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec +import scala.collection.mutable + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, EqualNullSafe, EqualTo, Expression, IsNotNull, PredicateHelper} +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.util.sideBySide + + + + +/** + * Move all cartesian products to the root of the plan + */ +object DelayCartesianProduct extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Extract cliques from the input plans. + * A cliques is a sub-tree(sub-plan) which doesn't have any join with other sub-plan. + * The input plans are picked from left to right + * , until we can't find join condition in the remaining plans. + * The same logic is applied to the remaining plans, until all plans are picked. + * This function can produce a left-deep tree or a bushy tree. + * + * @param input a list of LogicalPlans to inner join and the type of inner join. + * @param conditions a list of condition for join. + */ + private def extractCliques(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + : Seq[(LogicalPlan, InnerLike)] = { + if (input.size == 1) { + input + } else { + val (leftPlan, leftInnerJoinType) :: linearSeq = input + // discover the initial join that contains at least one join condition + val conditionalOption = linearSeq.find { planJoinPair => + val plan = planJoinPair._1 + val refs = leftPlan.outputSet ++ plan.outputSet + conditions + .filterNot(l => l.references.nonEmpty && canEvaluate(l, leftPlan)) + .filterNot(r => r.references.nonEmpty && canEvaluate(r, plan)) + .exists(_.references.subsetOf(refs)) + } + + if (conditionalOption.isEmpty) { + Seq((leftPlan, leftInnerJoinType)) ++ extractCliques(linearSeq, conditions) + } else { + val (rightPlan, rightInnerJoinType) = conditionalOption.get + + val joinedRefs = leftPlan.outputSet ++ rightPlan.outputSet + val (joinConditions, otherConditions) = conditions.partition( + e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) + val joined = Join(leftPlan, rightPlan, rightInnerJoinType, + joinConditions.reduceLeftOption(And), JoinHint.NONE) + + // must not make reference to the same logical plan + extractCliques(Seq((joined, Inner)) + ++ linearSeq.filterNot(_._1 eq rightPlan), otherConditions) + } + } + } + + /** + * Link cliques by cartesian product + * + * @param input + * @return + */ + private def linkCliques(input: Seq[(LogicalPlan, InnerLike)]) + : LogicalPlan = { + if (input.length == 1) { + input.head._1 + } else if (input.length == 2) { + val ((left, innerJoinType1), (right, innerJoinType2)) = (input(0), input(1)) + val joinType = resetJoinType(innerJoinType1, innerJoinType2) + Join(left, right, joinType, None, JoinHint.NONE) + } else { + val (left, innerJoinType1) :: (right, innerJoinType2) :: rest = input + val joinType = resetJoinType(innerJoinType1, innerJoinType2) + linkCliques(Seq((Join(left, right, joinType, None, JoinHint.NONE), joinType)) ++ rest) + } + } + + /** + * This is to reset the join type before reordering. + * + * @param leftJoinType + * @param rightJoinType + * @return + */ + private def resetJoinType(leftJoinType: InnerLike, rightJoinType: InnerLike): InnerLike = { + (leftJoinType, rightJoinType) match { + case (_, Cross) | (Cross, _) => Cross + case _ => Inner + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!ColumnarPluginConfig.getSessionConf.enableDelayCartesianProduct) { + return plan + } + + // Reorder joins only when there are cartesian products. + var existCartesianProduct = false + plan foreach { + case Join(_, _, _: InnerLike, None, _) => existCartesianProduct = true + case _ => + } + + if (existCartesianProduct) { + plan.transform { + case originalPlan@ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + val cliques = extractCliques(input, conditions) + val reorderedPlan = linkCliques(cliques) + + reorderedPlan match { + // Generate a bushy tree after reordering. + case ExtractFiltersAndInnerJoinsForBushy(_, joinConditions) => + val primalConditions = conditions.flatMap(splitConjunctivePredicates) + val reorderedConditions = joinConditions.flatMap(splitConjunctivePredicates).toSet + val missingConditions = primalConditions.filterNot(reorderedConditions.contains) + if (missingConditions.nonEmpty) { + val comparedPlans = + sideBySide(originalPlan.treeString, reorderedPlan.treeString).mkString("\n") + logWarning("There are missing conditions after reordering, falling back to the " + + s"original plan. == Comparing two plans ===\n$comparedPlans") + originalPlan + } else { + reorderedPlan + } + case _ => throw new AnalysisException( + s"There is no join node in the plan, this should not happen: $reorderedPlan") + } + } + } else { + plan + } + } +} + +/** + * Firstly, Heuristic reorder join need to execute small joins with filters + * , which can reduce intermediate results + */ +object HeuristicJoinReorder extends Rule[LogicalPlan] + with PredicateHelper with JoinSelectionHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * The joined plan are picked from left to right, thus the final result is a left-deep tree. + * + * @param input a list of LogicalPlans to inner join and the type of inner join. + * @param conditions a list of condition for join. + */ + @tailrec + final def createReorderJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + : LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) + val ((leftPlan, leftJoinType), (rightPlan, rightJoinType)) = (input(0), input(1)) + val innerJoinType = (leftJoinType, rightJoinType) match { + case (Inner, Inner) => Inner + case (_, _) => Cross + } + // Set the join node ordered so that we don't need to transform them again. + val orderJoin = OrderedJoin(leftPlan, rightPlan, innerJoinType, joinConditions.reduceLeftOption(And)) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), orderJoin) + } else { + orderJoin + } + } else { + val (left, _) :: rest = input.toList + val candidates = rest.filter { planJoinPair => + val plan = planJoinPair._1 + // 1. it has join conditions with the left node + // 2. it has a filter + // 3. it can be broadcast + val isEqualJoinCondition = conditions.flatMap { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None + case EqualNullSafe(l, r) if l.references.isEmpty || r.references.isEmpty => None + case e@EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) + case e@EqualTo(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) + case e@EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, plan) => Some(e) + case e@EqualNullSafe(l, r) if canEvaluate(l, plan) && canEvaluate(r, left) => Some(e) + case _ => None + }.nonEmpty + + val hasFilter = plan match { + case f: Filter if hasValuableCondition(f.condition) => true + case Project(_, f: Filter) if hasValuableCondition(f.condition) => true + case _ => false + } + + isEqualJoinCondition && hasFilter + } + val (right, innerJoinType) = if (candidates.nonEmpty) { + candidates.minBy(_._1.stats.sizeInBytes) + } else { + rest.head + } + + val joinedRefs = left.outputSet ++ right.outputSet + val selectedJoinConditions = mutable.HashSet.empty[Expression] + val (joinConditions, others) = conditions.partition { e => + // If there are semantically equal conditions, they should come from two different joins. + // So we should not put them into one join. + if (!selectedJoinConditions.contains(e.canonicalized) && e.references.subsetOf(joinedRefs) + && canEvaluateWithinJoin(e)) { + selectedJoinConditions.add(e.canonicalized) + true + } else { + false + } + } + // Set the join node ordered so that we don't need to transform them again. + val joined = OrderedJoin(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createReorderJoin(Seq((joined, Inner)) ++ rest.filterNot(_._1 eq right), others) + } + } + + private def hasValuableCondition(condition: Expression): Boolean = { + val conditions = splitConjunctivePredicates(condition) + !conditions.forall(_.isInstanceOf[IsNotNull]) + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (ColumnarPluginConfig.getSessionConf.enableHeuristicJoinReorder) { + val newPlan = plan.transform { + case p@ExtractFiltersAndInnerJoinsByIgnoreProjects(input, conditions) + if input.size > 2 && conditions.nonEmpty => + val reordered = createReorderJoin(input, conditions) + if (p.sameOutput(reordered)) { + reordered + } else { + // Reordering the joins have changed the order of the columns. + // Inject a projection to make sure we restore to the expected ordering. + Project(p.output, reordered) + } + } + + // After reordering is finished, convert OrderedJoin back to Join + val result = newPlan.transformDown { + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond, JoinHint.NONE) + } + if (!result.resolved) { + // In some special cases related to subqueries, we find that after reordering, + val comparedPlans = sideBySide(plan.treeString, result.treeString).mkString("\n") + logWarning("The structural integrity of the plan is broken, falling back to the " + + s"original plan. == Comparing two plans ===\n$comparedPlans") + plan + } else { + result + } + } else { + plan + } + } +} + +/** + * This is different from [[ExtractFiltersAndInnerJoins]] in that it can collect filters and + * inner joins by ignoring projects on top of joins, which are produced by column pruning. + */ +private object ExtractFiltersAndInnerJoinsByIgnoreProjects extends PredicateHelper { + + /** + * Flatten all inner joins, which are next to each other. + * Return a list of logical plans to be joined with a boolean for each plan indicating if it + * was involved in an explicit cross join. Also returns the entire list of join conditions for + * the left-deep tree. + */ + def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) + : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { + case Join(left, right, joinType: InnerLike, cond, hint) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(left, joinType) + (plans ++ Seq((right, joinType)), conditions ++ + cond.toSeq.flatMap(splitConjunctivePredicates)) + case Filter(filterCondition, j@Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + case Project(projectList, child) + if projectList.forall(_.isInstanceOf[Attribute]) => flattenJoin(child) + + case _ => (Seq((plan, parentJoinType)), Seq.empty) + } + + def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] + = plan match { + case f@Filter(_, Join(_, _, _: InnerLike, _, _)) => + Some(flattenJoin(f)) + case j@Join(_, _, _, _, hint) if hint == JoinHint.NONE => + Some(flattenJoin(j)) + case _ => None + } +} + +private object ExtractFiltersAndInnerJoinsForBushy extends PredicateHelper { + + /** + * This function works for both left-deep and bushy trees. + * + * @param plan + * @param parentJoinType + * @return + */ + def flattenJoin(plan: LogicalPlan, parentJoinType: InnerLike = Inner) + : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { + case Join(left, right, joinType: InnerLike, cond, _) => + val (lPlans, lConds) = flattenJoin(left, joinType) + val (rPlans, rConds) = flattenJoin(right, joinType) + (lPlans ++ rPlans, lConds ++ rConds ++ cond.toSeq) + + case Filter(filterCondition, j@Join(_, _, _: InnerLike, _, _)) => + val (plans, conditions) = flattenJoin(j) + (plans, conditions ++ splitConjunctivePredicates(filterCondition)) + + case _ => (Seq((plan, parentJoinType)), Seq()) + } + + def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] = { + plan match { + case f@Filter(_, Join(_, _, _: InnerLike, _, _)) => + Some(flattenJoin(f)) + case j@Join(_, _, _, _, _) => + Some(flattenJoin(j)) + case _ => None + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala new file mode 100644 index 0000000000000000000000000000000000000000..1b5baa23080e816868934da02c8cbe87fa21c4d8 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFilters.scala @@ -0,0 +1,685 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Max, Min, Sum} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, Filter, Join, LogicalPlan, Project, Subquery, WithCTE} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * The skeleton of this rule is just as same as MergeScalarSubqueries Rule. This rule relaxes the + * constraint of filters which can be merged. + */ +object MergeSubqueryFilters extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + // Subquery reuse needs to be enabled for this optimization. + case _ if !conf.getConf(SQLConf.SUBQUERY_REUSE_ENABLED) => plan + + // This rule does a whole plan traversal, no need to run on subqueries. + case _: Subquery => plan + + // Plans with CTEs are not supported for now. + case _: WithCTE => plan + + case _ => extractCommonScalarSubqueries(plan) + } + } + + /** + * An item in the cache of merged scalar subqueries. + * + * @param attributes Attributes that form the struct scalar return value of a merged subquery. + * @param plan The plan of a merged scalar subquery. + * @param merged A flag to identify if this item is the result of merging subqueries. + * Please note that `attributes.size == 1` doesn't always mean that the plan + * is not merged as there can be subqueries that are different + * ([[checkIdenticalPlans]] is false) due to an extra [[Project]] node in + * one of them. In that case `attributes.size` remains 1 after merging, but + * the merged flag becomes true. + * @param references A set of subquery indexes in the cache to track all (including transitive) + * nested subqueries. + */ + case class Header( + attributes: Seq[Attribute], + plan: LogicalPlan, + merged: Boolean, + references: Set[Int]) + + private def extractCommonScalarSubqueries(plan: LogicalPlan) = { + val cache = ArrayBuffer.empty[Header] + val planWithReferences = insertReferences(plan, cache) + cache.zipWithIndex.foreach { case (header, i) => + cache(i) = cache(i).copy(plan = + if (header.merged) { + CTERelationDef( + createProject(header.attributes, + removeReferences(removePropagatedFilters(header.plan), cache)), + underSubquery = true) + } else { + removeReferences(header.plan, cache) + }) + } + val newPlan = removeReferences(planWithReferences, cache) + val subqueryCTEs = cache.filter(_.merged).map(_.plan.asInstanceOf[CTERelationDef]) + if (subqueryCTEs.nonEmpty) { + WithCTE(newPlan, subqueryCTEs.toSeq) + } else { + newPlan + } + } + + // First traversal builds up the cache and inserts `ScalarSubqueryReference`s to the plan. + private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]): LogicalPlan = { + plan.transformUpWithSubqueries { + case n => n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) { + // The subquery could contain a hint that is not propagated once we cache it, but as a + // non-correlated scalar subquery won't be turned into a Join the loss of hints is fine. + case s: ScalarSubquery if !s.isCorrelated && s.deterministic => + val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache) + ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType, s.exprId) + } + } + } + + // Caching returns the index of the subquery in the cache and the index of scalar member in the + // "Header". + private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): (Int, Int) = { + val output = plan.output.head + val references = mutable.HashSet.empty[Int] + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { + case ssr: ScalarSubqueryReference => + references += ssr.subqueryIndex + references ++= cache(ssr.subqueryIndex).references + ssr + } + + cache.zipWithIndex.collectFirst(Function.unlift { + case (header, subqueryIndex) if !references.contains(subqueryIndex) => + checkIdenticalPlans(plan, header.plan).map { outputMap => + val mappedOutput = mapAttributes(output, outputMap) + val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) + subqueryIndex -> headerIndex + }.orElse { + tryMergePlans(plan, header.plan, false).collect { + case (mergedPlan, outputMap, None, None, _) => + val mappedOutput = mapAttributes(output, outputMap) + var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) + val newHeaderAttributes = if (headerIndex == -1) { + headerIndex = header.attributes.size + header.attributes :+ mappedOutput + } else { + header.attributes + } + cache(subqueryIndex) = + Header(newHeaderAttributes, mergedPlan, true, header.references ++ references) + subqueryIndex -> headerIndex + } + } + case _ => None + }).getOrElse { + cache += Header(Seq(output), plan, false, references.toSet) + cache.length - 1 -> 0 + } + } + + // If 2 plans are identical return the attribute mapping from the new to the cached version. + private def checkIdenticalPlans( + newPlan: LogicalPlan, + cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = { + if (newPlan.canonicalized == cachedPlan.canonicalized) { + Some(AttributeMap(newPlan.output.zip(cachedPlan.output))) + } else { + None + } + } + + /** + * Recursively traverse down and try merging 2 plans. + * + * Please note that merging arbitrary plans can be complicated, the current version supports only + * some of the most important nodes. + * + * @param newPlan a new plan that we want to merge to an already processed plan + * @param cachedPlan a plan that we already processed, it can be either an + * original plan or a merged version of 2 or more plans + * @param filterPropagationSupported a boolean flag that we propagate down to signal we have seen + * an `Aggregate` node where propagated filters can be merged + * @return A tuple of: + * - the merged plan, + * - the attribute mapping from the new to the merged version, + * - the 2 optional filters of both plans that we need to propagate up and merge in + * an ancestor `Aggregate` node if possible, + * - the optional accumulated extra cost of merge that we need to propagate up and + * check in the ancestor `Aggregate` node. + * The cost is optional to signal if the cost needs to be taken into account up in the + * `Aggregate` node to decide about merge. + */ + private def tryMergePlans( + newPlan: LogicalPlan, + cachedPlan: LogicalPlan, + filterPropagationSupported: Boolean): + Option[(LogicalPlan, AttributeMap[Attribute], Option[Expression], Option[Expression], + Option[Double])] = { + checkIdenticalPlans(newPlan, cachedPlan).map { outputMap => + // Currently the cost is always propagated up when `filterPropagationSupported` is true but + // later we can address cases when we don't need to take cost into account. Please find the + // details at the `Filter` node handling. + val mergeCost = if (filterPropagationSupported) Some(0d) else None + + (cachedPlan, outputMap, None, None, mergeCost) + }.orElse( + (newPlan, cachedPlan) match { + case (np: Project, cp: Project) => + tryMergePlans(np.child, cp.child, filterPropagationSupported).map { + case (mergedChild, outputMap, newChildFilter, mergedChildFilter, childMergeCost) => + val (mergedProjectList, newOutputMap, newPlanFilter, mergedPlanFilter, mergeCost) = + mergeNamedExpressions(np.projectList, outputMap, cp.projectList, newChildFilter, + mergedChildFilter, childMergeCost) + val mergedPlan = Project(mergedProjectList, mergedChild) + (mergedPlan, newOutputMap, newPlanFilter, mergedPlanFilter, mergeCost) + } + case (np, cp: Project) => + tryMergePlans(np, cp.child, filterPropagationSupported).map { + case (mergedChild, outputMap, newChildFilter, mergedChildFilter, childMergeCost) => + val (mergedProjectList, newOutputMap, newPlanFilter, mergedPlanFilter, mergeCost) = + mergeNamedExpressions(np.output, outputMap, cp.projectList, newChildFilter, + mergedChildFilter, childMergeCost) + val mergedPlan = Project(mergedProjectList, mergedChild) + (mergedPlan, newOutputMap, newPlanFilter, mergedPlanFilter, mergeCost) + } + case (np: Project, cp) => + tryMergePlans(np.child, cp, filterPropagationSupported).map { + case (mergedChild, outputMap, newChildFilter, mergedChildFilter, childMergeCost) => + val (mergedProjectList, newOutputMap, newPlanFilter, mergedPlanFilter, mergeCost) = + mergeNamedExpressions(np.projectList, outputMap, cp.output, newChildFilter, + mergedChildFilter, childMergeCost) + val mergedPlan = Project(mergedProjectList, mergedChild) + (mergedPlan, newOutputMap, newPlanFilter, mergedPlanFilter, mergeCost) + } + case (np: Aggregate, cp: Aggregate) if supportedAggregateMerge(np, cp) => + val filterPropagationSupported = + ColumnarPluginConfig.getConf.filterMergeEnable && + supportsFilterPropagation(np) && supportsFilterPropagation(cp) + tryMergePlans(np.child, cp.child, filterPropagationSupported).flatMap { + case (mergedChild, outputMap, None, None, _) => + val mappedNewGroupingExpression = + np.groupingExpressions.map(mapAttributes(_, outputMap)) + // Order of grouping expression does matter as merging different grouping orders can + // introduce "extra" shuffles/sorts that might not present in all of the original + // subqueries. + if (mappedNewGroupingExpression.map(_.canonicalized) == + cp.groupingExpressions.map(_.canonicalized)) { + // No need to calculate and check costs as there is no propagated filter + val (mergedAggregateExpressions, newOutputMap, _, _, _) = + mergeNamedExpressions(np.aggregateExpressions, outputMap, cp.aggregateExpressions, + None, None, None) + val mergedPlan = + Aggregate(cp.groupingExpressions, mergedAggregateExpressions, mergedChild) + Some(mergedPlan, newOutputMap, None, None, None) + } else { + None + } + case (mergedChild, outputMap, newChildFilter, mergedChildFilter, childMergeCost) => + // No need to calculate cost in `mergeNamedExpressions()` + val (mergedAggregateExpressions, newOutputMap, _, _, _) = + mergeNamedExpressions( + filterAggregateExpressions(np.aggregateExpressions, newChildFilter), + outputMap, + filterAggregateExpressions(cp.aggregateExpressions, mergedChildFilter), + None, + None, + None) + + val mergeFilters = newChildFilter.isEmpty || mergedChildFilter.isEmpty || { + val mergeCost = childMergeCost.map { c => + val newPlanExtraCost = mergedChildFilter.map(getCost).getOrElse(0d) + + newChildFilter.map(getCost).getOrElse(0d) + val cachedPlanExtraCost = newPlanExtraCost + c + newPlanExtraCost + cachedPlanExtraCost + } + mergeCost.forall { c => + val maxCost = ColumnarPluginConfig.getConf.filterMergeThreshold + val enableMerge = maxCost < 0 || c <= maxCost + if (!enableMerge) { + logDebug( + s"Plan merge of\n${np}and\n${cp}failed as the merge cost is too high: $c") + } + enableMerge + } + } + if (mergeFilters) { + val mergedPlan = Aggregate(Seq.empty, mergedAggregateExpressions, mergedChild) + Some(mergedPlan, newOutputMap, None, None, None) + } else { + None + } + case _ => None + } + + // Here is the difference with MergeScalarSubqueries Rule. + // We can still merge the 'Filters' when they are not exactly the same. + // The differing `Filter`s can be merged if: + // - they both they have an ancestor `Aggregate` node that has no grouping and + // - there are only `Project` or `Filter` nodes in between the different `Filters` and the + // ancestor `Aggregate` nodes. + // + // For example, we can merge: + // + // SELECT avg(a) FROM t WHERE c = 1 + // + // and: + // + // SELECT sum(b) FROM t WHERE c = 2 + // + // into: + // + // SELECT + // avg(a) FILTER (WHERE c = 1), + // sum(b) FILTER (WHERE c = 2) + // FROM t + // WHERE c = 1 OR c = 2 + // + // But there are some special cases we need to consider: + // - The plans to be merged might contain multiple adjacent `Filter` nodes and the parent + // `Filter` nodes should incorporate the propagated filters from child ones during merge. + // For example, adjacent filters can appear in plans when some of the optimization rules + // (like `PushDownPredicates`) are disabled. + // + // Let's consider we want to merge query 1: + // + // SELECT avg(a) + // FROM ( + // SELECT * FROM t WHERE c1 = 1 + // ) + // WHERE c2 = 1 + // + // and query 2: + // + // SELECT sum(b) + // FROM ( + // SELECT * FROM t WHERE c1 = 2 + // ) + // WHERE c2 = 2 + // + // Then the optimal merged query is: + // + // SELECT + // avg(a) FILTER (WHERE c1 = 1 AND c2 = 1), + // sum(b) FILTER (WHERE c1 = 2 AND c2 = 2) + // FROM ( + // SELECT * FROM t WHERE c1 = 1 OR c1 = 2 + // ) + // WHERE (c1 = 1 AND c2 = 1) OR (c1 = 2 AND c2 = 2) + case (np: Filter, cp: Filter) => + tryMergePlans(np.child, cp.child, filterPropagationSupported).flatMap { + case (mergedChild, outputMap, newChildFilter, mergedChildFilter, childMergeCost) => + val mappedNewCondition = mapAttributes(np.condition, outputMap) + // Comparing the canonicalized form is required to ignore different forms of the same + // expression. + if (mappedNewCondition.canonicalized == cp.condition.canonicalized) { + val filters = (mergedChildFilter.toSeq ++ newChildFilter.toSeq).reduceOption(Or) + .map(PropagatedFilter) + val mergedCondition = (filters.toSeq :+ cp.condition).reduce(And) + val mergedPlan = Filter(mergedCondition, mergedChild) + val mergeCost = addFilterCost(childMergeCost, mergedCondition, + getCost(np.condition), getCost(cp.condition)) + Some(mergedPlan, outputMap, newChildFilter, mergedChildFilter, mergeCost) + } else if (filterPropagationSupported) { + val newPlanFilter = (newChildFilter.toSeq :+ mappedNewCondition).reduce(And) + val cachedPlanFilter = (mergedChildFilter.toSeq :+ cp.condition).reduce(And) + val mergedCondition = PropagatedFilter(Or(cachedPlanFilter, newPlanFilter)) + val mergedPlan = Filter(mergedCondition, mergedChild) + val nonPropagatedCachedFilter = extractNonPropagatedFilter(cp.condition) + val mergedPlanFilter = + (mergedChildFilter.toSeq ++ nonPropagatedCachedFilter.toSeq).reduceOption(And) + val mergeCost = addFilterCost(childMergeCost, mergedCondition, + getCost(np.condition), getCost(cp.condition)) + Some(mergedPlan, outputMap, Some(newPlanFilter), mergedPlanFilter, mergeCost) + } else { + None + } + } + case (np, cp: Filter) if filterPropagationSupported => + tryMergePlans(np, cp.child, true).map { + case (mergedChild, outputMap, newChildFilter, mergedChildFilter, childMergeCost) => + val nonPropagatedCachedFilter = extractNonPropagatedFilter(cp.condition) + val mergedPlanFilter = + (mergedChildFilter.toSeq ++ nonPropagatedCachedFilter.toSeq).reduceOption(And) + if (newChildFilter.isEmpty) { + (mergedChild, outputMap, None, mergedPlanFilter, childMergeCost) + } else { + val cachedPlanFilter = (mergedChildFilter.toSeq :+ cp.condition).reduce(And) + val mergedCondition = PropagatedFilter(Or(cachedPlanFilter, newChildFilter.get)) + val mergedPlan = Filter(mergedCondition, mergedChild) + val mergeCost = + addFilterCost(childMergeCost, mergedCondition, 0d, getCost(cp.condition)) + (mergedPlan, outputMap, newChildFilter, mergedPlanFilter, mergeCost) + } + } + case (np: Filter, cp) if filterPropagationSupported => + tryMergePlans(np.child, cp, true).map { + case (mergedChild, outputMap, newChildFilter, mergedChildFilter, childMergeCost) => + val mappedNewCondition = mapAttributes(np.condition, outputMap) + val newPlanFilter = (newChildFilter.toSeq :+ mappedNewCondition).reduce(And) + if (mergedChildFilter.isEmpty) { + (mergedChild, outputMap, Some(newPlanFilter), None, childMergeCost) + } else { + val mergedCondition = PropagatedFilter(Or(mergedChildFilter.get, newPlanFilter)) + val mergedPlan = Filter(mergedCondition, mergedChild) + val mergeCost = + addFilterCost(childMergeCost, mergedCondition, getCost(np.condition), 0d) + (mergedPlan, outputMap, Some(newPlanFilter), mergedChildFilter, mergeCost) + } + } + + case (np: Join, cp: Join) if np.joinType == cp.joinType && np.hint == cp.hint => + // Filter propagation is not allowed through joins + tryMergePlans(np.left, cp.left, false).flatMap { + case (mergedLeft, leftOutputMap, None, None, _) => + tryMergePlans(np.right, cp.right, false).flatMap { + case (mergedRight, rightOutputMap, None, None, _) => + val outputMap = leftOutputMap ++ rightOutputMap + val mappedNewCondition = np.condition.map(mapAttributes(_, outputMap)) + // Comparing the canonicalized form is required to ignore different forms of the + // same expression and `AttributeReference.quailifier`s in `cp.condition`. + if (mappedNewCondition.map(_.canonicalized) == + cp.condition.map(_.canonicalized)) { + val mergedPlan = cp.withNewChildren(Seq(mergedLeft, mergedRight)) + Some(mergedPlan, outputMap, None, None, None) + } else { + None + } + case _ => None + } + case _ => None + } + + // Otherwise merging is not possible. + case _ => None + } + ) + } + + private def createProject(attributes: Seq[Attribute], plan: LogicalPlan): Project = { + Project( + Seq(Alias( + CreateNamedStruct(attributes.flatMap(a => Seq(Literal(a.name), a))), + "mergedValue")()), + plan) + } + + private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { + expr.transform { + case a: Attribute => outputMap.getOrElse(a, a) + }.asInstanceOf[T] + } + + /** + * Merges named expression lists of `Project` or `Aggregate` nodes of the new plan into the named + * expression list of a similar node of the cached plan. + * + * - Before we can merge the new expressions we need to take into account the propagated + * attribute mapping that describes the transformation from the input attributes of the new plan + * node to the output attributes of the already merged child plan node. + * - While merging the new expressions we need to build a new attribute mapping to propagate up. + * - If any filters are propagated from `Filter` nodes below then we could add all the referenced + * attributes of filter conditions to the merged expression list, but it is better if we alias + * whole filter conditions and propagate only the new boolean attributes. + * + * @param newExpressions the expression list of the new plan node + * @param outputMap the propagated attribute mapping + * @param cachedExpressions the expression list of the cached plan node + * @param newChildFilter the propagated filters from `Filter` nodes of the new plan + * @param mergedChildFilter the propagated filters from `Filter` nodes of the merged child plan + * @param childMergeCost the optional accumulated extra costs of merge + * @return A tuple of: + * - the merged expression list, + * - the new attribute mapping to propagate, + * - the output attribute of the merged newChildFilter to propagate, + * - the output attribute of the merged mergedChildFilter to propagate, + * - the extra costs of merging new expressions and filters added to `childMergeCost` + */ + private def mergeNamedExpressions( + newExpressions: Seq[NamedExpression], + outputMap: AttributeMap[Attribute], + cachedExpressions: Seq[NamedExpression], + newChildFilter: Option[Expression], + mergedChildFilter: Option[Expression], + childMergeCost: Option[Double]): + (Seq[NamedExpression], AttributeMap[Attribute], Option[Attribute], Option[Attribute], + Option[Double]) = { + val mergedExpressions = ArrayBuffer[NamedExpression](cachedExpressions: _*) + val commonCachedExpressions = mutable.Set.empty[NamedExpression] + var cachedPlanExtraCost = 0d + val newOutputMap = AttributeMap(newExpressions.map { ne => + val mapped = mapAttributes(ne, outputMap) + val withoutAlias = mapped match { + case Alias(child, _) => child + case e => e + } + ne.toAttribute -> mergedExpressions.find { + case Alias(child, _) => child semanticEquals withoutAlias + case e => e semanticEquals withoutAlias + }.map { e => + if (childMergeCost.isDefined) { + commonCachedExpressions += e + } + e + }.getOrElse { + mergedExpressions += mapped + if (childMergeCost.isDefined) { + cachedPlanExtraCost += getCost(mapped) + } + mapped + }.toAttribute + }) + + def mergeFilter(filter: Option[Expression]) = { + filter.map { f => + mergedExpressions.find { + case Alias(child, _) => child semanticEquals f + case e => e semanticEquals f + }.map { e => + if (childMergeCost.isDefined) { + commonCachedExpressions += e + } + e + }.getOrElse { + val named = f match { + case ne: NamedExpression => ne + case o => Alias(o, "propagatedFilter")() + } + mergedExpressions += named + if (childMergeCost.isDefined) { + cachedPlanExtraCost += getCost(named) + } + named + }.toAttribute + } + } + + val mergedPlanFilter = mergeFilter(mergedChildFilter) + val newPlanFilter = mergeFilter(newChildFilter) + + val mergeCost = childMergeCost.map { c => + val newPlanExtraCost = cachedExpressions.collect { + case e if !commonCachedExpressions.contains(e) => getCost(e) + }.sum + c + newPlanExtraCost + cachedPlanExtraCost + } + + (mergedExpressions.toSeq, newOutputMap, newPlanFilter, mergedPlanFilter, mergeCost) + } + + /** + * Adds the extra cost of using `mergedCondition` (instead of the original cost of new and cached + * plan filter conditions) to the propagated extra cost from merged child plans. + */ + private def addFilterCost( + childMergeCost: Option[Double], + mergedCondition: Expression, + newPlanFilterCost: Double, + cachedPlanFilterCost: Double) = { + childMergeCost.map { c => + val mergedConditionCost = getCost(mergedCondition) + val newPlanExtraCost = mergedConditionCost - newPlanFilterCost + val cachedPlanExtraCost = mergedConditionCost - cachedPlanFilterCost + c + newPlanExtraCost + cachedPlanExtraCost + } + } + + // Currently only the most basic expressions are supported. + private def getCost(e: Expression): Double = e match { + case _: Literal | _: Attribute => 0d + case PropagatedFilter(child) => getCost(child) + case Alias(child, _) => getCost(child) + case _: BinaryComparison | _: BinaryArithmetic | _: And | _: Or | _: IsNull | _: IsNotNull => + 1d + e.children.map(getCost).sum + case _ => Double.PositiveInfinity + } + + // Only allow aggregates of the same implementation because merging different implementations + // could cause performance regression. + private def supportedAggregateMerge(newPlan: Aggregate, cachedPlan: Aggregate) = { + val aggregateExpressionsSeq = Seq(newPlan, cachedPlan).map { plan => + plan.aggregateExpressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + } + val Seq(newPlanSupportsHashAggregate, cachedPlanSupportsHashAggregate) = + aggregateExpressionsSeq.map(aggregateExpressions => Aggregate.supportsHashAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))) + newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate || + newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && { + val Seq(newPlanSupportsObjectHashAggregate, cachedPlanSupportsObjectHashAggregate) = + aggregateExpressionsSeq.map(aggregateExpressions => + Aggregate.supportsObjectHashAggregate(aggregateExpressions)) + newPlanSupportsObjectHashAggregate && cachedPlanSupportsObjectHashAggregate || + newPlanSupportsObjectHashAggregate == cachedPlanSupportsObjectHashAggregate + } + } + + private def extractNonPropagatedFilter(e: Expression) = { + e match { + case And(_: PropagatedFilter, e) => Some(e) + case _: PropagatedFilter => None + case o => Some(o) + } + } + + // We allow filter propagation into aggregates which: + // - doesn't have grouping expressions and + // - contains only the most basic aggregate functions. + private def supportsFilterPropagation(a: Aggregate) = { + a.groupingExpressions.isEmpty && + a.aggregateExpressions.forall { + !_.exists { + case ae: AggregateExpression => + ae.aggregateFunction match { + case _: Count | _: Sum | _: Average | _: Max | _: Min => false + case _ => true + } + case _ => false + } + } + } + + private def filterAggregateExpressions( + aggregateExpressions: Seq[NamedExpression], + filter: Option[Expression]) = { + if (filter.isDefined) { + aggregateExpressions.map(_.transform { + case ae: AggregateExpression => + ae.copy(filter = (filter.get +: ae.filter.toSeq).reduceOption(And)) + }.asInstanceOf[NamedExpression]) + } else { + aggregateExpressions + } + } + + private def removePropagatedFilters(plan: LogicalPlan) = { + plan.transformAllExpressions { + case pf: PropagatedFilter => pf.child + } + } + + // Second traversal replaces `ScalarSubqueryReference`s to either + // `GetStructField(ScalarSubquery(CTERelationRef to the merged plan)` if the plan is merged from + // multiple subqueries or `ScalarSubquery(original plan)` if it isn't. + private def removeReferences( + plan: LogicalPlan, + cache: ArrayBuffer[Header]) = { + plan.transformUpWithSubqueries { + case n => + n.transformExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { + case ssr: ScalarSubqueryReference => + val header = cache(ssr.subqueryIndex) + if (header.merged) { + val subqueryCTE = header.plan.asInstanceOf[CTERelationDef] + GetStructField( + ScalarSubquery( + CTERelationRef(subqueryCTE.id, _resolved = true, subqueryCTE.output), + exprId = ssr.exprId), + ssr.headerIndex) + } else { + ScalarSubquery(header.plan, exprId = ssr.exprId) + } + } + } + } +} + +/** + * Temporal reference to a cached subquery. + * + * @param subqueryIndex A subquery index in the cache. + * @param headerIndex An index in the output of merged subquery. + * @param dataType The dataType of origin scalar subquery. + */ +case class ScalarSubqueryReference( + subqueryIndex: Int, + headerIndex: Int, + dataType: DataType, + exprId: ExprId) extends LeafExpression with Unevaluable { + override def nullable: Boolean = true + + final override val nodePatterns: Seq[TreePattern] = Seq(SCALAR_SUBQUERY_REFERENCE) + + override def stringArgs: Iterator[Any] = Iterator(subqueryIndex, headerIndex, dataType, exprId.id) +} + + +/** + * Temporal wrapper around already propagated predicates. + */ +case class PropagatedFilter(child: Expression) extends UnaryExpression with Unevaluable { + override def dataType: DataType = child.dataType + + override protected def withNewChildInternal(newChild: Expression): PropagatedFilter = + copy(child = newChild) +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/RewriteSelfJoinInInPredicate.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala similarity index 99% rename from omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/RewriteSelfJoinInInPredicate.scala rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala index 22557aeaf1080b6b8fbbc2b6fb6cf51cba3f7c72..9e402902518a90df7f68db4b420d325a53f9bf41 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/RewriteSelfJoinInInPredicate.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSelfJoinInInPredicate.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package com.huawei.boostkit.spark +package org.apache.spark.sql.catalyst.optimizer import com.huawei.boostkit.spark.ColumnarPluginConfig diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..9ddbd2bd135d97ef215c67802497dfb78788ab16 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/AbstractUnsafeRowSorter.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; + +import scala.collection.Iterator; +import scala.math.Ordering; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; + +public abstract class AbstractUnsafeRowSorter +{ + protected final StructType schema; + + /** + * If positive, forces records to be spilled to disk at the give frequency (measured in numbers of records). + * This is only intended to be used in tests. + * */ + protected int testSpillFrequency = 0; + + AbstractUnsafeRowSorter(final StructType schema) { + this.schema = schema; + } + + // This flag makes sure the cleanupResource() has been called. + // After the cleanup work, iterator.next should always return false. + // Downstream operator triggers the resource cleanup while they found there's no need to keep the iterator anymore. + // See more detail in SPARK-21492. + boolean isReleased = false; + + public abstract void insertRow(UnsafeRow row) throws IOException; + + public abstract Iterator sort() throws IOException; + + public abstract Iterator sort(Iterator inputIterator) throws IOException; + + /** + * @return the peak memory used so far, in bytes. + * */ + public abstract long getPeakMemoryUsage(); + + /** + * @return the total amount of time spent sorting data (in-memory only). + * */ + public abstract long getSortTimeNanos(); + + public abstract void cleanupResources(); + + /** + * Foreces spills to occur every 'frequency' records. Only for use in tests. + * */ + @VisibleForTesting + void setTestSpillFrequency(int frequency) { + assert frequency > 0 : "Frequency must be positive"; + testSpillFrequency = frequency; + } + + static final class RowComparator extends RecordComparator { + private final Ordering ordering; + private final UnsafeRow row1; + private final UnsafeRow row2; + + RowComparator(Ordering ordering, int numFields) { + this.row1 = new UnsafeRow(numFields); + this.row2 = new UnsafeRow(numFields); + this.ordering = ordering; + } + + @Override + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { + // Note that since ordering doesn't need the total length of the record, we just pass 0 int the row. + row1.pointTo(baseObj1, baseOff1, 0); + row2.pointTo(baseObj2, baseOff2, 0); + return ordering.compare(row1, row2); + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index a421717549b2f87c211e7b1c1ed4693fe769ac12..4863698430e78896de76218897b30870bad46634 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -58,7 +58,7 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches")) def buildCheck(): Unit = { val omniAttrExpsIdMap = getExprIdMap(child.output) @@ -71,7 +71,7 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val addInputTime = longMetric("addInputTime") val omniCodegenTime = longMetric("omniCodegenTime") val getOutputTime = longMetric("getOutputTime") @@ -83,7 +83,7 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => - dealPartitionData(numOutputRows, numOutputVecBatchs, addInputTime, omniCodegenTime, + dealPartitionData(numOutputRows, numOutputVecBatches, addInputTime, omniCodegenTime, getOutputTime, omniInputTypes, omniExpressions, iter, this.schema) } } @@ -145,12 +145,12 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), - "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches")) protected override def doExecute(): RDD[InternalRow] = { @@ -188,9 +188,9 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numInputRows = longMetric("numInputRows") - val numInputVecBatchs = longMetric("numInputVecBatchs") + val numInputVecBatches= longMetric("numInputVecBatches") val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val addInputTime = longMetric("addInputTime") val omniCodegenTime = longMetric("omniCodegenTime") val getOutputTime = longMetric("getOutputTime") @@ -226,7 +226,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) val startInput = System.nanoTime() filterOperator.addInput(vecBatch) addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) - numInputVecBatchs += 1 + numInputVecBatches+= 1 numInputRows += batch.numRows() val startGetOp = System.nanoTime() @@ -254,7 +254,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) vector.setVec(vecBatch.getVectors()(i)) } numOutputRows += vecBatch.getRowCount - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 vecBatch.close() new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) } @@ -281,18 +281,18 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), - "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches")) override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numInputRows = longMetric("numInputRows") - val numInputVecBatchs = longMetric("numInputVecBatchs") + val numInputVecBatches= longMetric("numInputVecBatches") val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val addInputTime = longMetric("addInputTime") val omniCodegenTime = longMetric("omniCodegenTime") val getOutputTime = longMetric("getOutputTime") @@ -327,7 +327,7 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], val startInput = System.nanoTime() operator.addInput(vecBatch) addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) - numInputVecBatchs += 1 + numInputVecBatches+= 1 numInputRows += batch.numRows() val startGetOp = System.nanoTime() @@ -355,7 +355,7 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], vector.setVec(vecBatch.getVectors()(i)) } numOutputRows += vecBatch.getRowCount - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 vecBatch.close() new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBloomFilterSubquery.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBloomFilterSubquery.scala index 03ba89e33f0e8794a4104901ac47fd1a1d2d2be0..5b567abe6b9e74f36587897aae00d3b4a8aa3a02 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBloomFilterSubquery.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBloomFilterSubquery.scala @@ -43,14 +43,17 @@ case class ColumnarBloomFilterSubquery(plan: BaseSubqueryExec, exprId: ExprId, s override def eval(input: InternalRow): Any = { var ret = 0L // if eval at driver side, return 0 - if (SparkEnv.get.executorId != SparkContext.DRIVER_IDENTIFIER) { + try { result = scalarSubquery.eval(input) - if (result != null) { - ret = copyToNativeBloomFilter() - } + } catch { + case e: IllegalArgumentException => { return ret; } + } + if (result != null) { + ret = copyToNativeBloomFilter() } ret } + override def withNewPlan(query: BaseSubqueryExec): ColumnarBloomFilterSubquery = copy(plan = scalarSubquery.plan) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = scalarSubquery.doGenCode(ctx, ev) override def updateResult(): Unit = scalarSubquery.updateResult() @@ -72,7 +75,12 @@ case class ColumnarBloomFilterSubquery(plan: BaseSubqueryExec, exprId: ExprId, s // return BloomFilter off-heap address assert(outputs.hasNext, s"Expects bloom filter address value, but got nothing.") - bloomFilterNativeAddress = outputs.next().getVector(0).asInstanceOf[LongVec].get(0) + val outVecBatch = outputs.next() + bloomFilterNativeAddress = outVecBatch.getVector(0).asInstanceOf[LongVec].get(0) + // bloomFilterNativeAddress is used, but on one trace outVecBatch + outVecBatch.releaseAllVectors + outVecBatch.close + bloomFilterNativeAddress } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index ce510b168ac2d88eee5f5514f5035b671e5c7c86..2a90769c30dbd89917dd2d72df893e87a361fb06 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -18,24 +18,31 @@ package org.apache.spark.sql.execution import java.util.concurrent._ - import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs + import nova.hetu.omniruntime.vector.VecBatch import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters.asScalaIteratorConverter import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, SparkException} + +import org.apache.spark.{SparkException, broadcast} import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike} import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelationBroadcastMode, HashedRelationWithAllNullKeys} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{SparkFatalException, ThreadUtils} @@ -208,6 +215,67 @@ class ColumnarHashedRelation extends Serializable { } buildData = array } + + def transform(key: Expression, output: Seq[Attribute]): Array[InternalRow] = { + if (relation == EmptyHashedRelation) { + Iterator.empty.toArray + } else { + val deserializer = VecBatchSerializerFactory.create() + val columnNames = key.flatMap { + case expression: AttributeReference => Some(expression) + case _ => None + } + if (columnNames.isEmpty) { + throw new IllegalArgumentException(s"Key column not found in expression: $key") + } + if (columnNames.size != 1) { + throw new IllegalArgumentException(s"Multiple key columns found in expression: $key") + } + val columnExpr = columnNames.head + val oneColumnWithSameName = output.count(_.name == columnExpr.name) == 1 + val columnInOutput = output.zipWithIndex.filter { + p: (Attribute, Int) => + if (oneColumnWithSameName) { + // The comparison of exprId can be ignored when + // only one attribute name match is found. + p._1.name == columnExpr.name + } else { + // A case where output has multiple columns with same name + p._1.name == columnExpr.name && p._1.exprId == columnExpr.exprId + } + } + if (columnInOutput.isEmpty) { + throw new IllegalStateException( + s"Key $key not found from build side relation output: $output") + } + if (columnInOutput.size != 1) { + throw new IllegalStateException( + s"More than one key $key found from build side relation output: $output") + } + val replacement = + BoundReference(columnInOutput.head._2, columnExpr.dataType, columnExpr.nullable) + + val projExpr = key.transformDown { + case _: AttributeReference => + replacement + } + + val proj = UnsafeProjection.create(projExpr) + + val retRows = new ArrayBuffer[InternalRow]() + buildData.foreach { input => + val vecBatch: VecBatch = deserializer.deserialize(input) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, StructType.fromAttributes(output), false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i))} + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + .rowIterator().asScala.map(proj).foreach(retRows.append(_)) + } + retRows.toArray + } + } } object ColumnarBroadcastExchangeExec { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarCoalesceExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarCoalesceExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..166471fabbdac10ffe42f99a12488367d2e0db05 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarCoalesceExec.scala @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.sparkTypeToOmniType + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning} +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Physical plan for returning a new RDD that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. If a larger number of partitions + * is requested, it will stay at the current number of partitions. + * + * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, + * this may result in your computation taking place on fewer nodes than + * you like (e.g. one node in the case of numPartitions = 1). To avoid this, + * you see ShuffleExchange. This will add a shuffle step, but means the + * current upstream partitions will be executed in parallel (per whatever + * the current partitioning is). + */ +case class ColumnarCoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecNode { + + override def nodeName: String = "ColumnarCoalesceExec" + + override def supportsColumnar: Boolean = true + + def buildCheck(): Unit = { + child.output.foreach(attr => sparkTypeToOmniType(attr.dataType, attr.metadata)) + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = { + if (numPartitions == 1) SinglePartition + else UnknownPartitioning(numPartitions) + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("ColumnarCoalesceExec operator doesn't support doExecute().") + } + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val rdd = child.executeColumnar() + if (numPartitions == 1 && rdd.getNumPartitions < 1) { + // Make sure we don't output an RDD with 0 partitions, when claiming that we have a + // `SinglePartition`. + new ColumnarCoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + } else { + rdd.coalesce(numPartitions, shuffle = false) + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarCoalesceExec = + copy(child = newChild) +} + +object ColumnarCoalesceExec { + /** A simple RDD with no data, but with the given number of partitions. */ + class EmptyRDDWithPartitions( + @transient private val sc: SparkContext, + numPartitions: Int) extends RDD[ColumnarBatch](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + Iterator.empty + } + } + + case class EmptyPartition(index: Int) extends Partition +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index 26c2dd7cf81e9c81e2d4c23d5376a7481f54a393..cec2012e6229d81822ad66abc2f0b74091b5efb1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -37,29 +37,6 @@ import org.apache.spark.util.Utils import nova.hetu.omniruntime.vector.Vec -/** - * Holds a user defined rule that can be used to inject columnar implementations of various - * operators in the plan. The [[preColumnarTransitions]] [[Rule]] can be used to replace - * [[SparkPlan]] instances with versions that support a columnar implementation. After this - * Spark will insert any transitions necessary. This includes transitions from row to columnar - * [[RowToColumnarExec]] and from columnar to row [[ColumnarToRowExec]]. At this point the - * [[postColumnarTransitions]] [[Rule]] is called to allow replacing any of the implementations - * of the transitions or doing cleanup of the plan, like inserting stages to build larger batches - * for more efficient processing, or stages that transition the data to/from an accelerator's - * memory. - */ -class ColumnarRule { - def preColumnarTransitions: Rule[SparkPlan] = plan => plan - def postColumnarTransitions: Rule[SparkPlan] = plan => plan -} - -/** - * A trait that is used as a tag to indicate a transition from columns to rows. This allows plugins - * to replace the current [[ColumnarToRowExec]] with an optimized version and still have operations - * that walk a spark plan looking for this type of transition properly match it. - */ -trait ColumnarToRowTransition extends UnaryExecNode - /** * Provides an optimized set of APIs to append row based data to an array of @@ -348,9 +325,9 @@ object ColumnarBatchToInternalRow { val batchIter = batches.flatMap { batch => - // toClosedVecs closed case: + // toClosedVecs closed case: [Deprcated] // 1) all rows of batch fetched and closed - // 2) only fetch parital rows(eg: top-n, limit-n), closed at task CompletionListener callback + // 2) only fetch Partial rows(eg: top-n, limit-n), closed at task CompletionListener callback val toClosedVecs = new ListBuffer[Vec] for (i <- 0 until batch.numCols()) { batch.column(i) match { @@ -366,27 +343,22 @@ object ColumnarBatchToInternalRow { new Iterator[InternalRow] { val numOutputRowsMetric: SQLMetric = numOutputRows - var closed = false - - // only invoke if fetch partial rows of batch - if (mayPartialFetch) { - SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => - if (!closed) { - toClosedVecs.foreach {vec => - vec.close() - } + + + SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => + toClosedVecs.foreach {vec => + vec.close() } - } } override def hasNext: Boolean = { val has = iter.hasNext - // fetch all rows and closed - if (!has && !closed) { + // fetch all rows + if (!has) { toClosedVecs.foreach {vec => vec.close() + toClosedVecs.remove(toClosedVecs.indexOf(vec)) } - closed = true } has } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index b25d97d604da1ae0cbaef04b34bbf53e61b8af83..24c74d600b9f360db932d25f0e3a60ea87ade19c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -18,22 +18,28 @@ package org.apache.spark.sql.execution import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP -import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, getExprIdMap, rewriteToOmniJsonExpressionLiteral, sparkTypeToOmniType} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, getExprIdMap, rewriteToOmniJsonExpressionLiteral, sparkTypeToOmniType, toOmniAggFunType, toOmniAggInOutJSonExp, toOmniAggInOutType} import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.project.OmniProjectOperatorFactory import nova.hetu.omniruntime.vector.{LongVec, Vec, VecBatch} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Max, Min, Partial, PartialMerge, Sum} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.vectorized.ColumnarBatch +import scala.collection.mutable.ListBuffer import scala.concurrent.duration.NANOSECONDS +import scala.math.pow /** * Apply all of the GroupExpressions to every input row, hence we will get @@ -59,7 +65,7 @@ case class ColumnarExpandExec( "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), ) // The GroupExpressions can output data with arbitrary partitioning, so set it @@ -82,9 +88,37 @@ case class ColumnarExpandExec( omniExpressions.foreach(exps => checkOmniJsonWhiteList("", exps)) } + def matchRollupOptimization(): Boolean = { + // Expand operator contains "count(distinct)", "rollup", "cube", "grouping sets", + // it checks whether match "rollup" operations and part "grouping sets" operation. + // For example, grouping columns a and b, such as rollup(a, b), grouping sets((a, b), (a)). + if (projections.length == 1){ + return false + } + var step = 0 + projections.foreach(projection => { + projection.last match { + case literal: Literal => + if (literal.value != (pow(2, step) - 1)) { + return false + } + case _ => + return false + } + step += 1 + }) + true + } + + def replace(newProjections: Seq[Seq[Expression]] = projections, + newOutput: Seq[Attribute] = output, + newChild: SparkPlan = child): ColumnarExpandExec = { + copy(projections = newProjections, output = newOutput, child = newChild) + } + override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRowsMetric = longMetric("numOutputRows") - val numOutputVecBatchsMetric = longMetric("numOutputVecBatchs") + val numOutputVecBatchesMetric = longMetric("numOutputVecBatches") val addInputTimeMetric = longMetric("addInputTime") val omniCodegenTimeMetric = longMetric("omniCodegenTime") val getOutputTimeMetric = longMetric("getOutputTime") @@ -166,7 +200,7 @@ case class ColumnarExpandExec( val rowCount = result.getRowCount numOutputRowsMetric += rowCount - numOutputVecBatchsMetric += 1 + numOutputVecBatchesMetric += 1 result.close() new ColumnarBatch(vectors.toArray, rowCount) } @@ -181,3 +215,210 @@ case class ColumnarExpandExec( override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExpandExec = copy(child = newChild) } + + +/** + * rollup optimization: handle 2~N combinations + * + * @param projections The group and aggregation of expressions, all of the group expressions should + * output the same schema specified bye the parameter `output` + * @param output The output Schema + * @param groupingExpressions The group of expressions + * @param aggregateExpressions The aggregation of expressions + * @param aggregateAttributes The aggregation of attributes + * @param child Child operator + */ +case class ColumnarOptRollupExec( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode { + + override def supportsColumnar: Boolean = true + + override def nodeName: String = "OmniColumnarOptRollup" + + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), + ) + + // The GroupExpressions can output data with arbitrary partitioning, so set it + // as UNKNOWN partitioning + override def outputPartitioning: Partitioning = UnknownPartitioning(0) + + @transient + override lazy val references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRowsMetric = longMetric("numOutputRows") + val numOutputVecBatchesMetric = longMetric("numOutputVecBatches") + val addInputTimeMetric = longMetric("addInputTime") + val omniCodegenTimeMetric = longMetric("omniCodegenTime") + val getOutputTimeMetric = longMetric("getOutputTime") + + // handle expand logic + val projectAttrExpsIdMap = getExprIdMap(child.output) + val omniInputTypes = child.output.map(exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val omniExpressions = projections.map(exps => exps.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, projectAttrExpsIdMap) + ).toArray).toArray + + // handle hashagg logic + val hashaggAttrExpsIdMap = getExprIdMap(child.output) + val omniGroupByChannel = groupingExpressions.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, hashaggAttrExpsIdMap) + ).toArray + + val omniInputRaws = new Array[Boolean](aggregateExpressions.size) + val omniOutputPartials = new Array[Boolean](aggregateExpressions.size) + val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) + val omniAggOutputTypes = new Array[Array[DataType]](aggregateExpressions.size) + var omniAggChannels = new Array[Array[String]](aggregateExpressions.size) + val omniAggChannelsFilter = new Array[String](aggregateExpressions.size) + + var index = 0 + for (exp <- aggregateExpressions) { + if (exp.filter.isDefined) { + omniAggChannelsFilter(index) = + rewriteToOmniJsonExpressionLiteral(exp.filter.get, hashaggAttrExpsIdMap) + } + if (exp.mode == PartialMerge) { + ColumnarHashAggregateExec.AssignOmniInfoWhenPartialMergeStage(exp, + hashaggAttrExpsIdMap, + index, + omniInputRaws, + omniOutputPartials, + omniAggFunctionTypes, + omniAggOutputTypes, + omniAggChannels) + } else { + throw new UnsupportedOperationException(s"Unsupported aggregate mode: ${exp.mode} in ColumnarOptRollupExec") + } + index += 1 + } + + omniAggChannels = omniAggChannels.filter(key => key != null) + val omniSourceTypes = new Array[DataType](child.output.size) + child.output.zipWithIndex.foreach { + case (attr, i) => + omniSourceTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) + } + + child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + val startCodegen = System.nanoTime() + val projectOperators = omniExpressions.map(exps => { + val factory = new OmniProjectOperatorFactory(exps, omniInputTypes, 1, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + factory.createOperator + }) + + val hashaggOperator = OmniAdaptorUtil.getAggOperator(groupingExpressions, + omniGroupByChannel, + omniAggChannels, + omniAggChannelsFilter, + omniSourceTypes, + omniAggFunctionTypes, + omniAggOutputTypes, + omniInputRaws, + omniOutputPartials) + + omniCodegenTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + val results = new ListBuffer[VecBatch]() + var hashaggResults: java.util.Iterator[VecBatch] = null + + // close operator + addLeakSafeTaskCompletionListener[Unit](_ => { + projectOperators.foreach(operator => operator.close()) + hashaggOperator.close() + results.foreach(vecBatch => { + vecBatch.releaseAllVectors() + vecBatch.close() + }) + }) + + while (iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + results.append(vecBatch) + projectOperators.foreach(projectOperator => { + val vecs = transColBatchToOmniVecs(batch, true) + + val projectInput = new VecBatch(vecs, vecBatch.getRowCount) + var startInput = System.nanoTime() + projectOperator.addInput(projectInput) + addInputTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startInput) + + val startGetOutput = System.nanoTime() + val projectResults = projectOperator.getOutput + getOutputTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startGetOutput) + + if (!projectResults.hasNext) { + throw new RuntimeException("project operator failed!") + } + + val hashaggInput = projectResults.next() + + startInput = System.nanoTime() + hashaggOperator.addInput(hashaggInput) + addInputTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startInput) + }) + } + + if (results.nonEmpty) { + val startGetOutput = System.nanoTime() + hashaggResults = hashaggOperator.getOutput + getOutputTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startGetOutput) + } + + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + val startGetOutput = System.nanoTime() + val hasNext = results.nonEmpty || (hashaggResults != null && hashaggResults.hasNext) + getOutputTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startGetOutput) + hasNext + } + + override def next(): ColumnarBatch = { + var vecBatch: VecBatch = null + if (results.nonEmpty) { + vecBatch = results.remove(0) + } else { + val startGetOutput = System.nanoTime() + vecBatch = hashaggResults.next() + getOutputTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startGetOutput) + } + + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, schema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + + val rowCount = vecBatch.getRowCount + numOutputRowsMetric += rowCount + numOutputVecBatchesMetric += 1 + vecBatch.close() + new ColumnarBatch(vectors.toArray, rowCount) + } + } + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"ColumnarOptRollupExec operator doesn't support doExecute().") + } + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarOptRollupExec = + copy(child = newChild) +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 68ef1256212a7c85108cbb343248ba7ac61e58e4..ddabce36747a404dc95450b9c9f5e95ea9a1b508 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -72,9 +72,7 @@ abstract class BaseColumnarFileSourceScanExec( optionalNumCoalescedBuckets: Option[Int], dataFilters: Seq[Expression], tableIdentifier: Option[TableIdentifier], - needPriv: Boolean = false, - disableBucketedScan: Boolean = false, - outputAllAttributes: Seq[Attribute] = Seq.empty[Attribute]) + disableBucketedScan: Boolean = false) extends DataSourceScanExec { lazy val metadataColumns: Seq[AttributeReference] = @@ -193,12 +191,12 @@ abstract class BaseColumnarFileSourceScanExec( // If sort columns are (col0, col1), then sort ordering would be considered as (col0) // If sort columns are (col1, col0), then sort ordering would be empty as per rule #2 // above - val spec = relation.bucketSpec.get - val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) - val numPartitions = optionalNumCoalescedBuckets.getOrElse(spec.numBuckets) + val bucketSpec = relation.bucketSpec.get + val bucketColumns = bucketSpec.bucketColumnNames.flatMap(n => toAttribute(n)) + val numPartitions = optionalNumCoalescedBuckets.getOrElse(bucketSpec.numBuckets) val partitioning = HashPartitioning(bucketColumns, numPartitions) val sortColumns = - spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) + bucketSpec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) val shouldCalculateSortOrder = conf.getConf(SQLConf.LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING) && sortColumns.nonEmpty && @@ -335,17 +333,6 @@ abstract class BaseColumnarFileSourceScanExec( relation.fileFormat } - // Prepare conf for persist bad records - val userBadRecordsPath = BadRecordsWriterUtils.getUserBadRecordsPath(relation.sparkSession) - val options = if (userBadRecordsPath.isDefined) { - val badRecordsPathWithTableIdentifier = BadRecordsWriterUtils.addTableIdentifierToPath( - userBadRecordsPath.get, tableIdentifier) - relation.options ++ Map( - "badRecordsPath" -> badRecordsPathWithTableIdentifier) - } else { - relation.options - } - val readFile: (PartitionedFile) => Iterator[InternalRow] = fileFormat.buildReaderWithPartitionValues( sparkSession = relation.sparkSession, @@ -353,8 +340,8 @@ abstract class BaseColumnarFileSourceScanExec( partitionSchema = relation.partitionSchema, requiredSchema = requiredSchema, filters = pushedDownFilters, - options = options, - hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(options)) + options = relation.options, + hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) val readRDD = if (bucketedScan) { createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions, @@ -401,7 +388,7 @@ abstract class BaseColumnarFileSourceScanExec( "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files read"), "metadataTime" -> SQLMetrics.createTimingMetric(sparkContext, "metadata time"), "filesSize" -> SQLMetrics.createSizeMetric(sparkContext, "size of files read"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs") + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches") ) ++ { // Tracking scan time has overhead, we can't afford to do it for each row, and can only do // it for each batch. @@ -435,7 +422,7 @@ abstract class BaseColumnarFileSourceScanExec( protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") val scanTime = longMetric("scanTime") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val localSchema = this.schema inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => new Iterator[ColumnarBatch] { @@ -458,7 +445,7 @@ abstract class BaseColumnarFileSourceScanExec( vector.setVec(input(i)) } numOutputRows += batch.numRows - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 new ColumnarBatch(vectors.toArray, batch.numRows) } } @@ -552,7 +539,7 @@ abstract class BaseColumnarFileSourceScanExec( _ => true } - var splitFiles = selectedPartitions.flatMap { partition => + val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => // getPath() is very expensive so we only want to call it once in this block: val filePath = file.getPath @@ -572,13 +559,7 @@ abstract class BaseColumnarFileSourceScanExec( Seq.empty } } - } - - if (fsRelation.sparkSession.sessionState.conf.fileListSortBy == "length") { - splitFiles = splitFiles.sortBy(_.length)(implicitly[Ordering[Long]].reverse) - } else { - splitFiles = splitFiles.sortBy(_.filePath) - } + }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) val partitions = FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) @@ -607,17 +588,17 @@ abstract class BaseColumnarFileSourceScanExec( aggIndexOffset += agg.groupingExpressions.size val omniAggInputRaws = new Array[Boolean](agg.aggregateExpressions.size) - val omniAggOutputPartials =new Array[Boolean](agg.aggregateExpressions.size) + val omniPartialsAggOutput =new Array[Boolean](agg.aggregateExpressions.size) val omniAggTypes = new Array[DataType](agg.aggregateExpressions.size) val omniAggFunctionTypes = new Array[FunctionType](agg.aggregateExpressions.size) val omniAggOutputTypes = new Array[Array[DataType]](agg.aggregateExpressions.size) val omniAggChannels = new Array[Array[String]](agg.aggregateExpressions.size) val omniAggChannelsFilter = new Array[String](agg.aggregateExpressions.size) - var omniAggindex = 0 + var omniAggIndex = 0 for (exp <- agg.aggregateExpressions) { if (exp.filter.isDefined) { - omniAggChannelsFilter(omniAggindex) = + omniAggChannelsFilter(omniAggIndex) = rewriteToOmniJsonExpressionLiteral(exp.filter.get, attrAggExpsIdMap) } if (exp.mode == Final) { @@ -628,16 +609,16 @@ abstract class BaseColumnarFileSourceScanExec( val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> - (omniAggindex + aggIndexOffset) + (omniAggIndex + aggIndexOffset) } - omniAggTypes(omniAggindex) = sparkTypeToOmniType(aggExp.dataType) - omniAggFunctionTypes(omniAggindex) = toOmniAggFunType(exp, true) - omniAggOutputTypes(omniAggindex) = + omniAggTypes(omniAggIndex) = sparkTypeToOmniType(aggExp.dataType) + omniAggFunctionTypes(omniAggIndex) = toOmniAggFunType(exp, true) + omniAggOutputTypes(omniAggIndex) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) - omniAggChannels(omniAggindex) = + omniAggChannels(omniAggIndex) = toOmniAggInOutJSonExp(exp.aggregateFunction.children, attrAggExpsIdMap) - omniAggInputRaws(omniAggindex) = true - omniAggOutputPartials(omniAggindex) = true + omniAggInputRaws(omniAggIndex) = true + omniPartialsAggOutput(omniAggIndex) = true case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: $exp") } } else if (exp.mode == PartialMerge) { @@ -646,22 +627,22 @@ abstract class BaseColumnarFileSourceScanExec( val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> - (omniAggindex + aggIndexOffset) + (omniAggIndex + aggIndexOffset) } - omniAggTypes(omniAggindex) = sparkTypeToOmniType(aggExp.dataType) - omniAggFunctionTypes(omniAggindex) = toOmniAggFunType(exp, true) - omniAggOutputTypes(omniAggindex) = + omniAggTypes(omniAggIndex) = sparkTypeToOmniType(aggExp.dataType) + omniAggFunctionTypes(omniAggIndex) = toOmniAggFunType(exp, true) + omniAggOutputTypes(omniAggIndex) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) - omniAggChannels(omniAggindex) = + omniAggChannels(omniAggIndex) = toOmniAggInOutJSonExp(exp.aggregateFunction.inputAggBufferAttributes, attrAggExpsIdMap) - omniAggInputRaws(omniAggindex) = false - omniAggOutputPartials(omniAggindex) = true + omniAggInputRaws(omniAggIndex) = false + omniPartialsAggOutput(omniAggIndex) = true case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: $exp") } } else { throw new UnsupportedOperationException(s"Unsupported aggregate mode: $exp.mode") } - omniAggindex += 1 + omniAggIndex += 1 } var resultIdxToOmniResultIdxMap: Map[Int, Int] = Map() @@ -678,7 +659,7 @@ abstract class BaseColumnarFileSourceScanExec( omniAggSourceTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) } (omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniAggSourceTypes, omniAggFunctionTypes, - omniAggOutputTypes, omniAggInputRaws, omniAggOutputPartials, resultIdxToOmniResultIdxMap) + omniAggOutputTypes, omniAggInputRaws, omniPartialsAggOutput, resultIdxToOmniResultIdxMap) } def genProjectOutput(project: ColumnarProjectExec) = { @@ -792,9 +773,7 @@ case class ColumnarFileSourceScanExec( optionalNumCoalescedBuckets: Option[Int], dataFilters: Seq[Expression], tableIdentifier: Option[TableIdentifier], - needPriv: Boolean = false, - disableBucketedScan: Boolean = false, - outputAllAttributes: Seq[Attribute] = Seq.empty[Attribute]) + disableBucketedScan: Boolean = false) extends BaseColumnarFileSourceScanExec( relation, output, @@ -804,9 +783,7 @@ case class ColumnarFileSourceScanExec( optionalNumCoalescedBuckets, dataFilters, tableIdentifier, - needPriv, - disableBucketedScan, - outputAllAttributes) { + disableBucketedScan) { override def doCanonicalize(): ColumnarFileSourceScanExec = { ColumnarFileSourceScanExec( relation, @@ -818,7 +795,6 @@ case class ColumnarFileSourceScanExec( optionalNumCoalescedBuckets, QueryPlan.normalizePredicates(dataFilters, output), None, - needPriv, disableBucketedScan) } } @@ -883,7 +859,7 @@ case class ColumnarMultipleOperatorExec( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "outputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "omniJitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs") + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches") ) ++ { // Tracking scan time has overhead, we can't afford to do it for each row, and can only do // it for each batch. @@ -907,7 +883,7 @@ case class ColumnarMultipleOperatorExec( val numOutputRows = longMetric("numOutputRows") val scanTime = longMetric("scanTime") val numInputRows = longMetric("numInputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val addInputTime = longMetric("addInputTime") val omniCodegenTime = longMetric("omniJitTime") val getOutputTime = longMetric("outputTime") @@ -954,7 +930,7 @@ case class ColumnarMultipleOperatorExec( projectOperator1.close() }) - val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, + val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes1, buildJoinColsExp1, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp1 = buildOpFactory1.createOperator() @@ -970,7 +946,7 @@ case class ColumnarMultipleOperatorExec( } buildOp1.getOutput val lookupOpFactory1 = new OmniLookupJoinWithExprOperatorFactory(probeTypes1, probeOutputCols1, - probeHashColsExp1, buildOutputCols1, buildOutputTypes1, OMNI_JOIN_TYPE_INNER, buildOpFactory1, + probeHashColsExp1, buildOutputCols1, buildOutputTypes1, buildOpFactory1, if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp1 = lookupOpFactory1.createOperator() @@ -988,7 +964,7 @@ case class ColumnarMultipleOperatorExec( projectOperator2.close() }) - val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, + val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes2, buildJoinColsExp2, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp2 = buildOpFactory2.createOperator() @@ -1004,7 +980,7 @@ case class ColumnarMultipleOperatorExec( } buildOp2.getOutput val lookupOpFactory2 = new OmniLookupJoinWithExprOperatorFactory(probeTypes2, probeOutputCols2, - probeHashColsExp2, buildOutputCols2, buildOutputTypes2, OMNI_JOIN_TYPE_INNER, buildOpFactory2, + probeHashColsExp2, buildOutputCols2, buildOutputTypes2, buildOpFactory2, if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp2 = lookupOpFactory2.createOperator() @@ -1023,7 +999,7 @@ case class ColumnarMultipleOperatorExec( projectOperator3.close() }) - val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, + val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes3, buildJoinColsExp3, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp3 = buildOpFactory3.createOperator() @@ -1039,7 +1015,7 @@ case class ColumnarMultipleOperatorExec( } buildOp3.getOutput val lookupOpFactory3 = new OmniLookupJoinWithExprOperatorFactory(probeTypes3, probeOutputCols3, - probeHashColsExp3, buildOutputCols3, buildOutputTypes3, OMNI_JOIN_TYPE_INNER, buildOpFactory3, + probeHashColsExp3, buildOutputCols3, buildOutputTypes3, buildOpFactory3, if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp3 = lookupOpFactory3.createOperator() @@ -1058,7 +1034,7 @@ case class ColumnarMultipleOperatorExec( projectOperator4.close() }) - val buildOpFactory4 = new OmniHashBuilderWithExprOperatorFactory(buildTypes4, + val buildOpFactory4 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes4, buildJoinColsExp4, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp4 = buildOpFactory4.createOperator() @@ -1074,7 +1050,7 @@ case class ColumnarMultipleOperatorExec( } buildOp4.getOutput val lookupOpFactory4 = new OmniLookupJoinWithExprOperatorFactory(probeTypes4, probeOutputCols4, - probeHashColsExp4, buildOutputCols4, buildOutputTypes4, OMNI_JOIN_TYPE_INNER, buildOpFactory4, + probeHashColsExp4, buildOutputCols4, buildOutputTypes4, buildOpFactory4, if (joinFilter4.nonEmpty) {Optional.of(joinFilter4.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp4 = lookupOpFactory4.createOperator() @@ -1171,7 +1147,7 @@ case class ColumnarMultipleOperatorExec( vector.setVec(vecBatch.getVectors()(resultIdxToOmniResultIdxMap(i))) } numOutputRows += vecBatch.getRowCount - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) vecBatch.close() new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) @@ -1232,7 +1208,7 @@ case class ColumnarMultipleOperatorExec1( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "outputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "omniJitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), //operator metric "lookupAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup addInput"), // @@ -1259,7 +1235,7 @@ case class ColumnarMultipleOperatorExec1( val numOutputRows = longMetric("numOutputRows") val scanTime = longMetric("scanTime") val numInputRows = longMetric("numInputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val addInputTime = longMetric("addInputTime") val omniCodegenTime = longMetric("omniJitTime") val getOutputTime = longMetric("outputTime") @@ -1319,7 +1295,7 @@ case class ColumnarMultipleOperatorExec1( projectOperator1.close() }) - val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, + val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes1, buildJoinColsExp1, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp1 = buildOpFactory1.createOperator() @@ -1335,7 +1311,7 @@ case class ColumnarMultipleOperatorExec1( } buildOp1.getOutput val lookupOpFactory1 = new OmniLookupJoinWithExprOperatorFactory(probeTypes1, probeOutputCols1, - probeHashColsExp1, buildOutputCols1, buildOutputTypes1, OMNI_JOIN_TYPE_INNER, buildOpFactory1, + probeHashColsExp1, buildOutputCols1, buildOutputTypes1, buildOpFactory1, if (joinFilter1.nonEmpty) {Optional.of(joinFilter1.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp1 = lookupOpFactory1.createOperator() @@ -1354,7 +1330,7 @@ case class ColumnarMultipleOperatorExec1( projectOperator2.close() }) - val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, + val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes2, buildJoinColsExp2, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp2 = buildOpFactory2.createOperator() @@ -1370,7 +1346,7 @@ case class ColumnarMultipleOperatorExec1( } buildOp2.getOutput val lookupOpFactory2 = new OmniLookupJoinWithExprOperatorFactory(probeTypes2, probeOutputCols2, - probeHashColsExp2, buildOutputCols2, buildOutputTypes2, OMNI_JOIN_TYPE_INNER, buildOpFactory2, + probeHashColsExp2, buildOutputCols2, buildOutputTypes2, buildOpFactory2, if (joinFilter2.nonEmpty) {Optional.of(joinFilter2.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp2 = lookupOpFactory2.createOperator() @@ -1389,7 +1365,7 @@ case class ColumnarMultipleOperatorExec1( projectOperator3.close() }) - val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, + val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes3, buildJoinColsExp3, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp3 = buildOpFactory3.createOperator() @@ -1405,7 +1381,7 @@ case class ColumnarMultipleOperatorExec1( } buildOp3.getOutput val lookupOpFactory3 = new OmniLookupJoinWithExprOperatorFactory(probeTypes3, probeOutputCols3, - probeHashColsExp3, buildOutputCols3, buildOutputTypes3, OMNI_JOIN_TYPE_INNER, buildOpFactory3, + probeHashColsExp3, buildOutputCols3, buildOutputTypes3, buildOpFactory3, if (joinFilter3.nonEmpty) {Optional.of(joinFilter3.get)} else {Optional.empty()}, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp3 = lookupOpFactory3.createOperator() @@ -1504,7 +1480,7 @@ case class ColumnarMultipleOperatorExec1( vector.setVec(vecBatch.getVectors()(resultIdxToOmniResultIdxMap(i))) } numOutputRows += vecBatch.getRowCount - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) vecBatch.close() new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index 71d79f5c27323530b58a9757b56bf8e690d33c36..8eff1774a10663f0e7249b9ad1f0abe991ced544 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import java.io.File +import java.util.UUID import java.util.concurrent.TimeUnit.NANOSECONDS import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP @@ -27,8 +31,9 @@ import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.constants.FunctionType import nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory -import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SparkSpillConfig} import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -41,6 +46,7 @@ import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils /** * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. @@ -81,11 +87,12 @@ case class ColumnarHashAggregateExec( override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), - "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) protected override def needHashTable: Boolean = true @@ -199,14 +206,25 @@ case class ColumnarHashAggregateExec( } } + val tmpSparkConf = sparkContext.conf + + def generateSpillDir(conf: SparkConf, subDir: String): String = { + val localDirs: Array[String] = Utils.getConfiguredLocalDirs(conf) + val hash = Utils.nonNegativeHash(UUID.randomUUID.toString) + val root = localDirs(hash % localDirs.length) + val dir = new File(root, subDir) + dir.getCanonicalPath + } + override def doExecuteColumnar(): RDD[ColumnarBatch] = { val addInputTime = longMetric("addInputTime") val numInputRows = longMetric("numInputRows") - val numInputVecBatchs = longMetric("numInputVecBatchs") + val numInputVecBatches= longMetric("numInputVecBatches") val omniCodegenTime = longMetric("omniCodegenTime") val getOutputTime = longMetric("getOutputTime") val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") + val spillSize = longMetric("spillSize") val attrExpsIdMap = getExprIdMap(child.output) val omniGroupByChanel = groupingExpressions.map( @@ -241,20 +259,14 @@ case class ColumnarHashAggregateExec( case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else if (exp.mode == PartialMerge) { - exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => - omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) - omniAggOutputTypes(index) = - toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) - omniAggChannels(index) = - toOmniAggInOutJSonExp(exp.aggregateFunction.inputAggBufferAttributes, attrExpsIdMap) - omniInputRaws(index) = false - omniOutputPartials(index) = true - if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { - omniAggChannels(index) = null - } - case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") - } + ColumnarHashAggregateExec.AssignOmniInfoWhenPartialMergeStage(exp, + attrExpsIdMap, + index, + omniInputRaws, + omniOutputPartials, + omniAggFunctionTypes, + omniAggOutputTypes, + omniAggChannels) } else if (exp.mode == Partial) { exp.aggregateFunction match { case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => @@ -265,6 +277,9 @@ case class ColumnarHashAggregateExec( toOmniAggInOutJSonExp(exp.aggregateFunction.children, attrExpsIdMap) omniInputRaws(index) = true omniOutputPartials(index) = true + if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { + omniAggChannels(index) = null + } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else { @@ -281,6 +296,15 @@ case class ColumnarHashAggregateExec( } child.executeColumnar().mapPartitionsWithIndex { (index, iter) => + val columnarConf = ColumnarPluginConfig.getSessionConf + val hashAggSpillRowThreshold = columnarConf.columnarHashAggSpillRowThreshold + val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold + val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize + val hashAggSpillEnable = columnarConf.enableHashAggSpill + val spillDirectory = generateSpillDir(tmpSparkConf, "columnarHashAggSpill") + val sparkSpillConf = new SparkSpillConfig(hashAggSpillEnable, spillDirectory, + spillDirDiskReserveSize, hashAggSpillRowThreshold, spillMemPctThreshold) + val startCodegen = System.nanoTime() val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions, omniGroupByChanel, @@ -290,11 +314,13 @@ case class ColumnarHashAggregateExec( omniAggFunctionTypes, omniAggOutputTypes, omniInputRaws, - omniOutputPartials) + omniOutputPartials, + sparkSpillConf) omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + spillSize += operator.getSpilledBytes() operator.close() }) @@ -305,7 +331,7 @@ case class ColumnarHashAggregateExec( val vecBatch = new VecBatch(input, batch.numRows()) operator.addInput(vecBatch) addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) - numInputVecBatchs += 1 + numInputVecBatches+= 1 numInputRows += batch.numRows() } val startGetOp = System.nanoTime() @@ -339,7 +365,7 @@ case class ColumnarHashAggregateExec( } numOutputRows += vecBatch.getRowCount - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 vecBatch.close() new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) @@ -364,3 +390,27 @@ case class ColumnarHashAggregateExec( throw new UnsupportedOperationException("This operator doesn't support doExecute().") } } + +object ColumnarHashAggregateExec { + def AssignOmniInfoWhenPartialMergeStage( + exp:AggregateExpression, + exprsIdMap: Map[ExprId, Int], + index: Int, + omniInputRaws : Array[Boolean], + omniOutputPartials : Array[Boolean], + omniAggFunctionTypes : Array[FunctionType], + omniAggOutputTypes : Array[Array[DataType]], + omniAggChannels : Array[Array[String]]): Unit ={ + exp.aggregateFunction match { + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + omniAggOutputTypes(index) = + toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) + omniAggChannels(index) = + toOmniAggInOutJSonExp(exp.aggregateFunction.inputAggBufferAttributes, exprsIdMap) + omniInputRaws(index) = false + omniOutputPartials(index) = true + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala index fcd0bb9e139170b27d71a7517f04262a98664590..4ce57e12ab8cb9f939664c4ce8a59b441d65ffd2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala @@ -21,11 +21,14 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, getExprIdMap, isSimpleColumnForAll, rewriteToOmniJsonExpressionLiteral, sparkTypeToOmniType} import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{addAllAndGetIterator, genSortParam} +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{addAllAndGetIterator, genSortParam, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.limit.OmniLimitOperatorFactory import nova.hetu.omniruntime.operator.topn.OmniTopNWithExprOperatorFactory -import org.apache.spark.rdd.RDD +import nova.hetu.omniruntime.vector.VecBatch + +import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} @@ -34,6 +37,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.ColumnarProjection.dealPartitionData import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -49,30 +53,74 @@ trait ColumnarBaseLimitExec extends LimitExec { override def output: Seq[Attribute] = child.output + override lazy val metrics = Map( + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches")) + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniCodegenTime") + val getOutputTime = longMetric("getOutputTime") + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatches= longMetric("numOutputVecBatches") + child.executeColumnar().mapPartitions { iter => - val hasInput = iter.hasNext - if (hasInput) { - new Iterator[ColumnarBatch] { - var rowCount = 0 - override def hasNext: Boolean = { - val hasNext = iter.hasNext - hasNext && (rowCount < limit) + + val startCodegen = System.nanoTime() + val limitOperatorFactory = new OmniLimitOperatorFactory(limit) + val limitOperator = limitOperatorFactory.createOperator + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + limitOperator.close() + }) + + val localSchema = this.schema + new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + + override def hasNext: Boolean = { + while ((results == null || !results.hasNext) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startInput = System.nanoTime() + limitOperator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + + val startGetOp = System.nanoTime() + results = limitOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + } + if (results == null) { + false + } else { + val startGetOp: Long = System.nanoTime() + val hasNext = results.hasNext + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + hasNext } + } - override def next(): ColumnarBatch = { - val output = iter.next() - val preRowCount = rowCount - rowCount += output.numRows - if (rowCount > limit) { - val newSize = limit - preRowCount - output.setNumRows(newSize) - } - output + override def next(): ColumnarBatch = { + val startGetOp = System.nanoTime() + val vecBatch = results.next() + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) } + numOutputRows += vecBatch.getRowCount + numOutputVecBatches+= 1 + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) } - } else { - Iterator.empty } } } @@ -147,8 +195,8 @@ case class ColumnarTakeOrderedAndProjectExec( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), - "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput") ) ++ readMetrics ++ writeMetrics @@ -181,65 +229,77 @@ case class ColumnarTakeOrderedAndProjectExec( } override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val (sourceTypes, ascendings, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) + val childRDD = child.executeColumnar() + val childRDDPartitions = childRDD.getNumPartitions + + if (childRDDPartitions == 0) { + new ParallelCollectionRDD(sparkContext, Seq.empty[ColumnarBatch], 1, Map.empty) + } else { + val (sourceTypes, ascending, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) + + def computeTopN(iter: Iterator[ColumnarBatch], schema: StructType): Iterator[ColumnarBatch] = { + val startCodegen = System.nanoTime() + val topNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, limit, sortColsExp, ascending, nullFirsts, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val topNOperator = topNOperatorFactory.createOperator + longMetric("omniCodegenTime") += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + topNOperator.close() + }) + addAllAndGetIterator(topNOperator, iter, schema, + longMetric("addInputTime"), longMetric("numInputVecBatches"), longMetric("numInputRows"), + longMetric("getOutputTime"), longMetric("numOutputVecBatches"), longMetric("numOutputRows"), + longMetric("outputDataSize")) + } - def computeTopN(iter: Iterator[ColumnarBatch], schema: StructType): Iterator[ColumnarBatch] = { - val startCodegen = System.nanoTime() - val topNOperatorFactory = new OmniTopNWithExprOperatorFactory(sourceTypes, limit, sortColsExp, ascendings, nullFirsts, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) - val topNOperator = topNOperatorFactory.createOperator - longMetric("omniCodegenTime") += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) - SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { - topNOperator.close() - }) - addAllAndGetIterator(topNOperator, iter, schema, - longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), - longMetric("getOutputTime"), longMetric("numOutputVecBatchs"), longMetric("numOutputRows"), - longMetric("outputDataSize")) - } + val singlePartitionRDD = if (childRDDPartitions == 1) { + childRDD + } else { + val localTopK: RDD[ColumnarBatch] = { + child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => + computeTopN(iter, this.child.schema) + } + } - val localTopK: RDD[ColumnarBatch] = { - child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => - computeTopN(iter, this.child.schema) + new ShuffledColumnarRDD( + ColumnarShuffleExchangeExec.prepareShuffleDependency( + localTopK, + child.output, + SinglePartition, + serializer, + writeMetrics, + longMetric("dataSize"), + longMetric("bytesSpilled"), + longMetric("numInputRows"), + longMetric("splitTime"), + longMetric("spillTime")), + readMetrics) } - } - val shuffled = new ShuffledColumnarRDD( - ColumnarShuffleExchangeExec.prepareShuffleDependency( - localTopK, - child.output, - SinglePartition, - serializer, - writeMetrics, - longMetric("dataSize"), - longMetric("bytesSpilled"), - longMetric("numInputRows"), - longMetric("splitTime"), - longMetric("spillTime")), - readMetrics) - val projectEqualChildOutput = projectList == child.output - var omniInputTypes: Array[DataType] = null - var omniExpressions: Array[String] = null - var addInputTime: SQLMetric = null - var omniCodegenTime: SQLMetric = null - var getOutputTime: SQLMetric = null - if (!projectEqualChildOutput) { - omniInputTypes = child.output.map( - exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray - omniExpressions = projectList.map( - exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(child.output))).toArray - addInputTime = longMetric("addInputTime") - omniCodegenTime = longMetric("omniCodegenTime") - getOutputTime = longMetric("getOutputTime") - } - shuffled.mapPartitions { iter => - // TopN = omni-top-n + omni-project - val topN: Iterator[ColumnarBatch] = computeTopN(iter, this.child.schema) + val projectEqualChildOutput = projectList == child.output + var omniInputTypes: Array[DataType] = null + var omniExpressions: Array[String] = null + var addInputTime: SQLMetric = null + var omniCodegenTime: SQLMetric = null + var getOutputTime: SQLMetric = null if (!projectEqualChildOutput) { - dealPartitionData(null, null, addInputTime, omniCodegenTime, - getOutputTime, omniInputTypes, omniExpressions, topN, this.schema) - } else { - topN + omniInputTypes = child.output.map( + exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + omniExpressions = projectList.map( + exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(child.output))).toArray + addInputTime = longMetric("addInputTime") + omniCodegenTime = longMetric("omniCodegenTime") + getOutputTime = longMetric("getOutputTime") + } + singlePartitionRDD.mapPartitions { iter => + // TopN = omni-top-n + omni-project + val topN: Iterator[ColumnarBatch] = computeTopN(iter, this.child.schema) + if (!projectEqualChildOutput) { + dealPartitionData(null, null, addInputTime, omniCodegenTime, + getOutputTime, omniInputTypes, omniExpressions, topN, this.schema) + } else { + topN + } } } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala index 0ccdbd6de43c3cbd62b455ef7adf6e487cc69c45..49e6968685ccd2079776ace720aaa2342791ca9e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * @since 2022/3/5 */ object ColumnarProjection { - def dealPartitionData(numOutputRows: SQLMetric, numOutputVecBatchs: SQLMetric, + def dealPartitionData(numOutputRows: SQLMetric, numOutputVecBatches: SQLMetric, addInputTime: SQLMetric, omniCodegenTime: SQLMetric, getOutputTime: SQLMetric, omniInputTypes: Array[DataType], @@ -92,8 +92,8 @@ object ColumnarProjection { if(numOutputRows != null) { numOutputRows += result.getRowCount } - if (numOutputVecBatchs != null) { - numOutputVecBatchs += 1 + if (numOutputVecBatches!= null) { + numOutputVecBatches+= 1 } result.close() new ColumnarBatch(vectors.toArray, result.getRowCount) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index ec3e6d5eae7177467e249c8314eb9821561ce733..6e6588304b3ba809fc9a9a561e66703a4f988122 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExch import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor import org.apache.spark.sql.execution.metric._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleWriteMetricsReporter} -import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.internal.SQLConf @@ -73,9 +73,10 @@ case class ColumnarShuffleExchangeExec( "avgReadBatchNumRows" -> SQLMetrics .createAverageMetric(sparkContext, "avg read batch num rows"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), - "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), - "numOutputRows" -> SQLMetrics - .createMetric(sparkContext, "number of output rows")) ++ readMetrics ++ writeMetrics + "numMergedVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatches"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions") + ) ++ readMetrics ++ writeMetrics override def nodeName: String = "OmniColumnarShuffleExchange" @@ -123,9 +124,15 @@ case class ColumnarShuffleExchangeExec( longMetric("numInputRows"), longMetric("splitTime"), longMetric("spillTime")) + metrics("numPartitions").set(dep.partitioner.numPartitions) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics("numPartitions") :: Nil) dep } - var cachedShuffleRDD: ShuffledColumnarRDD = _ + + private var cachedShuffleRDD: ShuffledColumnarRDD = null + + private val enableShuffleBatchMerge: Boolean = ColumnarPluginConfig.getSessionConf.enableShuffleBatchMerge override def doExecute(): RDD[InternalRow] = { throw new UnsupportedOperationException() @@ -153,18 +160,22 @@ case class ColumnarShuffleExchangeExec( if (cachedShuffleRDD == null) { cachedShuffleRDD = new ShuffledColumnarRDD(columnarShuffleDependency, readMetrics) } - val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + if (enableShuffleBatchMerge) { cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => - new MergeIterator(iter, + val mergeIterator = new MergeIterator(iter, StructType.fromAttributes(child.output), - longMetric("numMergedVecBatchs")) + longMetric("numMergedVecBatches")) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } } else { cachedShuffleRDD } } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarShuffleExchangeExec = copy(child = newChild) } @@ -194,10 +205,10 @@ object ColumnarShuffleExchangeExec extends Logging { val rddForSampling = rdd.mapPartitionsInternal { iter => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) iter.flatMap(batch => { val rows: Iterator[InternalRow] = batch.rowIterator.asScala - val projection = - UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) val mutablePair = new MutablePair[InternalRow, Null]() new Iterator[MutablePair[InternalRow, Null]] { var closed = false @@ -246,7 +257,7 @@ object ColumnarShuffleExchangeExec extends Logging { for (i <- 0 until columnarBatch.numRows()) { val partitionId = TaskContext.get().partitionId() val position = new XORShiftRandom(partitionId).nextInt(numPartitions) - pidArr(i) = position + 1 + pidArr(i) = position } val vec = new IntVec(columnarBatch.numRows()) vec.put(pidArr, 0, 0, pidArr.length) @@ -261,9 +272,8 @@ object ColumnarShuffleExchangeExec extends Logging { (0, new ColumnarBatch(newColumns, cb.numRows)) } - def computePartitionId( - cbIter: Iterator[ColumnarBatch], - partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = { + def computePartitionId(cbIter: Iterator[ColumnarBatch], + partitionKeyExtractor: InternalRow => Any): Iterator[(Int, ColumnarBatch)] = { val addPid2ColumnBatch = addPidToColumnBatch() cbIter.filter(cb => cb.numRows != 0 && cb.numCols != 0).map { cb => @@ -314,7 +324,7 @@ object ColumnarShuffleExchangeExec extends Logging { newIter }, isOrderSensitive = isOrderSensitive) case h@HashPartitioning(expressions, numPartitions) => - if (containsRollUp(expressions)) { + if (containsRollUp(expressions) || expressions.length > 6) { rdd.mapPartitionsWithIndexInternal((_, cbIter) => { val partitionKeyExtractor: InternalRow => Any = { val projection = @@ -414,4 +424,4 @@ object ColumnarShuffleExchangeExec extends Logging { } } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index 04955a9eff9c1c7959cf1af0cb2434827a3390e7..55e4c6d5d38e0c7e7d4b33943089ba5dabd8e4cc 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{File, IOException} +import java.io.File import java.util.UUID import java.util.concurrent.TimeUnit.NANOSECONDS @@ -43,9 +43,6 @@ case class ColumnarSortExec( child: SparkPlan, testSpillFrequency: Int = 0) extends UnaryExecNode { - - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - override def supportsColumnar: Boolean = true override def nodeName: String = "OmniColumnarSort" @@ -63,76 +60,62 @@ case class ColumnarSortExec( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil override lazy val metrics = Map( - "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), - "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "outputDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "output data size"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) def buildCheck(): Unit = { genSortParam(child.output, sortOrder) } - val sparkConfTmp = sparkContext.conf + val tmpSparkConf = sparkContext.conf - private def generateLocalDirs(conf: SparkConf): Array[File] = { - Utils.getConfiguredLocalDirs(conf).flatMap { rootDir => - val localDir = generateDirs(rootDir, "columnarSortSpill") - Some(localDir) - } - } - - def generateDirs(root: String, namePrefix: String = "spark"):File = { - var attempts = 0 - val maxAttempts = MAX_DIR_CREATION_ATTEMPTS - var dir: File = null - while (dir == null) { - attempts += 1 - if (attempts > maxAttempts) { - throw new IOException("Directory conflict: failed to generate a temp directory for columnarSortSpill " + - "(under " + root + ") after " + maxAttempts + " attempts!") - } - dir = new File(root, namePrefix + "-" + UUID.randomUUID.toString) - if (dir.exists()) { - dir = null - } - } - dir.getCanonicalFile + def generateSpillDir(conf: SparkConf, subDir: String): String = { + val localDirs: Array[String] = Utils.getConfiguredLocalDirs(conf) + val hash = Utils.nonNegativeHash(UUID.randomUUID.toString) + val root = localDirs(hash % localDirs.length) + val dir = new File(root, subDir) + dir.getCanonicalPath } override def doExecuteColumnar(): RDD[ColumnarBatch] = { val omniCodegenTime = longMetric("omniCodegenTime") + val spillSize = longMetric("spillSize") - val (sourceTypes, ascendings, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) + val (sourceTypes, ascending, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) val outputCols = output.indices.toArray child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => val columnarConf = ColumnarPluginConfig.getSessionConf val sortSpillRowThreshold = columnarConf.columnarSortSpillRowThreshold - val sortSpillMemPctThreshold = columnarConf.columnarSortSpillMemPctThreshold - val sortSpillDirDiskReserveSize = columnarConf.columnarSortSpillDirDiskReserveSize + val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold + val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize val sortSpillEnable = columnarConf.enableSortSpill - val sortlocalDirs: Array[File] = generateLocalDirs(sparkConfTmp) - val hash = Utils.nonNegativeHash(SparkEnv.get.executorId) - val dirId = hash % sortlocalDirs.length - val spillPathDir = sortlocalDirs(dirId).getCanonicalPath - val sparkSpillConf = new SparkSpillConfig(sortSpillEnable, spillPathDir, - sortSpillDirDiskReserveSize, sortSpillRowThreshold, sortSpillMemPctThreshold) + val spillDirectory = generateSpillDir(tmpSparkConf, "columnarSortSpill") + val sparkSpillConf = new SparkSpillConfig(sortSpillEnable, spillDirectory, spillDirDiskReserveSize, + sortSpillRowThreshold, spillMemPctThreshold) val startCodegen = System.nanoTime() - val sortOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, outputCols, sortColsExp, ascendings, nullFirsts, - new OperatorConfig(sparkSpillConf, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + + val radixSortEnable = columnarConf.enableRadixSort + val radixSortRowCountThreshold = if(radixSortEnable) {columnarConf.radixSortThreshold} else {-1} + + val sortOperatorFactory = new OmniSortWithExprOperatorFactory(sourceTypes, outputCols, sortColsExp, ascending, nullFirsts, + new OperatorConfig(sparkSpillConf, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP, radixSortRowCountThreshold.asInstanceOf[Int])) val sortOperator = sortOperatorFactory.createOperator omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + spillSize += sortOperator.getSpilledBytes() sortOperator.close() }) addAllAndGetIterator(sortOperator, iter, this.schema, - longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), - longMetric("getOutputTime"), longMetric("numOutputVecBatchs"), longMetric("numOutputRows"), + longMetric("addInputTime"), longMetric("numInputVecBatches"), longMetric("numInputRows"), + longMetric("getOutputTime"), longMetric("numOutputVecBatches"), longMetric("numOutputRows"), longMetric("outputDataSize")) } } @@ -140,4 +123,4 @@ case class ColumnarSortExec( override protected def doExecute(): RDD[InternalRow] = { throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..f87d4bdee98bed76e734971114d8fdd7a19030fd --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.execution.joins.{HashedRelation, HashJoin, LongHashedRelation} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.ThreadUtils + +import scala.concurrent.{Future,ExecutionContext} +import scala.concurrent.duration.Duration + +case class SubqueryBroadcastExec( + name: String, + index: Int, + buildKeys: Seq[Expression], + child: SparkPlan) + extends BaseSubqueryExec with UnaryExecNode { + + override def nodeName: String = { + val exchangeChild = child match { + case exec: ReusedExchangeExec => + exec.child + case _ => + child + } + if (exchangeChild.isInstanceOf[ColumnarBroadcastExchangeExec] || + (exchangeChild.isInstanceOf[AdaptiveSparkPlanExec] + && exchangeChild.asInstanceOf[AdaptiveSparkPlanExec].supportsColumnar)) { + "OmniColumnarSubqueryBroadcastExec" + } else { + "SubqueryBroadcastExec" + } + } + + // `SubqueryBroadcastExec` is only used with `InSubqueryExec`. + // No one would reference this output, + // so the exprId doesn't matter here. But it's important to correctly report the output length, so + // that `InSubqueryExec` can know it's the single-column execution mode, not multi-column. + override def output: Seq[Attribute] = { + val key = buildKeys(index) + val name = key match { + case n: NamedExpression => n.name + case Cast(n: NamedExpression, _, _, _) => n.name + case _ => "key" + } + Seq(AttributeReference(name, key.dataType, key.nullable)()) + } + + // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. + override lazy val metrics = Map( + "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), + "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)")) + + override def doCanonicalize(): SparkPlan = { + val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output)) + copy(name = "dpp", buildKeys = keys, child = child.canonicalized) + } + + @transient + private lazy val relationFuture: Future[Array[InternalRow]] = { + // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + Future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(session, executionId) { + val beforCollect = System.nanoTime() + val exchangeChild = child match { + case exec: ReusedExchangeExec => + exec.child + case _ => + child + } + val rows = if (exchangeChild.isInstanceOf[ColumnarBroadcastExchangeExec] || + (exchangeChild.isInstanceOf[AdaptiveSparkPlanExec] + && exchangeChild.asInstanceOf[AdaptiveSparkPlanExec].supportsColumnar)) { + // transform broadcasted columnar value to Array[InternalRow] by key + exchangeChild + .executeBroadcast[ColumnarHashedRelation] + .value + .transform(buildKeys(index), exchangeChild.output) + .distinct + } else { + val broadcastRelation = exchangeChild.executeBroadcast[HashedRelation]().value + val (iter, expr) = if (broadcastRelation.isInstanceOf[LongHashedRelation]) { + (broadcastRelation.keys(), HashJoin.extractKeyExprAt(buildKeys, index)) + } else { + (broadcastRelation.keys(), + BoundReference(index, buildKeys(index).dataType, buildKeys(index).nullable)) + } + + val proj = UnsafeProjection.create(expr) + val keyIter = iter.map(proj).map(_.copy()) + keyIter.toArray[InternalRow].distinct + } + val beforBuild = System.nanoTime() + longMetric("collectTime") += (beforBuild - beforCollect) / 1000000 + val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes).sum + longMetric("dataSize") += dataSize + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + rows + } + }(SubqueryBroadcastExec.executionContext) + } + + override protected def doPrepare(): Unit = { + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "does not support the execute() code path.") + } + + override def executeCollect(): Array[InternalRow] = { + ThreadUtils.awaitResult(relationFuture, Duration.Inf) + } + + override def stringArgs: Iterator[Any] = super.stringArgs ++ Iterator(s"[id=#$id]") + + protected def withNewChildInternal(newChild: SparkPlan): SubqueryBroadcastExec = + copy(child = newChild) +} + +object SubqueryBroadcastExec { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("dynamicpruning", 16)) +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala index 6fa91733415e5760be2e513ef7e12e4121b5626b..9e52282922c1aceb1da34d650349b7e9562804d7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala @@ -61,20 +61,15 @@ case class ColumnarTopNSortExec( override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), - "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "outputDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "output data size"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches")) def buildCheck(): Unit = { - // current only support rank function of window - // strictTopN true for row_number, false for rank - if (strictTopN) { - throw new UnsupportedOperationException(s"Unsupported strictTopN is true") - } val omniAttrExpsIdMap = getExprIdMap(child.output) val omniPartitionChanels: Array[AnyRef] = partitionSpec.map( exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray @@ -87,12 +82,12 @@ case class ColumnarTopNSortExec( val omniAttrExpsIdMap = getExprIdMap(child.output) val omniPartitionChanels = partitionSpec.map( exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray - val (sourceTypes, ascendings, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) + val (sourceTypes, ascending, nullFirsts, sortColsExp) = genSortParam(child.output, sortOrder) child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => val startCodegen = System.nanoTime() val topNSortOperatorFactory = new OmniTopNSortWithExprOperatorFactory(sourceTypes, n, - strictTopN, omniPartitionChanels, sortColsExp, ascendings, nullFirsts, + strictTopN, omniPartitionChanels, sortColsExp, ascending, nullFirsts, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val topNSortOperator = topNSortOperatorFactory.createOperator omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) @@ -100,8 +95,8 @@ case class ColumnarTopNSortExec( topNSortOperator.close() }) addAllAndGetIterator(topNSortOperator, iter, this.schema, - longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), - longMetric("getOutputTime"), longMetric("numOutputVecBatchs"), longMetric("numOutputRows"), + longMetric("addInputTime"), longMetric("numInputVecBatches"), longMetric("numInputRows"), + longMetric("getOutputTime"), longMetric("numOutputVecBatches"), longMetric("numOutputRows"), longMetric("outputDataSize")) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index 184bbdaf1e1fb1c84ee2316498890cc9ab94be51..7d1828c27d9fa009afb343acbce821fbf0ca1c5d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -17,17 +17,21 @@ package org.apache.spark.sql.execution +import java.io.File +import java.util.UUID import java.util.concurrent.TimeUnit.NANOSECONDS +import com.huawei.boostkit.spark.ColumnarPluginConfig import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.constants.{FunctionType, OmniWindowFrameBoundType, OmniWindowFrameType} -import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SparkSpillConfig} import nova.hetu.omniruntime.operator.window.OmniWindowWithExprOperatorFactory import nova.hetu.omniruntime.vector.VecBatch +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -40,6 +44,7 @@ import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.execution.window.WindowExecBase import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], @@ -55,17 +60,28 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), - "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), + "numInputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatches"), "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), "omniCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) override protected def doExecute(): RDD[InternalRow] = { throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") } + val tmpSparkConf = sparkContext.conf + + def generateSpillDir(conf: SparkConf, subDir: String): String = { + val localDirs: Array[String] = Utils.getConfiguredLocalDirs(conf) + val hash = Utils.nonNegativeHash(UUID.randomUUID.toString) + val root = localDirs(hash % localDirs.length) + val dir = new File(root, subDir) + dir.getCanonicalPath + } + def getWindowFrameParam(frame: SpecifiedWindowFrame): (OmniWindowFrameType, OmniWindowFrameBoundType, OmniWindowFrameBoundType, Int, Int) = { var windowFrameStartChannel = -1 @@ -113,10 +129,23 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val omniAttrExpsIdMap = getExprIdMap(child.output) val windowFrameTypes = new Array[OmniWindowFrameType](winExpressions.size) val windowFrameStartTypes = new Array[OmniWindowFrameBoundType](winExpressions.size) - val winddowFrameStartChannels = new Array[Int](winExpressions.size) + val windowFrameStartChannels = new Array[Int](winExpressions.size) val windowFrameEndTypes = new Array[OmniWindowFrameBoundType](winExpressions.size) - val winddowFrameEndChannels = new Array[Int](winExpressions.size) + val windowFrameEndChannels = new Array[Int](winExpressions.size) var attrMap: Map[String, Int] = Map() + + for (sortAttr <- orderSpec) { + if (!sortAttr.child.isInstanceOf[AttributeReference]) { + throw new UnsupportedOperationException(s"Unsupported sort col : ${sortAttr.child.nodeName}") + } + } + + for (partitionAttr <- partitionSpec) { + if (!partitionAttr.isInstanceOf[AttributeReference]) { + throw new UnsupportedOperationException(s"Unsupported partition col : ${partitionAttr.nodeName}") + } + } + child.output.zipWithIndex.foreach { case (inputIter, i) => sourceTypes(i) = sparkTypeToOmniType(inputIter.dataType, inputIter.metadata) @@ -131,12 +160,12 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], case e@WindowExpression(function, spec) => if (spec.frameSpecification.isInstanceOf[SpecifiedWindowFrame]) { - val winFram = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - if (winFram.lower != UnboundedPreceding && winFram.lower != CurrentRow) { - throw new UnsupportedOperationException(s"Unsupported Specified frame_start: ${winFram.lower}") + val winFrame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + if (winFrame.lower != UnboundedPreceding && winFrame.lower != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_start: ${winFrame.lower}") } - if (winFram.upper != UnboundedFollowing && winFram.upper != CurrentRow) { - throw new UnsupportedOperationException(s"Unsupported Specified frame_end: ${winFram.upper}") + if (winFrame.upper != UnboundedFollowing && winFrame.upper != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_end: ${winFrame.upper}") } } windowFunRetType(index) = sparkTypeToOmniType(function.dataType) @@ -145,13 +174,13 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], windowFrameTypes(index) = winFrameParam._1 windowFrameStartTypes(index) = winFrameParam._2 windowFrameEndTypes(index) = winFrameParam._3 - winddowFrameStartChannels(index) = winFrameParam._4 - winddowFrameEndChannels(index) = winFrameParam._5 + windowFrameStartChannels(index) = winFrameParam._4 + windowFrameEndChannels(index) = winFrameParam._5 function match { // AggregateWindowFunction - case winfunc: WindowFunction => - windowFunType(index) = toOmniWindowFunType(winfunc) - windowArgKeys = winfunc.children.map( + case winFunc: WindowFunction => + windowFunType(index) = toOmniWindowFunType(winFunc) + windowArgKeys = winFunc.children.map( exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray // AggregateExpression case agg@AggregateExpression(aggFunc, _, _, _, _) => @@ -191,11 +220,12 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], override def doExecuteColumnar(): RDD[ColumnarBatch] = { val addInputTime = longMetric("addInputTime") val numInputRows = longMetric("numInputRows") - val numInputVecBatchs = longMetric("numInputVecBatchs") + val numInputVecBatches= longMetric("numInputVecBatches") val omniCodegenTime = longMetric("omniCodegenTime") val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val getOutputTime = longMetric("getOutputTime") + val spillSize = longMetric("spillSize") val sourceTypes = new Array[DataType](child.output.size) val sortCols = new Array[Int](orderSpec.size) @@ -220,24 +250,30 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], sourceTypes(i) = sparkTypeToOmniType(inputIter.dataType, inputIter.metadata) attrMap += (inputIter.name -> i) } - // partition column parameters // sort column parameters var i = 0 for (sortAttr <- orderSpec) { - if (attrMap.contains(sortAttr.child.asInstanceOf[AttributeReference].name)) { - sortCols(i) = attrMap(sortAttr.child.asInstanceOf[AttributeReference].name) - ascendings(i) = sortAttr.isAscending match { - case true => 1 - case _ => 0 - } - nullFirsts(i) = sortAttr.nullOrdering.sql match { - case "NULLS LAST" => 0 - case _ => 1 - } - } else { - throw new UnsupportedOperationException(s"Unsupported sort col not in inputset: ${sortAttr.nodeName}") + val sortExpr = sortAttr.child + sortExpr match { + case attr: AttributeReference => + if (attrMap.contains(attr.name)) { + sortCols(i) = attrMap(attr.name) + } else { + throw new UnsupportedOperationException(s"Unsupported sort col not in inputset: ${sortAttr.nodeName}") + } + case _ => + throw new UnsupportedOperationException(s"Unsupported sort col : ${sortExpr}") + } + ascendings(i) = sortAttr.isAscending match { + case true => 1 + case _ => 0 } + nullFirsts(i) = sortAttr.nullOrdering.sql match { + case "NULLS LAST" => 0 + case _ => 1 + } + i += 1 } @@ -253,14 +289,20 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], i += 1 } - // partitionSpec: Seq[Expression] + // partition column parameters i = 0 for (partitionAttr <- partitionSpec) { - if (attrMap.contains(partitionAttr.asInstanceOf[AttributeReference].name)) { - omminPartitionChannels(i) = attrMap(partitionAttr.asInstanceOf[AttributeReference].name) - } else { - throw new UnsupportedOperationException(s"output col not in input cols: ${partitionAttr}") + partitionAttr match { + case attr: AttributeReference => + if (attrMap.contains(attr.name)) { + omminPartitionChannels(i) = attrMap(attr.name) + } else { + throw new UnsupportedOperationException(s"Partition col not in input cols: ${partitionAttr}") + } + case _ => + throw new UnsupportedOperationException(s"Unsupported partition col : ${partitionAttr}") } + i += 1 } @@ -270,12 +312,12 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], x.foreach { case e@WindowExpression(function, spec) => if (spec.frameSpecification.isInstanceOf[SpecifiedWindowFrame]) { - val winFram = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - if (winFram.lower != UnboundedPreceding && winFram.lower != CurrentRow) { - throw new UnsupportedOperationException(s"Unsupported Specified frame_start: ${winFram.lower}") + val winFrame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + if (winFrame.lower != UnboundedPreceding && winFrame.lower != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_start: ${winFrame.lower}") } - if (winFram.upper != UnboundedFollowing && winFram.upper != CurrentRow) { - throw new UnsupportedOperationException(s"Unsupported Specified frame_end: ${winFram.upper}") + if (winFrame.upper != UnboundedFollowing && winFrame.upper != CurrentRow) { + throw new UnsupportedOperationException(s"Unsupported Specified frame_end: ${winFrame.upper}") } } windowFunRetType(index) = sparkTypeToOmniType(function.dataType) @@ -288,8 +330,8 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], windowFrameEndChannels(index) = winFrameParam._5 function match { // AggregateWindowFunction - case winfunc: WindowFunction => - windowFunType(index) = toOmniWindowFunType(winfunc) + case winFunc: WindowFunction => + windowFunType(index) = toOmniWindowFunType(winFunc) windowArgKeys(index) = null // AggregateExpression case agg@AggregateExpression(aggFunc, _, _, _, _) => @@ -313,17 +355,27 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val windowExpressionWithProjectConstant = windowExpressionWithProject child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + val columnarConf = ColumnarPluginConfig.getSessionConf + val windowSpillEnable = columnarConf.enableWindowSpill + val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize + val windowSpillRowThreshold = columnarConf.columnarWindowSpillRowThreshold + val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold + val spillDirectory = generateSpillDir(tmpSparkConf, "columnarWindowSpill") + val sparkSpillConfig = new SparkSpillConfig(windowSpillEnable, spillDirectory, + spillDirDiskReserveSize, windowSpillRowThreshold, spillMemPctThreshold) + val startCodegen = System.nanoTime() val windowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, outputCols, windowFunType, omminPartitionChannels, preGroupedChannels, sortCols, ascendings, nullFirsts, 0, 10000, windowArgKeys, windowFunRetType, windowFrameTypes, windowFrameStartTypes, windowFrameStartChannels, windowFrameEndTypes, windowFrameEndChannels, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + new OperatorConfig(sparkSpillConfig, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val windowOperator = windowOperatorFactory.createOperator omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + spillSize += windowOperator.getSpilledBytes windowOperator.close() }) @@ -334,7 +386,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val startInput = System.nanoTime() windowOperator.addInput(vecBatch) addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) - numInputVecBatchs += 1 + numInputVecBatches+= 1 numInputRows += batch.numRows() } @@ -344,8 +396,8 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], var windowResultSchema = this.schema if (windowExpressionWithProjectConstant) { - val omnifinalOutSchema = child.output ++ winExpToReferences.map(_.toAttribute) - windowResultSchema = StructType.fromAttributes(omnifinalOutSchema) + val omniFinalOutSchema = child.output ++ winExpToReferences.map(_.toAttribute) + windowResultSchema = StructType.fromAttributes(omniFinalOutSchema) } val outputColSize = outputCols.length val omniWindowResultIter = new Iterator[ColumnarBatch] { @@ -376,7 +428,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], vecBatch.getVectors()(i).close() } numOutputRows += vecBatch.getRowCount - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 vecBatch.close() new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) @@ -395,4 +447,4 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], } } } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala index 7f664121bc7f309d3c3f1226ba49a3afb2e231b9..17c2d2ac08373df5c39ebb7ca5f3db3dec13b349 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.{Dependency, OmniMapOutputTrackerMaster, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} @@ -105,7 +105,7 @@ class ShuffledColumnarRDD( } override def getPreferredLocations(partition: Partition): Seq[String] = { - val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[OmniMapOutputTrackerMaster] partition.asInstanceOf[ShuffledColumnarRDDPartition].spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => startReducerIndex.until(endReducerIndex).flatMap { reducerIndex => diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/SortExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..0ddf89b8c1c3d36b63e7bebd2f9b12e0b1a7f385 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.concurrent.TimeUnit._ +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS +import org.apache.spark.sql.execution.UnsafeExternalRowSorter.PrefixComputer +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator + + +/** + * Base class of [[SortExec]] and [[TopNSortExec]]. All subclasses of this class need to override + * their own sorter which inherits from [[org.apache.spark.sql.execution.AbstractUnsafeRowSorter]] + * to perform corresponding sorting. + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. + * If set, will spill every 'frequency' records. + * */ +abstract class SortExecBase( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryExecNode with BlockingOperatorWithCodegen { + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + // sort performed is local within a given partition so will retain + // child operator's partitioning + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder):: Nil else UnspecifiedDistribution :: Nil + + private val enableRadixSort = conf.enableRadixSort + + override lazy val metrics = Map( + "sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"), + "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size") + ) + + protected val sorterClassName: String + + protected def newSorterInstance( + ordering: Ordering[InternalRow], + prefixComparator: PrefixComparator, + prefixComputer: PrefixComputer, + pageSize: Long, + canSortFullyWIthPrefix: Boolean): AbstractUnsafeRowSorter + + private[sql] var rowSorter: AbstractUnsafeRowSorter = _ + + /** + * This method gets invoked only once for each SortExec instance to initialize + * an AbstractUnsafeRowSorter, both 'plan.execute' and code generation are using it. + * In the code generation code path, we need to call this function outside the class + * so we should make it public + * */ + def createSorter(): AbstractUnsafeRowSorter = { + val ordering = RowOrdering.create(sortOrder, output) + + // THe comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + val canSortFullyWIthPrefix = sortOrder.length == 1 && + SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) + + // The generator for prefix + val prefixExpr = SortPrefix(boundSortExpression) + val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + val prefix = prefixProjection.apply(row) + result.isNull = prefix.isNullAt(0) + result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) + result + } + } + + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + rowSorter = newSorterInstance(ordering, prefixComparator, prefixComputer, + pageSize, canSortFullyWIthPrefix) + + if (testSpillFrequency > 0) { + rowSorter.setTestSpillFrequency(testSpillFrequency) + } + rowSorter + } + + protected override def doExecute(): RDD[InternalRow] = { + val peakMemory = longMetric("peakMemory") + val spillSize = longMetric("spillSize") + val sortTime = longMetric("sortTime") + + child.execute().mapPartitionsInternal { iter => + val sorter = createSorter() + val metrics = TaskContext.get().taskMetrics() + + // Remember spill data size of this task before execute this operator, + // so that we can figure out how many bytes we spilled for this operator. + val spillSizeBefore = metrics.memoryBytesSpilled + val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + sortTime += NANOSECONDS.toMillis(sorter.getSortTimeNanos) + peakMemory += sorter.getPeakMemoryUsage + spillSize += metrics.memoryBytesSpilled - spillSizeBefore + metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) + + sortedIterator + } + } + + override def usedInputs: AttributeSet = AttributeSet(Seq.empty) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs + } + + // Name of sorter variable used in codegen + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, + "needToSort", v => s"$v = true;") + + // Initalize the class member variables. This includes the instance of the Sorter + // and the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + // Inline mutable state since not many Sort operations in a task + sorterVariable = ctx.addMutableState(sorterClassName, "sorter", + v => s"$v = $thisPlan.createSorter();", forceInline = true) + val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics", + v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true) + val sortedIterator = ctx.addMutableState("scala.collection.Iterator", + "sortedIter", forceInline = true) + + val addToSorter = ctx.freshName("addToSorter") + val addToSorterFuncName = ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + val outputRow = ctx.freshName("outputRow") + val peakMemory = metricTerm(ctx, "peakMemory") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + val sortTime = metricTerm(ctx, "sortTime") + s""" + | if ($needToSort) { + | long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $addToSorterFuncName(); + | $sortedIterator = $sorterVariable.sort(); + | $sortTime.add($sorterVariable.getSortTimeNanos() / $NANOS_PER_MILLIS); + | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($limitNotReachedCond $sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + s""" + | ${row.code} + | $sorterVariable.insertRow((UnsafeRow)${row.value}); + """.stripMargin + } + + /** + * In BaseSortExec, we overwrites cleanupResources to close AbstractUnsafeRowSorter. + * */ + + override protected[sql] def cleanupResources(): Unit = { + if (rowSorter != null) { + // There's possible for rowSorter is null here, for example, in the scenario of empty + // iterator in the current task, the downstream physical node(like SortMergeJoinExec) will + // trigger cleanupResources before rowSorter initialized in createSorter + rowSorter.cleanupResources() + } + super.cleanupResources() + } +} + + +/** + * Performs (external) sorting + * */ +case class SortExec( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends SortExecBase(sortOrder, global, child, testSpillFrequency) { + private val enableRadixSort = conf.enableRadixSort + + + override val sorterClassName: String = classOf[UnsafeExternalRowSorter].getName + + override def newSorterInstance( + ordering: Ordering[InternalRow], + prefixComparator: PrefixComparator, + prefixComputer: PrefixComputer, + pageSize: Long, + canSortFullyWIthPrefix: Boolean): UnsafeExternalRowSorter = { + UnsafeExternalRowSorter.create( + schema, + ordering, + prefixComparator, + prefixComputer, + pageSize, + enableRadixSort && canSortFullyWIthPrefix) + } + + override def createSorter(): UnsafeExternalRowSorter = { + super.createSorter().asInstanceOf[UnsafeExternalRowSorter] + } + + override protected def withNewChildInternal(newChild: SparkPlan): SortExec = { + copy(child = newChild) + } +} + +/** + * Performs topN sort + * + * @param strictTopN when true it strictly returns n results. This param distinguishes + * [[RowNumber]] from [[Rank]]. [[RowNumber]] corresponds to true + * and [[Rank]] corresponds to false. + * @param partitionSpec partitionSpec of [[org.apache.spark.sql.execution.window.WindowExec]] + * @param sortOrder orderSpec of [[org.apache.spark.sql.execution.window.WindowExec]] + * */ +case class TopNSortExec( + n: Int, + strictTopN: Boolean, + partitionSpec: Seq[Expression], + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends SortExecBase(sortOrder, global, child, 0) { + + override val sorterClassName: String = classOf[UnsafeTopNRowSorter].getName + + override def newSorterInstance( + ordering: Ordering[InternalRow], + prefixComparator: PrefixComparator, + prefixComputer: PrefixComputer, + pageSize: Long, + canSortFullyWIthPrefix: Boolean): UnsafeTopNRowSorter = { + val partitionSpecProjection = UnsafeProjection.create(partitionSpec, output) + UnsafeTopNRowSorter.create( + n, + strictTopN, + schema, + partitionSpecProjection, + ordering, + prefixComparator, + prefixComputer, + pageSize, + canSortFullyWIthPrefix) + } + + override def createSorter(): UnsafeTopNRowSorter = { + super.createSorter().asInstanceOf[UnsafeTopNRowSorter] + } + + override protected def withNewChildInternal(newChild: SparkPlan): TopNSortExec = { + copy(child = newChild) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..b36a424d22f54fa629e8dfd774c7d503ee75362c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.function.Supplier; + +import scala.collection.Iterator; +import scala.math.Ordering; + +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.internal.config.package$; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; + +public final class UnsafeExternalRowSorter extends AbstractUnsafeRowSorter { + private long numRowsInserted = 0; + private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; + private final UnsafeExternalSorter sorter; + + public abstract static class PrefixComputer { + public static class Prefix { + // Key prefix value, or the null prefix value if isNull = true + public long value; + + // Whether the key is null + public boolean isNull; + } + + /** + * Computes prefix for the given row. For efficiency, the object may be reused in + * further calls to a given PrefixComputer. + * */ + public abstract Prefix computePrefix(InternalRow row); + } + + public static UnsafeExternalRowSorter createWithRecordComparator( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + public static UnsafeExternalRowSorter create( + StructType schema, + Ordering ordering, + PrefixComparator prefixComparator, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + Supplier recordComparatorSupplier = () -> new RowComparator(ordering, schema.length()); + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + private UnsafeExternalRowSorter( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) { + super(schema); + this.prefixComputer = prefixComputer; + final SparkEnv sparkEnv = SparkEnv.get(); + final TaskContext taskContext = TaskContext.get(); + sorter = UnsafeExternalSorter.create( + taskContext.taskMemoryManager(), + sparkEnv.blockManager(), + sparkEnv.serializerManager(), + taskContext, + recordComparatorSupplier, + prefixComparator, + (int) (long) sparkEnv.conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), + pageSizeBytes, + (int) sparkEnv.conf().get( + package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + canUseRadixSort); + } + + @Override + public void insertRow(UnsafeRow row) throws IOException { + final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); + sorter.insertRecord( + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), + prefix.value, + prefix.isNull); + numRowsInserted++; + if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { + sorter.spill(); + } + } + + @Override + public long getPeakMemoryUsage() { + return sorter.getPeakMemoryUsedBytes(); + } + + @Override + public long getSortTimeNanos() { + return sorter.getSortTimeNanos(); + } + + @Override + public void cleanupResources() { + isReleased = true; + sorter.cleanupResources(); + } + + @Override + public Iterator sort() throws IOException { + try { + final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); + if (!sortedIterator.hasNext()) { + // Since we won't ever call next() on an empty iterator, we need to clean up resources + // here in order to prevent memory leaks. + cleanupResources(); + } + return new RowIterator() { + private final int numFields = schema.length(); + private UnsafeRow row = new UnsafeRow(numFields); + + @Override + public boolean advanceNext() { + try { + if (!isReleased && sortedIterator.hasNext()) { + sortedIterator.loadNext(); + row.pointTo( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + sortedIterator.getRecordLength()); + // Here is the initial buf ifx in SPARK-9364: the bug fix of use-after-free bug + // when returning the last row from an iterator. For example, in + // [[GroupedIterator]], we still use the last row after traversing the iterator + // in 'fetchNextGroupIterator' + if (!sortedIterator.hasNext()) { + row = row.copy(); // so that we don't have dangling pointers to freed page + cleanupResources(); + } + return true; + } else { + row = null; // so that we don't keep reference to the base object + return false; + } + } catch (IOException e) { + cleanupResources(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception. + Platform.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + } + + @Override + public UnsafeRow getRow() { return row; } + }.toScala(); + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + @Override + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..6a27c8edfa16042201f37addc0d0e0783fa81d5c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/UnsafeTopNRowSorter.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.*; +import java.util.function.Supplier; + +import scala.collection.Iterator; +import scala.math.Ordering; + +import org.apache.spark.TaskContext; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.topnsort.UnsafeInMemoryTopNSorter; +import org.apache.spark.sql.execution.topnsort.UnsafePartitionedTopNSorter; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; + +public final class UnsafeTopNRowSorter extends AbstractUnsafeRowSorter { + + private final UnsafePartitionedTopNSorter partitionedTopNSorter; + + // partition key + private final UnsafeProjection partitionSpecProjection; + + // order(rank) key + private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; + + private long totalSortTimeNanos = 0L; + private final long timeNanosBeforeInsertRow; + + public static UnsafeTopNRowSorter create( + int n, + boolean strictTopN, + StructType schema, + UnsafeProjection partitionSpecProjection, + Ordering orderingOfRankKey, + PrefixComparator prefixComparator, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canSortFullyWithPrefix) { + Supplier recordComparatorSupplier = + () -> new RowComparator(orderingOfRankKey, schema.length()); + return new UnsafeTopNRowSorter( + n, strictTopN, schema, partitionSpecProjection, recordComparatorSupplier, + prefixComparator, prefixComputer, pageSizeBytes, canSortFullyWithPrefix); + } + + private UnsafeTopNRowSorter( + int n, + boolean strictTopN, + StructType schema, + UnsafeProjection partitionSpecProjection, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + UnsafeExternalRowSorter.PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canSortFullyWithPrefix) { + super(schema); + this.prefixComputer = prefixComputer; + final TaskContext taskContext = TaskContext.get(); + this.partitionSpecProjection = partitionSpecProjection; + this.partitionedTopNSorter = UnsafePartitionedTopNSorter.create( + n, + strictTopN, + taskContext.taskMemoryManager(), + taskContext, + recordComparatorSupplier, + prefixComparator, + pageSizeBytes, + canSortFullyWithPrefix); + timeNanosBeforeInsertRow = System.nanoTime(); + } + + @Override + public void insertRow(UnsafeRow row) throws IOException { + final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); + UnsafeRow partKey = partitionSpecProjection.apply(row); + partitionedTopNSorter.insertRow(partKey, row, prefix.value); + } + + /** + * Return the peak memory used so far, in bytes. + * */ + @Override + public long getPeakMemoryUsage() { + return partitionedTopNSorter.getPeakMemoryUsedBytes(); + } + + /** + * @return the total amount of time spent sorting data (in-memory only). + * */ + @Override + public long getSortTimeNanos() { + return totalSortTimeNanos; + } + + @Override + public Iterator sort() throws IOException + { + try { + Map partKeyToSorter = + partitionedTopNSorter.getPartKeyToSorter(); + if (partKeyToSorter.isEmpty()) { + // Since we won't ever call next() on an empty iterator, we need to clean up resources + // here in order to prevent memory leaks. + cleanupResources(); + return emptySortedIterator(); + } + + Queue sortedIteratorsForPartitions = new LinkedList<>(); + for (Map.Entry entry : partKeyToSorter.entrySet()) { + final UnsafeInMemoryTopNSorter topNSorter = entry.getValue(); + final UnsafeSorterIterator unsafeSorterIterator = topNSorter.getSortedIterator(); + + sortedIteratorsForPartitions.add(new RowIterator() + { + private final int numFields = schema.length(); + private UnsafeRow row = new UnsafeRow(numFields); + + @Override + public boolean advanceNext() + { + try { + if (!isReleased && unsafeSorterIterator.hasNext()) { + unsafeSorterIterator.loadNext(); + row.pointTo( + unsafeSorterIterator.getBaseObject(), + unsafeSorterIterator.getBaseOffset(), + unsafeSorterIterator.getRecordLength()); + // Here is the initial buf ifx in SPARK-9364: the bug fix of use-after-free bug + // when returning the last row from an iterator. For example, in + // [[GroupedIterator]], we still use the last row after traversing the iterator + // in 'fetchNextGroupIterator' + if (!unsafeSorterIterator.hasNext()) { + row = row.copy(); // so that we don't have dangling pointers to freed page + topNSorter.freeMemory(); + } + return true; + } + else { + row = null; // so that we don't keep reference to the base object + return false; + } + } catch (IOException e) { + topNSorter.freeMemory(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception. + Platform.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + } + + @Override + public UnsafeRow getRow() + { + return row; + } + }); + } + + // Update total sort time. + if (totalSortTimeNanos == 0L) { + totalSortTimeNanos = System.nanoTime() - timeNanosBeforeInsertRow; + } + final ChainedIterator chainedIterator = new ChainedIterator(sortedIteratorsForPartitions); + return chainedIterator.toScala(); + } catch (Exception e) { + cleanupResources(); + throw e; + } + } + + private Iterator emptySortedIterator() { + return new RowIterator() { + @Override + public boolean advanceNext() { + return false; + } + + @Override + public UnsafeRow getRow() { + return null; + } + }.toScala(); + } + + /** + * Chain multiple UnsafeSorterIterators from PartSorterMap as single one. + * */ + private static final class ChainedIterator extends RowIterator { + private final Queue iterators; + private RowIterator current; + private UnsafeRow row; + + ChainedIterator(Queue iterators) { + assert iterators.size() > 0; + this.iterators = iterators; + this.current = iterators.remove(); + } + + @Override + public boolean advanceNext() { + boolean result = this.current.advanceNext(); + while(!result && !this.iterators.isEmpty()) { + this.current = iterators.remove(); + result = this.current.advanceNext(); + } + if (!result) { + this.row = null; + } else { + this.row = (UnsafeRow) this.current.getRow(); + } + return result; + } + + @Override + public UnsafeRow getRow() { + return row; + } + } + + @Override + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); + } + + @Override + public void cleanupResources() { + isReleased = true; + partitionedTopNSorter.cleanupResources(); + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala index 15e28ceb3825c7c91752fd06caf03b9c24876f4d..741f5f1da1e2ebadcba4ced651b63a67be429dec 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -54,6 +54,7 @@ case class OmniAQEShuffleReadExec( override def supportsColumnar: Boolean = true override def output: Seq[Attribute] = child.output + override lazy val outputPartitioning: Partitioning = { // If it is a local shuffle reader with one mapper per task, then the output partitioning is // the same as the plan before shuffle. @@ -210,7 +211,7 @@ case class OmniAQEShuffleReadExec( override lazy val metrics: Map[String, SQLMetric] = { if (shuffleStage.isDefined) { - Map("numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), + Map("numMergedVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatches"), "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")) ++ { if (isLocalRead) { // We split the mapper partition evenly when creating local shuffle read, so no @@ -243,31 +244,25 @@ case class OmniAQEShuffleReadExec( } } + private val enableShuffleBatchMerge: Boolean = ColumnarPluginConfig.getSessionConf.enableShuffleBatchMerge + private lazy val shuffleRDD: RDD[_] = { shuffleStage match { case Some(stage) => sendDriverMetrics() - val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - val enableShuffleBatchMerge: Boolean = columnarConf.enableShuffleBatchMerge + val rdd = stage.shuffle.asInstanceOf[ColumnarShuffleExchangeExec].getShuffleRDD(partitionSpecs.toArray) if (enableShuffleBatchMerge) { - new ShuffledColumnarRDD( - stage.shuffle - .asInstanceOf[ColumnarShuffleExchangeExec] - .columnarShuffleDependency, - stage.shuffle.asInstanceOf[ColumnarShuffleExchangeExec].readMetrics, - partitionSpecs.toArray).mapPartitionsWithIndexInternal { (index,iter) => - new MergeIterator(iter, - StructType.fromAttributes(child.output), - longMetric("numMergedVecBatchs")) - } - + rdd.mapPartitionsWithIndexInternal { (index,iter) => + val mergeIterator = new MergeIterator(iter, + StructType.fromAttributes(child.output), + longMetric("numMergedVecBatches")) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator + } } else { - new ShuffledColumnarRDD( - stage.shuffle - .asInstanceOf[ColumnarShuffleExchangeExec] - .columnarShuffleDependency, - stage.shuffle.asInstanceOf[ColumnarShuffleExchangeExec].readMetrics, - partitionSpecs.toArray) + rdd } case _ => throw new IllegalStateException("operating on canonicalized plan") @@ -283,5 +278,5 @@ case class OmniAQEShuffleReadExec( } override protected def withNewChildInternal(newChild: SparkPlan): OmniAQEShuffleReadExec = - new OmniAQEShuffleReadExec(newChild, this.partitionSpecs) + copy(child = newChild) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/ExtendedAggUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/ExtendedAggUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..b30104e9d2eba23076157b1e2b62329842fc8cde --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/ExtendedAggUtils.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial} +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, Statistics} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +object ExtendedAggUtils { + def normalizeGroupingExpressions(groupingExpressions: Seq[NamedExpression]) = { + groupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } + } + + def planPartialAggregateWithoutDistinct( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): SparkPlan = { + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + createAggregate( + requiredChildDistributionExpressions = None, + groupingExpressions = groupingExpressions.map(_.toAttribute), + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute), + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, + child = child) + } + + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + isStreaming: Boolean = false, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val useHash = Aggregate.supportsHashAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + + if (useHash) { + HashAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + isStreaming = isStreaming, + numShufflePartitions = None, + groupingExpressions = groupingExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + val objectHashEnabled = child.conf.useObjectHashAggregation + val useObjectHash = Aggregate.supportsObjectHashAggregate(aggregateExpressions) + + if (objectHashEnabled && useObjectHash) { + ObjectHashAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + isStreaming = isStreaming, + numShufflePartitions = None, + groupingExpressions = groupingExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + isStreaming = isStreaming, + numShufflePartitions = None, + groupingExpressions = groupingExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + } + + private def mayRemoveAggFilters(exprs: Seq[AggregateExpression]): Seq[AggregateExpression] = { + exprs.map { ae => + if (ae.filter.isDefined) { + ae.mode match { + case Partial | Complete => ae + case _ => ae.copy(filter = None) + } + } else { + ae + } + } + } +} + +case class DummyLogicalPlan() extends LeafNode { + override def output: Seq[Attribute] = Nil + + override def computeStats(): Statistics = throw new UnsupportedOperationException +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 7325635ff98811b37acaca8cc3fab7c9b1b3fe8d..334800f5111521feeda8ab1fa78a3ed9de436ac3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -19,23 +19,25 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.Serializable import java.net.URI - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.orc.{OrcConf, OrcFile, TypeDescription} +import org.apache.orc.TypeDescription.Category._ import org.apache.orc.mapreduce.OrcInputFormat - import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.util.SparkMemoryUtils -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.sql.types.StringType + +import org.apache.spark.sql.types.DecimalType class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializable { @@ -54,6 +56,44 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ OrcUtils.inferSchema(sparkSession, files, options) } + private def isPPDSafe(filters: Seq[Filter], dataSchema: StructType): Seq[Boolean] = { + def convertibleFiltersHelper(filter: Filter, + dataSchema: StructType): Boolean = filter match { + case And(left, right) => + convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) + case Or(left, right) => + convertibleFiltersHelper(left, dataSchema) && convertibleFiltersHelper(right, dataSchema) + case Not(pred) => + convertibleFiltersHelper(pred, dataSchema) + case other => + other match { + case EqualTo(name, _) => + dataSchema.apply(name).dataType != StringType + case EqualNullSafe(name, _) => + dataSchema.apply(name).dataType != StringType + case LessThan(name, _) => + dataSchema.apply(name).dataType != StringType + case LessThanOrEqual(name, _) => + dataSchema.apply(name).dataType != StringType + case GreaterThan(name, _) => + dataSchema.apply(name).dataType != StringType + case GreaterThanOrEqual(name, _) => + dataSchema.apply(name).dataType != StringType + case IsNull(name) => + dataSchema.apply(name).dataType != StringType + case IsNotNull(name) => + dataSchema.apply(name).dataType != StringType + case In(name, _) => + dataSchema.apply(name).dataType != StringType + case _ => false + } + } + + filters.map { filter => + convertibleFiltersHelper(filter, dataSchema) + } + } + override def buildReaderWithPartitionValues( sparkSession: SparkSession, dataSchema: StructType, @@ -79,60 +119,61 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val conf = broadcastedConf.value.value val filePath = new Path(new URI(file.filePath)) + val isPPDSafeValue = isPPDSafe(filters, dataSchema).reduceOption(_ && _) - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val orcSchema = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema) - val resultedColPruneInfo = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf) - - if (resultedColPruneInfo.isEmpty) { - Iterator.empty - } else { - // ORC predicate pushdown - if (orcFilterPushDown && filters.nonEmpty) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { - fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } + // ORC predicate pushdown + if (orcFilterPushDown && filters.nonEmpty && isPPDSafeValue.getOrElse(false)) { + OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { + fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) } } - - val (requestedColIds, canPruneCols) = resultedColPruneInfo.get - val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols, - dataSchema, resultSchema, partitionSchema, conf) - assert(requestedColIds.length == requiredSchema.length, - "[BUG] requested column IDs do not match required schema") - val taskConf = new Configuration(conf) - - val includeColumns = requestedColIds.filter(_ != -1).sorted.mkString(",") - taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, includeColumns) - val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) - val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) - val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) - - // read data from vectorized reader - val batchReader = new OmniOrcColumnarBatchReader(capacity) - // SPARK-23399 Register a task completion listener first to call `close()` in all cases. - // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) - // after opening a file. - val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - val requestedDataColIds = requestedColIds ++ Array.fill(partitionSchema.length)(-1) - val requestedPartitionColIds = - Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) - SparkMemoryUtils.init() - batchReader.initialize(fileSplit, taskAttemptContext) - batchReader.initBatch( - TypeDescription.fromString(resultSchemaString), - resultSchema.fields, - requestedDataColIds, - requestedPartitionColIds, - file.partitionValues) - - iter.asInstanceOf[Iterator[InternalRow]] + } + + val taskConf = new Configuration(conf) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + // read data from vectorized reader + val batchReader = new OmniOrcColumnarBatchReader(capacity) + // SPARK-23399 Register a task completion listener first to call `close()` in all cases. + // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) + // after opening a file. + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + // fill requestedDataColIds with -1, fil real values int initDataColIds function + val requestedDataColIds = Array.fill(requiredSchema.length)(-1) ++ Array.fill(partitionSchema.length)(-1) + val requestedPartitionColIds = + Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length) + + // 初始化precision数组和scale数组,透传至java侧使用 + val requiredFields = requiredSchema.fields + val fieldslength = requiredFields.length + val precisionArray : Array[Int] = Array.ofDim[Int](fieldslength) + val scaleArray : Array[Int] = Array.ofDim[Int](fieldslength) + for ((reqField, index) <- requiredFields.zipWithIndex) { + val reqdatatype = reqField.dataType + if (reqdatatype.isInstanceOf[DecimalType]) { + val precision = reqdatatype.asInstanceOf[DecimalType].precision + val scale = reqdatatype.asInstanceOf[DecimalType].scale + precisionArray(index) = precision + scaleArray(index) = scale } + } + + SparkMemoryUtils.init() + batchReader.initialize(fileSplit, taskAttemptContext) + batchReader.initDataColIds(requiredSchema, requestedPartitionColIds, requestedDataColIds, resultSchema.fields, + precisionArray, scaleArray) + batchReader.initBatch( + requiredSchema.fields, + resultSchema.fields, + requestedDataColIds, + requestedPartitionColIds, + file.partitionValues) + + iter.asInstanceOf[Iterator[InternalRow]] } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index f9e5937e7c564c562dcbc6b94bd4eb262ac43ef6..ed3ca244b4df110f0cf3f654cba90690ad50c673 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -105,8 +105,8 @@ case class ColumnarBroadcastHashJoinExec( SQLMetrics.createTimingMetric(sparkContext, "time in omni build getOutput"), "buildCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni build codegen"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), - "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs") + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), + "numMergedVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatches") ) override def supportsColumnar: Boolean = true @@ -221,8 +221,12 @@ case class ColumnarBroadcastHashJoinExec( } def buildCheck(): Unit = { + if (isNullAwareAntiJoin) { + throw new UnsupportedOperationException(s"isNullAwareAntiJoin is not supported " + + s"in ${this.nodeName}") + } joinType match { - case LeftOuter | Inner | LeftSemi => + case LeftOuter | Inner | LeftSemi | LeftAnti | RightOuter => case _ => throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + s"in ${this.nodeName}") @@ -277,8 +281,8 @@ case class ColumnarBroadcastHashJoinExec( override def doExecuteColumnar(): RDD[ColumnarBatch] = { // input/output: {col1#10,col2#11,col1#12,col2#13} val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") - val numMergedVecBatchs = longMetric("numMergedVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") + val numMergedVecBatches= longMetric("numMergedVecBatches") val buildAddInputTime = longMetric("buildAddInputTime") val buildCodegenTime = longMetric("buildCodegenTime") val buildGetOutputTime = longMetric("buildGetOutputTime") @@ -297,7 +301,7 @@ case class ColumnarBroadcastHashJoinExec( // {0}, buildKeys: col1#12 val buildOutputCols: Array[Int] = joinType match { - case Inner | LeftOuter => + case Inner | LeftOuter | RightOuter => getIndexArray(buildOutput, projectList) case LeftExistence(_) => Array[Int]() @@ -340,7 +344,7 @@ case class ColumnarBroadcastHashJoinExec( def createBuildOpFactoryAndOp(): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { val startBuildCodegen = System.nanoTime() val opFactory = - new OmniHashBuilderWithExprOperatorFactory(buildTypes, buildJoinColsExp, 1, + new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, buildJoinColsExp, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val op = opFactory.createOperator() @@ -386,7 +390,7 @@ case class ColumnarBroadcastHashJoinExec( val startLookupCodegen = System.nanoTime() val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, - probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, buildOpFactory, filter, + probeHashColsExp, buildOutputCols, buildOutputTypes, buildOpFactory, filter, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp = lookupOpFactory.createOperator() @@ -482,14 +486,18 @@ case class ColumnarBroadcastHashJoinExec( } val rowCnt: Int = result.getRowCount numOutputRows += rowCnt - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 result.close() new ColumnarBatch(vecs.toArray, rowCnt) } } if (enableJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs) + val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } else { iterBatch } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala index 4e1d91bebe0b946c75431d429e2deb08a7aa5739..e041a1fb119bc63bf689361534c207304383e83e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildSide} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -87,7 +87,7 @@ case class ColumnarShuffledHashJoinExec( "time in omni build getOutput"), "buildCodegenTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni build codegen"), - "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "build side input data size") ) @@ -121,7 +121,7 @@ case class ColumnarShuffledHashJoinExec( def buildCheck(): Unit = { joinType match { - case FullOuter | Inner | LeftAnti | LeftOuter | LeftSemi => + case FullOuter | Inner | LeftAnti | LeftOuter | LeftSemi | RightOuter => case _ => throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + s"in ${this.nodeName}") @@ -172,7 +172,7 @@ case class ColumnarShuffledHashJoinExec( */ override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") val buildAddInputTime = longMetric("buildAddInputTime") val buildCodegenTime = longMetric("buildCodegenTime") val buildGetOutputTime = longMetric("buildGetOutputTime") @@ -187,7 +187,7 @@ case class ColumnarShuffledHashJoinExec( } val buildOutputCols: Array[Int] = joinType match { - case Inner | FullOuter | LeftOuter => + case Inner | FullOuter | LeftOuter | RightOuter => getIndexArray(buildOutput, projectList) case LeftExistence(_) => Array[Int]() @@ -226,17 +226,16 @@ case class ColumnarShuffledHashJoinExec( case _ => Optional.empty() } val startBuildCodegen = System.nanoTime() - val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(buildTypes, + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, buildJoinColsExp, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val buildOp = buildOpFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) val startLookupCodegen = System.nanoTime() - val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) - val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, - probeOutputCols, probeHashColsExp, buildOutputCols, buildOutputTypes, lookupJoinType, - buildOpFactory, filter, new OperatorConfig(SpillConfig.NONE, + val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, probeHashColsExp, + buildOutputCols, buildOutputTypes, buildOpFactory, filter, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) val lookupOp = lookupOpFactory.createOperator() lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) @@ -339,7 +338,7 @@ case class ColumnarShuffledHashJoinExec( } val rowCnt: Int = result.getRowCount numOutputRows += rowCnt - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 result.close() new ColumnarBatch(vecs.toArray, rowCnt) } @@ -363,7 +362,9 @@ case class ColumnarShuffledHashJoinExec( override def hasNext: Boolean = { if (output == null) { + val startLookupOuterGetOp = System.nanoTime() output = lookupOuterOp.getOutput + lookupGetOutputTime += NANOSECONDS.toMillis((System.nanoTime() - startLookupOuterGetOp)) } output.hasNext } @@ -391,7 +392,7 @@ case class ColumnarShuffledHashJoinExec( } } numOutputRows += result.getRowCount - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 new ColumnarBatch(vecs.toArray, result.getRowCount) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index 6718e5e7f8934f41b86db307213663a28d270175..c3a22b1ea3a8c8a250e75142af71d47df9f2bcb5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -146,11 +146,11 @@ case class ColumnarSortMergeJoinExec( SQLMetrics.createTimingMetric(sparkContext, "time in omni buffered codegen"), "getOutputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni buffered getOutput"), - "numOutputVecBatchs" -> - SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), - "numMergedVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatchs"), - "numStreamVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of streamed vecBatchs"), - "numBufferVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of buffered vecBatchs") + "numOutputVecBatches" -> + SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), + "numMergedVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatches"), + "numStreamVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of streamed vecBatches"), + "numBufferVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of buffered vecBatches") ) override def verboseStringWithOperatorId(): String = { @@ -232,15 +232,15 @@ case class ColumnarSortMergeJoinExec( override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") - val numOutputVecBatchs = longMetric("numOutputVecBatchs") - val numMergedVecBatchs = longMetric("numMergedVecBatchs") + val numOutputVecBatches= longMetric("numOutputVecBatches") + val numMergedVecBatches= longMetric("numMergedVecBatches") val streamedAddInputTime = longMetric("streamedAddInputTime") val streamedCodegenTime = longMetric("streamedCodegenTime") val bufferedAddInputTime = longMetric("bufferedAddInputTime") val bufferedCodegenTime = longMetric("bufferedCodegenTime") val getOutputTime = longMetric("getOutputTime") - val streamVecBatchs = longMetric("numStreamVecBatchs") - val bufferVecBatchs = longMetric("numBufferVecBatchs") + val streamVecBatches= longMetric("numStreamVecBatches") + val bufferVecBatches= longMetric("numBufferVecBatches") val streamedTypes = new Array[DataType](left.output.size) left.output.zipWithIndex.foreach { case (attr, i) => @@ -328,11 +328,11 @@ case class ColumnarSortMergeJoinExec( def checkAndClose() : Unit = { while (streamedIter.hasNext) { - streamVecBatchs += 1 + streamVecBatches+= 1 streamedIter.next().close() } while(bufferedIter.hasNext) { - bufferVecBatchs += 1 + bufferVecBatches+= 1 bufferedIter.next().close() } } @@ -366,7 +366,7 @@ case class ColumnarSortMergeJoinExec( val startBuildStreamedInput = System.nanoTime() if (!isStreamedFinished && streamedIter.hasNext) { val batch = streamedIter.next() - streamVecBatchs += 1 + streamVecBatches+= 1 val inputVecBatch = transColBatchToVecBatch(batch) decodeOpStatus(streamedOp.addInput(inputVecBatch)) } else { @@ -379,7 +379,7 @@ case class ColumnarSortMergeJoinExec( val startBuildBufferedInput = System.nanoTime() if (!isBufferedFinished && bufferedIter.hasNext) { val batch = bufferedIter.next() - bufferVecBatchs += 1 + bufferVecBatches+= 1 val inputVecBatch = transColBatchToVecBatch(batch) decodeOpStatus(bufferedOp.addInput(inputVecBatch)) } else { @@ -423,7 +423,7 @@ case class ColumnarSortMergeJoinExec( v.setVec(resultVecs(index)) } } - numOutputVecBatchs += 1 + numOutputVecBatches+= 1 numOutputRows += result.getRowCount result.close() new ColumnarBatch(vecs.toArray, result.getRowCount) @@ -462,7 +462,11 @@ case class ColumnarSortMergeJoinExec( } if (enableSortMergeJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatchs) + val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } else { iterBatch } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..7b14bb6694eec58c48cef5a96aa6626ff22ec431 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafeInMemoryTopNSorter.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.topnsort; + +import org.apache.spark.TaskContext; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; + +public final class UnsafeInMemoryTopNSorter { + + private final MemoryConsumer consumer; + private final TaskMemoryManager memoryManager; + private final UnsafePartitionedTopNSorter.TopNSortComparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer to the record at index {@code i}, + * while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + * + * Only part of the array will be used to store the pointers, the rest part is preserved as temporary buffer for sorting. + */ + private LongArray array; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int nextEmptyPos = 0; + + // Top n. + private final int n; + private final boolean strictTopN; + + // The capacity of array. + private final int capacity; + private static final int MIN_ARRAY_CAPACITY = 64; + + public UnsafeInMemoryTopNSorter( + final int n, + final boolean strictTopN, + final MemoryConsumer consumer, + final TaskMemoryManager memoryManager, + final UnsafePartitionedTopNSorter.TopNSortComparator sortComparator) { + this.n = n; + this.strictTopN = strictTopN; + this.consumer = consumer; + this.memoryManager = memoryManager; + this.sortComparator = sortComparator; + this.capacity = Math.max(MIN_ARRAY_CAPACITY, Integer.highestOneBit(n) << 1); + // The size of Long array is equal to twice capacity because each item consists of a prefix and a pointer. + this.array = consumer.allocateArray(capacity << 1); + } + + /** + * Free the memory used by pointer array + */ + public void freeMemory() { + if (consumer != null) { + if (array != null) { + consumer.freeArray(array); + } + array = null; + } + nextEmptyPos = 0; + } + + public long getMemoryUsage() { + if (array == null) { + return 0L; + } + return array.size() * 8; + } + + public int insert(UnsafeRow row, long prefix) { + if (nextEmptyPos < n) { + return insertIntoArray(nextEmptyPos -1, row, prefix); + } else { + // reach n candidates + final int compareResult = nthRecordCompareTo(row, prefix); + if (compareResult < 0) { + // skip this record + return -1; + } + else if (compareResult == 0) { + if (strictTopN) { + // For rows that have duplicate values, skip it if this is strict TopN (e.g. RowNumber). + return -1; + } + // append record + checkForInsert(); + array.set((nextEmptyPos << 1) + 1, prefix); + return nextEmptyPos++; + } + else { + checkForInsert(); + // The record at position n -1 should be excluded, so we start comparing with record at position n - 2. + final int insertPosition = insertIntoArray(n - 2, row, prefix); + if (strictTopN || insertPosition == n - 1 || hasDistinctTopN()) { + nextEmptyPos = n; + } + // For other cases, 'nextEmptyPos' will move to the next empty position in 'insertIntoArray()'. + // e.g. given rank <= 4, and we already have 1, 2, 6, 6, so 'nextEmptyPos' is 4. + // If the new row is 3, then values in the array will be 1, 2, 3, 6, 6, and 'nextEmptyPos' will be 5. + return insertPosition; + } + } + } + + public void updateRecordPointer(int position, long pointer) { + array.set(position << 1, pointer); + } + + private int insertIntoArray(int position, UnsafeRow row, long prefix) { + // find insert position + while (position >= 0 && sortComparator.compare(array.get(position << 1), array.get((position << 1) + 1), row, prefix) > 0) { + --position; + } + final int insertPos = position + 1; + + // move records between 'insertPos' and 'nextEmptyPos' to next positions + for (int i = nextEmptyPos; i > insertPos; --i) { + int src = (i - 1) << 1; + int dst = i << 1; + array.set(dst, array.get(src)); + array.set(dst + 1, array.get(src + 1)); + } + + // Insert prefix of this row. Note that the address will be inserted by 'updateRecordPointer()' + // after we get its address from 'taskMemoryManager' + array.set((insertPos << 1) + 1, prefix); + ++nextEmptyPos; + return insertPos; + } + + private void checkForInsert() { + if (nextEmptyPos >= capacity) { + throw new IllegalStateException("No space for new record.\n" + + "For RANK expressions with TOP-N filter(e.g. rk <= 100), we maintain a fixed capacity " + + "array for TOP-N sorting for each partition, and if there are too many same rankings, " + + "the result that needs to be retained will exceed the capacity of the array.\n" + + "Please consider using ROW_NUMBER expression or disabling TOP-N sorting by setting " + + "saprk.sql.execution.topNPushDownFOrWindow.enabled to false."); + } + } + + private int nthRecordCompareTo(UnsafeRow row, long prefix) { + int nthPos = n - 1; + return sortComparator.compare(array.get(nthPos << 1), array.get((nthPos << 1) + 1), row, prefix); + } + + private boolean hasDistinctTopN() { + int nthPosition = (n - 1) << 1; + return sortComparator.compare(array.get(nthPosition), array.get(nthPosition + 1), // nth record + array.get(nthPosition + 2), array.get(nthPosition + 3)) // (n + 1)th record + != 0; // not eq + } + + /** + * This is copied from + * {@link org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.SortedIterator}. + * */ + public final class TopNSortedIterator extends UnsafeSorterIterator implements Cloneable { + private final int numRecords; + private int position; + private int offset; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); + + private TopNSortedIterator(int numRecords, int offset) { + this.numRecords = numRecords; + this.position = 0; + this.offset = offset; + } + + public TopNSortedIterator clone() { + TopNSortedIterator iter = new TopNSortedIterator(numRecords, offset); + iter.position = position; + iter.baseObject = baseObject; + iter.baseOffset = baseOffset; + iter.keyPrefix = keyPrefix; + iter.recordLength = recordLength; + iter.currentPageNumber = currentPageNumber; + return iter; + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public boolean hasNext() { + return position / 2 < numRecords; + } + + @Override + public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in 'loadNext()' instead of in + // 'hasNext()' because it's technically possible for the caller to be relying on + // 'getNumRecords()' instead of 'hasNext()' to know when to stop. + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); + } + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = array.get(offset + position); + currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + baseObject = memoryManager.getPage(recordPointer); + // Skip over record length + baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize; + recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize); + keyPrefix = array.get(offset + position + 1); + position += 2; + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public long getCurrentPageNumber() { + return currentPageNumber; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + * */ + public UnsafeSorterIterator getSortedIterator() { + return new TopNSortedIterator(nextEmptyPos, 0); + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java new file mode 100644 index 0000000000000000000000000000000000000000..57941aefb4fc8a3234c0ab22b6d45294ae09c639 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/topnsort/UnsafePartitionedTopNSorter.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.topnsort; + +import java.util.*; +import java.util.function.Supplier; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.TaskContext; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.collection.unsafe.sort.*; + +/** + * Partitioned top n sorter based on {@link org.apache.spark.sql.execution.topnsort.UnsafeInMemoryTopNSorter}. + * The implementation mostly refers to {@link UnsafeExternalSorter}. + * */ +public final class UnsafePartitionedTopNSorter extends MemoryConsumer { + private final TaskMemoryManager taskMemoryManager; + private TopNSortComparator sortComparator; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager itself). + * */ + private final LinkedList allocatedPages = new LinkedList<>(); + private final Map partToSorters = new LinkedHashMap<>(); + + private final int n; + private final boolean strictTopN; + private MemoryBlock currentPage = null; + private long pageCursor = -1; + private long peakMemoryUsedBytes = 0; + + public static UnsafePartitionedTopNSorter create( + int n, + boolean strictTopN, + TaskMemoryManager taskMemoryManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + long pageSizeBytes, + boolean canSortFullyWithPrefix) { + assert n > 0 : "Top n must be positive"; + assert recordComparatorSupplier != null; + return new UnsafePartitionedTopNSorter(n, strictTopN, taskMemoryManager, taskContext, + recordComparatorSupplier, prefixComparator, pageSizeBytes, canSortFullyWithPrefix); + } + + private UnsafePartitionedTopNSorter( + int n, + boolean strictTopN, + TaskMemoryManager taskMemoryManager, + TaskContext taskContext, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + long pageSizeBytes, + boolean canSortFullyWithPrefix) { + super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); + this.n = n; + this.strictTopN = strictTopN; + this.taskMemoryManager = taskMemoryManager; + this.sortComparator = new TopNSortComparator(recordComparatorSupplier.get(), + prefixComparator, taskMemoryManager, canSortFullyWithPrefix); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addTaskCompletionListener(context -> { + cleanupResources(); + }); + } + + @Override + public long spill(long size, MemoryConsumer trigger) { + throw new UnsupportedOperationException("Spill is unsupported operation in topN in-memory sorter"); + } + + /** + * Return the total memory usage of this sorter, including the data pages and the sorter's pointer array. + * */ + private long getMemoryUsage() { + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + for (UnsafeInMemoryTopNSorter sorter : partToSorters.values()) { + totalPageSize += sorter.getMemoryUsage(); + } + return totalPageSize; + } + + private void updatePeakMemoryUsed() { + long mem = getMemoryUsage(); + if (mem > peakMemoryUsedBytes) { + peakMemoryUsedBytes = mem; + } + } + + /** + * Return the peak memory used so far, in bytes. + * */ + public long getPeakMemoryUsedBytes() { + updatePeakMemoryUsed(); + return peakMemoryUsedBytes; + } + + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); + } + + /** + * Free this sorter's data pages. + * + * @return the number of bytes freed. + * */ + private long freeMemory() { + updatePeakMemoryUsed(); + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryFreed += block.size(); + freePage(block); + } + allocatedPages.clear(); + currentPage = null; + pageCursor = 0; + for (UnsafeInMemoryTopNSorter sorter: partToSorters.values()) { + memoryFreed += sorter.getMemoryUsage(); + sorter.freeMemory(); + } + partToSorters.clear(); + sortComparator = null; + return memoryFreed; + } + + /** + * Frees this sorter's in-memory data structures and cleans up its spill files. + * */ + public void cleanupResources() { + synchronized (this) { + freeMemory(); + } + } + + /** + * Allocates an additional page in order to insert an additional record. This will request + * additional memory from the memory manager and spill if the requested memory can not be obtained. + * + * @param required the required space in the data page, in bytes, including space for storing the record size + * */ + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) { + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + public void insertRow(UnsafeRow partKey, UnsafeRow row, long prefix) { + UnsafeInMemoryTopNSorter sorter = + partToSorters.computeIfAbsent( + partKey, + k -> new UnsafeInMemoryTopNSorter(n, strictTopN, this, taskMemoryManager, sortComparator) + ); + final int position = sorter.insert(row, prefix); + if (position >= 0) { + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + // Need 4 or 8 bytes to store the record length. + final int length = row.getSizeInBytes(); + final int required = length + uaoSize; + acquireNewPageIfNecessary(required); + + final Object base = currentPage.getBaseObject(); + final long recordAddress = + taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + UnsafeAlignedOffset.putSize(base, pageCursor, length); + pageCursor += uaoSize; + Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), base, pageCursor, length); + pageCursor += length; + + sorter.updateRecordPointer(position, recordAddress); + } + } + + public Map getPartKeyToSorter() { + return partToSorters; + } + + static final class TopNSortComparator { + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + private final boolean needCompareFully; + + TopNSortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager, + boolean canSortFullyWithPrefix) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + this.needCompareFully = !canSortFullyWithPrefix; + } + + public int compare(long pointer1, long prefix1, long pointer2, long prefix2) { + final int prefixComparisonResult = prefixComparator.compare(prefix1, prefix2); + if (needCompareFully && prefixComparisonResult == 0) { + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final Object baseObject1 = memoryManager.getPage(pointer1); + final long baseOffset1 = memoryManager.getOffsetInPage(pointer1) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); + final Object baseObject2 = memoryManager.getPage(pointer2); + final long baseOffset2 = memoryManager.getOffsetInPage(pointer2) + uaoSize; + final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); + } else { + return prefixComparisonResult; + } + } + + public int compare(long pointer, long prefix1, UnsafeRow row, long prefix2) { + final int prefixComparisonResult = prefixComparator.compare(prefix1, prefix2); + if (needCompareFully && prefixComparisonResult == 0) { + final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final Object baseObject1 = memoryManager.getPage(pointer); + final long baseOffset1 = memoryManager.getOffsetInPage(pointer) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); + final Object baseObject2 = row.getBaseObject(); + final long baseOffset2 = row.getBaseOffset(); + final int baseLength2 = row.getSizeInBytes(); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); + } else { + return prefixComparisonResult; + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala index 017eaba23ccf9a464e8f383c1e76d22c5f281953..f9a09780ddc124679df8220048860c4f41fa471b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{BooleanType, DateType, DecimalType, DoubleTyp import org.apache.spark.sql.vectorized.ColumnarBatch class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, - numMergedVecBatchs: SQLMetric) extends Iterator[ColumnarBatch] { + numMergedVecBatches: SQLMetric) extends Iterator[ColumnarBatch] { private val outputQueue = new mutable.Queue[VecBatch] private val bufferedVecBatch = new ListBuffer[VecBatch]() @@ -44,29 +44,40 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, private def createOmniVectors(schema: StructType, columnSize: Int): Array[Vec] = { val vecs = new Array[Vec](schema.fields.length) - schema.fields.zipWithIndex.foreach { case (field, index) => - field.dataType match { - case LongType => - vecs(index) = new LongVec(columnSize) - case DateType | IntegerType => - vecs(index) = new IntVec(columnSize) - case ShortType => - vecs(index) = new ShortVec(columnSize) - case DoubleType => - vecs(index) = new DoubleVec(columnSize) - case BooleanType => - vecs(index) = new BooleanVec(columnSize) - case StringType => - val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) - vecs(index) = new VarcharVec(columnSize) - case dt: DecimalType => - if (DecimalType.is64BitDecimalType(dt)) { + try { + schema.fields.zipWithIndex.foreach { case (field, index) => + field.dataType match { + case LongType => vecs(index) = new LongVec(columnSize) - } else { - vecs(index) = new Decimal128Vec(columnSize) + case DateType | IntegerType => + vecs(index) = new IntVec(columnSize) + case ShortType => + vecs(index) = new ShortVec(columnSize) + case DoubleType => + vecs(index) = new DoubleVec(columnSize) + case BooleanType => + vecs(index) = new BooleanVec(columnSize) + case StringType => + val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) + vecs(index) = new VarcharVec(columnSize) + case dt: DecimalType => + if (DecimalType.is64BitDecimalType(dt)) { + vecs(index) = new LongVec(columnSize) + } else { + vecs(index) = new Decimal128Vec(columnSize) + } + case _ => + throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields") + } + } + } catch { + case e: Exception => { + for (vec <- vecs) { + if (vec != null) { + vec.close() } - case _ => - throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields") + } + throw new RuntimeException("allocate memory failed!") } } vecs @@ -110,7 +121,7 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, val resultBatch: VecBatch = new VecBatch(createOmniVectors(localSchema, totalRows), totalRows) merge(resultBatch, bufferedVecBatch) outputQueue.enqueue(resultBatch) - numMergedVecBatchs += 1 + numMergedVecBatches += 1 bufferedVecBatch.clear() currentBatchSizeInBytes = 0 @@ -122,8 +133,8 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( vecBatch.getRowCount, localSchema, false) vectors.zipWithIndex.foreach { case (vector, i) => - vector.reset() - vector.setVec(vecBatch.getVectors()(i)) + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) } vecBatch.close() new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) @@ -164,4 +175,15 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, def isFull(): Boolean = { totalRows > maxRowCount || currentBatchSizeInBytes >= maxBatchSizeInBytes } + + def close(): Unit = { + for (elem <- bufferedVecBatch) { + elem.releaseAllVectors() + elem.close() + } + for (elem <- outputQueue) { + elem.releaseAllVectors() + elem.close() + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala new file mode 100644 index 0000000000000000000000000000000000000000..94e566f9b571c57c2a0e1cc17143c055d7be9229 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.window; + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{FilterExec, SortExec, SparkPlan, TopNSortExec} + +object TopNPushDownForWindow extends Rule[SparkPlan] with PredicateHelper { + override def apply(plan: SparkPlan): SparkPlan = { + if (!ColumnarPluginConfig.getConf.topNPushDownForWindowEnable) { + return plan + } + + plan.transform { + case f @ FilterExec(condition, + w @ WindowExec(Seq(windowExpression), _, orderSpec, sort: SortExec)) + if orderSpec.nonEmpty && isTopNExpression(windowExpression) => + var topn = Int.MaxValue + val nonTopNConditions = splitConjunctivePredicates(condition).filter { + case LessThan(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case LessThanOrEqual(e: NamedExpression, IntegerLiteral(n)) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case GreaterThan(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n - 1) + false + case GreaterThanOrEqual(IntegerLiteral(n), e: NamedExpression) + if e.exprId == windowExpression.exprId => + topn = Math.min(topn, n) + false + case EqualTo(e: NamedExpression, IntegerLiteral(n)) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case EqualTo(IntegerLiteral(n), e: NamedExpression) + if n == 1 && e.exprId == windowExpression.exprId => + topn = 1 + false + case _ => true + } + + // topn <= SQLConf.get.topNPushDownForWindowThreshold 100. + if (topn> 0 && topn <= ColumnarPluginConfig.getConf.topNPushDownForWindowThreshold) { + val strictTopN = isStrictTopN(windowExpression) + val topNSortExec = TopNSortExec( + topn, strictTopN, w.partitionSpec, w.orderSpec, sort.global, sort.child) + val newCondition = if (nonTopNConditions.isEmpty) { + Literal.TrueLiteral + } else { + nonTopNConditions.reduce(And) + } + FilterExec(newCondition, w.copy(child = topNSortExec)) + } else { + f + } + } + } + + private def isTopNExpression(e: Expression): Boolean = e match { + case Alias(child, _) => isTopNExpression(child) + case WindowExpression(windowFunction, _) + if windowFunction.isInstanceOf[Rank] || windowFunction.isInstanceOf[RowNumber] => true + case _ => false + } + + private def isStrictTopN(e: Expression): Boolean = e match { + case Alias(child, _) => isStrictTopN(child) + case WindowExpression(windowFunction, _) => windowFunction.isInstanceOf[RowNumber] + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..5451091784bd9bcd1b01d1ff416c8279e0ee9a86 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/util/HadoopFSUtils.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.FileNotFoundException + +import scala.collection.mutable + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.viewfs.ViewFileSystem +import org.apache.hadoop.hdfs.DistributedFileSystem + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics + +/** + * Utility functions to simplify and speed-up file listing. + */ +private[spark] object HadoopFSUtils extends Logging { + /** + * Lists a collection of paths recursively. Picks the listing strategy adaptively depending + * on the number of paths to list. + * + * This may only be called on the driver. + * + * @param sc Spark context used to run parallel listing. + * @param paths Input paths to list + * @param hadoopConf Hadoop configuration + * @param filter Path filter used to exclude leaf files from result + * @param ignoreMissingFiles Ignore missing files that happen during recursive listing + * (e.g., due to race conditions) + * @param ignoreLocality Whether to fetch data locality info when listing leaf files. If false, + * this will return `FileStatus` without `BlockLocation` info. + * @param parallelismThreshold The threshold to enable parallelism. If the number of input paths + * is smaller than this value, this will fallback to use + * sequential listing. + * @param parallelismMax The maximum parallelism for listing. If the number of input paths is + * larger than this value, parallelism will be throttled to this value + * to avoid generating too many tasks. + * @return for each input path, the set of discovered files for the path + */ + def parallelListLeafFiles( + sc: SparkContext, + paths: Seq[Path], + hadoopConf: Configuration, + filter: PathFilter, + ignoreMissingFiles: Boolean, + ignoreLocality: Boolean, + parallelismThreshold: Int, + parallelismMax: Int): Seq[(Path, Seq[FileStatus])] = { + parallelListLeafFilesInternal(sc, paths, hadoopConf, filter, isRootLevel = true, + ignoreMissingFiles, ignoreLocality, parallelismThreshold, parallelismMax) + } + + private def parallelListLeafFilesInternal( + sc: SparkContext, + paths: Seq[Path], + hadoopConf: Configuration, + filter: PathFilter, + isRootLevel: Boolean, + ignoreMissingFiles: Boolean, + ignoreLocality: Boolean, + parallelismThreshold: Int, + parallelismMax: Int): Seq[(Path, Seq[FileStatus])] = { + + // Short-circuits parallel listing when serial listing is likely to be faster. + if (paths.size <= parallelismThreshold) { + return paths.map { path => + val leafFiles = listLeafFiles( + path, + hadoopConf, + filter, + Some(sc), + ignoreMissingFiles = ignoreMissingFiles, + ignoreLocality = ignoreLocality, + isRootPath = isRootLevel, + parallelismThreshold = parallelismThreshold, + parallelismMax = parallelismMax) + (path, leafFiles) + } + } + + logInfo(s"Listing leaf files and directories in parallel under ${paths.length} paths." + + s" The first several paths are: ${paths.take(10).mkString(", ")}.") + HiveCatalogMetrics.incrementParallelListingJobCount(1) + + val brSerializableConfiguration = sc.broadcast(new SerializableConfiguration(hadoopConf)); + val serializedPaths = paths.map(_.toString) + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, parallelismMax) + + val previousJobDescription = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + val statusMap = try { + val description = paths.size match { + case 0 => + "Listing leaf files and directories 0 paths" + case 1 => + s"Listing leaf files and directories for 1 path:
${paths(0)}" + case s => + s"Listing leaf files and directories for $s paths:
${paths(0)}, ..." + } + sc.setJobDescription(description) + sc + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = brSerializableConfiguration.value.value + pathStrings.map(new Path(_)).toSeq.map { path => + val leafFiles = listLeafFiles( + path = path, + hadoopConf = hadoopConf, + filter = filter, + contextOpt = None, // Can't execute parallel scans on workers + ignoreMissingFiles = ignoreMissingFiles, + ignoreLocality = ignoreLocality, + isRootPath = isRootLevel, + parallelismThreshold = Int.MaxValue, + parallelismMax = 0) + (path, leafFiles) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + } finally { + sc.setJobDescription(previousJobDescription) + } + + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } + } + + // scalastyle:off argcount + /** + * Lists a single filesystem path recursively. If a `SparkContext` object is specified, this + * function may launch Spark jobs to parallelize listing based on `parallelismThreshold`. + * + * If sessionOpt is None, this may be called on executors. + * + * @return all children of path that match the specified filter. + */ + private def listLeafFiles( + path: Path, + hadoopConf: Configuration, + filter: PathFilter, + contextOpt: Option[SparkContext], + ignoreMissingFiles: Boolean, + ignoreLocality: Boolean, + isRootPath: Boolean, + parallelismThreshold: Int, + parallelismMax: Int): Seq[FileStatus] = { + + logTrace(s"Listing $path") + val fs = path.getFileSystem(hadoopConf) + + // Note that statuses only include FileStatus for the files and dirs directly under path, + // and does not include anything else recursively. + val statuses: Array[FileStatus] = try { + fs match { + // DistributedFileSystem overrides listLocatedStatus to make 1 single call to namenode + // to retrieve the file status with the file block location. The reason to still fallback + // to listStatus is because the default implementation would potentially throw a + // FileNotFoundException which is better handled by doing the lookups manually below. + case (_: DistributedFileSystem | _: ViewFileSystem) if !ignoreLocality => + val remoteIter = fs.listLocatedStatus(path) + new Iterator[LocatedFileStatus]() { + def next(): LocatedFileStatus = remoteIter.next + def hasNext(): Boolean = remoteIter.hasNext + }.toArray + case _ => fs.listStatus(path) + } + } catch { + // If we are listing a root path for SQL (e.g. a top level directory of a table), we need to + // ignore FileNotFoundExceptions during this root level of the listing because + // + // (a) certain code paths might construct an InMemoryFileIndex with root paths that + // might not exist (i.e. not all callers are guaranteed to have checked + // path existence prior to constructing InMemoryFileIndex) and, + // (b) we need to ignore deleted root paths during REFRESH TABLE, otherwise we break + // existing behavior and break the ability drop SessionCatalog tables when tables' + // root directories have been deleted (which breaks a number of Spark's own tests). + // + // If we are NOT listing a root path then a FileNotFoundException here means that the + // directory was present in a previous level of file listing but is absent in this + // listing, likely indicating a race condition (e.g. concurrent table overwrite or S3 + // list inconsistency). + // + // The trade-off in supporting existing behaviors / use-cases is that we won't be + // able to detect race conditions involving root paths being deleted during + // InMemoryFileIndex construction. However, it's still a net improvement to detect and + // fail-fast on the non-root cases. For more info see the SPARK-27676 review discussion. + case _: FileNotFoundException if isRootPath || ignoreMissingFiles => + logWarning(s"The directory $path was not found. Was it deleted very recently?") + Array.empty[FileStatus] + } + + val filteredStatuses = + statuses.filterNot(status => shouldFilterOutPathName(status.getPath.getName)) + + val allLeafStatuses = { + val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory) + val nestedFiles: Seq[FileStatus] = contextOpt match { + case Some(context) if dirs.size > parallelismThreshold => + parallelListLeafFilesInternal( + context, + dirs.map(_.getPath), + hadoopConf = hadoopConf, + filter = filter, + isRootLevel = false, + ignoreMissingFiles = ignoreMissingFiles, + ignoreLocality = ignoreLocality, + parallelismThreshold = parallelismThreshold, + parallelismMax = parallelismMax + ).flatMap(_._2) + case _ => + dirs.flatMap { dir => + listLeafFiles( + path = dir.getPath, + hadoopConf = hadoopConf, + filter = filter, + contextOpt = contextOpt, + ignoreMissingFiles = ignoreMissingFiles, + ignoreLocality = ignoreLocality, + isRootPath = false, + parallelismThreshold = parallelismThreshold, + parallelismMax = parallelismMax) + } + } + val allFiles = topLevelFiles ++ nestedFiles + if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles + } + + val missingFiles = mutable.ArrayBuffer.empty[String] + val resolvedLeafStatuses = allLeafStatuses.flatMap { + case f: LocatedFileStatus => + Some(f) + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `parallelListLeafFiles` when the number of + // paths exceeds threshold. + case f if !ignoreLocality => + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), + // which is very slow on some file system (RawLocalFileSystem, which is launch a + // subprocess and parse the stdout). + try { + val locations = fs.getFileBlockLocations(f, 0, f.getLen).map { loc => + // Store BlockLocation objects to consume less memory + if (loc.getClass == classOf[BlockLocation]) { + loc + } else { + new BlockLocation(loc.getNames, loc.getHosts, loc.getOffset, loc.getLength) + } + } + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + Some(lfs) + } catch { + case _: FileNotFoundException if ignoreMissingFiles => + missingFiles += f.getPath.toString + None + } + + case f => Some(f) + } + + if (missingFiles.nonEmpty) { + logWarning( + s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}") + } + + resolvedLeafStatuses + } + // scalastyle:on argcount + + /** A serializable variant of HDFS's BlockLocation. This is required by Hadoop 2.7. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. This is required by Hadoop 2.7. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) + + /** Checks if we should filter out this path name. */ + def shouldFilterOutPathName(pathName: String): Boolean = { + // We filter follow paths: + // 1. everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we + // should skip this file in case of double reading. + val exclude = (pathName.startsWith("_") && !pathName.contains("=")) || + pathName.startsWith(".") || pathName.endsWith("._COPYING_") + val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata") + exclude && !include + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java index 73db9a981dba57a4551ee832ce32b4298983115d..77283e4d07eb7e1d77317a026f5d04ea4032c393 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java @@ -19,11 +19,11 @@ package com.huawei.boostkit.spark.jni; import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.IntVec; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.OrcFile; import org.json.JSONObject; import org.junit.After; import org.junit.Before; @@ -32,17 +32,17 @@ import org.junit.Test; import org.junit.runners.MethodSorters; import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; -import static org.junit.Assert.*; - @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @Before public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); initReaderJava(); initRecordReaderJava(); initBatch(); @@ -50,17 +50,22 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { @After public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); + System.out.println("orcColumnarBatchScanReader test finished"); } public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); + String absolutePath = directory.getAbsolutePath(); + System.out.println(absolutePath); + URI uri = null; + try { + uri = new URI(absolutePath); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + assertTrue(orcColumnarBatchScanReader.reader != 0); } public void initRecordReaderJava() { @@ -80,20 +85,20 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { includedColumns.add("i_current_price"); job.put("includedColumns", includedColumns.toArray()); - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); } public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); + orcColumnarBatchScanReader.batchReader = orcColumnarBatchScanReader.jniReader.initializeBatch(orcColumnarBatchScanReader.recordReader, 4096); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @Test public void testNext() { int[] typeId = new int[4]; long[] vecNativeId = new long[4]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); LongVec vec1 = new LongVec(vecNativeId[0]); VarcharVec vec2 = new VarcharVec(vecNativeId[1]); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java index d9fe13683343f4299ad2b4b2290b0cbf47d761e1..72587b3f36a469fe130abc76c51197fb2a16bd29 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -19,11 +19,10 @@ package com.huawei.boostkit.spark.jni; import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.OrcFile; import org.json.JSONObject; import org.junit.After; import org.junit.Before; @@ -32,17 +31,17 @@ import org.junit.Test; import org.junit.runners.MethodSorters; import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; -import static org.junit.Assert.*; - @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @Before public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); initReaderJava(); initRecordReaderJava(); initBatch(); @@ -50,17 +49,22 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { @After public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); + System.out.println("OrcColumnarBatchScanReader test finished"); } public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); + String absolutePath = directory.getAbsolutePath(); + System.out.println(absolutePath); + URI uri = null; + try { + uri = new URI(absolutePath); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + assertTrue(orcColumnarBatchScanReader.reader != 0); } public void initRecordReaderJava() { @@ -74,20 +78,20 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { includedColumns.add("i_item_id"); job.put("includedColumns", includedColumns.toArray()); - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); } public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); + orcColumnarBatchScanReader.batchReader = orcColumnarBatchScanReader.jniReader.initializeBatch(orcColumnarBatchScanReader.recordReader, 4096); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @Test public void testNext() { int[] typeId = new int[2]; long[] vecNativeId = new long[2]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); LongVec vec1 = new LongVec(vecNativeId[0]); VarcharVec vec2 = new VarcharVec(vecNativeId[1]); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java index 87f0cc1d2920982de3b73d9046d173a8f2c8fbb8..6c75eda79e38332e48e67df8a27c0e1394e4e477 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -18,37 +18,30 @@ package com.huawei.boostkit.spark.jni; -import static org.junit.Assert.*; import junit.framework.TestCase; -import org.apache.hadoop.mapred.join.ArrayListBackedIterator; -import org.apache.orc.OrcFile.ReaderOptions; -import org.apache.orc.Reader.Options; -import org.hamcrest.Condition; +import nova.hetu.omniruntime.vector.LongVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.OrcFile; import org.json.JSONObject; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; import org.junit.Test; import org.junit.runners.MethodSorters; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; -import nova.hetu.omniruntime.vector.LongVec; -import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; import java.io.File; -import java.lang.reflect.Array; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @Before public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); initReaderJava(); initRecordReaderJava(); initBatch(); @@ -56,17 +49,22 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { @After public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); + System.out.println("orcColumnarBatchScanReader test finished"); } public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); + String absolutePath = directory.getAbsolutePath(); + System.out.println(absolutePath); + URI uri = null; + try { + uri = new URI(absolutePath); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + assertTrue(orcColumnarBatchScanReader.reader != 0); } public void initRecordReaderJava() { @@ -126,20 +124,20 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { includedColumns.add("i_item_id"); job.put("includedColumns", includedColumns.toArray()); - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); } public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); + orcColumnarBatchScanReader.batchReader = orcColumnarBatchScanReader.jniReader.initializeBatch(orcColumnarBatchScanReader.recordReader, 4096); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @Test public void testNext() { int[] typeId = new int[2]; long[] vecNativeId = new long[2]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); LongVec vec1 = new LongVec(vecNativeId[0]); VarcharVec vec2 = new VarcharVec(vecNativeId[1]); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java index 484365c537231b46816e139b090d2384f08b5588..7fb87efa3a2899bc5375168f56d516243a10881d 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -22,6 +22,8 @@ import junit.framework.TestCase; import nova.hetu.omniruntime.vector.IntVec; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.OrcFile; import org.json.JSONObject; import org.junit.After; import org.junit.Before; @@ -30,17 +32,17 @@ import org.junit.Test; import org.junit.runners.MethodSorters; import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; -import static org.junit.Assert.*; - @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @Before public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); initReaderJava(); initRecordReaderJava(); initBatch(); @@ -48,17 +50,22 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { @After public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); + System.out.println("orcColumnarBatchScanReader test finished"); } public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); + String absolutePath = directory.getAbsolutePath(); + System.out.println(absolutePath); + URI uri = null; + try { + uri = new URI(absolutePath); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + assertTrue(orcColumnarBatchScanReader.reader != 0); } public void initRecordReaderJava() { @@ -78,20 +85,20 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { includedColumns.add("i_current_price"); job.put("includedColumns", includedColumns.toArray()); - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); } public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); + orcColumnarBatchScanReader.batchReader = orcColumnarBatchScanReader.jniReader.initializeBatch(orcColumnarBatchScanReader.recordReader, 4096); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @Test public void testNext() { int[] typeId = new int[4]; long[] vecNativeId = new long[4]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); LongVec vec1 = new LongVec(vecNativeId[0]); VarcharVec vec2 = new VarcharVec(vecNativeId[1]); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java index b03d60aac4b61291c614bce9f7a52503918a1106..4ba4579cc9340ea410d1d1cdbbf5e0a88ebe2888 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -19,11 +19,11 @@ package com.huawei.boostkit.spark.jni; import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.IntVec; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; -import nova.hetu.omniruntime.vector.Vec; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.OrcFile; import org.json.JSONObject; import org.junit.After; import org.junit.Before; @@ -32,17 +32,17 @@ import org.junit.Test; import org.junit.runners.MethodSorters; import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; -import static org.junit.Assert.*; - @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @Before public void setUp() throws Exception { - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); initReaderJava(); initRecordReaderJava(); initBatch(); @@ -50,17 +50,22 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { @After public void tearDown() throws Exception { - System.out.println("orcColumnarBatchJniReader test finished"); + System.out.println("orcColumnarBatchScanReader test finished"); } public void initReaderJava() { - JSONObject job = new JSONObject(); - job.put("serializedTail",""); - job.put("tailLocation",9223372036854775807L); File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/part-00000-2d6ca713-08b0-4b40-828c-f7ee0c81bb9a-c000.snappy.orc"); - System.out.println(directory.getAbsolutePath()); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReader(directory.getAbsolutePath(), job); - assertTrue(orcColumnarBatchJniReader.reader != 0); + String absolutePath = directory.getAbsolutePath(); + System.out.println(absolutePath); + URI uri = null; + try { + uri = new URI(absolutePath); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, OrcFile.readerOptions(new Configuration())); + assertTrue(orcColumnarBatchScanReader.reader != 0); } public void initRecordReaderJava() { @@ -126,20 +131,20 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { includedColumns.add("i_current_price"); job.put("includedColumns", includedColumns.toArray()); - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReader(orcColumnarBatchJniReader.reader, job); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.jniReader.initializeRecordReader(orcColumnarBatchScanReader.reader, job); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); } public void initBatch() { - orcColumnarBatchJniReader.batchReader = orcColumnarBatchJniReader.initializeBatch(orcColumnarBatchJniReader.recordReader, 4096); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); + orcColumnarBatchScanReader.batchReader = orcColumnarBatchScanReader.jniReader.initializeBatch(orcColumnarBatchScanReader.recordReader, 4096); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @Test public void testNext() { int[] typeId = new int[4]; long[] vecNativeId = new long[4]; - long rtn = orcColumnarBatchJniReader.recordReaderNext(orcColumnarBatchJniReader.recordReader, orcColumnarBatchJniReader.batchReader, typeId, vecNativeId); + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); LongVec vec1 = new LongVec(vecNativeId[0]); VarcharVec vec2 = new VarcharVec(vecNativeId[1]); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index 99801bcfb86567a5a2cb44dc43e4428496b00ed3..c8581f35ebc605845896b5f35731b0908779f326 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -42,7 +42,11 @@ import org.junit.Test; import org.junit.runners.MethodSorters; import org.apache.hadoop.conf.Configuration; import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; +import java.util.List; +import java.util.Arrays; import org.apache.orc.Reader.Options; import static org.junit.Assert.*; @@ -50,7 +54,7 @@ import static org.junit.Assert.*; @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderTest extends TestCase { public Configuration conf = new Configuration(); - public OrcColumnarBatchJniReader orcColumnarBatchJniReader; + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; public int batchSize = 4096; @Before @@ -77,39 +81,78 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { sarg.getExpression().toString(); } - orcColumnarBatchJniReader = new OrcColumnarBatchJniReader(); + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); initReaderJava(); + initDataColIds(options, orcColumnarBatchScanReader); initRecordReaderJava(options); initBatch(options); } + public void initDataColIds( + Options options, OrcColumnarBatchScanReader orcColumnarBatchScanReader) { + List allCols; + allCols = Arrays.asList(options.getColumnNames()); + orcColumnarBatchScanReader.colToInclu = new ArrayList(); + List optionField = options.getSchema().getFieldNames(); + orcColumnarBatchScanReader.colsToGet = new int[optionField.size()]; + orcColumnarBatchScanReader.realColsCnt = 0; + for (int i = 0; i < optionField.size(); i++) { + if (allCols.contains(optionField.get(i))) { + orcColumnarBatchScanReader.colToInclu.add(optionField.get(i)); + orcColumnarBatchScanReader.colsToGet[i] = 0; + orcColumnarBatchScanReader.realColsCnt++; + } else { + orcColumnarBatchScanReader.colsToGet[i] = -1; + } + } + + orcColumnarBatchScanReader.requiredfieldNames = new String[optionField.size()]; + TypeDescription schema = options.getSchema(); + int[] precisionArray = new int[optionField.size()]; + int[] scaleArray = new int[optionField.size()]; + for (int i = 0; i < optionField.size(); i++) { + precisionArray[i] = schema.findSubtype(optionField.get(i)).getPrecision(); + scaleArray[i] = schema.findSubtype(optionField.get(i)).getScale(); + orcColumnarBatchScanReader.requiredfieldNames[i] = optionField.get(i); + } + orcColumnarBatchScanReader.precisionArray = precisionArray; + orcColumnarBatchScanReader.scaleArray = scaleArray; + } + @After public void tearDown() throws Exception { System.out.println("orcColumnarBatchJniReader test finished"); } - public void initReaderJava() { + public void initReaderJava() throws URISyntaxException { OrcFile.ReaderOptions readerOptions = OrcFile.readerOptions(conf); File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/000000_0"); String path = directory.getAbsolutePath(); - orcColumnarBatchJniReader.reader = orcColumnarBatchJniReader.initializeReaderJava(path, readerOptions); - assertTrue(orcColumnarBatchJniReader.reader != 0); + URI uri = null; + try { + uri = new URI(path); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri, readerOptions); + assertTrue(orcColumnarBatchScanReader.reader != 0); } public void initRecordReaderJava(Options options) { - orcColumnarBatchJniReader.recordReader = orcColumnarBatchJniReader.initializeRecordReaderJava(options); - assertTrue(orcColumnarBatchJniReader.recordReader != 0); + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader.initializeRecordReaderJava(options); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); } public void initBatch(Options options) { - orcColumnarBatchJniReader.initBatchJava(batchSize); - assertTrue(orcColumnarBatchJniReader.batchReader != 0); + orcColumnarBatchScanReader.initBatchJava(batchSize); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); } @Test public void testNext() { Vec[] vecs = new Vec[2]; - long rtn = orcColumnarBatchJniReader.next(vecs); + long rtn = orcColumnarBatchScanReader.next(vecs); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0)); @@ -122,7 +165,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { public void testGetProgress() { String tmp = ""; try { - double progressValue = orcColumnarBatchJniReader.getProgress(); + double progressValue = orcColumnarBatchScanReader.getProgress(); } catch (Exception e) { tmp = e.getMessage(); } finally { diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java index 5996413555c00cd1dedc2fe81bf50da5efe3c097..dee2cea90333695790db49cbca225ef9493840fe 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/ParquetColumnarBatchJniReaderTest.java @@ -19,7 +19,9 @@ package com.huawei.boostkit.spark.jni; import junit.framework.TestCase; -import nova.hetu.omniruntime.vector.*; +import nova.hetu.omniruntime.vector.Vec; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.types.DataType; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; @@ -31,29 +33,36 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import static org.apache.spark.sql.types.DataTypes.*; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING) public class ParquetColumnarBatchJniReaderTest extends TestCase { - private ParquetColumnarBatchJniReader parquetColumnarBatchJniReader; + private ParquetColumnarBatchScanReader parquetColumnarBatchScanReader; private Vec[] vecs; + private List types; + @Before public void setUp() throws Exception { - parquetColumnarBatchJniReader = new ParquetColumnarBatchJniReader(); + parquetColumnarBatchScanReader = new ParquetColumnarBatchScanReader(); List rowGroupIndices = new ArrayList<>(); rowGroupIndices.add(0); List columnIndices = new ArrayList<>(); Collections.addAll(columnIndices, 0, 1, 3, 6, 7, 8, 9, 10, 12); - File file = new File("../cpp/test/tablescan/resources/parquet_data_all_type"); + types = new ArrayList<>(); + Collections.addAll(types, IntegerType, StringType, LongType, DoubleType, createDecimalType(9, 8), + createDecimalType(18, 5), BooleanType, ShortType, DateType); + File file = new File("../../omniop-native-reader/cpp/test/tablescan/resources/parquet_data_all_type"); String path = file.getAbsolutePath(); - parquetColumnarBatchJniReader.initializeReaderJava(path, 100000, rowGroupIndices, columnIndices, "root@sample"); + parquetColumnarBatchScanReader.initializeReaderJava(new Path(path), 100000, rowGroupIndices, columnIndices, "root@sample"); vecs = new Vec[9]; } @After public void tearDown() throws Exception { - parquetColumnarBatchJniReader.close(); + parquetColumnarBatchScanReader.close(); for (Vec vec : vecs) { vec.close(); } @@ -61,7 +70,7 @@ public class ParquetColumnarBatchJniReaderTest extends TestCase { @Test public void testRead() { - long num = parquetColumnarBatchJniReader.next(vecs); + long num = parquetColumnarBatchScanReader.next(vecs, types); assertTrue(num == 1); } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala index a4131e3ef869c301b56a72b48f3d2994884241f3..ded676538759023db54d956f4c4449ae427e76b6 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptorSuite.scala @@ -18,10 +18,10 @@ package com.huawei.boostkit.spark.expression -import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{getExprIdMap, procCaseWhenExpression, procLikeExpression, rewriteToOmniExpressionLiteral, rewriteToOmniJsonExpressionLiteral} +import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature} +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{getExprIdMap, procCaseWhenExpression, rewriteToOmniJsonExpressionLiteral} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Max, Min, Sum} import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, StringType} /** @@ -36,79 +36,6 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { AttributeReference("d", BooleanType)(), AttributeReference("e", IntegerType)(), AttributeReference("f", StringType)(), AttributeReference("g", StringType)()) - test("expression rewrite") { - checkExpressionRewrite("$operator$ADD:1(#0,#1)", Add(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$ADD:1(#0,1:1)", Add(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$SUBTRACT:1(#0,#1)", - Subtract(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$SUBTRACT:1(#0,1:1)", Subtract(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$MULTIPLY:1(#0,#1)", - Multiply(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$MULTIPLY:1(#0,1:1)", Multiply(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$DIVIDE:1(#0,#1)", Divide(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$DIVIDE:1(#0,1:1)", Divide(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$MODULUS:1(#0,#1)", - Remainder(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$MODULUS:1(#0,1:1)", Remainder(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$GREATER_THAN:4(#0,#1)", - GreaterThan(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$GREATER_THAN:4(#0,1:1)", - GreaterThan(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$GREATER_THAN_OR_EQUAL:4(#0,#1)", - GreaterThanOrEqual(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$GREATER_THAN_OR_EQUAL:4(#0,1:1)", - GreaterThanOrEqual(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$LESS_THAN:4(#0,#1)", - LessThan(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$LESS_THAN:4(#0,1:1)", - LessThan(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$LESS_THAN_OR_EQUAL:4(#0,#1)", - LessThanOrEqual(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$LESS_THAN_OR_EQUAL:4(#0,1:1)", - LessThanOrEqual(allAttribute(0), Literal(1))) - - checkExpressionRewrite("$operator$EQUAL:4(#0,#1)", EqualTo(allAttribute(0), allAttribute(1))) - checkExpressionRewrite("$operator$EQUAL:4(#0,1:1)", EqualTo(allAttribute(0), Literal(1))) - - checkExpressionRewrite("OR:4(#2,#3)", Or(allAttribute(2), allAttribute(3))) - checkExpressionRewrite("OR:4(#2,3:1)", Or(allAttribute(2), Literal(3))) - - checkExpressionRewrite("AND:4(#2,#3)", And(allAttribute(2), allAttribute(3))) - checkExpressionRewrite("AND:4(#2,3:1)", And(allAttribute(2), Literal(3))) - - checkExpressionRewrite("not:4(#3)", Not(allAttribute(3))) - - checkExpressionRewrite("IS_NOT_NULL:4(#4)", IsNotNull(allAttribute(4))) - - checkExpressionRewrite("substr:15(#5,#0,#1)", - Substring(allAttribute(5), allAttribute(0), allAttribute(1))) - - checkExpressionRewrite("CAST:2(#1)", Cast(allAttribute(1), LongType)) - - checkExpressionRewrite("abs:1(#0)", Abs(allAttribute(0))) - - checkExpressionRewrite("SUM:2(#0)", Sum(allAttribute(0))) - - checkExpressionRewrite("MAX:1(#0)", Max(allAttribute(0))) - - checkExpressionRewrite("AVG:3(#0)", Average(allAttribute(0))) - - checkExpressionRewrite("MIN:1(#0)", Min(allAttribute(0))) - - checkExpressionRewrite("IN:4(#0,#0,#1)", - In(allAttribute(0), Seq(allAttribute(0), allAttribute(1)))) - - // checkExpressionRewrite("IN:4(#0, #0, #1)", InSet(allAttribute(0), Set(allAttribute(0), allAttribute(1)))) - } - test("json expression rewrite") { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"ADD\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + @@ -117,7 +44,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"ADD\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", Add(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"SUBTRACT\"," + @@ -127,7 +54,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"SUBTRACT\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", Subtract(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MULTIPLY\"," + @@ -137,7 +64,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MULTIPLY\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", Multiply(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"DIVIDE\"," + @@ -147,7 +74,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"DIVIDE\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", Divide(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MODULUS\"," + @@ -157,7 +84,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":1,\"operator\":\"MODULUS\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", Remainder(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + @@ -169,7 +96,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + "\"operator\":\"GREATER_THAN\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", GreaterThan(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + @@ -181,7 +108,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + "\"operator\":\"GREATER_THAN_OR_EQUAL\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", GreaterThanOrEqual(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + @@ -191,7 +118,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", LessThan(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + @@ -203,7 +130,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4," + "\"operator\":\"LESS_THAN_OR_EQUAL\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", LessThanOrEqual(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + @@ -213,7 +140,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"EQUAL\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":1}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":1}}", EqualTo(allAttribute(0), Literal(1))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + @@ -223,7 +150,7 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"OR\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":3}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":3}}", Or(allAttribute(2), Literal(3))) checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + @@ -233,10 +160,10 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { checkJsonExprRewrite("{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"AND\"," + "\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":4,\"colVal\":2}," + - "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":3}}", + "\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":3}}", And(allAttribute(2), Literal(3))) - checkJsonExprRewrite("{\"exprType\":\"UNARY\",\"returnType\":4, \"operator\":\"not\"," + + checkJsonExprRewrite("{\"exprType\":\"UNARY\",\"returnType\":4,\"operator\":\"not\"," + "\"expr\":{\"exprType\":\"IS_NULL\",\"returnType\":4," + "\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4}]}}", IsNotNull(allAttribute(4))) @@ -250,25 +177,25 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { Abs(allAttribute(0))) checkJsonExprRewrite("{\"exprType\":\"FUNCTION\",\"returnType\":1,\"function_name\":\"round\"," + - " \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0},{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":2}]}", + " \"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":0},{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":2}]}", Round(allAttribute(0), Literal(2))) } - protected def checkExpressionRewrite(expected: Any, expression: Expression): Unit = { - { - val runResult = rewriteToOmniExpressionLiteral(expression, getExprIdMap(allAttribute)) - if (!expected.equals(runResult)) { - fail(s"expression($expression) not match with expected value:$expected," + - s"running value:$runResult") - } - } - } - protected def checkJsonExprRewrite(expected: Any, expression: Expression): Unit = { val runResult = rewriteToOmniJsonExpressionLiteral(expression, getExprIdMap(allAttribute)) - if (!expected.equals(runResult)) { - fail(s"expression($expression) not match with expected value:$expected," + - s"running value:$runResult") + checkJsonKeyValueIgnoreKeySequence(expected.asInstanceOf[String], runResult, expression) + } + + private def checkJsonKeyValueIgnoreKeySequence(expected: String, runResult: String, expression: Expression) : Unit = { + // 将expected runResult 两个json字符串中的key排序后比较两个json字符串是否相同 + val objectMapper = new ObjectMapper().configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true) + val expectedJsonNode = objectMapper.readTree(expected) + val runResultJsonNode = objectMapper.readTree(runResult) + val expectedIgnoreKeySequence = objectMapper.writeValueAsString(objectMapper.treeToValue(expectedJsonNode, classOf[Object])) + val runResultIgnoreKeySequence = objectMapper.writeValueAsString(objectMapper.treeToValue(runResultJsonNode, classOf[Object])) + if (!expectedIgnoreKeySequence.equals(runResultIgnoreKeySequence)) { + fail(s"expression($expression) not match with expected value:$expectedIgnoreKeySequence," + + s"running value:$runResultIgnoreKeySequence") } } @@ -282,28 +209,22 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { val elseValue = Some(Not(EqualTo(cnAttribute(3), Literal("啊水水水水")))) val caseWhen = CaseWhen(branch, elseValue); val caseWhenResult = rewriteToOmniJsonExpressionLiteral(caseWhen, getExprIdMap(cnAttribute)) - val caseWhenExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" - if (!caseWhenExp.equals(caseWhenResult)) { - fail(s"expression($caseWhen) not match with expected value:$caseWhenExp," + - s"running value:$caseWhenResult") - } + val caseWhenExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"啊水水水水\",\"width\":5}}}}" + + checkJsonKeyValueIgnoreKeySequence(caseWhenExp, caseWhenResult, caseWhen) val isNull = IsNull(cnAttribute(0)); val isNullResult = rewriteToOmniJsonExpressionLiteral(isNull, getExprIdMap(cnAttribute)) val isNullExp = "{\"exprType\":\"IS_NULL\",\"returnType\":4,\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50}]}" - if (!isNullExp.equals(isNullResult)) { - fail(s"expression($isNull) not match with expected value:$isNullExp," + - s"running value:$isNullResult") - } + + checkJsonKeyValueIgnoreKeySequence(isNullExp, isNullResult, isNull) val children = Seq(cnAttribute(0), cnAttribute(1)) val coalesce = Coalesce(children); val coalesceResult = rewriteToOmniJsonExpressionLiteral(coalesce, getExprIdMap(cnAttribute)) - val coalesceExp = "{\"exprType\":\"COALESCE\",\"returnType\":15,\"width\":50, \"value1\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"value2\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}}" - if (!coalesceExp.equals(coalesceResult)) { - fail(s"expression($coalesce) not match with expected value:$coalesceExp," + - s"running value:$coalesceResult") - } + val coalesceExp = "{\"exprType\":\"COALESCE\",\"returnType\":15,\"width\":50,\"value1\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"value2\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}}" + + checkJsonKeyValueIgnoreKeySequence(coalesceExp, coalesceResult, coalesce) val children2 = Seq(cnAttribute(0), cnAttribute(1), cnAttribute(2)) val coalesce2 = Coalesce(children2); @@ -327,36 +248,67 @@ class OmniExpressionAdaptorSuite extends SparkFunSuite { val branch = Seq(t1, t2) val elseValue = Some(Not(EqualTo(caseWhenAttribute(3), Literal("啊水水水水")))) val expression = CaseWhen(branch, elseValue); - val runResult = procCaseWhenExpression(expression, getExprIdMap(caseWhenAttribute)) - val filterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false, \"value\":\"啊水水水水\",\"width\":5}}}}" - if (!filterExp.equals(runResult)) { - fail(s"expression($expression) not match with expected value:$filterExp," + - s"running value:$runResult") - } + val runResult = procCaseWhenExpression(expression, getExprIdMap(caseWhenAttribute)).toString() + val filterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"新\",\"width\":1}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"官方爸爸\",\"width\":4}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"爱你三千遍\",\"width\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":2,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"新\",\"width\":1}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":3,\"width\":50},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":15,\"isNull\":false,\"value\":\"啊水水水水\",\"width\":5}}}}" + + checkJsonKeyValueIgnoreKeySequence(filterExp, runResult, expression) val t3 = new Tuple2(Not(EqualTo(caseWhenAttribute(4), Literal(5))), Not(EqualTo(caseWhenAttribute(5), Literal(10)))) val t4 = new Tuple2(LessThan(caseWhenAttribute(4), Literal(15)), GreaterThan(caseWhenAttribute(5), Literal(20))) val branch2 = Seq(t3, t4) val elseValue2 = Some(Not(EqualTo(caseWhenAttribute(5), Literal(25)))) val numExpression = CaseWhen(branch2, elseValue2); - val numResult = procCaseWhenExpression(numExpression, getExprIdMap(caseWhenAttribute)) - val numFilterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":10}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":15}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":20}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":25}}}}" - if (!numFilterExp.equals(numResult)) { - fail(s"expression($numExpression) not match with expected value:$numFilterExp," + - s"running value:$numResult") - } + val numResult = procCaseWhenExpression(numExpression, getExprIdMap(caseWhenAttribute)).toString() + val numFilterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":10}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":15}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":20}},\"if_false\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":25}}}}" + + checkJsonKeyValueIgnoreKeySequence(numFilterExp, numResult, numExpression) val t5 = new Tuple2(Not(EqualTo(caseWhenAttribute(4), Literal(5))), Not(EqualTo(caseWhenAttribute(5), Literal(10)))) val t6 = new Tuple2(LessThan(caseWhenAttribute(4), Literal(15)), GreaterThan(caseWhenAttribute(5), Literal(20))) val branch3 = Seq(t5, t6) val elseValue3 = None val noneExpression = CaseWhen(branch3, elseValue3); - val noneResult = procCaseWhenExpression(noneExpression, getExprIdMap(caseWhenAttribute)) - val noneFilterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":10}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":15}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1, \"isNull\":false, \"value\":20}},\"if_false\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":true}}}" - if (!noneFilterExp.equals(noneResult)) { - fail(s"expression($noneExpression) not match with expected value:$noneFilterExp," + - s"running value:$noneResult") - } + val noneResult = procCaseWhenExpression(noneExpression, getExprIdMap(caseWhenAttribute)).toString() + val noneFilterExp = "{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":5}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"NOT_EQUAL\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":10}},\"if_false\":{\"exprType\":\"IF\",\"returnType\":4,\"condition\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"LESS_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":4},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":15}},\"if_true\":{\"exprType\":\"BINARY\",\"returnType\":4,\"operator\":\"GREATER_THAN\",\"left\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":1,\"colVal\":5},\"right\":{\"exprType\":\"LITERAL\",\"dataType\":1,\"isNull\":false,\"value\":20}},\"if_false\":{\"exprType\":\"LITERAL\",\"dataType\":4,\"isNull\":true}}}" + + checkJsonKeyValueIgnoreKeySequence(noneFilterExp, noneResult, noneExpression) + + val t7 = Tuple2(Not(EqualTo(caseWhenAttribute(0), Literal("\"\\\\t/\\b\\f\\n\\r\\t123"))), Not(EqualTo(caseWhenAttribute(1), Literal("\"\\\\t/\\b\\f\\n\\r\\t234")))) + val t8 = Tuple2(Not(EqualTo(caseWhenAttribute(2), Literal("\"\\\\t/\\b\\f\\n\\r\\t345"))), Not(EqualTo(caseWhenAttribute(2), Literal("\"\\\\t/\\b\\f\\n\\r\\t123")))) + val branch4 = Seq(t7, t8) + val elseValue4 = Some(Not(EqualTo(caseWhenAttribute(3), Literal("\"\\\\t/\\b\\f\\n\\r\\t456")))) + val specialCharacterExpression = CaseWhen(branch4, elseValue4); + val specialCharacterRunResult = procCaseWhenExpression(specialCharacterExpression, getExprIdMap(caseWhenAttribute)).toString() + val specialCharacterFilterExp = "{\"condition\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":0,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t123\",\"width\":18}},\"exprType\":\"IF\",\"if_false\":{\"condition\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":2,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t345\",\"width\":18}},\"exprType\":\"IF\",\"if_false\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":3,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t456\",\"width\":18}},\"if_true\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":2,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t123\",\"width\":18}},\"returnType\":4},\"if_true\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":1,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t234\",\"width\":18}},\"returnType\":4} " + + checkJsonKeyValueIgnoreKeySequence(specialCharacterFilterExp, specialCharacterRunResult, specialCharacterExpression) + + } + + test("test special character rewrite") { + val specialCharacterAttribute = Seq(AttributeReference("char_1", StringType)(), AttributeReference("char_20", StringType)(), + AttributeReference("varchar_1", StringType)(), AttributeReference("varchar_20", StringType)()) + + val t1 = new Tuple2(Not(EqualTo(specialCharacterAttribute(0), Literal("\"\\\\t/\\b\\f\\n\\r\\t123"))), Not(EqualTo(specialCharacterAttribute(1), Literal("\"\\\\t/\\b\\f\\n\\r\\t234")))) + val t2 = new Tuple2(Not(EqualTo(specialCharacterAttribute(2), Literal("\"\\\\t/\\b\\f\\n\\r\\t345"))), Not(EqualTo(specialCharacterAttribute(2), Literal("\"\\\\t/\\b\\f\\n\\r\\t456")))) + val branch = Seq(t1, t2) + val elseValue = Some(Not(EqualTo(specialCharacterAttribute(3), Literal("\"\\\\t/\\b\\f\\n\\r\\t456")))) + val caseWhen = CaseWhen(branch, elseValue); + val caseWhenResult = rewriteToOmniJsonExpressionLiteral(caseWhen, getExprIdMap(specialCharacterAttribute)) + val caseWhenExp = "{\"condition\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":0,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t123\",\"width\":18}},\"exprType\":\"IF\",\"if_false\":{\"condition\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":2,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t345\",\"width\":18}},\"exprType\":\"IF\",\"if_false\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":3,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t456\",\"width\":18}},\"if_true\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":2,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t456\",\"width\":18}},\"returnType\":4},\"if_true\":{\"exprType\":\"BINARY\",\"left\":{\"colVal\":1,\"dataType\":15,\"exprType\":\"FIELD_REFERENCE\",\"width\":50},\"operator\":\"NOT_EQUAL\",\"returnType\":4,\"right\":{\"dataType\":15,\"exprType\":\"LITERAL\",\"isNull\":false,\"value\":\"\\\"\\\\\\\\t/\\\\b\\\\f\\\\n\\\\r\\\\t234\",\"width\":18}},\"returnType\":4}" + checkJsonKeyValueIgnoreKeySequence(caseWhenExp, caseWhenResult, caseWhen) + + val isNull = IsNull(specialCharacterAttribute(0)); + val isNullResult = rewriteToOmniJsonExpressionLiteral(isNull, getExprIdMap(specialCharacterAttribute)) + val isNullExp = "{\"exprType\":\"IS_NULL\",\"returnType\":4,\"arguments\":[{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50}]}" + + checkJsonKeyValueIgnoreKeySequence(isNullExp, isNullResult, isNull) + + val children = Seq(specialCharacterAttribute(0), specialCharacterAttribute(1)) + val coalesce = Coalesce(children); + val coalesceResult = rewriteToOmniJsonExpressionLiteral(coalesce, getExprIdMap(specialCharacterAttribute)) + val coalesceExp = "{\"exprType\":\"COALESCE\",\"returnType\":15,\"width\":50,\"value1\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":0,\"width\":50},\"value2\":{\"exprType\":\"FIELD_REFERENCE\",\"dataType\":15,\"colVal\":1,\"width\":50}}" + checkJsonKeyValueIgnoreKeySequence(coalesceExp, coalesceResult, coalesce) } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuiteBase.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuiteBase.scala new file mode 100644 index 0000000000000000000000000000000000000000..a9d668628eb1e05f3747aedc19ce2c6501dc5e1d --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuiteBase.scala @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.GivenWhenThen + +import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Test suite for the filtering ratio policy used to trigger dynamic partition pruning (DPP). + */ +class DynamicPartitionPruningSuite + extends ColumnarSparkPlanTest + with SQLTestUtils + with GivenWhenThen + with AdaptiveSparkPlanHelper { + + val tableFormat: String = "parquet" + + import testImplicits._ + + protected def initState(): Unit = {} + protected def runAnalyzeColumnCommands: Boolean = true + + override protected def beforeAll(): Unit = { + super.beforeAll() + + initState() + + val factData = Seq[(Int, Int, Int, Int)]( + (1000, 1, 1, 10), + (1010, 2, 1, 10), + (1020, 2, 1, 10), + (1030, 3, 2, 10), + (1040, 3, 2, 50), + (1050, 3, 2, 50), + (1060, 3, 2, 50), + (1070, 4, 2, 10), + (1080, 4, 3, 20), + (1090, 4, 3, 10), + (1100, 4, 3, 10), + (1110, 5, 3, 10), + (1120, 6, 4, 10), + (1130, 7, 4, 50), + (1140, 8, 4, 50), + (1150, 9, 1, 20), + (1160, 10, 1, 20), + (1170, 11, 1, 30), + (1180, 12, 2, 20), + (1190, 13, 2, 20), + (1200, 14, 3, 40), + (1200, 15, 3, 70), + (1210, 16, 4, 10), + (1220, 17, 4, 20), + (1230, 18, 4, 20), + (1240, 19, 5, 40), + (1250, 20, 5, 40), + (1260, 21, 5, 40), + (1270, 22, 5, 50), + (1280, 23, 1, 50), + (1290, 24, 1, 50), + (1300, 25, 1, 50) + ) + + val storeData = Seq[(Int, String, String)]( + (1, "North-Holland", "NL"), + (2, "South-Holland", "NL"), + (3, "Bavaria", "DE"), + (4, "California", "US"), + (5, "Texas", "US"), + (6, "Texas", "US") + ) + + val storeCode = Seq[(Int, Int)]( + (1, 10), + (2, 20), + (3, 30), + (4, 40), + (5, 50), + (6, 60) + ) + + if (tableFormat == "hive") { + spark.sql("set hive.exec.dynamic.partition.mode=nonstrict") + } + + spark.range(1000) + .select($"id" as "product_id", ($"id" % 10) as "store_id", ($"id" + 1) as "code") + .write + .format(tableFormat) + .mode("overwrite") + .saveAsTable("product") + + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write + .format(tableFormat) + .saveAsTable("fact_np") + + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write + .partitionBy("store_id") + .format(tableFormat) + .saveAsTable("fact_sk") + + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write + .partitionBy("store_id") + .format(tableFormat) + .saveAsTable("fact_stats") + + storeData.toDF("store_id", "state_province", "country") + .write + .format(tableFormat) + .saveAsTable("dim_store") + + storeData.toDF("store_id", "state_province", "country") + .write + .format(tableFormat) + .saveAsTable("dim_stats") + + storeCode.toDF("store_id", "code") + .write + .partitionBy("store_id") + .format(tableFormat) + .saveAsTable("code_stats") + + if (runAnalyzeColumnCommands) { + sql("ANALYZE TABLE fact_stats COMPUTE STATISTICS FOR COLUMNS store_id") + sql("ANALYZE TABLE dim_stats COMPUTE STATISTICS FOR COLUMNS store_id") + sql("ANALYZE TABLE dim_store COMPUTE STATISTICS FOR COLUMNS store_id") + sql("ANALYZE TABLE code_stats COMPUTE STATISTICS FOR COLUMNS store_id") + } + } + + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS fact_np") + sql("DROP TABLE IF EXISTS fact_sk") + sql("DROP TABLE IF EXISTS product") + sql("DROP TABLE IF EXISTS dim_store") + sql("DROP TABLE IF EXISTS fact_stats") + sql("DROP TABLE IF EXISTS dim_stats") + sql("DROP TABLE IF EXISTS code_stats") + } finally { + spark.sessionState.conf.unsetConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED) + spark.sessionState.conf.unsetConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) + super.afterAll() + } + } + + /** + * Check if the query plan has a partition pruning filter inserted as + * a subquery duplicate or as a custom broadcast exchange. + */ + def checkPartitionPruningPredicate( + df: DataFrame, + withSubquery: Boolean, + withBroadcast: Boolean): Unit = { + df.collect() + + val plan = df.queryExecution.executedPlan + val dpExprs = collectDynamicPruningExpressions(plan) + val hasSubquery = dpExprs.exists { + case InSubqueryExec(_, _: SubqueryExec, _, _, _, _) => true + case _ => false + } + val subqueryBroadcast = dpExprs.collect { + case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => b + } + + val hasFilter = if (withSubquery) "Should" else "Shouldn't" + assert(hasSubquery == withSubquery, + s"$hasFilter trigger DPP with a subquery duplicate:\n${df.queryExecution}") + val hasBroadcast = if (withBroadcast) "Should" else "Shouldn't" + assert(subqueryBroadcast.nonEmpty == withBroadcast, + s"$hasBroadcast trigger DPP with a reused broadcast exchange:\n${df.queryExecution}") + + subqueryBroadcast.foreach { s => + s.child match { + case _: ReusedExchangeExec => // reuse check ok. + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => // reuse check ok. + case b: BroadcastExchangeLike => + val hasReuse = plan.exists { + case ReusedExchangeExec(_, e) => e eq b + case _ => false + } + assert(hasReuse, s"$s\nshould have been reused in\n$plan") + case a: AdaptiveSparkPlanExec => + val broadcastQueryStage = collectFirst(a) { + case b: BroadcastQueryStageExec => b + } + val broadcastPlan = broadcastQueryStage.get.broadcast + val hasReuse = find(plan) { + case ReusedExchangeExec(_, e) => e eq broadcastPlan + case b: BroadcastExchangeLike => b eq broadcastPlan + case _ => false + }.isDefined + assert(hasReuse, s"$s\nshould have been reused in\n$plan") + case _ => + fail(s"Invalid child node found in\n$s") + } + } + + val isMainQueryAdaptive = plan.isInstanceOf[AdaptiveSparkPlanExec] + subqueriesAll(plan).filterNot(subqueryBroadcast.contains).foreach { s => + val subquery = s match { + case r: ReusedSubqueryExec => r.child + case o => o + } + assert(subquery.exists(_.isInstanceOf[AdaptiveSparkPlanExec]) == isMainQueryAdaptive) + } + } + + /** + * Check if the plan has the given number of distinct broadcast exchange subqueries. + */ + def checkDistinctSubqueries(df: DataFrame, n: Int): Unit = { + df.collect() + + val buf = collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect { + case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) => + b.index + } + assert(buf.distinct.size == n) + } + + /** + * Collect the children of all correctly pushed down dynamic pruning expressions in a spark plan. + */ + protected def collectDynamicPruningExpressions(plan: SparkPlan): Seq[Expression] = { + flatMap(plan) { + case s: ColumnarFileSourceScanExec => s.partitionFilters.collect { + case d: DynamicPruningExpression => d.child + } + case s: BatchScanExec => s.runtimeFilters.collect { + case d: DynamicPruningExpression => d.child + } + case _ => Nil + } + } + + /** + * Check if the plan contains unpushed dynamic pruning filters. + */ + def checkUnpushedFilters(df: DataFrame): Boolean = { + find(df.queryExecution.executedPlan) { + case FilterExec(condition, _) => + splitConjunctivePredicates(condition).exists { + case _: DynamicPruningExpression => true + case _ => false + } + case _ => false + }.isDefined + } + + test("broadcast a single key in a HashedRelation") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + withTable("fact", "dim") { + spark.range(100).select( + $"id", + ($"id" + 1).cast("int").as("one"), + ($"id" + 2).cast("byte").as("two"), + ($"id" + 3).cast("short").as("three"), + (($"id" * 20) % 100).as("mod"), + ($"id" + 1).cast("string").as("str")) + .write.partitionBy("one", "two", "three", "str") + .format(tableFormat).mode("overwrite").saveAsTable("fact") + + spark.range(10).select( + $"id", + ($"id" + 1).cast("int").as("one"), + ($"id" + 2).cast("byte").as("two"), + ($"id" + 3).cast("short").as("three"), + ($"id" * 10).as("prod"), + ($"id" + 1).cast("string").as("str")) + .write.format(tableFormat).mode("overwrite").saveAsTable("dim") + + // broadcast a single Long key + val dfLong = sql( + """ + |SELECT f.id, f.one, f.two, f.str FROM fact f + |JOIN dim d + |ON (f.one = d.one) + |WHERE d.prod > 80 + """.stripMargin) + + checkAnswer(dfLong, Row(9, 10, 11, "10") :: Nil) + + // reuse a single Byte key + val dfByte = sql( + """ + |SELECT f.id, f.one, f.two, f.str FROM fact f + |JOIN dim d + |ON (f.two = d.two) + |WHERE d.prod > 80 + """.stripMargin) + + checkAnswer(dfByte, Row(9, 10, 11, "10") :: Nil) + + // reuse a single String key + val dfStr = sql( + """ + |SELECT f.id, f.one, f.two, f.str FROM fact f + |JOIN dim d + |ON (f.str = d.str) + |WHERE d.prod > 80 + """.stripMargin) + + checkAnswer(dfStr, Row(9, 10, 11, "10") :: Nil) + } + } + } + + test("broadcast multiple keys in a LongHashedRelation") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + withTable("fact", "dim") { + spark.range(100).select( + $"id", + ($"id" + 1).cast("int").as("one"), + ($"id" + 2).cast("byte").as("two"), + ($"id" + 3).cast("short").as("three"), + (($"id" * 20) % 100).as("mod"), + ($"id" % 10).cast("string").as("str")) + .write.partitionBy("one", "two", "three") + .format(tableFormat).mode("overwrite").saveAsTable("fact") + + spark.range(10).select( + $"id", + ($"id" + 1).cast("int").as("one"), + ($"id" + 2).cast("byte").as("two"), + ($"id" + 3).cast("short").as("three"), + ($"id" * 10).as("prod")) + .write.format(tableFormat).mode("overwrite").saveAsTable("dim") + + // broadcast multiple keys + val dfLong = sql( + """ + |SELECT f.id, f.one, f.two, f.str FROM fact f + |JOIN dim d + |ON (f.one = d.one and f.two = d.two and f.three = d.three) + |WHERE d.prod > 80 + """.stripMargin) + + checkAnswer(dfLong, Row(9, 10, 11, "9") :: Nil) + } + } + } +} + + diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala new file mode 100644 index 0000000000000000000000000000000000000000..d8d7d0bd97807cb6bdcaad024e54232687b43937 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderPlanTestBase.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.sideBySide + +trait HeuristicJoinReorderPlanTestBase extends PlanTest { + + def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { + plans.map(_.output).reduce(_ ++ _) + } + + def assertEqualJoinPlans( + optimizer: RuleExecutor[LogicalPlan], + originalPlan: LogicalPlan, + groundTruthBestPlan: LogicalPlan): Unit = { + val analyzed = originalPlan.analyze + val optimized = optimizer.execute(analyzed) + val expected = EliminateResolvedHint.apply(groundTruthBestPlan.analyze) + + assert(equivalentOutput(analyzed, expected)) + assert(equivalentOutput(analyzed, optimized)) + + compareJoinOrder(optimized, expected) + } + + protected def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + normalizeExprIds(plan1).output == normalizeExprIds(plan2).output + } + + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide( + rewriteNameFromAttrNullability(normalized1).treeString, + rewriteNameFromAttrNullability(normalized2).treeString).mkString("\n")} + """.stripMargin) + } + } + + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right) + && j1.hint.leftHint == j2.hint.leftHint && j1.hint.rightHint == j2.hint.rightHint) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left) + && j1.hint.leftHint == j2.hint.rightHint && j1.hint.rightHint == j2.hint.leftHint) + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) + case _ => + plan1 == plan2 + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..c7ea9bd95ad953b136ff9f5f196a0e0bace24027 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/HeuristicJoinReorderSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} + +class HeuristicJoinReorderSuite + extends HeuristicJoinReorderPlanTestBase with StatsEstimationTestBase { + + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + attr("t1.k-1-2") -> rangeColumnStat(2, 0), + attr("t1.v-1-10") -> rangeColumnStat(10, 0), + attr("t2.k-1-5") -> rangeColumnStat(5, 0), + attr("t3.v-1-100") -> rangeColumnStat(100, 0), + attr("t4.k-1-2") -> rangeColumnStat(2, 0), + attr("t4.v-1-10") -> rangeColumnStat(10, 0), + attr("t5.k-1-5") -> rangeColumnStat(5, 0), + attr("t5.v-1-5") -> rangeColumnStat(5, 0) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val t1 = StatsTestPlan( + outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), + rowCount = 1000, + size = Some(1000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) + + private val t2 = StatsTestPlan( + outputList = Seq("t2.k-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3.v-1-100").map(nameToAttr), + rowCount = 100, + size = Some(100 * (8 + 4)), + attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) + + test("reorder 3 tables") { + val originalPlan = + t1.join(t2).join(t3) + .where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + + val analyzed = originalPlan.analyze + val optimized = HeuristicJoinReorder.apply(analyzed).select(outputsOf(t1, t2, t3): _*) + val expected = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(outputsOf(t1, t2, t3): _*) + + assert(equivalentOutput(analyzed, expected)) + assert(equivalentOutput(analyzed, optimized)) + + compareJoinOrder(optimized, expected) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..aaa244cdf65c04699ba2bf4c5c443e8727faf9f5 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeSubqueryFiltersSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import com.huawei.boostkit.spark.ColumnarPluginConfig + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetStructField, Literal, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class MergeSubqueryFiltersSuite extends PlanTest { + + override def beforeEach(): Unit = { + CTERelationDef.curId.set(0) + } + + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("MergeSubqueryFilters", Once, MergeSubqueryFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.string) + + private def definitionNode(plan: LogicalPlan, cteIndex: Int) = { + CTERelationDef(plan, cteIndex, underSubquery = true) + } + + private def extractorExpression(cteIndex: Int, output: Seq[Attribute], fieldIndex: Int) = { + GetStructField(ScalarSubquery(CTERelationRef(cteIndex, _resolved = true, output)), fieldIndex) + .as("scalarsubquery()") + } + + test("Merging subqueries with different filters") { + val subquery1 = ScalarSubquery(testRelation.where('b > 0).groupBy()(max('a).as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.where('b < 0).groupBy()(sum('a).as("sum_a"))) + val subquery3 = ScalarSubquery(testRelation.where('b === 0).groupBy()(avg('a).as("avg_a"))) + val originalQuery = testRelation + .select( + subquery1, + subquery2, + subquery3) + + val correctAnswer = if (ColumnarPluginConfig.getConf.filterMergeEnable) { + val mergedSubquery = testRelation + .where('b > 0 || 'b < 0 || 'b === 0) + .groupBy()( + max('a, Some('b > 0)).as("max_a"), + sum('a, Some('b < 0)).as("sum_a"), + avg('a, Some('b === 0)).as("avg_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), 'max_a, + Literal("sum_a"), 'sum_a, + Literal("avg_a"), 'avg_a + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + WithCTE( + testRelation + .select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1), + extractorExpression(0, analyzedMergedSubquery.output, 2)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + } else { + originalQuery + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merging subqueries with same condition in filter and in having") { + val subquery1 = ScalarSubquery(testRelation.where('b > 0).groupBy()(max('a).as("max_a"))) + val subquery2 = ScalarSubquery(testRelation.groupBy()(max('a, Some('b > 0)).as("max_a_2"))) + val originalQuery = testRelation + .select( + subquery1, + subquery2) + + val correctAnswer = if (ColumnarPluginConfig.getConf.filterMergeEnable) { + val mergedSubquery = testRelation + .groupBy()( + max('a, Some('b > 0)).as("max_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), 'max_a)).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + + WithCTE(testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 0)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + } else { + originalQuery + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("Merging subqueries with different filters, multiple filters propagated") { + val subquery1 = + ScalarSubquery(testRelation.where('b > 0).where('c === "a").groupBy()(max('a).as("max_a"))) + val subquery2 = + ScalarSubquery(testRelation.where('b > 0).where('c === "b").groupBy()(avg('a).as("avg_a"))) + val subquery3 = ScalarSubquery( + testRelation.where('b < 0).where('c === "c").groupBy()(count('a).as("cnt_a"))) + val originalQuery = testRelation + .select( + subquery1, + subquery2, + subquery3) + + val correctAnswer = if (ColumnarPluginConfig.getConf.filterMergeEnable) { + val mergedSubquery = testRelation + .where('b > 0 || 'b < 0) + .where('b > 0 && ('c === "a" || 'c === "b") || 'b < 0 && 'c === "c") + .groupBy()( + max('a, Some('b > 0 && 'c === "a")).as("max_a"), + avg('a, Some('b > 0 && 'c === "b")).as("avg_a"), + count('a, Some('b < 0 && 'c === "c")).as("cnt_a")) + .select(CreateNamedStruct(Seq( + Literal("max_a"), 'max_a, + Literal("avg_a"), 'avg_a, + Literal("cnt_a"), 'cnt_a + )).as("mergedValue")) + val analyzedMergedSubquery = mergedSubquery.analyze + + WithCTE(testRelation.select( + extractorExpression(0, analyzedMergedSubquery.output, 0), + extractorExpression(0, analyzedMergedSubquery.output, 1), + extractorExpression(0, analyzedMergedSubquery.output, 2)), + Seq(definitionNode(analyzedMergedSubquery, 0))) + } else { + originalQuery + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarCoalesceExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarCoalesceExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..9a2c51f043ada20edbceabca37a9b04b90f6bb6f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarCoalesceExecSuite.scala @@ -0,0 +1,166 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, Row} + +class ColumnarCoalesceExecSuite extends ColumnarSparkPlanTest { + + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var dealerDf: DataFrame = _ + private var dealerExpect: Seq[Row] = _ + private var floatDealerDf: DataFrame = _ + private var floatDealerExpect: Seq[Row] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + // for normal case + dealerDf = Seq[(Int, String, String, Int)]( + (100, "Fremont", "Honda Civic", 10), + (100, "Fremont", "Honda Accord", 15), + (100, "Fremont", "Honda CRV", 7), + (200, "Dublin", "Honda Civic", 20), + (200, "Dublin", "Honda Accord", 10), + (200, "Dublin", "Honda CRV", 3), + (300, "San Jose", "Honda Civic", 5), + (300, "San Jose", "Honda Accord", 8), + ).toDF("id", "city", "car_model", "quantity") + dealerDf.createOrReplaceTempView("dealer") + + dealerExpect = Seq( + Row(100, "Fremont", 10), + Row(100, "Fremont", 15), + Row(100, "Fremont", 7), + Row(200, "Dublin", 20), + Row(200, "Dublin", 10), + Row(200, "Dublin", 3), + Row(300, "San Jose", 5), + Row(300, "San Jose", 8), + ) + + // for rollback case + floatDealerDf = Seq[(Int, String, String, Float)]( + (100, "Fremont", "Honda Civic", 10.00F), + (100, "Fremont", "Honda Accord", 15.00F), + (100, "Fremont", "Honda CRV", 7.00F), + (200, "Dublin", "Honda Civic", 20.00F), + (200, "Dublin", "Honda Accord", 10.00F), + (200, "Dublin", "Honda CRV", 3.00F), + (300, "San Jose", "Honda Civic", 5.00F), + (300, "San Jose", "Honda Accord", 8.00F), + ).toDF("id", "city", "car_model", "quantity") + floatDealerDf.createOrReplaceTempView("float_dealer") + + floatDealerExpect = Seq( + Row(100, "Fremont", 10.00F), + Row(100, "Fremont", 15.00F), + Row(100, "Fremont", 7.00F), + Row(200, "Dublin", 20.00F), + Row(200, "Dublin", 10.00F), + Row(200, "Dublin", 3.00F), + Row(300, "San Jose", 5.00F), + Row(300, "San Jose", 8.00F), + ) + } + + test("use ColumnarCoalesceExec with normal input") { + val result = spark.sql("SELECT /*+ COALESCE(3) */ id, city, quantity FROM dealer") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarCoalesceExec]).isDefined) + assert(plan.find(_.isInstanceOf[CoalesceExec]).isEmpty) + checkAnswer(result, dealerExpect) + } + + test("use ColumnarCoalesceExec with normal input and not enable ColumnarExpandExec") { + // default is true + spark.conf.set("spark.omni.sql.columnar.coalesce", false) + val result = spark.sql("SELECT /*+ COALESCE(3) */ id, city, quantity FROM dealer") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarCoalesceExec]).isEmpty) + assert(plan.find(_.isInstanceOf[CoalesceExec]).isDefined) + spark.conf.set("spark.omni.sql.columnar.coalesce", true) + checkAnswer(result, dealerExpect) + } + + test("use ColumnarCoalesceExec with input not support and rollback") { + val result = spark.sql("SELECT /*+ COALESCE(3) */ id, city, quantity FROM float_dealer") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarCoalesceExec]).isEmpty) + assert(plan.find(_.isInstanceOf[CoalesceExec]).isDefined) + checkAnswer(result, floatDealerExpect) + } + + test("ColumnarCoalesceExec and CoalesceExec return the same result") { + val sql1 = "SELECT /*+ COALESCE(3) */ id, city, car_model, quantity FROM dealer" + checkCoalesceExecAndColumnarCoalesceExecAgree(sql1) + + val sql2 = "SELECT /*+ COALESCE(3) */ id, city, car_model, quantity FROM float_dealer" + checkCoalesceExecAndColumnarCoalesceExecAgree(sql2, true) + } + + // check CoalesceExec and ColumnarCoalesceExec return the same result + private def checkCoalesceExecAndColumnarCoalesceExecAgree(sql: String, + rollBackByInputCase: Boolean = false): Unit = { + spark.conf.set("spark.omni.sql.columnar.coalesce", true) + val omniResult = spark.sql(sql) + val omniPlan = omniResult.queryExecution.executedPlan + if (rollBackByInputCase) { + assert(omniPlan.find(_.isInstanceOf[ColumnarCoalesceExec]).isEmpty, + s"SQL:${sql}\n@SparkEnv not have ColumnarCoalesceExec, sparkPlan:${omniPlan}") + assert(omniPlan.find(_.isInstanceOf[CoalesceExec]).isDefined, + s"SQL:${sql}\n@SparkEnv have CoalesceExec, sparkPlan:${omniPlan}") + } else { + assert(omniPlan.find(_.isInstanceOf[ColumnarCoalesceExec]).isDefined, + s"SQL:${sql}\n@SparkEnv have ColumnarCoalesceExec, sparkPlan:${omniPlan}") + assert(omniPlan.find(_.isInstanceOf[CoalesceExec]).isEmpty, + s"SQL:${sql}\n@SparkEnv not have CoalesceExec, sparkPlan:${omniPlan}") + } + + spark.conf.set("spark.omni.sql.columnar.coalesce", false) + val sparkResult = spark.sql(sql) + val sparkPlan = sparkResult.queryExecution.executedPlan + assert(sparkPlan.find(_.isInstanceOf[ColumnarCoalesceExec]).isEmpty, + s"SQL:${sql}\n@SparkEnv not have ColumnarCoalesceExec, sparkPlan:${sparkPlan}") + assert(sparkPlan.find(_.isInstanceOf[CoalesceExec]).isDefined, + s"SQL:${sql}\n@SparkEnv have CoalesceExec, sparkPlan:${sparkPlan}") + // DataFrame do not support comparing with equals method, use DataFrame.except instead + assert(omniResult.except(sparkResult).isEmpty) + spark.conf.set("spark.omni.sql.columnar.coalesce", true) + } + + test("use ColumnarCoalesceExec by RDD api to check repartition") { + // reinit to 6 partitions + val dealerDf6P = dealerDf.repartition(6) + assert(dealerDf6P.rdd.partitions.length == 6) + + // coalesce to 2 partitions + val dealerDfCoalesce2P = dealerDf6P.coalesce(2) + assert(dealerDfCoalesce2P.rdd.partitions.length == 2) + val dealerDfCoalesce2Plan = dealerDfCoalesce2P.queryExecution.executedPlan + assert(dealerDfCoalesce2Plan.find(_.isInstanceOf[ColumnarCoalesceExec]).isDefined, + s"sparkPlan:${dealerDfCoalesce2Plan}") + assert(dealerDfCoalesce2Plan.find(_.isInstanceOf[CoalesceExec]).isEmpty, + s"sparkPlan:${dealerDfCoalesce2Plan}") + // always return 8 rows + assert(dealerDfCoalesce2P.collect().length == 8) + } +} + diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala index cc724b31ad30b1e4f0d152194bd927ed71d7b3f3..311d7a990eaa70e4dae1e36c2edb832494cbf595 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExecSuite.scala @@ -47,7 +47,7 @@ class ColumnarExecSuite extends ColumnarSparkPlanTest { test("spark limit with columnarToRow as child") { - // fetch parital + // fetch Partial val sql1 = "select * from (select a, b+2 from dealer order by a, b+2) limit 2" assertColumnarToRowOmniAndSparkResultEqual(sql1, false) @@ -59,7 +59,7 @@ class ColumnarExecSuite extends ColumnarSparkPlanTest { val sql3 = "select a, b+2 from dealer limit 10" assertColumnarToRowOmniAndSparkResultEqual(sql3, true) - // fetch parital + // fetch Partial val sql4 = "select a, b+2 from dealer order by a limit 2" assertColumnarToRowOmniAndSparkResultEqual(sql4, false) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExpandExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExpandExecSuite.scala index 5c39c04851ebd02e4504575861a87a90223341d1..4e0c0768a9944cd6c882af3dfb1a43efce709e7d 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExpandExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarExpandExecSuite.scala @@ -69,11 +69,41 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { } - test("use ColumnarExpandExec in Grouping Sets clause when default") { + test("use ColumnarExpandExec in Grouping Sets clause when default, case1 can't match rollup optimization") { val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + // GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) is equal to CUBE (city, car_model) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) + } + + test("use ColumnarExpandExec in Grouping Sets clause when default, case2 can't match rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY GROUPING SETS ((city, car_model), ()) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) + } + + test("use ColumnarExpandExec in Grouping Sets clause when default, case3 matches rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY GROUPING SETS ((city, car_model), (city), ()) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + // GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) is equal to ROLLUP (city, car_model) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isDefined) + assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) + } + + test("use ColumnarExpandExec in Grouping Sets clause when default, case4 matches rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY GROUPING SETS ((city, car_model), (city)) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isDefined) assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) } @@ -82,6 +112,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) } @@ -91,16 +122,29 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) spark.conf.set("spark.omni.sql.columnar.expand", true) } - test("use ColumnarExpandExec in Rollup clause when default") { + test("use ColumnarExpandExec in Rollup clause when default, default use rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isDefined) + assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) + } + + test("use ColumnarExpandExec in Rollup clause when default, not use rollup optimization") { val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;") + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", false) val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", true) } test("use ExpandExec in Rollup clause when SparkExtension rollback") { @@ -108,6 +152,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) } @@ -117,6 +162,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) spark.conf.set("spark.omni.sql.columnar.expand", true) } @@ -126,6 +172,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY CUBE(city, car_model) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isEmpty) } @@ -134,6 +181,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY CUBE(city, car_model) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) } @@ -143,15 +191,17 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY CUBE(city, car_model) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isEmpty) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) assert(plan.find(_.isInstanceOf[ExpandExec]).isDefined) spark.conf.set("spark.omni.sql.columnar.expand", true) } - test("ColumnarExpandExec exec correctly in Grouping Sets clause") { + test("ColumnarExpandExec exec correctly in Grouping Sets clause, case1 can't match rollup optimization") { val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) val expect = Seq( Row(null, null, 78), @@ -173,11 +223,60 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { checkAnswer(result, expect) } - test("ColumnarExpandExec exec correctly in Rollup clause") { + test("ColumnarExpandExec exec correctly in Grouping Sets clause, case2 matches rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY GROUPING SETS ((city, car_model), (city)) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isDefined) + + val expect = Seq( + Row("Dublin", null, 33), + Row("Dublin", "Honda Accord", 10), + Row("Dublin", "Honda CRV", 3), + Row("Dublin", "Honda Civic", 20), + Row("Fremont", null, 32), + Row("Fremont", "Honda Accord", 15), + Row("Fremont", "Honda CRV", 7), + Row("Fremont", "Honda Civic", 10), + Row("San Jose", null, 13), + Row("San Jose", "Honda Accord", 8), + Row("San Jose", "Honda Civic", 5), + ) + checkAnswer(result, expect) + } + + test("ColumnarExpandExec exec correctly in Rollup clause, default use rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY ROLLUP (city, car_model) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isDefined) + + val expect = Seq( + Row(null, null, 78), + Row("Dublin", null, 33), + Row("Dublin", "Honda Accord", 10), + Row("Dublin", "Honda CRV", 3), + Row("Dublin", "Honda Civic", 20), + Row("Fremont", null, 32), + Row("Fremont", "Honda Accord", 15), + Row("Fremont", "Honda CRV", 7), + Row("Fremont", "Honda Civic", 10), + Row("San Jose", null, 13), + Row("San Jose", "Honda Accord", 8), + Row("San Jose", "Honda Civic", 5), + ) + checkAnswer(result, expect) + } + + test("ColumnarExpandExec exec correctly in Rollup clause, not use rollup optimization") { val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum FROM dealer " + "GROUP BY ROLLUP (city, car_model) ORDER BY city, car_model;") + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", false) val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) val expect = Seq( Row(null, null, 78), @@ -194,6 +293,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { Row("San Jose", "Honda Civic", 5), ) checkAnswer(result, expect) + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", true) } test("ColumnarExpandExec exec correctly in Cube clause") { @@ -201,6 +301,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) val expect = Seq( Row(null, null, 78), @@ -222,11 +323,12 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { checkAnswer(result, expect) } - test("ColumnarExpandExec exec correctly in Grouping Sets clause with GROUPING__ID column") { + test("ColumnarExpandExec exec correctly in Grouping Sets clause with GROUPING__ID column, case1 can't match rollup optimization") { val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) val expect = Seq( Row(null, null, 78, 3), @@ -248,11 +350,60 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { checkAnswer(result, expect) } - test("ColumnarExpandExec exec correctly in Rollup clause with GROUPING__ID column") { + test("ColumnarExpandExec exec correctly in Grouping Sets clause with GROUPING__ID column, case2 matches rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + + "GROUP BY GROUPING SETS ((city, car_model), (city)) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isDefined) + + val expect = Seq( + Row("Dublin", null, 33, 1), + Row("Dublin", "Honda Accord", 10, 0), + Row("Dublin", "Honda CRV", 3, 0), + Row("Dublin", "Honda Civic", 20, 0), + Row("Fremont", null, 32, 1), + Row("Fremont", "Honda Accord", 15, 0), + Row("Fremont", "Honda CRV", 7, 0), + Row("Fremont", "Honda Civic", 10, 0), + Row("San Jose", null, 13, 1), + Row("San Jose", "Honda Accord", 8, 0), + Row("San Jose", "Honda Civic", 5, 0), + ) + checkAnswer(result, expect) + } + + test("ColumnarExpandExec exec correctly in Rollup clause with GROUPING__ID column, default use rollup optimization") { + val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + + "GROUP BY ROLLUP (city, car_model) ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isDefined) + + val expect = Seq( + Row(null, null, 78, 3), + Row("Dublin", null, 33, 1), + Row("Dublin", "Honda Accord", 10, 0), + Row("Dublin", "Honda CRV", 3, 0), + Row("Dublin", "Honda Civic", 20, 0), + Row("Fremont", null, 32, 1), + Row("Fremont", "Honda Accord", 15, 0), + Row("Fremont", "Honda CRV", 7, 0), + Row("Fremont", "Honda Civic", 10, 0), + Row("San Jose", null, 13, 1), + Row("San Jose", "Honda Accord", 8, 0), + Row("San Jose", "Honda Civic", 5, 0), + ) + checkAnswer(result, expect) + } + + test("ColumnarExpandExec exec correctly in Rollup clause with GROUPING__ID column, not use rollup optimization") { val result = spark.sql("SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + "GROUP BY ROLLUP (city, car_model) ORDER BY city, car_model;") + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", false) val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) val expect = Seq( Row(null, null, 78, 3), @@ -269,6 +420,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { Row("San Jose", "Honda Civic", 5, 0), ) checkAnswer(result, expect) + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", true) } test("ColumnarExpandExec exec correctly in Cube clause with GROUPING__ID column") { @@ -276,6 +428,7 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;") val plan = result.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) val expect = Seq( Row(null, null, 78, 3), @@ -298,16 +451,30 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { } - test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause") { + test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause, case1 can't match rollup optimization") { val sql = "SELECT city, car_model, sum(quantity) AS sum FROM dealer " + "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;" checkExpandExecAndColumnarExpandExecAgree(sql) } - test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause") { + test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause, case2 matches rollup optimization") { + val sql = "SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY GROUPING SETS ((city, car_model), (city)) ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause, default use rollup optimization") { + val sql = "SELECT city, car_model, sum(quantity) AS sum FROM dealer " + + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause, not use rollup optimization") { val sql = "SELECT city, car_model, sum(quantity) AS sum FROM dealer " + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", false) checkExpandExecAndColumnarExpandExecAgree(sql) + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", true) } test("ColumnarExpandExec and ExpandExec return the same result when use Cube clause") { @@ -316,16 +483,30 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { checkExpandExecAndColumnarExpandExecAgree(sql) } - test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with null value") { + test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with null value, case1 can't match rollup optimization") { val sql = "SELECT city, car_model, sum(quantity) AS sum FROM null_dealer " + "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;" checkExpandExecAndColumnarExpandExecAgree(sql) } - test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with null value") { + test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with null value, case2 matches rollup optimization") { + val sql = "SELECT city, car_model, sum(quantity) AS sum FROM null_dealer " + + "GROUP BY GROUPING SETS ((city, car_model), (city)) ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with null value, default use rollup optimization") { + val sql = "SELECT city, car_model, sum(quantity) AS sum FROM null_dealer " + + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with null value, not use rollup optimization") { val sql = "SELECT city, car_model, sum(quantity) AS sum FROM null_dealer " + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", false) checkExpandExecAndColumnarExpandExecAgree(sql) + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", true) } test("ColumnarExpandExec and ExpandExec return the same result when use Cube clause with null value") { @@ -334,23 +515,83 @@ class ColumnarExpandExecSuite extends ColumnarSparkPlanTest { checkExpandExecAndColumnarExpandExecAgree(sql) } - test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with GROUPING__ID column") { + test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with GROUPING__ID column, case1 can't match rollup optimization") { val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + "GROUP BY GROUPING SETS ((city, car_model), (city), (car_model), ()) ORDER BY city, car_model;" checkExpandExecAndColumnarExpandExecAgree(sql) } - test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with GROUPING__ID column") { + test("ColumnarExpandExec and ExpandExec return the same result when use Grouping Sets clause with GROUPING__ID column, case2 matches rollup optimization") { + val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + + "GROUP BY GROUPING SETS ((city, car_model), (city)) ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with GROUPING__ID column, default use rollup optimization") { + val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use Rollup clause with GROUPING__ID column, not use rollup optimization") { val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + "GROUP BY ROLLUP(city, car_model) ORDER BY city, car_model;" + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", false) checkExpandExecAndColumnarExpandExecAgree(sql) + spark.conf.set("spark.omni.sql.columnar.rollupOptimization.enabled", true) } + test("ColumnarExpandExec and ExpandExec return the same result when use Cube clause with GROUPING__ID column") { val sql = "SELECT city, car_model, sum(quantity) AS sum, GROUPING__ID FROM dealer " + "GROUP BY CUBE (city, car_model) ORDER BY city, car_model;" checkExpandExecAndColumnarExpandExecAgree(sql) } + // test distinct + rollup + test("ColumnarExpandExec and ExpandExec return the same result when use distinct + Rollup clause, case1") { + val result = spark.sql("SELECT city, car_model, sum(DISTINCT id) AS sum, count(DISTINCT quantity) AS count FROM dealer " + + "GROUP BY city, car_model WITH ROLLUP ORDER BY city, car_model;") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarExpandExec]).isDefined) + assert(plan.find(_.isInstanceOf[ColumnarOptRollupExec]).isEmpty) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use distinct + Rollup clause, case2") { + val sql = "SELECT city, car_model, count(DISTINCT quantity) AS count FROM dealer " + + "GROUP BY city, car_model WITH ROLLUP ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use distinct + Rollup clause, case3") { + val sql = "SELECT city, car_model, count(DISTINCT quantity) AS count, sum(id) AS sum FROM dealer " + + "GROUP BY city, car_model WITH ROLLUP ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use distinct + Rollup clause, case4") { + val sql = "SELECT city, car_model, sum(id) AS sum, count(DISTINCT quantity) AS count FROM dealer " + + "GROUP BY city, car_model WITH ROLLUP ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use distinct + Rollup clause, case5") { + val sql = "SELECT city, car_model, sum(DISTINCT coalesce(id * quantity, 0)) AS sum FROM dealer " + + "GROUP BY city, car_model WITH ROLLUP ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use distinct + Rollup clause, case6") { + val sql = "SELECT city, car_model, sum(coalesce(id * quantity, 0)) AS sum, count(DISTINCT quantity) AS count FROM dealer " + + "GROUP BY city, car_model WITH ROLLUP ORDER BY city, car_model;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + + test("ColumnarExpandExec and ExpandExec return the same result when use distinct + Rollup clause, case7") { + val sql = "SELECT city, count(car_model) AS count1, count(DISTINCT quantity) AS count2, sum(id) AS sum FROM dealer " + + "GROUP BY city WITH ROLLUP ORDER BY city;" + checkExpandExecAndColumnarExpandExecAgree(sql) + } + // check ExpandExec and ColumnarExpandExec return the same result def checkExpandExecAndColumnarExpandExecAgree(sql: String): Unit = { spark.conf.set("spark.omni.sql.columnar.expand", true) diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala index ad0fe196ab03cc706b0cdc908c19514ff4aa6b06..a3eee279a30287ddd48bf624bbc9035d82fab6e5 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarJoinExecSuite.scala @@ -251,16 +251,17 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { val df = leftWithNull.join(rightWithNull.hint("broadcast"), col("q").isNotNull === col("c").isNotNull, "leftouter") checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( - Row("abc", null, 4, 2.0, " add", null, 1, null), - Row("abc", null, 4, 2.0, "", "Hello", 2, 2.0), Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), + Row("abc", null, 4, 2.0, "", "Hello", 2, 2.0), + Row("abc", null, 4, 2.0, " add", null, 1, null), Row("", "Hello", null, 1.0, " yeah ", null, null, 4.0), - Row(" add", "World", 8, 3.0, " add", null, 1, null), - Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), Row(" add", "World", 8, 3.0, "abc", "", 4, 1.0), - Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null), + Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), + Row(" add", "World", 8, 3.0, " add", null, 1, null), + Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), - Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0) + Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) + ), false) } @@ -289,15 +290,15 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { col("q").isNotNull === col("c").isNotNull, "fullouter") checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( Row("", "Hello", null, 1.0, " yeah ", null, null, 4.0), - Row("abc", null, 4, 2.0, " add", null, 1, null), - Row("abc", null, 4, 2.0, "", "Hello", 2, 2.0), Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), - Row(" add", "World", 8, 3.0, " add", null, 1, null), - Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), + Row("abc", null, 4, 2.0, "", "Hello", 2, 2.0), + Row("abc", null, 4, 2.0, " add", null, 1, null), Row(" add", "World", 8, 3.0, "abc", "", 4, 1.0), - Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null), + Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), + Row(" add", "World", 8, 3.0, " add", null, 1, null), + Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), - Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0) + Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null) ), false) } @@ -382,10 +383,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678), Row("Carter", 77895), - Row("Adams", 22456), + Row("Carter", 44678), Row("Adams", 24562), + Row("Adams", 22456), Row("Bush", null) ), false) } @@ -397,10 +398,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678, 3), Row("Carter", 77895, 3), - Row("Adams", 22456, 1), + Row("Carter", 44678, 3), Row("Adams", 24562, 1), + Row("Adams", 22456, 1), Row("Bush", null, null) ), false) } @@ -412,10 +413,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row(44678, "Carter", 3), Row(77895, "Carter", 3), - Row(22456, "Adams", 1), + Row(44678, "Carter", 3), Row(24562, "Adams", 1), + Row(22456, "Adams", 1), Row(null, "Bush", null) ), false) } @@ -427,10 +428,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"SQL:\n@OmniEnv have ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row(44679, "Carter"), Row(77896, "Carter"), - Row(22457, "Adams"), + Row(44679, "Carter"), Row(24563, "Adams"), + Row(22457, "Adams"), Row(null, "Bush") ), false) } @@ -442,10 +443,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678), Row("Carter", 77895), - Row("Adams", 22456), + Row("Carter", 44678), Row("Adams", 24562), + Row("Adams", 22456), Row("Bush", null) ), false) } @@ -476,6 +477,47 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { ), false) } + test("columnar ShuffledHashJoin right outer join is equal to native") { + val df = left.join(right.hint("SHUFFLE_HASH"), col("q") === col("c"), "rightouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), + Row(null, null, null, null, "", "Hello", 2, 2.0), + Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), + Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) + ), false) + } + + test("columnar ShuffledHashJoin right outer join is equal to native with null") { + val df = leftWithNull.join(rightWithNull.hint("SHUFFLE_HASH"), + col("q") === col("c"), "rightouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), + Row(null, null, null, null, "", "Hello", 2, 2.0), + Row(null, null, null, null, " add", null, 1, null), + Row(null, null, null, null, " yeah ", null, null, 4.0) + ), false) + } + + test("columnar BroadcastHashJoin right outer join is equal to native") { + val df = left.join(right.hint("broadcast"), col("q") === col("c"), "rightouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), + Row(null, null, null, null, "", "Hello", 2, 2.0), + Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), + Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) + ), false) + } + + test("columnar BroadcastHashJoin right outer join is equal to native with null") { + val df = leftWithNull.join(rightWithNull.hint("broadcast"), col("q") === col("c"), "rightouter") + checkAnswer(df, _ => df.queryExecution.executedPlan, Seq( + Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), + Row(null, null, null, null, "", "Hello", 2, 2.0), + Row(null, null, null, null, " add", null, 1, null), + Row(null, null, null, null, " yeah ", null, null, 4.0) + ), false) + } + test("shuffledHashJoin and project funsion test") { val omniResult = person_test.join(order_test.hint("SHUFFLE_HASH"), person_test("id_p") === order_test("id_p"), "inner") .select(person_test("name"), order_test("order_no")) @@ -483,10 +525,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678), Row("Carter", 77895), - Row("Adams", 22456), - Row("Adams", 24562) + Row("Carter", 44678), + Row("Adams", 24562), + Row("Adams", 22456) ), false) } @@ -497,10 +539,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678, 3), Row("Carter", 77895, 3), - Row("Adams", 22456, 1), - Row("Adams", 24562, 1) + Row("Carter", 44678, 3), + Row("Adams", 24562, 1), + Row("Adams", 22456, 1) ), false) } @@ -511,10 +553,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row(44678, "Carter", 3), Row(77895, "Carter", 3), - Row(22456, "Adams", 1), - Row(24562, "Adams", 1) + Row(44678, "Carter", 3), + Row(24562, "Adams", 1), + Row(22456, "Adams", 1) ), false) } @@ -525,10 +567,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isDefined, s"SQL:\n@OmniEnv have ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row(44679, "Carter"), Row(77896, "Carter"), - Row(22457, "Adams"), - Row(24563, "Adams") + Row(44679, "Carter"), + Row(24563, "Adams"), + Row(22457, "Adams") ), false) } @@ -539,10 +581,10 @@ class ColumnarJoinExecSuite extends ColumnarSparkPlanTest { assert(omniPlan.find(_.isInstanceOf[ColumnarProjectExec]).isEmpty, s"SQL:\n@OmniEnv no ColumnarProjectExec,omniPlan:${omniPlan}") checkAnswer(omniResult, _ => omniPlan, Seq( - Row("Carter", 44678), Row("Carter", 77895), - Row("Adams", 22456), - Row("Adams", 24562) + Row("Carter", 44678), + Row("Adams", 24562), + Row("Adams", 22456) ), false) } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala index 09d7a75c48ed9cdc9d75d750b95c8c5a83aca6d4..5944618785ba67a869cba691d5ef223b7c13c045 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarLimitExecSuit.scala @@ -40,12 +40,48 @@ class ColumnarLimitExecSuit extends ColumnarSparkPlanTest { (3, 3, 3), (4, 5, 6) ).toDF("a", "b", "c") + left.createOrReplaceTempView("left") right = Seq[(java.lang.Integer, java.lang.Integer, java.lang.Integer)]( (1, 1, 1), (2, 2, 2), (3, 3, 3) ).toDF("x", "y", "z") + right.createOrReplaceTempView("right") + } + + test("limit with local and global limit columnar exec") { + val result = spark.sql("SELECT y FROM right WHERE x in " + + "(SELECT a FROM left WHERE a = 4 LIMIT 2)") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarLocalLimitExec]).isDefined, + s"not match ColumnarLocalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[LocalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(plan.find(_.isInstanceOf[ColumnarGlobalLimitExec]).isDefined, + s"not match ColumnarGlobalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[GlobalLimitExec]).isEmpty, + s"real plan: ${plan}") + // 0 rows return + assert(result.count() == 0) + } + + test("limit with rollback global limit to row-based exec") { + spark.conf.set("spark.omni.sql.columnar.globalLimit", false) + val result = spark.sql("SELECT a FROM left WHERE a in " + + "(SELECT x FROM right LIMIT 2)") + val plan = result.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ColumnarLocalLimitExec]).isDefined, + s"not match ColumnarLocalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[LocalLimitExec]).isEmpty, + s"real plan: ${plan}") + assert(plan.find(_.isInstanceOf[ColumnarGlobalLimitExec]).isEmpty, + s"match ColumnarGlobalLimitExec, real plan: ${plan}") + assert(plan.find(_.isInstanceOf[GlobalLimitExec]).isDefined, + s"real plan: ${plan}") + // 2 rows return + assert(result.count() == 2) + spark.conf.set("spark.omni.sql.columnar.globalLimit", true) } test("Push down limit through LEFT SEMI and LEFT ANTI join") { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala index 679da5a6f9d947de8929130334cb188103a92d5e..a788501ed8ed7d06e8939cd45560f59648a56acf 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala @@ -48,14 +48,14 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { } test("Test topNSort") { - val sql1 ="select * from (SELECT city, rank() OVER (ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;" + val sql1 = "select * from (SELECT city, rank() OVER (ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;" assertColumnarTopNSortExecAndSparkResultEqual(sql1, true) val sql2 = "select * from (SELECT city, row_number() OVER (ORDER BY sales) AS rn FROM dealer) where rn < 4 order by rn;" assertColumnarTopNSortExecAndSparkResultEqual(sql2, false) val sql3 = "select * from (SELECT city, rank() OVER (PARTITION BY city ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;" - assertColumnarTopNSortExecAndSparkResultEqual(sql3, false) + assertColumnarTopNSortExecAndSparkResultEqual(sql3, true) } private def assertColumnarTopNSortExecAndSparkResultEqual(sql: String, hasColumnarTopNSortExec: Boolean = true): Unit = { @@ -63,20 +63,23 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { spark.conf.set("spark.omni.sql.columnar.topNSort", true) spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", true) spark.conf.set("spark.sql.execution.topNPushDownForWindow.threshold", 100) + spark.conf.set("spark.sql.adaptive.enabled", true) val omniResult = spark.sql(sql) - val omniPlan = omniResult.queryExecution.executedPlan + omniResult.collect() + val omniPlan = omniResult.queryExecution.executedPlan.toString() if (hasColumnarTopNSortExec) { - assert(omniPlan.find(_.isInstanceOf[ColumnarTopNSortExec]).isDefined, + assert(omniPlan.contains("ColumnarTopNSort"), s"SQL:${sql}\n@OmniEnv no ColumnarTopNSortExec, omniPlan:${omniPlan}") } // run TopNSortExec config spark.conf.set("spark.omni.sql.columnar.topNSort", false) val sparkResult = spark.sql(sql) - val sparkPlan = sparkResult.queryExecution.executedPlan - assert(sparkPlan.find(_.isInstanceOf[ColumnarTopNSortExec]).isEmpty, + sparkResult.collect() + val sparkPlan = sparkResult.queryExecution.executedPlan.toString() + assert(!sparkPlan.contains("ColumnarTopNSort"), s"SQL:${sql}\n@SparkEnv have ColumnarTopNSortExec, sparkPlan:${sparkPlan}") - assert(sparkPlan.find(_.isInstanceOf[TopNSortExec]).isDefined, + assert(sparkPlan.contains("TopNSort"), s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}") // DataFrame do not support comparing with equals method, use DataFrame.except instead // DataFrame.except can do equal for rows misorder(with and without order by are same) diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 81043d4afdf8e62ca6ecb3bdaca521d9518b2720..b7315c5b49145c0805d940cd088eb70cbc9c6265 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,19 +8,19 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.3.1-1.3.0 + 3.3.1-1.4.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.3.1-h0.cbu.mrs.321.r7 - 3.3.1-h0.cbu.mrs.321.r7 + 3.3.1 + 3.2.2 UTF-8 UTF-8 3.13.0-h19 FALSE - 1.3.0 + 1.4.0 java @@ -171,4 +171,13 @@
+ + + + hadoop-3.2 + + 3.2.0 + + + \ No newline at end of file