diff --git a/README.md b/README.md index 218a76cb77ba1bebe4a4744b8cd8837bd755d390..8ee84cadedd7862a1540242ca1a5d604a576dd0f 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ ms_custom_ops/ ### 1. 环境准备 确保已安装: -- MindSpore >= 2.6 +- MindSpore br_infer_iter分支日构建包 - 昇腾 CANN 工具包 - CMake >= 3.14 - Python >= 3.9 diff --git a/ccsrc/CMakeLists.txt b/ccsrc/CMakeLists.txt index b74143bf35c8e574ff61e36d1d9e6d4862c882f6..2f566fcb765d37bdc9b17a3abe268dd1be1af1f5 100644 --- a/ccsrc/CMakeLists.txt +++ b/ccsrc/CMakeLists.txt @@ -74,7 +74,7 @@ endfunction() file(GLOB_RECURSE OPS_YAML_FILES "${CMAKE_CURRENT_SOURCE_DIR}/../yaml/*_op.yaml") message(STATUS "OPS_YAML_FILES: ${OPS_YAML_FILES}") -get_yaml_files("${OPS_YAML_FILES}" OPS_YAML_STRING) +get_yaml_files("${OPS_YAML_FILES}" DEF_YAML_STRING) file(GLOB_RECURSE DOC_YAML_FILES "${CMAKE_CURRENT_SOURCE_DIR}/../yaml/*_doc.yaml") message(STATUS "DOC_YAML_FILES: ${DOC_YAML_FILES}") @@ -92,8 +92,8 @@ src_files = '${SRC_FILES}'.split(';') ms.ops.CustomOpBuilder( name='${MS_EXTENSION_NAME}', sources=src_files, - yaml=${OPS_YAML_STRING}, - doc=${DOC_YAML_STRING}, + op_def=${DEF_YAML_STRING}, + op_doc=${DOC_YAML_STRING}, backend='Ascend', cflags='${CFLAGS_INCLUDES}', ldflags='-L${INTERNAL_KERNEL_LIB_PATH} -l${LIBS}', diff --git a/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc b/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc index b8f5d631e5ee0889aad1e428259dbc59a48d929a..764fddd3912491fde41c460bef1071d5f4a73238 100644 --- a/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc +++ b/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc @@ -77,18 +77,8 @@ protected: const internal::OutputsImmutableInfoList &outputs, const std::vector &ms_inputs, const std::vector &ms_outputs) override { - internal::ReshapeAndCacheParam param; - auto head_num = ms_inputs.at(internal::kIndex5); - if (head_num->dtype_id() == TypeId::kNumberTypeInt64) { - param.head_num = - static_cast(head_num->GetValue().value()); - } else { - MS_LOG(EXCEPTION) - << "ReshapeAndCache [head_num]'s dtype wrong, expect int64, but got: " - << head_num->dtype_id(); - } return internal::CreateReshapeAndCacheOp( - inputs, outputs, param, internal::kInternalReshapeAndCacheOpName); + inputs, outputs, internal::kInternalReshapeAndCacheOpName); } }; } // namespace ms_custom_ops @@ -114,10 +104,8 @@ protected: internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, const internal::OutputsImmutableInfoList &outputs) override { - internal::ReshapeAndCacheParam param; - param.head_num = this->head_num_; return internal::CreateReshapeAndCacheOp( - inputs, outputs, param, internal::kInternalReshapeAndCacheOpName); + inputs, outputs, internal::kInternalReshapeAndCacheOpName); } private: